import shutil

import torch.nn as nn
import torch.optim
from tensorboardX import SummaryWriter
from torch.autograd import Variable
from torch.optim.rmsprop import RMSprop
from tqdm import tqdm

from utils import AverageTracker


class Train:
    def __init__(self, model, trainloader, valloader, args):
        self.model = model
        self.trainloader = trainloader
        self.valloader = valloader
        self.args = args
        self.start_epoch = 0
        self.best_top1 = 0.0

        # Loss function and Optimizer
        self.loss = None
        self.optimizer = None
        self.create_optimization()

        # Model Loading
        self.load_pretrained_model()
        self.load_checkpoint(self.args.resume_from)

        # Tensorboard Writer
        self.summary_writer = SummaryWriter(log_dir=args.summary_dir)

    def train(self):
        for cur_epoch in range(self.start_epoch, self.args.num_epochs):

            # Initialize tqdm
            tqdm_batch = tqdm(self.trainloader,
                              desc="Epoch-" + str(cur_epoch) + "-")

            # Learning rate adjustment
            self.adjust_learning_rate(self.optimizer, cur_epoch)

            # Meters for tracking the average values
            loss, top1, top5 = AverageTracker(), AverageTracker(), AverageTracker()

            # Set the model to be in training mode (for dropout and batchnorm)
            self.model.train()

            for data, target in tqdm_batch:

                if self.args.cuda:
                    data, target = data.cuda(async=self.args.async_loading), target.cuda(
                        async=self.args.async_loading)
                data_var, target_var = Variable(data), Variable(target)

                # Forward pass
                output = self.model(data_var)
                cur_loss = self.loss(output, target_var)

                # Optimization step
                self.optimizer.zero_grad()
                cur_loss.backward()
                self.optimizer.step()

                # Top-1 and Top-5 Accuracy Calculation
                cur_acc1, cur_acc5 = self.compute_accuracy(output.data, target, topk=(1, 5))
                loss.update(cur_loss.data[0])
                top1.update(cur_acc1[0])
                top5.update(cur_acc5[0])

            # Summary Writing
            self.summary_writer.add_scalar("epoch-loss", loss.avg, cur_epoch)
            self.summary_writer.add_scalar("epoch-top-1-acc", top1.avg, cur_epoch)
            self.summary_writer.add_scalar("epoch-top-5-acc", top5.avg, cur_epoch)

            # Print in console
            tqdm_batch.close()
            print("Epoch-" + str(cur_epoch) + " | " + "loss: " + str(
                loss.avg) + " - acc-top1: " + str(
                top1.avg)[:7] + "- acc-top5: " + str(top5.avg)[:7])

            # Evaluate on Validation Set
            if cur_epoch % self.args.test_every == 0 and self.valloader:
                self.test(self.valloader, cur_epoch)

            # Checkpointing
            is_best = top1.avg > self.best_top1
            self.best_top1 = max(top1.avg, self.best_top1)
            self.save_checkpoint({
                'epoch': cur_epoch + 1,
                'state_dict': self.model.state_dict(),
                'best_top1': self.best_top1,
                'optimizer': self.optimizer.state_dict(),
            }, is_best)

    def test(self, testloader, cur_epoch=-1):
        loss, top1, top5 = AverageTracker(), AverageTracker(), AverageTracker()

        # Set the model to be in testing mode (for dropout and batchnorm)
        self.model.eval()

        for data, target in testloader:
            if self.args.cuda:
                data, target = data.cuda(async=self.args.async_loading), target.cuda(
                    async=self.args.async_loading)
            data_var, target_var = Variable(data, volatile=True), Variable(target, volatile=True)

            # Forward pass
            output = self.model(data_var)
            cur_loss = self.loss(output, target_var)

            # Top-1 and Top-5 Accuracy Calculation
            cur_acc1, cur_acc5 = self.compute_accuracy(output.data, target, topk=(1, 5))
            loss.update(cur_loss.data[0])
            top1.update(cur_acc1[0])
            top5.update(cur_acc5[0])

        if cur_epoch != -1:
            # Summary Writing
            self.summary_writer.add_scalar("test-loss", loss.avg, cur_epoch)
            self.summary_writer.add_scalar("test-top-1-acc", top1.avg, cur_epoch)
            self.summary_writer.add_scalar("test-top-5-acc", top5.avg, cur_epoch)

        print("Test Results" + " | " + "loss: " + str(loss.avg) + " - acc-top1: " + str(
            top1.avg)[:7] + "- acc-top5: " + str(top5.avg)[:7])

    def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'):
        torch.save(state, self.args.checkpoint_dir + filename)
        if is_best:
            shutil.copyfile(self.args.checkpoint_dir + filename,
                            self.args.checkpoint_dir + 'model_best.pth.tar')

    def compute_accuracy(self, output, target, topk=(1,)):
        """Computes the accuracy@k for the specified values of k"""
        maxk = max(topk)
        batch_size = target.size(0)

        _, idx = output.topk(maxk, 1, True, True)
        idx = idx.t()
        correct = idx.eq(target.view(1, -1).expand_as(idx))

        acc_arr = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            acc_arr.append(correct_k.mul_(1.0 / batch_size))
        return acc_arr

    def adjust_learning_rate(self, optimizer, epoch):
        """Sets the learning rate to the initial LR multiplied by 0.98 every epoch"""
        learning_rate = self.args.learning_rate * (self.args.learning_rate_decay ** epoch)
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate

    def create_optimization(self):
        self.loss = nn.CrossEntropyLoss()

        if self.args.cuda:
            self.loss.cuda()

        self.optimizer = RMSprop(self.model.parameters(), self.args.learning_rate,
                                 momentum=self.args.momentum,
                                 weight_decay=self.args.weight_decay)

    def load_pretrained_model(self):
        try:
            print("Loading ImageNet pretrained weights...")
            pretrained_dict = torch.load(self.args.pretrained_path)
            self.model.load_state_dict(pretrained_dict)
            print("ImageNet pretrained weights loaded successfully.\n")
        except:
            print("No ImageNet pretrained weights exist. Skipping...\n")

    def load_checkpoint(self, filename):
        filename = self.args.checkpoint_dir + filename
        try:
            print("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)
            self.start_epoch = checkpoint['epoch']
            self.best_top1 = checkpoint['best_top1']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            print("Checkpoint loaded successfully from '{}' at (epoch {})\n"
                  .format(self.args.checkpoint_dir, checkpoint['epoch']))
        except:
            print("No checkpoint exists from '{}'. Skipping...\n".format(self.args.checkpoint_dir))