from typing import List  # isort:skip

import numpy as np
from skimage.color import label2rgb

import torch


def encode_mask_with_color(
    semantic_masks: torch.Tensor, threshold: float = 0.5
) -> List[np.ndarray]:
    """
    Compare each channel of the `mask` with the `threshold`
    and encode the result with the levels of grey.

    Args:
        semantic_masks (torch.Tensor): semantic mask batch tensor
        threshold (float): threshold for masks binarization

    Returns:
        List[np.ndarray]: list of semantic masks
    """
    batch = []
    for observation in semantic_masks:
        result = np.zeros_like(observation[0], dtype=np.int32)
        for i, ch in enumerate(observation, start=1):
            result[ch > threshold] = i

        batch.append(result)

    return batch


def mask_to_overlay_image(
    image: np.ndarray, mask: np.ndarray, mask_strength: float
) -> np.ndarray:
    """Draw mask over image.

    Args:
        image (np.ndarray): RGB image used as underlay for masks
        mask (np.ndarray): mask to draw
        mask_strength (float): opacity of colorized masks

    Returns:
        np.ndarray: HxWx3 image with overlay
    """
    mask = label2rgb(mask, bg_label=0)
    image_with_overlay = image * (1 - mask_strength) + mask * mask_strength
    image_with_overlay = (
        (image_with_overlay * 255).clip(0, 255).round().astype(np.uint8)
    )
    return image_with_overlay