import os
from typing import Union, Tuple, List, NamedTuple

import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from backbone.base import Base as BackboneBase
from bbox import BBox
from nms.nms import NMS
from roi.wrapper import Wrapper as ROIWrapper
from rpn.region_proposal_network import RegionProposalNetwork


class Model(nn.Module):

    class ForwardInput(object):
        class Train(NamedTuple):
            image: Tensor
            gt_classes: Tensor
            gt_bboxes: Tensor

        class Eval(NamedTuple):
            image: Tensor

    class ForwardOutput(object):
        class Train(NamedTuple):
            anchor_objectness_loss: Tensor
            anchor_transformer_loss: Tensor
            proposal_class_loss: Tensor
            proposal_transformer_loss: Tensor

        class Eval(NamedTuple):
            detection_bboxes: Tensor
            detection_classes: Tensor
            detection_probs: Tensor

    def __init__(self, backbone: BackboneBase, num_classes: int, pooling_mode: ROIWrapper.Mode,
                 anchor_ratios: List[Tuple[int, int]], anchor_scales: List[int], rpn_pre_nms_top_n: int, rpn_post_nms_top_n: int):
        super().__init__()

        conv_layers, lateral_layers, dealiasing_layers, num_features_out = backbone.features()
        self.conv1, self.conv2, self.conv3, self.conv4, self.conv5 = conv_layers
        self.lateral_c2, self.lateral_c3, self.lateral_c4, self.lateral_c5 = lateral_layers
        self.dealiasing_p2, self.dealiasing_p3, self.dealiasing_p4 = dealiasing_layers

        self._bn_modules = [it for it in self.conv1.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.conv2.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.conv3.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.conv4.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.conv5.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.lateral_c2.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.lateral_c3.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.lateral_c4.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.lateral_c5.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.dealiasing_p2.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.dealiasing_p3.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.dealiasing_p4.modules() if isinstance(it, nn.BatchNorm2d)]

        self.num_classes = num_classes

        self.rpn = RegionProposalNetwork(num_features_out, anchor_ratios, anchor_scales, rpn_pre_nms_top_n, rpn_post_nms_top_n)
        self.detection = Model.Detection(pooling_mode, self.num_classes)

    def forward(self, forward_input: Union[ForwardInput.Train, ForwardInput.Eval]) -> Union[ForwardOutput.Train, ForwardOutput.Eval]:
        # freeze batch normalization modules for each forwarding process just in case model was switched to `train` at any time
        for bn_module in self._bn_modules:
            bn_module.eval()
            for parameter in bn_module.parameters():
                parameter.requires_grad = False

        image = forward_input.image.unsqueeze(dim=0)
        image_height, image_width = image.shape[2], image.shape[3]

        # Bottom-up pathway
        c1 = self.conv1(image)
        c2 = self.conv2(c1)
        c3 = self.conv3(c2)
        c4 = self.conv4(c3)
        c5 = self.conv5(c4)

        # Top-down pathway and lateral connections
        p5 = self.lateral_c5(c5)
        p4 = self.lateral_c4(c4) + F.interpolate(input=p5, size=(c4.shape[2], c4.shape[3]), mode='nearest')
        p3 = self.lateral_c3(c3) + F.interpolate(input=p4, size=(c3.shape[2], c3.shape[3]), mode='nearest')
        p2 = self.lateral_c2(c2) + F.interpolate(input=p3, size=(c2.shape[2], c2.shape[3]), mode='nearest')

        # Reduce the aliasing effect
        p4 = self.dealiasing_p4(p4)
        p3 = self.dealiasing_p3(p3)
        p2 = self.dealiasing_p2(p2)

        p6 = F.max_pool2d(input=p5, kernel_size=1, stride=2)

        # NOTE: We define the anchors to have areas of {32^2, 64^2, 128^2, 256^2, 512^2} pixels on {P2, P3, P4, P5, P6} respectively

        anchor_objectnesses = []
        anchor_transformers = []
        anchor_bboxes = []
        proposal_bboxes = []

        for p, anchor_size in zip([p2, p3, p4, p5, p6], [32, 64, 128, 256, 512]):
            p_anchor_objectnesses, p_anchor_transformers = self.rpn.forward(features=p, image_width=image_width, image_height=image_height)
            p_anchor_bboxes = self.rpn.generate_anchors(image_width, image_height,
                                                        num_x_anchors=p.shape[3], num_y_anchors=p.shape[2],
                                                        anchor_size=anchor_size).cuda()
            p_proposal_bboxes = self.rpn.generate_proposals(p_anchor_bboxes, p_anchor_objectnesses, p_anchor_transformers,
                                                            image_width, image_height)
            anchor_objectnesses.append(p_anchor_objectnesses)
            anchor_transformers.append(p_anchor_transformers)
            anchor_bboxes.append(p_anchor_bboxes)
            proposal_bboxes.append(p_proposal_bboxes)

        anchor_objectnesses = torch.cat(anchor_objectnesses, dim=0)
        anchor_transformers = torch.cat(anchor_transformers, dim=0)
        anchor_bboxes = torch.cat(anchor_bboxes, dim=0)
        proposal_bboxes = torch.cat(proposal_bboxes, dim=0)

        if self.training:
            forward_input: Model.ForwardInput.Train

            anchor_sample_fg_indices, anchor_sample_selected_indices, gt_anchor_objectnesses, gt_anchor_transformers = self.rpn.sample(anchor_bboxes, forward_input.gt_bboxes, image_width, image_height)
            anchor_objectnesses = anchor_objectnesses[anchor_sample_selected_indices]
            anchor_transformers = anchor_transformers[anchor_sample_fg_indices]
            anchor_objectness_loss, anchor_transformer_loss = self.rpn.loss(anchor_objectnesses, anchor_transformers, gt_anchor_objectnesses, gt_anchor_transformers)

            proposal_sample_fg_indices, proposal_sample_selected_indices, gt_proposal_classes, gt_proposal_transformers = self.detection.sample(proposal_bboxes, forward_input.gt_classes, forward_input.gt_bboxes)
            proposal_bboxes = proposal_bboxes[proposal_sample_selected_indices]
            proposal_classes, proposal_transformers = self.detection.forward(p2, p3, p4, p5, proposal_bboxes, image_width, image_height)
            proposal_class_loss, proposal_transformer_loss = self.detection.loss(proposal_classes, proposal_transformers, gt_proposal_classes, gt_proposal_transformers)

            forward_output = Model.ForwardOutput.Train(anchor_objectness_loss, anchor_transformer_loss, proposal_class_loss, proposal_transformer_loss)
        else:
            proposal_classes, proposal_transformers = self.detection.forward(p2, p3, p4, p5, proposal_bboxes, image_width, image_height)
            detection_bboxes, detection_classes, detection_probs = self.detection.generate_detections(proposal_bboxes, proposal_classes, proposal_transformers, image_width, image_height)
            forward_output = Model.ForwardOutput.Eval(detection_bboxes, detection_classes, detection_probs)

        return forward_output

    def save(self, path_to_checkpoints_dir: str, step: int, optimizer: Optimizer, scheduler: _LRScheduler) -> str:
        path_to_checkpoint = os.path.join(path_to_checkpoints_dir, f'model-{step}.pth')
        checkpoint = {
            'state_dict': self.state_dict(),
            'step': step,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }
        torch.save(checkpoint, path_to_checkpoint)
        return path_to_checkpoint

    def load(self, path_to_checkpoint: str, optimizer: Optimizer = None, scheduler: _LRScheduler = None) -> 'Model':
        checkpoint = torch.load(path_to_checkpoint)
        self.load_state_dict(checkpoint['state_dict'])
        step = checkpoint['step']
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if scheduler is not None:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        return step

    class Detection(nn.Module):

        def __init__(self, pooling_mode: ROIWrapper.Mode, num_classes: int):
            super().__init__()
            self._pooling_mode = pooling_mode
            self._hidden = nn.Sequential(
                nn.Linear(256 * 7 * 7, 1024),
                nn.ReLU(),
                nn.Linear(1024, 1024),
                nn.ReLU()
            )
            self.num_classes = num_classes
            self._class = nn.Linear(1024, num_classes)
            self._transformer = nn.Linear(1024, num_classes * 4)
            self._transformer_normalize_mean = torch.tensor([0., 0., 0., 0.], dtype=torch.float).cuda()
            self._transformer_normalize_std = torch.tensor([.1, .1, .2, .2], dtype=torch.float).cuda()

        def forward(self, p2: Tensor, p3: Tensor, p4: Tensor, p5: Tensor, proposal_bboxes: Tensor, image_width: int, image_height: int) -> Tuple[Tensor, Tensor]:
            w = proposal_bboxes[:, 2] - proposal_bboxes[:, 0]
            h = proposal_bboxes[:, 3] - proposal_bboxes[:, 1]
            k0 = 4
            k = torch.floor(k0 + torch.log2(torch.sqrt(w * h) / 224)).long()
            k = torch.clamp(k, min=2, max=5)

            k_to_p_dict = {2: p2, 3: p3, 4: p4, 5: p5}
            unique_k = torch.unique(k)

            # NOTE: `picked_indices` is for recording the order of selection from `proposal_bboxes`
            #       so that `pools` can be then restored to make it have a consistent correspondence
            #       with `proposal_bboxes`. For example:
            #
            #           proposal_bboxes =>  B0  B1  B2
            #            picked_indices =>   1   2   0
            #                     pools => BP1 BP2 BP0
            #            sorted_indices =>   2   0   1
            #                     pools => BP0 BP1 BP2

            pools = []
            picked_indices = []

            for uk in unique_k:
                uk = uk.item()
                p = k_to_p_dict[uk]
                uk_indices = (k == uk).nonzero().view(-1)
                uk_proposal_bboxes = proposal_bboxes[uk_indices]
                pool = ROIWrapper.apply(p, uk_proposal_bboxes, mode=self._pooling_mode, image_width=image_width, image_height=image_height)
                pools.append(pool)
                picked_indices.append(uk_indices)

            pools = torch.cat(pools, dim=0)
            picked_indices = torch.cat(picked_indices, dim=0)

            _, sorted_indices = torch.sort(picked_indices)
            pools = pools[sorted_indices]

            pools = pools.view(pools.shape[0], -1)
            hidden = self._hidden(pools)
            classes = self._class(hidden)
            transformers = self._transformer(hidden)
            return classes, transformers

        def sample(self, proposal_bboxes: Tensor, gt_classes: Tensor, gt_bboxes: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
            sample_fg_indices = torch.arange(end=len(proposal_bboxes), dtype=torch.long)
            sample_selected_indices = torch.arange(end=len(proposal_bboxes), dtype=torch.long)

            # find labels for each `proposal_bboxes`
            labels = torch.ones(len(proposal_bboxes), dtype=torch.long).cuda() * -1
            ious = BBox.iou(proposal_bboxes, gt_bboxes)
            proposal_max_ious, proposal_assignments = ious.max(dim=1)
            labels[proposal_max_ious < 0.5] = 0
            labels[proposal_max_ious >= 0.5] = gt_classes[proposal_assignments[proposal_max_ious >= 0.5]]

            # select 128 samples
            fg_indices = (labels > 0).nonzero().view(-1)
            bg_indices = (labels == 0).nonzero().view(-1)
            fg_indices = fg_indices[torch.randperm(len(fg_indices))[:min(len(fg_indices), 32)]]
            bg_indices = bg_indices[torch.randperm(len(bg_indices))[:128 - len(fg_indices)]]
            selected_indices = torch.cat([fg_indices, bg_indices])
            selected_indices = selected_indices[torch.randperm(len(selected_indices))]

            proposal_bboxes = proposal_bboxes[selected_indices]
            gt_proposal_transformers = BBox.calc_transformer(proposal_bboxes, gt_bboxes[proposal_assignments[selected_indices]])
            gt_proposal_classes = labels[selected_indices]

            gt_proposal_transformers = (gt_proposal_transformers - self._transformer_normalize_mean) / self._transformer_normalize_std

            gt_proposal_transformers = gt_proposal_transformers.cuda()
            gt_proposal_classes = gt_proposal_classes.cuda()

            sample_fg_indices = sample_fg_indices[fg_indices]
            sample_selected_indices = sample_selected_indices[selected_indices]

            return sample_fg_indices, sample_selected_indices, gt_proposal_classes, gt_proposal_transformers

        def loss(self, proposal_classes: Tensor, proposal_transformers: Tensor, gt_proposal_classes: Tensor, gt_proposal_transformers: Tensor) -> Tuple[Tensor, Tensor]:
            cross_entropy = F.cross_entropy(input=proposal_classes, target=gt_proposal_classes)

            proposal_transformers = proposal_transformers.view(-1, self.num_classes, 4)
            proposal_transformers = proposal_transformers[torch.arange(end=len(proposal_transformers), dtype=torch.long).cuda(), gt_proposal_classes]

            fg_indices = gt_proposal_classes.nonzero().view(-1)

            # NOTE: The default of `reduction` is `elementwise_mean`, which is divided by N x 4 (number of all elements), here we replaced by N for better performance
            smooth_l1_loss = F.smooth_l1_loss(input=proposal_transformers[fg_indices], target=gt_proposal_transformers[fg_indices], reduction='sum')
            smooth_l1_loss /= len(gt_proposal_transformers)

            return cross_entropy, smooth_l1_loss

        def generate_detections(self, proposal_bboxes: Tensor, proposal_classes: Tensor, proposal_transformers: Tensor, image_width: int, image_height: int) -> Tuple[Tensor, Tensor, Tensor]:
            proposal_transformers = proposal_transformers.view(-1, self.num_classes, 4)
            mean = self._transformer_normalize_mean.repeat(1, self.num_classes, 1)
            std = self._transformer_normalize_std.repeat(1, self.num_classes, 1)

            proposal_transformers = proposal_transformers * std - mean
            proposal_bboxes = proposal_bboxes.view(-1, 1, 4).repeat(1, self.num_classes, 1)
            detection_bboxes = BBox.apply_transformer(proposal_bboxes.view(-1, 4), proposal_transformers.view(-1, 4))

            detection_bboxes = detection_bboxes.view(-1, self.num_classes, 4)

            detection_bboxes[:, :, [0, 2]] = detection_bboxes[:, :, [0, 2]].clamp(min=0, max=image_width)
            detection_bboxes[:, :, [1, 3]] = detection_bboxes[:, :, [1, 3]].clamp(min=0, max=image_height)

            proposal_probs = F.softmax(proposal_classes, dim=1)

            detection_bboxes = detection_bboxes.cpu()
            proposal_probs = proposal_probs.cpu()

            generated_bboxes = []
            generated_classes = []
            generated_probs = []

            for c in range(1, self.num_classes):
                detection_class_bboxes = detection_bboxes[:, c, :]
                proposal_class_probs = proposal_probs[:, c]

                _, sorted_indices = proposal_class_probs.sort(descending=True)
                detection_class_bboxes = detection_class_bboxes[sorted_indices]
                proposal_class_probs = proposal_class_probs[sorted_indices]

                kept_indices = NMS.suppress(detection_class_bboxes.cuda(), threshold=0.3)
                detection_class_bboxes = detection_class_bboxes[kept_indices]
                proposal_class_probs = proposal_class_probs[kept_indices]

                generated_bboxes.append(detection_class_bboxes)
                generated_classes.append(torch.ones(len(kept_indices), dtype=torch.int) * c)
                generated_probs.append(proposal_class_probs)

            generated_bboxes = torch.cat(generated_bboxes, dim=0)
            generated_classes = torch.cat(generated_classes, dim=0)
            generated_probs = torch.cat(generated_probs, dim=0)
            return generated_bboxes, generated_classes, generated_probs