from mpi4py import MPI
from solver import Solver
import torch
import os
import time
import warnings
import datetime
import numpy as np
from tqdm import tqdm
from misc.utils import color, get_fake, get_labels, get_loss_value
from misc.utils import split, TimeNow, to_var
from misc.losses import _compute_loss_smooth, _GAN_LOSS
import torch.utils.data.distributed
from misc.utils import horovod
hvd = horovod()
comm = MPI.COMM_WORLD
warnings.filterwarnings('ignore')


class Train(Solver):
    def __init__(self, config, data_loader):
        super(Train, self).__init__(config, data_loader)
        self.count_seed = 0
        self.step_seed = 4  # 1 disc - 3 gen
        self.run()

    # ============================================================#
    # ============================================================#
    def update_lr(self, g_lr, d_lr):
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    # ============================================================#
    # ============================================================#
    def reset_grad(self):
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    # ============================================================#
    # ============================================================#
    def update_loss(self, loss, value):
        try:
            self.LOSS[loss].append(value)
        except BaseException:
            self.LOSS[loss] = []
            self.LOSS[loss].append(value)

    # ============================================================#
    # ============================================================#
    def get_labels(self):
        return get_labels(
            self.config.image_size,
            self.config.dataset_fake,
            attr=self.data_loader.dataset)

    # ============================================================#
    # ============================================================#
    def debug_vars(self, start):
        fixed_x = []
        fixed_label = []
        for i, (images, labels, _) in enumerate(self.data_loader):
            fixed_x.append(images)
            fixed_label.append(labels)
            if i == max(1, int(16 / self.config.batch_size)):
                break
        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_label = torch.cat(fixed_label, dim=0)
        fixed_style = self.random_style(fixed_x, seed=self.count_seed)

        if start == 0:
            self.generate_SMIT(
                fixed_x,
                self.output_sample(0, 0),
                Multimodal=1,
                label=fixed_label,
                training=True,
                fixed_style=fixed_style)
            if self.config.image_size == 256:
                self.generate_SMIT(
                    fixed_x,
                    self.output_sample(0, 0),
                    label=fixed_label,
                    training=True)

        return fixed_x, fixed_label, fixed_style

    # ============================================================#
    # ============================================================#
    def _GAN_LOSS(self, real_x, fake_x, label):
        cross_entropy = self.config.dataset_fake in [
            'painters_14', 'Animals', 'Image2Weather', 'Image2Season',
            'Image2Edges', 'RafD', 'BP4D_idt'
            # 'Image2Edges', 'Yosemite', 'RafD', 'BP4D_idt'
        ]
        if cross_entropy:
            label = torch.max(label, dim=1)[1]
        return _GAN_LOSS(
            self.D, real_x, fake_x, label, cross_entropy=cross_entropy)

    # ============================================================#
    # ============================================================#
    def INFO(self, epoch, iter):
        # PRINT log info
        if self.verbose:
            if (iter + 1) % self.config.log_step == 0 or iter + epoch == 0:
                self.loss = {
                    key: get_loss_value(value)
                    for key, value in self.loss.items()
                }
                color(self.loss, 'Gatm', 'blue')
                self.progress_bar.set_postfix(**self.loss)
            if (iter + 1) == len(self.data_loader):
                self.progress_bar.set_postfix('')

    # ============================================================#
    # ============================================================#
    def Decay_lr(self, current_epoch=0):
        self.d_lr -= (
            self.config.d_lr /
            float(self.config.num_epochs - self.config.num_epochs_decay))
        self.g_lr -= (
            self.config.g_lr /
            float(self.config.num_epochs - self.config.num_epochs_decay))
        self.update_lr(self.g_lr, self.d_lr)
        if self.verbose and current_epoch % self.config.save_epoch == 0:
            self.PRINT('Decay learning rate to g_lr: {}, d_lr: {}.'.format(
                self.g_lr, self.d_lr))

    # ============================================================#
    # ============================================================#
    def RESUME_INFO(self):
        if not self.config.pretrained_model:
            return 0, 0
        start = int(self.config.pretrained_model.split('_')[0]) + 1
        total_iter = start * int(self.config.pretrained_model.split('_')[1])
        self.count_seed = start * total_iter * self.step_seed
        for e in range(start):
            if e > self.config.num_epochs_decay:
                self.Decay_lr(e)
        return start, total_iter

    # ============================================================#
    # ============================================================#
    def MISC(self, epoch, iter):
        if epoch % self.config.save_epoch == 0 and self.verbose:
            # Save Weights
            self.save(epoch, iter + 1)

            # Save Translation
            self.generate_SMIT(
                self.fixed_x,
                self.output_sample(epoch, iter + 1),
                Multimodal=1,
                label=self.fixed_label,
                training=True,
                fixed_style=self.fixed_style)
            if self.config.image_size == 256:
                self.generate_SMIT(
                    self.fixed_x,
                    self.output_sample(epoch, iter + 1),
                    Multimodal=1,
                    label=self.fixed_label,
                    training=True)
            if self.config.image_size == 256:
                self.generate_SMIT(
                    self.fixed_x,
                    self.output_sample(epoch, iter + 1),
                    label=self.fixed_label,
                    training=True)

            # Debug INFO
            elapsed = time.time() - self.start_time
            elapsed = str(datetime.timedelta(seconds=elapsed))
            log = '-> %s | Elapsed [Iter: %d] (%d/%d) : %s | %s\nTrain' % (
                TimeNow(), self.total_iter, epoch, self.config.num_epochs,
                elapsed, self.Log)
            for tag, value in sorted(self.LOSS.items()):
                log += ", {}: {:.4f}".format(tag, np.array(value).mean())
            self.PRINT(log)
            # self.PLOT(epoch)

        comm.Barrier()
        # Decay learning rate
        if epoch > self.config.num_epochs_decay:
            self.Decay_lr(epoch)

    # ============================================================#
    # ============================================================#
    def reset_losses(self):
        return {}

    # ============================================================#
    # ============================================================#
    def current_losses(self, mode, **kwargs):
        loss = 0
        for key, _ in kwargs.items():
            if mode in key:
                loss += self.loss[key]
                self.update_loss(key, get_loss_value(self.loss[key]))
        return loss

    # ============================================================#
    # ============================================================#
    def to_var(self, *args):
        vars = []
        for arg in args:
            vars.append(to_var(arg))
        return vars

    # ============================================================#
    # ============================================================#
    def train_model(self, generator=False, discriminator=False):
        if torch.cuda.device_count() > 1 and hvd.size() == 1:
            G = self.G.module
        else:
            G = self.G
        for p in G.generator.parameters():
            try:
                p.requires_grad_(generator)
            except AttributeError:
                p.requires_grad = generator
        for p in self.D.parameters():
            try:
                p.requires_grad_(discriminator)
            except AttributeError:
                p.requires_grad = discriminator

    # ============================================================#
    # ============================================================#

    def Dis_update(self, real_x, real_c, fake_c):
        self.train_model(discriminator=True)
        real_x, real_c, fake_c = self.to_var(real_x, real_c, fake_c)
        style_fake = to_var(self.random_style(real_x, seed=self.count_seed))
        self.count_seed += 1
        fake_x = self.G(real_x, fake_c, style_fake)[0]
        d_loss_src, d_loss_cls = self._GAN_LOSS(real_x, fake_x, real_c)

        self.loss['Dsrc'] = d_loss_src
        self.loss['Dcls'] = d_loss_cls * self.config.lambda_cls
        d_loss = self.current_losses('D', **self.loss)
        self.reset_grad()
        d_loss.backward()
        self.d_optimizer.step()

    # ============================================================#
    # ============================================================#
    def Gen_update(self, real_x, real_c, fake_c):
        self.train_model(generator=True)
        real_x, real_c, fake_c = self.to_var(real_x, real_c, fake_c)
        criterion_l1 = torch.nn.L1Loss()
        style_fake = to_var(self.random_style(real_x, seed=self.count_seed))
        style_rec = to_var(self.random_style(real_x, seed=self.count_seed + 1))
        style_identity = to_var(
            self.random_style(real_x, seed=self.count_seed + 2))
        self.count_seed += 3

        fake_x = self.G(real_x, fake_c, style_fake)

        g_loss_src, g_loss_cls = self._GAN_LOSS(fake_x[0], real_x, fake_c)
        self.loss['Gsrc'] = g_loss_src
        self.loss['Gcls'] = g_loss_cls * self.config.lambda_cls

        # REC LOSS
        rec_x = self.G(fake_x[0], real_c, style_rec)
        g_loss_rec = criterion_l1(rec_x[0], real_x)
        self.loss['Grec'] = self.config.lambda_rec * g_loss_rec

        # ========== Attention Part ==========#
        self.loss['Gatm'] = self.config.lambda_mask * (
            torch.mean(rec_x[1]) + torch.mean(fake_x[1]))
        self.loss['Gats'] = self.config.lambda_mask_smooth * (
            _compute_loss_smooth(rec_x[1]) + _compute_loss_smooth(fake_x[1]))

        # ========== Identity Part ==========#
        if self.config.Identity:
            idt_x = self.G(real_x, real_c, style_identity)[0]
            g_loss_idt = criterion_l1(idt_x, real_x)
            self.loss['Gidt'] = self.config.lambda_idt * \
                g_loss_idt

        g_loss = self.current_losses('G', **self.loss)
        self.reset_grad()
        g_loss.backward()
        self.g_optimizer.step()

    # ============================================================#
    # ============================================================#
    def run(self):
        # lr cache for decaying
        self.g_lr = self.config.g_lr
        self.d_lr = self.config.d_lr
        self.PRINT('Training with learning rate g_lr: {}, d_lr: {}.'.format(
            self.g_optimizer.param_groups[0]['lr'],
            self.d_optimizer.param_groups[0]['lr']))

        # Start with trained info if exists
        start, self.total_iter = self.RESUME_INFO()

        # Fixed inputs, target domain labels, and style for debugging
        self.fixed_x, self.fixed_label, self.fixed_style = self.debug_vars(
            start)

        self.PRINT("Current time: " + TimeNow())
        self.PRINT("Debug Log txt: " + os.path.realpath(self.config.log.name))

        # Log info
        # RaGAN uses different data for Dis and Gen
        self.Log = self.PRINT_LOG(self.config.batch_size // 2)

        self.start_time = time.time()

        # Start training
        for epoch in range(start, self.config.num_epochs):
            self.D.train()
            self.G.train()
            self.LOSS = {}
            desc_bar = '[Iter: %d] Epoch: %d/%d' % (self.total_iter, epoch,
                                                    self.config.num_epochs)
            epoch_verbose = (epoch % self.config.save_epoch) and epoch != 0
            self.progress_bar = tqdm(
                enumerate(self.data_loader),
                unit_scale=True,
                total=len(self.data_loader),
                desc=desc_bar,
                disable=not self.verbose or epoch_verbose,
                ncols=5)
            for _iter, (real_x, real_c, _) in self.progress_bar:
                self.loss = self.reset_losses()
                self.total_iter += 1 * hvd.size()
                # RaGAN uses different data for Dis and Gen
                real_x0, real_x1 = split(real_x)
                real_c0, real_c1 = split(real_c)
                fake_c = get_fake(real_c, seed=_iter)
                fake_c0, fake_c1 = split(fake_c)

                # ============================================================#
                # ======================== Train D ===========================#
                # ============================================================#
                self.Dis_update(real_x0, real_c0, fake_c0)

                # ============================================================#
                # ======================== Train G ===========================#
                # ============================================================#
                self.Gen_update(real_x1, real_c1, fake_c1)

                # ====================== DEBUG =====================#
                self.INFO(epoch, _iter)

            # ============================================================#
            # ======================= MISCELANEOUS =======================#
            # ============================================================#
            # Shuffling dataset each epoch
            self.data_loader.dataset.shuffle(epoch)
            self.MISC(epoch, _iter)