import os
import torch
import torch.nn as nn

from .cross_entropy_loss import CrossEntropyLoss


from torchreid.utils.nuc_norm import nuclear_norm


class SpectralLoss(nn.Module):

    def __init__(self, num_classes, *, use_gpu=True, label_smooth=True, beta=None, penalty_position='before'):
        super().__init__()

        os_beta = None

        sing_beta = os.environ.get('spec_beta')
        if sing_beta is not None:
            try:
                os_beta = float(sing_beta)
            except (ValueError, TypeError):
                pass

        if os_beta is None:

            try:
                os_beta = float(os.environ.get('beta'))
            except (ValueError, TypeError):
                raise RuntimeError('No beta specified. ABORTED.')
        print('USE_GPU', use_gpu)
        self.beta = beta if not os_beta else os_beta
        print('beta', self.beta)
        self.xent_loss = CrossEntropyLoss(num_classes=num_classes, use_gpu=use_gpu, label_smooth=label_smooth)
        self.penalty_position = frozenset(penalty_position.split(','))

    def get_laplacian_nuc_norm(self, A: 'N x C x S'):

        N, C, _ = A.size()
        # print(A)
        AAT = torch.bmm(A, A.permute(0, 2, 1))
        ones = torch.ones((N, C, 1), device='cuda')
        D = torch.bmm(AAT, ones).view(N, C)
        D = torch.diag_embed(D)

        return nuclear_norm(D - AAT, sym=True).sum() / N

    def apply_penalty(self, k, x):

        if isinstance(x, tuple):
            return sum([self.apply_penalty(k, xx) for xx in x]) / len(x)

        batches, channels, height, width = x.size()
        W = x.view(batches, channels, -1)

        penalty = self.get_laplacian_nuc_norm(W)

        if k == 'layer5':
            penalty *= 0.01

        return penalty * self.beta  # Quirk: normalize to 32-batch case

    def forward(self, inputs, pids):

        _, y, _, feature_dict = inputs

        existed_positions = frozenset(feature_dict.keys())
        missing = self.penalty_position - existed_positions
        if missing:
            raise RuntimeError('Cannot apply singular loss, as positions {!r} are missing.'.format(list(missing)))

        penalty = sum([self.apply_penalty(k, x) for k, x in feature_dict.items() if k in self.penalty_position])

        xloss = self.xent_loss(y, pids)
        # logger.debug(str(singular_penalty))
        print(penalty)
        return penalty + xloss