__author__ = 'yawli'


import os
import math
from decimal import Decimal
import torch.optim as optim
import utility

import torch
from torch.autograd import Variable
from tqdm import tqdm
from trainer import Trainer


class TrainerFT(Trainer):
    def __init__(self, args, loader, my_model, my_loss, ckp):
        super(TrainerFT, self).__init__(args, loader, my_model, my_loss, ckp)
        # self.args = args
        # self.scale = args.scale
        #
        # self.ckp = ckp
        # self.loader_train = loader.loader_train
        # self.loader_test = loader.loader_test
        # self.model = my_model
        # self.loss = my_loss
        if self.args.model.lower() == 'finetune':
            self.optimizer = self.make_optimizer(args, self.model)
        # self.scheduler = utility.make_scheduler(args, self.optimizer)
        #
        # if self.args.load != '.':
        #     self.optimizer.load_state_dict(
        #         torch.load(os.path.join(ckp.dir, 'optimizer.pt'))
        #     )
        #     for _ in range(len(ckp.log)): self.scheduler.step()
        #
        # self.error_last = 1e8

    def train(self):
        self.scheduler.step()
        self.loss.step()
        epoch = self.scheduler.last_epoch + 1
        lr = self.scheduler.get_lr()[0]

        self.ckp.write_log(
            '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
        )
        self.loss.start_log()
        self.model.train()
        # from IPython import embed; embed(); exit()
        timer_data, timer_model = utility.timer(), utility.timer()
        for batch, (lr, nl, mk, hr, _, idx_scale) in enumerate(self.loader_train):
            # from IPython import embed; embed(); exit()
            lr, nl, mk, hr = self.prepare([lr, nl, mk, hr])
            timer_data.hold()
            timer_model.tic()

            self.optimizer.zero_grad()
            sr = self.model(idx_scale, lr, nl, mk)
            # from IPython import embed; embed(); exit()
            loss = self.loss(sr, hr)
            if loss.item() < self.args.skip_threshold * self.error_last:
                loss.backward()
                self.optimizer.step()
            else:
                print('Skip this batch {}! (Loss: {})'.format(
                    batch + 1, loss.item()
                ))

            timer_model.hold()

            if (batch + 1) % self.args.print_every == 0:
                self.ckp.write_log('[{}/{}]\t{}\t{:.3f}+{:.3f}s'.format(
                    (batch + 1) * self.args.batch_size,
                    len(self.loader_train.dataset),
                    self.loss.display_loss(batch),
                    timer_model.release(),
                    timer_data.release()))

            timer_data.tic()
            # from IPython import embed; embed(); exit()
        self.loss.end_log(len(self.loader_train))
        self.error_last = self.loss.log[-1, -1]

    def test(self):
        epoch = self.scheduler.last_epoch + 1
        self.ckp.write_log('\nEvaluation:')
        self.ckp.add_log(torch.zeros(1, len(self.scale)))
        self.model.eval()

        timer_test = utility.timer()
        with torch.no_grad():
            for idx_scale, scale in enumerate(self.scale):
                eval_acc = 0
                self.loader_test.dataset.set_scale(idx_scale)
                tqdm_test = tqdm(self.loader_test, ncols=80)
                for idx_img, (lr, nl, mk, hr, filename, _) in enumerate(tqdm_test):
                    # print('FLAG')
                    # print(filename)
                    filename = filename[0]
                    print(filename)
                    no_eval = (hr.nelement() == 1)
                    if not no_eval:
                        lr, nl, mk, hr = self.prepare([lr, nl, mk, hr])
                    else:
                        lr, nl, mk, = self.prepare([lr, nl, mk])

                    sr = self.model(idx_scale, lr, nl, mk)
                    sr = utility.quantize(sr, self.args.rgb_range)
                    # print(sr.shape)
                    b, c, h, w = sr.shape
                    hr = hr[:, :, :h, :w]
                    save_list = [sr]
                    if not no_eval:
                        eval_acc += utility.calc_psnr(
                            sr, hr, scale, self.args.rgb_range,
                            benchmark=self.loader_test.dataset.benchmark
                        )
                        save_list.extend([lr, hr])

                    if self.args.save_results:
                        self.ckp.save_results(filename, save_list, scale)

                self.ckp.log[-1, idx_scale] = eval_acc / len(self.loader_test)
                best = self.ckp.log.max(0)
                self.ckp.write_log(
                    '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
                        self.args.data_test,
                        scale,
                        self.ckp.log[-1, idx_scale],
                        best[0][idx_scale],
                        best[1][idx_scale] + 1
                    )
                )

        self.ckp.write_log(
            'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True
        )
        if not self.args.test_only:
            self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch))

    def make_optimizer(self, args, model):
        trainable = filter(lambda x: x.requires_grad, model.model.parameters())
        # from IPython import embed; embed(); exit()
        finetune_id = list(map(id, model.model.body_ft.parameters())) \
                      + list(map(id, model.model.tail_ft.parameters()))#\
                      #+ list(map(id, model.model.tail_ft2.parameters()))
        base_params = filter(lambda x: id(x) not in finetune_id, trainable)
        trainable = filter(lambda x: x.requires_grad, model.model.parameters())
        finetune_params = filter(lambda x: id(x) in finetune_id, trainable)
        if args.optimizer == 'SGD':
            optimizer_function = optim.SGD
            kwargs = {'momentum': args.momentum}
        elif args.optimizer == 'ADAM':
            optimizer_function = optim.Adam
            kwargs = {
                'betas': (args.beta1, args.beta2),
                'eps': args.epsilon
            }
        elif args.optimizer == 'RMSprop':
            optimizer_function = optim.RMSprop
            kwargs = {'eps': args.epsilon}

        kwargs['lr'] = args.lr * 0.1
        kwargs['weight_decay'] = args.weight_decay
        # from IPython import embed; embed(); exit()
        return optimizer_function([
                                      {'params': base_params},
                                      {'params': finetune_params, 'lr': args.lr}
                                  ], **kwargs)