# ==============================================================================
# Copyright (c) 2018, Yamagishi Laboratory, National Institute of Informatics
# Author: Yusuke Yasuda (yasuda@nii.ac.jp)
# All rights reserved.
# ==============================================================================
"""  """


import tensorflow as tf
from tensorflow.contrib.rnn import RNNCell
from collections import namedtuple
from functools import reduce
from abc import abstractmethod
from typing import Tuple
from tacotron2.tacotron.modules import PreNet


class TransparentRNNCellLike:
    '''
    RNNCell-like base class that do not create scopes
    '''

    @property
    @abstractmethod
    def state_size(self):
        pass

    @property
    @abstractmethod
    def output_size(self):
        pass

    @abstractmethod
    def zero_state(self, batch_size, dtype):
        pass

    @abstractmethod
    def __call__(self, inputs, state):
        pass


class RNNStateHistoryWrapperState(
    namedtuple("RNNStateHistoryWrapperState", ["rnn_state", "rnn_state_history", "time"])):
    pass


class RNNStateHistoryWrapper(TransparentRNNCellLike):

    def __init__(self, cell: RNNCell, max_iter):
        self._cell = cell
        self._max_iter = max_iter

    @property
    def state_size(self):
        return RNNStateHistoryWrapperState(self._cell.state_size,
                                           tf.TensorShape([None, None, self.output_size]),
                                           tf.TensorShape([]))

    @property
    def output_size(self):
        return self._cell.output_size

    def zero_state(self, batch_size, dtype):
        rnn_state = self._cell.zero_state(batch_size, dtype)
        history = tf.zeros(shape=[batch_size, 0, self.output_size], dtype=dtype)
        # avoid Tensor#set_shape which merge unknown shape with known shape
        history._shape_val = tf.TensorShape([None, None, self.output_size])  # pylint: disable=protected-access
        time = tf.zeros([], dtype=tf.int32)
        return RNNStateHistoryWrapperState(rnn_state, history, time)

    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], input_shape[1], self.output_size])

    def __call__(self, inputs, state: RNNStateHistoryWrapperState):
        output, new_rnn_state = self._cell(inputs, state.rnn_state)
        new_history = tf.concat([state.rnn_state_history,
                                 tf.expand_dims(output, axis=1)], axis=1)
        new_history.set_shape([None, None, self.output_size])
        new_state = RNNStateHistoryWrapperState(new_rnn_state, new_history, state.time + 1)
        return output, new_state


class TransformerWrapperState(namedtuple("TransformerWrapperState", ["rnn_state", "alignments"])):
    pass


class TransformerWrapper(TransparentRNNCellLike):

    def __init__(self, cell: RNNStateHistoryWrapper, transformers, memory_sequence_length):
        self._cell = cell
        self._transformers = transformers
        self._memory_sequence_length = memory_sequence_length

    @property
    def state_size(self):
        return TransformerWrapperState(self._cell.state_size, [(None, None) for _ in self._transformers])

    @property
    def output_size(self):
        return TransformerWrapperState(self._cell.output_size, [(None, None) for _ in self._transformers])

    def zero_state(self, batch_size, dtype):
        def initial_alignment(num_heads):
            ia = tf.zeros([batch_size, 0, 0], dtype)
            ia._shape_val = tf.TensorShape([None, None, None])  # pylint: disable=protected-access
            return [ia] * num_heads

        return TransformerWrapperState(self._cell.zero_state(batch_size, dtype),
                                       [ia for ia in initial_alignment(2) for _ in self._transformers])

    def __call__(self, inputs, state: TransformerWrapperState):
        output, new_rnn_state = self._cell(inputs, state.rnn_state)
        history = new_rnn_state.rnn_state_history

        def self_attend(input, alignments, layer):
            output, alignment = layer(input, memory_sequence_length=self._memory_sequence_length)
            return output, alignments + alignment

        transformed, alignments = reduce(lambda acc, sa: self_attend(acc[0], acc[1], sa),
                                         self._transformers,
                                         (history, []))
        output_element = transformed[:, -1, :]
        new_state = TransformerWrapperState(new_rnn_state, alignments)
        return output_element, new_state


