from nlp_toolkit.models import Base_Model
from nlp_toolkit.modules.logits import tc_output_logits
from nlp_toolkit.modules.token_embedders import Token_Embedding
from keras.layers import Input, Dense, Flatten, Dropout
from keras.layers import Conv1D, MaxPooling1D
from keras.layers.merge import concatenate
from keras.models import Model


class textCNN(Base_Model):
    """
    The known Kim CNN model used in text classification.
    It use mulit-channel CNN to encode texts
    """

    def __init__(self, nb_classes, nb_tokens, maxlen,
                 embedding_dim=256, embeddings=None, embed_l2=1E-6,
                 conv_kernel_size=[3, 4, 5], pool_size=[2, 2, 2],
                 nb_filters=128, fc_size=128,
                 embed_dropout_rate=0.25, final_dropout_rate=0.5):
        super(textCNN).__init__()
        self.nb_classes = nb_classes
        self.nb_tokens = nb_tokens
        self.maxlen = maxlen
        self.embedding_dim = embedding_dim
        self.nb_filters = nb_filters
        self.pool_size = pool_size
        self.conv_kernel_size = conv_kernel_size
        self.fc_size = fc_size
        self.final_dropout_rate = final_dropout_rate
        self.embed_dropout_rate = embed_dropout_rate

        # core layer: multi-channel cnn-pool layers
        self.cnn_list = [Conv1D(
            nb_filters, f, padding='same', name='conv_%d' % k) for k, f in enumerate(conv_kernel_size)]
        self.pool_list = [MaxPooling1D(p, name='pool_%d' % k)
                          for k, p in enumerate(pool_size)]
        self.fc = Dense(fc_size, activation='relu',
                        kernel_initializer='he_normal')
        if embeddings is not None:
            self.token_embeddings = [embeddings]
        else:
            self.token_embeddings = None
        self.invalid_params = {'cnn_list', 'pool_list', 'fc'}

    def forward(self):
        model_input = Input(shape=(self.maxlen,), dtype='int32', name='token')
        x = Token_Embedding(model_input, self.nb_tokens, self.embedding_dim,
                            self.token_embeddings, False, self.maxlen,
                            self.embed_dropout_rate, name='token_embeddings')
        cnn_combine = []
        for i in range(len(self.conv_kernel_size)):
            cnn = self.cnn_list[i](x)
            pool = self.pool_list[i](cnn)
            cnn_combine.append(pool)
        x = concatenate(cnn_combine, axis=-1)

        x = Flatten()(x)
        x = Dropout(self.final_dropout_rate)(x)
        x = self.fc(x)

        outputs = tc_output_logits(x, self.nb_classes, self.final_dropout_rate)

        self.model = Model(inputs=model_input,
                           outputs=outputs, name="TextCNN")

    def get_loss(self):
        if self.nb_classes == 2:
            return 'binary_crossentropy'
        elif self.nb_classes > 2:
            return 'categorical_crossentropy'

    def get_metrics(self):
        return ['acc']