"""
    BiSeNet for CelebAMask-HQ, implemented in TensorFlow.
    Original paper: 'BiSeNet: Bilateral Segmentation Network for Real-time Semantic Segmentation,'
    https://arxiv.org/abs/1808.00897.
"""

__all__ = ['BiSeNet', 'bisenet_resnet18_celebamaskhq']

import os
import tensorflow as tf
import tensorflow.keras.layers as nn
from .common import conv1x1, conv1x1_block, conv3x3_block, InterpolationBlock, MultiOutputSequential, get_channel_axis,\
    get_im_size, is_channels_first
from .resnet import resnet18


class PyramidPoolingZeroBranch(nn.Layer):
    """
    Pyramid pooling zero branch.

    Parameters:
    ----------
    in_channels : int
        Number of input channels.
    out_channels : int
        Number of output channels.
    in_size : tuple of 2 int
        Spatial size of output image for the upsampling operation.
    data_format : str, default 'channels_last'
        The ordering of the dimensions in tensors.
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 in_size,
                 data_format="channels_last",
                 **kwargs):
        super(PyramidPoolingZeroBranch, self).__init__(**kwargs)
        self.in_size = in_size
        self.data_format = data_format

        self.pool = nn.GlobalAveragePooling2D(
            data_format=data_format,
            name="pool")
        self.conv = conv1x1_block(
            in_channels=in_channels,
            out_channels=out_channels,
            data_format=data_format,
            name="conv")
        self.up = InterpolationBlock(
            scale_factor=None,
            interpolation="bilinear",
            data_format=data_format,
            name="up")

    def call(self, x, training=None):
        in_size = self.in_size if self.in_size is not None else get_im_size(x, data_format=self.data_format)
        x = self.pool(x)
        axis = -1 if is_channels_first(self.data_format) else 1
        x = tf.expand_dims(tf.expand_dims(x, axis=axis), axis=axis)
        x = self.conv(x, training=training)
        x = self.up(x, size=in_size)
        return x


class AttentionRefinementBlock(nn.Layer):
    """
    Attention refinement block.

    Parameters:
    ----------
    in_channels : int
        Number of input channels.
    out_channels : int
        Number of output channels.
    data_format : str, default 'channels_last'
        The ordering of the dimensions in tensors.
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 data_format="channels_last",
                 **kwargs):
        super(AttentionRefinementBlock, self).__init__(**kwargs)
        self.data_format = data_format

        self.conv1 = conv3x3_block(
            in_channels=in_channels,
            out_channels=out_channels,
            data_format=data_format,
            name="conv1")
        self.pool = nn.GlobalAveragePooling2D(
            data_format=data_format,
            name="pool")
        self.conv2 = conv1x1_block(
            in_channels=out_channels,
            out_channels=out_channels,
            activation="sigmoid",
            data_format=data_format,
            name="conv2")

    def call(self, x, training=None):
        x = self.conv1(x, training=training)
        w = self.pool(x)
        axis = -1 if is_channels_first(self.data_format) else 1
        w = tf.expand_dims(tf.expand_dims(w, axis=axis), axis=axis)
        w = self.conv2(w, training=training)
        x = x * w
        return x


