from collections import Counter
from typing import List

from nltk import FreqDist

Tokens = List[str]


class Vocab:

    def __init__(self, tokens: List[Tokens], special_symbols: List[str] = None):
        special_symbols = [] if special_symbols is None else special_symbols
        special_symbols = special_symbols + ["<eot>", "<response>", "<eos>", "<unk>", "<pad>", "<bos>"]
        self.vocab = FreqDist()
        self.cdf = 0.
        for sample in tokens:
            for token in sample:
                if token not in special_symbols:
                    self.vocab[token] += 1

        print(f"total samples in vocab: {self.vocab.N()}, total tokens in vocab: {self.vocab.B()}")
        self.itos = []
        self.stoi = {}

    def fit(self, num_tokens=15000):
        cdf = 0.
        for cdf in self.vocab._cumulative_frequencies([i[0] for i in self.vocab.most_common(num_tokens)]):
            pass
        self.cdf = cdf / self.vocab.N()
        print(f"cdf of the {num_tokens} most common tokens in vocab {self.cdf}")
        self.itos = ["<unk>", "<pad>", "<eos>", "<bos>"] + [tup[0] for tup in self.vocab.most_common(num_tokens)]
        self.stoi = Counter({key: index for index, key in enumerate(self.itos)})