# -*- coding: utf-8 -*-
# @Author       : William
# @Project      : TextGAN-william
# @FileName     : JSDGAN_instructor.py
# @Time         : Created at 2019/11/25
# @Blog         : http://zhiweil.ml/
# @Description  : 
# Copyrights (C) 2018. All Rights Reserved.

import torch
import torch.optim as optim

import config as cfg
from instructor.real_data.instructor import BasicInstructor
from models.JSDGAN_G import JSDGAN_G


class JSDGANInstructor(BasicInstructor):
    def __init__(self, opt):
        super(JSDGANInstructor, self).__init__(opt)

        # generator
        self.gen = JSDGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim,
                            cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA)
        self.init_model()

        # Optimizer
        self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr)

    def init_model(self):
        if cfg.gen_pretrain:
            self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path))
            self.gen.load_state_dict(torch.load(cfg.pretrained_gen_path, map_location='cuda:{}'.format(cfg.device)))

        if cfg.CUDA:
            self.gen = self.gen.cuda()

    def _run(self):
        # ===PRE-TRAINING===
        # TRAIN GENERATOR
        self.log.info('Starting Generator MLE Training...')
        self.pretrain_generator(cfg.MLE_train_epoch)

        # ===ADVERSARIAL TRAINING===
        self.log.info('Starting Adversarial Training...')

        for adv_epoch in range(cfg.ADV_train_epoch):
            g_loss = self.adv_train_generator(cfg.ADV_g_step)  # Generator

            if adv_epoch % cfg.adv_log_step == 0:
                self.log.info('[ADV] epoch %d: g_loss = %.4f, %s' % (adv_epoch, g_loss, self.cal_metrics(fmt_str=True)))

                if cfg.if_save and not cfg.if_test:
                    self._save('ADV', adv_epoch)

    def _test(self):
        print('>>> Begin test...')

        self._run()
        pass

    def pretrain_generator(self, epochs):
        """
        Max Likelihood Pre-training for the generator
        """
        for epoch in range(epochs):
            self.sig.update()
            if self.sig.pre_sig:
                pre_loss = self.train_gen_epoch(self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt)

                # ===Test===
                if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1:
                    self.log.info(
                        '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True)))
                    if cfg.if_save and not cfg.if_test:
                        self._save('MLE', epoch)
            else:
                self.log.info('>>> Stop by pre signal, skip to adversarial training...')
                break

    def adv_train_generator(self, g_step):
        """
        The gen is trained using policy gradients, using the reward from the discriminator.
        Training is done for num_batches batches.
        """
        global inp, target
        total_loss = 0
        for step in range(g_step):
            for i, data in enumerate(self.train_data.loader):
                inp, target = data['input'], data['target']
                if cfg.CUDA:
                    inp, target = inp.cuda(), target.cuda()

                # ===Train===
                adv_loss = self.gen.JSD_loss(inp, target)
                self.optimize(self.gen_opt, adv_loss, self.gen)
                total_loss += adv_loss.item()

        return total_loss