from __future__ import print_function
from constant import *
from models import *
from dataset import *
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from datetime import datetime
from pytorch_pretrained_bert.optimization import BertAdam


from sklearn.metrics import f1_score,precision_score,recall_score
os.environ["CUDA_VISIBLE_DEVICES"]="2"

TestSet=Dataset("TestA")
selector=Generator().cuda()
discriminator=Discriminator().cuda()

param_optimizer=list(selector.named_parameters())
no_decay=['bias','gamma','beta']
optimizer_grouped_parameters=[
        {'params':[p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],'weight_decay_rate':0.001},
        {'params':[p for n, p in param_optimizer if any(nd in n for nd in no_decay)],'weight_decay_rate':0.0}
    ]
sOpt=BertAdam(optimizer_grouped_parameters,lr=sLr,warmup=0.1,t_total=Epoch)

param_optimizer=list(discriminator.named_parameters())
no_decay=['bias','gamma','beta']
optimizer_grouped_parameters=[
        {'params':[p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],'weight_decay_rate':0.001},
        {'params':[p for n, p in param_optimizer if any(nd in n for nd in no_decay)],'weight_decay_rate':0.0}
    ]
dOpt=BertAdam(optimizer_grouped_parameters,lr=dLr,warmup=0.1,t_total=Epoch)

nonNAindex=[x for x in range(1,dimE)]
bst=0.0

def test(dataset):
    discriminator.eval()
    preds=[]
    labels=[]
    for words,inMask,maskL,maskR,label in dataset.batchs():
        scores=discriminator(words.cuda(),inMask.cuda(),maskL.cuda(),maskR.cuda())
        pred=torch.argmax(scores,dim=1).cpu().numpy()
        #print(pred)
        preds.append(pred)
        labels.append(label.numpy())
        #times=times+np.equal(pred,label.numpy())
    cnt=0
    cnt1=0
    FN=0
    FP=0
    preds=np.concatenate(preds,0)
    labels=np.concatenate(labels,0)
    '''
    cnt=0
    cnt1=0
    preds=np.array([1,1,1,1,0,0,0,0,0,0])
    labels=np.array([1,2,0,0,0,0,0,0,0,0])
    '''
    if dataset==TestSet or True:
        for i in range(0,preds.shape[0]):
            if labels[i]==0:
                cnt1+=1
            if preds[i]!=labels[i]:
                #print("%d %d"%(preds[i],labels[i]))
                cnt+=1
                if preds[i]==0 and labels[i]!=0:
                    FN+=1
                if preds[i]!=0 and labels[i]==0:
                    FP+=1
        print("EVAL %s #Wrong %d #NegToPos %d #PosToNeg %d #All %d #Negs %d"%("Test",cnt,FP,FN,len(preds),cnt1))
    acc=precision_score(labels,preds,labels=list(range(1,34)),average="micro")
    f1=f1_score(labels,preds,labels=list(range(1,34)),average="micro")
#    print(preds,labels)
    print(acc,f1)
    global bst
    if f1>bst and dataset==TestSet:
        print("BST %f"%(bst))
        torch.save(discriminator.state_dict(),"Dmodel.tar")
        bst=f1
    return f1

'''
def train(dataset):
    opt=optim.Adadelta(model.parameters(),lr=lr,rho=0.95,eps=1e-06)
    criterion=nn.CrossEntropyLoss().cuda()
    test(TestSet)
    for i in range(0,Epoch):
        loss_sum=0.0
        model.train()
        for words,pos,loc,maskL,maskR,label in dataset.batchs():
            scores=model(words.cuda(),pos.cuda(),loc.cuda(),maskL.cuda(),maskR.cuda())
            opt.zero_grad()
            loss=criterion(scores,label.cuda())
            loss.backward()
            #print(model.word_emb.weight.grad)
            opt.step()
            loss_sum+=loss.item()
        f=test(TestSet)
        f2=test(TrainSet)
        global bst
        print("Epoch %d Loss %f F1 %f %f BST %f"%(i,loss_sum,f,f2,bst))
'''
def genMask(idx):
    res=[]
    idx=idx.numpy()
    for x in idx:
        tmp=[0.0 for i in range(0,dimE)]
        tmp[x]=1.0
        res.append(tmp)
    return torch.ByteTensor(res)
def Dscore_G(nwords,nMask,nmaskL,nmaskR,nlabel,uwords,uMask,umaskL,umaskR,ulabel):
    nScores=F.sigmoid(discriminator(nwords.cuda(),nMask.cuda(),nmaskL.cuda(),nmaskR.cuda()))
    nScores=nScores[:,nonNAindex]
    nScores=torch.mean(nScores,dim=1)
    #print(uwords.size())
    if uwords.size(0)==0:
        return nScores
    uScores=F.sigmoid(discriminator(uwords.cuda(),uMask.cuda(),umaskL.cuda(),umaskR.cuda()))
    umask=genMask(ulabel).cuda()
    uScores=torch.masked_select(uScores,umask)
    return torch.cat((nScores,uScores),dim=0)
