# coding=utf-8

import torch.nn as nn
import math, os, time, random
import torch.utils.model_zoo as model_zoo
import argparse
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from tqdm import tqdm

parser = argparse.ArgumentParser(description='hashNet')
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--test-batch-size', type=int, default=1000)
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--momentum', type=float, default=0.4)
parser.add_argument('--no-cuda', action='store_true', default=False)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--log-interval', type=int, default=3)

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

# print(args.batch_size)

kwargs = {'num_worker': 1, 'pin_memory': True} if args.cuda else {}


class HashNet(nn.Module):
    def __init__(self, in_channel=1536, hashLength=1024):
        super(HashNet, self).__init__()
        self.fc = nn.Linear(in_channel, hashLength)
        self.sm = nn.Sigmoid()
        self.sma = nn.Softmax()
        self.initLinear()
        print(self.fc.weight.data)

    def forward(self, x1, x2, y):
        # x1 = self.sm(self.fc(self.sma(x1)))
        # x2 = self.sm(self.fc(self.sma(x2)))
        # y = self.sm(self.fc(self.sma(y)))

        # x1 = self.sm(self.fc(x1))
        # x2 = self.sm(self.fc(x2))
        # y = self.sm(self.fc(y))

        x1 = F.selu(self.fc(self.sma(x1)))
        x2 = F.selu(self.fc(self.sma(x2)))
        y = F.selu(self.fc(self.sma(y)))
        return x1, x2, y

    def initLinear(self):
        self.fc.weight.data.normal_(1.0, 0.33)
        self.fc.bias.data.fill_(0.1)


model = HashNet(in_channel=1536, hashLength=8192)

if args.cuda:
    model.cuda()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
pdist = nn.PairwiseDistance(2)


def train(epochs):
    model.train()
    pos1 = torch.load(open('../../feature/generated_dataset/pos1_fea.pt', 'rb'))
    pos2 = torch.load(open('../../feature/generated_dataset/pos2_fea.pt', 'rb'))
    neg = torch.load(open('../../feature/generated_dataset/neg_fea.pt', 'rb'))

    # print(pos1)
    # print(pos2)
    # print(neg)

    length = len(pos1)
    print(length)

    for epoch in range(epochs):
        for index in range(length):
            x1, x2, y = torch.FloatTensor(pos1[index]).view(1,-1), \
                        torch.FloatTensor(pos2[index]).view(1,-1), \
                        torch.FloatTensor(neg[index]).view(1,-1)
            # x1, x2, y = pos1[index],pos2[index],neg[index]
            # print('---------')
            # print(x1, x2, y)
            # print('---------')
            # x1 = x1.contiguous().view(1,3,28,28)
            # x2 = x2.contiguous().view(1,3,28,28)
            # y = y.contiguous().view(1,3,28,28)
            # print('---------')
            # print(batch_idx, x1, x2, y)
            # print('---------')

            # x1.type(torch.FloatTensor)
            # x2.type(torch.FloatTensor)
            # y.type(torch.FloatTensor)
            if args.cuda:
                x1, x2, y = x1.cuda(), x2.cuda(), y.cuda()
            x1, x2, y = Variable(x1), Variable(x2), Variable(y)
            optimizer.zero_grad()
            hash_x1, hash_x2, hash_y = model(x1, x2, y)
            loss1 = pdist(hash_x1, hash_x2)
            loss2 = pdist(hash_x1, hash_y)
            l = 10 - loss2 + loss1
            loss = F.relu(l)
            loss.backward()
            optimizer.step()
            if index % args.log_interval == 0:
                # print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                #     epoch, index , 1000, 100.0 * index / 1000,
                #     loss.data[0]))
                print("==============================================")
                print("total loss=", loss.data[0])
                print("x1==",x1.data,'x2==',x2.data,'y==',y.data)
                print("hashx1=",hash_x1,"hashx2=",hash_x2,'hashy=',hash_y)
                print('loss1==', loss1.data[0][0], 'loss2==', loss2.data[0][0])
                # time.sleep(5)
        torch.save(model.state_dict(), '../../model/hashNetInceptionv4-epoch' + str(epoch) + '.pth')


if __name__ == '__main__':
    train(1)