from keras.layers import Dense, Permute, Reshape, Input
from keras.layers.wrappers import Bidirectional
from keras.layers.recurrent import LSTM
from keras.models import Model, load_model
from keras.regularizers import l2
from keras.applications.inception_v3 import InceptionV3

NAME = "Inceptionv3 CRNN"

def create_model(input_shape, config):


    input_tensor = Input(shape=input_shape)  # this assumes K.image_dim_ordering() == 'tf'
    inception_model = InceptionV3(include_top=False, weights=None, input_tensor=input_tensor)
    # inception_model.load_weights("logs/2016-12-18-13-56-44/weights.21.model", by_name=True)

    for layer in inception_model.layers:
        layer.trainable = False

    x = inception_model.output
    #x = GlobalAveragePooling2D()(x)

    # (bs, y, x, c) --> (bs, x, y, c)
    x = Permute((2, 1, 3))(x)

    # (bs, x, y, c) --> (bs, x, y * c)
    _x, _y, _c = [int(s) for s in x._shape[1:]]
    x = Reshape((_x, _y*_c))(x)
    x = Bidirectional(LSTM(512, return_sequences=False), merge_mode="concat")(x)

    predictions = Dense(config["num_classes"], activation='softmax')(x)

    model = Model(input=inception_model.input, output=predictions)
    model.load_weights("logs/2017-01-02-13-39-41/weights.06.model")

    return model