import argparse
import json
from pathlib import Path
from typing import Dict, Tuple

import cv2
import numpy as np
import torch
import torch.backends.cudnn
import utils
from PIL import Image
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from unet_models import Loss, UNet11

img_cols, img_rows = 1280, 1920

Size = Tuple[int, int]


class CarvanaDataset(Dataset):
    def __init__(self, root: Path, to_augment=False):
        # TODO This potentially may lead to bug.
        self.image_paths = sorted(root.joinpath('images').glob('*.jpg'))
        self.mask_paths = sorted(root.joinpath('masks').glob('*'))
        self.to_augment = to_augment

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = load_image(self.image_paths[idx])
        mask = load_mask(self.mask_paths[idx])

        if self.to_augment:
            img, mask = augment(img, mask)

        return utils.img_transform(img), torch.from_numpy(np.expand_dims(mask, 0))


def grayscale_aug(img, mask):
    car_pixels = (cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB) * img).astype(np.uint8)

    gray_car = cv2.cvtColor(car_pixels, cv2.COLOR_RGB2GRAY)

    rgb_gray_car = cv2.cvtColor(gray_car, cv2.COLOR_GRAY2RGB)

    rgb_img = img.copy()
    rgb_img[rgb_gray_car > 0] = rgb_gray_car[rgb_gray_car > 0]
    return rgb_img


def augment(img, mask):
    if np.random.random() < 0.5:
        img = np.flip(img, axis=1)
        mask = np.flip(mask, axis=1)

    if np.random.random() < 0.5:
        if np.random.random() < 0.5:
            img = random_hue_saturation_value(img,
                                              hue_shift_limit=(-50, 50),
                                              sat_shift_limit=(-5, 5),
                                              val_shift_limit=(-15, 15))
        else:
            img = grayscale_aug(img, mask)

    return img.copy(), mask.copy()


def random_hue_saturation_value(image,
                                hue_shift_limit=(-180, 180),
                                sat_shift_limit=(-255, 255),
                                val_shift_limit=(-255, 255)):

    image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    h, s, v = cv2.split(image)
    hue_shift = np.random.uniform(hue_shift_limit[0], hue_shift_limit[1])
    h = cv2.add(h, hue_shift)
    sat_shift = np.random.uniform(sat_shift_limit[0], sat_shift_limit[1])
    s = cv2.add(s, sat_shift)
    val_shift = np.random.uniform(val_shift_limit[0], val_shift_limit[1])
    v = cv2.add(v, val_shift)
    image = cv2.merge((h, s, v))
    image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)

    return image


def load_image(path: Path):
    img = cv2.imread(str(path))
    img = cv2.copyMakeBorder(img, 0, 0, 1, 1, cv2.BORDER_REFLECT_101)

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img.astype(np.uint8)


def load_mask(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            if '.gif' in str(path):
                img = (np.asarray(img) > 0)
            else:
                img = (np.asarray(img) > 255 * 0.5)
            img = cv2.copyMakeBorder(img.astype(np.uint8), 0, 0, 1, 1, cv2.BORDER_REFLECT_101)
            return img.astype(np.float32)


def validation(model: nn.Module, criterion, valid_loader) -> Dict[str, float]:
    model.eval()
    losses = []

    dice = []

    for inputs, targets in valid_loader:
        inputs = utils.variable(inputs, volatile=True)
        targets = utils.variable(targets)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        losses.append(loss.data[0])
        dice += [get_dice(targets, (outputs > 0.5).float()).data[0]]

    valid_loss = np.mean(losses)  # type: float

    valid_dice = np.mean(dice)

    print('Valid loss: {:.5f}, dice: {:.5f}'.format(valid_loss, valid_dice))
    metrics = {'valid_loss': valid_loss, 'dice_loss': valid_dice}
    return metrics


def get_dice(y_true, y_pred):
    epsilon = 1e-15
    intersection = (y_pred * y_true).sum(dim=-2).sum(dim=-1)
    union = y_true.sum(dim=-2).sum(dim=-1) + y_pred.sum(dim=-2).sum(dim=-1) + epsilon

    return 2 * (intersection / union).mean()


def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('--dice-weight', type=float)
    arg('--nll-weights', action='store_true')
    arg('--device-ids', type=str, help='For example 0,1 to run on two GPUs')
    arg('--fold', type=int, help='fold', default=0)
    arg('--size', type=str, default='1280x1920', help='Input size, for example 288x384. Must be multiples of 32')
    utils.add_args(parser)
    args = parser.parse_args()

    model_name = 'unet_11'

    args.root = str(utils.MODEL_PATH / model_name)

    root = Path(args.root)
    root.mkdir(exist_ok=True, parents=True)

    model = UNet11()

    device_ids = list(map(int, args.device_ids.split(',')))
    model = nn.DataParallel(model, device_ids=device_ids).cuda()

    loss = Loss()

    def make_loader(ds_root: Path, to_augment=False, shuffle=False):
        return DataLoader(
            dataset=CarvanaDataset(ds_root, to_augment=to_augment),
            shuffle=shuffle,
            num_workers=args.workers,
            batch_size=args.batch_size,
            pin_memory=True
        )

    train_root = utils.DATA_ROOT / str(args.fold) / 'train'
    valid_root = utils.DATA_ROOT / str(args.fold) / 'val'

    valid_loader = make_loader(valid_root)
    train_loader = make_loader(train_root, to_augment=True, shuffle=True)

    root.joinpath('params.json').write_text(
        json.dumps(vars(args), indent=True, sort_keys=True))

    utils.train(
        init_optimizer=lambda lr: Adam(model.parameters(), lr=lr),
        args=args,
        model=model,
        criterion=loss,
        train_loader=train_loader,
        valid_loader=valid_loader,
        validation=validation,
        fold=args.fold
    )


if __name__ == '__main__':
    main()