# -*- coding: utf-8 -*-

# Copyright (c) 2015-2016 MIT Probabilistic Computing Project

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#    http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools

from math import isinf

from cgpm.network import helpers as hu
from cgpm.utils import general as gu


class ImportanceNetwork(object):
    """Querier for a Composite CGpm."""

    def __init__(self, cgpms, accuracy=None, rng=None):
        if accuracy is None:
            accuracy = 1
        self.rng = rng if rng else gu.gen_rng(1)
        self.cgpms = hu.validate_cgpms(cgpms)
        self.accuracy = accuracy
        self.v_to_c = hu.retrieve_variable_to_cgpm(self.cgpms)
        self.adjacency = hu.retrieve_adjacency_list(self.cgpms, self.v_to_c)
        self.extraneous = hu.retrieve_extraneous_inputs(self.cgpms, self.v_to_c)
        self.topo = hu.topological_sort(self.adjacency)

    @gu.simulate_many
    def simulate(self, rowid, targets, constraints=None, inputs=None, N=None):
        if constraints is None:
            constraints = {}
        if inputs is None:
            inputs = {}
        samples, weights = zip(*[
            self.weighted_sample(rowid, targets, constraints, inputs)
            for _i in xrange(self.accuracy)
        ])
        if all(isinf(l) for l in weights):
            raise ValueError('Zero density constraints: %s' % (constraints,))
        # Skip an expensive random choice if there is only one option.
        index = 0 if self.accuracy == 1 else \
            gu.log_pflip(weights, rng=self.rng)
        return {q: samples[index][q] for q in targets}

    def logpdf(self, rowid, targets, constraints=None, inputs=None):
        if constraints is None:
            constraints = {}
        if inputs is None:
            inputs = {}
        # Compute joint probability.
        samples_joint, weights_joint = zip(*[
            self.weighted_sample(
                rowid, [], gu.merged(targets, constraints), inputs)
            for _i in xrange(self.accuracy)
        ])
        logp_joint = gu.logmeanexp(weights_joint)
        # Compute marginal probability.
        samples_marginal, weights_marginal = zip(*[
            self.weighted_sample(rowid, [], constraints, inputs)
            for _i in xrange(self.accuracy)
        ]) if constraints else ({}, [0.])
        if all(isinf(l) for l in weights_marginal):
            raise ValueError('Zero density constraints: %s' % (constraints,))
        logp_constraints = gu.logmeanexp(weights_marginal)
        # Return log ratio.
        return logp_joint - logp_constraints

    def weighted_sample(self, rowid, targets, constraints, inputs):
        targets_required = self.retrieve_required_inputs(targets, constraints)
        targets_all = targets + targets_required
        sample = dict(constraints)
        weight = 0
        for l in self.topo:
            sl, wl = self.invoke_cgpm(
                rowid, self.cgpms[l], targets_all, sample, inputs)
            sample.update(sl)
            weight += wl
        assert set(sample) == set.union(set(constraints), set(targets_all))
        return sample, weight

    def invoke_cgpm(self, rowid, cgpm, targets, constraints, inputs):
        cgpm_inputs = {
            e : x for e, x in
                itertools.chain(inputs.iteritems(), constraints.iteritems())
            if e in cgpm.inputs
        }
        cgpm_constraints = {
            e:x for e, x in constraints.iteritems()
            if e in cgpm.outputs
        }
        # ev_all = gu.merged(ev_in, ev_out)
        cgpm_targets = [q for q in targets if q in cgpm.outputs]
        if cgpm_constraints or cgpm_targets:
            assert all(i in cgpm_inputs for i in cgpm.inputs)
        weight = cgpm.logpdf(
            rowid,
            targets=cgpm_constraints,
            constraints=None,
            inputs=cgpm_inputs) if cgpm_constraints else 0
        sample = cgpm.simulate(
            rowid,
            targets=cgpm_targets,
            constraints=cgpm_constraints,
            inputs=cgpm_inputs
            ) if cgpm_targets else {}
        return sample, weight

    def retrieve_required_inputs(self, targets, constraints):
        """Return list of inputs required to answer query."""
        def retrieve_required_inputs(cgpm, targets):
            active = any(i in targets or i in constraints for i in cgpm.outputs)
            return cgpm.inputs if active else []
        required_all = set(targets)
        for l in reversed(self.topo):
            required_l = retrieve_required_inputs(self.cgpms[l], required_all)
            required_all.update(required_l)
        return [c for c in required_all if
            all(c not in x for x in [targets, constraints, self.extraneous])]