from collections import namedtuple
import torch.nn as nn
from ...utils import *
from ...learners import AbstractBaseLearner

import torch
from torch.autograd import Variable, grad
import torch.cuda as cuda
from torchvision import utils as vutils
import os
from time import time
import numpy as np
import matplotlib.pyplot as plt
from import DataLoader
from tensorboardX import SummaryWriter
import torch.distributions as distribution
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.onnx as onnx
from math import ceil

r"""This is the base Learner for training, evaluating and performing inference with Generative Adversarial Networks
Goodfellow et al. 2014 (
All custom GAN learners that use a single Generator and a single Discriminator should subclass this 

        gen_model (Module):  the generator module.
        disc_model (Module):  the discriminator module.
        use_cuda_if_available (boolean): If set to true, training would be done on a gpu if any is available"""

class BaseGanLearner(AbstractBaseLearner):
    def __init__(self, gen_model,disc_model,use_cuda_if_available=True):
        self.model_dir = os.getcwd()
        self.gen_model = gen_model
        self.disc_model = disc_model
        self.cuda = False
        if use_cuda_if_available and cuda.is_available():
            self.cuda = True

        self.__train_history__ = {}

        self.gen_optimizer = None
        self.disc_optimizer = None
        self.gen_running_loss = None
        self.disc_running_loss = None

        self.visdom_log = None
        self.tensorboard_log = None

    r"""Initialize generator model weights using pre-trained weights from the filepath

            path (str): path to a compatible pre-defined model

    def load_generator(self, path):

    r"""Initialize discriminator model weights using pre-trained weights from the filepath

            path (str): path to a compatible pre-defined model


    def load_discriminator(self, path):
        load_model(self.disc_model, path)

    r"""Saves the generator model to the path specified
                path (str): path to save model
                save_architecture (boolean): if True, both weights and architecture will be saved, default is False

    def save_generator(self, path,save_architecture=False):

    r"""Saves the discriminator model to the path specified
                path (str): path to save model
                save_architecture (boolean): if True, both weights and architecture will be saved, default is False

    def save_discriminator(self, path,save_architecture=False):
        save_model(self.disc_model, path, save_architecture)

    def train(self,*args):

    def __train_loop__(self, train_loader,gen_optimizer,disc_optimizer,num_epochs=10, disc_steps=1,gen_lr_scheduler=None,disc_lr_scheduler=None, model_dir=os.getcwd(),save_model_interval=1, save_outputs_interval=100,display_outputs=True,notebook_mode=False,batch_log=True,save_logs=None,display_metrics=False,save_metrics=False,visdom_log=None,tensorboard_log=None,save_architecture=False):


        :param train_loader:
        :param gen_optimizer:
        :param disc_optimizer:
        :param num_epochs:
        :param disc_steps:
        :param gen_lr_scheduler:
        :param disc_lr_scheduler:
        :param model_dir:
        :param save_model_interval:
        :param save_outputs_interval:
        :param display_outputs:
        :param notebook_mode:
        :param batch_log:
        :param save_logs:
        :param display_metrics:
        :param save_metrics:
        :param visdom_log:
        :param tensorboard_log:
        :param save_architecture:

        assert(disc_steps < len(train_loader.dataset))

        self.gen_optimizer = gen_optimizer
        self.disc_optimizer = disc_optimizer
        self.tensorboard_log = tensorboard_log
        self.visdom_log = visdom_log

        if not os.path.exists(model_dir):

        self.model_dir = model_dir
        models_gen = os.path.join(model_dir, "gen_models")
        models_disc = os.path.join(model_dir, "disc_models")

        if not os.path.exists(models_gen):

        if not os.path.exists(models_disc):

        iterations = 0

        from tqdm import tqdm_notebook
        from tqdm import tqdm

        train_start_time = time()

        for e in range(num_epochs):
            print("Epoch {} of {}".format(e + 1, num_epochs))

            for func in self.epoch_start_funcs:
                func(e + 1)

            self.gen_running_loss = torch.Tensor([0.0])
            self.disc_running_loss = torch.Tensor([0.0])
            gen_loss = 0
            disc_loss = 0

            gen_data_len = 0
            disc_data_len = 0

            if notebook_mode and batch_log:
                progress_ = tqdm_notebook(enumerate(train_loader))
            elif batch_log:
                progress_ = tqdm(enumerate(train_loader))
                progress_ = enumerate(train_loader)

            max_batch_size = 0

            init_time = time()

            for i,t in progress_:

                for func in self.batch_start_funcs:
                    func(e + 1,i + 1)

                batch_size = get_batch_size(t)
                disc_data_len += batch_size

                if max_batch_size < batch_size:
                    max_batch_size = batch_size


                disc_loss = / disc_data_len

                if (i+1) % disc_steps == 0:
                    gen_data_len += batch_size

                    gen_loss = / gen_data_len

                if batch_log:
                     progress_dict = {"Gen Loss": gen_loss,"Disc Loss":disc_loss}

                iterations += 1

                if iterations % save_outputs_interval == 0:
                    if display_outputs:

                if batch_log:

                    progress_.set_description("{}/{} batches ".format(int(ceil(disc_data_len / max_batch_size)),
                                                                          train_loader.dataset) / max_batch_size))))
                    progress_dict = {"Disc Loss": disc_loss,"Gen Loss":gen_loss}


                batch_info = {"gen_loss":gen_loss , "disc_loss": disc_loss}

                for func in self.batch_end_funcs:
                    func(e + 1,i + 1,batch_info)

            if self.cuda:
            duration = time() - init_time

            if "duration" in self.__train_history__:
                self.__train_history__["duration"] = [duration]

            if gen_lr_scheduler is not None:
                if isinstance(gen_lr_scheduler, ReduceLROnPlateau):

            if disc_lr_scheduler is not None:
                if isinstance(disc_lr_scheduler, ReduceLROnPlateau):

            if "disc_loss" in self.__train_history__:
                self.__train_history__["disc_loss"] = [disc_loss]

            if "gen_loss" in self.__train_history__:
                self.__train_history__["gen_loss"] = [gen_loss]

            if "epoch" in self.__train_history__:
                self.__train_history__["epoch"].append(e + 1)
                self.__train_history__["epoch"] = [e + 1]

            if (e+1) % save_model_interval == 0:

                model_file = os.path.join(models_gen, "gen_model_{}.pth".format(e + 1))
                self.save_generator(model_file, save_architecture)

                print("New Generator model saved at {}".format(model_file))

                model_file = os.path.join(models_disc, "disc_model_{}.pth".format(e + 1))
                self.save_discriminator(model_file, save_architecture)

                print("New Discriminator model saved at {}".format(model_file))

            print("Epoch: {}, Duration: {} , Gen Loss: {} Disc Loss: {}".format(e+1, duration, gen_loss,disc_loss))

            if save_logs is not None:
                logfile = open(save_logs, "a")
                logfile.write("Epoch: {}, Duration: {} , Gen Loss: {} Disc Loss: {}".format(e+1, duration, gen_loss,disc_loss))

            epoch_arr = self.__train_history__["epoch"]
            epoch_arr_tensor = torch.LongTensor(epoch_arr)

            if display_metrics or save_metrics:

                save_path = None

                if save_metrics:
                    save_path = os.path.join(model_dir, "epoch_{}_loss.png".format(e+1))

                visualize(epoch_arr, [PlotInput(value=self.__train_history__["gen_loss"], name="Generator Loss", color="red"),
                                      PlotInput(value=self.__train_history__["disc_loss"], name="Discriminator Loss", color="red")],display=display_metrics,

            if visdom_log is not None:
                visdom_log.plot_line(torch.FloatTensor(self.__train_history__["gen_loss"]), epoch_arr_tensor, win="gen_loss",
                                     title="Generator Loss")

                visdom_log.plot_line(torch.FloatTensor(self.__train_history__["disc_loss"]), epoch_arr_tensor, win="disc_loss",
                                     title="Discriminator Loss")

            if tensorboard_log is not None:
                writer = SummaryWriter(os.path.join(model_dir, tensorboard_log))
                writer.add_scalar("logs/gen_loss", gen_loss, global_step=e+1)
                writer.add_scalar("logs/disc_loss", disc_loss, global_step=e+1)


            epoch_info = {"gen_loss": gen_loss, "disc_loss": disc_loss, "duration": duration}
            for func in self.epoch_end_funcs:
                func(e + 1, epoch_info)

        train_end_time = time() - train_start_time
        train_info = {"train_duration": train_end_time}
        for func in self.train_completed_funcs:

    """ Abstract function containing the training logic for the generator, 
     all custom trainers must override this.

            data: a single batch of data from the train set
    def __gen_train_func__(self,data):
        raise NotImplementedError()

    """ Abstract function containing the training logic for the discriminator, 
         all custom trainers must override this.

                data: a single batch of data from the train set
    def __disc_train_func__(self,data):
        raise NotImplementedError()

    """ Abstract function containing logic to save the outputs of the generator, 
         all custom trainers must override this.

                iterations: total number of batch iterations since the start of training
    def __save__(self,iterations):
        raise NotImplementedError()

    """ Abstract function containing logic to display the outputs of the generator, 
             all custom trainers must override this.

                    iterations: total number of batch iterations since the start of training

    def  evaluate(self, *args):
    def validate(self, *args):

    def __show__(self,iterations):
        raise NotImplementedError()

    """ Returns predictions for a given input tensor or a an instance of a DataLoader
                inputs: input Tensor or DataLoader
    def predict(self, inputs):

        if isinstance(inputs, DataLoader):
            predictions = []
            for i, data in enumerate(inputs):
                batch_pred = self.__predict_func__(data)
                for pred in batch_pred:

            pred = self.__predict_func__(inputs)

            return pred.squeeze(0)

    """ Abstract function containing custom logic for performing inference, must return the output,
            all custom trainers should implement this.
            input: An input tensor or a batch of input tensors
    def __predict_func__(self, input):
        raise NotImplementedError()

    r""" Returns a dictionary containing the values of epochs and loss during training.
    def get_train_history(self):
        return self.__train_history__

class BaseGanCore(BaseGanLearner):
    def __init__(self, gen_model, disc_model, use_cuda_if_available=True):
        super(BaseGanCore,self).__init__(gen_model, disc_model, use_cuda_if_available)

        self.latent_size = None
        self.dist = distribution.Normal(0,1)
        self.fixed_source = None

        self.conditional = False
        self.classes = 0
        self.num_samples = 5

    def __disc_train_func__(self, data):

        for params in self.disc_model.parameters():
            params.requires_grad = True

        for params in self.gen_model.parameters():
            params.requires_grad = False

        if isinstance(data, list) or isinstance(data, tuple):
            x = data[0]

            if self.conditional:
                raise ValueError("Conditional mode is invalid for inputs with no labels, set num_classes to None or provide labels")
            x = data

        batch_size = x.size(0)
        if isinstance(self.latent_size, int):
            source_size = list([self.latent_size])
            source_size = list(self.latent_size)
        cond_samples_size = [self.num_samples] + source_size
        source_size = [batch_size] + source_size

        if self.fixed_source is None:

            if self.conditional:
                self.fixed_source = self.dist.sample(tuple(cond_samples_size))

                self.fixed_source = self.dist.sample(tuple(source_size))

            if self.cuda:
                self.fixed_source = self.fixed_source.cuda()

            self.fixed_source = Variable(self.fixed_source)

    def __gen_train_func__(self, data):

        for params in self.gen_model.parameters():
            params.requires_grad = True

        for params in self.disc_model.parameters():
            params.requires_grad = False

    def __save__(self, iteration):

        save_dir = os.path.join(self.model_dir, "gen_images")

        if os.path.exists(save_dir) == False:
        if self.tensorboard_log is not None:
            writer = SummaryWriter(self.tensorboard_log)

        if self.conditional:

            for i in range(self.classes):
                class_path = os.path.join(save_dir,"class_{}".format(i))
                if not os.path.exists(class_path):
                class_labels = torch.randn((self.num_samples, 1)).type(torch.LongTensor).fill_(i)

                if self.cuda:
                    class_labels = class_labels.cuda()

                outputs = self.gen_model(self.fixed_source, class_labels)
                if self.fp16_mode:
                    outputs = outputs.float()

                images_file = os.path.join(class_path, "iteration{}.png".format(iteration))

                images = vutils.make_grid(outputs.cpu().data, normalize=True)
                vutils.save_image(outputs.cpu().data, images_file, normalize=True)

                if self.tensorboard_log is not None:
                    writer.add_image("logs/gen_images/class_{}".format(i), images, global_step=iteration)

                if self.visdom_log is not None:
                    self.visdom_log.log_images(images, win="class_{}".format(i), title="Class {}".format(i))

            outputs = self.gen_model(self.fixed_source)
            if self.fp16_mode:
                outputs = outputs.float()
            images_file = os.path.join(save_dir, "image_{}.png".format(iteration))

            images = vutils.make_grid(outputs.cpu().data, normalize=True)
            vutils.save_image(outputs.cpu().data, images_file, normalize=True)

            if self.tensorboard_log is not None:
                writer.add_image("logs/gen_images", images, global_step=iteration)

            if self.visdom_log is not None:
                self.visdom_log.log_images(images, win="gen_images", title="Generated Images")
        if self.tensorboard_log:

    def __show__(self, iteration):

        if self.conditional:
            for i in range(self.classes):
                class_labels = torch.randn((self.num_samples, 1)).type(torch.LongTensor).fill_(i)

                if self.cuda:
                    class_labels = class_labels.cuda()
                outputs = self.gen_model(self.fixed_source, class_labels)

                if self.fp16_mode:
                    outputs = outputs.float()

                images = vutils.make_grid(outputs.cpu().data, normalize=True)
                images = np.transpose(images.numpy(), (1, 2, 0))
                plt.subplot(self.classes, 1, i + 1)
                # plt.title("class {}".format(i))


            outputs = self.gen_model(self.fixed_source)
            if self.fp16_mode:
                outputs = outputs.float()

            images = vutils.make_grid(outputs.cpu().data, normalize=True)

            images = np.transpose(images.numpy(), (1, 2, 0))

    def __predict_func__(self, data):
        labels = None
        if isinstance(data, list) or isinstance(data, tuple):
            source = data[0]
            labels = data[1]
            source = data

        if self.cuda:
            source = source.cuda()
            if labels is not None:
                labels = labels.cuda()

        if labels is not None:
            labels = labels.unsqueeze(1) if len(labels.size()) == 1 else labels
            outputs = self.gen_model(source, labels)
            outputs = self.gen_model(source)

        return outputs

    def gen_summary(self, input_size, label=None, input_type=torch.FloatTensor, item_length=26, tensorboard_log=None):
        input = torch.randn(input_size).type(input_type).unsqueeze(0)
        inputs = input.cuda() if self.cuda else input
        if label is not None:
            label = torch.randn(1, 1).fill_(label).long()
            label = label.cuda() if self.cuda else label
            return get_model_summary(self.gen_model, inputs,label, item_length=item_length, tensorboard_log=tensorboard_log)
            return get_model_summary(self.gen_model, inputs, item_length=item_length,tensorboard_log=tensorboard_log)

    def disc_summary(self, input_size, label=None, input_type=torch.FloatTensor, item_length=26, tensorboard_log=None):
        input = torch.randn(input_size).type(input_type).unsqueeze(0)
        inputs = input.cuda() if self.cuda else input
        if label is not None:
            label = torch.randn(1, 1).fill_(label).long()
            label = label.cuda() if self.cuda else label
            return get_model_summary(self.gen_model, inputs, label, item_length=item_length,
            return get_model_summary(self.gen_model, inputs, item_length=item_length, tensorboard_log=tensorboard_log)

    def gen_to_onnx(self, path, input_size, label=None, input_type=torch.FloatTensor, **kwargs):
        input = torch.randn(input_size).type(input_type).unsqueeze(0)
        if label is None:
            inputs = Variable(input.cuda() if self.cuda else input)
            label = torch.randn(1, 1).fill_(label)
            inputs = [Variable(input.cuda() if self.cuda else input), label.cuda() if self.cuda else label]

        return onnx._export(self.gen_model, inputs, f=path, **kwargs)

    def disc_to_onnx(self, path, input_size, label=None, input_type=torch.FloatTensor, **kwargs):
        input = torch.randn(input_size).type(input_type).unsqueeze(0)
        if label is None:
            inputs = Variable(input.cuda() if self.cuda else input)
            label = torch.randn(1, 1).fill_(label)
            inputs = [Variable(input.cuda() if self.cuda else input), label.cuda() if self.cuda else label]

        return onnx._export(self.disc_model, inputs, f=path, **kwargs)

class StandardBaseGanLearner(BaseGanCore):

    def train(self, train_loader, gen_optimizer, disc_optimizer, latent_size,relative_mode=True, dist=distribution.Normal(0, 1),
                  num_classes=0, num_samples=5,**kwargs):

        self.latent_size = latent_size
        self.dist = dist
        self.classes = num_classes
        self.num_samples = num_samples
        self.conditional = (num_classes > 0)
        self.relative_mode = relative_mode

        super().__train_loop__(train_loader, gen_optimizer, disc_optimizer, **kwargs)

    def __disc_train_func__(self, data):



        if isinstance(data, list) or isinstance(data, tuple):
            x = data[0]
            if self.conditional:
                class_labels = data[1]
            x = data

        batch_size = x.size(0)
        if isinstance(self.latent_size, int):
            source_size = list([self.latent_size])
            source_size = list(self.latent_size)
        source_size = [batch_size] + source_size

        source = self.dist.sample(tuple(source_size))

        if self.cuda:
            x = x.cuda()
            source = source.cuda()

            if self.conditional:
                class_labels = class_labels.cuda()

        x = Variable(x)
        source = Variable(source)
        if self.conditional:
            class_labels = class_labels.unsqueeze(1) if len(class_labels.size())==1 else class_labels

        if self.conditional:
            outputs = self.disc_model(x, class_labels)
            outputs = self.disc_model(x)

        if self.conditional:
            random_labels = torch.from_numpy(np.random.randint(0, self.classes, size=(batch_size, 1))).long()
            if self.cuda:
                random_labels = random_labels.cuda()


            generated = self.gen_model(source, random_labels)
            gen_outputs = self.disc_model(generated.detach(), random_labels)

            generated = self.gen_model(source)
            gen_outputs = self.disc_model(generated.detach())

        loss = self.__update_discriminator_loss__(x,generated,outputs,gen_outputs)
        if self.fp16_mode:
        self.disc_running_loss = self.disc_running_loss + (loss.cpu().item() * batch_size)

    def __gen_train_func__(self, data):



        if isinstance(data, list) or isinstance(data, tuple):
            x = data[0]
            if self.conditional and self.relative_mode:
                class_labels = data[1].unsqueeze(1) if len(data[1]) == 1 else data[1]
            x = data
        batch_size = x.size(0)

        if isinstance(self.latent_size, int):
            source_size = list([self.latent_size])
            source_size = list(self.latent_size)

        source_size = [batch_size] + source_size

        source = self.dist.sample(tuple(source_size))
        if self.conditional:
            random_labels = torch.from_numpy(np.random.randint(0, self.classes, size=(batch_size, 1))).long()
            if self.cuda:
                random_labels = random_labels.cuda()


        if self.cuda:
            source = source.cuda()

            x = x.cuda()
            if self.conditional and self.relative_mode:
                class_labels = class_labels.cuda()

        source = Variable(source)
        x = Variable(x)

        if self.conditional:
            fake_images = self.gen_model(source, random_labels)
            outputs = self.disc_model(fake_images, random_labels)
            if self.relative_mode:
                real_outputs = self.disc_model(x, class_labels)
            fake_images = self.gen_model(source)
            outputs = self.disc_model(fake_images)
            if self.relative_mode:
                real_outputs = self.disc_model(x)

        if not self.relative_mode:
            real_outputs = None

        loss = self.__update_generator_loss__(x,fake_images,real_outputs,outputs)

        if self.fp16_mode:

        self.gen_running_loss = self.gen_running_loss + (loss.cpu().item() * batch_size)

    def __update_generator_loss__(self,real_images,gen_images,real_preds,gen_preds):
        raise NotImplementedError()

    def __update_discriminator_loss__(self,real_images,gen_images,real_preds,gen_preds):
        raise NotImplementedError()

class StandardGanLearner(StandardBaseGanLearner):
    def __init__(self, gen_model, disc_model, use_cuda_if_available=True):
        super(StandardGanLearner, self).__init__(gen_model, disc_model, use_cuda_if_available)

    r"""Training function

                    train_loader (DataLoader): an instance of DataLoader containing the training set
                    gen_optimizer (Optimizer): an optimizer for updating parameters of the generator
                    disc_optimizer (Optimizer): an optimizer for updating parameters of the discriminator
                    latent_size (int): the size of the latent variable to be fed into the generator
                    gen_loss_fn: the generator loss function
                    disc_loss_fn: the generator loss function
                    num_epochs (int): The maximum number of training epochs
                    disc_steps (int): The number of times to train the discriminator before training generator
                    save_models (str): If all, the model is saved at the end of each epoch while the best models based 
                    gen_lr_scheduler (_LRScheduler): Learning rate scheduler for the generator
                    disc_lr_scheduler (_LRScheduler): Learning rate sheduler for the discriminator
                        on the test set are also saved in best_models folder
                        If 'best', only the best models are saved,  test_loader and test_metrics must be provided
                    model_dir (str) = a path in which to save the models
                    save_model_interval (int): saves the models after every n epoch
                    save_outputs_interval (int): saves sample outputs from the generator after every n iterations
                    notebook_mode (boolean): Optimizes the progress bar for either jupyter notebooks or consoles
                    display_metrics (boolean): Enables display of metrics and loss visualizations at the end of each epoch.
                    save_metrics (boolean): Enables saving of metrics and loss visualizations at the end of each epoch.
                    batch_log (boolean): Enables printing of logs at every batch iteration
                    save_logs (str): Specifies a filepath in which to permanently save logs at every epoch
                    visdom_log (VisdomLogger): Logs outputs and metrics to the visdom server
                    tensorboard_log (str): Logs outputs and metrics to the filepath for visualization in tensorboard
                    save_architecture (boolean): Saves the architecture as well as weights during model saving
                    dist: A distribution from torch.distribution, used as the source of the latent vector fed to the generator
                    real_labels (int): The value to be used for the real images
                    fake_labels (int): The value to be used for the generated images
                    num_classes (int): The number of classes for conditional generation
                    num_samples (int): The number of samples to be generated per class in the conditional setting

    def train(self,train_loader, gen_optimizer,disc_optimizer,latent_size=100,gen_loss_fn=nn.BCELoss(), disc_loss_fn=nn.BCELoss(), real_labels=1,fake_labels=0,**kwargs):

        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.real_labels = real_labels
        self.fake_labels = fake_labels


    def __update_discriminator_loss__(self,x,gen_images,real_preds,gen_preds):

        batch_size = x.size(0)
        real_labels = torch.randn(batch_size, 1).fill_(self.real_labels)
        fake_labels = torch.randn(batch_size, 1).fill_(self.fake_labels)

        real_labels = real_labels.cuda() if self.cuda else real_labels
        fake_labels = fake_labels.cuda() if self.cuda else fake_labels

        real_loss = self.disc_loss_fn(real_preds,real_labels)

        gen_loss = self.gen_loss_fn(gen_preds,fake_labels)

        return real_loss + gen_loss

    def __update_generator_loss__(self,x,gen_images,real_preds,gen_preds):
        batch_size = x.size(0)
        real_labels = torch.ones(batch_size, 1)

        real_labels = real_labels.cuda() if self.cuda else real_labels

        loss = self.gen_loss_fn(gen_preds,real_labels)

        return loss

""" Standard GANS that use the relativistic discriminator
    See Alexia Jolicoeur-Martineau. 2018 (
        gen_model:  the generator module.
        disc_model:  the discriminator module.
        use_cuda_if_available: If set to true, training would be done on a gpu if any is available

class RStandardGanLearner(StandardBaseGanLearner):
    def __init__(self, gen_model, disc_model, use_cuda_if_available=True):
        super(RStandardGanLearner, self).__init__(gen_model, disc_model, use_cuda_if_available)

    r"""Training function

                       train_loader (DataLoader): an instance of DataLoader containing the training set
                       gen_optimizer (Optimizer): an optimizer for updating parameters of the generator
                       disc_optimizer (Optimizer): an optimizer for updating parameters of the discriminator
                       latent_size (int): the size of the latent variable to be fed into the generator
                       gen_loss_fn: the generator loss function
                       disc_loss_fn: the generator loss function
                       num_epochs (int): The maximum number of training epochs
                       disc_steps (int): The number of times to train the discriminator before training generator
                       save_models (str): If all, the model is saved at the end of each epoch while the best models based 
                       gen_lr_scheduler (_LRScheduler): Learning rate scheduler for the generator
                       disc_lr_scheduler (_LRScheduler): Learning rate sheduler for the discriminator
                           on the test set are also saved in best_models folder
                           If 'best', only the best models are saved,  test_loader and test_metrics must be provided
                       model_dir (str) = a path in which to save the models
                       save_model_interval (int): saves the models after every n epoch
                       save_outputs_interval (int): saves sample outputs from the generator after every n iterations
                       notebook_mode (boolean): Optimizes the progress bar for either jupyter notebooks or consoles
                       display_metrics (boolean): Enables display of metrics and loss visualizations at the end of each epoch.
                       save_metrics (boolean): Enables saving of metrics and loss visualizations at the end of each epoch.
                       batch_log (boolean): Enables printing of logs at every batch iteration
                       save_logs (str): Specifies a filepath in which to permanently save logs at every epoch
                       visdom_log (VisdomLogger): Logs outputs and metrics to the visdom server
                       tensorboard_log (str): Logs outputs and metrics to the filepath for visualization in tensorboard
                       save_architecture (boolean): Saves the architecture as well as weights during model saving
                       dist: A distribution from torch.distribution, used as the source of the latent vector fed to the generator
                       real_labels (int): The value to be used for the real images
                       fake_labels (int): The value to be used for the generated images
                       num_classes (int): The number of classes for conditional generation
                       num_samples (int): The number of samples to be generated per class in the conditional setting

    def train(self,train_loader, gen_optimizer,disc_optimizer,latent_size=100,gen_loss_fn=nn.BCEWithLogitsLoss(), disc_loss_fn=nn.BCEWithLogitsLoss(),**kwargs):

        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn


    def __update_discriminator_loss__(self,x,gen_images,real_preds,gen_preds):

        batch_size = x.size(0)
        real_labels = torch.ones(batch_size, 1)

        real_labels = real_labels.cuda() if self.cuda else real_labels

        loss = self.disc_loss_fn(real_preds - gen_preds,real_labels)

        return loss

    def __update_generator_loss__(self,x,gen_images,real_preds,gen_preds):
        batch_size = x.size(0)
        real_labels = torch.ones(batch_size, 1)

        real_labels = real_labels.cuda() if self.cuda else real_labels

        loss = self.gen_loss_fn(gen_preds - real_preds,real_labels)

        return loss

""" Standard GANS that use the average relativistic discriminator
    See Alexia Jolicoeur-Martineau. 2018 (
        gen_model:  the generator module.
        disc_model:  the discriminator module.
        use_cuda_if_available: If set to true, training would be done on a gpu if any is available
class RAvgStandardGanLearner(StandardBaseGanLearner):
    def __init__(self, gen_model, disc_model, use_cuda_if_available=True):
        super(RAvgStandardGanLearner, self).__init__(gen_model, disc_model, use_cuda_if_available)

    r"""Training function

                       train_loader (DataLoader): an instance of DataLoader containing the training set
                       gen_optimizer (Optimizer): an optimizer for updating parameters of the generator
                       disc_optimizer (Optimizer): an optimizer for updating parameters of the discriminator
                       latent_size (int): the size of the latent variable to be fed into the generator
                       gen_loss_fn: the generator loss function
                       disc_loss_fn: the generator loss function
                       num_epochs (int): The maximum number of training epochs
                       disc_steps (int): The number of times to train the discriminator before training generator
                       save_models (str): If all, the model is saved at the end of each epoch while the best models based 
                       gen_lr_scheduler (_LRScheduler): Learning rate scheduler for the generator
                       disc_lr_scheduler (_LRScheduler): Learning rate sheduler for the discriminator
                           on the test set are also saved in best_models folder
                           If 'best', only the best models are saved,  test_loader and test_metrics must be provided
                       model_dir (str) = a path in which to save the models
                       save_model_interval (int): saves the models after every n epoch
                       save_outputs_interval (int): saves sample outputs from the generator after every n iterations
                       notebook_mode (boolean): Optimizes the progress bar for either jupyter notebooks or consoles
                       display_metrics (boolean): Enables display of metrics and loss visualizations at the end of each epoch.
                       save_metrics (boolean): Enables saving of metrics and loss visualizations at the end of each epoch.
                       batch_log (boolean): Enables printing of logs at every batch iteration
                       save_logs (str): Specifies a filepath in which to permanently save logs at every epoch
                       visdom_log (VisdomLogger): Logs outputs and metrics to the visdom server
                       tensorboard_log (str): Logs outputs and metrics to the filepath for visualization in tensorboard
                       save_architecture (boolean): Saves the architecture as well as weights during model saving
                       dist: A distribution from torch.distribution, used as the source of the latent vector fed to the generator
                       real_labels (int): The value to be used for the real images
                       fake_labels (int): The value to be used for the generated images
                       num_classes (int): The number of classes for conditional generation
                       num_samples (int): The number of samples to be generated per class in the conditional setting

    def train(self,train_loader, gen_optimizer,disc_optimizer,latent_size=100,gen_loss_fn=nn.BCEWithLogitsLoss(), disc_loss_fn=nn.BCEWithLogitsLoss(),**kwargs):

        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn


    def __update_discriminator_loss__(self,x,gen_images,real_preds,gen_preds):

        batch_size = x.size(0)
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        real_labels = real_labels.cuda() if self.cuda else real_labels
        fake_labels = fake_labels.cuda() if self.cuda else fake_labels

        loss = (self.disc_loss_fn(real_preds - torch.mean(gen_preds),real_labels) + self.gen_loss_fn(gen_preds - torch.mean(real_preds),fake_labels))/2

        return loss

    def __update_generator_loss__(self,x,gen_images,real_preds,gen_preds):
        batch_size = x.size(0)
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        real_labels = real_labels.cuda() if self.cuda else real_labels
        fake_labels = fake_labels.cuda() if self.cuda else fake_labels

        loss = (self.gen_loss_fn(real_preds - torch.mean(gen_preds), fake_labels) + self.disc_loss_fn(gen_preds - torch.mean(real_preds),real_labels))/2

        return loss

""" GANS that use the hinge loss
    See Ruohan Wang. 2018 (
        gen_model:  the generator module.
        disc_model:  the discriminator module.
        use_cuda_if_available: If set to true, training would be done on a gpu if any is available

class HingeGanLearner(StandardBaseGanLearner):
    def __init__(self, gen_model, disc_model, use_cuda_if_available=True):
        super(HingeGanLearner, self).__init__(gen_model, disc_model, use_cuda_if_available)

    r"""Training function

                       train_loader (DataLoader): an instance of DataLoader containing the training set
                       gen_optimizer (Optimizer): an optimizer for updating parameters of the generator
                       disc_optimizer (Optimizer): an optimizer for updating parameters of the discriminator
                       latent_size (int): the size of the latent variable to be fed into the generator
                       num_epochs (int): The maximum number of training epochs
                       disc_steps (int): The number of times to train the discriminator before training generator
                       save_models (str): If all, the model is saved at the end of each epoch while the best models based 
                       gen_lr_scheduler (_LRScheduler): Learning rate scheduler for the generator
                       disc_lr_scheduler (_LRScheduler): Learning rate sheduler for the discriminator
                           on the test set are also saved in best_models folder
                           If 'best', only the best models are saved,  test_loader and test_metrics must be provided
                       model_dir (str) = a path in which to save the models
                       save_model_interval (int): saves the models after every n epoch
                       save_outputs_interval (int): saves sample outputs from the generator after every n iterations
                       notebook_mode (boolean): Optimizes the progress bar for either jupyter notebooks or consoles
                       display_metrics (boolean): Enables display of metrics and loss visualizations at the end of each epoch.
                       save_metrics (boolean): Enables saving of metrics and loss visualizations at the end of each epoch.
                       batch_log (boolean): Enables printing of logs at every batch iteration
                       save_logs (str): Specifies a filepath in which to permanently save logs at every epoch
                       visdom_log (VisdomLogger): Logs outputs and metrics to the visdom server
                       tensorboard_log (str): Logs outputs and metrics to the filepath for visualization in tensorboard
                       save_architecture (boolean): Saves the architecture as well as weights during model saving
                       dist: A distribution from torch.distribution, used as the source of the latent vector fed to the generator
                       num_classes (int): The number of classes for conditional generation
                       num_samples (int): The number of samples to be generated per class in the conditional setting
    def train(self,train_loader, gen_optimizer, disc_optimizer, latent_size, dist=distribution.Normal(0, 1),
                  num_classes=0, num_samples=5,**kwargs):

        super(HingeGanLearner,self).train(train_loader, gen_optimizer, disc_optimizer, latent_size,False, dist,
                  num_classes, num_samples,**kwargs)

    def __update_discriminator_loss__(self,x,gen_images,real_preds,gen_preds):

        return torch.mean(F.relu(1-real_preds)) + torch.mean(F.relu(1+gen_preds))

    def __update_generator_loss__(self,x,gen_images,real_preds,gen_preds):

        return -torch.mean(gen_preds)

""" Hinge GANS that use the relativistic discriminator
    See Alexia Jolicoeur-Martineau. 2018 (
        gen_model:  the generator module.
        disc_model:  the discriminator module.
        use_cuda_if_available: If set to true, training would be done on a gpu if any is available
class RHingeGanLearner(StandardBaseGanLearner):
    def __init__(self, gen_model, disc_model, use_cuda_if_available=True):
        super(RHingeGanLearner, self).__init__(gen_model, disc_model, use_cuda_if_available)

    r"""Training function

                           train_loader (DataLoader): an instance of DataLoader containing the training set
                           gen_optimizer (Optimizer): an optimizer for updating parameters of the generator
                           disc_optimizer (Optimizer): an optimizer for updating parameters of the discriminator
                           latent_size (int): the size of the latent variable to be fed into the generator
                           num_epochs (int): The maximum number of training epochs
                           disc_steps (int): The number of times to train the discriminator before training generator
                           save_models (str): If all, the model is saved at the end of each epoch while the best models based 
                           gen_lr_scheduler (_LRScheduler): Learning rate scheduler for the generator
                           disc_lr_scheduler (_LRScheduler): Learning rate sheduler for the discriminator
                               on the test set are also saved in best_models folder
                               If 'best', only the best models are saved,  test_loader and test_metrics must be provided
                           model_dir (str) = a path in which to save the models
                           save_model_interval (int): saves the models after every n epoch
                           save_outputs_interval (int): saves sample outputs from the generator after every n iterations
                           notebook_mode (boolean): Optimizes the progress bar for either jupyter notebooks or consoles
                           display_metrics (boolean): Enables display of metrics and loss visualizations at the end of each epoch.
                           save_metrics (boolean): Enables saving of metrics and loss visualizations at the end of each epoch.
                           batch_log (boolean): Enables printing of logs at every batch iteration
                           save_logs (str): Specifies a filepath in which to permanently save logs at every epoch
                           visdom_log (VisdomLogger): Logs outputs and metrics to the visdom server
                           tensorboard_log (str): Logs outputs and metrics to the filepath for visualization in tensorboard
                           save_architecture (boolean): Saves the architecture as well as weights during model saving
                           dist: A distribution from torch.distribution, used as the source of the latent vector fed to the generator
                           num_classes (int): The number of classes for conditional generation
                           num_samples (int): The number of samples to be generated per class in the conditional setting

    def train(self,train_loader, gen_optimizer, disc_optimizer, latent_size,dist=distribution.Normal(0, 1),
                  num_classes=0, num_samples=5,**kwargs):

        super(RHingeGanLearner,self).train(train_loader, gen_optimizer, disc_optimizer, latent_size,True, dist,
                  num_classes, num_samples,**kwargs)

    def __update_discriminator_loss__(self,x,gen_images,real_preds,gen_preds):

        return torch.mean(F.relu(1-real_preds)) + torch.mean(F.relu(1+gen_preds))

    def __update_generator_loss__(self,x,gen_images,real_preds,gen_preds):

        return torch.mean(F.relu(1+real_preds)) + torch.mean(F.relu(1-gen_preds))

""" Hinge GANS that use the average relativistic discriminator
    See Alexia Jolicoeur-Martineau. 2018 (
        gen_model:  the generator module.
        disc_model:  the discriminator module.
        use_cuda_if_available: If set to true, training would be done on a gpu if any is available

class RAvgHingeGanLearner(StandardBaseGanLearner):
    def __init__(self, gen_model, disc_model, use_cuda_if_available=True):
        super(RAvgHingeGanLearner, self).__init__(gen_model, disc_model, use_cuda_if_available)

    r"""Training function

                           train_loader (DataLoader): an instance of DataLoader containing the training set
                           gen_optimizer (Optimizer): an optimizer for updating parameters of the generator
                           disc_optimizer (Optimizer): an optimizer for updating parameters of the discriminator
                           latent_size (int): the size of the latent variable to be fed into the generator
                           num_epochs (int): The maximum number of training epochs
                           disc_steps (int): The number of times to train the discriminator before training generator
                           save_models (str): If all, the model is saved at the end of each epoch while the best models based 
                           gen_lr_scheduler (_LRScheduler): Learning rate scheduler for the generator
                           disc_lr_scheduler (_LRScheduler): Learning rate sheduler for the discriminator
                               on the test set are also saved in best_models folder
                               If 'best', only the best models are saved,  test_loader and test_metrics must be provided
                           model_dir (str) = a path in which to save the models
                           save_model_interval (int): saves the models after every n epoch
                           save_outputs_interval (int): saves sample outputs from the generator after every n iterations
                           notebook_mode (boolean): Optimizes the progress bar for either jupyter notebooks or consoles
                           display_metrics (boolean): Enables display of metrics and loss visualizations at the end of each epoch.
                           save_metrics (boolean): Enables saving of metrics and loss visualizations at the end of each epoch.
                           batch_log (boolean): Enables printing of logs at every batch iteration
                           save_logs (str): Specifies a filepath in which to permanently save logs at every epoch
                           visdom_log (VisdomLogger): Logs outputs and metrics to the visdom server
                           tensorboard_log (str): Logs outputs and metrics to the filepath for visualization in tensorboard
                           save_architecture (boolean): Saves the architecture as well as weights during model saving
                           dist: A distribution from torch.distribution, used as the source of the latent vector fed to the generator
                           num_classes (int): The number of classes for conditional generation
                           num_samples (int): The number of samples to be generated per class in the conditional setting

    def train(self,train_loader, gen_optimizer, disc_optimizer, latent_size, dist=distribution.Normal(0, 1),
                  num_classes=0, num_samples=5,**kwargs):

        super(RAvgHingeGanLearner,self).train(train_loader, gen_optimizer, disc_optimizer, latent_size,True, dist,
                  num_classes, num_samples,**kwargs)

    def __update_discriminator_loss__(self,x,gen_images,real_preds,gen_preds):

        return (torch.mean(F.relu(1-(real_preds - torch.mean(gen_preds)))) + torch.mean(F.relu(1+(gen_preds - torch.mean(real_preds)))))/2

    def __update_generator_loss__(self,x,gen_images,real_preds,gen_preds):
        return (torch.mean(F.relu(1 + (real_preds - torch.mean(gen_preds)))) + torch.mean(
            F.relu(1 - (gen_preds - torch.mean(real_preds))))) / 2

""" GANS that use the Improved Wasserstein Distance
    See Gulrajani et al. 2017 (
    based on earlier work by Arjovsky et al. 2017 (
        gen_model:  the generator module.
        disc_model:  the discriminator module.
        use_cuda_if_available: If set to true, training would be done on a gpu if any is available

class WGanLearner(BaseGanCore):

    r"""Training function

                           train_loader (DataLoader): an instance of DataLoader containing the training set
                           gen_optimizer (Optimizer): an optimizer for updating parameters of the generator
                           disc_optimizer (Optimizer): an optimizer for updating parameters of the discriminator
                           latent_size (int): the size of the latent variable to be fed into the generator
                           num_epochs (int): The maximum number of training epochs
                           disc_steps (int): The number of times to train the discriminator before training generator
                           save_models (str): If all, the model is saved at the end of each epoch while the best models based 
                           gen_lr_scheduler (_LRScheduler): Learning rate scheduler for the generator
                           disc_lr_scheduler (_LRScheduler): Learning rate sheduler for the discriminator
                               on the test set are also saved in best_models folder
                               If 'best', only the best models are saved,  test_loader and test_metrics must be provided
                           model_dir (str) = a path in which to save the models
                           save_model_interval (int): saves the models after every n epoch
                           save_outputs_interval (int): saves sample outputs from the generator after every n iterations
                           notebook_mode (boolean): Optimizes the progress bar for either jupyter notebooks or consoles
                           display_metrics (boolean): Enables display of metrics and loss visualizations at the end of each epoch.
                           save_metrics (boolean): Enables saving of metrics and loss visualizations at the end of each epoch.
                           batch_log (boolean): Enables printing of logs at every batch iteration
                           save_logs (str): Specifies a filepath in which to permanently save logs at every epoch
                           visdom_log (VisdomLogger): Logs outputs and metrics to the visdom server
                           tensorboard_log (str): Logs outputs and metrics to the filepath for visualization in tensorboard
                           save_architecture (boolean): Saves the architecture as well as weights during model saving
                           dist: A distribution from torch.distribution, used as the source of the latent vector fed to the generator
                           num_classes (int): The number of classes for conditional generation
                           num_samples (int): The number of samples to be generated per class in the conditional setting

    def train(self,train_loader, gen_optimizer,disc_optimizer,latent_size, dist=distribution.Normal(0,1),
                 num_classes=0, num_samples=5,lambda_ = 0.25, **kwargs):

        self.lambda_ = lambda_
        self.latent_size = latent_size
        self.dist = dist

        self.conditional = (num_classes is not None)
        self.classes = num_classes
        self.num_samples = num_samples


    def __disc_train_func__(self, data):



        if isinstance(data, list) or isinstance(data, tuple):
            x = data[0]
            if self.conditional:
                class_labels = data[1]
            x = data

        batch_size = x.size(0)
        if isinstance(self.latent_size, int):
            source_size = list([self.latent_size])
            source_size = list(self.latent_size)
        source_size = [batch_size] + source_size

        source = self.dist.sample(tuple(source_size))

        if self.cuda:
            x = x.cuda()
            source = source.cuda()

            if self.conditional:
                class_labels = class_labels.cuda()

        if self.conditional:
            class_labels = class_labels.unsqueeze(1) if len(class_labels.size()) == 1 else class_labels

        if self.conditional:
            outputs = self.disc_model(x, class_labels)
            outputs = self.disc_model(x)

        if self.conditional:
            random_labels = torch.from_numpy(np.random.randint(0, self.classes, size=(batch_size, 1))).long()
            if self.cuda:
                random_labels = random_labels.cuda()

            generated = self.gen_model(source, random_labels)
            gen_outputs = self.disc_model(generated.detach(), random_labels)

            generated = self.gen_model(source)
            gen_outputs = self.disc_model(generated.detach())

        gen_loss = torch.mean(gen_outputs)

        real_loss = -torch.mean(outputs)

        eps = torch.randn(x.size()).uniform_(0, 1)

        if self.cuda:
            eps = eps.cuda()

        x__ = Variable(eps * + (1.0 - eps) * generated.detach().data, requires_grad=True)

        if self.conditional:
            pred__ = self.disc_model(x__,class_labels)
            pred__ = self.disc_model(x__)

        grad_outputs = torch.ones(pred__.size())

        if self.cuda:
            grad_outputs = grad_outputs.cuda()

        gradients = grad(outputs=pred__, inputs=x__, grad_outputs=grad_outputs, create_graph=True, retain_graph=True,

        gradient_penalty = self.lambda_ * ((gradients.view(gradients.size(0), -1).norm(2, 1) - 1) ** 2).mean()

        loss = gen_loss + real_loss + gradient_penalty
        if self.fp16_mode:
        self.disc_running_loss = self.disc_running_loss + (loss.cpu().item() * batch_size)

    def __gen_train_func__(self, data):



        if isinstance(data, list) or isinstance(data, tuple):
            x = data[0]
            if self.conditional:
                class_labels = data[1].unsqueeze(1) if len(data[1].size()) == 1 else data[1]
            x = data
        batch_size = x.size(0)

        if isinstance(self.latent_size, int):
            source_size = list([self.latent_size])
            source_size = list(self.latent_size)

        source_size = [batch_size] + source_size

        source = self.dist.sample(tuple(source_size))
        if self.conditional:
            random_labels = torch.from_numpy(np.random.randint(0, self.classes, size=(batch_size, 1))).long()
            if self.cuda:
                random_labels = random_labels.cuda()

            random_labels = random_labels

        if self.cuda:
            source = source.cuda()

            x = x.cuda()
            if self.conditional:
                class_labels = class_labels.cuda()

        if self.conditional:
            fake_images = self.gen_model(source, random_labels)
            outputs = self.disc_model(fake_images, random_labels)

            fake_images = self.gen_model(source)
            outputs = self.disc_model(fake_images)

        loss = -torch.mean(outputs)

        if self.fp16_mode:


        self.gen_running_Loss = self.gen_running_loss + (loss.cpu().item() * batch_size)