import keras
from pj1_clf_news.model import *
import pj1_clf_news.config as conf
from pj1_clf_news.dataset import Dataset
from keras.callbacks import ModelCheckpoint, Callback, EarlyStopping
from sklearn.metrics import f1_score

name_model = {'lstm': model_lstm, 'bilstm': model_bilstm, 'rcnn': model_rcnn,
              'rcnn_res': model_rcnn_res, 'cnn_res': model_cnn_res, 'cnn': model_cnn}
model_name = 'cnn'


class MyCallback(Callback):
    def on_epoch_begin(self, epoch, logs=None):
        # print('epoch: ===', epoch)
        pass

    def on_epoch_end(self, epoch, logs=None):
        # print('=========epoch: ', epoch, 'loss: ', logs.get('loss'), 'acc: ', logs.get('acc'))
        pass

    def on_batch_end(self, batch, logs=None):
        if (batch + 1) % 10 == 0:
            print('batch: ', batch + 1, 'loss: ', logs.get('loss'), 'acc: ', logs.get('acc'))


class F1(Callback):
    def __init__(self, validation_generate, steps_per_epoch):
        super(Callback, self).__init__()
        self.validation_generate = validation_generate
        self.steps_per_epoch = steps_per_epoch

    def on_epoch_end(self, epoch, logs=None):
        y_trues = []
        y_preds = []
        for i in range(self.steps_per_epoch):
            x_val, y_val = next(self.validation_generate)
            y_pred = self.model.predict(x_val, verbose=0)
            y_trues.extend(np.argmax(y_val, axis=1).tolist())
            y_preds.extend(np.argmax(y_pred, axis=1).tolist())
        score = f1_score(np.array(y_trues), np.array(y_preds), average='micro')
        print('\n f1 - epoch:%d - score:%.6f \n' % (epoch + 1, score))


def train():
    # load data
    train_dataset = Dataset(training=True)
    dev_dataset = Dataset(training=False)

    # model
    MODEL = name_model[model_name]
    model = MODEL(train_dataset.vocab_size, conf.n_classes, train_dataset.emb_mat)

    # callback
    my_callback = MyCallback()
    f1 = F1(dev_dataset.gen_batch_data(), dev_dataset.steps_per_epoch)
    checkpointer = ModelCheckpoint('data/{}.hdf5'.format(model_name), save_best_only=True)
    early_stop = EarlyStopping(monitor='val_loss', patience=5, verbose=0, mode='auto')

    # train
    model.compile(optimizer=keras.optimizers.Adam(),
                  loss=keras.losses.categorical_crossentropy, metrics=['acc'])
    model.fit_generator(train_dataset.gen_batch_data(),
                        steps_per_epoch=train_dataset.steps_per_epoch,
                        verbose=0,
                        epochs=conf.epochs, callbacks=[my_callback, checkpointer, early_stop, f1])
    keras.models.save_model(model, conf.model_path.format(model_name))


if __name__ == '__main__':
    train()