import random

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
import lib.augmentations as aug


def gen_random_image(patch_size):
    img = np.zeros((patch_size, patch_size, 3), dtype=np.uint8)
    mask = np.zeros((patch_size, patch_size), dtype=np.uint8)

    # Background
    dark_color0 = random.randint(0, 100)
    dark_color1 = random.randint(0, 100)
    dark_color2 = random.randint(0, 100)
    img[:, :, 0] = dark_color0
    img[:, :, 1] = dark_color1
    img[:, :, 2] = dark_color2

    # Object
    light_color0 = random.randint(dark_color0 + 1, 255)
    light_color1 = random.randint(dark_color1 + 1, 255)
    light_color2 = random.randint(dark_color2 + 1, 255)
    center_0 = random.randint(0, patch_size)
    center_1 = random.randint(0, patch_size)
    r1 = random.randint(10, 56)
    r2 = random.randint(10, 56)
    cv2.ellipse(img, (center_0, center_1), (r1, r2), 0, 0, 360, (light_color0, light_color1, light_color2), -1)
    cv2.ellipse(mask, (center_0, center_1), (r1, r2), 0, 0, 360, 1, -1)

    # White noise
    density = random.uniform(0, 0.1)
    for i in range(patch_size):
        for j in range(patch_size):
            if random.random() < density:
                img[i, j, 0] = random.randint(0, 255)
                img[i, j, 1] = random.randint(0, 255)
                img[i, j, 2] = random.randint(0, 255)

    return img, mask


class ShapesDataset(Dataset):
    def __init__(self, steps, patch_size, transform=aug.ImageOnly(aug.NormalizeImage())):
        self.transform = transform
        self.patch_size = patch_size
        self.steps = steps

    def __len__(self):
        return self.steps

    def __getitem__(self, item):
        image, mask = gen_random_image(self.patch_size)
        image, mask = self.transform(image, mask)

        image = torch.from_numpy(np.moveaxis(image, -1, 0).copy()).float()
        mask = torch.from_numpy(np.expand_dims(mask, 0)).long()
        return image, mask


def SHAPES(patch_size):
    """
    https://github.com/ZFTurbo/ZF_UNET_patch_size_Pretrained_Model/blob/master/train_infinite_generator.py
    :param patch_size:
    :return:
    """
    return ShapesDataset(1024, patch_size), ShapesDataset(128, patch_size), 1