class PyramidPoolingMainBranch(nn.Layer):
    """
    Pyramid pooling main branch.

    Parameters:
    ----------
    in_channels : int
        Number of input channels.
    out_channels : int
        Number of output channels.
    scale_factor : float
        Multiplier for spatial size.
    data_format : str, default 'channels_last'
        The ordering of the dimensions in tensors.
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 scale_factor,
                 data_format="channels_last",
                 **kwargs):
        super(PyramidPoolingMainBranch, self).__init__(**kwargs)
        self.att = AttentionRefinementBlock(
            in_channels=in_channels,
            out_channels=out_channels,
            data_format=data_format,
            name="att")
        self.up = InterpolationBlock(
            scale_factor=scale_factor,
            interpolation="bilinear",
            data_format=data_format,
            name="up")
        self.conv = conv3x3_block(
            in_channels=out_channels,
            out_channels=out_channels,
            data_format=data_format,
            name="conv")

    def call(self, x, y, training=None):
        x = self.att(x, training=training)
        x = x + y
        x = self.up(x)
        x = self.conv(x, training=training)
        return x


class FeatureFusion(nn.Layer):
    """
    Feature fusion block.

    Parameters:
    ----------
    in_channels : int
        Number of input channels.
    out_channels : int
        Number of output channels.
    reduction : int, default 4
        Squeeze reduction value.
    data_format : str, default 'channels_last'
        The ordering of the dimensions in tensors.
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 reduction=4,
                 data_format="channels_last",
                 **kwargs):
        super(FeatureFusion, self).__init__(**kwargs)
        self.data_format = data_format
        mid_channels = out_channels // reduction

        self.conv_merge = conv1x1_block(
            in_channels=in_channels,
            out_channels=out_channels,
            data_format=data_format,
            name="conv_merge")
        self.pool = nn.GlobalAveragePooling2D(
            data_format=data_format,
            name="pool")
        self.conv1 = conv1x1(
            in_channels=out_channels,
            out_channels=mid_channels,
            data_format=data_format,
            name="conv1")
        self.activ = nn.ReLU()
        self.conv2 = conv1x1(
            in_channels=mid_channels,
            out_channels=out_channels,
            data_format=data_format,
            name="conv2")
        self.sigmoid = tf.nn.sigmoid

    def call(self, x, y, training=None):
        x = tf.concat([x, y], axis=get_channel_axis(self.data_format))
        x = self.conv_merge(x, training=training)
        w = self.pool(x)
        axis = -1 if is_channels_first(self.data_format) else 1
        w = tf.expand_dims(tf.expand_dims(w, axis=axis), axis=axis)
        w = self.conv1(w)
        w = self.activ(w)
        w = self.conv2(w)
        w = self.sigmoid(w)
        x_att = x * w
        x = x + x_att
        return x


class PyramidPooling(nn.Layer):
    """
    Pyramid Pooling module.

    Parameters:
    ----------
    x16_in_channels : int
        Number of input channels for x16.
    x32_in_channels : int
        Number of input channels for x32.
    y_out_channels : int
        Number of output channels for y-outputs.
    y32_out_size : tuple of 2 int
        Spatial size of the y32 tensor.
    data_format : str, default 'channels_last'
        The ordering of the dimensions in tensors.
    """
    def __init__(self,
                 x16_in_channels,
                 x32_in_channels,
                 y_out_channels,
                 y32_out_size,
                 data_format="channels_last",
                 **kwargs):
        super(PyramidPooling, self).__init__(**kwargs)
        z_out_channels = 2 * y_out_channels

        self.pool32 = PyramidPoolingZeroBranch(
            in_channels=x32_in_channels,
            out_channels=y_out_channels,
            in_size=y32_out_size,
            data_format=data_format,
            name="pool32")
        self.pool16 = PyramidPoolingMainBranch(
            in_channels=x32_in_channels,
            out_channels=y_out_channels,
            scale_factor=2,
            data_format=data_format,
            name="pool16")
        self.pool8 = PyramidPoolingMainBranch(
            in_channels=x16_in_channels,
            out_channels=y_out_channels,
            scale_factor=2,
            data_format=data_format,
            name="pool8")
        self.fusion = FeatureFusion(
            in_channels=z_out_channels,
            out_channels=z_out_channels,
            data_format=data_format,
            name="fusion")

    def call(self, x8, x16, x32, training=None):
        y32 = self.pool32(x32, training=training)
        y16 = self.pool16(x32, y32, training=training)
        y8 = self.pool8(x16, y16, training=training)
        z8 = self.fusion(x8, y8, training=training)
        return z8, y8, y16


