'''
Objects as Points
Mxnet adaptation of the Official Pytorch Implementation

Author: Guanghan Ning
Date: August, 2019
'''

from mxnet import gluon, init, nd
from mxnet.gluon import nn

"""
1. Basic Re-usable blocks
"""

class convolution(nn.Block):
    def __init__(self, kernel_size, channels_out, channels_in=0, strides=1, with_bn=True, **kwargs):
        super(convolution, self).__init__(**kwargs)
        paddings  = (kernel_size - 1)//2  # determine paddings to keep resolution unchanged
        with self.name_scope():
            self.conv = nn.Conv2D(channels_out, kernel_size, strides, paddings, in_channels=channels_in, use_bias= not with_bn) # infer input shape if not specified
            self.bn   = nn.BatchNorm(in_channels= channels_out) if with_bn else nn.Sequential()
            #self.bn   = nn.BatchNorm(in_channels= channels_out) if with_bn else nn.HybridSequential()

    def forward(self, X):
        conv = self.conv(X)
        bn   = self.bn(conv)
        return nd.relu(bn)


def test_convolution_shape():
    blk = convolution(kernel_size=3, channels_out=128)
    blk.initialize()
    X = nd.random.uniform(shape=(1, 64, 128, 128))
    Y = blk(X)
    print("\t Input shape: ", X.shape)
    print("\t output shape:", Y.shape)


class fully_connected(nn.Block):
    def __init__(self, channels_out, channels_in = 0, with_bn=True, **kwargs):
        super(fully_connected, self).__init__(**kwargs)
        with self.name_scope():
            self.with_bn = with_bn
            self.linear  = nn.Dense(channels_out, in_units=channels_in) if channels_in else nn.Dense(channels_out)
            if self.with_bn:
                self.bn  = nn.BatchNorm(in_channels=channels_out)

    def forward(self, X):
        linear = self.linear(X)
        bn = self.bn(linear) if self.with_bn else linear
        return nd.relu(bn)


def test_fully_connected_shape():
    blk = fully_connected(channels_out=128)
    blk.initialize()
    X   = nd.random.uniform(shape=(1, 2, 32, 32))
    Y   = blk(X)
    print("\t Input shape: ", X.shape)
    print("\t output shape:", Y.shape)


class residual(nn.Block):
    def __init__(self, kernel_size, channels_out, channels_in, stride=1, with_bn=True, **kwargs):
        #super(residual, self).__init__(**kwargs)
        super(residual, self).__init__()
        with self.name_scope():
            self.conv1 = nn.Conv2D(channels_out, kernel_size=(3,3), strides=(stride, stride), padding=(1,1), in_channels=channels_in, use_bias=False)
            self.bn1   = nn.BatchNorm(in_channels= channels_out)

            self.conv2 = nn.Conv2D(channels_out, kernel_size=(3,3), strides=(1, 1), padding=(1,1), in_channels = channels_out,use_bias=False)
            self.bn2   = nn.BatchNorm(in_channels= channels_out)

            #self.skip = nn.HybridSequential()
            self.skip = nn.Sequential()
            if stride != 1 or channels_in != channels_out:
                self.skip.add( nn.Conv2D(channels_out, kernel_size=(1,1), strides=(stride, stride), in_channels= channels_in, use_bias=False),
                               nn.BatchNorm(in_channels= channels_out)
                )

    def forward(self, X):
        conv1 = self.conv1(X)
        bn1   = self.bn1(conv1)
        relu1 = nd.relu(bn1)

        conv2 = self.conv2(relu1)
        bn2   = self.bn2(conv2)

        skip  = self.skip(X)
        return nd.relu(bn2 + skip)


def test_residual():
    blk = residual(kernel_size=3, channels_out=32, channels_in=64)
    blk.initialize()
    X   = nd.random.uniform(shape=(1, 64, 128, 128))
    Y   = blk(X)
    print("\t Input shape: ", X.shape)
    print("\t output shape:", Y.shape)


class bilinear_upsample(nn.Block):
    def __init__(self, scale_factor=2, **kwargs):
        super(bilinear_upsample, self).__init__(**kwargs)
        self.scale_factor = scale_factor

    def forward(self, X):
        height, width = X.shape[2:4]
        return nd.contrib.BilinearResize2D(X, height= height*self.scale_factor, width=width*self.scale_factor)


