import re
import math
import torch
from collections import Iterable


def prune_vanilla_elementwise(param, sparsity, fn_importance=lambda x: x.abs()):
    """
    element-wise vanilla pruning
    :param param: torch.(cuda.)Tensor, weight of conv/fc layer
    :param sparsity: float, pruning sparsity
    :param fn_importance: function, inputs 'param' and returns the importance of
                                    each position in 'param',
                                    default=lambda x: x.abs()
    :return:
        torch.(cuda.)ByteTensor, mask for zeros
    """
    sparsity = min(max(0.0, sparsity), 1.0)
    if sparsity == 1.0:
        return torch.zeros_like(param).byte()
    num_el = param.numel()
    importance = fn_importance(param)
    num_pruned = int(math.ceil(num_el * sparsity))
    num_stayed = num_el - num_pruned
    if sparsity <= 0.5:
        _, topk_indices = torch.topk(importance.view(num_el), k=num_pruned,
                                     dim=0, largest=False, sorted=False)
        mask = torch.zeros_like(param).byte()
        param.view(num_el).index_fill_(0, topk_indices, 0)
        mask.view(num_el).index_fill_(0, topk_indices, 1)
    else:
        thr = torch.min(torch.topk(importance.view(num_el), k=num_stayed,
                                   dim=0, largest=True, sorted=False)[0])
        mask = torch.lt(importance, thr)
        param.masked_fill_(mask, 0)
    return mask


def prune_vanilla_kernelwise(param, sparsity, fn_importance=lambda x: x.norm(1, -1)):
    """
    kernel-wise vanilla pruning, the importance determined by L1 norm
    :param param: torch.(cuda.)Tensor, weight of conv/fc layer
    :param sparsity: float, pruning sparsity
    :param fn_importance: function, inputs 'param' as size (param.size(0) * param.size(1), -1) and
                                    returns the importance of each kernel in 'param',
                                    default=lambda x: x.norm(1, -1)
    :return:
        torch.(cuda.)ByteTensor, mask for zeros
    """
    assert param.dim() >= 3
    sparsity = min(max(0.0, sparsity), 1.0)
    if sparsity == 1.0:
        return torch.zeros_like(param).byte()
    num_kernels = param.size(0) * param.size(1)
    param_k = param.view(num_kernels, -1)
    param_importance = fn_importance(param_k)
    num_pruned = int(math.ceil(num_kernels * sparsity))
    _, topk_indices = torch.topk(param_importance, k=num_pruned,
                                 dim=0, largest=False, sorted=False)
    mask = torch.zeros_like(param).byte()
    mask_k = mask.view(num_kernels, -1)
    param_k.index_fill_(0, topk_indices, 0)
    mask_k.index_fill_(0, topk_indices, 1)
    return mask


def prune_vanilla_filterwise(sparsity, param, fn_importance=lambda x: x.norm(1, -1)):
    """
    filter-wise vanilla pruning, the importance determined by L1 norm
    :param param: torch.(cuda.)Tensor, weight of conv/fc layer
    :param sparsity: float, pruning sparsity
    :param fn_importance: function, inputs 'param' as size (param.size(0), -1) and
                                returns the importance of each filter in 'param',
                                default=lambda x: x.norm(1, -1)
    :return:
        torch.(cuda.)ByteTensor, mask for zeros
    """
    assert param.dim() >= 3
    sparsity = min(max(0.0, sparsity), 1.0)
    if sparsity == 1.0:
        return torch.zeros_like(param).byte()
    num_filters = param.size(0)
    param_k = param.view(num_filters, -1)
    param_importance = fn_importance(param_k)
    num_pruned = int(math.ceil(num_filters * sparsity))
    _, topk_indices = torch.topk(param_importance, k=num_pruned,
                                 dim=0, largest=False, sorted=False)
    mask = torch.zeros_like(param).byte()
    mask_k = mask.view(num_filters, -1)
    param_k.index_fill_(0, topk_indices, 0)
    mask_k.index_fill_(0, topk_indices, 1)
    return mask


