import torch
from torchtext.data import TabularDataset, Field, Iterator
from torchtext.vocab import Vectors
from sklearn.metrics import f1_score, accuracy_score, recall_score, precision_score

from ..base.tool import Tool
from ..utils.log import logger
from .config import DEVICE, DEFAULT_CONFIG

seed = 2019
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)


def light_tokenize(text):
    return [text]


ENTITY = Field(tokenize=light_tokenize, batch_first=True)
RELATION = Field(tokenize=light_tokenize, batch_first=True)
Fields = [
            ('head', ENTITY),
            ('rel', RELATION),
            ('tail', ENTITY)
        ]


class RLTool(Tool):
    def get_dataset(self, path: str, fields=Fields, file_type='csv', skip_header=False):
        logger.info('loading dataset from {}'.format(path))
        rl_dataset = TabularDataset(path, format=file_type, fields=fields, skip_header=skip_header)
        logger.info('successed loading dataset')
        return rl_dataset

    def get_vocab(self, *dataset):
        logger.info('building entity vocab...')
        ENTITY.build_vocab(*dataset)
        logger.info('successed building entity vocab')
        logger.info('building relation vocab...')
        RELATION.build_vocab(*dataset)
        logger.info('successed building relation vocab')
        return ENTITY.vocab, RELATION.vocab

    def get_vectors(self, path: str):
        logger.info('loading vectors from {}'.format(path))
        vectors = Vectors(path)
        logger.info('successed loading vectors')
        return vectors

    def get_iterator(self, dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE,
                     sort_key=lambda x: len(x.text)):
        return Iterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key)

    def get_score(self, model, texts, labels, score_type='f1'):
        metrics_map = {
            'f1': f1_score,
            'p': precision_score,
            'r': recall_score,
            'acc': accuracy_score
        }
        metric_func = metrics_map[score_type] if score_type in metrics_map else metrics_map['f1']
        assert len(texts) == len(labels)
        vec_predict = model(texts)
        soft_predict = torch.softmax(vec_predict, dim=1)
        predict_prob, predict_index = torch.max(soft_predict.cpu().data, dim=1)
        # print('prob', predict_prob)
        # print('index', predict_index)
        # print('labels', labels)
        labels = labels.view(-1).cpu().data.numpy()
        return metric_func(predict_index, labels, average='micro')


rl_tool = RLTool()