class BiSeHead(nn.Layer):
    """
    BiSeNet head (final) block.

    Parameters:
    ----------
    in_channels : int
        Number of input channels.
    mid_channels : int
        Number of middle channels.
    out_channels : int
        Number of output channels.
    data_format : str, default 'channels_last'
        The ordering of the dimensions in tensors.
    """
    def __init__(self,
                 in_channels,
                 mid_channels,
                 out_channels,
                 data_format="channels_last",
                 **kwargs):
        super(BiSeHead, self).__init__(**kwargs)
        self.conv1 = conv3x3_block(
            in_channels=in_channels,
            out_channels=mid_channels,
            data_format=data_format,
            name="conv1")
        self.conv2 = conv1x1(
            in_channels=mid_channels,
            out_channels=out_channels,
            data_format=data_format,
            name="conv2")

    def call(self, x, training=None):
        x = self.conv1(x, training=training)
        x = self.conv2(x)
        return x


class BiSeNet(tf.keras.Model):
    """
    BiSeNet model from 'BiSeNet: Bilateral Segmentation Network for Real-time Semantic Segmentation,'
    https://arxiv.org/abs/1808.00897.

    Parameters:
    ----------
    backbone : func -> nn.Sequential
        Feature extractor.
    aux : bool, default True
        Whether to output an auxiliary results.
    fixed_size : bool, default True
        Whether to expect fixed spatial size of input image.
    in_channels : int, default 3
        Number of input channels.
    in_size : tuple of two ints, default (640, 480)
        Spatial size of the expected input image.
    classes : int, default 1000
        Number of classification classes.
    data_format : str, default 'channels_last'
        The ordering of the dimensions in tensors.
    """
    def __init__(self,
                 backbone,
                 aux=True,
                 fixed_size=True,
                 in_channels=3,
                 in_size=(640, 480),
                 classes=19,
                 data_format="channels_last",
                 **kwargs):
        super(BiSeNet, self).__init__(**kwargs)
        assert (in_channels == 3)
        self.in_size = in_size
        self.classes = classes
        self.data_format = data_format
        self.aux = aux
        self.fixed_size = fixed_size

        self.backbone, backbone_out_channels = backbone(
            data_format=data_format,
            name="backbone")

        y_out_channels = backbone_out_channels[0]
        z_out_channels = 2 * y_out_channels
        y32_out_size = (self.in_size[0] // 32, self.in_size[1] // 32) if fixed_size else None
        self.pool = PyramidPooling(
            x16_in_channels=backbone_out_channels[1],
            x32_in_channels=backbone_out_channels[2],
            y_out_channels=y_out_channels,
            y32_out_size=y32_out_size,
            data_format=data_format,
            name="pool")
        self.head_z8 = BiSeHead(
            in_channels=z_out_channels,
            mid_channels=z_out_channels,
            out_channels=classes,
            data_format=data_format,
            name="head_z8")
        self.up8 = InterpolationBlock(
            scale_factor=(8 if fixed_size else None),
            data_format=data_format,
            name="up8")

        if self.aux:
            mid_channels = y_out_channels // 2
            self.head_y8 = BiSeHead(
                in_channels=y_out_channels,
                mid_channels=mid_channels,
                out_channels=classes,
                data_format=data_format,
                name="head_y8")
            self.head_y16 = BiSeHead(
                in_channels=y_out_channels,
                mid_channels=mid_channels,
                out_channels=classes,
                data_format=data_format,
                name="head_y16")
            self.up16 = InterpolationBlock(
                scale_factor=(16 if fixed_size else None),
                data_format=data_format,
                name="up16")

    def call(self, x, training=None):
        assert is_channels_first(self.data_format) or ((x.shape[1] % 32 == 0) and (x.shape[2] % 32 == 0))
        assert (not is_channels_first(self.data_format)) or ((x.shape[2] % 32 == 0) and (x.shape[3] % 32 == 0))

        x8, x16, x32 = self.backbone(x, training=training)
        z8, y8, y16 = self.pool(x8, x16, x32, training=training)

        z8 = self.head_z8(z8, training=training)
        z8 = self.up8(z8)

        if self.aux:
            y8 = self.head_y8(y8, training=training)
            y16 = self.head_y16(y16, training=training)
            y8 = self.up8(y8)
            y16 = self.up16(y16)
            return z8, y8, y16
        else:
            return z8


def get_bisenet(model_name=None,
                pretrained=False,
                root=os.path.join("~", ".tensorflow", "models"),
                **kwargs):
    """
    Create BiSeNet model with specific parameters.

    Parameters:
    ----------
    model_name : str or None, default None
        Model name for loading pretrained model.
    pretrained : bool, default False
        Whether to load the pretrained weights for model.
    root : str, default '~/.tensorflow/models'
        Location for keeping the model parameters.
    """
    net = BiSeNet(
        **kwargs)

    if pretrained:
        if (model_name is None) or (not model_name):
            raise ValueError("Parameter `model_name` should be properly initialized for loading pretrained model.")
        from .model_store import get_model_file
        in_channels = kwargs["in_channels"] if ("in_channels" in kwargs) else 3
        input_shape = (1,) + (in_channels,) + net.in_size if net.data_format == "channels_first" else\
            (1,) + net.in_size + (in_channels,)
        net.build(input_shape=input_shape)
        net.load_weights(
            filepath=get_model_file(
                model_name=model_name,
                local_model_store_dir_path=root))

    return net


def bisenet_resnet18_celebamaskhq(pretrained_backbone=False, classes=19, **kwargs):
    """
    BiSeNet model on the base of ResNet-18 for face segmentation on CelebAMask-HQ from 'BiSeNet: Bilateral Segmentation
    Network for Real-time Semantic Segmentation,' https://arxiv.org/abs/1808.00897.

    Parameters:
    ----------
    pretrained_backbone : bool, default False
        Whether to load the pretrained weights for feature extractor.
    classes : int, default 19
        Number of classes.
    pretrained : bool, default False
        Whether to load the pretrained weights for model.
    root : str, default '~/.tensorflow/models'
        Location for keeping the model parameters.
    """
    def backbone(**bb_kwargs):
        features_raw = resnet18(pretrained=pretrained_backbone, **bb_kwargs).features
        features_raw._layers.pop()
        features = MultiOutputSequential(return_last=False, name="backbone")
        features.add(features_raw._layers[0])
        for i, stage in enumerate(features_raw._layers[1:]):
            if i != 0:
                stage.do_output = True
            features.add(stage)
        out_channels = [128, 256, 512]
        return features, out_channels
    return get_bisenet(backbone=backbone, classes=classes, model_name="bisenet_resnet18_celebamaskhq", **kwargs)


def _test():
    import numpy as np
    import tensorflow.keras.backend as K

    data_format = "channels_last"
    # data_format = "channels_first"
    in_size = (640, 480)
    aux = True
    pretrained = False

    models = [
        bisenet_resnet18_celebamaskhq,
    ]

    for model in models:

        net = model(pretrained=pretrained, in_size=in_size, aux=aux, data_format=data_format)

        batch = 14
        x = tf.random.normal((batch, 3, in_size[0], in_size[1]) if is_channels_first(data_format) else
                             (batch, in_size[0], in_size[1], 3))
        ys = net(x)
        y = ys[0] if aux else ys
        assert (y.shape[0] == x.shape[0])
        if is_channels_first(data_format):
            assert ((y.shape[1] == 19) and (y.shape[2] == x.shape[2]) and (y.shape[3] == x.shape[3]))
        else:
            assert ((y.shape[3] == 19) and (y.shape[1] == x.shape[1]) and (y.shape[2] == x.shape[2]))

        weight_count = sum([np.prod(K.get_value(w).shape) for w in net.trainable_weights])
        print("m={}, {}".format(model.__name__, weight_count))
        if aux:
            assert (model != bisenet_resnet18_celebamaskhq or weight_count == 13300416)
        else:
            assert (model != bisenet_resnet18_celebamaskhq or weight_count == 13150272)


if __name__ == "__main__":
    _test()