import torch
import torch.nn as nn

from ..registry import registry
from .modeling_utils import ProteinConfig
from .modeling_utils import ProteinModel

URL_PREFIX = "https://s3.amazonaws.com/proteindata/pytorch-models/"
TRROSETTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
    'xaa': URL_PREFIX + "trRosetta-xaa-pytorch_model.bin",
    'xab': URL_PREFIX + "trRosetta-xab-pytorch_model.bin",
    'xac': URL_PREFIX + "trRosetta-xac-pytorch_model.bin",
    'xad': URL_PREFIX + "trRosetta-xad-pytorch_model.bin",
    'xae': URL_PREFIX + "trRosetta-xae-pytorch_model.bin",
}
TRROSETTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    'xaa': URL_PREFIX + "trRosetta-xaa-config.json",
    'xab': URL_PREFIX + "trRosetta-xab-config.json",
    'xac': URL_PREFIX + "trRosetta-xac-config.json",
    'xad': URL_PREFIX + "trRosetta-xad-config.json",
    'xae': URL_PREFIX + "trRosetta-xae-config.json",
}


class TRRosettaConfig(ProteinConfig):

    pretrained_config_archive_map = TRROSETTA_PRETRAINED_CONFIG_ARCHIVE_MAP

    def __init__(self,
                 num_features: int = 64,
                 kernel_size: int = 3,
                 num_layers: int = 61,
                 dropout: float = 0.15,
                 msa_cutoff: float = 0.8,
                 penalty_coeff: float = 4.5,
                 initializer_range: float = 0.02,
                 **kwargs):
        super().__init__(**kwargs)
        self.num_features = num_features
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.msa_cutoff = msa_cutoff
        self.penalty_coeff = penalty_coeff
        self.initializer_range = initializer_range


class MSAFeatureExtractor(nn.Module):

    def __init__(self, config: TRRosettaConfig):
        super().__init__()
        self.msa_cutoff = config.msa_cutoff
        self.penalty_coeff = config.penalty_coeff

    def forward(self, msa1hot):
        # Convert to float, then potentially back to half
        # These transforms aren't well suited to half-precision
        initial_type = msa1hot.dtype

        msa1hot = msa1hot.float()
        seqlen = msa1hot.size(2)

        weights = self.reweight(msa1hot)
        features_1d = self.extract_features_1d(msa1hot, weights)
        features_2d = self.extract_features_2d(msa1hot, weights)

        left = features_1d.unsqueeze(2).repeat(1, 1, seqlen, 1)
        right = features_1d.unsqueeze(1).repeat(1, seqlen, 1, 1)
        features = torch.cat((left, right, features_2d), -1)
        features = features.type(initial_type)
        features = features.permute(0, 3, 1, 2)
        features = features.contiguous()
        return features

    def reweight(self, msa1hot, eps=1e-9):
        # Reweight
        seqlen = msa1hot.size(2)
        id_min = seqlen * self.msa_cutoff
        id_mtx = torch.stack([torch.tensordot(el, el, [[1, 2], [1, 2]]) for el in msa1hot], 0)
        id_mask = id_mtx > id_min
        weights = 1.0 / (id_mask.type_as(msa1hot).sum(-1) + eps)
        return weights

    def extract_features_1d(self, msa1hot, weights):
        # 1D Features
        f1d_seq = msa1hot[:, 0, :, :20]
        batch_size = msa1hot.size(0)
        seqlen = msa1hot.size(2)

        # msa2pssm
        beff = weights.sum()
        f_i = (weights[:, :, None, None] * msa1hot).sum(1) / beff + 1e-9
        h_i = (-f_i * f_i.log()).sum(2, keepdims=True)
        f1d_pssm = torch.cat((f_i, h_i), dim=2)
        f1d = torch.cat((f1d_seq, f1d_pssm), dim=2)
        f1d = f1d.view(batch_size, seqlen, 42)
        return f1d

    def extract_features_2d(self, msa1hot, weights):
        # 2D Features
        batch_size = msa1hot.size(0)
        num_alignments = msa1hot.size(1)
        seqlen = msa1hot.size(2)
        num_symbols = 21

        if num_alignments == 1:
            # No alignments, predict from sequence alone
            f2d_dca = torch.zeros(
                batch_size, seqlen, seqlen, 442,
                dtype=torch.float,
                device=msa1hot.device)
            return f2d_dca

        # compute fast_dca
        # covariance
        x = msa1hot.view(batch_size, num_alignments, seqlen * num_symbols)
        num_points = weights.sum(1) - weights.mean(1).sqrt()
        mean = (x * weights.unsqueeze(2)).sum(1, keepdims=True) / num_points[:, None, None]
        x = (x - mean) * weights[:, :, None].sqrt()
        cov = torch.matmul(x.transpose(-1, -2), x) / num_points[:, None, None]

        # inverse covariance
        reg = torch.eye(seqlen * num_symbols,
                        device=weights.device,
                        dtype=weights.dtype)[None]
        reg = reg * self.penalty_coeff / weights.sum(1, keepdims=True).sqrt().unsqueeze(2)
        cov_reg = cov + reg
        inv_cov = torch.stack([torch.inverse(cr) for cr in cov_reg.unbind(0)], 0)

        x1 = inv_cov.view(batch_size, seqlen, num_symbols, seqlen, num_symbols)
        x2 = x1.permute(0, 1, 3, 2, 4)
        features = x2.reshape(batch_size, seqlen, seqlen, num_symbols * num_symbols)

        x3 = (x1[:, :, :-1, :, :-1] ** 2).sum((2, 4)).sqrt() * (
            1 - torch.eye(seqlen, device=weights.device, dtype=weights.dtype)[None])
        apc = x3.sum(1, keepdims=True) * x3.sum(2, keepdims=True) / x3.sum(
            (1, 2), keepdims=True)
        contacts = (x3 - apc) * (1 - torch.eye(
            seqlen, device=x3.device, dtype=x3.dtype).unsqueeze(0))

        f2d_dca = torch.cat([features, contacts[:, :, :, None]], axis=3)
        return f2d_dca

    @property
    def feature_size(self) -> int:
        return 526


