import torch
import torch.nn as nn
from collections import OrderedDict
from models.resnet import _weights_init
from utils.kfac_utils import (ComputeCovA,
                              ComputeCovAPatch,
                              ComputeCovG,
                              fetch_mat_weights,
                              mat_to_weight_and_bias)
from utils.common_utils import (tensor_to_list, PresetLRScheduler)
from utils.prune_utils import (filter_indices,
                               get_threshold,
                               update_indices,
                               normalize_factors)
from utils.network_utils import stablize_bn
from tqdm import tqdm


class KFACFullPruner:

        def __init__(self,
                     model,
                     builder,
                     config,
                     writer,
                     logger,
                     prune_ratio_limit,
                     network,
                     batch_averaged=True,
                     use_patch=False,
                     fix_layers=0):
            print('Using patch is %s' % use_patch)
            self.iter = 0
            self.logger = logger
            self.writer = writer
            self.config = config
            self.prune_ratio_limit = prune_ratio_limit
            self.network = network
            self.CovAHandler = ComputeCovA() if not use_patch else ComputeCovAPatch()
            self.CovGHandler = ComputeCovG()
            self.batch_averaged = batch_averaged
            self.known_modules = {'Linear', 'Conv2d'}
            self.modules = []
            self.model = model
            self.builder = builder
            self.fix_layers = fix_layers
            # self._prepare_model()
            self.steps = 0
            self.use_patch = False  # use_patch
            self.m_aa, self.m_gg = {}, {}
            self.Q_a, self.Q_g = {}, {}
            self.d_a, self.d_g = {}, {}
            self.W_pruned = {}
            self.S_l = None

            self.importances = {}
            self._inversed = False
            self._cfgs = {}
            self._indices = {}

        def _save_input(self, module, input):
            aa = self.CovAHandler(input[0].data, module)
            # Initialize buffers
            if self.steps == 0:
                self.m_aa[module] = torch.diag(aa.new(aa.size(0)).fill_(0))
            self.m_aa[module] += aa

        def _save_grad_output(self, module, grad_input, grad_output):
            # Accumulate statistics for Fisher matrices
            gg = self.CovGHandler(grad_output[0].data, module, self.batch_averaged)
            # Initialize buffers
            if self.steps == 0:
                self.m_gg[module] = torch.diag(gg.new(gg.size(0)).fill_(0))
            self.m_gg[module] += gg

        def make_pruned_model(self, dataloader, criterion, device, fisher_type, prune_ratio, normalize=True, re_init=False):
            self._prepare_model()
            self.init_step()

            self._compute_fisher(dataloader, criterion, device, fisher_type)
            self._update_inv()  # eigen decomposition of fisher

            self._get_unit_importance(normalize)
            self._do_prune(prune_ratio, re_init)
            if not re_init:
                self._do_surgery()
            self._build_pruned_model(re_init)

            self._rm_hooks()
            self._clear_buffer()
            print(self.model)
            return str(self.model)

        def _prepare_model(self):
            count = 0
            print(self.model)
            print("=> We keep following layers in KFACPruner. ")
            for module in self.model.modules():
                classname = module.__class__.__name__
                if classname in self.known_modules:
                    self.modules.append(module)
                    module.register_forward_pre_hook(self._save_input)
                    module.register_backward_hook(self._save_grad_output)
                    print('(%s): %s' % (count, module))
                    count += 1
            self.modules = self.modules[self.fix_layers:]

        def _compute_fisher(self, dataloader, criterion, device='cuda', fisher_type='true'):
            self.mode = 'basis'
            self.model = self.model.eval()
            self.init_step()
            for batch_idx, (inputs, targets) in tqdm(enumerate(dataloader), total=len(dataloader)):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = self.model(inputs)
                if fisher_type == 'true':
                    sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs.cpu().data, dim=1),
                                                  1).squeeze().to(device)
                    loss_sample = criterion(outputs, sampled_y)
                    loss_sample.backward()
                else:
                    loss = criterion(outputs, targets)
                    loss.backward()
                self.step()
            self.mode = 'quite'

        def _update_inv(self):
            assert self.steps > 0, 'At least one step before update inverse!'
            eps = 1e-15
            for idx, m in enumerate(self.modules):
                # m_aa, m_gg = normalize_factors(self.m_aa[m], self.m_gg[m])
                m_aa, m_gg = self.m_aa[m], self.m_gg[m]
                self.d_a[m], self.Q_a[m] = torch.symeig(m_aa / self.steps, eigenvectors=True)
                self.d_g[m], self.Q_g[m] = torch.symeig(m_gg / self.steps, eigenvectors=True)
                self.d_a[m].mul_((self.d_a[m] > eps).float())
                self.d_g[m].mul_((self.d_g[m] > eps).float())

            self._inversed = True
            self.iter += 1

        def _get_unit_importance(self, normalize):
            eps = 1e-10
            assert self._inversed, 'Not inversed.'
            with torch.no_grad():
                for m in self.modules:
                    w = fetch_mat_weights(m, False)  # output_dim * input_dim
                    # (Q_a ⊗ Q_g) vec(W) = Q_g.t() @ W @ Q_a
                    if self.S_l is None:
                        A_inv = self.Q_a[m] @ (torch.diag(1.0 / (self.d_a[m] + eps))) @ self.Q_a[m].t()
                        G_inv = self.Q_g[m] @ (torch.diag(1.0 / (self.d_g[m] + eps))) @ self.Q_g[m].t()
                        A_inv_diag = torch.diag(A_inv)
                        G_inv_diag = torch.diag(G_inv)
                        w_imp = w ** 2 / (G_inv_diag.unsqueeze(1) @ A_inv_diag.unsqueeze(0))
                    else:
                        Q_a, Q_g = self.Q_a[m], self.Q_g[m]
                        S_l = self.S_l[m]
                        S_l_inv = 1.0 / (S_l + eps)
                        H_inv_diag = (Q_g ** 2) @ S_l_inv @ (Q_a.t() ** 2)  # output_dim * input_dim
                        w_imp = w ** 2 / H_inv_diag
                    self.W_pruned[m] = w
                    out_neuron_imp = w_imp.sum(1)  # w_imp.sum(1)
                    if not normalize:
                        out_imps = out_neuron_imp
                    else:
                        out_imps = out_neuron_imp / out_neuron_imp.sum()
                    self.importances[m] = (tensor_to_list(out_imps), out_neuron_imp.size(0))

        def _do_surgery(self):
            eps = 1e-10
            assert not self.use_patch, 'Will never use patch'
            with torch.no_grad():
                for idx, m in enumerate(self.modules):
                    w = fetch_mat_weights(m, False)  # output_dim * input_dim
                    if w.size(0) == len(m.out_indices):
                        continue
                    if self.S_l is None:
                        A_inv = self.Q_a[m] @ (torch.diag(1.0 / (self.d_a[m] + eps))) @ self.Q_a[m].t()
                        G_inv = self.Q_g[m] @ (torch.diag(1.0 / (self.d_g[m] + eps))) @ self.Q_g[m].t()
                        A_inv_diag = torch.diag(A_inv)
                        G_inv_diag = torch.diag(G_inv)
                        coeff = w / (G_inv_diag.unsqueeze(1) @ A_inv_diag.unsqueeze(0))
                        coeff[m.out_indices, :] = 0
                        delta_theta = -G_inv @ coeff @ A_inv
                    else:
                        Q_a, Q_g = self.Q_a[m], self.Q_g[m]
                        S_l = self.S_l[m]
                        S_l_inv = 1.0 / (S_l + eps)
                        H_inv_diag = (Q_g ** 2) @ S_l_inv @ (Q_a.t() ** 2)  # output_dim * input_dim
                        coeff = w / H_inv_diag
                        coeff[m.out_indices, :] = 0
                        delta_theta = (Q_g.t() @ coeff @ Q_a)/S_l_inv
                        delta_theta = Q_g @ delta_theta @ Q_a.t()
                    # ==== update weights and bias ======
                    dw, dbias = mat_to_weight_and_bias(delta_theta, m)
                    m.weight += dw
                    if m.bias is not None:
                        m.bias += dbias

        def _do_prune(self, prune_ratio, re_init):
            # get threshold
            all_importances = []
            for m in self.modules:
                imp_m = self.importances[m]
                imps = imp_m[0]
                all_importances += imps
            all_importances = sorted(all_importances)
            idx = int(prune_ratio * len(all_importances))
            threshold = all_importances[idx]

            threshold_recompute = get_threshold(all_importances, prune_ratio)
            idx_recomputed = len(filter_indices(all_importances, threshold))
            print('=> The threshold is: %.5f (%d), computed by function is: %.5f (%d).' % (threshold,
                                                                                           idx,
                                                                                           threshold_recompute,
                                                                                           idx_recomputed))

            # do pruning
            print('=> Conducting network pruning. Max: %.5f, Min: %.5f, Threshold: %.5f' % (max(all_importances),
                                                                                            min(all_importances),
                                                                                            threshold))
            self.logger.info("[Weight Importances] Max: %.5f, Min: %.5f, Threshold: %.5f." % (max(all_importances),
                                                                                              min(all_importances),
                                                                                              threshold))

            for idx, m in enumerate(self.modules):
                imp_m = self.importances[m]
                n_r = imp_m[1]
                row_imps = imp_m[0]
                row_indices = filter_indices(row_imps, threshold)
                r_ratio = 1 - len(row_indices) / n_r

                # compute row indices (out neurons)
                if r_ratio > self.prune_ratio_limit:
                    r_threshold = get_threshold(row_imps, self.prune_ratio_limit)
                    row_indices = filter_indices(row_imps, r_threshold)  # list(range(self.W_star[m].size(0)))
                    print('* row indices empty!')
                if isinstance(m, nn.Linear) and idx == len(self.modules) - 1:
                    row_indices = list(range(self.W_pruned[m].size(0)))

                m.out_indices = row_indices
                m.in_indices = None
            update_indices(self.model, self.network)

        def _build_pruned_model(self, re_init):
            for m in self.model.modules():
                # m.grad = None
                if isinstance(m, nn.BatchNorm2d):
                    idxs = m.in_indices
                    m.num_features = len(idxs)
                    m.weight.data = m.weight.data[idxs]
                    m.bias.data = m.bias.data[idxs].clone()
                    m.running_mean = m.running_mean[idxs].clone()
                    m.running_var = m.running_var[idxs].clone()
                    # m.in_indices = None
                    # m.out_indices = None
                    m.weight.grad = None
                    m.bias.grad = None
                elif isinstance(m, nn.Conv2d):
                    in_indices = m.in_indices
                    if m.in_indices is None: 
                        in_indices = list(range(m.weight.size(1)))
                    m.weight.data = m.weight.data[m.out_indices, :, :, :][:, in_indices, :, :].clone()
                    if m.bias is not None:
                        m.bias.data = m.bias.data[m.out_indices]
                        m.bias.grad = None
                    m.in_channels = len(in_indices)
                    m.out_channels = len(m.out_indices)
                    # m.in_indices = None
                    # m.out_indices = None
                    m.weight.grad = None
                    
                elif isinstance(m, nn.Linear):
                    in_indices = m.in_indices
                    if m.in_indices is None:
                        in_indices = list(range(m.weight.size(1)))
                    m.weight.data = m.weight.data[m.out_indices, :][:, in_indices].clone()
                    if m.bias is not None:
                        m.bias.data = m.bias.data[m.out_indices].clone()
                        m.bias.grad = None
                    m.in_features = len(in_indices)
                    m.out_features = len(m.out_indices)
                    # m.in_indices = None
                    # m.out_indices = None
                    m.weight.grad = None
            if re_init:
                self.model.apply(_weights_init)
            # import pdb; pdb.set_trace()

        def init_step(self):
            self.steps = 0

        def step(self):
            self.steps += 1

        def _rm_hooks(self):
            for m in self.model.modules():
                classname = m.__class__.__name__
                if classname in self.known_modules:
                    m._backward_hooks = OrderedDict()
                    m._forward_pre_hooks = OrderedDict()

        def _clear_buffer(self):
            self.m_aa = {}
            self.m_gg = {}
            self.d_a = {}
            self.d_g = {}
            self.Q_a = {}
            self.Q_g = {}
            self.modules = []
            if self.S_l is not None:
                self.S_l = {}

        def fine_tune_model(self, trainloader, testloader, criterion, optim, learning_rate, weight_decay, nepochs=10,
                            device='cuda'):
            self.model = self.model.train()
            optimizer = optim.SGD(self.model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay)
            # optimizer = optim.Adam(self.model.parameters(), weight_decay=5e-4)
            lr_schedule = {0: learning_rate, int(nepochs * 0.5): learning_rate * 0.1,
                           int(nepochs * 0.75): learning_rate * 0.01}
            lr_scheduler = PresetLRScheduler(lr_schedule)
            best_test_acc, best_test_loss = 0, 100
            iterations = 0
            for epoch in range(nepochs):
                self.model = self.model.train()
                correct = 0
                total = 0
                all_loss = 0
                lr_scheduler(optimizer, epoch)
                desc = ('[LR: %.5f] Loss: %.3f | Acc: %.3f%% (%d/%d)' % (
                lr_scheduler.get_lr(optimizer), 0, 0, correct, total))
                prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True)
                for batch_idx, (inputs, targets) in prog_bar:
                    optimizer.zero_grad()
                    inputs, targets = inputs.to(device), targets.to(device)
                    outputs = self.model(inputs)
                    loss = criterion(outputs, targets)
                    self.writer.add_scalar('train_%d/loss' % self.iter, loss.item(), iterations)
                    iterations += 1
                    all_loss += loss.item()
                    loss.backward()
                    optimizer.step()
                    _, predicted = outputs.max(1)
                    total += targets.size(0)
                    correct += predicted.eq(targets).sum().item()
                    desc = ('[%d][LR: %.5f, WD: %.5f] Loss: %.3f | Acc: %.3f%% (%d/%d)' %
                            (epoch, lr_scheduler.get_lr(optimizer), weight_decay, all_loss / (batch_idx + 1),
                             100. * correct / total, correct, total))
                    prog_bar.set_description(desc, refresh=True)
                test_loss, test_acc = self.test_model(testloader, criterion, device)
                best_test_loss = best_test_loss if best_test_acc > test_acc else test_loss
                best_test_acc = max(test_acc, best_test_acc)
            print('** Finetuning finished. Stabilizing batch norm and test again!')
            stablize_bn(self.model, trainloader)
            test_loss, test_acc = self.test_model(testloader, criterion, device)
            best_test_loss = best_test_loss if best_test_acc > test_acc else test_loss
            best_test_acc = max(test_acc, best_test_acc)
            return best_test_loss, best_test_acc

        def test_model(self, dataloader, criterion, device='cuda'):
            self.model = self.model.eval()
            correct = 0
            total = 0
            all_loss = 0
            desc = ('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (0, 0, correct, total))
            prog_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=desc, leave=True)
            for batch_idx, (inputs, targets) in prog_bar:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = self.model(inputs)
                loss = criterion(outputs, targets)
                all_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                desc = ('Loss: %.3f | Acc: %.3f%% (%d/%d)' %
                        (all_loss / (batch_idx + 1), 100. * correct / total, correct, total))
                prog_bar.set_description(desc, refresh=True)
            return all_loss / (batch_idx + 1), 100. * correct / total