# coding: utf8
import os
import pickle
from datetime import datetime

import keras.backend as K
import numpy as np
from keras.engine.topology import Container
from keras.optimizers import Adam

from began.config import BEGANConfig
from began.generate_image import generate
from began.models import build_model, load_model_weight

ESC = chr(0x1b)
UP = ESC+"[1A"


def main():
    config = BEGANConfig()
    training(config, epochs=500)


def training(config: BEGANConfig, epochs=3):
    # loading dataset, and let values to 0.0 ~ 1.0
    with open(config.dataset_filename, 'rb') as f:
        dataset = pickle.load(f)  # type: np.ndarray
        dataset = dataset / 255.0
        info(dataset.shape)

    batch_size = config.batch_size

    # building model and loading weight(if exists)
    autoencoder, generator, discriminator = build_model(config)
    load_model_weight(autoencoder, config.autoencoder_weight_filename)
    load_model_weight(generator, config.generator_weight_filename)
    load_model_weight(discriminator, config.discriminator_weight_filename)

    loss_d = DiscriminatorLoss(config.initial_k)  # special? loss object for BEGAN
    discriminator.compile(optimizer=Adam(), loss=loss_d)
    generator.compile(optimizer=Adam(), loss=create_generator_loss(autoencoder))
    lr_decay_step = 0
    last_m_global = np.Inf
    log_recorder = LogRecorder(config.training_log)

    # print("Generator Update Variables")
    # print_model_updates(generator)
    #
    # print("Discriminator Update Variables")
    # print_model_updates(discriminator)

    for ep in range(1, epochs+1):
        np.random.seed(ep * 100)
        # generate Z layer values for discriminator and generator
        zd = np.random.uniform(-1, 1, (len(dataset), config.hidden_size))
        zg = np.random.uniform(-1, 1, (len(dataset), config.hidden_size))

        # shuffle dataset index
        index_order = np.arange(len(dataset))
        np.random.shuffle(index_order)

        # set Learning Rate
        lr = max(config.initial_lr * (config.lr_decay_rate ** lr_decay_step), config.min_lr)
        K.set_value(generator.optimizer.lr, lr)
        K.set_value(discriminator.optimizer.lr, lr)
        m_global_history = []
        info("LearningRate=%.7f" % lr)
        batch_len = len(dataset)//batch_size

        for b_idx in range(batch_len):
            index_list = index_order[b_idx*batch_size:(b_idx+1)*batch_size]

            # training discriminator
            in_x1 = dataset[index_list]  # (bs, row, col, ch)
            in_x2 = generator.predict_on_batch(zd[index_list])
            in_x = np.concatenate([in_x1, in_x2], axis=-1)  # (bs, row, col, ch*2)
            loss_discriminator = discriminator.train_on_batch(in_x, in_x)

            # training generator
            in_x1 = zg[index_list]
            loss_generator = generator.train_on_batch(in_x1, np.zeros_like(in_x2))  # y_true is meaningless

            # record M-Global
            m_global_history.append(loss_d.m_global)
            if b_idx > 0:
                print(UP + UP)
            log_info = dict(
                epoch=ep,
                batch_index=b_idx,
                batch_len=batch_len,
                m_global=loss_d.m_global,
                loss_discriminator=loss_discriminator,
                loss_generator=loss_generator,
                loss_real_x=loss_d.loss_real_x,
                loss_gen_x=loss_d.loss_gen_x,
                k=loss_d.k,
                lr=lr,
            )
            info("ep=%(epoch)s, b_idx=%(batch_index)s/%(batch_len)s, MGlobal=%(m_global).5f, "
                 "Loss(D)=%(loss_discriminator).5f, Loss(G)=%(loss_generator).5f, Loss(X)=%(loss_real_x).5f, "
                 "Loss(G(Zd))=%(loss_gen_x).5f, K=%(k).6f" % log_info)
            log_recorder.write(**log_info)

        m_global = np.average(m_global_history)
        if last_m_global <= m_global:  # decay LearningRate
            lr_decay_step += 1
        last_m_global = m_global

        # Save Model Weight in each epoch
        autoencoder.save_weights(config.autoencoder_weight_filename)
        generator.save_weights(config.generator_weight_filename)
        discriminator.save_weights(config.discriminator_weight_filename)

        # Generate Image in each epoch for fun
        generate(config, "ep%03d" % ep, generator)


