"""
Pystruct-compatible models.
"""

# Author: Vlad Niculae <vlad@vene.ro>
# License: BSD 3-clause

# AD3 is (c) Andre F. T. Martins, LGPLv3.0: http://www.cs.cmu.edu/~ark/AD3/

import warnings
import numpy as np

from sklearn.utils import compute_class_weight
from sklearn.utils.extmath import safe_sparse_dot
from sklearn.metrics import f1_score
from sklearn.preprocessing import LabelEncoder, label_binarize
from pystruct.models import StructuredModel

from marseille.inference import loss_augment_unaries, CDCP_ILLEGAL_LINKS
from marseille.argdoc import DocLabel
from marseille.custom_logging import logging

from itertools import permutations

from ad3 import factor_graph as fg


def _binary_2d(y):
    if y.shape[1] == 1:
        y = np.column_stack([1 - y, y])
    return y


def arg_f1_scores(Y_true, Y_pred, **kwargs):

    macro = []
    micro_true = []
    micro_pred = []

    for y_true, y_pred in zip(Y_true, Y_pred):
        macro.append(f1_score(y_true, y_pred, **kwargs))
        micro_true.extend(y_true)
        micro_pred.extend(y_pred)

    return np.mean(macro), f1_score(micro_true, micro_pred, **kwargs)


