from tensorflow.keras.layers import (
    Input,
    GlobalAveragePooling2D,
    Dense,
    MaxPooling2D,
    UpSampling2D,
    AveragePooling2D,
    Conv2D,
    BatchNormalization,
    Concatenate,
    Activation,
    Flatten,
    Add,
    Multiply,
    Reshape,
    Lambda,
    Dropout
)
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
import tensorflow as tf


class SE_HRNet(object):
    def __init__(self, blocks=3, reduction_ratio=4, init_filters=64, expansion=4, training=True):
        self.blocks = blocks
        self.training = training
        self.reduction_ratio = reduction_ratio
        self.init_filters = init_filters
        self.expansion = expansion
        print('Expansion part has not been implemented.')

    def first_layer(self, x, scope):
        with tf.name_scope(scope):
            x = Conv2D(filters=self.init_filters, kernel_size=(3, 3), strides=(2, 2), padding='same', kernel_regularizer=l2(1e-4))(x)
            x = BatchNormalization(axis=-1)(x, training=self.training)
            x = Conv2D(filters=self.init_filters, kernel_size=(3, 3), strides=(2, 2), padding='same', kernel_regularizer=l2(1e-4))(x)
            norm = BatchNormalization(axis=-1)(x, training=self.training)
            act = Activation("relu")(norm)
            return act

    def bottleneck_block(self, x, filters, strides, scope):
        with tf.name_scope(scope):
            x = Conv2D(filters=filters, kernel_size=(1, 1), strides=(1, 1), padding='same', kernel_regularizer=l2(1e-4))(x)
            x = BatchNormalization(axis=-1)(x, training=self.training)
            x = Activation("relu")(x)

            x = Conv2D(filters=filters, kernel_size=(3, 3), strides=strides, padding='same', kernel_regularizer=l2(1e-4))(x)
            x = BatchNormalization(axis=-1)(x, training=self.training)
            x = Activation("relu")(x)

            x = Conv2D(filters=filters, kernel_size=(1, 1), strides=(1, 1), padding='same', kernel_regularizer=l2(1e-4))(x)
            x = BatchNormalization(axis=-1)(x, training=self.training)

            return x

    def squeeze_excitation_layer(self, input_x, out_dim, ratio, layer_name):
        with tf.name_scope(layer_name):

            squeeze = GlobalAveragePooling2D()(input_x)

            excitation = Dense(units=out_dim / ratio)(squeeze)
            excitation = Activation("relu")(excitation)
            excitation = Dense(units=out_dim)(excitation)
            excitation = Activation("sigmoid")(excitation)
            excitation = Reshape([1, 1, out_dim])(excitation)
            scale = Multiply()([input_x, excitation])

            return scale

    def residual_layer(self, out_dim, scope, first_layer_stride=(2, 2), res_block=None):
        if res_block is None:
            res_block = self.blocks

        def f(input_x):
            # split + transform(bottleneck) + transition + merge
            # input_dim = input_x.get_shape().as_list()[-1]
            for i in range(res_block):
                if i == 0:
                    strides = first_layer_stride
                    # filters = input_x.get_shape().as_list()[-1]
                else:
                    strides = (1, 1)
                    # filters = out_dim
                x = self.bottleneck_block(input_x, filters=out_dim, strides=strides, scope='bottleneck_' + str(i))
                x = self.squeeze_excitation_layer(x, out_dim=x.get_shape().as_list()[-1], ratio=self.reduction_ratio, layer_name='squeeze_layer_' + str(i))
                if i != 0:  # Leave the first block without residual connection due to unequal shape
                    x = Add()([input_x, x])
                x = Activation('relu')(x)
                input_x = x
            return input_x
        return f

    def multi_resolution_concat(self, maps, filter_list, scope='multi_resolution_concat'):
        fuse_layers = []
        print('Input', maps)
        with tf.name_scope(scope):
            for idx, _ in enumerate(maps):
                fuse_list = []
                for j in range(len(maps)):
                    x = maps[j]
                    # Upsamples, high resolution first
                    if j < idx:
                        # Downsamples, high resolution first
                        for k in range(idx - j):
                            x = Conv2D(filter_list[idx], kernel_size=(3, 3), strides=(2, 2), padding='same')(x)
                            x = BatchNormalization(axis=-1)(x, training=self.training)
                            if k == idx - j - 1:
                                x = Activation("relu")(x)
                    elif j == idx:
                        # Original feature map
                        pass
                    elif j > idx:
                        for k in range(j - idx):
                            x = Conv2D(filter_list[idx], kernel_size=(1, 1), strides=(1, 1), padding='same')(x)
                            x = BatchNormalization(axis=-1)(x, training=self.training)
                            x = UpSampling2D(size=(2, 2))(x)
                    else:
                        raise ValueError()
                    fuse_list.append(x)
                    print(idx, j, maps[j])
                    print(filter_list[idx], x)
                if len(fuse_list) > 1:
                    concat = Add()(fuse_list)
                    x = Activation("relu")(concat)
                    fuse_layers.append(x)
                    # print('Assemble', concat)
                else:
                    fuse_layers.append(fuse_list[0])
                    # print('Assemble O', fuse_list[0])
            print('Out', fuse_layers)
        return fuse_layers

    def extract_multi_resolution_feature(self, repetitions=3):

        def f(input_x):
            x = self.first_layer(input_x, scope='first_layer')
            features = []
            filters = self.init_filters
            self.filter_list = [filters]
            # First Layer consumed one stage
            for i in range(repetitions):
                print('\nBuilding ... %d/%d' % (i, repetitions))
                # Get Downsample
                scope = 'stage_%d' % (i + 1)
                if i == 0:
                    down_x = self.residual_layer(filters, scope=scope, first_layer_stride=(2, 2), res_block=self.blocks)(x)
                else:
                    down_x = self.residual_layer(filters, scope=scope, first_layer_stride=(2, 2), res_block=self.blocks)(features[-1])
                features.append(down_x)
                # Get concatenated feature maps
                out_maps = self.multi_resolution_concat(features, self.filter_list)
                features = []
                print('Identity Mapping:')
                # Residual connection with 3x3 kernel, 1x1 stride with same number of filters
                for idx, (fm, num_filter) in enumerate(zip(out_maps, self.filter_list)):
                    x = Lambda(lambda x: x, output_shape=x.get_shape().as_list())(fm)
                    print(idx, x)
                    features.append(x)
                filters *= 2
                self.filter_list.append(filters)
            return features

        return f

    def make_classification_head(self, feature_maps, filter_list):
        previous_fm = None
        for idx, fm in enumerate(feature_maps):
            if previous_fm is None:
                previous_fm = fm
                continue
            # if idx == len(feature_maps):
            #     # The final feature map no need to add
            #     continue
            print(previous_fm.get_shape().as_list(), fm.get_shape().as_list(), filter_list[idx], filter_list)
            x = Conv2D(filter_list[idx], kernel_size=(3, 3), strides=(2, 2), padding='same')(previous_fm)
            x = BatchNormalization(axis=-1)(x, training=self.training)
            x = Activation("relu")(x)
            previous_fm = Add()([fm, x])
        return previous_fm

    def build(self, input_shape, num_output, repetitions=3):
        input_x = Input(shape=input_shape)

        feature_maps = self.extract_multi_resolution_feature(repetitions=repetitions)(input_x)
        x = self.make_classification_head(feature_maps, self.filter_list)

        x = Conv2D(filters=x.get_shape().as_list()[-1] * 2, kernel_size=(1, 1), strides=(1, 1), padding='same', kernel_regularizer=l2(1e-4))(x)
        x = BatchNormalization(axis=-1)(x, training=self.training)
        x = Activation("relu")(x)
        x = GlobalAveragePooling2D()(x)
        x = Flatten()(x)

        x = Dense(units=num_output,
                  name='final_fully_connected',
                  kernel_initializer="he_normal",
                  kernel_regularizer=l2(1e-4),
                  activation='softmax')(x)

        return Model(inputs=input_x, outputs=x)


if __name__ == '__main__':
    model = SE_HRNet(blocks=3, reduction_ratio=4, init_filters=32, training=True
                     ).build(input_shape=(2048, 2048, 3), num_output=15, repetitions=4)
    model.summary()