#/usr/bin/env python3

import os
import numpy as np

import torch
import torch.nn as nn
from torch.autograd import Variable, Function
from torch.nn.parameter import Parameter
import torch.optim as optim
import torch.cuda

import model_classes

import ipdb


def task_loss(Y_sched, Y_actual, params):
    return (params["gamma_under"] * torch.clamp(Y_actual - Y_sched, min=0) + 
            params["gamma_over"] * torch.clamp(Y_sched - Y_actual, min=0) + 
            0.5 * (Y_sched - Y_actual)**2).mean(0)

def task_loss_no_mean(Y_sched, Y_actual, params):
    return (params["gamma_under"] * torch.clamp(Y_actual - Y_sched, min=0) + 
        params["gamma_over"] * torch.clamp(Y_sched - Y_actual, min=0) + 
        0.5 * (Y_sched - Y_actual)**2)

def rmse_loss(mu_pred, Y_actual):
    return ((mu_pred - Y_actual)**2).mean(dim=0).sqrt().data.cpu().numpy()

def rmse_loss_weighted(mu_pred, Y_actual, weights):
    return ((weights * (mu_pred - Y_actual)**2).mean(dim=0).sqrt()).sum()


# TODO: minibatching
def run_rmse_net(model, variables, X_train, Y_train):
    opt = optim.Adam(model.parameters(), lr=1e-3)

    for i in range(1000):
        opt.zero_grad()
        model.train()
        train_loss = nn.MSELoss()(
            model(variables['X_train_'])[0], variables['Y_train_'])
        train_loss.backward()
        opt.step()

        model.eval()
        test_loss = nn.MSELoss()(
            model(variables['X_test_'])[0], variables['Y_test_'])

        print(i, train_loss.data[0], test_loss.data[0])

    model.eval()
    model.set_sig(variables['X_train_'], variables['Y_train_'])

    return model


def run_weighted_rmse_net(X_train, Y_train, X_test, Y_test, params):
    weights = Variable(torch.ones(Y_train.shape), requires_grad=False).cuda()
    for i in range(10):
        model, weights2 = run_weighted_rmse_net_helper(X_train, Y_train, X_test, Y_test, params, weights, i)
        weights = Variable(weights2.data, requires_grad=False)
    return model

def run_weighted_rmse_net_helper(X_train, Y_train, X_test, Y_test, params, weights, i):
    X_train_ = Variable(torch.Tensor(X_train[:,:-1])).cuda()
    Y_train_ = Variable(torch.Tensor(Y_train)).cuda()
    X_test_ = Variable(torch.Tensor(X_test[:,:-1])).cuda()
    Y_test_ = Variable(torch.Tensor(Y_test)).cuda()

    model = model_classes.Net(X_train[:,:-1], Y_train, [200, 200])
    model.cuda()
    opt = optim.Adam(model.parameters(), lr=1e-3)
    solver = model_classes.SolveScheduling(params)
    for j in range(100):

        model.train()
        batch_train_weightrmse(100, i*100 + j, X_train_.data, Y_train_.data, model, opt, weights.data)

    # Rebalance weights
    model.eval()
    mu_pred_train, sig_pred_train = model(X_train_)
    Y_sched_train = solver(mu_pred_train.double(), sig_pred_train.double())
    weights2 = task_loss_no_mean(
        Y_sched_train.float(), Y_train_, params).cuda()
    model.set_sig(X_train_, Y_train_)

    return model, weights2

def batch_train_weightrmse(batch_sz, epoch, X_train_t, Y_train_t, model, opt, weights_t):

    batch_data_ = Variable(torch.Tensor(batch_sz, X_train_t.size(1)), 
        requires_grad=False).cuda()
    batch_targets_ = Variable(torch.Tensor(batch_sz, Y_train_t.size(1)), 
        requires_grad=False).cuda()
    batch_weights_ = Variable(torch.Tensor(batch_sz, weights_t.size(1)),
        requires_grad=False).cuda()

    size = batch_sz

    for i in range(0, X_train_t.size(0), batch_sz):

        # Deal with potentially incomplete (last) batch
        if i + batch_sz  > X_train_t.size(0):
            size = X_train_t.size(0) - i
            batch_data_ = Variable(torch.Tensor(size, X_train_t.size(1)), requires_grad=False).cuda()
            batch_targets_ = Variable(torch.Tensor(size, Y_train_t.size(1)), requires_grad=False).cuda()
            batch_weights_ = Variable(torch.Tensor(size, weights_t.size(1)), requires_grad=False).cuda()

        batch_data_.data[:] = X_train_t[i:i+size]
        batch_targets_.data[:] = Y_train_t[i:i+size]
        batch_weights_.data[:] = weights_t[i:i+size]

        opt.zero_grad()
        preds = model(batch_data_)[0]

        ((batch_weights_ * (preds - batch_targets_)**2).mean(dim=0).sqrt()).sum().backward()

        opt.step()

        print ('Epoch: {}, {}/{}'.format(epoch, i+batch_sz, X_train_t.size(0)))
       

