# -*- coding: utf-8 -*-
# Author: XuMing <xuming624@qq.com>
# Brief:
import time

from sklearn.metrics import classification_report, confusion_matrix

import config
from models.evaluate import cal_multiclass_lr_predict
from models.feature import Feature
from models.reader import data_reader
from models.xgboost_lr_model import XGBLR
from utils.data_utils import load_pkl, load_vocab, save_predict_result, load_dict
from utils.logger import get_logger

logger = get_logger(__name__)


def infer_classic(model_type='xgboost_lr',
                  model_save_path='',
                  label_vocab_path='',
                  test_data_path='',
                  pred_save_path='',
                  feature_vec_path='',
                  col_sep='\t',
                  feature_type='tfidf_word'):
    # load data content
    data_set, true_labels = data_reader(test_data_path, col_sep)
    # init feature
    feature = Feature(data=data_set, feature_type=feature_type,
                      feature_vec_path=feature_vec_path, is_infer=True)
    # get data feature
    data_feature = feature.get_feature()
    # load model
    if model_type == 'xgboost_lr':
        model = XGBLR(model_save_path)
    else:
        model = load_pkl(model_save_path)

    # predict
    pred_label_probs = model.predict_proba(data_feature)

    # label id map
    label_id = load_vocab(label_vocab_path)
    id_label = {v: k for k, v in label_id.items()}

    pred_labels = [id_label[prob.argmax()] for prob in pred_label_probs]
    pred_output = [id_label[prob.argmax()] + col_sep + str(prob.max()) for prob in pred_label_probs]
    logger.info("save infer label and prob result to:%s" % pred_save_path)
    save_predict_result(pred_output, ture_labels=None, pred_save_path=pred_save_path, data_set=data_set)

    # evaluate
    if true_labels:
        try:
            print(classification_report(true_labels, pred_labels))
            print(confusion_matrix(true_labels, pred_labels))
        except UnicodeEncodeError:
            true_labels_id = [label_id[i] for i in true_labels]
            pred_labels_id = [label_id[i] for i in pred_labels]
            print(classification_report(true_labels_id, pred_labels_id))
            print(confusion_matrix(true_labels_id, pred_labels_id))
        except Exception:
            print("error. no true labels")

    # analysis lr model
    if model_type == "logistic_regression":
        feature_weight_dict = load_dict(config.lr_feature_weight_path)
        pred_labels = cal_multiclass_lr_predict(data_set, feature_weight_dict, id_label)
        print(pred_labels[:5])


def infer_deep_model(model_type='cnn',
                     data_path='',
                     model_save_path='',
                     label_vocab_path='',
                     max_len=300,
                     batch_size=128,
                     col_sep='\t',
                     pred_save_path=None):
    from keras.models import load_model
    # load data content
    data_set, true_labels = data_reader(data_path, col_sep)
    # init feature
    # han model need [doc sentence dim] feature(shape 3); others is [sentence dim] feature(shape 2)
    if model_type == 'han':
        feature_type = 'doc_vectorize'
    else:
        feature_type = 'vectorize'
    feature = Feature(data_set, feature_type=feature_type, is_infer=True, max_len=max_len)
    # get data feature
    data_feature = feature.get_feature()

    # load model
    model = load_model(model_save_path)
    # predict, in keras, predict_proba same with predict
    pred_label_probs = model.predict(data_feature, batch_size=batch_size)

    # label id map
    label_id = load_vocab(label_vocab_path)
    id_label = {v: k for k, v in label_id.items()}
    pred_labels = [prob.argmax() for prob in pred_label_probs]
    pred_labels = [id_label[i] for i in pred_labels]
    pred_output = [id_label[prob.argmax()] + col_sep + str(prob.max()) for prob in pred_label_probs]
    logger.info("save infer label and prob result to: %s" % pred_save_path)
    save_predict_result(pred_output, ture_labels=None, pred_save_path=pred_save_path, data_set=data_set)
    if true_labels:
        # evaluate
        assert len(pred_labels) == len(true_labels)
        for label, prob in zip(true_labels, pred_label_probs):
            logger.debug('label_true:%s\tprob_label:%s\tprob:%s' % (label, id_label[prob.argmax()], prob.max()))

        print('total eval:')
        try:
            print(classification_report(true_labels, pred_labels))
            print(confusion_matrix(true_labels, pred_labels))
        except UnicodeEncodeError:
            true_labels_id = [label_id[i] for i in true_labels]
            pred_labels_id = [label_id[i] for i in pred_labels]
            print(classification_report(true_labels_id, pred_labels_id))
            print(confusion_matrix(true_labels_id, pred_labels_id))


if __name__ == "__main__":
    start_time = time.time()
    if config.model_type in ['fasttext', 'cnn', 'rnn', 'han']:
        infer_deep_model(model_type=config.model_type,
                         data_path=config.test_seg_path,
                         model_save_path=config.model_save_path,
                         label_vocab_path=config.label_vocab_path,
                         max_len=config.max_len,
                         batch_size=config.batch_size,
                         col_sep=config.col_sep,
                         pred_save_path=config.pred_save_path)
    else:
        infer_classic(model_type=config.model_type,
                      model_save_path=config.model_save_path,
                      label_vocab_path=config.label_vocab_path,
                      test_data_path=config.test_seg_path,
                      pred_save_path=config.pred_save_path,
                      feature_vec_path=config.feature_vec_path,
                      col_sep=config.col_sep,
                      feature_type=config.feature_type)
    logger.info("spend time %ds." % (time.time() - start_time))
    logger.info("finish predict.")