from __future__ import division

import numpy as np

import copy
import inspect
import types as python_types
import marshal
import sys
import warnings

from keras import activations, initializations, regularizers, constraints
from keras import backend as K
from keras.engine import InputSpec, Layer
from keras.layers.core import Dense, Flatten
from keras_extensions.initializations import glorot_uniform_sigm
from keras_extensions.activations import nrlu

class RBM(Layer):

	def __init__(self, hidden_dim, init='glorot_uniform', 
		weights=None, 
		Wrbm_regularizer=None, bx_regularizer=None, bh_regularizer=None, 
		activity_regularizer=None,
		Wrbm_constraint=None, bx_constraint=None, bh_constraint=None,
		input_dim=None, nb_gibbs_steps=1, persistent=False, batch_size=1,
		scaling_h_given_x=1.0, scaling_x_given_h=1.0,
		dropout=0.0,
		hidden_unit_type='binary',
		visible_unit_type='binary',
		Wrbm=None, bh=None, bx=None,
		**kwargs):

		self.p = dropout
		if(0.0 < self.p < 1.0): 
			self.uses_learning_phase = True 
			self.supports_masking = True 

		if(hidden_unit_type == 'softmax'):
			activation = 'softmax'
			self.is_persistent = False
			self.nb_gibbs_steps = 1
		elif(hidden_unit_type == 'nrlu'):
			activation = 'relu'
			self.is_persistent = False
			self.nb_gibbs_steps = 1
		else:
			activation = 'sigmoid'
			self.is_persistent = persistent
			self.nb_gibbs_steps = nb_gibbs_steps

		self.updates = []
		self.init = initializations.get(init)
		self.activation = activations.get(activation)
		self.hidden_dim = hidden_dim
		self.input_dim = input_dim
		self.batch_size = batch_size
		self.hidden_unit_type = hidden_unit_type
		self.visible_unit_type = visible_unit_type

		self.scaling_h_given_x = scaling_h_given_x
		self.scaling_x_given_h = scaling_x_given_h

		self.Wrbm_regularizer = regularizers.get(Wrbm_regularizer)
		self.bx_regularizer = regularizers.get(bx_regularizer)
		self.bh_regularizer = regularizers.get(bh_regularizer)
		self.activity_regularizer = regularizers.get(activity_regularizer)

		self.Wrbm_constraint = constraints.get(Wrbm_constraint)
		self.bx_constraint = constraints.get(bx_constraint)
		self.bh_constraint = constraints.get(bh_constraint)

		self.initial_weights = weights
		self.input_spec = [InputSpec(ndim='2+')]
	
		if self.input_dim:
			kwargs['input_shape'] = (self.input_dim,)
		
		super(RBM, self).__init__(**kwargs)


		if(Wrbm == None):
			self.Wrbm = self.add_weight((input_dim, self.hidden_dim),
									initializer=self.init,
									name='{}_Wrbm'.format(self.name),
									regularizer=self.Wrbm_regularizer,
									constraint=self.Wrbm_constraint)
		else:
			self.Wrbm = Wrbm

		if(bx == None):
			self.bx = self.add_weight((self.input_dim,),
								initializer='zero',
								name='{}_bx'.format(self.name),
								regularizer=self.bx_regularizer,
								constraint=self.bx_constraint)
		else:
			self.bx = bx

		if(bh == None):
			self.bh = self.add_weight((self.hidden_dim,),
								initializer='zero',
								name='{}_bh'.format(self.name),
								regularizer=self.bh_regularizer,
								constraint=self.bh_constraint)
		else:
			self.bh = bh

		if(self.is_persistent):
			self.persistent_chain = K.variable(np.zeros((self.batch_size, self.input_dim), 
												dtype=K.floatx()))

	def _get_noise_shape(self, x): 
		return None 


	def build(self, input_shape):
		assert len(input_shape) == 2
		input_dim = input_shape[-1]
		self.input_spec = [InputSpec(dtype=K.floatx(),
									ndim='2+')]

		#self.trainable_weights = [self.W, self.bx, self.bh]

		if self.initial_weights is not None:
			self.set_weights(self.initial_weights)
		del self.initial_weights

	def call(self, x, mask=None):
		return 1.0*x

	def get_output_shape_for(self, input_shape):
		#assert input_shape and len(input_shape) == 2
		return (input_shape[0], self.input_dim)

	def get_config(self):
		config = {'output_dim': self.hidden_dim,
		'init': self.init.__name__,
		'activation': self.activation.__name__,
		'Wrbm_regularizer': self.Wrbm_regularizer.get_config() if self.Wrbm_regularizer else None,
		'bh_regularizer': self.bh_regularizer.get_config() if self.bh_regularizer else None,
		'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer  else None,
		'Wrbm_constraint': self.Wrbm_constraint.get_config() if self.Wrbm_constraint else None,
		'bh_constraint': self.bh_constraint.get_config() if self.bh_constraint else None,
		'input_dim': self.input_dim
		}

		base_config = super(Dense, self).get_config()
		return dict(list(base_config.items()) + list(config.items()))

    # -------------
    # RBM internals
    # -------------

	def free_energy(self, x):
		wx_b = K.dot(x, self.Wrbm) + self.bh

		if(self.visible_unit_type == 'gaussian'):
			vbias_term = 0.5*K.sum((x - self.bx)**2, axis=1)
			hidden_term = K.sum(K.log(1 + K.exp(wx_b)), axis=1)
			return -hidden_term + vbias_term
		else:
			hidden_term = K.sum(K.log(1 + K.exp(wx_b)), axis=1)
			vbias_term = K.dot(x, self.bx)
			return -hidden_term - vbias_term


	def sample_h_given_x(self, x):
		h_pre = K.dot(x, self.Wrbm) + self.bh	
		h_sigm = self.activation(self.scaling_h_given_x * h_pre)

		# drop out noise
		#if(0.0 < self.p < 1.0):
		#	noise_shape = self._get_noise_shape(h_sigm)
		#	h_sigm = K.in_train_phase(K.dropout(h_sigm, self.p, noise_shape), h_sigm)
		
		if(self.hidden_unit_type == 'binary'):
			h_samp = K.random_binomial(shape=h_sigm.shape, p=h_sigm)
	        # random sample
	        #   \hat{h} = 1,      if p(h=1|x) > uniform(0, 1)
	        #             0,      otherwise
		elif(self.hidden_unit_type == 'nrlu'):
			h_samp = nrlu(h_pre)
		else:
			h_samp = h_sigm

		if(0.0 < self.p < 1.0):
			noise_shape = self._get_noise_shape(h_samp)
			h_samp = K.in_train_phase(K.dropout(h_samp, self.p, noise_shape), h_samp)

		return h_samp, h_pre, h_sigm


	def sample_x_given_h(self, h):
		x_pre = K.dot(h, self.Wrbm.T) + self.bx 

		if(self.visible_unit_type == 'gaussian'):
			x_samp = self.scaling_x_given_h  * x_pre
			return x_samp, x_samp, x_samp
		else:       
			x_sigm = K.sigmoid(self.scaling_x_given_h  * x_pre)             
			x_samp = K.random_binomial(shape=x_sigm.shape, p=x_sigm)
			return x_samp, x_pre, x_sigm


	def gibbs_xhx(self, x0):
		h1, h1_pre, h1_sigm = self.sample_h_given_x(x0)
		x1, x1_pre, x1_sigm = self.sample_x_given_h(h1)
		return x1, x1_pre, x1_sigm


	def mcmc_chain(self, x, nb_gibbs_steps):
		xi = x
		for i in range(nb_gibbs_steps):
			xi, xi_pre, xi_sigm = self.gibbs_xhx(xi)
		x_rec, x_rec_pre, x_rec_sigm = xi, xi_pre, xi_sigm

		x_rec = K.stop_gradient(x_rec)

		return x_rec, x_rec_pre, x_rec_sigm


	def contrastive_divergence_loss(self, y_true, y_pred):

		x = y_pred
		#x = K.reshape(x, (-1, self.input_dim))

		if(self.is_persistent):
			chain_start = self.persistent_chain
		else:
			chain_start = x

		def loss(chain_start, x):
			x_rec, _, _ = self.mcmc_chain(chain_start, self.nb_gibbs_steps)
			cd = K.mean(self.free_energy(x)) - K.mean(self.free_energy(x_rec))
			return cd, x_rec

		y, x_rec = loss(chain_start, x)

		if(self.is_persistent):
			self.updates = [(self.persistent_chain, x_rec)]

		return y


	def reconstruction_loss(self, y_true, y_pred):

		x = y_pred

		def loss(x):
			if(self.visible_unit_type == 'gaussian'):
				x_rec, _, _ = self.mcmc_chain(x, self.nb_gibbs_steps)
				return K.mean(K.sqrt(x - x_rec))
			else:
				_, pre, _ = self.mcmc_chain(x, self.nb_gibbs_steps)
				cross_entropy_loss = -K.mean(K.sum(x*K.log(K.sigmoid(pre)) + 
										(1 - x)*K.log(1 - K.sigmoid(pre)), axis=1))
				return cross_entropy_loss

		return loss(x)


	def free_energy_gap(self, x_train, x_test):
		return K.mean(self.free_energy(x_train)) - K.mean(self.free_energy(x_test))


	def get_h_given_x_layer(self, as_initial_layer=False):

		if(as_initial_layer):
			layer = Dense(input_dim=self.input_dim, output_dim=self.hidden_dim,
							activation=self.activation, 
							weights=[self.Wrbm.get_value(), self.bh.get_value()])
		else:
			layer = Dense(output_dim=self.hidden_dim, 
							activation=self.activation, 
							weights=[self.Wrbm.get_value(), self.bh.get_value()])
		return layer


	def get_x_given_h_layer(self, as_initial_layer=False):

		if(self.visible_unit_type == 'gaussian'):
			act = 'linear'
		else:
			act = 'sigmoid'

		if(as_initial_layer):
			layer = Dense(input_dim=self.hidden_dim, output_dim=self.input_dim, 
						activation=act, 
						weights=[self.Wrbm.get_value().T, self.bx.get_value()])
		else:
			layer = Dense(output_dim=self.input_dim, activation=act, 
						weights=[self.Wrbm.get_value().T, self.bx.get_value()])
		return layer


	def return_reconstruction_data(self, x):

		def re_sample(x):
			x_rec, pre, _ = self.mcmc_chain(x, self.nb_gibbs_steps)
			return x_rec

		return re_sample(x)