from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url

from dropblock import DropBlockScheduled, DropBlock2D


def swish(x):
    return x * x.sigmoid()


def hard_sigmoid(x, inplace=False):
    return F.relu6(x + 3, inplace) / 6


def hard_swish(x, inplace=False):
    return x * hard_sigmoid(x, inplace)


class HardSigmoid(nn.Module):
    def __init__(self, inplace=False):
        super(HardSigmoid, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return hard_sigmoid(x, inplace=self.inplace)


class HardSwish(nn.Module):
    def __init__(self, inplace=False):
        super(HardSwish, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return hard_swish(x, inplace=self.inplace)


def _make_divisible(v, divisor, min_value=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


# https://github.com/jonnedtc/Squeeze-Excitation-PyTorch/blob/master/networks.py
class SqEx(nn.Module):

    def __init__(self, n_features, reduction=4):
        super(SqEx, self).__init__()

        if n_features % reduction != 0:
            raise ValueError('n_features must be divisible by reduction (default = 4)')

        self.linear1 = nn.Linear(n_features, n_features // reduction, bias=True)
        self.nonlin1 = nn.ReLU(inplace=True)
        self.linear2 = nn.Linear(n_features // reduction, n_features, bias=True)
        self.nonlin2 = HardSigmoid(inplace=True)

    def forward(self, x):
        y = F.avg_pool2d(x, kernel_size=x.size()[2:4])
        y = y.permute(0, 2, 3, 1)
        y = self.nonlin1(self.linear1(y))
        y = self.nonlin2(self.linear2(y))
        y = y.permute(0, 3, 1, 2)
        y = x * y
        return y


class LinearBottleneck(nn.Module):
    def __init__(self, inplanes, outplanes, expplanes, k=3, stride=1, drop_prob=0, num_steps=3e5, start_step=0,
                 activation=nn.ReLU, act_params={"inplace": True}, SE=False):
        super(LinearBottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, expplanes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(expplanes)
        self.db1 = DropBlockScheduled(DropBlock2D(drop_prob=0, block_size=7), start_value=0.,
                                      stop_value=drop_prob, nr_steps=num_steps, start_step=start_step)
        self.act1 = activation(**act_params)  # first does have act according to MobileNetV2

        self.conv2 = nn.Conv2d(expplanes, expplanes, kernel_size=k, stride=stride, padding=k // 2, bias=False,
                               groups=expplanes)
        self.bn2 = nn.BatchNorm2d(expplanes)
        self.db2 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0.,
                                      stop_value=drop_prob, nr_steps=num_steps, start_step=start_step)
        self.act2 = activation(**act_params)

        self.se = SqEx(expplanes) if SE else lambda x: x

        self.conv3 = nn.Conv2d(expplanes, outplanes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(outplanes)
        self.db3 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0.,
                                      stop_value=drop_prob, nr_steps=num_steps, start_step=start_step)
        # self.act3 = activation(**act_params)  # works worse

        self.stride = stride
        self.expplanes = expplanes
        self.inplanes = inplanes
        self.outplanes = outplanes

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.db1(out)
        out = self.act1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.db2(out)
        out = self.act2(out)

        out = self.se(out)

        out = self.conv3(out)
        out = self.bn3(out)
        out = self.db3(out)
        # out = self.act3(out)

        if self.stride == 1 and self.inplanes == self.outplanes:  # TODO: or add 1x1?
            out += residual  # No inplace if there is in-place activation before

        return out


class LastBlockLarge(nn.Module):
    def __init__(self, inplanes, num_classes, expplanes1, expplanes2):
        super(LastBlockLarge, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, expplanes1, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(expplanes1)
        self.act1 = HardSwish(inplace=True)

        self.avgpool = nn.AdaptiveAvgPool2d(1)

        self.conv2 = nn.Conv2d(expplanes1, expplanes2, kernel_size=1, stride=1)
        self.act2 = HardSwish(inplace=True)

        self.dropout = nn.Dropout(p=0.2, inplace=True)
        self.fc = nn.Linear(expplanes2, num_classes)

        self.expplanes1 = expplanes1
        self.expplanes2 = expplanes2
        self.inplanes = inplanes
        self.num_classes = num_classes

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.act1(out)

        out = self.avgpool(out)

        out = self.conv2(out)
        out = self.act2(out)

        # flatten for input to fully-connected layer
        out = out.view(out.size(0), -1)
        out = self.fc(self.dropout(out))

        return out


class LastBlockSmall(nn.Module):
    def __init__(self, inplanes, num_classes, expplanes1, expplanes2):
        super(LastBlockSmall, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, expplanes1, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(expplanes1)
        self.act1 = HardSwish(inplace=True)

        self.se = SqEx(expplanes1)

        self.avgpool = nn.AdaptiveAvgPool2d(1)

        self.conv2 = nn.Conv2d(expplanes1, expplanes2, kernel_size=1, stride=1, bias=False)
        self.act2 = HardSwish(inplace=True)

        self.dropout = nn.Dropout(p=0.2, inplace=True)
        self.fc = nn.Linear(expplanes2, num_classes)

        self.expplanes1 = expplanes1
        self.expplanes2 = expplanes2
        self.inplanes = inplanes
        self.num_classes = num_classes

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.act1(out)

        out = self.se(out)
        out = self.avgpool(out)

        out = self.conv2(out)
        out = self.act2(out)

        # flatten for input to fully-connected layer
        out = out.view(out.size(0), -1)
        out = self.fc(self.dropout(out))

        return out


class MobileNetV3(nn.Module):
    """MobileNetV3 implementation.
    """

    def __init__(self, num_classes=1000, scale=1., in_channels=3, drop_prob=0.0, num_steps=3e5, start_step=0,
                 small=False):
        super(MobileNetV3, self).__init__()

        self.num_steps = num_steps
        self.start_step = start_step
        self.scale = scale
        self.num_classes = num_classes
        self.small = small

        # setting of bottlenecks blocks
        self.bottlenecks_setting_large = [
            # in, exp, out, s, k,         dp,    se,      act
            [16, 16, 16, 1, 3, 0, False, nn.ReLU],  # -> 112x112
            [16, 64, 24, 2, 3, 0, False, nn.ReLU],  # -> 56x56
            [24, 72, 24, 1, 3, 0, False, nn.ReLU],  # -> 56x56
            [24, 72, 40, 2, 5, 0, True, nn.ReLU],  # -> 28x28
            [40, 120, 40, 1, 5, 0, True, nn.ReLU],  # -> 28x28
            [40, 120, 40, 1, 5, 0, True, nn.ReLU],  # -> 28x28
            [40, 240, 80, 2, 3, drop_prob, False, HardSwish],  # -> 14x14
            [80, 200, 80, 1, 3, drop_prob, False, HardSwish],  # -> 14x14
            [80, 184, 80, 1, 3, drop_prob, False, HardSwish],  # -> 14x14
            [80, 184, 80, 1, 3, drop_prob, False, HardSwish],  # -> 14x14
            [80, 480, 112, 1, 3, drop_prob, True, HardSwish],  # -> 14x14
            [112, 672, 112, 1, 3, drop_prob, True, HardSwish],  # -> 14x14
            [112, 672, 160, 2, 5, drop_prob, True, HardSwish],  # -> 7x7
            [160, 960, 160, 1, 5, drop_prob, True, HardSwish],  # -> 7x7
            [160, 960, 160, 1, 5, drop_prob, True, HardSwish],  # -> 7x7
        ]
        self.bottlenecks_setting_small = [
            # in, exp, out, s, k,         dp,    se,      act
            [16, 64, 16, 2, 3, 0, True, nn.ReLU],  # -> 56x56
            [16, 72, 24, 2, 3, 0, False, nn.ReLU],  # -> 28x28
            [24, 88, 24, 1, 3, 0, False, nn.ReLU],  # -> 28x28
            [24, 96, 40, 2, 5, 0, True, HardSwish],  # -> 14x14
            [40, 240, 40, 1, 5, drop_prob, True, HardSwish],  # -> 14x14
            [40, 240, 40, 1, 5, drop_prob, True, HardSwish],  # -> 14x14
            [40, 120, 48, 1, 5, drop_prob, True, HardSwish],  # -> 14x14
            [48, 144, 96, 1, 5, drop_prob, True, HardSwish],  # -> 14x14
            [96, 288, 96, 2, 5, drop_prob, True, HardSwish],  # -> 7x7
            [96, 576, 96, 1, 5, drop_prob, True, HardSwish],  # -> 7x7
            [96, 576, 96, 1, 5, drop_prob, True, HardSwish],  # -> 7x7
        ]

        self.bottlenecks_setting = self.bottlenecks_setting_small if small else self.bottlenecks_setting_large
        for l in self.bottlenecks_setting:
            l[0] = _make_divisible(l[0] * self.scale, 8)
            l[1] = _make_divisible(l[1] * self.scale, 8)
            l[2] = _make_divisible(l[2] * self.scale, 8)

        self.conv1 = nn.Conv2d(in_channels, self.bottlenecks_setting[0][0], kernel_size=3, bias=False, stride=2,
                               padding=1)
        self.bn1 = nn.BatchNorm2d(self.bottlenecks_setting[0][0])
        self.act1 = HardSwish(inplace=True)
        self.bottlenecks = self._make_bottlenecks()

        # Last convolution has 1280 output channels for scale <= 1
        self.last_exp2 = 1280 if self.scale <= 1 else _make_divisible(1280 * self.scale, 8)
        if small:
            self.last_exp1 = _make_divisible(576 * self.scale, 8)
            self.last_block = LastBlockSmall(self.bottlenecks_setting[-1][2], num_classes, self.last_exp1,
                                             self.last_exp2)
        else:
            self.last_exp1 = _make_divisible(960 * self.scale, 8)
            self.last_block = LastBlockLarge(self.bottlenecks_setting[-1][2], num_classes, self.last_exp1,
                                             self.last_exp2)

    def _make_bottlenecks(self):
        modules = OrderedDict()
        stage_name = "Bottleneck"

        # add LinearBottleneck
        for i, setup in enumerate(self.bottlenecks_setting):
            name = stage_name + "_{}".format(i)
            module = LinearBottleneck(setup[0], setup[2], setup[1], k=setup[4], stride=setup[3], drop_prob=setup[5],
                                      num_steps=self.num_steps, start_step=self.start_step, activation=setup[7],
                                      act_params={"inplace": True}, SE=setup[6])
            modules[name] = module

        return nn.Sequential(modules)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act1(x)

        x = self.bottlenecks(x)
        x = self.last_block(x)
        return x


# TODO
model_urls = {
    'mobilenetv3_large_1.0_224': 'https://github.com/Randl/MobileNetV3-pytorch/blob/master/results/mobilenetv3large-v1/model_best0-ec869f9b.pth',
}


def mobilenetv3(input_size=224, num_classes=1000, scale=1., in_channels=3, drop_prob=0.0, num_steps=3e5, start_step=0,
                small=False, get_weights=True, progress=True):
    model = MobileNetV3(num_classes=num_classes, scale=scale, in_channels=in_channels, drop_prob=drop_prob,
                        num_steps=num_steps, start_step=start_step, small=small)
    name = 'mobilenetv3_{}_{}_{}'.format('small' if small else 'large', scale, input_size)
    if get_weights:
        if name in model_urls:
            state_dict = load_state_dict_from_url(model_urls[name], progress=progress, map_location='cpu')
            model.load_state_dict(state_dict)
        else:
            raise ValueError
    return model

if __name__ == "__main__":
    """Testing
    """
    model1 = MobileNetV3()
    print(model1)
    model2 = MobileNetV3(scale=0.35)
    print(model2)
    model3 = MobileNetV3(in_channels=2, num_classes=10)
    print(model3)
    x = torch.randn(1, 2, 224, 224)
    print(model3(x))
    model4_size = 32 * 10
    model4 = MobileNetV3(num_classes=10)
    print(model4)
    x2 = torch.randn(1, 3, model4_size, model4_size)
    print(model4(x2))
    model5 = MobileNetV3(scale=0.35, small=True)
    print(model2)