#!/usr/bin/env python 

from __future__ import division

import logging
from six import iteritems
from abc import ABCMeta, abstractmethod
from collections import OrderedDict

import numpy as np

import theano 
import theano.tensor as T
from theano.printing import Print
from theano.tensor.shared_randomstreams import RandomStreams
from theano.sandbox.rng_mrg import MRG_RandomStreams

from learning.model import Model
from learning.utils.datalog  import dlog

_logger = logging.getLogger(__name__)

floatX = theano.config.floatX

theano.config.exception_verbosity = 'high'
theano_rng = MRG_RandomStreams(seed=2341)


def enumerate_pairs(start, end):
    return [(i, i+1) for i in xrange(0, end-1)]

#=============================================================================

def f_replicate_batch(A, repeat):
    """Extend the given 2d Tensor by repeating reach line *repeat* times.

    With A.shape == (rows, cols), this function will return an array with
    shape (rows*repeat, cols).

    Parameters
    ----------
    A : T.tensor
        Each row of this 2d-Tensor will be replicated *repeat* times
    repeat : int

    Returns
    -------
    B : T.tensor
    """
    A_ = A.dimshuffle((0, 'x', 1))
    A_ = A_ + T.zeros((A.shape[0], repeat, A.shape[1]), dtype=floatX)
    A_ = A_.reshape( [A_.shape[0]*repeat, A.shape[1]] )
    return A_

def f_logsumexp(A, axis=None):
    """Numerically stable log( sum( exp(A) ) ) """
    A_max = T.max(A, axis=axis, keepdims=True)
    B = T.log(T.sum(T.exp(A-A_max), axis=axis, keepdims=True))+A_max
    B = T.sum(B, axis=axis)
    return B

#=============================================================================

class TopModule(Model):
    __metaclass__ = ABCMeta

    def __init__(self):
        super(TopModule, self).__init__()
        self.register_hyper_param('clamp_sigmoid', default=False)

    def setup(self):
        pass

    def sigmoid(self, x):
        """ Compute the element wise sigmoid function of x 

        Depending on the *clamp_sigmoid* hyperparameter, this might
        return a saturated sigmoid T.nnet.sigmoid(x)*0.9999 + 0.000005
        """
        if self.clamp_sigmoid:
            return T.nnet.sigmoid(x)*0.9999 + 0.000005
        else:
            return T.nnet.sigmoid(x)

    @abstractmethod
    def sample(self, n_samples):
        """ Sample from this toplevel module and return X ~ P(X), log(P(X))

        Parameters
        ----------
        n_samples:
            number of samples to drawn

        Returns
        -------
        X:      T.tensor
            samples from this module
        log_p:  T.tensor
            log-probabilities for the samples returned in X
        """
        return X, log_p

    @abstractmethod
    def log_prob(self, X):
        """ Calculate the log-probabilities for the samples in X 

        Parameters
        ----------
        X:      T.tensor
            samples to evaluate

        Returns
        -------
        log_p:  T.tensor
        """
        return log_p

class Module(Model):
    __metaclass__ = ABCMeta

    def __init__(self):
        super(Module, self).__init__()

        self.register_hyper_param('clamp_sigmoid', default=False)

    def setup(self):
        pass

    def sigmoid(self, x):
        """ Compute the element wise sigmoid function of x 

        Depending on the *clamp_sigmoid* hyperparameter, this might
        return a saturated sigmoid T.nnet.sigmoid(x)*0.9999 + 0.000005
        """
        if self.clamp_sigmoid:
            return T.nnet.sigmoid(x)*0.9999 + 0.000005
        else:
            return T.nnet.sigmoid(x)

    @abstractmethod
    def sample(self, Y):
        """ Given samples from the upper layer Y, sample values from X
            and return then together with their log probability.

        Parameters
        ----------
        Y:      T.tensor
            samples from the upper layer

        Returns
        -------
        X:      T.tensor
            samples from the lower layer
        log_p:  T.tensor
            log-posterior for the samples returned in X
        """
        X, log_p = None, None
        return X, log_p

    @abstractmethod
    def log_prob(self, X, Y):
        """ Evaluate the log-probability for the given samples.

        Parameters
        ----------
        Y:      T.tensor
            samples from the upper layer
        X:      T.tensor
            samples from the lower layer

        Returns
        -------
        log_p:  T.tensor
            log-probabilities for the samples in X and Y
        """
        return log_p

#=============================================================================

