# -*- coding: utf-8 -*-
'''ResNet50 model for Keras.
# Reference:
- [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)
Adapted from code contributed by BigMoyan.
'''
from __future__ import print_function
from __future__ import absolute_import

from keras.layers import merge, Input
from keras.layers import Dense, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D, AveragePooling2D, TimeDistributed
from keras import backend as K

from keras_frcnn.RoiPoolingConv import RoiPoolingConv
from keras_frcnn.FixedBatchNormalization import FixedBatchNormalization

import h5py

bn_mode = 0

def identity_block(input_tensor, kernel_size, filters, stage, block, trainable=True):

    '''The identity_block is the block that has no conv layer at shortcut
    # Arguments
            input_tensor: input tensor
            kernel_size: defualt 3, the kernel size of middle conv layer at main path
            filters: list of integers, the nb_filters of 3 conv layer at main path
            stage: integer, current stage label, used for generating layer names
            block: 'a','b'..., current block label, used for generating layer names
    '''
    nb_filter1, nb_filter2, nb_filter3 = filters
    if K.image_dim_ordering() == 'tf':
        bn_axis = 3
    else:
        bn_axis = 1
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    x = Convolution2D(nb_filter1, 1, 1, name=conv_name_base + '2a', trainable=trainable)(input_tensor)
    x = FixedBatchNormalization(trainable=False,axis=bn_axis, name=bn_name_base + '2a')(x)
    x = Activation('relu')(x)

    x = Convolution2D(nb_filter2, kernel_size, kernel_size, border_mode='same', name=conv_name_base + '2b', trainable=trainable)(x)
    x = FixedBatchNormalization(trainable=False,axis=bn_axis, name=bn_name_base + '2b')(x)
    x = Activation('relu')(x)

    x = Convolution2D(nb_filter3, 1, 1, name=conv_name_base + '2c', trainable=trainable)(x)
    x = FixedBatchNormalization(trainable=False,axis=bn_axis, name=bn_name_base + '2c')(x)

    x = merge([x, input_tensor], mode='sum')
    x = Activation('relu')(x)
    return x


def identity_block_td(input_tensor, kernel_size, filters, stage, block, trainable=True):
    '''The identity_block is the block that has no conv layer at shortcut
    # Arguments
            input_tensor: input tensor
            kernel_size: defualt 3, the kernel size of middle conv layer at main path
            filters: list of integers, the nb_filters of 3 conv layer at main path
            stage: integer, current stage label, used for generating layer names
            block: 'a','b'..., current block label, used for generating layer names
    '''
    nb_filter1, nb_filter2, nb_filter3 = filters
    if K.image_dim_ordering() == 'tf':
        bn_axis = 3
    else:
        bn_axis = 1

    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'


    x = TimeDistributed(Convolution2D(nb_filter1, 1, 1, trainable=trainable, init='normal'), name=conv_name_base + '2a')(input_tensor)
    x = TimeDistributed(FixedBatchNormalization(trainable=False,axis=bn_axis), name=bn_name_base + '2a')(x)

    x = Activation('relu')(x)

    x = TimeDistributed(Convolution2D(nb_filter2, kernel_size, kernel_size, trainable=trainable, init='normal',border_mode='same'), name=conv_name_base + '2b')(x)
    x = TimeDistributed(FixedBatchNormalization(trainable=False,axis=bn_axis), name=bn_name_base + '2b')(x)

    x = Activation('relu')(x)

    x = TimeDistributed(Convolution2D(nb_filter3, 1, 1, trainable=trainable, init='normal'), name=conv_name_base + '2c')(x)
    x = TimeDistributed(FixedBatchNormalization(trainable=False,axis=bn_axis), name=bn_name_base + '2c')(x)


    x = merge([x, input_tensor], mode='sum')
    x = Activation('relu')(x)

    return x

def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2), trainable=True):
    '''conv_block is the block that has a conv layer at shortcut
    # Arguments
            input_tensor: input tensor
            kernel_size: defualt 3, the kernel size of middle conv layer at main path
            filters: list of integers, the nb_filters of 3 conv layer at main path
            stage: integer, current stage label, used for generating layer names
            block: 'a','b'..., current block label, used for generating layer names
    Note that from stage 3, the first conv layer at main path is with subsample=(2,2)
    And the shortcut should have subsample=(2,2) as well
    '''
    nb_filter1, nb_filter2, nb_filter3 = filters
    if K.image_dim_ordering() == 'tf':
        bn_axis = 3
    else:
        bn_axis = 1
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    x = Convolution2D(nb_filter1, 1, 1, subsample=strides, name=conv_name_base + '2a', trainable=trainable)(input_tensor)
    x = FixedBatchNormalization(trainable=False,axis=bn_axis, name=bn_name_base + '2a')(x)
    x = Activation('relu')(x)

    x = Convolution2D(nb_filter2, kernel_size, kernel_size, border_mode='same', name=conv_name_base + '2b', trainable=trainable)(x)
    x = FixedBatchNormalization(trainable=False,axis=bn_axis, name=bn_name_base + '2b')(x)
    x = Activation('relu')(x)

    x = Convolution2D(nb_filter3, 1, 1, name=conv_name_base + '2c', trainable=trainable)(x)
    x = FixedBatchNormalization(trainable=False,axis=bn_axis, name=bn_name_base + '2c')(x)

    shortcut = Convolution2D(nb_filter3, 1, 1, subsample=strides, name=conv_name_base + '1', trainable=trainable)(input_tensor)
    shortcut = FixedBatchNormalization(trainable=False,axis=bn_axis, name=bn_name_base + '1')(shortcut)

    x = merge([x, shortcut], mode='sum')
    x = Activation('relu')(x)
    return x

