"""A set of wrappers usefull for tacotron 2 architecture
All notations and variable names were used in concordance with originial tensorflow implementation
"""
import collections
import numpy as np
import tensorflow as tf
from tensorflow.contrib.rnn import RNNCell
from tensorflow.python.framework import ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import check_ops
from tensorflow.python.util import nest
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.framework import tensor_shape
from tacotron.models.attention import _compute_attention

_zero_state_tensors = rnn_cell_impl._zero_state_tensors



class TacotronEncoderCell(RNNCell):
	"""Tacotron 2 Encoder Cell
	Passes inputs through a stack of convolutional layers then through a bidirectional LSTM
	layer to predict the hidden representation vector (or memory)
	"""

	def __init__(self, convolutional_layers, lstm_layer):
		"""Initialize encoder parameters

		Args:
			convolutional_layers: Encoder convolutional block class
			lstm_layer: encoder bidirectional lstm layer class
		"""
		super(TacotronEncoderCell, self).__init__()
		#Initialize encoder layers
		self._convolutions = convolutional_layers
		self._cell = lstm_layer

	def __call__(self, inputs, input_lengths=None):
		#Pass input sequence through a stack of convolutional layers
		conv_output = self._convolutions(inputs)

		#Extract hidden representation from encoder lstm cells
		hidden_representation = self._cell(conv_output, input_lengths)

		#For shape visualization
		self.conv_output_shape = conv_output.shape
		return hidden_representation


class TacotronDecoderCellState(
	collections.namedtuple("TacotronDecoderCellState",
	 ("cell_state", "attention", "time", "alignments",
	  "alignment_history", "finished"))):
	"""`namedtuple` storing the state of a `TacotronDecoderCell`.
	Contains:
	  - `cell_state`: The state of the wrapped `RNNCell` at the previous time
		step.
	  - `attention`: The attention emitted at the previous time step.
	  - `time`: int32 scalar containing the current time step.
	  - `alignments`: A single or tuple of `Tensor`(s) containing the alignments
		 emitted at the previous time step for each attention mechanism.
	  - `alignment_history`: a single or tuple of `TensorArray`(s)
		 containing alignment matrices from all time steps for each attention
		 mechanism. Call `stack()` on each to convert to a `Tensor`.
	"""
	def replace(self, **kwargs):
		"""Clones the current state while overwriting components provided by kwargs.
		"""
		return super(TacotronDecoderCellState, self)._replace(**kwargs)

