import sys,time
import numpy as np
import torch
from copy import deepcopy
import utils

########################################################################################################################

class Appr(object):

    def __init__(self,model,nepochs=100,sbatch=64,lr=0.05,lr_min=1e-4,lr_factor=3,lr_patience=5,clipgrad=10000,lamb=0.75,smax=400,args=None):
        self.model=model

        self.nepochs=nepochs
        self.sbatch=sbatch
        self.lr=lr
        self.lr_min=lr_min
        self.lr_factor=lr_factor
        self.lr_patience=lr_patience
        self.clipgrad=clipgrad

        self.ce=torch.nn.CrossEntropyLoss()
        self.optimizer=self._get_optimizer()

        self.lamb=lamb
        self.smax=smax
        self.logpath = None
        self.single_task = False
        if len(args.parameter)>=1:
            params=args.parameter.split(',')
            print('Setting parameters to',params)
            if len(params)>1:
                if utils.is_number(params[0]):
                    self.lamb=float(params[0])
                else:
                    self.logpath = params[0]
                if utils.is_number(params[1]):
                    self.smax=float(params[1])
                else:
                    self.logpath = params[1]
                if len(params)>2 and not utils.is_number(params[2]):
                    self.logpath = params[2]
                if len(params)>3 and utils.is_number(params[3]):
                    self.single_task = int(params[3])
            else:
                self.logpath = args.parameter

        if self.logpath is not None:
            self.logs={}
            self.logs['train_loss'] = {}
            self.logs['train_acc'] = {}
            self.logs['train_reg'] = {}
            self.logs['valid_loss'] = {}
            self.logs['valid_acc'] = {}
            self.logs['valid_reg'] = {}
            self.logs['mask'] = {}
            self.logs['mask_pre'] = {}
        else:
            self.logs = None

        self.mask_pre=None
        self.mask_back=None

        return

    def _get_optimizer(self,lr=None):
        if lr is None: lr=self.lr
        return torch.optim.SGD(self.model.parameters(),lr=lr)

    def train(self,t,xtrain,ytrain,xvalid,yvalid):
        best_loss=np.inf
        best_model=utils.get_model(self.model)
        lr=self.lr
        patience=self.lr_patience
        self.optimizer=self._get_optimizer(lr)

        #log
        losses_train = []
        losses_valid = []
        acc_train = []
        acc_valid = []
        reg_train = []
        reg_valid = []
        self.logs['mask'][t]={}
        self.logs['mask_pre'][t]={}
        task=torch.autograd.Variable(torch.LongTensor([t]).cuda(),volatile=False)
        bmask=self.model.mask(task,s=self.smax)
        for i in range(len(bmask)):
            bmask[i]=torch.autograd.Variable(bmask[i].data.clone(),requires_grad=False)
            self.logs['mask'][t][i]={}
            self.logs['mask'][t][i][-1]=deepcopy(bmask[i].data.cpu().numpy().astype(np.float32))
            if t==0:
                self.logs['mask_pre'][t][i]=deepcopy((0*bmask[i]).data.cpu().numpy().astype(np.float32))
            else:
                self.logs['mask_pre'][t][i]=deepcopy(self.mask_pre[i].data.cpu().numpy().astype(np.float32))

        if not self.single_task or (self.single_task and t==0):
            # Loop epochs
            try:
                for e in range(self.nepochs):
                    # Train
                    clock0=time.time()
                    self.train_epoch(t,xtrain,ytrain)
                    clock1=time.time()
                    train_loss,train_acc,train_reg=self.eval_withreg(t,xtrain,ytrain)
                    clock2=time.time()
                    print('| Epoch {:3d}, time={:5.1f}ms/{:5.1f}ms | Train: loss={:.3f}, acc={:5.1f}% |'.format(e+1,
                        1000*self.sbatch*(clock1-clock0)/xtrain.size(0),1000*self.sbatch*(clock2-clock1)/xtrain.size(0),train_loss,100*train_acc),end='')
                    # Valid
                    valid_loss,valid_acc,valid_reg=self.eval_withreg(t,xvalid,yvalid)
                    print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss,100*valid_acc),end='')

                    #log
                    losses_train.append(train_loss)
                    acc_train.append(train_acc)
                    reg_train.append(train_reg)
                    losses_valid.append(valid_loss)
                    acc_valid.append(valid_acc)
                    reg_valid.append(valid_reg)

                    # Adapt lr
                    if valid_loss<best_loss:
                        best_loss=valid_loss
                        best_model=utils.get_model(self.model)
                        patience=self.lr_patience
                        print(' *',end='')
                    else:
                        patience-=1
                        if patience<=0:
                            lr/=self.lr_factor
                            print(' lr={:.1e}'.format(lr),end='')
                            if lr<self.lr_min:
                                print()
                                break
                            patience=self.lr_patience
                            self.optimizer=self._get_optimizer(lr)
                    print()

                    # Log activations mask
                    task=torch.autograd.Variable(torch.LongTensor([t]).cuda(),volatile=False)
                    bmask=self.model.mask(task,s=self.smax)
                    for i in range(len(bmask)):
                        self.logs['mask'][t][i][e] = deepcopy(bmask[i].data.cpu().numpy().astype(np.float32))

                # Log losses
                if self.logs is not None:
                    self.logs['train_loss'][t] = np.array(losses_train)
                    self.logs['train_acc'][t] = np.array(acc_train)
                    self.logs['train_reg'][t] = np.array(reg_train)
                    self.logs['valid_loss'][t] = np.array(losses_valid)
                    self.logs['valid_acc'][t] = np.array(acc_valid)
                    self.logs['valid_reg'][t] = np.array(reg_valid)
            except KeyboardInterrupt:
                print()

        # Restore best validation model
        utils.set_model_(self.model,best_model)

        # Activations mask
        task=torch.autograd.Variable(torch.LongTensor([t]).cuda(),volatile=False)
        mask=self.model.mask(task,s=self.smax)
        for i in range(len(mask)):
            mask[i]=torch.autograd.Variable(mask[i].data.clone(),requires_grad=False)
        if t==0:
            self.mask_pre=mask
        else:
            for i in range(len(self.mask_pre)):
                self.mask_pre[i]=torch.max(self.mask_pre[i],mask[i])

        # Weights mask
        self.mask_back={}
        for n,_ in self.model.named_parameters():
            vals=self.model.get_view_for(n,self.mask_pre)
            if vals is not None:
                self.mask_back[n]=1-vals

        return

    def train_epoch(self,t,x,y,thres_cosh=50,thres_emb=6):
        self.model.train()

        r=np.arange(x.size(0))
        np.random.shuffle(r)
        r=torch.LongTensor(r).cuda()

        # Loop batches
        for i in range(0,len(r),self.sbatch):
            if i+self.sbatch<=len(r): b=r[i:i+self.sbatch]
            else: b=r[i:]
            images=torch.autograd.Variable(x[b],volatile=False)
            targets=torch.autograd.Variable(y[b],volatile=False)
            task=torch.autograd.Variable(torch.LongTensor([t]).cuda(),volatile=False)
            s=(self.smax-1/self.smax)*i/len(r)+1/self.smax

            # Forward
            outputs,masks=self.model.forward(task,images,s=s)
            output=outputs[t]
            loss,_=self.criterion(output,targets,masks)

            # Backward
            self.optimizer.zero_grad()
            loss.backward()

            # Restrict layer gradients in backprop
            if t>0:
                for n,p in self.model.named_parameters():
                    if n in self.mask_back:
                        p.grad.data*=self.mask_back[n]

            # Compensate embedding gradients
            for n,p in self.model.named_parameters():
                if n.startswith('e'):
                    num=torch.cosh(torch.clamp(s*p.data,-thres_cosh,thres_cosh))+1
                    den=torch.cosh(p.data)+1
                    p.grad.data*=self.smax/s*num/den

            # Apply step
            torch.nn.utils.clip_grad_norm(self.model.parameters(),self.clipgrad)
            self.optimizer.step()

            # Constrain embeddings
            for n,p in self.model.named_parameters():
                if n.startswith('e'):
                    p.data=torch.clamp(p.data,-thres_emb,thres_emb)

            #print(masks[-1].data.view(1,-1))
            #if i>=5*self.sbatch: sys.exit()
            #if i==0: print(masks[-2].data.view(1,-1),masks[-2].data.max(),masks[-2].data.min())
        #print(masks[-2].data.view(1,-1))

        return

    def eval(self,t,x,y):
        total_loss=0
        total_acc=0
        total_num=0
        self.model.eval()

        total_reg=0

        r=np.arange(x.size(0))
        r=torch.LongTensor(r).cuda()

        # Loop batches
        for i in range(0,len(r),self.sbatch):
            if i+self.sbatch<=len(r): b=r[i:i+self.sbatch]
            else: b=r[i:]
            images=torch.autograd.Variable(x[b],volatile=True)
            targets=torch.autograd.Variable(y[b],volatile=True)
            task=torch.autograd.Variable(torch.LongTensor([t]).cuda(),volatile=True)

            # Forward
            factor=1
            if self.single_task: factor=10000
            outputs,masks=self.model.forward(task,images,s=factor*self.smax)
            output=outputs[t]
            loss,reg=self.criterion(output,targets,masks)
            _,pred=output.max(1)
            hits=(pred==targets).float()

            # Log
            total_loss+=loss.data.cpu().numpy().item()*len(b)
            total_acc+=hits.sum().data.cpu().numpy().item()
            total_num+=len(b)
            total_reg+=reg.data.cpu().numpy().item()*len(b)

        print('  {:.3f}  '.format(total_reg/total_num),end='')

        return total_loss/total_num,total_acc/total_num

    def eval_withreg(self,t,x,y):
        total_loss=0
        total_acc=0
        total_num=0
        self.model.eval()

        total_reg=0

        r=np.arange(x.size(0))
        r=torch.LongTensor(r).cuda()

        # Loop batches
        for i in range(0,len(r),self.sbatch):
            if i+self.sbatch<=len(r): b=r[i:i+self.sbatch]
            else: b=r[i:]
            images=torch.autograd.Variable(x[b],volatile=True)
            targets=torch.autograd.Variable(y[b],volatile=True)
            task=torch.autograd.Variable(torch.LongTensor([t]).cuda(),volatile=True)

            # Forward
            factor=1
            if self.single_task: factor=10000
            outputs,masks=self.model.forward(task,images,s=factor*self.smax)
            output=outputs[t]
            loss,reg=self.criterion(output,targets,masks)
            _,pred=output.max(1)
            hits=(pred==targets).float()

            # Log
            total_loss+=loss.data.cpu().numpy().item()*len(b)
            total_acc+=hits.sum().data.cpu().numpy().item()
            total_num+=len(b)
            total_reg+=reg.data.cpu().numpy().item()*len(b)

        print('  {:.3f}  '.format(total_reg/total_num),end='')

        return total_loss/total_num,total_acc/total_num,total_reg/total_num

    def criterion(self,outputs,targets,masks):
        reg=0
        count=0
        if self.mask_pre is not None:
            for m,mp in zip(masks,self.mask_pre):
                aux=1-mp
                reg+=(m*aux).sum()
                count+=aux.sum()
        else:
            for m in masks:
                reg+=m.sum()
                count+=np.prod(m.size()).item()
        reg/=count
        return self.ce(outputs,targets)+self.lamb*reg,reg

########################################################################################################################