class DilatedResidualBlock(nn.Module):

    def __init__(self, num_features: int, kernel_size: int, dilation: int, dropout: float):
        super().__init__()
        padding = self._get_padding(kernel_size, dilation)
        self.conv1 = nn.Conv2d(
            num_features, num_features, kernel_size, padding=padding, dilation=dilation)
        self.norm1 = nn.InstanceNorm2d(num_features, affine=True, eps=1e-6)
        self.actv1 = nn.ELU(inplace=True)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(
            num_features, num_features, kernel_size, padding=padding, dilation=dilation)
        self.norm2 = nn.InstanceNorm2d(num_features, affine=True, eps=1e-6)
        self.actv2 = nn.ELU(inplace=True)
        self.apply(self._init_weights)
        nn.init.constant_(self.norm2.weight, 0)

    def _get_padding(self, kernel_size: int, dilation: int) -> int:
        return (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2

    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, nn.Conv2d):
            nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
            if module.bias is not None:
                module.bias.data.zero_()

        # elif isinstance(module, DilatedResidualBlock):
            # nn.init.constant_(module.norm2.weight, 0)

    def forward(self, features):
        shortcut = features
        features = self.conv1(features)
        features = self.norm1(features)
        features = self.actv1(features)
        features = self.dropout(features)
        features = self.conv2(features)
        features = self.norm2(features)
        features = self.actv2(features + shortcut)
        return features


class TRRosettaAbstractModel(ProteinModel):

    config_class = TRRosettaConfig
    base_model_prefix = 'trrosetta'
    pretrained_model_archive_map = TRROSETTA_PRETRAINED_MODEL_ARCHIVE_MAP

    def __init__(self, config: TRRosettaConfig):
        super().__init__(config)

    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Conv2d):
            nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, DilatedResidualBlock):
            nn.init.constant_(module.norm2.weight, 0)


