#
# This file is part of Dragonfly.
# (c) Copyright 2019 by David Zurow
# Licensed under the LGPL.
#
#   Dragonfly is free software: you can redistribute it and/or modify it
#   under the terms of the GNU Lesser General Public License as published
#   by the Free Software Foundation, either version 3 of the License, or
#   (at your option) any later version.
#
#   Dragonfly is distributed in the hope that it will be useful, but
#   WITHOUT ANY WARRANTY; without even the implied warranty of
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
#   Lesser General Public License for more details.
#
#   You should have received a copy of the GNU Lesser General Public
#   License along with Dragonfly.  If not, see
#   <http://www.gnu.org/licenses/>.
#

"""
Compiler classes for Kaldi backend
"""

import collections, logging, os.path, re, subprocess, types

from .testing                   import debug_timer
from .dictation                 import AlternativeDictation, DefaultDictation
from ..base                     import CompilerBase, CompilerError
from ...grammar                 import elements as elements_

from kaldi_active_grammar import WFST, KaldiRule
from kaldi_active_grammar import Compiler as KaldiAGCompiler

import six
from six import text_type
from six.moves import map, range

_log = logging.getLogger("engine.compiler")


#---------------------------------------------------------------------------
# Utilities

_trace_level=0
def trace_compile(func):
    return func
    def dec(self, element, src_state, dst_state, grammar, fst):
        global _trace_level
        s = '%s %s: compiling %s' % (grammar.name, '==='*_trace_level, element)
        l = 140-len(s)
        s += ' '*l + '| %-20s %s -> %s' % (id(fst), src_state, dst_state)
        grammar._log_load.error(s)
        _trace_level+=1
        ret = func(self, element, src_state, dst_state, grammar, fst)
        _trace_level-=1
        grammar._log_load.error('%s %s: compiling %s.' % (grammar.name, '...'*_trace_level, element))
        return ret
    return dec

InternalGrammar = collections.namedtuple('InternalGrammar', 'name')
InternalRule = collections.namedtuple('InternalRule', 'name gstring')

MockLiteral = collections.namedtuple('MockLiteral', 'words')


#---------------------------------------------------------------------------