def test_bilinear_upsample():
    blk = bilinear_upsample(scale_factor=2)
    blk.initialize()
    X   = nd.random.uniform(shape=(1, 2, 128, 128))
    Y   = blk(X)
    print("\t Input shape: ", X.shape)
    print("\t output shape:", Y.shape)

"""
2. Utils to re-use basic blocks; Factories for repetitive computations
"""
def make_repeat_layers(kernel_size, channels_out, channels_in, num_modules, layer=convolution, **kwargs):
    layers = [layer(kernel_size, channels_out, channels_in, **kwargs)]
    for _ in range(1, num_modules):
        layers.append(layer(kernel_size, channels_out, channels_out, **kwargs))
    #sequential = nn.HybridSequential()
    sequential = nn.Sequential()
    sequential.add(*layers)
    return sequential

def make_repeat_layers_reverse(kernel_size, channels_out, channels_in, num_modules, layer=convolution, **kwargs):
    layers = [layer(kernel_size, channels_in, channels_in, **kwargs) for _ in range(num_modules-1)]
    layers.append(layer(kernel_size, channels_out, channels_in, **kwargs))
    #sequential = nn.HybridSequential()
    sequential = nn.Sequential()
    sequential.add(*layers)
    return sequential

class MergeUp(nn.Block):
    def forward(self, up1, up2):
        return up1 + up2

def test_MergeUp():
    blk = MergeUp()
    blk.initialize()
    X1   = nd.random.uniform(shape=(1, 64, 128, 128))
    X2   = nd.random.uniform(shape=(1, 64, 128, 128))
    Y   = blk(X1, X2)
    print("\t Input_1 shape: ", X1.shape)
    print("\t Input_2 shape: ", X2.shape)
    print("\t output shape:", Y.shape)


def make_merge_layer():
    return MergeUp()

def make_pool_layer():
    return nn.MaxPool2D(pool_size=2)

def make_unpool_layer():
    return bilinear_upsample(scale_factor=2)

def make_keypoint_layer(channels_out, channels_intermediate, channels_in):
    #sequential = nn.HybridSequential()
    sequential = nn.Sequential()
    sequential.add(convolution(kernel_size=3, channels_out=channels_intermediate, channels_in=channels_in, with_bn=False))
    sequential.add(nn.Conv2D(channels_out, kernel_size=1))
    return sequential

def make_inter_layer(channels):
    return residual(3, channels, channels)

def make_conv_layer(channels_out, channels_in):
    return convolution(3, channels_out, channels_in)

def make_hg_layer(kernel_size, channels_out, channels_in, mod, layer=convolution, **kwargs):
    layers = [layer(kernel_size, channels_out, channels_in, strides=2)]
    layers += [layer(kernel_size, channels_out, channels_out) for _ in range(mod-1)]
    #sequential = nn.HybridSequential()
    sequential = nn.Sequential()
    sequential.add(*layers)
    return sequential


"""
3. Structures that are higher-level than basic blocks
"""
class  keypoint_struct(nn.Block):
    def __init__(self,
                 level, dims, num_blocks,
                 layer= residual,
                 make_up_layer = make_repeat_layers, make_low_layer=make_repeat_layers,
                 make_hg_layer = make_repeat_layers, make_hg_layer_reverse=make_repeat_layers_reverse,
                 make_pool_layer = make_pool_layer, make_unpool_layer=make_unpool_layer,
                 make_merge_layer = make_merge_layer, **kwargs):
        super(keypoint_struct, self).__init__()

        self.level = level
        print("\t Level({})".format(self.level), dims)

        curr_num_blocks = num_blocks[0]
        next_num_blocks = num_blocks[1]

        curr_dim = dims[0]
        next_dim = dims[1]

        with self.name_scope():
            self.up1  = make_up_layer(kernel_size=3, channels_out=curr_dim, channels_in=curr_dim, num_modules=curr_num_blocks, layer=layer, **kwargs)
            self.max1 = make_pool_layer()
            self.low1 = make_hg_layer(3, next_dim, curr_dim, curr_num_blocks, layer=layer, **kwargs)
            self.low2 = keypoint_struct(
                level-1, dims[1:], num_blocks[1:], layer=layer, **kwargs
            ) if self.level > 1 else \
                make_low_layer(
                    3, next_dim, next_dim, next_num_blocks, layer=layer, **kwargs
                )
            self.low3 = make_hg_layer_reverse(
                3, curr_dim, next_dim, curr_num_blocks, layer=layer, **kwargs
            )
            self.up2 = make_unpool_layer()
            self.merge = make_merge_layer()

    def forward(self, X):
        up1  = self.up1(X)
        max1 = self.max1(X)
        low1 = self.low1(max1)
        low2 = self.low2(low1)
        low3 = self.low3(low2)
        up2  = self.up2(low3)
        return self.merge(up1, up2)


