import torch
from torch.autograd import Variable
import torch.nn as nn
import prunable_nn as pnn
import torch.utils.model_zoo as model_zoo
from torchvision import models
from operator import itemgetter


class VGG(models.VGG):

    def __init__(self, features, num_classes=1000):
        super().__init__(features, num_classes)

    def pruning(self, flag):
        prunable = [module for module in self.features
                    if getattr(module, "prune_feature_map", False) and module.out_channels > 1]
        for p in prunable:
            p.pruning(flag)

    def prune(self):
        # gather all modules & their indices. (excluding classifier)
        # gather all talyor_estimate_lists & pair with the indices
        # gather all talyor_estimates & pair with their list index & module index
        # reduce to the minimum in the list
        # grab the module with the minimum & prune
        # grab the PBatchNorm & adjust
        # adjust the next layer too

        feature_list = list(enumerate(self.features))
        # grab the taylor estimates of PConv2ds & pair with the module's index in self.features
        taylor_estimates_by_module = [(module.taylor_estimates, module_idx) for module_idx, module in feature_list
                                      if issubclass(type(module), pnn.PConv2d) and module.out_channels > 1]
        taylor_estimates_by_feature_map = \
            [(estimate, map_idx, module_idx)
             for estimates_by_map, module_idx in taylor_estimates_by_module
             for map_idx, estimate in enumerate(estimates_by_map)]

        _, min_map_idx, min_module_idx = min(taylor_estimates_by_feature_map, key=itemgetter(0))

        p_conv2d = self.features[min_module_idx]
        p_conv2d.prune_feature_map(min_map_idx)

        p_batchnorm = self.features[min_module_idx+1]
        p_batchnorm.drop_input_channel(min_map_idx)

        offset = 3 # batchnorm, relu, maxpool
        is_last_conv2d = (len(feature_list)-1)-offset == min_module_idx
        is_double_conv2d_layer = min_module_idx == 8 or min_module_idx == 15 or min_module_idx == 22
        if is_last_conv2d:
            first_p_linear = self.classifier[0]
            shape = (first_p_linear.in_features//49, 7, 7) # the input is always ?x7x7
            first_p_linear.drop_inputs(shape, min_map_idx)
        elif is_double_conv2d_layer:
            # no max pool,
            next_p_conv2d = self.features[min_module_idx+offset]
            next_p_conv2d.drop_input_channel(min_map_idx)
        else:
            next_p_conv2d = self.features[min_module_idx+offset+1]
            next_p_conv2d.drop_input_channel(min_map_idx)

def vgg_model(num_classes):

    def make_layers(cfg, batch_norm=False):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                if batch_norm:
                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
            return nn.Sequential(*layers)

    cfg = {'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],}
    model_url = 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth'
    model = VGG(make_layers(cfg['A'], batch_norm=True))
    model.load_state_dict(model_zoo.load_url(model_url), strict=False)

    model.classifier = nn.Sequential(
        nn.Linear(512 * 7 * 7, 4096),
        nn.ReLU(True),
        nn.Dropout(),
        nn.Linear(4096, 4096),
        nn.ReLU(True),
        nn.Dropout(),
        nn.Linear(4096, num_classes),
    )
    return model, 'vgg11_bn'


def chinese_model(num_classes):
    return ChineseNet(num_classes), 'chinese_net'


def chinese_pruned_80(num_classes):
    config = [26, 'M', 39, 'M', 52, 'M', 75, 93, 'M', 88, 95, 'M']
    return ChineseNet(num_classes, config), 'chinese_net_80'


def chinese_pruned_90(num_classes):
    config = [15, 'M', 14, 'M', 20, 'M', 27, 31, 'M', 28, 30, 'M']
    return ChineseNet(num_classes, config), 'chinese_net_90'


class ChineseNet(nn.Module):
    # inspired by https://arxiv.org/abs/1702.07975, used for chinese ocr
    def __init__(self, num_classes, config=None):
        super(ChineseNet, self).__init__()
        self.config = [96, 'M', 128, 'M', 160, 'M', 256, 256, 'M', 384, 384, 'M'] if config is None else config
        self.features = self.make_layers()
        self.classifier = nn.Sequential(
            # input is 96x96, output from features section should always be 2x2
            pnn.PLinear(self.config[-2]*2*2, 1024),
            nn.BatchNorm1d(1024),
            nn.PReLU(),
            nn.Dropout(),
            nn.Linear(1024, num_classes)
        )
        self.convert_to_onnx = False
        self.__pruning = False

    def pruning(self, flag):
        self.__pruning = flag
        prunable = [module for module in self.features
                    if getattr(module, "prune_feature_map", False) and module.out_channels > 1]
        for p in prunable:
            p.pruning(flag)

    def make_layers(self):
        layers = []
        in_channels = 1
        for v in self.config:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=3, stride=2)]
            else:
                conv2d = pnn.PConv2d(in_channels, v, kernel_size=3, padding=1)
                layers += [conv2d, pnn.PBatchNorm2d(v), nn.PReLU()]
                in_channels = v
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        if self.convert_to_onnx:
            x = self.classifier[0](x)

            # manually perform 1d batchnorm, caffe2 currently requires a resize,
            # which is hard to squeeze into the exported network
            bn_1d = self.classifier[1]
            numerator = (x - Variable(bn_1d.running_mean))
            denominator = Variable(torch.sqrt(bn_1d.running_var + bn_1d.eps))
            x = numerator/denominator*Variable(bn_1d.weight.data) + Variable(bn_1d.bias.data)

            x = self.classifier[2](x)
            x = self.classifier[3](x)
            x = self.classifier[4](x)
            return x
        else:
            x = self.classifier(x)
            return x

    def prune(self):
        # gather all modules & their indices. (excluding classifier)
        # gather all talyor_estimate_lists & pair with the indices
        # gather all talyor_estimates & pair with their list index & module index
        # reduce to the minimum in the list
        # grab the module with the minimum & prune
        # grab the PBatchNorm & adjust
        # adjust the next layer too

        feature_list = list(enumerate(self.features))
        # grab the taylor estimates of PConv2ds & pair with the module's index in self.features
        taylor_estimates_by_module = [(module.taylor_estimates, module_idx) for module_idx, module in feature_list
                                      if getattr(module, "prune_feature_map", False) and module.out_channels > 1]

        taylor_estimates_by_feature_map = \
            [(estimate, map_idx, module_idx)
             for estimates_by_map, module_idx in taylor_estimates_by_module
             for map_idx, estimate in enumerate(estimates_by_map)]

        _, min_map_idx, min_module_idx = min(taylor_estimates_by_feature_map, key=itemgetter(0))

        p_conv2d = self.features[min_module_idx]
        p_conv2d.prune_feature_map(min_map_idx)

        p_batchnorm = self.features[min_module_idx+1]
        p_batchnorm.drop_input_channel(min_map_idx)

        offset = 3 # batchnorm & prelu & maxpool
        is_last_conv2d = (len(feature_list)-1)-offset == min_module_idx
        is_double_conv2d_layer = min_module_idx == 12 or min_module_idx == 19
        if is_last_conv2d:
            first_p_linear = self.classifier[0]
            shape = (first_p_linear.in_features//4, 2, 2) # the input is always ?x2x2
            first_p_linear.drop_inputs(shape, min_map_idx)
        elif is_double_conv2d_layer:
            # no max pool, -1
            next_p_conv2d = self.features[min_module_idx+offset]
            next_p_conv2d.drop_input_channel(min_map_idx)
        else:
            next_p_conv2d = self.features[min_module_idx+offset+1]
            next_p_conv2d.drop_input_channel(min_map_idx)