import torch
import torch.nn as nn
import torch.nn.functional as F
from a2c_ppo_acktr.utils import init

import os
import numpy as np
from torch.utils.data import RandomSampler, BatchSampler
from .trainer import Trainer
from .utils import EarlyStopping


class Unflatten(nn.Module):
    def __init__(self, new_shape):
        super().__init__()
        self.new_shape = new_shape

    def forward(self, x):
        x_uf = x.view(-1, *self.new_shape)
        return x_uf


class Decoder(nn.Module):
    def __init__(self, feature_size, final_conv_size, final_conv_shape, num_input_channels, encoder_type="Nature"):
        super().__init__()
        self.feature_size = feature_size
        self.final_conv_size = final_conv_size
        self.final_conv_shape = final_conv_shape
        self.num_input_channels = num_input_channels
        # self.fc =
        init_ = lambda m: init(m,
                               nn.init.orthogonal_,
                               lambda x: nn.init.constant_(x, 0),
                               nn.init.calculate_gain('relu'))
        if encoder_type == "Nature":
            self.main = nn.Sequential(
                nn.Linear(in_features=self.feature_size,
                          out_features=self.final_conv_size),
                nn.ReLU(),
                Unflatten(self.final_conv_shape),

                init_(nn.ConvTranspose2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=0)),
                nn.ReLU(),
                init_(nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=0)),
                nn.ReLU(),
                init_(nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=0,
                                         output_padding=1)),
                nn.ReLU(),
                init_(nn.ConvTranspose2d(in_channels=32, out_channels=num_input_channels,
                                         kernel_size=8, stride=4, output_padding=(2, 0))),
                nn.Sigmoid()
            )

    def forward(self, f):
        im = self.main(f)
        return im


class VAE(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder
        self.feature_size = self.encoder.feature_size
        self.final_conv_size = self.encoder.final_conv_size
        self.final_conv_shape = self.encoder.final_conv_shape
        self.input_channels = self.encoder.input_channels

#         self.mu_fc = nn.Linear(in_features=self.feature_size,
#                                    out_features=self.feature_size)
        
        self.logvar_fc = nn.Linear(in_features=self.final_conv_size,
                                   out_features=self.feature_size)

        self.decoder = Decoder(feature_size=self.feature_size,
                               final_conv_size=self.final_conv_size,
                               final_conv_shape=self.final_conv_shape,
                               num_input_channels=self.input_channels)

    def reparametrize(self, mu, logvar):
        if self.training:
            eps = torch.randn(*logvar.size()).to(mu.device)
            std = torch.exp(0.5 * logvar)
            z = mu + eps * std
        else:
            z = mu
        return z

    def forward(self, x):
        mu = self.encoder(x)
        logvar = self.logvar_fc(self.encoder.main[:-1](x))
        z = self.reparametrize(mu, logvar)
        x_hat = self.decoder(z)
        return x_hat, mu, logvar


class VAELoss(object):
    def __init__(self, beta=1.0):
        self.beta = beta

    def __call__(self, x, x_hat, mu, logvar):
        kldiv = -0.5 * torch.sum(1 + logvar - mu ** 2 - torch.exp(logvar))
        rec = F.mse_loss(x_hat, x, reduction='sum')
        loss = rec + self.beta * kldiv
        return loss


class VAETrainer(Trainer):
    # TODO: Make it work for all modes, right now only it defaults to pcl.
    def __init__(self, encoder, config, device=torch.device('cpu'), wandb=None):
        super().__init__(encoder, wandb, device)
        self.config = config
        self.patience = self.config["patience"]
        self.VAE = VAE(encoder).to(device)
        self.epochs = config['epochs']
        self.batch_size = config['batch_size']
        self.device = device
        self.optimizer = torch.optim.Adam(list(self.VAE.parameters()),
                                          lr=config['lr'], eps=1e-5)
        self.loss_fn = VAELoss(beta=self.config["beta"])
        self.early_stopper = EarlyStopping(patience=self.patience, verbose=False, wandb=self.wandb, name="encoder")

    def generate_batch(self, episodes):
        total_steps = sum([len(e) for e in episodes])
        print('Total Steps: {}'.format(total_steps))
        # Episode sampler
        # Sample `num_samples` episodes then batchify them with `self.batch_size` episodes per batch
        sampler = BatchSampler(RandomSampler(range(len(episodes)),
                                             replacement=True, num_samples=total_steps),
                               self.batch_size, drop_last=True)
        for indices in sampler:
            episodes_batch = [episodes[x] for x in indices]
            x_t, x_tprev, x_that, ts, thats = [], [], [], [], []
            for episode in episodes_batch:
                # Get one sample from this episode
                t, t_hat = 0, 0
                t, t_hat = np.random.randint(0, len(episode)), np.random.randint(0, len(episode))
                x_t.append(episode[t])
            yield torch.stack(x_t).float().to(self.device) / 255.

    def do_one_epoch(self, epoch, episodes):
        mode = "train" if self.VAE.training else "val"
        epoch_loss, accuracy, steps = 0., 0., 0
        data_generator = self.generate_batch(episodes)
        for x_t in data_generator:
            with torch.set_grad_enabled(mode == 'train'):
                x_hat, mu, logvar = self.VAE(x_t)
                loss = self.loss_fn(x_t, x_hat, mu, logvar)

            if mode == "train":
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            epoch_loss += loss.detach().item()
            steps += 1
        self.log_results(epoch, epoch_loss / steps, prefix=mode)
        if mode == "val":
            self.early_stopper(-epoch_loss / steps, self.encoder)

    #             xim = x_hat.detach().cpu().numpy()[0].transpose(1,2,0)
    #             self.wandb.log({"example_reconstruction": [self.wandb.Image(xim, caption="")]})

    def train(self, tr_eps, val_eps):
        for e in range(self.epochs):
            self.VAE.train()
            self.do_one_epoch(e, tr_eps)

            self.VAE.eval()
            self.do_one_epoch(e, val_eps)

            if self.early_stopper.early_stop:
                break
        torch.save(self.encoder.state_dict(), os.path.join(self.wandb.run.dir, self.config['env_name'] + '.pt'))

    def log_results(self, epoch_idx, epoch_loss, prefix=""):
        print("{} Epoch: {}, Epoch Loss: {}".format(prefix.capitalize(), epoch_idx, epoch_loss))
        self.wandb.log({prefix + '_loss': epoch_loss})