import os
import numpy as np

# For reproducibility 
np.random.seed(42)

# force cuda device (empty for CPU)
os.environ["CUDA_VISIBLE_DEVICES"]=""

import sys
from keras import backend as K

# keras imports
from keras.models import Model
from keras.layers import Dense, Reshape, Input
from keras.layers.merge import  add, concatenate, multiply
from keras.layers.core import Activation, Lambda
from keras.layers.recurrent import GRU
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import Convolution1D
from keras.optimizers import SGD, adam

import argparse
import math

# sklearn imports
from sklearn import preprocessing
from sklearn.externals import joblib

# netcdf for reading packaged data
from scipy.io import netcdf

import matplotlib as mpl
mpl.use('Agg') # no need for X-server
from matplotlib import pyplot as plt

from models import fft_model, time_glot_model, discriminator, generator, gan_container
from data_utils import nc_data_provider, norm_stats

# edge smoothing window
gen_filtwidths = np.asarray([15, 15, 15])
edgelen = sum(gen_filtwidths-1)
hannwin = np.hanning(edgelen)
smoothwin = np.concatenate((hannwin[:edgelen//2], np.ones(400-edgelen), hannwin[edgelen//2:]))

def plot_feats(generated_feats, epoch, index, ext='', fig_dir="./figures", fig_type=""):
    plt.figure()
    for row in generated_feats:
        plt.plot(row)
    plt.savefig(fig_dir + '/' + fig_type +'_epoch{}_index{}'.format(epoch, index) + ext + '.png')
    plt.close()

def train_pls_model(BATCH_SIZE, data_dir, file_list, context_len=32, max_files=30):
    
    no_epochs = 20
    max_epochs_no_improvement = 5

    timesteps = context_len

    optim = adam(lr=0.0001)
    pls_model = time_glot_model(timesteps=timesteps)
    pls_model.compile(loss=['mse', 'mse'], loss_weights=[1.0, 0.0], optimizer=optim) # disregard fft loss

    fft_mod = fft_model()

    patience = max_epochs_no_improvement
    best_val_loss = 1e20
    for epoch in range(no_epochs):
        print("Pre-train epoch is", epoch)
        epoch_error = [0.0, 0.0]
        total_batches = 0
        val_data = []
        for data in nc_data_provider(file_list, data_dir,
                                     max_files=max_files, context_len=timesteps):

            if len(val_data) == 0:
                val_data = data
                print("using data subset for validation")
                continue

            X_train = data[0]
            Y_train = data[1]

            no_batches = int(X_train.shape[0] / BATCH_SIZE)                
            print("Number of batches", int(X_train.shape[0] / BATCH_SIZE))

            # shuffle data
            ind = np.random.permutation(X_train.shape[0])
            X_train = X_train[ind]
            Y_train = Y_train[ind]
            for index in range(int(X_train.shape[0] / BATCH_SIZE)):
                x_feats_batch = X_train[
                    index * BATCH_SIZE:(index + 1) * BATCH_SIZE]
                y_feats_batch = Y_train[
                    index * BATCH_SIZE:(index + 1) * BATCH_SIZE]

                x_feats_batch_fft = fft_mod.predict(x_feats_batch)
                
                d = pls_model.train_on_batch([y_feats_batch],
                                             [x_feats_batch, x_feats_batch_fft])
                    
                epoch_error += d

                if  (index + total_batches) % 500 == 0:

                    print("pre-training batch %d, wave loss: %f, spec loss %f" %
                          (index+total_batches, d[0], d[1]))

                    wave, spec = pls_model.predict([y_feats_batch])

                    wav_gen = wave[0,:]
                    wav_ref = x_feats_batch[0,:]
                    wavs = np.array([wav_ref, wav_gen])
                    plot_feats(wavs, epoch, index+total_batches, fig_type='mse', ext='.wave-pls')

                    spec_gen = spec[0,:]
                    spec_ref = x_feats_batch_fft[0,:]
                    specs = np.array([spec_ref, spec_gen])
                    plot_feats(specs, epoch, index+total_batches, fig_type='mse', ext='.spec-pls')
                    
            total_batches += no_batches                

        epoch_error[0] /= total_batches
        epoch_error[1] /= total_batches

        val_spec = fft_mod.predict(val_data[0])
        val_loss = pls_model.evaluate([val_data[1]],
                                      [val_data[0], val_spec],
                                      batch_size=BATCH_SIZE)
        
        print("epoch %d validation wave loss: %f ,spec loss %f \n" %
              (epoch, val_loss[0], val_loss[1]))

        print("epoch %d training wave loss: %f, spec loss %f \n" %
              (epoch, epoch_error[0], epoch_error[1]))
        
        # only on wave loss
        if val_loss[0] < best_val_loss:
            best_val_loss = val_loss[0]
            patience = max_epochs_no_improvement
            pls_model.save_weights('./pls.model')
        else:
            patience -= 1

        if patience == 0:
            break

    print ("Finished training") 


def train_noise_model(BATCH_SIZE, data_dir, file_list, save_weights=False,
                      context_len=32, max_files=30, stats=None):

    no_epochs = 15
    
    timesteps = context_len

    optim_container = adam(lr=1e-4)
    optim_discriminator = SGD(lr=1e-5)

    fft_mod = fft_model()
    pls_model = time_glot_model(timesteps=timesteps)

    pls_model.compile(loss=['mse','mse'], loss_weights=[1.0, 1.0], optimizer='adam')
    pls_model.load_weights("./pls.model")

    disc_model = discriminator()
    gen_model = generator()
    disc_on_gen = gan_container(gen_model, disc_model)
 
    gen_model.compile(loss='mse', optimizer="adam")

    # use peek adversarial and peek mse loss for training generator
    disc_model.trainable = False
    disc_on_gen.compile(loss=['mse','mse'], loss_weights=[1.0, 1.0], optimizer=optim_container) 

    # don't use peek loss for discriminator
    disc_model.trainable = True
    disc_model.compile(loss=['mse','mse'], loss_weights=[1.0, 0.0], optimizer=optim_discriminator) 

    print ("Discriminator model:")
    print (disc_model.summary())
    print ("Generator model:")
    print (gen_model.summary())
    print ("Joint model:")
    print (disc_on_gen.summary())

    label_fake = np.zeros((BATCH_SIZE, 1), dtype=np.float32)
    label_real = np.ones((BATCH_SIZE, 1), dtype=np.float32)

    # train residual GAN with FFT     
    for epoch in range(no_epochs):
        print("Epoch is", epoch)

        epoch_error = 0
        total_batches = 0
      
        for data in nc_data_provider(file_list, data_dir,
                                     max_files=max_files, context_len=timesteps):   

            X_train = data[0]
            Y_train = data[1]

            pls_len = X_train.shape[1]

            no_batches = int(X_train.shape[0] / BATCH_SIZE)    

            # shuffle data
            ind = np.random.permutation(X_train.shape[0])
            X_train = X_train[ind]
            Y_train = Y_train[ind]
            for index in range(int(X_train.shape[0] / BATCH_SIZE)):
                x_feats_batch = X_train[
                    index * BATCH_SIZE:(index + 1) * BATCH_SIZE]
                y_feats_batch = Y_train[
                    index * BATCH_SIZE:(index + 1) * BATCH_SIZE]

                x_pred_batch, x_pred_batch_fft = pls_model.predict([y_feats_batch])
                                
                pls_pred = x_pred_batch
                pls_real = x_feats_batch

                # smoothing windows to prevent edge effects
                pls_pred *= smoothwin
                pls_real *= smoothwin

                # evaluate target fft
                fft_real = fft_mod.predict(pls_real)

                noise = np.random.randn(BATCH_SIZE, pls_len)

                # train generator through discriminator
                _, peek_real = disc_model.predict([pls_real, fft_real])
                disc_model.trainable = False
                loss_g = disc_on_gen.train_on_batch([pls_pred, noise], [label_real, peek_real])
 
                noise = np.random.randn(BATCH_SIZE, pls_len)

                # train discriminator with real data
                disc_model.trainable = True
                loss_dr = disc_model.train_on_batch([pls_real, fft_real], [label_real, peek_real])

                # train discriminator with fake data
                pls_fake, fft_fake = gen_model.predict([pls_pred, noise])
                loss_df = disc_model.train_on_batch([pls_fake, fft_fake], [label_fake, peek_real])
        
                if (index + total_batches) % 500 == 0:

                    print("training batch %d, G loss: %f, D loss (real): %f, D loss (fake): %f" %
                          (index + total_batches, loss_g[0], loss_dr[0], loss_df[0]))

                if (index + total_batches) % 500 == 0:

                    wav_ref = pls_real[0,:]
                    wav_gen = pls_pred[0,:]
                    wav_noised = pls_fake[0,:]
                    wavs = np.array([wav_ref, wav_gen, wav_noised])
                    plot_feats(wavs, epoch, index+total_batches, fig_type='gan', ext='.wave')
                    spec_ref = fft_real[0,:]
                    spec_noised = fft_fake[0,:]
                    specs = np.array([spec_ref, spec_noised])
                    plot_feats(specs, epoch, index+total_batches, fig_type='gan', ext='.spec')

            total_batches += no_batches

        gen_model.save_weights('./models/noise_gen_epoch' + str(epoch) + '.model')

    print ("Finished noise model training") 

def generate(file_list, data_dir, output_dir, context_len=32, stats=None,
             base_model_path='./pls.model', gan_model_path='./noise_gen.model'):
    
    pulse_model = time_glot_model(timesteps=context_len)
    gan_model = generator()
    
    pulse_model.compile(loss='mse', optimizer="adam")
    gan_model.compile(loss='mse', optimizer="adam")

    pulse_model.load_weights(base_model_path)
    gan_model.load_weights(gan_model_path)

    for data in nc_data_provider(file_list, data_dir, input_only=True, 
                                 context_len=context_len):
        for fname, ac_data in data.iteritems():
            print (fname)
                                              
            pls_pred, _ = pulse_model.predict([ac_data])
            noise = np.random.randn(pls_pred.shape[0], pls_pred.shape[1])
            pls_gan, _ = gan_model.predict([pls_pred, noise])
            
            out_file = os.path.join(args.output_dir, fname + '.pls')
            pls_gan.astype(np.float32).tofile(out_file)

            out_file = os.path.join(args.output_dir, fname + '.pls_nonoise')
            pls_pred.astype(np.float32).tofile(out_file)
    
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--data_dir", type=str,
                        default="./traindata")
    parser.add_argument("--testdata_dir", type=str,
                        default="./testdata")
    parser.add_argument("--output_dir", type=str,
                        default="./output")
    parser.add_argument("--rnn_context_len", type=int, default=64)
    parser.add_argument("--max_files", type=int, default=100)
    parser.set_defaults(nice=False)
    parser.add_argument("--gan_model", type=str,
                        default=None)
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = get_args()

    if args.mode == "train":

        file_list = os.listdir(args.data_dir)

        train_pls_model(BATCH_SIZE=args.batch_size, data_dir=args.data_dir,
                        file_list=file_list, max_files=args.max_files,
                        context_len=args.rnn_context_len)

        stats = norm_stats(file_list[0], args.data_dir)

        train_noise_model(BATCH_SIZE=args.batch_size, data_dir=args.data_dir,
                          file_list=file_list, max_files=args.max_files,
                          context_len=args.rnn_context_len,
                          stats=stats)

    elif args.mode == "train_pulse_model":
        print ("MODE: Training time domain pulse model")
    
        file_list = os.listdir(args.data_dir)

        train_pls_model(BATCH_SIZE=args.batch_size, data_dir=args.data_dir,
                        file_list=file_list, max_files=args.max_files,
                        context_len=args.rnn_context_len)
        
    elif args.mode == "train_noise_model":
        print ("MODE: Training noise model")
    
        file_list = os.listdir(args.data_dir)

        stats = norm_stats(file_list[0], args.data_dir)

        train_noise_model(BATCH_SIZE=args.batch_size, data_dir=args.data_dir,
                          file_list=file_list, max_files=args.max_files,
                          context_len=args.rnn_context_len,
                          stats=stats)

    elif args.mode == "generate":
 
        test_dir = args.testdata_dir
        file_list = os.listdir(test_dir)

        stats = norm_stats(file_list[0], test_dir)

        generate(data_dir=test_dir, file_list=file_list,
                 output_dir=args.output_dir,
                 context_len=args.rnn_context_len, stats=stats,
                 gan_model_path=args.gan_model)