def test_keypoint_struct():
    level       = 5
    channels    = [256, 256, 384, 384, 384, 512]
    num_blocks  = [2, 2, 2, 2, 2, 4]

    blk = keypoint_struct(level, channels, num_blocks)
    blk.initialize()
    X   = nd.random.uniform(shape=(1, 256, 384, 384))
    Y   = blk(X)
    print("\t Input shape: ", X.shape)
    print("\t output shape:", Y.shape)

"""
4. Stacked Hourglass Network
"""
class stacked_hourglass(nn.Block):
    def __init__(self, level, num_stacks,
                 dims, num_blocks, heads,
                 pre=None, conv_dim=256,
                 make_conv_layer = make_conv_layer,     make_heat_layer    = make_keypoint_layer,
                 make_tag_layer  = make_keypoint_layer, make_regress_layer = make_keypoint_layer,
                 make_up_layer   = make_repeat_layers,  make_low_layer     = make_repeat_layers,
                 make_hg_layer   = make_repeat_layers,  make_hg_layer_reverse = make_repeat_layers_reverse,
                 make_pool_layer = make_pool_layer,     make_unpool_layer     = make_unpool_layer,
                 make_merge_layer= make_merge_layer,    make_inter_layer      = make_inter_layer,
                 kp_layer = residual
    ):
        super(stacked_hourglass, self).__init__()

        self.num_stacks = num_stacks
        self.heads      = heads

        curr_dim = dims[0]

        if pre is None:
            #self.pre = nn.HybridSequential()
            self.pre = nn.Sequential()
            with self.name_scope():
                self.pre.add(
                    convolution(7, 128, 3, strides=2),
                    residual(3, 256, 128, stride=2)
                )
        else:
            self.pre = pre

        #self.kpts = nn.HybridSequential()
        self.kpts = nn.Sequential()
        with self.name_scope():
            for _ in range(num_stacks):
                self.kpts.add(
                    keypoint_struct(level, dims, num_blocks,
                                    make_up_layer = make_up_layer,
                                    make_low_layer = make_low_layer,
                                    make_hg_layer = make_hg_layer,
                                    make_hg_layer_reverse = make_hg_layer_reverse,
                                    make_pool_layer = make_pool_layer,
                                    make_unpool_layer = make_unpool_layer,
                                    make_merge_layer = make_merge_layer
                    )
                )

        #self.convs = nn.HybridSequential()
        self.convs = nn.Sequential()
        with self.name_scope():
            for _ in range(num_stacks):
                self.convs.add(
                    make_conv_layer(conv_dim, curr_dim)
                )

        #self.inters = nn.HybridSequential()
        self.inters = nn.Sequential()
        with self.name_scope():
            for _ in range(num_stacks):
                self.inters.add(
                    make_inter_layer(curr_dim)
                )

        #self.inters_ = nn.HybridSequential()
        self.inters_ = nn.Sequential()
        with self.name_scope():
            for _ in range(num_stacks-1):
                #seq = nn.HybridSequential()
                seq = nn.Sequential()
                seq.add(
                    nn.Conv2D(curr_dim, (1,1), use_bias=False, in_channels=conv_dim),
                    nn.BatchNorm()
                )
                self.inters_.add(seq)

        #self.convs_ = nn.HybridSequential()
        self.convs_ = nn.Sequential()
        with self.name_scope():
            for _ in range(num_stacks-1):
                #seq = nn.HybridSequential()
                seq = nn.Sequential()
                seq.add(
                    nn.Conv2D(curr_dim, (1,1), use_bias=False, in_channels=conv_dim),
                    nn.BatchNorm()
                )
                self.convs_.add(seq)

        # keypoint heatmaps
        for head in heads.keys():
            if "hm" in head:
                #module = nn.HybridSequential()
                module = nn.Sequential()
                with self.name_scope():
                    for _ in range(num_stacks):
                        module.add(
                            make_heat_layer(channels_out=heads[head], channels_intermediate=curr_dim, channels_in=conv_dim)
                        )
                self.__setattr__(head, module)
                '''
                for heat in self.__getattribute__(head):
                #for heat in self.__getattr__(head):
                    #print("heat[-1]: ", heat[-1].bias.data)
                    #heat[-1].bias.data.fill_(-2.19)
                    heat[-1].bias.data = -2.19
                '''
            else:
                #module = nn.HybridSequential()
                module = nn.Sequential()
                with self.name_scope():
                    for _ in range(num_stacks):
                        module.add(
                            make_regress_layer(channels_out=heads[head], channels_intermediate=curr_dim, channels_in=conv_dim)
                        )
                self.__setattr__(head, module)

    def forward(self, img):
        inter = self.pre(img)
        #print("\t inter shape: ", inter.shape)
        outs = []

        for ind in range(self.num_stacks):
            kp_, conv_ = self.kpts[ind], self.convs[ind]
            kp = kp_(inter)
            conv = conv_(kp)
            #print("\t conv shape: ", conv.shape)

            out = {}
            for head in self.heads:
                layer = self.__getattribute__(head)[ind]
                y = layer(conv)
                out[head] = y
            outs.append(out)

            if ind < self.num_stacks - 1:
                inter = self.inters_[ind](inter) + self.convs_[ind](conv)
                inter = nd.relu(inter)
                inter = self.inters[ind](inter)
                #print("\t inter shape: ", inter.shape)
        return outs