class LayerStack(Model):
    def __init__(self, **hyper_params):     
        super(LayerStack, self).__init__()

        # Hyper parameters
        self.register_hyper_param('p_layers', help='STBP P layers', default=[])
        self.register_hyper_param('q_layers', help='STBP Q layers', default=[])
        self.register_hyper_param('n_samples', help='no. of samples to use', default=10)

        self.set_hyper_params(hyper_params)

    def setup(self):
        p_layers = self.p_layers
        q_layers = self.q_layers
        n_layers = len(p_layers)
        
        assert len(p_layers) == len(q_layers)+1
        assert isinstance(p_layers[-1], TopModule)

        
        self.n_X = p_layers[0].n_X

        for l in xrange(0, n_layers-1):
            assert isinstance(p_layers[l], Module)
            assert isinstance(q_layers[l], Module)
            assert p_layers[l].n_Y == p_layers[l+1].n_X
            assert p_layers[l].n_Y == q_layers[l].n_X 

            p_layers[l].setup()
            q_layers[l].setup()


        p_layers[-1].setup()

    def sample_p(self, n_samples):
        """ Draw *n_samples* drawn from the P-model.
        
        This method returns a list with the samples values on all layers
        and the correesponding log_p.
        """
        p_layers = self.p_layers
        q_layers = self.q_layers
        n_layers = len(p_layers)

        # Generate samples from the generative model
        samples = [None]*n_layers

        samples[-1], log_prob = p_layers[-1].sample(n_samples)
        for l in xrange(n_layers-1, 0, -1):
            samples[l-1], log_p_l = p_layers[l-1].sample(samples[l])
            log_prob += log_p_l
        
        return samples, log_prob

    def sample_q(self, X, Y=None):
        """ Given a set of observed X, samples from q(H | X) and calculate 
            both P(X, H) and Q(H | X)
        """
        p_layers = self.p_layers
        q_layers = self.q_layers
        n_layers = len(p_layers)

        size = X.shape[0]

        # Prepare input for layers
        samples = [None]*n_layers
        log_q   = [None]*n_layers
        log_p   = [None]*n_layers

        samples[0] = X
        log_q[0]   = T.zeros([size]) 

        # Generate samples (feed-forward)
        for l in xrange(n_layers-1):
            samples[l+1], log_q[l+1] = q_layers[l].sample(samples[l])
        
        # Get log_probs from generative model
        log_p[n_layers-1] = p_layers[n_layers-1].log_prob(samples[n_layers-1])
        for l in xrange(n_layers-1, 0, -1):
            log_p[l-1] = p_layers[l-1].log_prob(samples[l-1], samples[l])

        return samples, log_p, log_q
 
    def log_likelihood(self, X, Y=None, n_samples=None):
        p_layers = self.p_layers
        q_layers = self.q_layers
        n_layers = len(p_layers)

        if n_samples == None:
            n_samples = self.n_samples

        batch_size = X.shape[0]

        # Get samples
        X = f_replicate_batch(X, n_samples)
        samples, log_p, log_q = self.sample_q(X, None)

        # Reshape and sum
        log_p_all = T.zeros((batch_size, n_samples))
        log_q_all = T.zeros((batch_size, n_samples))
        for l in xrange(n_layers):
            samples[l] = samples[l].reshape((batch_size, n_samples, p_layers[l].n_X))
            log_q[l] = log_q[l].reshape((batch_size, n_samples))
            log_p[l] = log_p[l].reshape((batch_size, n_samples))
            log_p_all += log_p[l]   # agregate all layers
            log_q_all += log_q[l]   # agregate all layers

        # Approximate log P(X)
        log_px = f_logsumexp(log_p_all-log_q_all, axis=1) - T.log(n_samples)
        
        # Calculate samplig weights
        log_pq = (log_p_all-log_q_all-T.log(n_samples))
        w_norm = f_logsumexp(log_pq, axis=1)
        log_w = log_pq-T.shape_padright(w_norm)
        w = T.exp(log_w)

        # Calculate KL(P|Q), Hp, Hq
        KL = [None]*n_layers
        Hp = [None]*n_layers
        Hq = [None]*n_layers
        for l in xrange(n_layers):
            KL[l] = T.sum(w*(log_p[l]-log_q[l]), axis=1)
            Hp[l] = f_logsumexp(log_w+log_p[l], axis=1)
            Hq[l] = T.sum(w*log_q[l], axis=1)

        return log_px, w, log_p_all, log_q_all, KL, Hp, Hq

    def get_gradients(self, X, Y, lr_p, lr_q, n_samples):
        """ return log_PX and an OrderedDict with parameter gradients """
        log_PX, w, log_p, log_q, KL, Hp, Hq = self.log_likelihood(X, Y, n_samples=n_samples)
        
        batch_log_PX = T.sum(log_PX)
        cost_p = T.sum(T.sum(log_p*w, axis=1))
        cost_q = T.sum(T.sum(log_q*w, axis=1))

        gradients = OrderedDict()
        for nl, layer in enumerate(self.p_layers):
            for name, shvar in iteritems(layer.get_model_params()):
                gradients[shvar] = lr_p[nl] * T.grad(cost_p, shvar, consider_constant=[w])

        for nl, layer in enumerate(self.q_layers):
            for name, shvar in iteritems(layer.get_model_params()):
                gradients[shvar] = lr_q[nl] * T.grad(cost_q, shvar, consider_constant=[w])

        return batch_log_PX, gradients

    def get_sleep_gradients(self, lr_s=1., n_dreams=100):
        p_layers = self.p_layers
        q_layers = self.q_layers
        n_layers = len(p_layers)

        p, log_p = self.sample_p(n_dreams)

        log_q = T.zeros((n_dreams,))
        for i, j in enumerate_pairs(0, n_layers):
            log_q += q_layers[i].log_prob(p[i+1], p[i])

        cost_q = T.sum(log_q)

        gradients = OrderedDict()
        for nl, layer in enumerate(self.q_layers):
            for name, shvar in iteritems(layer.get_model_params()):
                gradients[shvar] = lr_s[nl] * T.grad(cost_q, shvar)

        return log_q, gradients
        
    #------------------------------------------------------------------------
    def get_p_params(self):
        params = OrderedDict()
        for l in self.p_layers:
            params.update( l.get_model_params() )
        return params

    def get_q_params(self):
        params = OrderedDict()
        for l in self.q_layers:
            params.update( l.get_model_params() )
        return params

    def model_params_to_dict(self):
        vals = {}
        for n,l in enumerate(self.p_layers):
            for pname, shvar in iteritems(l.get_model_params()):
                key = "L%d.P.%s" % (n, pname)
                vals[key] = shvar.get_value()
        for n,l in enumerate(self.q_layers):                
            for pname, shvar in iteritems(l.get_model_params()):
                key = "L%d.Q.%s" % (n, pname)
                vals[key] = shvar.get_value()
        return vals

    def model_params_from_dict(self, vals):
        for n,l in enumerate(self.p_layers):
            for pname, shvar in iteritems(l.get_model_params()):
                key = "L%d.P.%s" % (n, pname)
                value = vals[key]
                shvar.set_value(value)
        for n,l in enumerate(self.q_layers):                
            for pname, shvar in iteritems(l.get_model_params()):
                key = "L%d.Q.%s" % (n, pname)
                value = vals[key]
                shvar.set_value(value)

    def model_params_to_dlog(self, dlog):
        vals = self.model_params_to_dict()
        dlog.append_all(vals)

    def model_params_from_dlog(self, dlog, row=-1):
        for n,l in enumerate(self.p_layers):
            for pname, shvar in iteritems(l.get_model_params()):
                key = "L%d.P.%s" % (n, pname)
                value = dlog.load(key)
                shvar.set_value(value)
        for n,l in enumerate(self.q_layers):                
            for pname, shvar in iteritems(l.get_model_params()):
                key = "L%d.Q.%s" % (n, pname)
                value = dlog.load(key)
                shvar.set_value(value)
 
    def model_params_from_h5(self, h5, row=-1, basekey="model."):
        for n,l in enumerate(self.p_layers):
            try:
                for pname, shvar in iteritems(l.get_model_params()):
                    key = "%sL%d.P.%s" % (basekey, n, pname)
                    value = h5[key][row]
                    shvar.set_value(value)
            except KeyError:
                if n >= len(self.p_layers)-2:
                    _logger.warning("Unable to load top P-layer params %s[%d]... continuing" % (key, row))
                    continue
                else:
                    _logger.error("Unable to load %s[%d] from %s" % (key, row, h5.filename))
                    raise

        for n,l in enumerate(self.q_layers):                
            try:
                for pname, shvar in iteritems(l.get_model_params()):
                    key = "%sL%d.Q.%s" % (basekey, n, pname)
                    value = h5[key][row]
                    shvar.set_value(value)
            except KeyError:
                if n == len(self.q_layers)-1:
                    _logger.warning("Unable to load top Q-layer params %s[%d]... continuing" % (key, row))
                    continue
                _logger.error("Unable to load %s[%d] from %s" % (key, row, h5.filename))
                raise