def conv_block_td(input_tensor, kernel_size, filters, stage, block, strides=(2, 2), trainable=True):
    '''conv_block is the block that has a conv layer at shortcut
    # Arguments
            input_tensor: input tensor
            kernel_size: defualt 3, the kernel size of middle conv layer at main path
            filters: list of integers, the nb_filters of 3 conv layer at main path
            stage: integer, current stage label, used for generating layer names
            block: 'a','b'..., current block label, used for generating layer names
    Note that from stage 3, the first conv layer at main path is with subsample=(2,2)
    And the shortcut should have subsample=(2,2) as well
    '''
    nb_filter1, nb_filter2, nb_filter3 = filters
    if K.image_dim_ordering() == 'tf':
        bn_axis = 3
    else:
        bn_axis = 1

    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    x = TimeDistributed(Convolution2D(nb_filter1, 1, 1, subsample=strides, trainable=trainable, init='normal'), name=conv_name_base + '2a')(input_tensor)
    x = TimeDistributed(FixedBatchNormalization(trainable=False,axis=bn_axis), name=bn_name_base + '2a')(x)

    x = Activation('relu')(x)

    x = TimeDistributed(Convolution2D(nb_filter2, kernel_size, kernel_size, border_mode='same', trainable=trainable, init='normal'), name=conv_name_base + '2b')(x)
    x = TimeDistributed(FixedBatchNormalization(trainable=False,axis=bn_axis), name=bn_name_base + '2b')(x)

    x = Activation('relu')(x)

    x = TimeDistributed(Convolution2D(nb_filter3, 1, 1, init='normal'), name=conv_name_base + '2c', trainable=trainable)(x)
    x = TimeDistributed(FixedBatchNormalization(trainable=False,axis=bn_axis), name=bn_name_base + '2c')(x)

    shortcut = TimeDistributed(Convolution2D(nb_filter3, 1, 1, subsample=strides, trainable=trainable, init='normal'), name=conv_name_base + '1')(input_tensor)
    shortcut = TimeDistributed(FixedBatchNormalization(trainable=False,axis=bn_axis), name=bn_name_base + '1')(shortcut)


    x = merge([x, shortcut], mode='sum')
    x = Activation('relu')(x)
    return x

def nn_base(input_tensor=None, trainable = False):

    # Determine proper input shape
    if K.image_dim_ordering() == 'th':
        input_shape = (3, None, None)
    else:
        input_shape = (None, None, 3)

    if input_tensor is None:
        img_input = Input(shape=input_shape)
    else:
        if not K.is_keras_tensor(input_tensor):
            img_input = Input(tensor=input_tensor, shape=input_shape)
        else:
            img_input = input_tensor

    if K.image_dim_ordering() == 'tf':
        bn_axis = 3
    else:
        bn_axis = 1

    x = ZeroPadding2D((3, 3))(img_input)

    x = Convolution2D(64, 7, 7, subsample=(2, 2), name='conv1', trainable = trainable)(x)
    x = FixedBatchNormalization(trainable=False,axis=bn_axis, name='bn_conv1')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((3, 3), strides=(2, 2))(x)

    x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1), trainable = trainable)
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='b', trainable = trainable)
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='c', trainable = trainable)

    x = conv_block(x, 3, [128, 128, 512], stage=3, block='a', trainable = trainable)
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='b', trainable = trainable)
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='c', trainable = trainable)
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='d', trainable = trainable)

    x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a', trainable = trainable)
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b', trainable = trainable)
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c', trainable = trainable)
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d', trainable = trainable)
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e', trainable = trainable)
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f', trainable = trainable)

    return x

def classifier_layers(x, trainable=False):
    x = conv_block_td(x, 3, [512, 512, 2048], stage=5, block='a', strides=(1, 1), trainable=trainable)
    x = identity_block_td(x, 3, [512, 512, 2048], stage=5, block='b', trainable=trainable)
    x = identity_block_td(x, 3, [512, 512, 2048], stage=5, block='c', trainable=True)

    x = TimeDistributed(AveragePooling2D((7, 7)), name='avg_pool')(x)

    return x

def rpn(base_layers,num_anchors):

    x = Convolution2D(512, 3, 3, border_mode = 'same', activation='relu', init='normal',name='rpn_conv1')(base_layers)

    x_class = Convolution2D(num_anchors, 1, 1, activation='sigmoid', init='normal',name='rpn_out_class')(x)
    x_regr = Convolution2D(num_anchors * 4, 1, 1, activation='linear', init='normal',name='rpn_out_regr')(x)

    return [x_class,x_regr]

def classifier(base_layers,input_rois,num_rois,nb_classes = 21):

    pooling_regions = 7

    out_roi_pool = RoiPoolingConv(pooling_regions, num_rois)([base_layers,input_rois])

    out = classifier_layers(out_roi_pool)
    out = TimeDistributed(Flatten(),name='td_flatten')(out)
    out_class = TimeDistributed(Dense(nb_classes, activation='softmax'), name='dense_class_{}'.format(nb_classes))(out)
    out_regr = TimeDistributed(Dense(4, activation='linear'), name='dense_regr')(out)


    return [out_class,out_regr]