def test_stacked_hourglass():
    level       = 5
    channels    = [256, 256, 384, 384, 384, 512]
    num_blocks  = [2, 2, 2, 2, 2, 4]
    num_stacks  = 2

    import sys
    sys.path.insert(0, "/export/guanghan/CenterNet-Gluon/")
    sys.path.insert(0, "/Users/guanghan.ning/Desktop/dev/CenterNet-Gluon/")
    from opts import opts
    opt = opts().init()
    print(opt.arch)
    print(opt.heads)

    blk = stacked_hourglass(level, num_stacks, channels, num_blocks, opt.heads)
    blk.initialize()
    X   = nd.random.uniform(shape=(1, 3, 512, 512))
    Y   = blk(X)
    print("\t Input shape: ", X.shape)
    print("\t output len:", len(Y))

"""
5. Network with specifications
"""
class HourglassNet(stacked_hourglass):
    def __init__(self, heads, num_stacks=2):
        level       = 5
        channels    = [256, 256, 384, 384, 384, 512]
        num_blocks  = [2, 2, 2, 2, 2, 4]

        super(HourglassNet, self).__init__(
            level, num_stacks, channels, num_blocks, heads,
            make_pool_layer = make_pool_layer,
            make_hg_layer = make_hg_layer,
            kp_layer= residual,
            conv_dim= 256
        )

"""
6. Constructor & interface for outside call
"""
def get_hourglass_net(num_layers, heads, head_conv, ctx):
    model = HourglassNet(heads, 2)
    return model


# test utils
def peek_network(net, input_shape=(224, 224)):
    w, h = input_shape
    X = nd.random.uniform(shape=(1,1,w,h))
    net.initialize()
    for layer in net:
        X = layer(X)
        print(layer.name, 'output shape: {}'.format(X.shape))
    return


def test_all():
    funcs = [test_convolution_shape,
             test_fully_connected_shape,
             test_residual,
             test_bilinear_upsample,
             test_MergeUp,
             test_keypoint_struct,
             test_stacked_hourglass
             ]
    for func in funcs:
        print("Testing routine: {}".format(func.__name__))
        func()


if __name__ == "__main__":
    test_all()