import numpy as np
import scipy.ndimage

import torch
import torch.nn.functional as F

import archs

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def load_model(casm_path):
    name = casm_path.split('/')[-1].replace('.chk','')

    print("\n=> Loading model localized in '{}'".format(casm_path))
    classifier = archs.resnet50shared()
    checkpoint = torch.load(casm_path)

    classifier.load_state_dict(checkpoint['state_dict_classifier'])
    classifier.eval().to(device)

    decoder = archs.decoder()
    decoder.load_state_dict(checkpoint['state_dict_decoder'])
    decoder.eval().to(device)
    print("=> Model loaded.")

    return {'classifier': classifier, 'decoder': decoder, 'name': name}

def get_masks_and_check_predictions(input, target, model):
    with torch.no_grad():
        input, target = torch.tensor(input), torch.tensor(target)
        mask, output = get_mask(input, model, get_output=True)

        rectangular = binarize_mask(mask.clone())

        for id in range(mask.size(0)):
            if rectangular[id].sum() == 0:
                continue
            rectangular[id] = get_rectangular_mask(rectangular[id].squeeze().numpy())

        target = target.to(device)
        _, max_indexes = output.data.max(1)
        isCorrect = target.eq(max_indexes)

        return mask.squeeze().cpu().numpy(), rectangular.squeeze().cpu().numpy(), isCorrect.cpu().numpy() 

def get_mask(input, model, get_output=False):
    with torch.no_grad():
        input = input.to(device)
        output, layers = model['classifier'](input)
        if get_output:
            return model['decoder'](layers), output

        return model['decoder'](layers)

def binarize_mask(mask):
    with torch.no_grad():
        avg = F.avg_pool2d(mask, 224, stride=1).squeeze()
        flat_mask = mask.cpu().view(mask.size(0), -1)
        binarized_mask = torch.zeros_like(flat_mask)
        for i in range(mask.size(0)):
            kth = 1 + int((flat_mask[i].size(0) - 1) * (1 - avg[i].item()) + 0.5)
            th, _ = torch.kthvalue(flat_mask[i], kth)
            th.clamp_(1e-6, 1 - 1e-6)
            binarized_mask[i] = flat_mask[i].gt(th).float()
        binarized_mask = binarized_mask.view(mask.size())

        return binarized_mask

def get_largest_connected(m):
    mask, num_labels = scipy.ndimage.label(m)
    largest_label = np.argmax(np.bincount(
        mask.reshape(-1), weights=m.reshape(-1)))
    largest_connected = (mask == largest_label)

    return largest_connected

def get_bounding_box(m):
    x = m.any(1)
    y = m.any(0)
    xmin = np.argmax(x)
    xmax = np.argmax(np.cumsum(x))
    ymin = np.argmax(y)
    ymax = np.argmax(np.cumsum(y))
    with torch.no_grad():
        box_mask = torch.zeros(224, 224).to(device)
        box_mask[xmin:xmax+1, ymin:ymax+1] = 1

        return box_mask

def get_rectangular_mask(m):
    return get_bounding_box(get_largest_connected(m))