import torch
from torch import nn
from torch import optim
from torch.autograd import Variable

import numpy as np

import os
import sys
import time
import optparse


import utils
import config
from data import Dataset
from model import PerformanceRNN
from sequence import NoteSeq, EventSeq, ControlSeq

# pylint: disable=E1102
# pylint: disable=E1101

#========================================================================
# Settings
#========================================================================

def get_options():
    parser = optparse.OptionParser()

    parser.add_option('-s', '--session',
                      dest='sess_path',
                      type='string',
                      default='save/train.sess')

    parser.add_option('-d', '--dataset',
                      dest='data_path',
                      type='string',
                      default='dataset/processed/')

    parser.add_option('-i', '--saving-interval',
                      dest='saving_interval',
                      type='float',
                      default=60.)

    parser.add_option('-b', '--batch-size',
                      dest='batch_size',
                      type='int',
                      default=config.train['batch_size'])

    parser.add_option('-l', '--learning-rate',
                      dest='learning_rate',
                      type='float',
                      default=config.train['learning_rate'])

    parser.add_option('-w', '--window-size',
                      dest='window_size',
                      type='int',
                      default=config.train['window_size'])

    parser.add_option('-S', '--stride-size',
                      dest='stride_size',
                      type='int',
                      default=config.train['stride_size'])

    parser.add_option('-c', '--control-ratio',
                      dest='control_ratio',
                      type='float',
                      default=config.train['control_ratio'])

    parser.add_option('-T', '--teacher-forcing-ratio',
                      dest='teacher_forcing_ratio',
                      type='float',
                      default=config.train['teacher_forcing_ratio'])

    parser.add_option('-t', '--use-transposition',
                      dest='use_transposition',
                      action='store_true',
                      default=config.train['use_transposition'])

    parser.add_option('-p', '--model-params',
                      dest='model_params',
                      type='string',
                      default='')
                      
    parser.add_option('-r', '--reset-optimizer',
                      dest='reset_optimizer',
                      action='store_true',
                      default=False)
                      
    parser.add_option('-L', '--enable-logging',
                      dest='enable_logging',
                      action='store_true',
                      default=False)

    return parser.parse_args()[0]

options = get_options()

#------------------------------------------------------------------------

sess_path = options.sess_path
data_path = options.data_path
saving_interval = options.saving_interval

learning_rate = options.learning_rate
batch_size = options.batch_size
window_size = options.window_size
stride_size = options.stride_size
use_transposition = options.use_transposition
control_ratio = options.control_ratio
teacher_forcing_ratio = options.teacher_forcing_ratio
reset_optimizer = options.reset_optimizer
enable_logging = options.enable_logging

event_dim = EventSeq.dim()
control_dim = ControlSeq.dim()
model_config = config.model
model_params = utils.params2dict(options.model_params)
model_config.update(model_params)
device = config.device

print('-' * 70)

print('Session path:', sess_path)
print('Dataset path:', data_path)
print('Saving interval:', saving_interval)
print('-' * 70)

print('Hyperparameters:', utils.dict2params(model_config))
print('Learning rate:', learning_rate)
print('Batch size:', batch_size)
print('Window size:', window_size)
print('Stride size:', stride_size)
print('Control ratio:', control_ratio)
print('Teacher forcing ratio:', teacher_forcing_ratio)
print('Random transposition:', use_transposition)
print('Reset optimizer:', reset_optimizer)
print('Enabling logging:', enable_logging)
print('Device:', device)
print('-' * 70)


#========================================================================
# Load session and dataset
#========================================================================

def load_session():
    global sess_path, model_config, device, learning_rate, reset_optimizer
    try:
        sess = torch.load(sess_path)
        if 'model_config' in sess and sess['model_config'] != model_config:
            model_config = sess['model_config']
            print('Use session config instead:')
            print(utils.dict2params(model_config))
        model_state = sess['model_state']
        optimizer_state = sess['model_optimizer_state']
        print('Session is loaded from', sess_path)
        sess_loaded = True
    except:
        print('New session')
        sess_loaded = False
    model = PerformanceRNN(**model_config).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    if sess_loaded:
        model.load_state_dict(model_state)
        if not reset_optimizer:
            optimizer.load_state_dict(optimizer_state)
    return model, optimizer

def load_dataset():
    global data_path
    dataset = Dataset(data_path, verbose=True)
    dataset_size = len(dataset.samples)
    assert dataset_size > 0
    return dataset


print('Loading session')
model, optimizer = load_session()
print(model)

print('-' * 70)

print('Loading dataset')
dataset = load_dataset()
print(dataset)

print('-' * 70)

#------------------------------------------------------------------------

def save_model():
    global model, optimizer, model_config, sess_path
    print('Saving to', sess_path)
    torch.save({'model_config': model_config,
                'model_state': model.state_dict(),
                'model_optimizer_state': optimizer.state_dict()}, sess_path)
    print('Done saving')


#========================================================================
# Training
#========================================================================

if enable_logging:
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter()

last_saving_time = time.time()
loss_function = nn.CrossEntropyLoss()

try:
    batch_gen = dataset.batches(batch_size, window_size, stride_size)

    for iteration, (events, controls) in enumerate(batch_gen):
        if use_transposition:
            offset = np.random.choice(np.arange(-6, 6))
            events, controls = utils.transposition(events, controls, offset)

        events = torch.LongTensor(events).to(device)
        assert events.shape[0] == window_size

        if np.random.random() < control_ratio:
            controls = torch.FloatTensor(controls).to(device)
            assert controls.shape[0] == window_size
        else:
            controls = None

        init = torch.randn(batch_size, model.init_dim).to(device)
        outputs = model.generate(init, window_size, events=events[:-1], controls=controls,
                                 teacher_forcing_ratio=teacher_forcing_ratio, output_type='logit')
        assert outputs.shape[:2] == events.shape[:2]

        loss = loss_function(outputs.view(-1, event_dim), events.view(-1))
        model.zero_grad()
        loss.backward()

        norm = utils.compute_gradient_norm(model.parameters())
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()

        if enable_logging:
            writer.add_scalar('model/loss', loss.item(), iteration)
            writer.add_scalar('model/norm', norm.item(), iteration)

        print(f'iter {iteration}, loss: {loss.item()}')

        if time.time() - last_saving_time > saving_interval:
            save_model()
            last_saving_time = time.time()

except KeyboardInterrupt:
    save_model()