class KaldiCompiler(CompilerBase, KaldiAGCompiler):

    def __init__(self, model_dir, tmp_dir, auto_add_to_user_lexicon=None, lazy_compilation=None, **kwargs):
        CompilerBase.__init__(self)
        KaldiAGCompiler.__init__(self, model_dir=model_dir, tmp_dir=tmp_dir, **kwargs)

        self.auto_add_to_user_lexicon = bool(auto_add_to_user_lexicon)
        self.lazy_compilation = bool(lazy_compilation)

        self.kaldi_rule_by_rule_dict = collections.OrderedDict()  # maps Rule -> KaldiRule
        self._grammar_rule_states_dict = dict()  # FIXME: disabled!
        self.kaldi_rules_by_listreflist_dict = collections.defaultdict(set)
        self.added_word = False
        self.internal_grammar = InternalGrammar('!kaldi_engine_internal')

    impossible_word = property(lambda self: self._longest_word.lower())  # FIXME
    unknown_word = '<unk>'

    #-----------------------------------------------------------------------
    # Methods for handling lexicon translation.

    # FIXME: documentation
    translation_dict = {
    }

    # FIXME: documentation
    untranslation_dict = { v: k for k, v in translation_dict.items() }
    translation_dict.update({
    })

    def untranslate_output(self, output):
        for old, new in six.iteritems(self.untranslation_dict):
            output = output.replace(old, new)
        return output

    def translate_words(self, words):
        # Unused
        if self.translation_dict:
            new_words = []
            for word in words:
                for old, new in six.iteritems(self.translation_dict):
                    word = word.replace(old, new)
                new_words.extend(word.split())
            words = new_words
        words = [word.lower() for word in words]
        for i in range(len(words)):
            if words[i] not in self.lexicon_words:
                words[i] = self.handle_oov_word(words[i])
        return words

    def handle_oov_word(self, word):
        if self.auto_add_to_user_lexicon:
            try:
                pronunciations = self.model.add_word(word, lazy_compilation=True)
                self.added_word = True
            except Exception as e:
                self._log.exception("%s: exception automatically adding word %r" % (self, word))
            else:
                for phones in pronunciations:
                    self._log.warning("%s: Word not in lexicon (generated automatic pronunciation): %r [%s]" % (self, word, ' '.join(phones)))
                return word

        self._log.warning("%s: Word %r not in lexicon (will NOT be recognized; see documentation about user lexicon and auto_add_to_user_lexicon)" % (self, word))
        word = self.impossible_word
        return word

    #-----------------------------------------------------------------------
    # Methods for compiling grammars.

    def compile_grammar(self, grammar, engine):
        self._log.debug("%s: Compiling grammar %s." % (self, grammar.name))

        kaldi_rule_by_rule_dict = collections.OrderedDict()
        for rule in grammar.rules:
            if rule.exported:
                if rule.element is None:
                    # We cannot deal with an empty rule (could be fixed by refactoring)
                    raise CompilerError("Invalid None element for %s in %s" % (rule, grammar))

                kaldi_rule = KaldiRule(self,
                    name='%s::%s' % (grammar.name, rule.name),
                    has_dictation=bool((rule.element is not None) and ('<Dictation()>' in rule.gstring())))  # FIXME
                kaldi_rule.parent_grammar = grammar
                kaldi_rule.parent_rule = rule
                kaldi_rule_by_rule_dict[rule] = kaldi_rule

                try:
                    self._compile_rule_root(rule, grammar, kaldi_rule)
                except Exception as e:
                    kaldi_rule.destroy()
                    raise

        self.kaldi_rule_by_rule_dict.update(kaldi_rule_by_rule_dict)
        return kaldi_rule_by_rule_dict

    def _compile_rule_root(self, rule, grammar, kaldi_rule):
        self._compile_rule(rule, grammar, kaldi_rule, kaldi_rule.fst)
        if self.added_word:
            self.model.generate_lexicon_files()
            self.model.load_words()
            self.decoder.load_lexicon()
            self.added_word = False
        kaldi_rule.compile(lazy=self.lazy_compilation)

    def _compile_rule(self, rule, grammar, kaldi_rule, fst, export=True):
        """ :param export: whether rule is exported (a root rule) """
        # Determine whether this rule has already been compiled.
        # if (grammar.name, rule.name) in self._grammar_rule_states_dict:
        #     self._log.debug("%s: Already compiled rule %s%s." % (self, rule.name, ' [EXPORTED]' if export else ''))
        #     return self._grammar_rule_states_dict[(grammar.name, rule.name)]
        # else:
        self._log.debug("%s: Compiling rule %s%s." % (self, rule.name, ' [EXPORTED]' if export else ''))

        if export:
            # Root rule, so must handle grammar's weight, in addition to this rule's weight
            weight = self.get_weight(grammar) * self.get_weight(rule)
            outer_src_state = fst.add_state(initial=True)
            inner_src_state = fst.add_state()
            fst.add_arc(outer_src_state, inner_src_state, None, weight=weight)
            dst_state = fst.add_state(final=True)

        else:
            # Only handle this rule's weight
            weight = self.get_weight(rule)
            outer_src_state = fst.add_state()
            inner_src_state = fst.add_state()
            fst.add_arc(outer_src_state, inner_src_state, None, weight=weight)
            dst_state = fst.add_state()

        self.compile_element(rule.element, inner_src_state, dst_state, grammar, kaldi_rule, fst)
        # self._grammar_rule_states_dict[(grammar.name, rule.name)] = (src_state, dst_state)
        return (outer_src_state, dst_state)

    def unload_grammar(self, grammar, rules, engine):
        for rule in rules:
            kaldi_rule = self.kaldi_rule_by_rule_dict[rule]
            # Unload kaldi_rule: destroy() handles KaldiAGCompiler stuff; we must handle ours
            kaldi_rule.destroy()
            del self.kaldi_rule_by_rule_dict[rule]
            for kaldi_rules_set in self.kaldi_rules_by_listreflist_dict.values():
                kaldi_rules_set.discard(kaldi_rule)
            # NOTE: the kaldi_rule_by_rule_dict we returned from compile_grammar() is not updated, but it should be dropped upon unload anyway!

    def update_list(self, lst, grammar):
        # Note: we update all rules in all grammars that reference this list (unlike WSR/natlink?)
        lst_kaldi_rules = self.kaldi_rules_by_listreflist_dict[id(lst)]
        for kaldi_rule in lst_kaldi_rules:
            with kaldi_rule.reload():
                self._compile_rule_root(kaldi_rule.parent_rule, grammar, kaldi_rule)

    #-----------------------------------------------------------------------
    # Methods for compiling elements.

    _eps_like_nonterms = frozenset(('#nonterm:dictation', '#nonterm:dictation_cloud'))

    def compile_element(self, element, *args, **kwargs):
        """Compile element in FST (from src_state to dst_state) and return result."""
        # Look for a compiler method to handle the given element.
        for element_type, compiler in self.element_compilers:
            if isinstance(element, element_type):
                return compiler(self, element, *args, **kwargs)
        # Didn't find a compiler method for this element type.
        raise NotImplementedError("Compiler %s not implemented for element type %s." % (self, element))

    # @trace_compile
    def _compile_sequence(self, element, src_state, dst_state, grammar, kaldi_rule, fst):
        src_state = self.add_weight_linkage(src_state, dst_state, self.get_weight(element), fst)
        children = element.children
        # Optimize for special lengths
        if len(children) == 0:
            fst.add_arc(src_state, dst_state, None)
            return

        elif len(children) == 1:
            return self.compile_element(children[0], src_state, dst_state, grammar, kaldi_rule, fst)

        else:  # len(children) >= 2:
            # Handle Repetition elements differently as a special case
            is_repetition = isinstance(element, elements_.Repetition)
            if is_repetition and element.optimize:
                # Repetition...
                # Insert new states, so back arc only affects child
                s1 = fst.add_state()
                s2 = fst.add_state()
                fst.add_arc(src_state, s1, None)
                # NOTE: to avoid creating an un-decodable epsilon loop, we must not allow an all-epsilon child here (compile_graph_agf should check this)
                self.compile_element(children[0], s1, s2, grammar, kaldi_rule, fst)
                if not fst.has_eps_path(s1, s2, self._eps_like_nonterms):
                    fst.add_arc(s2, s1, fst.eps_disambig, fst.eps)  # back arc
                    fst.add_arc(s2, dst_state, None)
                    return

                else:
                    # Cannot do optimize path, because of epsilon loop, so finish up with Sequence path
                    self._log.warning("%s: Cannot optimize Repetition element, because its child element can match empty string;"
                        " falling back to inefficient non-optimize path. (this is not that bad)" % self)
                    states = [src_state, s2] + [fst.add_state() for i in range(len(children)-2)] + [dst_state]
                    for i, child in enumerate(children[1:], start=1):
                        s1 = states[i]
                        s2 = states[i + 1]
                        self.compile_element(child, s1, s2, grammar, kaldi_rule, fst)
                    return

            else:
                # Sequence, not Repetition...
                # Insert new states for individual children elements
                states = [src_state] + [fst.add_state() for i in range(len(children)-1)] + [dst_state]
                for i, child in enumerate(children):
                    s1 = states[i]
                    s2 = states[i + 1]
                    self.compile_element(child, s1, s2, grammar, kaldi_rule, fst)
                return

    # @trace_compile
    def _compile_alternative(self, element, src_state, dst_state, grammar, kaldi_rule, fst):
        src_state = self.add_weight_linkage(src_state, dst_state, self.get_weight(element), fst)
        for child in element.children:
            self.compile_element(child, src_state, dst_state, grammar, kaldi_rule, fst)

    # @trace_compile
    def _compile_optional(self, element, src_state, dst_state, grammar, kaldi_rule, fst):
        src_state = self.add_weight_linkage(src_state, dst_state, self.get_weight(element), fst)
        self.compile_element(element.children[0], src_state, dst_state, grammar, kaldi_rule, fst)
        fst.add_arc(src_state, dst_state, None)

    # @trace_compile
    def _compile_literal(self, element, src_state, dst_state, grammar, kaldi_rule, fst):
        weight = self.get_weight(element)  # Handle weight internally below, without adding a state
        words = element.words
        words = list(map(text_type, words))
        # words = self.translate_words(words)

        # Special case optimize single-word literal
        if len(words) == 1:
            word = words[0].lower()
            if word not in self.lexicon_words:
                word = self.handle_oov_word(word)
            fst.add_arc(src_state, dst_state, word, weight=weight)

        else:
            words = [word.lower() for word in words]
            for i in range(len(words)):
                if words[i] not in self.lexicon_words:
                    words[i] = self.handle_oov_word(words[i])
            # "Insert" new states for individual words
            states = [src_state] + [fst.add_state() for i in range(len(words)-1)] + [dst_state]
            for i, word in enumerate(words):
                fst.add_arc(states[i], states[i + 1], word, weight=weight)
                weight = None  # Only need to set weight on first arc

    # @trace_compile
    def _compile_rule_ref(self, element, src_state, dst_state, grammar, kaldi_rule, fst):
        weight = self.get_weight(element)  # Handle weight internally below without adding a state
        rule_src_state, rule_dst_state = self._compile_rule(element.rule, grammar, kaldi_rule, fst, export=False)
        fst.add_arc(src_state, rule_src_state, None, weight=weight)
        fst.add_arc(rule_dst_state, dst_state, None)

    # @trace_compile
    def _compile_list_ref(self, element, src_state, dst_state, grammar, kaldi_rule, fst):
        src_state = self.add_weight_linkage(src_state, dst_state, self.get_weight(element), fst)
        # list_rule_name = "__list_%s" % element.list.name
        if element.list not in grammar.lists:
            # Should only happen during initial compilation; during updates, we must skip this
            grammar.add_list(element.list)
        self.kaldi_rules_by_listreflist_dict[id(element.list)].add(kaldi_rule)
        for child_str in element.list.get_list_items():
            self._compile_literal(MockLiteral(child_str.split()), src_state, dst_state, grammar, kaldi_rule, fst)

    # @trace_compile
    def _compile_dictation(self, element, src_state, dst_state, grammar, kaldi_rule, fst):
        src_state = self.add_weight_linkage(src_state, dst_state, self.get_weight(element), fst)
        # fst.add_arc(src_state, dst_state, '#nonterm:dictation', olabel=WFST.eps)
        extra_state = fst.add_state()
        cloud_dictation = isinstance(element, (AlternativeDictation, DefaultDictation)) and element.cloud
        dictation_nonterm = '#nonterm:dictation_cloud' if cloud_dictation else '#nonterm:dictation'
        fst.add_arc(src_state, extra_state, '#nonterm:dictation', dictation_nonterm)
        # Accepts zero or more words
        fst.add_arc(extra_state, dst_state, WFST.eps, '#nonterm:end')
        # fst.add_arc(extra_state, dst_state, '!SIL', '#nonterm:end')  # Causes problems with lack of phones during decoding

    # @trace_compile
    def _compile_impossible(self, element, src_state, dst_state, grammar, kaldi_rule, fst):
        # FIXME: not impossible enough (lower probability?)
        # Note: setting weight=0 breaks compilation!
        fst.add_arc(src_state, dst_state, self.impossible_word, weight=1e-10)

    # @trace_compile
    def _compile_empty(self, element, src_state, dst_state, grammar, kaldi_rule, fst):
        src_state = self.add_weight_linkage(src_state, dst_state, self.get_weight(element), fst)
        fst.add_arc(src_state, dst_state, WFST.eps)

    #-----------------------------------------------------------------------
    # Utility methods.

    def get_weight(self, obj, name='weight'):
        """ Gets the weight of given grammar or rule, checking for invalid values. """
        weight = getattr(obj, name, 1)
        try:
            weight = float(weight)
        except TypeError:
            # Ignore crazy string method handling on Dictation elements; otherwise error
            if not (isinstance(obj, elements_.Dictation) and isinstance(weight, types.FunctionType)):
                self._log.error("%s: Weight must be a numeric, but %s %s is %s" % (self, obj, name, weight))
                import pdb; pdb.set_trace()
            weight = 1
        if weight <= 0:
            self._log.error("%s: Weight cannot be negative or 0, but %s %s is %s" % (self, obj, name, weight))
            weight = 1e-9
        return weight

    def add_weight_linkage(self, outer_src_state, dst_state, weight, fst):
        """ Returns new source state, to be used by the caller as the effective source state. Only modifies if weight is non-default. """
        if (weight is None) or (weight == 1):
            return outer_src_state
        # self._log.debug("%s: Adding weight linkage for weight=%s" % (self, weight))
        inner_src_state = fst.add_state()
        fst.add_arc(outer_src_state, inner_src_state, None, weight=weight)
        return inner_src_state