import math
import torch
import torch.nn.functional as F
from torch import nn

from .inference import make_fcos_postprocessor
from .loss import make_fcos_loss_evaluator

from maskrcnn_benchmark.layers import Scale


class CascadeFCOSHead(torch.nn.Module):
    def __init__(self, cfg, in_channels):
        """
        Arguments:
            in_channels (int): number of channels of the input feature
        """
        super(CascadeFCOSHead, self).__init__()
        # TODO: Implement the sigmoid version first.
        num_classes = cfg.MODEL.FCOS.NUM_CLASSES - 1
        cascade_area_th = cfg.MODEL.FCOS.CASCADE_AREA_TH
        self.no_centerness = no_centerness = cfg.MODEL.FCOS.CASCADE_NO_CENTERNESS

        cls_tower = []
        bbox_tower = []
        for i in range(cfg.MODEL.FCOS.NUM_CONVS):
            cls_tower.append(
                nn.Conv2d(
                    in_channels,
                    in_channels,
                    kernel_size=3,
                    stride=1,
                    padding=1
                )
            )
            cls_tower.append(nn.GroupNorm(32, in_channels))
            cls_tower.append(nn.ReLU())
            bbox_tower.append(
                nn.Conv2d(
                    in_channels,
                    in_channels,
                    kernel_size=3,
                    stride=1,
                    padding=1
                )
            )
            bbox_tower.append(nn.GroupNorm(32, in_channels))
            bbox_tower.append(nn.ReLU())

        self.add_module('cls_tower', nn.Sequential(*cls_tower))
        self.add_module('bbox_tower', nn.Sequential(*bbox_tower))
        self.cls_logits_set = nn.ModuleDict()
        for area_th in cascade_area_th:
            self.cls_logits_set.add_module("cls_logits_{}%".format(int(area_th*100)), nn.Conv2d(
                in_channels, num_classes, kernel_size=3, stride=1, padding=1
            ))
        self.bbox_pred = nn.Conv2d(
            in_channels, 4, kernel_size=3, stride=1,
            padding=1
        )
        if not no_centerness:
            self.centerness = nn.Conv2d(
                in_channels, 1, kernel_size=3, stride=1,
                padding=1
            )

        # initialization
        for modules in [self.cls_tower, self.bbox_tower,
                        self.bbox_pred,  # self.centerness
                        ] + [m for m in self.cls_logits_set.values()]:
            for l in modules.modules():
                if isinstance(l, nn.Conv2d):
                    torch.nn.init.normal_(l.weight, std=0.01)
                    torch.nn.init.constant_(l.bias, 0)

        # initialize the bias for focal loss
        prior_prob = cfg.MODEL.FCOS.PRIOR_PROB
        bias_value = -math.log((1 - prior_prob) / prior_prob)
        for m in self.cls_logits_set.values():
            torch.nn.init.constant_(m.bias, bias_value)

        self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)])

    def forward(self, x):
        logits_set = {name: [] for name in self.cls_logits_set}
        bbox_reg = []
        centerness = []
        for l, feature in enumerate(x):
            cls_tower = self.cls_tower(feature)
            for name, cls_logits in self.cls_logits_set.items():
                logits_set[name].append(cls_logits(cls_tower))
            if not self.no_centerness:
                centerness.append(self.centerness(cls_tower))
            bbox_reg.append(torch.exp(self.scales[l](
                self.bbox_pred(self.bbox_tower(feature))
            )))
        if len(centerness) == 0: centerness = None
        return logits_set, bbox_reg, centerness


class CascadeFCOSModule(torch.nn.Module):
    """
    Module for FCOS computation. Takes feature maps from the backbone and
    FCOS outputs and losses. Only Test on FPN now.
    """

    def __init__(self, cfg, in_channels):
        super(CascadeFCOSModule, self).__init__()

        head = CascadeFCOSHead(cfg, in_channels)

        box_selector_test = make_fcos_postprocessor(cfg)

        loss_evaluator = make_fcos_loss_evaluator(cfg)
        self.head = head
        self.box_selector_test = box_selector_test
        self.loss_evaluator = loss_evaluator
        self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES
        self.vis_labels = cfg.MODEL.FCOS.DEBUG.VIS_LABELS

    def forward(self, images, features, targets=None):
        """
        Arguments:
            images (ImageList): images for which we want to compute the predictions
            features (list[Tensor]): features computed from the images that are
                used for computing the predictions. Each tensor in the list
                correspond to different feature levels
            targets (list[BoxList): ground-truth boxes present in the image (optional)

        Returns:
            boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per
                image.
            losses (dict[Tensor]): the losses for the model during training. During
                testing, it is an empty dict.
        """
        box_cls_set, box_regression, centerness = self.head(features)
        locations = self.compute_locations(features)

        if self.training:
            res = self._forward_train(
                locations, box_cls_set,
                box_regression,
                centerness, targets
            )
        else:
            res = self._forward_test(
                locations, box_cls_set, box_regression,
                centerness, images.image_sizes, images=images, targets=targets
            )

        if self.vis_labels:
            show_image(images, targets, res[0])
        return res

    def _forward_train(self, locations, box_cls_set, box_regression, centerness, targets):
        loss_box_cls, loss_box_reg, loss_centerness = self.loss_evaluator(
            locations, box_cls_set, box_regression, centerness, targets
        )
        losses = {
            "loss_cls": loss_box_cls,
            "loss_reg": loss_box_reg,
        }
        if isinstance(loss_centerness, torch.Tensor):
            losses["loss_centerness"] = loss_centerness
        return None, losses

    def _forward_test(self, locations, box_cls_set, box_regression, centerness, image_sizes, **kwargs):
        boxes = self.box_selector_test(
            locations, box_cls_set, box_regression,
            centerness, image_sizes, **kwargs
        )
        return boxes, {}

    def compute_locations(self, features):
        locations = []
        for level, feature in enumerate(features):
            h, w = feature.size()[-2:]
            locations_per_level = self.compute_locations_per_level(
                h, w, self.fpn_strides[level],
                feature.device
            )
            locations.append(locations_per_level)
        return locations

    def compute_locations_per_level(self, h, w, stride, device):
        shifts_x = torch.arange(
            0, w * stride, step=stride,
            dtype=torch.float32, device=device
        )
        shifts_y = torch.arange(
            0, h * stride, step=stride,
            dtype=torch.float32, device=device
        )
        shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
        shift_x = shift_x.reshape(-1)
        shift_y = shift_y.reshape(-1)
        locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2
        return locations


