from tensorflow.keras.models import Model
from tensorflow.keras.layers import Lambda, Activation, Conv2D, MaxPooling2D, Add, Input, BatchNormalization, UpSampling2D, Concatenate
from tensorflow.keras.layers import concatenate, add
from tensorflow.keras.regularizers import l2


TOP_DOWN_PYRAMID_SIZE = 256

"""
Implementation of Resnext FPN 
"""


def resnext_fpn(input_shape, nb_labels, depth=(3, 4, 6, 3), cardinality=32, width=4, weight_decay=5e-4, batch_norm=True,
                batch_momentum=0.9):
    """
    TODO: add dilated convolutions as well
    Resnext-50 is defined by (3, 4, 6, 3) [default]
    Resnext-101 is defined by (3, 4, 23, 3)
    Resnext-152 is defined by (3, 8, 23, 3)
    :param input_shape:
    :param nb_labels:
    :param depth:
    :param cardinality:
    :param width:
    :param weight_decay:
    :param batch_norm:
    :param batch_momentum:
    :return:
    """
    nb_rows, nb_cols, _ = input_shape
    input_tensor = Input(shape=input_shape)

    bn_axis = 3
    x = Conv2D(64, (7, 7), strides=(2, 2), padding='same', name='conv1', kernel_regularizer=l2(weight_decay))(input_tensor)
    if batch_norm:
        x = BatchNormalization(axis=bn_axis, name='bn_conv1', momentum=batch_momentum)(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
    stage_1 = x

    # filters are cardinality * width * 2 for each depth level
    for i in range(depth[0]):
        x = bottleneck_block(x, 128, cardinality, strides=1, weight_decay=weight_decay)
    stage_2 = x

    # this can be done with a for loop but is more explicit this way
    x = bottleneck_block(x, 256, cardinality, strides=2, weight_decay=weight_decay)
    for idx in range(1, depth[1]):
        x = bottleneck_block(x, 256, cardinality, strides=1, weight_decay=weight_decay)
    stage_3 = x

    x = bottleneck_block(x, 512, cardinality, strides=2, weight_decay=weight_decay)
    for idx in range(1, depth[2]):
        x = bottleneck_block(x, 512, cardinality, strides=1, weight_decay=weight_decay)
    stage_4 = x

    x = bottleneck_block(x, 1024, cardinality, strides=2, weight_decay=weight_decay)
    for idx in range(1, depth[3]):
        x = bottleneck_block(x, 1024, cardinality, strides=1, weight_decay=weight_decay)
    stage_5 = x

    P5 = Conv2D(TOP_DOWN_PYRAMID_SIZE, (1, 1), name='fpn_c5p5')(stage_5)
    P4 = Add(name="fpn_p4add")([UpSampling2D(size=(2, 2), name="fpn_p5upsampled")(P5),
                                Conv2D(TOP_DOWN_PYRAMID_SIZE, (1, 1), name='fpn_c4p4', padding='same')(stage_4)])
    P3 = Add(name="fpn_p3add")([UpSampling2D(size=(2, 2), name="fpn_p4upsampled")(P4),
                                Conv2D(TOP_DOWN_PYRAMID_SIZE, (1, 1), name='fpn_c3p3')(stage_3)])
    P2 = Add(name="fpn_p2add")([UpSampling2D(size=(2, 2), name="fpn_p3upsampled")(P3),
                                Conv2D(TOP_DOWN_PYRAMID_SIZE, (1, 1), name='fpn_c2p2', padding='same')(stage_2)])
    # Attach 3x3 conv to all P layers to get the final feature maps. --> Reduce aliasing effect of upsampling
    P2 = Conv2D(TOP_DOWN_PYRAMID_SIZE, (3, 3), padding="SAME", name="fpn_p2")(P2)
    P3 = Conv2D(TOP_DOWN_PYRAMID_SIZE, (3, 3), padding="SAME", name="fpn_p3")(P3)
    P4 = Conv2D(TOP_DOWN_PYRAMID_SIZE, (3, 3), padding="SAME", name="fpn_p4")(P4)
    P5 = Conv2D(TOP_DOWN_PYRAMID_SIZE, (3, 3), padding="SAME", name="fpn_p5")(P5)

    head1 = Conv2D(TOP_DOWN_PYRAMID_SIZE // 2, (3, 3), padding="SAME", name="head1_conv")(P2)
    head1 = Conv2D(TOP_DOWN_PYRAMID_SIZE // 2, (3, 3), padding="SAME", name="head1_conv_2")(head1)

    head2 = Conv2D(TOP_DOWN_PYRAMID_SIZE // 2, (3, 3), padding="SAME", name="head2_conv")(P3)
    head2 = Conv2D(TOP_DOWN_PYRAMID_SIZE // 2, (3, 3), padding="SAME", name="head2_conv_2")(head2)

    head3 = Conv2D(TOP_DOWN_PYRAMID_SIZE // 2, (3, 3), padding="SAME", name="head3_conv")(P4)
    head3 = Conv2D(TOP_DOWN_PYRAMID_SIZE // 2, (3, 3), padding="SAME", name="head3_conv_2")(head3)

    head4 = Conv2D(TOP_DOWN_PYRAMID_SIZE // 2, (3, 3), padding="SAME", name="head4_conv")(P5)
    head4 = Conv2D(TOP_DOWN_PYRAMID_SIZE // 2, (3, 3), padding="SAME", name="head4_conv_2")(head4)

    f_p2 = UpSampling2D(size=(8, 8), name="pre_cat_2")(head4)
    f_p3 = UpSampling2D(size=(4, 4), name="pre_cat_3")(head3)
    f_p4 = UpSampling2D(size=(2, 2), name="pre_cat_4")(head2)
    f_p5 = head1

    x = Concatenate(axis=-1)([f_p2, f_p3, f_p4, f_p5])
    x = Conv2D(nb_labels, (3, 3), padding="SAME", name="final_conv", kernel_initializer='he_normal',
               activation='linear')(x)
    x = UpSampling2D(size=(4, 4), name="final_upsample")(x)
    x = Activation('sigmoid')(x)

    model = Model(input_tensor, x)

    return model


def grouped_convolution_block(input, grouped_channels, cardinality, strides, weight_decay=5e-4):
    init = input
    group_list = []

    if cardinality == 1:
        # with cardinality 1, it is a standard convolution
        x = Conv2D(grouped_channels, (3, 3), padding='same', use_bias=False, strides=(strides, strides),
                   kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(init)
        x = BatchNormalization(axis=3)(x)
        x = Activation('relu')(x)
        return x

    for c in range(cardinality):
        x = Lambda(lambda z: z[:, :, :, c * grouped_channels:(c + 1) * grouped_channels])(input)
        x = Conv2D(grouped_channels, (3, 3), padding='same', use_bias=False, strides=(strides, strides),
                   kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(x)
        group_list.append(x)

    group_merge = concatenate(group_list, axis=3)
    x = BatchNormalization(axis=3)(group_merge)
    x = Activation('relu')(x)
    return x


def bottleneck_block(input, filters=64, cardinality=8, strides=1, weight_decay=5e-4):
    init = input
    grouped_channels = int(filters / cardinality)

    if init.shape[-1] != 2 * filters:
        init = Conv2D(filters * 2, (1, 1), padding='same', strides=(strides, strides),
                      use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(init)
        init = BatchNormalization(axis=3)(init)

    x = Conv2D(filters, (1, 1), padding='same', use_bias=False,
               kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(input)
    x = BatchNormalization(axis=3)(x)
    x = Activation('relu')(x)

    x = grouped_convolution_block(x, grouped_channels, cardinality, strides, weight_decay)
    x = Conv2D(filters * 2, (1, 1), padding='same', use_bias=False, kernel_initializer='he_normal',
               kernel_regularizer=l2(weight_decay))(x)
    x = BatchNormalization(axis=3)(x)

    x = add([init, x])
    x = Activation('relu')(x)
    return x


A=resnext_fpn((256,256,3), 10, depth=(3, 4, 6, 3), cardinality=32, width=4, weight_decay=5e-4, batch_norm=True,
                batch_momentum=0.9)
A.summary()