# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from __future__ import division
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.models as models


class VA(nn.Module):
    """The layer for transforming the skeleton to the observed viewpoints"""
    def __init__(self,num_classes = 60):
        super(VA, self).__init__()
        self.num_classes = num_classes
        self.conv1 = nn.Conv2d(3, 128, kernel_size=5, stride=2,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(128)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(128, 128, kernel_size=5, stride=2,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        self.relu2 = nn.ReLU(inplace=True)
        self.avepool = nn.MaxPool2d(7)
        self.fc = nn.Linear(6272, 6)
        self.classifier = models.resnet50(pretrained=True)
        self.init_weight()

    def forward(self, x1, maxmin):

        x = self.conv1(x1)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.avepool(x)

        x = x.view(x.size(0), -1)
        trans = self.fc(x)

        temp1 = trans.cpu()
        x = _transform(x1, trans, maxmin)

        temp = x.cpu()
        x = self.classifier(x)
        return x, temp.data.numpy(), temp1.data.numpy()

    def init_weight(self):
        for layer in [self.conv1, self.conv2]:
            for name, param in layer.named_parameters():
                if 'weight' in name:
                    nn.init.xavier_uniform_(param)
                if 'bias' in name:
                    param.data.zero_()
        for layer in [self.bn1, self.bn2]:
            layer.weight.data.fill_(1)
            layer.bias.data.fill_(0)
            layer.momentum = 0.99
            layer.eps = 1e-3

        self.fc.bias.data.zero_()
        self.fc.weight.data.zero_()

        num_ftrs = self.classifier.fc.in_features
        self.classifier.fc = nn.Linear(num_ftrs, self.num_classes)

# get transformation matrix
def _trans_rot(trans, rot):
    cos_r, sin_r = rot.cos(), rot.sin()
    zeros = Variable(rot.data.new(rot.size()[:1] + (1,)).zero_())
    ones = Variable(rot.data.new(rot.size()[:1] + (1,)).fill_(1))

    r1 = torch.stack((ones, zeros, zeros),dim=-1)
    rx2 = torch.stack((zeros, cos_r[:,0:1], sin_r[:,0:1]), dim = -1)
    rx3 = torch.stack((zeros, -sin_r[:,0:1], cos_r[:,0:1]), dim = -1)
    rx = torch.cat((r1, rx2, rx3), dim = 1)

    ry1 = torch.stack((cos_r[:,1:2], zeros, -sin_r[:,1:2]), dim =-1)
    r2 = torch.stack((zeros, ones, zeros),dim=-1)
    ry3 = torch.stack((sin_r[:,1:2], zeros, cos_r[:,1:2]), dim =-1)
    ry = torch.cat((ry1, r2, ry3), dim = 1)

    rz1 = torch.stack((cos_r[:,2:3], sin_r[:,2:3], zeros), dim =-1)
    r3 = torch.stack((zeros, zeros, ones),dim=-1)
    rz2 = torch.stack((-sin_r[:,2:3], cos_r[:,2:3],zeros), dim =-1)
    rz = torch.cat((rz1, rz2, r3), dim = 1)

    rot = rz.matmul(ry).matmul(rx)


    rt1 = torch.stack((ones, zeros, zeros, trans[:,0:1]), dim = -1)
    rt2 = torch.stack((zeros, ones, zeros, trans[:,1:2]), dim = -1)
    rt3 = torch.stack((zeros, zeros, ones, trans[:,2:3]), dim = -1)
    trans = torch.cat((rt1, rt2, rt3), dim = 1)

    return trans, rot

# transform skeleton
def _transform(x, mat, maxmin):
    rot = mat[:,0:3]
    trans = mat[:,3:6]

    x = x.contiguous().view(-1, x.size()[1] , x.size()[2] * x.size()[3])

    max_val, min_val = maxmin[:,0], maxmin[:,1]
    max_val, min_val = max_val.contiguous().view(-1,1), min_val.contiguous().view(-1,1)
    max_val, min_val = max_val.repeat(1,3), min_val.repeat(1,3)
    trans, rot = _trans_rot(trans, rot)

    x1 = torch.matmul(rot,x)
    min_val1 = torch.cat((min_val, Variable(min_val.data.new(min_val.size()[0], 1).fill_(1))), dim=-1)
    min_val1 = min_val1.unsqueeze(-1)
    min_val1 = torch.matmul(trans, min_val1)

    min_val = torch.div( torch.add(torch.matmul(rot, min_val1).squeeze(-1), - min_val), torch.add(max_val, - min_val))

    min_val = min_val.mul_(255)
    x = torch.add(x1, min_val.unsqueeze(-1))

    x = x.contiguous().view(-1,3, 224,224)

    return x