import numpy as np
import torch
from torch import nn

from .utils import balance_sampling


class AvgMeter:
    def __init__(self):
        self.average = 0
        self.num_averaged = 0

    def update(self, loss, size):
        n = self.num_averaged
        m = n + size
        self.average = ((n * self.average) + float(loss)) / m
        self.num_averaged = m

    def reset(self):
        self.average = 0
        self.num_averaged = 0


class DetectionCriterion(nn.Module):
    """
    The loss for the Tiny Faces detector
    """

    def __init__(self, n_templates=25, reg_weight=1, pos_fraction=0.5):
        super().__init__()

        # We don't want per element averaging.
        # We want to normalize over the batch or positive samples.
        self.regression_criterion = nn.SmoothL1Loss(reduction='none')
        self.classification_criterion = nn.SoftMarginLoss(reduction='none')
        self.n_templates = n_templates
        self.reg_weight = reg_weight
        self.pos_fraction = pos_fraction

        self.class_average = AvgMeter()
        self.reg_average = AvgMeter()

        self.masked_class_loss = None
        self.masked_reg_loss = None
        self.total_loss = None

    def balance_sample(self, class_map):
        device = class_map.device
        label_class_np = class_map.cpu().numpy()
        # iterate through batch
        for idx in range(label_class_np.shape[0]):
            label_class_np[idx, ...] = balance_sampling(label_class_np[idx, ...],
                                                        pos_fraction=self.pos_fraction)

        class_map = torch.from_numpy(label_class_np).to(device)

        return class_map

    def hard_negative_mining(self, classification, class_map):
        loss_class_map = nn.functional.soft_margin_loss(classification.detach(), class_map,
                                                        reduction='none')
        class_map[loss_class_map < 0.03] = 0
        return class_map

    def forward(self, output, class_map, regression_map):
        classification = output[:, 0:self.n_templates, :, :]
        regression = output[:, self.n_templates:, :, :]

        # online hard negative mining
        class_map = self.hard_negative_mining(classification, class_map)
        # balance sampling
        class_map = self.balance_sample(class_map)

        class_loss = self.classification_criterion(classification, class_map)

       # weights used to mask out invalid regions i.e. where the label is 0
        class_mask = (class_map != 0).type(output.dtype)
        # Mask the classification loss
        self.masked_class_loss = class_mask * class_loss

        reg_loss = self.regression_criterion(regression, regression_map)
        # make same size as reg_map
        reg_mask = (class_map > 0).repeat(1, 4, 1, 1).type(output.dtype)

        self.masked_reg_loss = reg_mask * reg_loss  # / reg_loss.size(0)

        self.total_loss = self.masked_class_loss.sum() + \
            self.reg_weight * self.masked_reg_loss.sum()

        self.class_average.update(self.masked_class_loss.sum(), output.size(0))
        self.reg_average.update(self.masked_reg_loss.sum(), output.size(0))

        return self.total_loss

    def reset(self):
        self.class_average.reset()
        self.reg_average.reset()