class OutputMgcLf0AndStopTokenWrapper(RNNCell):

    def __init__(self, cell, mgc_out_units, lf0_out_units, dtype=None):
        super(OutputMgcLf0AndStopTokenWrapper, self).__init__()
        self._mgc_out_units = mgc_out_units
        self._lf0_out_units = lf0_out_units
        self._cell = cell
        self.mgc_out_projection1 = tf.layers.Dense(cell.output_size, activation=tf.nn.tanh, dtype=dtype)
        self.mgc_out_projection2 = tf.layers.Dense(mgc_out_units, dtype=dtype)
        self.lf0_out_projection = tf.layers.Dense(lf0_out_units, dtype=dtype)
        self.stop_token_projection = tf.layers.Dense(1, dtype=dtype)

    @property
    def state_size(self):
        return self._cell.state_size

    @property
    def output_size(self):
        return (self._mgc_out_units, self._lf0_out_units, 1)

    def zero_state(self, batch_size, dtype):
        return self._cell.zero_state(batch_size, dtype)

    def call(self, inputs, state):
        output, res_state = self._cell(inputs, state)
        mgc_output = self.mgc_out_projection2(self.mgc_out_projection1(output))
        lf0_output = self.lf0_out_projection(output)
        stop_token = self.stop_token_projection(output)
        return (mgc_output, lf0_output, stop_token), res_state


class DecoderMgcLf0PreNetWrapper(RNNCell):

    def __init__(self, cell: RNNCell, mgc_prenets: Tuple[PreNet], lf0_prenets: Tuple[PreNet]):
        super(DecoderMgcLf0PreNetWrapper, self).__init__()
        self._cell = cell
        self.mgc_prenets = mgc_prenets
        self.lf0_prenets = lf0_prenets

    @property
    def state_size(self):
        return self._cell.state_size

    @property
    def output_size(self):
        return self._cell.output_size

    def zero_state(self, batch_size, dtype):
        return self._cell.zero_state(batch_size, dtype)

    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], input_shape[1], self.output_size])

    def call(self, inputs, state):
        mgc_input, lf0_input = inputs
        mgc_prenet_output = reduce(lambda acc, pn: pn(acc), self.mgc_prenets, mgc_input)
        lf0_prenet_output = reduce(lambda acc, pn: pn(acc), self.lf0_prenets, lf0_input)
        prenet_output = tf.concat([mgc_prenet_output, lf0_prenet_output], axis=-1)
        return self._cell(prenet_output, state)


class OutputAndStopTokenTransparentWrapper(TransparentRNNCellLike):

    def __init__(self, cell, out_units, out_projection, stop_token_projection):
        self._out_units = out_units
        self._cell = cell
        self.out_projection = out_projection
        self.stop_token_projection = stop_token_projection

    @property
    def state_size(self):
        return self._cell.state_size

    @property
    def output_size(self):
        return (self._out_units, 1)

    def zero_state(self, batch_size, dtype):
        return self._cell.zero_state(batch_size, dtype)

    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], input_shape[1], self.output_size])

    def __call__(self, inputs, state):
        output, res_state = self._cell(inputs, state)
        mel_output = self.out_projection(output)
        stop_token = self.stop_token_projection(output)
        return (mel_output, stop_token), res_state


class OutputMgcLf0AndStopTokenTransparentWrapper(TransparentRNNCellLike):

    def __init__(self, cell, mgc_out_units, lf0_out_units, mgc_out_projection, lf0_out_projection, stop_token_projection):
        self._mgc_out_units = mgc_out_units
        self._lf0_out_units = lf0_out_units
        self._cell = cell
        self.mgc_out_projection = mgc_out_projection
        self.lf0_out_projection = lf0_out_projection
        self.stop_token_projection = stop_token_projection

    @property
    def state_size(self):
        return self._cell.state_size

    @property
    def output_size(self):
        return (self._mgc_out_units, self._lf0_out_units, 1)

    def zero_state(self, batch_size, dtype):
        return self._cell.zero_state(batch_size, dtype)

    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], input_shape[1], self.output_size])

    def __call__(self, inputs, state):
        output, res_state = self._cell(inputs, state)
        mgc_output = self.mgc_out_projection(output)
        lf0_output = self.lf0_out_projection(output)
        stop_token = self.stop_token_projection(output)
        return (mgc_output, lf0_output, stop_token), res_state