from ctcdecode import CTCBeamDecoder
import torch.nn.functional as F
import torch
import numpy as np
import editdistance


class Decoder:
    def __init__(self, labels, lm_path=None, alpha=1, beta=1.5, cutoff_top_n=40, cutoff_prob=0.99, beam_width=200, num_processes=24, blank_id=0):
        self.vocab_list = ['_'] + labels # NOTE: blank symbol
        self._decoder = CTCBeamDecoder(['_@'] + labels[1:], lm_path, alpha, beta, cutoff_top_n, cutoff_prob, beam_width, num_processes, blank_id)
        # NOTE: the whitespace symbol is replaced with an @ symbol for explicit modeling in char-based LMs

    def convert_to_string(self, tokens, seq_len=None):
        if not seq_len:
            seq_len = tokens.size(0)
        out = []
        for i in range(seq_len):
            if len(out) == 0:
                if tokens[i] != 0:
                    out.append(tokens[i])
            else:
                if tokens[i] != 0 and tokens[i] != tokens[i - 1]:
                    out.append(tokens[i])
        return ''.join(self.vocab_list[i] for i in out)
    
    def decode_beam(self, logits, seq_lens):
        decoded = []
        tlogits = logits.transpose(0, 1)
        beam_result, beam_scores, timesteps, out_seq_len = self._decoder.decode(tlogits.softmax(-1), seq_lens)
        for i in range(tlogits.size(0)):
            output_str = ''.join(map(lambda x: self.vocab_list[x], beam_result[i][0][:out_seq_len[i][0]]))
            decoded.append(output_str)
        return decoded

    def decode_greedy(self, logits, seq_lens):
        decoded = []
        tlogits = logits.transpose(0, 1)
        _, tokens = torch.max(tlogits, 2)
        for i in range(tlogits.size(0)):
            output_str = self.convert_to_string(tokens[i], seq_lens[i])
            decoded.append(output_str)
        return decoded
    
    def get_mean(self, decoded, gt, individual_length, func):
        total_norm  = 0.0
        length      = len(decoded)
        for i in range(0, length):
            val         = float(func(decoded[i], gt[i]))
            total_norm += val / individual_length
        return total_norm / length

    def wer(self, r, h):
        # initialisation
        d = np.zeros((len(r)+1)*(len(h)+1), dtype=np.uint8)
        d = d.reshape((len(r)+1, len(h)+1))
        for i in range(len(r)+1):
            for j in range(len(h)+1):
                if i == 0:
                    d[0][j] = j
                elif j == 0:
                    d[i][0] = i

        # computation
        for i in range(1, len(r)+1):
            for j in range(1, len(h)+1):
                if r[i-1] == h[j-1]:
                    d[i][j] = d[i-1][j-1]
                else:
                    substitution = d[i-1][j-1] + 1
                    insertion    = d[i][j-1] + 1
                    deletion     = d[i-1][j] + 1
                    d[i][j] = min(substitution, insertion, deletion)

        return d[len(r)][len(h)]

    def wer_sentence(self, r, h):
        return self.wer(r.split(), h.split())
    
    def cer_batch(self, decoded, gt):
        assert len(decoded) == len(gt), 'batch size mismatch: {}!={}'.format(len(decoded), len(gt))
        mean_indiv_len = np.mean([len(s) for s in gt])
        
        return self.get_mean(decoded, gt, mean_indiv_len, editdistance.eval)
        
    def wer_batch(self, decoded, gt):
        assert len(decoded) == len(gt), 'batch size mismatch: {}!={}'.format(len(decoded), len(gt))
        mean_indiv_len = np.mean([len(s.split()) for s in gt])
        
        return self.get_mean(decoded, gt, mean_indiv_len, self.wer_sentence)