class BaseArgumentMixin(object):

    def initialize_labels(self, Y):

        y_nodes_flat = [y_val for y in Y for y_val in y.nodes]
        y_links_flat = [y_val for y in Y for y_val in y.links]
        self.prop_encoder_ = LabelEncoder().fit(y_nodes_flat)
        self.link_encoder_ = LabelEncoder().fit(y_links_flat)

        self.n_prop_states = len(self.prop_encoder_.classes_)
        self.n_link_states = len(self.link_encoder_.classes_)

        self.prop_cw_ = np.ones_like(self.prop_encoder_.classes_,
                                     dtype=np.double)
        self.link_cw_ = compute_class_weight(self.class_weight,
                                             self.link_encoder_.classes_,
                                             y_links_flat)

        self.link_cw_ /= self.link_cw_.min()

        logging.info('Setting node class weights {}'.format(", ".join(
            "{}: {}".format(lbl, cw) for lbl, cw in zip(
                self.prop_encoder_.classes_, self.prop_cw_))))

        logging.info('Setting link class weights {}'.format(", ".join(
            "{}: {}".format(lbl, cw) for lbl, cw in zip(
                self.link_encoder_.classes_, self.link_cw_))))

    def _round(self, prop_marg, link_marg, prop_unary=None, link_unary=None,
               inverse_transform=True):

        # ensure ties are broken according to unary scores
        if prop_unary is not None:
            prop_unary = prop_unary.copy()
            prop_unary -= np.min(prop_unary)
            prop_unary /= np.max(prop_unary) * np.max(prop_marg)
            prop_marg[prop_marg > 1e-9] += prop_unary[prop_marg > 1e-9]
        if link_unary is not None:
            link_unary = link_unary.copy()
            link_unary -= np.min(link_unary)
            link_unary /= np.max(link_unary) * np.max(link_marg)
            link_marg[link_marg > 1e-9] += link_unary[link_marg > 1e-9]

        y_hat_props = np.argmax(prop_marg, axis=1)
        y_hat_links = np.argmax(link_marg, axis=1)
        if inverse_transform:
            y_hat_props = self.prop_encoder_.inverse_transform(y_hat_props)
            y_hat_links = self.link_encoder_.inverse_transform(y_hat_links)
        return DocLabel(y_hat_props, y_hat_links)

    def loss(self, y, y_hat):

        if not isinstance(y_hat, DocLabel):
            return self.continuous_loss(y, y_hat)

        y_nodes = self.prop_encoder_.transform(y.nodes)
        y_links = self.link_encoder_.transform(y.links)

        node_loss = np.sum(self.prop_cw_[y_nodes] * (y.nodes != y_hat.nodes))
        link_loss = np.sum(self.link_cw_[y_links] * (y.links != y_hat.links))

        return node_loss + link_loss

    def max_loss(self, y):
        y_nodes = self.prop_encoder_.transform(y.nodes)
        y_links = self.link_encoder_.transform(y.links)
        return np.sum(self.prop_cw_[y_nodes]) + np.sum(self.link_cw_[y_links])

    def continuous_loss(self, y, y_hat):

        if isinstance(y_hat, DocLabel):
            raise ValueError("continuous loss on discrete input")

        if isinstance(y_hat[0], tuple):
            y_hat = y_hat[0]

        prop_marg, link_marg = y_hat
        y_nodes = self.prop_encoder_.transform(y.nodes)
        y_links = self.link_encoder_.transform(y.links)

        prop_ix = np.indices(y.nodes.shape)
        link_ix = np.indices(y.links.shape)

        # relies on prop_marg and link_marg summing to 1 row-wise
        prop_loss = np.sum(self.prop_cw_[y_nodes] *
                           (1 - prop_marg[prop_ix, y_nodes]))

        link_loss = np.sum(self.link_cw_[y_links] *
                           (1 - link_marg[link_ix, y_links]))

        loss = prop_loss + link_loss
        return loss

    def _marg_rounded(self, x, y):
        y_node = y.nodes
        y_link = y.links
        Y_node = label_binarize(y_node, self.prop_encoder_.classes_)
        Y_link = label_binarize(y_link, self.link_encoder_.classes_)

        # XXX can this be avoided?
        Y_node, Y_link = map(_binary_2d, (Y_node, Y_link))

        src_type = Y_node[x.link_to_prop[:, 0]]
        trg_type = Y_node[x.link_to_prop[:, 1]]

        if self.compat_features:
            pw = np.einsum('...j,...k,...l->...jkl',
                           src_type, trg_type, Y_link)
            compat = np.tensordot(x.X_compat.T, pw, axes=[1, 0])
        else:
            # equivalent to compat_features == np.ones(n_links)
            compat = np.einsum('ij,ik,il->jkl', src_type, trg_type, Y_link)

        second_order = []

        if self.coparents_ or self.grandparents_ or self.siblings_:
            link = {(a, b): k for k, (a, b) in enumerate(x.link_to_prop)}
            if self.coparents_:
                second_order.extend(y_link[link[a, b]] & y_link[link[c, b]]
                                    for a, b, c in x.second_order)
            if self.grandparents_:
                second_order.extend(y_link[link[a, b]] & y_link[link[b, c]]
                                    for a, b, c in x.second_order)
            if self.siblings_:
                second_order.extend(y_link[link[b, a]] & y_link[link[b, c]]
                                    for a, b, c in x.second_order)
        second_order = np.array(second_order)

        return Y_node, Y_link, compat, second_order

    def _marg_fractional(self, x, y):
        (prop_marg, link_marg), (compat_marg, second_order_marg) = y

        if self.compat_features:
            compat_marg = np.tensordot(x.X_compat.T, compat_marg, axes=[1, 0])
        else:
            compat_marg = compat_marg.sum(axis=0)

        return prop_marg, link_marg, compat_marg, second_order_marg

    def _inference(self, x, potentials, exact=False, relaxed=True,
                   return_energy=False, constraints=None,
                   eta=0.1, adapt=True, max_iter=5000,
                   verbose=False):

        (prop_potentials,
         link_potentials,
         compat_potentials,
         coparent_potentials,
         grandparent_potentials,
         sibling_potentials) = potentials

        n_props, n_prop_classes = prop_potentials.shape
        n_links, n_link_classes = link_potentials.shape

        g = fg.PFactorGraph()
        g.set_verbosity(verbose)

        prop_vars = [g.create_multi_variable(n_prop_classes)
                     for _ in range(n_props)]
        link_vars = [g.create_multi_variable(n_link_classes)
                     for _ in range(n_links)]

        for var, scores in zip(prop_vars, prop_potentials):
            for state, score in enumerate(scores):
                var.set_log_potential(state, score)

        for var, scores in zip(link_vars, link_potentials):
            for state, score in enumerate(scores):
                var.set_log_potential(state, score)

        # compatibility trigram factors
        compat_factors = []
        link_vars_dict = {}

        link_on, link_off = self.link_encoder_.transform([True, False])

        # account for compat features
        if self.compat_features:
            assert compat_potentials.shape[0] == n_links
            compats = compat_potentials
        else:
            compats = (compat_potentials for _ in range(n_links))

        for (src, trg), link_v, compat in zip(x.link_to_prop,
                                              link_vars,
                                              compats):
            src_v = prop_vars[src]
            trg_v = prop_vars[trg]
            compat_factors.append(g.create_factor_dense([src_v, trg_v, link_v],
                                                        compat.ravel()))

            # keep track of binary link variables, for constraints.
            # we need .get_state() to get the underlaying PBinaryVariable
            link_vars_dict[src, trg] = link_v.get_state(link_on)

        # second-order factors
        coparent_factors = []
        grandparent_factors = []
        sibling_factors = []

        for score, (a, b, c) in zip(coparent_potentials, x.second_order):
            # a -> b <- c
            vars = [link_vars_dict[a, b], link_vars_dict[c, b]]
            coparent_factors.append(g.create_factor_pair(vars, score))

        for score, (a, b, c) in zip(grandparent_potentials, x.second_order):
            # a -> b -> c
            vars = [link_vars_dict[a, b], link_vars_dict[b, c]]
            grandparent_factors.append(g.create_factor_pair(vars, score))

        for score, (a, b, c) in zip(sibling_potentials, x.second_order):
            # a <- b -> c
            vars = [link_vars_dict[b, a], link_vars_dict[b, c]]
            sibling_factors.append(g.create_factor_pair(vars, score))

        # domain-specific constraints
        if constraints and 'cdcp' in constraints:

            # antisymmetry: if a -> b, then b cannot -> a
            for src in range(n_props):
                for trg in range(src):
                    fwd_link_v = link_vars_dict[src, trg]
                    rev_link_v = link_vars_dict[trg, src]
                    g.create_factor_logic('ATMOSTONE',
                                          [fwd_link_v, rev_link_v],
                                          [False, False])

            # transitivity.
            # forall a != b != c: a->b and b->c imply a->c
            for a, b, c in permutations(range(n_props), 3):
                ab_link_v = link_vars_dict[a, b]
                bc_link_v = link_vars_dict[b, c]
                ac_link_v = link_vars_dict[a, c]
                g.create_factor_logic('IMPLY',
                                      [ab_link_v, bc_link_v, ac_link_v],
                                      [False, False, False])

            # standard model:
            if 'strict' in constraints:
                for src, trg in x.link_to_prop:
                    src_v = prop_vars[src]
                    trg_v = prop_vars[trg]

                    for types in CDCP_ILLEGAL_LINKS:
                        src_ix, trg_ix = self.prop_encoder_.transform(types)
                        g.create_factor_logic('IMPLY',
                                              [src_v.get_state(src_ix),
                                               trg_v.get_state(trg_ix),
                                               link_vars_dict[src, trg]],
                                              [False, False, True])

        elif constraints and 'ukp' in constraints:

            # Tree constraints using AD3 MST factor for each paragraph.
            # First, identify paragraphs
            prop_para = np.array(x.prop_para)
            link_para = prop_para[x.link_to_prop[:, 0]]
            tree_factors = []

            for para_ix in np.unique(link_para):
                props = np.where(prop_para == para_ix)[0]
                offset = props.min()
                para_vars = []
                para_arcs = []  # call them arcs, semantics differ from links

                # add a new head node pointing to every possible variable
                for relative_ix, prop_ix in enumerate(props, 1):
                    para_vars.append(g.create_binary_variable())
                    para_arcs.append((0, relative_ix))

                # add an MST arc for each link
                for src, trg in x.link_to_prop[link_para == para_ix]:
                    relative_src = src - offset + 1
                    relative_trg = trg - offset + 1
                    para_vars.append(link_vars_dict[src, trg])
                    # MST arcs have opposite direction from argument links!
                    # because each prop can have multiple supports but not
                    # the other way around
                    para_arcs.append((relative_trg, relative_src))

                tree = fg.PFactorTree()
                g.declare_factor(tree, para_vars, True)
                tree.initialize(1 + len(props), para_arcs)
                tree_factors.append(tree)

            if 'strict' in constraints:
                # further domain-specific constraints
                mclaim_ix, claim_ix, premise_ix = self.prop_encoder_.transform(
                    ['MajorClaim', 'Claim', 'Premise'])

                # a -> b implies a = 'premise'
                for (src, trg), link_v in zip(x.link_to_prop, link_vars):
                    src_v = prop_vars[src]
                    g.create_factor_logic('IMPLY',
                                          [link_v.get_state(link_on),
                                           src_v.get_state(premise_ix)],
                                          [False, False])

        g.fix_multi_variables_without_factors()
        g.set_eta_ad3(eta)
        g.adapt_eta_ad3(adapt)
        g.set_max_iterations_ad3(max_iter)

        if exact:
            val, posteriors, additionals, status = g.solve_exact_map_ad3()
        else:
            val, posteriors, additionals, status = g.solve_lp_map_ad3()

        status = ["integer", "fractional", "infeasible", "not solved"][status]

        prop_marg = posteriors[:n_props * n_prop_classes]
        prop_marg = np.array(prop_marg).reshape(n_props, -1)

        link_marg = posteriors[n_props * n_prop_classes:]
        # remaining posteriors are for artificial root nodes for MST factors
        link_marg = link_marg[:n_links * n_link_classes]
        link_marg = np.array(link_marg).reshape(n_links, -1)

        n_compat = n_links * n_link_classes * n_prop_classes ** 2
        compat_marg = additionals[:n_compat]
        compat_marg = np.array(compat_marg).reshape((n_links,
                                                     n_prop_classes,
                                                     n_prop_classes,
                                                     n_link_classes))

        second_ordermarg = np.array(additionals[n_compat:])

        posteriors = (prop_marg, link_marg)
        additionals = (compat_marg, second_ordermarg)

        if relaxed:
            y_hat = posteriors, additionals
        else:
            y_hat = self._round(prop_marg, link_marg, prop_potentials,
                                link_potentials)

        if return_energy:
            return y_hat, status, -val
        else:
            return y_hat, status

    def _score(self, Y_true, Y_pred):

        acc = sum(1 for y_true, y_pred in zip(Y_true, Y_pred)
                  if np.all(y_true.links == y_pred.links) and
                  np.all(y_true.nodes == y_pred.nodes))
        acc /= len(Y_true)

        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            link_macro, link_micro = arg_f1_scores(
                (y.links for y in Y_true),
                (y.links for y in Y_pred),
                average='binary',
                pos_label=True,
                labels=self.link_encoder_.classes_
            )

            node_macro, node_micro = arg_f1_scores(
                (y.nodes for y in Y_true),
                (y.nodes for y in Y_pred),
                average='macro',
                labels=self.prop_encoder_.classes_
            )

        return link_macro, link_micro, node_macro, node_micro, acc