# TODO: minibatching
def run_task_net(model, variables, params, X_train, Y_train, args):
    opt = optim.Adam(model.parameters(), lr=1e-4)
    solver = model_classes.SolveScheduling(params)

    # For early stopping
    prev_min = 0
    hold_costs = []
    model_states = []
    num_stop_rounds = 20

    for i in range(1000):
        opt.zero_grad()
        model.train()
        mu_pred_train, sig_pred_train = model(variables['X_train_'])
        Y_sched_train = solver(mu_pred_train.double(), sig_pred_train.double())
        train_loss = task_loss(
            Y_sched_train.float(),variables['Y_train_'], params)
        train_loss.sum().backward()

        model.eval()
        mu_pred_test, sig_pred_test = model(variables['X_test_'])
        Y_sched_test = solver(mu_pred_test.double(), sig_pred_test.double())
        test_loss = task_loss(
            Y_sched_test.float(), variables['Y_test_'], params)

        mu_pred_hold, sig_pred_hold = model(variables['X_hold_'])
        Y_sched_hold = solver(mu_pred_hold.double(), sig_pred_hold.double())
        hold_loss = task_loss(
            Y_sched_hold.float(), variables['Y_hold_'], params)

        opt.step()

        print(i, train_loss.sum().data[0], test_loss.sum().data[0], 
            hold_loss.sum().data[0])


        # Early stopping
        hold_costs.append(hold_loss.sum().data[0])
        model_states.append(model.state_dict().copy())
        if i > 0 and i % num_stop_rounds == 0:
            idx = hold_costs.index(min(hold_costs))
            if prev_min == hold_costs[idx]:
                model.eval()
                best_model = model_classes.Net(
                    X_train[:,:-1], Y_train, [200, 200])
                best_model.load_state_dict(model_states[idx])
                best_model.cuda()
                return best_model
            else:
                prev_min = hold_costs[idx]
                hold_costs = [prev_min]
                model_states = [model_states[idx]]

    return model


# TODO: minibatching
def eval_net(which, model, variables, params, save_folder):
    solver = model_classes.SolveScheduling(params)

    model.eval()
    mu_pred_train, sig_pred_train = model(variables['X_train_'])
    mu_pred_test, sig_pred_test = model(variables['X_test_'])

    if (which == "task_net"):
        mu_pred_hold, sig_pred_hold = model(variables['X_hold_'])

    # Eval model on rmse
    train_rmse = rmse_loss(mu_pred_train, variables['Y_train_'])
    test_rmse = rmse_loss(mu_pred_test, variables['Y_test_'])

    if (which == "task_net"):
        hold_rmse = rmse_loss(mu_pred_hold, variables['Y_hold_'])

    with open(
        os.path.join(save_folder, '{}_train_rmse'.format(which)), 'wb') as f:
        np.save(f, train_rmse)

    with open(
        os.path.join(save_folder, '{}_test_rmse'.format(which)), 'wb') as f:
        np.save(f, test_rmse)

    if (which == "task_net"):
        with open(
            os.path.join(save_folder, '{}_hold_rmse'.format(which)), 'wb') as f:
            np.save(f, hold_rmse)

    # Eval model on task loss
    Y_sched_train = solver(mu_pred_train.double(), sig_pred_train.double())
    train_loss_task = task_loss(
        Y_sched_train.float(), variables['Y_train_'], params)

    Y_sched_test = solver(mu_pred_test.double(), sig_pred_test.double())
    test_loss_task = task_loss(
        Y_sched_test.float(), variables['Y_test_'], params)

    if (which == "task_net"):
        Y_sched_hold = solver(mu_pred_hold.double(), sig_pred_hold.double())
        hold_loss_task = task_loss(
            Y_sched_hold.float(), variables['Y_hold_'], params)

    torch.save(train_loss_task.data, 
        os.path.join(save_folder, '{}_train_task'.format(which)))
    torch.save(test_loss_task.data, 
        os.path.join(save_folder, '{}_test_task'.format(which)))

    if (which == "task_net"):
        torch.save(hold_loss_task.data, 
            os.path.join(save_folder, '{}_hold_task'.format(which)))