#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2020/3/27 17:42
# @Author  : JackyLUO
# @E-mail  : lingluo@stumail.neu.edu.cn
# @Site    : 
# @File    : lednet.py
# @Software: PyCharm

from keras import layers, models
import tensorflow as tf


class LEDNet:
    def __init__(self, groups, classes, input_shape):
        self.groups = groups
        self.classes = classes
        self.input_shape = input_shape

    def ss_bt(self, x, dilation, strides=(1, 1), padding='same'):
        x1, x2 = self.channel_split(x)
        filters = (int(x.shape[-1]) // self.groups)
        x1 = layers.Conv2D(filters, kernel_size=(3, 1), strides=strides, padding=padding)(x1)
        x1 = layers.Activation('relu')(x1)
        x1 = layers.Conv2D(filters, kernel_size=(1, 3), strides=strides, padding=padding)(x1)
        x1 = layers.BatchNormalization()(x1)
        x1 = layers.Activation('relu')(x1)
        x1 = layers.Conv2D(filters, kernel_size=(3, 1), strides=strides, padding=padding, dilation_rate=(dilation, 1))(
            x1)
        x1 = layers.Activation('relu')(x1)
        x1 = layers.Conv2D(filters, kernel_size=(1, 3), strides=strides, padding=padding, dilation_rate=(1, dilation))(
            x1)
        x1 = layers.BatchNormalization()(x1)
        x1 = layers.Activation('relu')(x1)

        x2 = layers.Conv2D(filters, kernel_size=(1, 3), strides=strides, padding=padding)(x2)
        x2 = layers.Activation('relu')(x2)
        x2 = layers.Conv2D(filters, kernel_size=(3, 1), strides=strides, padding=padding)(x2)
        x2 = layers.BatchNormalization()(x2)
        x2 = layers.Activation('relu')(x2)
        x2 = layers.Conv2D(filters, kernel_size=(1, 3), strides=strides, padding=padding, dilation_rate=(1, dilation))(
            x2)
        x2 = layers.Activation('relu')(x2)
        x2 = layers.Conv2D(filters, kernel_size=(3, 1), strides=strides, padding=padding, dilation_rate=(dilation, 1))(
            x2)
        x2 = layers.BatchNormalization()(x2)
        x2 = layers.Activation('relu')(x2)
        x_concat = layers.concatenate([x1, x2], axis=-1)
        x_add = layers.add([x, x_concat])
        output = self.channel_shuffle(x_add)
        return output

    def channel_shuffle(self, x):
        n, h, w, c = x.shape.as_list()
        x_reshaped = layers.Reshape([h, w, self.groups, int(c // self.groups)])(x)
        x_transposed = layers.Permute((1, 2, 4, 3))(x_reshaped)
        output = layers.Reshape([h, w, c])(x_transposed)
        return output

    def channel_split(self, x):
        def splitter(y):
            # keras Lambda saving bug!!!
            # x_left = layers.Lambda(lambda y: y[:, :, :, :int(int(y.shape[-1]) // self.groups)])(x)
            # x_right = layers.Lambda(lambda y: y[:, :, :, int(int(y.shape[-1]) // self.groups):])(x)
            # return x_left, x_right
            return tf.split(y, num_or_size_splits=self.groups, axis=-1)

        return layers.Lambda(lambda y: splitter(y))(x)

    def down_sample(self, x, filters):
        x_filters = int(x.shape[-1])
        x_conv = layers.Conv2D(filters - x_filters, kernel_size=3, strides=(2, 2), padding='same')(x)
        x_pool = layers.MaxPool2D()(x)
        x = layers.concatenate([x_conv, x_pool], axis=-1)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
        return x

    def apn_module(self, x):

        def right(x):
            x = layers.AveragePooling2D()(x)
            x = layers.Conv2D(self.classes, kernel_size=1, padding='same')(x)
            x = layers.BatchNormalization()(x)
            x = layers.Activation('relu')(x)
            x = layers.UpSampling2D(interpolation='bilinear')(x)
            return x

        def conv(x, filters, kernel_size, stride):
            x = layers.Conv2D(filters, kernel_size=kernel_size, strides=(stride, stride), padding='same')(x)
            x = layers.BatchNormalization()(x)
            x = layers.Activation('relu')(x)
            return x

        x_7 = conv(x, int(x.shape[-1]), 7, stride=2)
        x_5 = conv(x_7, int(x.shape[-1]), 5, stride=2)
        x_3 = conv(x_5, int(x.shape[-1]), 3, stride=2)

        x_3_1 = conv(x_3, self.classes, 3, stride=1)
        x_3_1_up = layers.UpSampling2D(interpolation='bilinear')(x_3_1)
        x_5_1 = conv(x_5, self.classes, 5, stride=1)
        x_3_5 = layers.add([x_5_1, x_3_1_up])
        x_3_5_up = layers.UpSampling2D(interpolation='bilinear')(x_3_5)
        x_7_1 = conv(x_7, self.classes, 3, stride=1)
        x_3_5_7 = layers.add([x_7_1, x_3_5_up])
        x_3_5_7_up = layers.UpSampling2D(interpolation='bilinear')(x_3_5_7)

        x_middle = conv(x, self.classes, 1, stride=1)
        x_middle = layers.multiply([x_3_5_7_up, x_middle])

        x_right = right(x)
        x_middle = layers.add([x_middle, x_right])
        return x_middle

    def encoder(self, x):
        x = self.down_sample(x, filters=32)
        for _ in range(3):
            x = self.ss_bt(x, dilation=1)

        x = self.down_sample(x, filters=64)
        for _ in range(2):
            x = self.ss_bt(x, dilation=1)

        x = self.down_sample(x, filters=128)

        dilation_rate = [1, 2, 5, 9, 2, 5, 9, 17]
        for dilation in dilation_rate:
            x = self.ss_bt(x, dilation=dilation)
        return x

    def decoder(self, x):
        x = self.apn_module(x)
        x = layers.UpSampling2D(size=8, interpolation='bilinear')(x)
        x = layers.Conv2D(self.classes, kernel_size=3, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('softmax')(x)
        return x

    def model(self):
        inputs = layers.Input(shape=self.input_shape)
        encoder_out = self.encoder(inputs)
        output = self.decoder(encoder_out)
        return models.Model(inputs, output)


if __name__ == '__main__':
    from flops import get_flops

    model = LEDNet(2, 3, (256, 256, 3)).model()
    model.summary()

    get_flops(model)