class ArgumentGraphCRF(BaseArgumentMixin, StructuredModel):
    def __init__(self, class_weight=None, link_node_weight_ratio=1,
                 exact=False, constraints=None, compat_features=False,
                 coparents=False, grandparents=False, siblings=False):
        self.class_weight = class_weight
        self.link_node_weight_ratio = link_node_weight_ratio
        self.exact = exact
        self.constraints = constraints
        self.compat_features = compat_features
        self.coparents = coparents
        self.grandparents = grandparents
        self.siblings = siblings

        self.n_second_order_factors_ = coparents + grandparents + siblings
        self.n_prop_states = None
        self.n_link_states = None
        self.n_prop_features = None
        self.n_link_features = None

        self.n_second_order_features_ = None
        self.n_compat_features_ = None
        self.inference_calls = 0

        super(ArgumentGraphCRF, self).__init__()

    def initialize(self, X, Y):
        # each x in X is a vectorized doc exposing sp.csr x.X_prop, x.X_link,
        # and maybe x.X_compat and x.X_sec_ord
        # each y in Y exposes lists y.nodes, y.links

        x = X[0]
        self.n_prop_features = x.X_prop.shape[1]
        self.n_link_features = x.X_link.shape[1]

        if self.compat_features:
            self.n_compat_features_ = x.X_compat.shape[1]

        if self.n_second_order_factors_:
            self.n_second_order_features_ = x.X_sec_ord.shape[1]
        else:
            self.n_second_order_features_ = 0

        self.initialize_labels(Y)
        self._set_size_joint_feature()

        self.coparents_ = self.coparents
        self.grandparents_ = self.grandparents
        self.siblings_ = self.siblings

    def _set_size_joint_feature(self):  # assumes no second order
        compat_size = self.n_prop_states ** 2 * self.n_link_states
        if self.compat_features:
            compat_size *= self.n_compat_features_

        total_n_second_order = (self.n_second_order_features_ *
                                self.n_second_order_factors_)

        self.size_joint_feature = (self.n_prop_features * self.n_prop_states +
                                   self.n_link_features * self.n_link_states +
                                   compat_size + total_n_second_order)

        logging.info("Joint feature size: {}".format(self.size_joint_feature))

    def joint_feature(self, x, y):

        if isinstance(y, DocLabel):
            Y_prop, Y_link, compat, second_order = self._marg_rounded(x, y)
        else:
            Y_prop, Y_link, compat, second_order = self._marg_fractional(x, y)

        prop_acc = safe_sparse_dot(Y_prop.T, x.X_prop)  # node_cls * node_feats
        link_acc = safe_sparse_dot(Y_link.T, x.X_link)  # link_cls * link_feats

        f_sec_ord = []

        if len(second_order):
            second_order = second_order.reshape(-1, len(x.second_order))
            if self.coparents:
                f_sec_ord.append(safe_sparse_dot(second_order[0], x.X_sec_ord))
                second_order = second_order[1:]

            if self.grandparents:
                f_sec_ord.append(safe_sparse_dot(second_order[0], x.X_sec_ord))
                second_order = second_order[1:]

            if self.siblings:
                f_sec_ord.append(safe_sparse_dot(second_order[0], x.X_sec_ord))

        elif self.n_second_order_factors_:
            # document has no second order factors so the joint feature
            # must be filled with zeros manually
            f_sec_ord = [np.zeros(self.n_second_order_features_)
                         for _ in range(self.n_second_order_factors_)]

        jf = np.concatenate([prop_acc.ravel(), link_acc.ravel(),
                             compat.ravel()] + f_sec_ord)

        return jf

    # basically reversing the joint feature
    def _get_potentials(self, x, w):
        # check sizes?
        n_node_coefs = self.n_prop_states * self.n_prop_features
        n_link_coefs = self.n_link_states * self.n_link_features
        n_compat_coefs = self.n_prop_states ** 2 * self.n_link_states

        if self.compat_features:
            n_compat_coefs *= self.n_compat_features_

        assert w.size == (n_node_coefs + n_link_coefs + n_compat_coefs +
                          self.n_second_order_features_ *
                          self.n_second_order_factors_)

        w_node = w[:n_node_coefs]

        w_node = w_node.reshape(self.n_prop_states, self.n_prop_features)

        w_link = w[n_node_coefs:n_node_coefs + n_link_coefs]
        w_link = w_link.reshape(self.n_link_states, self.n_link_features)

        # for readability, consume w. This is not inplace, don't worry.
        w = w[n_node_coefs + n_link_coefs:]
        w_compat = w[:n_compat_coefs]

        if self.compat_features:
            w_compat = w_compat.reshape((self.n_compat_features_, -1))
            w_compat = np.dot(x.X_compat, w_compat)
            compat_potentials = w_compat.reshape((-1,
                                                  self.n_prop_states,
                                                  self.n_prop_states,
                                                  self.n_link_states))
        else:
            compat_potentials = w_compat.reshape(self.n_prop_states,
                                                 self.n_prop_states,
                                                 self.n_link_states)

        w = w[n_compat_coefs:]

        coparent_potentials = grandparent_potentials = sibling_potentials = []

        if self.coparents:
            w_coparent = w[:self.n_second_order_features_]
            coparent_potentials = safe_sparse_dot(x.X_sec_ord, w_coparent)
            w = w[self.n_second_order_features_:]

        if self.grandparents:
            w_grandparent = w[:self.n_second_order_features_]
            grandparent_potentials = safe_sparse_dot(x.X_sec_ord,
                                                     w_grandparent)
            w = w[self.n_second_order_features_:]

        if self.siblings:
            w_sibling = w[:self.n_second_order_features_]
            sibling_potentials = safe_sparse_dot(x.X_sec_ord, w_sibling)

        prop_potentials = safe_sparse_dot(x.X_prop, w_node.T)
        link_potentials = safe_sparse_dot(x.X_link, w_link.T)

        return (prop_potentials, link_potentials, compat_potentials,
                coparent_potentials, grandparent_potentials,
                sibling_potentials)

    def inference(self, x, w, relaxed=False, return_energy=False):

        self.inference_calls += 1

        potentials = self._get_potentials(x, w)
        out = self._inference(x, potentials, exact=self.exact,
                              relaxed=relaxed, return_energy=return_energy,
                              constraints=self.constraints)
        if return_energy:
            return out[0], out[-1]
        else:
            return out[0]

    def loss_augmented_inference(self, x, y, w, relaxed=None):

        self.inference_calls += 1

        potentials = self._get_potentials(x, w)

        (prop_potentials,
         link_potentials,
         compat_potentials,
         coparent_potentials,
         grandparent_potentials,
         sibling_potentials) = potentials

        y_prop = self.prop_encoder_.transform(y.nodes)
        y_link = self.link_encoder_.transform(y.links)
        loss_augment_unaries(prop_potentials, y_prop, self.prop_cw_)
        loss_augment_unaries(link_potentials, y_link, self.link_cw_)

        potentials = (prop_potentials,
                      link_potentials,
                      compat_potentials,
                      coparent_potentials,
                      grandparent_potentials,
                      sibling_potentials)

        out = self._inference(x, potentials, exact=self.exact,
                              relaxed=relaxed, constraints=self.constraints)
        return out[0]