import h5py
import shutil

import torch
import torch.nn as nn
from torch.nn import init
from torch.autograd import Variable
from torch.optim.optimizer import Optimizer


class ReduceLROnPlateau(object):

    def __init__(self, optimizer, mode='min', factor=0.5, patience=5,
                 verbose=True, epsilon=1E-3, min_lr=0.):

        if factor <= 0.0:
            raise ValueError('ReduceLROnPlateau '
                             'does not support a factor <= 0.0')
        self.factor = factor
        self.min_lr = min_lr
        self.epsilon = epsilon
        self.patience = patience
        self.verbose = verbose
        self.mode = mode
        assert isinstance(optimizer, Optimizer)
        self.optimizer = optimizer
        self.reset()

    def reset(self):
        """Resets wait counter and cooldown counter.
        """
        if self.mode not in ['min', 'max']:
            raise RuntimeError(
                'Learning Rate Plateau Reducing mode %s is unknown!')
        if self.mode == 'min':
            self.monitor_op = lambda a, b: a < (b - self.epsilon)
            self.best = 1E12
        else:
            self.monitor_op = lambda a, b: a > (b + self.epsilon)
            self.best = -1E12
        self.wait = 0
        self.lr_epsilon = self.min_lr * 1E-4

    def step(self, metric, epoch):
        if self.monitor_op(metric, self.best):
            self.best = metric
            self.wait = 0

        elif self.wait >= self.patience:
            for param_group in self.optimizer.param_groups:
                old_lr = float(param_group['lr'])
                if old_lr > (self.min_lr + self.lr_epsilon):
                    new_lr = old_lr * self.factor
                    param_group['lr'] = max(new_lr, self.min_lr)
                    if self.verbose:
                        print('Reducing learning rate to %s.' % new_lr)
                    self.wait = 0
        else:
            self.wait += 1


class Flatten(nn.Module):

    def forward(self, x):
        size = x.size()  # read in N, C, H, W
        return x.view(size[0], -1)


class Repeat(nn.Module):

    def __init__(self, rep):
        super(Repeat, self).__init__()

        self.rep = rep

    def forward(self, x):
        size = tuple(x.size())
        size = (size[0], 1) + size[1:]
        x_expanded = x.view(*size)
        n = [1 for _ in size]
        n[1] = self.rep
        return x_expanded.repeat(*n)


class TimeDistributed(nn.Module):

    def __init__(self, module, batch_first=True):
        super(TimeDistributed, self).__init__()
        self.module = module
        self.batch_first = batch_first

    def forward(self, x):

        if len(x.size()) <= 2:
            return self.module(x)

        # Squash samples and timesteps into a single axis
        # (samples * timesteps, input_size)
        x_reshape = x.contiguous().view(-1, x.size(-1))

        y = self.module(x_reshape)

        # We have to reshape Y
        if self.batch_first:
            # (samples, timesteps, output_size)
            y = y.contiguous().view(x.size(0), -1, y.size(-1))
        else:
            # (timesteps, samples, output_size)
            y = y.view(-1, x.size(1), y.size(-1))

        return y


def reset(m):
    if hasattr(m, 'reset_parameters'):
        m.reset_parameters()


def train_model(train_loader, encoder, decoder, optimizer, dtype,
                print_every=100):
    encoder.train()
    decoder.train()

    for t, (x, y) in enumerate(train_loader):
        x_var = Variable(x.type(dtype))

        y_var = encoder(x_var)
        z_var = decoder(y_var)

        loss = encoder.vae_loss(z_var, x_var)
        if (t + 1) % print_every == 0:
            print('t = %d, loss = %.4f' % (t + 1, loss.data[0]))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


def validate_model(val_loader, encoder, decoder, dtype):
    encoder.eval()
    decoder.eval()

    avg_val_loss = 0.
    for t, (x, y) in enumerate(val_loader):
        x_var = Variable(x.type(dtype))

        y_var = encoder(x_var)
        z_var = decoder(y_var)

        avg_val_loss += encoder.vae_loss(z_var, x_var).data
    avg_val_loss /= t
    print('average validation loss: %.4f' % avg_val_loss[0])
    return avg_val_loss[0]


def load_dataset(filename, split=True):
    h5f = h5py.File(filename, 'r')
    if split:
        data_train = h5f['data_train'][:]
    else:
        data_train = None
    data_test = h5f['data_test'][:]
    charset = h5f['charset'][:]
    h5f.close()
    if split:
        return (data_train, data_test, charset)
    else:
        return (data_test, charset)


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


def initialize_weights(m):
    if (isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d)):
        init.xavier_uniform(m.weight.data)
    elif isinstance(m, nn.GRU):
        for weights in m.all_weights:
            for weight in weights:
                if len(weight.size()) > 1:
                    init.xavier_uniform(weight.data)