class TRRosettaPredictor(TRRosettaAbstractModel):

    def __init__(self, config: TRRosettaConfig):
        super().__init__(config)
        layers = [
            nn.Conv2d(526, config.num_features, 1),
            nn.InstanceNorm2d(config.num_features, affine=True, eps=1e-6),
            nn.ELU(),
            nn.Dropout(config.dropout)]

        dilation = 1
        for _ in range(config.num_layers):
            block = DilatedResidualBlock(
                config.num_features, config.kernel_size, dilation, config.dropout)
            layers.append(block)

            dilation *= 2
            if dilation > 16:
                dilation = 1

        self.resnet = nn.Sequential(*layers)
        self.predict_theta = nn.Conv2d(config.num_features, 25, 1)
        self.predict_phi = nn.Conv2d(config.num_features, 13, 1)
        self.predict_dist = nn.Conv2d(config.num_features, 37, 1)
        self.predict_bb = nn.Conv2d(config.num_features, 3, 1)
        self.predict_omega = nn.Conv2d(config.num_features, 25, 1)

        self.init_weights()

    def init_weights(self):
        self.apply(self._init_weights)
        nn.init.constant_(self.predict_theta.weight, 0)
        nn.init.constant_(self.predict_phi.weight, 0)
        nn.init.constant_(self.predict_dist.weight, 0)
        nn.init.constant_(self.predict_bb.weight, 0)
        nn.init.constant_(self.predict_omega.weight, 0)

    def forward(self,
                features,
                theta=None,
                phi=None,
                dist=None,
                omega=None):
        batch_size = features.size(0)
        seqlen = features.size(2)
        embedding = self.resnet(features)

        # anglegrams for theta
        logits_theta = self.predict_theta(embedding)

        # anglegrams for phi
        logits_phi = self.predict_phi(embedding)

        # symmetrize
        sym_embedding = 0.5 * (embedding + embedding.transpose(-1, -2))

        # distograms
        logits_dist = self.predict_dist(sym_embedding)

        # beta-strand pairings (not used)
        # logits_bb = self.predict_bb(sym_embedding)

        # anglegrams for omega
        logits_omega = self.predict_omega(sym_embedding)

        logits_dist = logits_dist.permute(0, 2, 3, 1).contiguous()
        logits_theta = logits_theta.permute(0, 2, 3, 1).contiguous()
        logits_omega = logits_omega.permute(0, 2, 3, 1).contiguous()
        logits_phi = logits_phi.permute(0, 2, 3, 1).contiguous()

        probs = {}
        probs['p_dist'] = nn.Softmax(-1)(logits_dist)
        probs['p_theta'] = nn.Softmax(-1)(logits_theta)
        probs['p_omega'] = nn.Softmax(-1)(logits_omega)
        probs['p_phi'] = nn.Softmax(-1)(logits_phi)
        outputs = (probs,)

        metrics = {}
        total_loss = 0

        if dist is not None:
            logits_dist = logits_dist.reshape(batch_size * seqlen * seqlen, 37)
            loss_dist = nn.CrossEntropyLoss(ignore_index=-1)(logits_dist, dist.view(-1))
            metrics['dist'] = loss_dist
            total_loss += loss_dist
        if theta is not None:
            logits_theta = logits_theta.reshape(batch_size * seqlen * seqlen, 25)
            loss_theta = nn.CrossEntropyLoss(ignore_index=0)(logits_theta, theta.view(-1))
            metrics['theta'] = loss_theta
            total_loss += loss_theta
        if omega is not None:
            logits_omega = logits_omega.reshape(batch_size * seqlen * seqlen, 25)
            loss_omega = nn.CrossEntropyLoss(ignore_index=0)(logits_omega, omega.view(-1))
            metrics['omega'] = loss_omega
            total_loss += loss_omega
        if phi is not None:
            logits_phi = logits_phi.reshape(batch_size * seqlen * seqlen, 13)
            loss_phi = nn.CrossEntropyLoss(ignore_index=0)(logits_phi, phi.view(-1))
            metrics['phi'] = loss_phi
            total_loss += loss_phi

        if len(metrics) > 0:
            outputs = ((total_loss, metrics),) + outputs

        return outputs


@registry.register_task_model('trrosetta', 'trrosetta')
class TRRosetta(TRRosettaAbstractModel):

    def __init__(self, config: TRRosettaConfig):
        super().__init__(config)
        self.extract_features = MSAFeatureExtractor(config)
        self.trrosetta = TRRosettaPredictor(config)

    def forward(self,
                msa1hot,
                theta=None,
                phi=None,
                dist=None,
                omega=None):
        features = self.extract_features(msa1hot)
        return self.trrosetta(features, theta, phi, dist, omega)