# -*- coding: utf-8 -*-
# @Author  : DevinYang(pistonyang@gmail.com)
__all__ = ['summary']

from collections import OrderedDict
import torch
import torch.nn as nn
import numpy as np


def _flops_str(flops):
    preset = [(1e12, 'T'), (1e9, 'G'), (1e6, 'M'), (1e3, 'K')]

    for p in preset:
        if flops // p[0] > 0:
            N = flops / p[0]
            ret = "%.1f%s" % (N, p[1])
            return ret
    ret = "%.1f" % flops
    return ret


def _cac_grad_params(p, w):
    t, n = 0, 0
    if w.requires_grad:
        t += p
    else:
        n += p
    return t, n


def _cac_conv(layer, input, output):
    # bs, ic, ih, iw = input[0].shape
    oh, ow = output.shape[-2:]
    kh, kw = layer.kernel_size
    ic, oc = layer.in_channels, layer.out_channels
    g = layer.groups

    tb_params = 0
    ntb__params = 0
    flops = 0
    if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'):
        params = np.prod(layer.weight.shape)
        t, n = _cac_grad_params(params, layer.weight)
        tb_params += t
        ntb__params += n
        flops += (2 * ic * kh * kw - 1) * oh * ow * (oc // g)
    if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'):
        params = np.prod(layer.bias.shape)
        t, n = _cac_grad_params(params, layer.bias)
        tb_params += t
        ntb__params += n
        flops += oh * ow * (oc // g)
    return tb_params, ntb__params, flops


def _cac_xx_norm(layer, input, output):
    tb_params = 0
    ntb__params = 0
    if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'):
        params = np.prod(layer.weight.shape)
        t, n = _cac_grad_params(params, layer.weight)
        tb_params += t
        ntb__params += n
    if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'):
        params = np.prod(layer.bias.shape)
        t, n = _cac_grad_params(params, layer.bias)
        tb_params += t
        ntb__params += n
    if hasattr(layer, 'running_mean') and hasattr(layer.running_mean, 'shape'):
        params = np.prod(layer.running_mean.shape)
        ntb__params += params
    if hasattr(layer, 'running_var') and hasattr(layer.running_var, 'shape'):
        params = np.prod(layer.running_var.shape)
        ntb__params += params
    in_shape = input[0]
    flops = np.prod(in_shape.shape)
    if layer.affine:
        flops *= 2
    return tb_params, ntb__params, flops


def _cac_linear(layer, input, output):
    ic, oc = layer.in_features, layer.out_features

    tb_params = 0
    ntb__params = 0
    flops = 0
    if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'):
        params = np.prod(layer.weight.shape)
        t, n = _cac_grad_params(params, layer.weight)
        tb_params += t
        ntb__params += n
        flops += (2 * ic - 1) * oc
    if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'):
        params = np.prod(layer.bias.shape)
        t, n = _cac_grad_params(params, layer.bias)
        tb_params += t
        ntb__params += n
        flops += oc
    return tb_params, ntb__params, flops


@torch.no_grad()
def summary(model, x, return_results=False):
    """

    Args:
        model (nn.Module): model to summary
        x (torch.Tensor): input data
        return_results (bool): return results

    Returns:

    """
    # change bn work way
    model.eval()

    def register_hook(layer):

        def hook(layer, input, output):
            model_name = str(layer.__class__.__name__)
            module_idx = len(model_summary)
            s_key = '{}-{}'.format(model_name, module_idx + 1)
            model_summary[s_key] = OrderedDict()
            model_summary[s_key]['input_shape'] = list(input[0].shape)
            if isinstance(output, (tuple, list)):
                model_summary[s_key]['output_shape'] = [
                    list(o.shape) for o in output
                ]
            else:
                model_summary[s_key]['output_shape'] = list(output.shape)
            tb_params = 0
            ntb__params = 0
            flops = 0

            if isinstance(layer, nn.Conv2d):
                tb_params, ntb__params, flops = _cac_conv(layer, input, output)
            elif isinstance(layer, (nn.BatchNorm2d, nn.GroupNorm)):
                tb_params, ntb__params, flops = _cac_xx_norm(
                    layer, input, output)
            elif isinstance(layer, nn.Linear):
                tb_params, ntb__params, flops = _cac_linear(
                    layer, input, output)

            model_summary[s_key]['trainable_params'] = tb_params
            model_summary[s_key]['non_trainable_params'] = ntb__params
            model_summary[s_key]['params'] = tb_params + ntb__params
            model_summary[s_key]['flops'] = flops

        if not isinstance(layer, (nn.Sequential, nn.ModuleList,
                                  nn.Identity, nn.ModuleDict)):
            hooks.append(layer.register_forward_hook(hook))

    model_summary = OrderedDict()
    hooks = []
    model.apply(register_hook)
    model(x)
    for h in hooks:
        h.remove()

    print('-' * 80)
    line_new = "{:>20}  {:>25} {:>15} {:>15}".format(
        "Layer (type)", "Output Shape", "Params", "FLOPs(M+A) #")
    print(line_new)
    print('=' * 80)
    total_params = 0
    trainable_params = 0
    total_flops = 0
    for layer in model_summary:
        line_new = "{:>20}  {:>25} {:>15} {:>15}".format(
            layer,
            str(model_summary[layer]['output_shape']),
            model_summary[layer]['params'],
            model_summary[layer]['flops'],
        )
        print(line_new)
        total_params += model_summary[layer]['params']
        trainable_params += model_summary[layer]['trainable_params']
        total_flops += model_summary[layer]['flops']

    param_str = _flops_str(total_params)
    flop_str = _flops_str(total_flops)
    flop_str_m = _flops_str(total_flops // 2)
    param_size = total_params * 4 / (1024 ** 2)
    print('=' * 80)
    print('        Total parameters: {:,}  {}'.format(total_params, param_str))
    print('    Trainable parameters: {:,}'.format(trainable_params))
    print(
        'Non-trainable parameters: {:,}'.format(total_params - trainable_params))
    print('Total flops(M)  : {:,}  {}'.format(total_flops // 2, flop_str_m))
    print('Total flops(M+A): {:,}  {}'.format(total_flops, flop_str))
    print('-' * 80)
    print('Parameters size (MB): {:.2f}'.format(param_size))
    if return_results:
        return total_params, total_flops