def build_cascade_fcos(cfg, in_channels):
    return CascadeFCOSModule(cfg, in_channels)


class ResultShower(object):
    """
        1. plot image
        2. plot list of bboxes, bboxes can be ground-truth or detection results
        3. show score text for detection result
        4. show detection location as red point, score as point size
    """
    import numpy as np

    def __init__(self, image_mean=np.array([102.9801, 115.9465, 122.7717]), show_iter=1):
        self.score_th = None
        self.show_score_topk = 6
        self.image_mean = image_mean
        self.point_size = 100
        self.plot = self.plot2
        self.show_iter = show_iter
        self.counter = 0

    def __call__(self, images, *targets_list):
        import matplotlib.pyplot as plt
        import seaborn as sbn
        if (self.counter + 1) % self.show_iter != 0:
            self.counter += 1
            return
        self.counter += 1
        colors = sbn.color_palette(n_colors=len(targets_list))
        img = images.tensors[0].permute((1, 2, 0)).cpu().numpy() + self.image_mean
        img = img[:, :, [2, 1, 0]]
        plt.imshow(img/255)
        title = "boxes:"
        for ci, targets in enumerate(targets_list):
            if targets is not None:
                bboxes = targets[0].bbox.cpu().numpy().tolist()
                scores = targets[0].extra_fields['scores'].cpu() if 'scores' in targets[0].extra_fields else None
                locations = targets[0].extra_fields['det_locations'].cpu() if 'det_locations' in targets[0].extra_fields else None
                labels = targets[0].extra_fields['labels'].cpu()
                if scores is None:
                    self.plot1(bboxes, scores, locations, labels, None, (1, 0, 0))  # ground-truth
                else:
                    score_th = -torch.kthvalue(-scores, self.show_score_topk)[0]\
                        if self.score_th is None else self.score_th
                    self.plot(bboxes, scores, locations, labels, score_th, colors[ci])
                count = len(targets[0].bbox) if scores is None else (scores > score_th).sum()
                title += "{}({}) ".format(count, len(targets[0].bbox))
        plt.title(title)
        plt.show()
        input()

    def plot2(self, bboxes, scores, locations, labels, score_th, color=None):
        """
            no dash line link box and location, use color link
            different color for different box,
            same color for same box and location
        """
        import matplotlib.pyplot as plt
        import seaborn as sbn

        if True:# sorted
            scores, idx = (-scores).sort()
            scores = -scores
            labels = labels[idx]
            locations = locations[idx]

        colors = sbn.color_palette(n_colors=len(bboxes))
        for i, (x1, y1, x2, y2) in enumerate(bboxes):
            w = x2 - x1 + 1
            h = y2 - y1 + 1
            color = colors[i]
            if scores is not None:
                if scores[i] >= score_th:
                    plt.text(x1, y1, '{}:{:.2f}'.format(labels[i], scores[i]), color=(1, 0, 0))
                    rect = plt.Rectangle((x1, y1), w, h, fill=False, color=color, linewidth=1.5)
                    plt.axes().add_patch(rect)
                if locations is not None:
                    lx, ly = locations[i]
                    plt.scatter(lx, ly, color=color, s=self.point_size*scores[i])
            else:
                plt.text(x2, y2, '{}'.format(labels[i]), color=(1, 0, 0))
                rect = plt.Rectangle((x1, y1), w, h, fill=False, color=color, linewidth=1.5)
                plt.axes().add_patch(rect)

        print(scores)
        print(labels)
        print(locations)

    def plot1(self, bboxes, scores, locations, labels, score_th, color):
        """
        , use dash line link bbox and location
        """
        import matplotlib.pyplot as plt
        for i, (x1, y1, x2, y2) in enumerate(bboxes):
            w = x2 - x1 + 1
            h = y2 - y1 + 1
            if scores is not None:
                if scores[i] >= score_th:
                    plt.text(x1, y1, '{:.2f}'.format(scores[i]), color=(1, 0, 0))
                    rect = plt.Rectangle((x1, y1), w, h, fill=False, color=color, linewidth=1.5)
                    plt.axes().add_patch(rect)
                    if locations is not None:
                        lx, ly = locations[i]
                        plt.plot([lx, lx, lx], [y2, ly, y1], '--', color=color)
                        plt.plot([x2, lx, x1], [ly, ly, ly], '--', color=color)
                if locations is not None:
                    lx, ly = locations[i]
                    plt.scatter(lx, ly, color='r', s=self.point_size * scores[i])
            else:
                rect = plt.Rectangle((x1, y1), w, h, fill=False, color=color, linewidth=1.5)
                plt.axes().add_patch(rect)


show_image = ResultShower()