import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from scipy.misc import imresize
from scipy.ndimage import label
from .modules import SoftProposal

# helper functions

def hook_spn(model):
    if not (hasattr(model, 'sp_hook') and hasattr(model, 'fc_hook')):
        model._training = model.training
        model.train(False)
        
        def _sp_hook(self, input, output):
            self.parent_modules[0].class_response_maps = output
        def _fc_hook(self, input, output):
            if hasattr(self.parent_modules[0], 'class_response_maps'):
                self.parent_modules[0].class_response_maps = F.conv2d(self.parent_modules[0].class_response_maps, self.weight.unsqueeze(-1).unsqueeze(-1))
            else:
                raise RuntimeError('The SPN is broken, please recreate it.')
                
        sp_layer = None
        fc_layer = None
        for mod in model.modules():
            if isinstance(mod, SoftProposal):
                sp_layer = mod
            elif isinstance(mod, torch.nn.Linear):
                fc_layer = mod
        
        if sp_layer is None or fc_layer is None:
            raise RuntimeError('Invalid SPN model')
        else:
            sp_layer.parent_modules = [model]
            fc_layer.parent_modules = [model]
            model.sp_hook = sp_layer.register_forward_hook(_sp_hook)
            model.fc_hook = fc_layer.register_forward_hook(_fc_hook)
    
    return model

def unhook_spn(model):
    try:
        model.sp_hook.remove()
        model.fc_hook.remove()
        del model.sp_hook
        del model.fc_hook
        model.train(model._training)
        return model
    except:
        raise RuntimeError('The model haven\'t been hooked!')

def compute_iou(box_a, box_b):
    x_a = max(box_a[0], box_b[0])
    y_a = max(box_a[1], box_b[1])
    x_b = min(box_a[2], box_b[2])
    y_b = min(box_a[3], box_b[3])
    inter_area = max(x_b - x_a + 1, 0) * max(y_b - y_a + 1, 0)
    box_a_area = (box_a[2] - box_a[0] + 1) * (box_a[3] - box_a[1] + 1)
    box_b_area = (box_b[2] - box_b[0] + 1) * (box_b[3] - box_b[1] + 1)
    return inter_area / float(box_a_area + box_b_area - inter_area)

def bbox_nms(bbox_list, threshold=0.5):
    bbox_list = sorted(bbox_list,  key=lambda x: x[-1], reverse=True)
    selected_bboxes = []
    while len(bbox_list) > 0:
        obj = bbox_list.pop(0)
        selected_bboxes.append(obj)
        def iou_filter(x):
            iou = compute_iou(obj[1:5], x[1:5])
            if (x[0] == obj[0] and iou >= threshold):
                return None
            else:
                return x
        bbox_list = list(filter(iou_filter, bbox_list))
    return selected_bboxes

def gen_filter(bbox_threshold=(0., 50)):
    def _filter(x):
        xmin, ymin, xmax, ymax = x[1:5]
        w, h = (xmax - xmin), (ymax - ymin)
        if x[-1] > bbox_threshold[0] and w >= bbox_threshold[1] and h >= bbox_threshold[1]:
            return x
        else:
            return None
    return _filter

def extract_bbox_from_map(input):
    assert input.ndim == 2, 'Invalid input shape'
    rows = np.any(input, axis=1)
    cols = np.any(input, axis=0)
    ymin, ymax = np.where(rows)[0][[0, -1]]
    xmin, xmax = np.where(cols)[0][[0, -1]]
    return xmin, ymin, xmax, ymax

def extract_point_from_map(input):
    assert input.ndim == 2, 'Invalid input shape'
    cols = input.shape[1]
    index = np.argmax(input)
    return index % cols, index // cols

def localize_from_map(class_response_map, class_idx=0, location_type='bbox', threshold_ratio=1, multi_objects=True):
    assert location_type == 'bbox' or location_type == 'point', 'Unknown location type'
    foreground_map = class_response_map >= (class_response_map.mean() * threshold_ratio)
    if multi_objects:
        objects, count = label(foreground_map)
        res = []
        for obj_idx in range(count):
            obj = objects == (obj_idx + 1)
            if location_type == 'bbox':
                score = class_response_map[obj].mean()
                extraction = extract_bbox_from_map
            elif location_type == 'point':
                obj = class_response_map * obj.astype(float)
                score = np.max(obj)
                extraction = extract_point_from_map
            res.append((class_idx,) + extraction(obj) + (score,))
        return res
    else:
        if location_type == 'bbox':
            return [(class_idx,) + extract_bbox_from_map(foreground_map) + (class_response_map.mean(),), ]
        elif location_type == 'point':
            return [(class_idx,) + extract_point_from_map(class_response_map) + (class_response_map.max(),), ]

def object_localization(models, input, **kwargs):
    # multi-scale detection
    scales = sorted(kwargs.pop('scales', models.keys()), reverse=True)
    # localize with/without prediction
    pred_labels = kwargs.pop('gt_labels', None)
    # classification threshold
    threshold = kwargs.pop('threshold', 0)
    # switch to inference mode
    force_inference = kwargs.pop('force_inference', True)
    # NMS threshold
    nms_threshold = kwargs.pop('nms_threshold', 0.)
    # type of localization
    location_type = kwargs.get('location_type', 'bbox')

    assert len(models) == 1 or set(scales).issubset(set(models.keys())), 'Invalid scales'
    
    if input.ndimension() == 3: 
        input = input.unsqueeze(0)
    assert input.size(0) == 1, 'Batch processing is currently not supported'

    # enable spn inference mode
    if force_inference: 
        models = {k:hook_spn(v) for k, v in models.items()}

    # localize objects
    predictions = []
    for size in scales:
        model = models[size] if len(models) > 1 else next(iter(models.values()))
        class_scores = model(F.upsample(input, size=(size, size), mode='bilinear'))
        pred_labels = torch.nonzero(class_scores.data.squeeze() > threshold).squeeze() if pred_labels is None else pred_labels
        for class_idx in pred_labels:
            kwargs['class_idx'] = class_idx
            class_response_map = F.upsample(model.class_response_maps[0, class_idx].unsqueeze(0).unsqueeze(0), size=(input.size(2), input.size(3)), mode='bilinear')
            predictions += localize_from_map(class_response_map.squeeze().data.cpu().numpy(), **kwargs)
    
    # non maximum suppression
    if location_type == 'bbox' and len(models) > 1:
        predictions = list(filter(gen_filter(), bbox_nms(predictions, nms_threshold))) 

    return predictions, pred_labels