from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, SpatialDropout2D, concatenate
from tensorflow.keras.models import Model

from groupnorm import GroupNormalization


def down_stage(inputs, filters, kernel_size=3,
               activation="relu", padding="SAME"):
    conv = Conv2D(filters, kernel_size,
                  activation=activation, padding=padding)(inputs)
    conv = GroupNormalization()(conv)
    conv = Conv2D(filters, kernel_size,
                  activation=activation, padding=padding)(conv)
    conv = GroupNormalization()(conv)
    pool = MaxPooling2D()(conv)
    return conv, pool


def up_stage(inputs, skip, filters, kernel_size=3,
             activation="relu", padding="SAME"):
    up = UpSampling2D()(inputs)
    up = Conv2D(filters, 2, activation=activation, padding=padding)(up)
    up = GroupNormalization()(up)

    merge = concatenate([skip, up])
    merge = GroupNormalization()(merge)

    conv = Conv2D(filters, kernel_size,
                  activation=activation, padding=padding)(merge)
    conv = GroupNormalization()(conv)
    conv = Conv2D(filters, kernel_size,
                  activation=activation, padding=padding)(conv)
    conv = GroupNormalization()(conv)
    conv = SpatialDropout2D(0.5)(conv, training=True)

    return conv


def end_stage(inputs, kernel_size=3, activation="relu", padding="SAME"):
    conv = Conv2D(1, kernel_size, activation=activation, padding="SAME")(inputs)
    conv = Conv2D(1, 1, activation="sigmoid")(conv)

    return conv


def dropout_unet(input_shape=(280, 280, 1), kernel_size=3,
                 activation="relu", padding="SAME", **kwargs):
    inputs = Input(input_shape)

    conv1, pool1 = down_stage(inputs, 16,
                              kernel_size=kernel_size,
                              activation=activation,
                              padding=padding)
    conv2, pool2 = down_stage(pool1, 32,
                              kernel_size=kernel_size,
                              activation=activation,
                              padding=padding)
    conv3, pool3 = down_stage(pool2, 64,
                              kernel_size=kernel_size,
                              activation=activation,
                              padding=padding)
    conv4, _ = down_stage(pool3, 128,
                          kernel_size=kernel_size,
                          activation=activation,
                          padding=padding)
    conv4 = SpatialDropout2D(0.5)(conv4, training=True)

    conv5 = up_stage(conv4, conv3, 64,
                     kernel_size=kernel_size,
                     activation=activation,
                     padding=padding)
    conv6 = up_stage(conv5, conv2, 32,
                     kernel_size=kernel_size,
                     activation=activation,
                     padding=padding)
    conv7 = up_stage(conv6, conv1, 16,
                     kernel_size=kernel_size,
                     activation=activation,
                     padding=padding)

    conv8 = end_stage(conv7, kernel_size=kernel_size,
                      activation=activation, padding=padding)

    return Model(inputs=inputs, outputs=conv8)