class TacotronDecoderCell(RNNCell):
	"""Tactron 2 Decoder Cell
	Decodes encoder output and previous mel frames into next r frames

	Decoder Step i:
		1) Prenet to compress last output information
		2) Concat compressed inputs with previous context vector (input feeding) *
		3) Decoder RNN (actual decoding) to predict current state s_{i} *
		4) Compute new context vector c_{i} based on s_{i} and a cumulative sum of previous alignments *
		5) Predict new output y_{i} using s_{i} and c_{i} (concatenated)
		6) Predict <stop_token> output ys_{i} using s_{i} and c_{i} (concatenated)

	* : This is typically taking a vanilla LSTM, wrapping it using tensorflow's attention wrapper,
	and wrap that with the prenet before doing an input feeding, and with the prediction layer
	that uses RNN states to project on output space. Actions marked with (*) can be replaced with 
	tensorflow's attention wrapper call if it was using cumulative alignments instead of previous alignments only.
	"""

	def __init__(self, prenet, attention_mechanism, rnn_cell, frame_projection, stop_projection, mask_finished=False):
		"""Initialize decoder parameters

		Args:
		    prenet: A tensorflow fully connected layer acting as the decoder pre-net
		    attention_mechanism: A _BaseAttentionMechanism instance, usefull to 
			    learn encoder-decoder alignments
		    rnn_cell: Instance of RNNCell, main body of the decoder
		    frame_projection: tensorflow fully connected layer with r * num_mels output units
		    stop_projection: tensorflow fully connected layer, expected to project to a scalar 
			    and through a sigmoid activation
			mask_finished: Boolean, Whether to mask decoder frames after the <stop_token>
		"""
		super(TacotronDecoderCell, self).__init__()
		#Initialize decoder layers
		self._prenet = prenet
		self._attention_mechanism = attention_mechanism
		self._cell = rnn_cell
		self._frame_projection = frame_projection
		self._stop_projection = stop_projection

		self._mask_finished = mask_finished
		self._attention_layer_size = self._attention_mechanism.values.get_shape()[-1].value

	def _batch_size_checks(self, batch_size, error_message):
		return [check_ops.assert_equal(batch_size,
		  self._attention_mechanism.batch_size,
		  message=error_message)]

	@property
	def output_size(self):
		return self._frame_projection.shape

	@property
	def state_size(self):
		"""The `state_size` property of `TacotronDecoderCell`.

		Returns:
		  An `TacotronDecoderCell` tuple containing shapes used by this object.
		"""
		return TacotronDecoderCellState(
			cell_state=self._cell._cell.state_size,
			time=tensor_shape.TensorShape([]),
			attention=self._attention_layer_size,
			alignments=self._attention_mechanism.alignments_size,
			alignment_history=(),
			finished=())

	def zero_state(self, batch_size, dtype):
		"""Return an initial (zero) state tuple for this `AttentionWrapper`.
		
		Args:
		  batch_size: `0D` integer tensor: the batch size.
		  dtype: The internal state data type.
		Returns:
		  An `TacotronDecoderCellState` tuple containing zeroed out tensors and,
		  possibly, empty `TensorArray` objects.
		Raises:
		  ValueError: (or, possibly at runtime, InvalidArgument), if
			`batch_size` does not match the output size of the encoder passed
			to the wrapper object at initialization time.
		"""
		with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
			cell_state = self._cell._cell.zero_state(batch_size, dtype)
			error_message = (
				"When calling zero_state of TacotronDecoderCell %s: " % self._base_name +
				"Non-matching batch sizes between the memory "
				"(encoder output) and the requested batch size.")
			with ops.control_dependencies(
				self._batch_size_checks(batch_size, error_message)):
				cell_state = nest.map_structure(
					lambda s: array_ops.identity(s, name="checked_cell_state"),
					cell_state)
			return TacotronDecoderCellState(
				cell_state=cell_state,
				time=array_ops.zeros([], dtype=tf.int32),
				attention=_zero_state_tensors(self._attention_layer_size, batch_size,
				  dtype),
				alignments=self._attention_mechanism.initial_alignments(batch_size, dtype),
				alignment_history=tensor_array_ops.TensorArray(dtype=dtype, size=0,
				dynamic_size=True),
				finished=tf.reshape(tf.tile([0.0], [batch_size]), [-1, 1]))

	def __call__(self, inputs, state):
		#Information bottleneck (essential for learning attention)
		prenet_output = self._prenet(inputs)

		#Concat context vector and prenet output to form LSTM cells input (input feeding)
		LSTM_input = tf.concat([prenet_output, state.attention], axis=-1)

		#Unidirectional LSTM layers
		LSTM_output, next_cell_state = self._cell(LSTM_input, state.cell_state)

		#Compute the attention (context) vector and alignments using
		#the new decoder cell hidden state as query vector 
		#and cumulative alignments to extract location features
		#The choice of the new cell hidden state (s_{i}) of the last
		#decoder RNN Cell is based on Luong et Al. (2015):
		#https://arxiv.org/pdf/1508.04025.pdf
		previous_alignments = state.alignments
		previous_alignment_history = state.alignment_history
		context_vector, alignments, cumulated_alignments = _compute_attention(self._attention_mechanism, 
			LSTM_output,
			previous_alignments,
			attention_layer=None)

		#Concat LSTM outputs and context vector to form projections inputs
		projections_input = tf.concat([LSTM_output, context_vector], axis=-1)

		#Compute predicted frames and predicted <stop_token>
		cell_outputs = self._frame_projection(projections_input)
		stop_tokens = self._stop_projection(projections_input)

		#mask attention computed for decoding steps where sequence is already finished
		#this is purely for visual purposes and will not affect the training of the model
		#we don't pay much attention to the alignments of the output paddings if we impute
		#the decoder outputs beyond the end of sequence.
		if self._mask_finished:
			finished = tf.cast(state.finished * tf.ones(tf.shape(alignments)), tf.bool)
			mask = tf.zeros(tf.shape(alignments))
			masked_alignments = tf.where(finished, mask, alignments)
		else:
			masked_alignments = alignments

		#Save alignment history
		alignment_history = previous_alignment_history.write(state.time, masked_alignments)

		#Prepare next decoder state
		next_state = TacotronDecoderCellState(
			time=state.time + 1,
			cell_state=next_cell_state,
			attention=context_vector,
			alignments=cumulated_alignments,
			alignment_history=alignment_history,
			finished=state.finished)

		return (cell_outputs, stop_tokens), next_state