class VanillaPruner(object):

    def __init__(self, rule=None):
        """
        Pruner Class for Vanilla Pruning Method
        :param rule: str, path to the rule file, each line formats
                          'param_name granularity sparsity_stage_0, sparstiy_stage_1, ...'
                     list of tuple, [(param_name(str), granularity(str),
                                      sparsity(float) or [sparsity_stage_0(float), sparstiy_stage_1,],
                                      fn_importance(optional, str or function))]
                     'granularity': str, choose from ['element', 'kernel', 'filter']
                     'fn_importance': str, choose from ['abs', 'l1norm', 'l2norm']
        """
        if rule:
            if isinstance(rule, str):
                content = map(lambda x: x.split(), open(rule).readlines())
                content = filter(lambda x: len(x) == 3, content)
                rule = list(map(lambda x: (x[0], x[1], list(map(float, x[2].split(',')))), content))
            for r in rule:
                if not isinstance(r[2], Iterable):
                    assert isinstance(r[2], float) or isinstance(r[2], int)
                    r[2] = [float(r[2])]
                if len(r) == 3:
                    r.append('default')
                granularity = r[1]
                if granularity == 'element':
                    r.append(prune_vanilla_elementwise)
                elif granularity == 'kernel':
                    r.append(prune_vanilla_kernelwise)
                elif granularity == 'filter':
                    r.append(prune_vanilla_filterwise)
                else:
                    raise NotImplementedError

        self.rule = rule

        self.masks = dict()

        print("=" * 89)
        if self.rule:
            print("Initializing Vanilla Pruner with rules:")
            for r in self.rule:
                print(r[:-1])
        else:
            print("Initializing Vanilla Pruner WITHOUT rules")
        print("=" * 89)

    def load_state_dict(self, state_dict, replace_rule=True):
        """
        Recover Pruner
        :param state_dict: dict, a dictionary containing a whole state of the Pruner
        :param replace_rule: bool, whether to use rule settings in 'state_dict'
        :return: VanillaPruner
        """
        if replace_rule:
            self.rule = state_dict['rule']
            for r in self.rule:
                granularity = r[1]
                if granularity == 'element':
                    r.append(prune_vanilla_elementwise)
                elif granularity == 'kernel':
                    r.append(prune_vanilla_kernelwise)
                elif granularity == 'filter':
                    r.append(prune_vanilla_filterwise)
                else:
                    raise NotImplementedError
        self.masks = state_dict['masks']
        print("=" * 89)
        print("Customizing Vanilla Pruner with rules:")
        for r in self.rule:
            print(r[:-1])
        print("=" * 89)

    def state_dict(self):
        """
        Returns a dictionary containing a whole state of the Pruner
        :return: dict, a dictionary containing a whole state of the Pruner
        """
        state_dict = dict()
        state_dict['rule'] = [r[:-1] for r in self.rule]
        state_dict['masks'] = self.masks
        return state_dict

    def prune_param(self, param, param_name, stage=0, verbose=False):
        """
        prune parameter
        :param param: torch.(cuda.)tensor
        :param param_name: str, name of param
        :param stage: int, the pruning stage, default=0
        :param verbose: bool, whether to print the pruning details
        :return:
            torch.(cuda.)ByteTensor, mask for zeros
        """
        rule_id = -1
        for idx, r in enumerate(self.rule):
            m = re.match(r[0], param_name)
            if m is not None and len(param_name) == m.span()[1]:
                rule_id = idx
                break
        if rule_id > -1:
            sparsity = self.rule[rule_id][2][stage]
            fn_prune = self.rule[rule_id][-1]
            fn_importance = self.rule[rule_id][3]
            if verbose:
                print("{param_name:^30} | {stage:5d} | {spars:.3f}".
                      format(param_name=param_name, stage=stage, spars=sparsity))
            if fn_importance is None or fn_importance == 'default':
                mask = fn_prune(param=param, sparsity=sparsity)
            elif fn_importance == 'abs':
                mask = fn_prune(param=param, sparsity=sparsity, fn_importance=lambda x: x.abs())
            elif fn_importance == 'l1norm':
                mask = fn_prune(param=param, sparsity=sparsity, fn_importance=lambda x: x.norm(1, -1))
            elif fn_importance == 'l2norm':
                mask = fn_prune(param=param, sparsity=sparsity, fn_importance=lambda x: x.norm(2, -1))
            else:
                mask = fn_prune(param=param, sparsity=sparsity, fn_importance=fn_importance)
            return mask
        else:
            if verbose:
                print("{param_name:^30} | skipping".format(param_name=param_name))
            return None

    def prune(self, model, stage=0, update_masks=False, verbose=False):
        """
        prune models
        :param model: torch.nn.Module
        :param stage: int, the pruning stage, default=0
        :param update_masks: bool, whether update masks
        :param verbose: bool, whether to print the pruning details
        :return:
            void
        """
        update_masks = True if update_masks or len(self.masks) == 0 else False
        if verbose:
            print("=" * 89)
            print("Pruning Models")
            if len(self.masks) == 0:
                print("Initializing Masks")
            elif update_masks:
                print("Updating Masks")
            print("=" * 89)
            print("{name:^30} | stage | sparsity".format(name='param_name'))
        for param_name, param in model.named_parameters():
            if 'AuxLogits' not in param_name:
                # deal with googlenet
                if param.dim() > 1:
                    if update_masks:
                        mask = self.prune_param(param=param.data, param_name=param_name,
                                                stage=stage, verbose=verbose)
                        if mask is not None:
                            self.masks[param_name] = mask
                    else:
                        if param_name in self.masks:
                            mask = self.masks[param_name]
                            param.data.masked_fill_(mask, 0)
        if verbose:
            print("=" * 89)