def create_generator_loss(autoencoder: Container):
    def generator_loss(y_true, y_pred):
        y_pred_dash = autoencoder(y_pred)
        return K.mean(K.abs(y_pred - y_pred_dash), axis=[1, 2, 3])
    return generator_loss


class DiscriminatorLoss:
    __name__ = 'discriminator_loss'

    def __init__(self, initial_k=0, lambda_k=0.001, gamma=0.5):
        self.lambda_k = lambda_k
        self.gamma = gamma
        self.k_var = K.variable(initial_k, dtype=K.floatx(), name="discriminator_k")
        self.m_global_var = K.variable(0, dtype=K.floatx(), name="m_global")
        self.loss_real_x_var = K.variable(0, name="loss_real_x")  # for observation
        self.loss_gen_x_var = K.variable(0, name="loss_gen_x")    # for observation
        self.updates = []

    def __call__(self, y_true, y_pred):  # y_true, y_pred shape: (BS, row, col, ch * 2)
        data_true, generator_true = y_true[:, :, :, 0:3], y_true[:, :, :, 3:6]
        data_pred, generator_pred = y_pred[:, :, :, 0:3], y_pred[:, :, :, 3:6]
        loss_data = K.mean(K.abs(data_true - data_pred), axis=[1, 2, 3])
        loss_generator = K.mean(K.abs(generator_true - generator_pred), axis=[1, 2, 3])
        ret = loss_data - self.k_var * loss_generator

        # for updating values in each epoch, use `updates` mechanism
        # DiscriminatorModel collects Loss Function's updates attributes
        mean_loss_data = K.mean(loss_data)
        mean_loss_gen = K.mean(loss_generator)

        # update K
        new_k = self.k_var + self.lambda_k * (self.gamma * mean_loss_data - mean_loss_gen)
        new_k = K.clip(new_k, 0, 1)
        self.updates.append(K.update(self.k_var, new_k))

        # calculate M-Global
        m_global = mean_loss_data + K.abs(self.gamma * mean_loss_data - mean_loss_gen)
        self.updates.append(K.update(self.m_global_var, m_global))

        # let loss_real_x mean_loss_data
        self.updates.append(K.update(self.loss_real_x_var, mean_loss_data))

        # let loss_gen_x mean_loss_gen
        self.updates.append(K.update(self.loss_gen_x_var, mean_loss_gen))

        return ret

    @property
    def k(self):
        return K.get_value(self.k_var)

    @property
    def m_global(self):
        return K.get_value(self.m_global_var)

    @property
    def loss_real_x(self):
        return K.get_value(self.loss_real_x_var)

    @property
    def loss_gen_x(self):
        return K.get_value(self.loss_gen_x_var)


def info(msg):
    now = datetime.now()
    print("%s: %s" % (now, msg))


class LogRecorder:
    def __init__(self, log_filename):
        if not os.path.exists(os.path.dirname(log_filename)):
            os.makedirs((os.path.dirname(log_filename)))
        self.file_out = open(log_filename, "wt")
        self.columns = None

    def write(self, **kwargs):
        if not self.columns:
            self.columns = list(sorted(kwargs.keys()))
            self.file_out.write(",".join(self.columns) + "\n")
        values = [str(kwargs.get(x, "")) for x in self.columns]
        self.file_out.write(",".join(values) + "\n")
        self.file_out.flush()


def print_model_updates(model):
    training_updates = model.optimizer.get_updates(
        model._collected_trainable_weights,
        model.constraints,
        model.total_loss)
    updates = model.updates + training_updates
    print("\n".join(sorted([str(x[0]) for x in updates])))


if __name__ == '__main__':
    main()