def genLoss(nwords,nMask,nmaskL,nmaskR,nlabel,uwords,uMask,umaskL,umaskR,ulabel):
    dScores=Dscore_G(nwords,npos,nloc,nmaskL,nmaskR,nlabel,uwords,upos,uloc,umaskL,umaskR,ulabel)
    words=torch.cat((nwords,uwords),0)
    inMask=torch.cat((nMask,uMask),0)
    maskL=torch.cat((nmaskL,umaskL),0)
    maskR=torch.cat((nmaskR,umaskR),0)
    cScores=selector(words.cuda(),inMask.cuda(),maskL.cuda(),maskR.cuda())
    cScores=torch.pow(cScores,alpha)
    cScores=F.softmax(cScores,dim=0)
    return -torch.dot(cScores,torch.log(dScores))
def trainGen(unconfIter):
    sOpt.zero_grad()
    nwords,nMask,nmaskL,nmaskR,nlabel,uwords,uMask,umaskL,umaskR,ulabel=unconfIter.next()
    sLoss=genLoss(nwords,nMask,nmaskL,nmaskR,nlabel,uwords,uMask,umaskL,umaskR,ulabel)
    sLoss.backward()
    sOpt.step()
    return sLoss.item()
def disConfLoss(words,inMask,maskL,maskR,label):
    dScores=F.sigmoid(discriminator(words.cuda(),inMask.cuda(),maskL.cuda(),maskR.cuda()))
    #print("DisConfLoss")
    #print("label")
    #print(label)
    mask=genMask(label).cuda()
    #print("mask")
    #print(mask)
    dScores=torch.masked_select(dScores,mask)
    #print("dScores")
    #print(dScores)
    #print(-torch.mean(torch.log(dScores)))
    return -torch.mean(torch.log(dScores))
def disUnconfLoss(words,inMask,maskL,maskR,label):
    cScores=selector(words.cuda(),inMask.cuda(),maskL.cuda(),maskR.cuda())
    #print("DisUnconfloss")
    #print("label")
    #print(label)
    cScores=torch.pow(cScores,alpha)
    cScores=F.softmax(cScores,dim=0)
    #print("cScores")
    #print(cScores)
    dScores=F.sigmoid(discriminator(words.cuda(),inMask.cuda(),maskL.cuda(),maskR.cuda()))
    mask=genMask(label).cuda()
    dScores=torch.masked_select(dScores,mask)
    return -torch.dot(cScores,torch.log(1.0-dScores))
def oneEpoch(e,joinSet,unconfSet):
    unconfIter=unconfSet.batchs()
    confIter=joinSet.conf_batch()
    cnt=0
    for uwords,uMask,umaskL,umaskR,ulabel,utimes in joinSet.unconf_batch():
        sLoss=trainGen(unconfIter)
        cwords,cMask,cmaskL,cmaskR,clabel=confIter.next()
        dLoss=disConfLoss(cwords,cMask,cmaskL,cmaskR,clabel)
        dLoss=dLoss+disUnconfLoss(uwords,uMask,umaskL,umaskR,ulabel)
        dOpt.zero_grad()
        dLoss.backward()
        dOpt.step()
        cnt+=1
        #print("Epoch %d Batch %d s_loss %f d_loss %f"%(e,cnt,sLoss,dLoss))
def testCnt(joinSet):
    discriminator.eval()
    for words,inMask,maskL,maskR,label,bound in joinSet.unconf_batch():
        scores=discriminator(words.cuda(),inMask.cuda(),maskL.cuda(),maskR.cuda())
        pred=torch.argmax(scores,dim=1).cpu().numpy()
        #print(times)
        #print(np.equal(pred,label.numpy()))
        joinSet.utimes[bound[0]:bound[1]]+=np.equal(pred,label.numpy())
        #times=times+np.equal(pred,label.numpy())
        #print(times)
def train():
    unconfSet=uDataset("TrainA_unconf")
    joinSet=joinDataset("TrainA_conf","TrainA_unconf")
    selector.train()
    discriminator.train()
    #testCnt(joinSet)
    for e in range(0,Epoch):
        discriminator.train()
        oneEpoch(e,joinSet,unconfSet)
        testCnt(joinSet)
        test(TestSet)
    test(TestSet)
    joinSet.dump(Threshold,"TrainA_conf","TrainA_unconf")
if __name__=='__main__':
    discriminator.load_state_dict(torch.load("model.tar"))
    selector.encoder.load_state_dict(torch.load("encoder.tar"))
    test(TestSet)
    for i in range(0,ItemTimes):
        train()