# -*- coding: utf-8 -*-
"""
@author:XuMing<xuming624@qq.com>
@description:
"""

from typing import List, Optional

import numpy as np

from .base_processor import BaseProcessor, get_list_subset


class DefaultProcessor(BaseProcessor):
    """
    Corpus Pre Processor class
    """

    def __init__(self, multi_label=False, **kwargs):
        super(DefaultProcessor, self).__init__(**kwargs)
        self.multi_label = multi_label
        self.multi_label_binarizer = None

    def info(self):
        info = super(DefaultProcessor, self).info()
        info['multi_label'] = self.multi_label
        return info

    def _build_label_dict(self,
                          labels: List[str]):
        from sklearn.preprocessing import MultiLabelBinarizer
        if self.multi_label:
            label_set = set()
            for i in labels:
                label_set = label_set.union(list(i))
        else:
            label_set = set(labels)
        self.label2idx = {}
        for idx, label in enumerate(sorted(label_set)):
            self.label2idx[label] = len(self.label2idx)

        self.idx2label = dict([(value, key) for key, value in self.label2idx.items()])
        self.dataset_info['label_count'] = len(self.label2idx)
        self.multi_label_binarizer = MultiLabelBinarizer(classes=list(self.label2idx.keys()))

    def process_y_dataset(self,
                          data: List[str],
                          max_len: Optional[int] = None,
                          subset: Optional[List[int]] = None) -> np.ndarray:
        from tensorflow.python.keras.utils import to_categorical
        if subset is not None:
            target = get_list_subset(data, subset)
        else:
            target = data
        if self.multi_label:
            return self.multi_label_binarizer.fit_transform(target)
        else:
            numerized_samples = self.numerize_label_sequences(target)
            return to_categorical(numerized_samples, len(self.label2idx))

    def numerize_token_sequences(self,
                                 sequences: List[List[str]]):
        result = []
        for seq in sequences:
            if self.add_bos_eos:
                seq = [self.token_bos] + seq + [self.token_eos]
            unk_index = self.token2idx[self.token_unk]
            result.append([self.token2idx.get(token, unk_index) for token in seq])
        return result

    def numerize_label_sequences(self,
                                 sequences: List[str]) -> List[int]:
        """
        Convert label sequence to label-index sequence
        ``['O', 'O', 'B-ORG'] -> [0, 0, 2]``

        Args:
            sequences: label sequence, list of str

        Returns:
            label-index sequence, list of int
        """
        return [self.label2idx[label] for label in sequences]

    def reverse_numerize_label_sequences(self, sequences, **kwargs):
        if self.multi_label:
            return self.multi_label_binarizer.inverse_transform(sequences)
        else:
            return [self.idx2label[label] for label in sequences]