"""Module implementing EUNN Cell.
"""
import tensorflow as tf
import numpy as np
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.rnn_cell_impl import RNNCell
from baselineModels.modrelu import modrelu
from termcolor import colored


def _eunn_param(hidden_size, capacity=2, fft=False, comp=True, name="eunn"):
    """
    Create parameters and do the initial preparations
    """
    theta_phi_initializer = init_ops.random_uniform_initializer(-np.pi, np.pi)
    if fft:
        capacity = int(np.ceil(np.log2(hidden_size)))

        diag_list_0 = []
        off_list_0 = []
        varsize = 0
        for i in range(capacity):
            size = capacity - i
            normal_size = (hidden_size // (2 ** size)) * (2 ** (size - 1))
            extra_size = max(0, (hidden_size %
                                 (2 ** size)) - (2 ** (size - 1)))
            varsize += normal_size + extra_size

        params_theta = vs.get_variable(
            name + "theta_0", [varsize], initializer=theta_phi_initializer)
        cos_theta = math_ops.cos(params_theta)
        sin_theta = math_ops.sin(params_theta)

        if comp:
            params_phi = vs.get_variable(
                name + "phi_0", [varsize], initializer=theta_phi_initializer)
            cos_phi = math_ops.cos(params_phi)
            sin_phi = math_ops.sin(params_phi)

            cos_list_0 = math_ops.complex(
                cos_theta, array_ops.zeros_like(cos_theta))
            cos_list_1 = math_ops.complex(math_ops.multiply(
                cos_theta, cos_phi), math_ops.multiply(cos_theta, sin_phi))
            sin_list_0 = math_ops.complex(
                sin_theta, array_ops.zeros_like(sin_theta))
            sin_list_1 = math_ops.complex(-math_ops.multiply(
                sin_theta, cos_phi), -math_ops.multiply(sin_theta, sin_phi))

        last = 0
        for i in range(capacity):
            size = capacity - i
            normal_size = (hidden_size // (2 ** size)) * (2 ** (size - 1))
            extra_size = max(0, (hidden_size %
                                 (2 ** size)) - (2 ** (size - 1)))

            if comp:
                cos_list_normal = array_ops.concat([array_ops.slice(cos_list_0, [last], [
                                                   normal_size]), array_ops.slice(cos_list_1, [last], [normal_size])], 0)
                sin_list_normal = array_ops.concat([array_ops.slice(sin_list_0, [last], [
                                                   normal_size]), -array_ops.slice(sin_list_1, [last], [normal_size])], 0)
                last += normal_size

                cos_list_extra = array_ops.concat([array_ops.slice(cos_list_0, [last], [extra_size]), math_ops.complex(tf.ones(
                    [hidden_size - 2 * normal_size - 2 * extra_size]), tf.zeros([hidden_size - 2 * normal_size - 2 * extra_size])), array_ops.slice(cos_list_1, [last], [extra_size])], 0)
                sin_list_extra = array_ops.concat([array_ops.slice(sin_list_0, [last], [extra_size]), math_ops.complex(tf.zeros(
                    [hidden_size - 2 * normal_size - 2 * extra_size]), tf.zeros([hidden_size - 2 * normal_size - 2 * extra_size])), -array_ops.slice(sin_list_1, [last], [extra_size])], 0)
                last += extra_size

            else:
                cos_list_normal = array_ops.slice(
                    cos_theta, [last], [normal_size])
                cos_list_normal = array_ops.concat(
                    [cos_list_normal, cos_list_normal], 0)
                cos_list_extra = array_ops.slice(
                    cos_theta, [last + normal_size], [extra_size])
                cos_list_extra = array_ops.concat([cos_list_extra, tf.ones(
                    [hidden_size - 2 * normal_size - 2 * extra_size]), cos_list_extra], 0)

                sin_list_normal = array_ops.slice(
                    sin_theta, [last], [normal_size])
                sin_list_normal = array_ops.concat(
                    [sin_list_normal, -sin_list_normal], 0)
                sin_list_extra = array_ops.slice(
                    sin_theta, [last + normal_size], [extra_size])
                sin_list_extra = array_ops.concat([sin_list_extra, tf.zeros(
                    [hidden_size - 2 * normal_size - 2 * extra_size]), -sin_list_extra], 0)

                last += normal_size + extra_size

            if normal_size != 0:
                cos_list_normal = array_ops.reshape(array_ops.transpose(
                    array_ops.reshape(cos_list_normal, [-1, 2 * normal_size // (2**size)])), [-1])
                sin_list_normal = array_ops.reshape(array_ops.transpose(
                    array_ops.reshape(sin_list_normal, [-1, 2 * normal_size // (2**size)])), [-1])

            cos_list = array_ops.concat([cos_list_normal, cos_list_extra], 0)
            sin_list = array_ops.concat([sin_list_normal, sin_list_extra], 0)
            diag_list_0.append(cos_list)
            off_list_0.append(sin_list)

        diag_vec = array_ops.stack(diag_list_0, 0)
        off_vec = array_ops.stack(off_list_0, 0)

    else:
        capacity_b = capacity // 2
        capacity_a = capacity - capacity_b

        hidden_size_a = hidden_size // 2
        hidden_size_b = (hidden_size - 1) // 2

        params_theta_0 = vs.get_variable(
            name + "theta_0", [capacity_a, hidden_size_a], initializer=theta_phi_initializer)
        cos_theta_0 = array_ops.reshape(
            math_ops.cos(params_theta_0), [capacity_a, -1, 1])
        sin_theta_0 = array_ops.reshape(
            math_ops.sin(params_theta_0), [capacity_a, -1, 1])

        params_theta_1 = vs.get_variable(
            name + "theta_1", [capacity_b, hidden_size_b], initializer=theta_phi_initializer)
        cos_theta_1 = array_ops.reshape(
            math_ops.cos(params_theta_1), [capacity_b, -1, 1])
        sin_theta_1 = array_ops.reshape(
            math_ops.sin(params_theta_1), [capacity_b, -1, 1])

        if comp:
            params_phi_0 = vs.get_variable(
                name + "phi_0", [capacity_a, hidden_size_a], initializer=theta_phi_initializer)
            cos_phi_0 = array_ops.reshape(
                math_ops.cos(params_phi_0), [capacity_a, -1, 1])
            sin_phi_0 = array_ops.reshape(
                math_ops.sin(params_phi_0), [capacity_a, -1, 1])

            cos_list_0_re = array_ops.reshape(array_ops.concat(
                [cos_theta_0, math_ops.multiply(cos_theta_0, cos_phi_0)], 2), [capacity_a, -1])
            cos_list_0_im = array_ops.reshape(array_ops.concat([array_ops.zeros_like(
                cos_theta_0), math_ops.multiply(cos_theta_0, sin_phi_0)], 2), [capacity_a, -1])
            if hidden_size_a * 2 != hidden_size:
                cos_list_0_re = array_ops.concat(
                    [cos_list_0_re, tf.ones([capacity_a, 1])], 1)
                cos_list_0_im = array_ops.concat(
                    [cos_list_0_im, tf.zeros([capacity_a, 1])], 1)
            cos_list_0 = math_ops.complex(cos_list_0_re, cos_list_0_im)

            sin_list_0_re = array_ops.reshape(array_ops.concat(
                [sin_theta_0, - math_ops.multiply(sin_theta_0, cos_phi_0)], 2), [capacity_a, -1])
            sin_list_0_im = array_ops.reshape(array_ops.concat([array_ops.zeros_like(
                sin_theta_0), - math_ops.multiply(sin_theta_0, sin_phi_0)], 2), [capacity_a, -1])
            if hidden_size_a * 2 != hidden_size:
                sin_list_0_re = array_ops.concat(
                    [sin_list_0_re, tf.zeros([capacity_a, 1])], 1)
                sin_list_0_im = array_ops.concat(
                    [sin_list_0_im, tf.zeros([capacity_a, 1])], 1)
            sin_list_0 = math_ops.complex(sin_list_0_re, sin_list_0_im)

            params_phi_1 = vs.get_variable(
                name + "phi_1", [capacity_b, hidden_size_b], initializer=theta_phi_initializer)
            cos_phi_1 = array_ops.reshape(
                math_ops.cos(params_phi_1), [capacity_b, -1, 1])
            sin_phi_1 = array_ops.reshape(
                math_ops.sin(params_phi_1), [capacity_b, -1, 1])

            cos_list_1_re = array_ops.reshape(array_ops.concat(
                [cos_theta_1, math_ops.multiply(cos_theta_1, cos_phi_1)], 2), [capacity_b, -1])
            cos_list_1_re = array_ops.concat(
                [tf.ones((capacity_b, 1)), cos_list_1_re], 1)
            cos_list_1_im = array_ops.reshape(array_ops.concat([array_ops.zeros_like(
                cos_theta_1), math_ops.multiply(cos_theta_1, sin_phi_1)], 2), [capacity_b, -1])
            cos_list_1_im = array_ops.concat(
                [tf.zeros((capacity_b, 1)), cos_list_1_im], 1)
            if hidden_size_b * 2 != hidden_size - 1:
                cos_list_1_re = array_ops.concat(
                    [cos_list_1_re, tf.ones([capacity_b, 1])], 1)
                cos_list_1_im = array_ops.concat(
                    [cos_list_1_im, tf.zeros([capacity_b, 1])], 1)
            cos_list_1 = math_ops.complex(cos_list_1_re, cos_list_1_im)

            sin_list_1_re = array_ops.reshape(array_ops.concat(
                [sin_theta_1, -math_ops.multiply(sin_theta_1, cos_phi_1)], 2), [capacity_b, -1])
            sin_list_1_re = array_ops.concat(
                [tf.zeros((capacity_b, 1)), sin_list_1_re], 1)
            sin_list_1_im = array_ops.reshape(array_ops.concat([array_ops.zeros_like(
                sin_theta_1), -math_ops.multiply(sin_theta_1, sin_phi_1)], 2), [capacity_b, -1])
            sin_list_1_im = array_ops.concat(
                [tf.zeros((capacity_b, 1)), sin_list_1_im], 1)
            if hidden_size_b * 2 != hidden_size - 1:
                sin_list_1_re = array_ops.concat(
                    [sin_list_1_re, tf.zeros([capacity_b, 1])], 1)
                sin_list_1_im = array_ops.concat(
                    [sin_list_1_im, tf.zeros([capacity_b, 1])], 1)
            sin_list_1 = math_ops.complex(sin_list_1_re, sin_list_1_im)
        else:
            cos_list_0 = array_ops.reshape(array_ops.concat(
                [cos_theta_0, cos_theta_0], 2), [capacity_a, -1])
            sin_list_0 = array_ops.reshape(array_ops.concat(
                [sin_theta_0, -sin_theta_0], 2), [capacity_a, -1])
            if hidden_size_a * 2 != hidden_size:
                cos_list_0 = array_ops.concat(
                    [cos_list_0, tf.ones([capacity_a, 1])], 1)
                sin_list_0 = array_ops.concat(
                    [sin_list_0, tf.zeros([capacity_a, 1])], 1)

            cos_list_1 = array_ops.reshape(array_ops.concat(
                [cos_theta_1, cos_theta_1], 2), [capacity_b, -1])
            cos_list_1 = array_ops.concat(
                [tf.ones((capacity_b, 1)), cos_list_1], 1)
            sin_list_1 = array_ops.reshape(array_ops.concat(
                [sin_theta_1, -sin_theta_1], 2), [capacity_b, -1])
            sin_list_1 = array_ops.concat(
                [tf.zeros((capacity_b, 1)), sin_list_1], 1)
            if hidden_size_b * 2 != hidden_size - 1:
                cos_list_1 = array_ops.concat(
                    [cos_list_1, tf.zeros([capacity_b, 1])], 1)
                sin_list_1 = array_ops.concat(
                    [sin_list_1, tf.zeros([capacity_b, 1])], 1)

        if capacity_b != capacity_a:
            if comp:
                cos_list_1 = array_ops.concat([cos_list_1, math_ops.complex(
                    tf.zeros([1, hidden_size]), tf.zeros([1, hidden_size]))], 0)
                sin_list_1 = array_ops.concat([sin_list_1, math_ops.complex(
                    tf.zeros([1, hidden_size]), tf.zeros([1, hidden_size]))], 0)
            else:
                cos_list_1 = array_ops.concat(
                    [cos_list_1, tf.zeros([1, hidden_size])], 0)
                sin_list_1 = array_ops.concat(
                    [sin_list_1, tf.zeros([1, hidden_size])], 0)

        diag_vec = tf.reshape(tf.concat([cos_list_0, cos_list_1], 1), [
                              capacity_a * 2, hidden_size])
        off_vec = tf.reshape(tf.concat([sin_list_0, sin_list_1], 1), [
                             capacity_a * 2, hidden_size])

        if capacity_b != capacity_a:
            diag_vec = tf.slice(diag_vec, [0, 0], [capacity, hidden_size])
            off_vec = tf.slice(off_vec, [0, 0], [capacity, hidden_size])

    def _toTensorArray(elems):

        elems = ops.convert_to_tensor(elems)
        n = array_ops.shape(elems)[0]
        elems_ta = tensor_array_ops.TensorArray(
            dtype=elems.dtype, size=n, dynamic_size=False, infer_shape=True, clear_after_read=False)
        elems_ta = elems_ta.unstack(elems)
        return elems_ta

    diag_vec = _toTensorArray(diag_vec)
    off_vec = _toTensorArray(off_vec)
    if comp:
        omega = vs.get_variable(
            name + "omega", [hidden_size], initializer=theta_phi_initializer)
        diag = math_ops.complex(math_ops.cos(omega), math_ops.sin(omega))
    else:
        diag = None

    return diag_vec, off_vec, diag, capacity


def _eunn_loop(state, capacity, diag_vec_list, off_vec_list, diag, fft):
    """
    EUNN main loop, applying unitary matrix on input tensor
    """
    i = 0

    def layer_tunable(x, i):

        diag_vec = diag_vec_list.read(i)
        off_vec = off_vec_list.read(i)

        diag = math_ops.multiply(x, diag_vec)
        off = math_ops.multiply(x, off_vec)

        def even_input(off, size):

            def even_s(off, size):
                off = array_ops.reshape(off, [-1, size // 2, 2])
                off = array_ops.reshape(
                    array_ops.reverse(off, [2]), [-1, size])
                return off

            def odd_s(off, size):
                off, helper = array_ops.split(off, [size - 1, 1], 1)
                size -= 1
                off = even_s(off, size)
                off = array_ops.concat([off, helper], 1)
                return off

            off = control_flow_ops.cond(gen_math_ops.equal(gen_math_ops.mod(
                size, 2), 0), lambda: even_s(off, size), lambda: odd_s(off, size))
            return off

        def odd_input(off, size):
            helper, off = array_ops.split(off, [1, size - 1], 1)
            size -= 1
            off = even_input(off, size)
            off = array_ops.concat([helper, off], 1)
            return off

        size = int(off.get_shape()[1])
        off = control_flow_ops.cond(gen_math_ops.equal(gen_math_ops.mod(
            i, 2), 0), lambda: even_input(off, size), lambda: odd_input(off, size))

        layer_output = diag + off
        i += 1

        return layer_output, i

    def layer_fft(state, i):

        diag_vec = diag_vec_list.read(i)
        off_vec = off_vec_list.read(i)
        diag = math_ops.multiply(state, diag_vec)
        off = math_ops.multiply(state, off_vec)

        hidden_size = int(off.get_shape()[1])
        # size = 2**i
        dist = capacity - i
        normal_size = (hidden_size // (2**dist)) * (2**(dist - 1))
        normal_size *= 2
        extra_size = tf.maximum(0, (hidden_size % (2**dist)) - (2**(dist - 1)))
        hidden_size -= normal_size

        def modify(off_normal, dist, normal_size):
            off_normal = array_ops.reshape(array_ops.reverse(array_ops.reshape(
                off_normal, [-1, normal_size // (2**dist), 2, (2**(dist - 1))]), [2]), [-1, normal_size])
            return off_normal

        def do_nothing(off_normal):
            return off_normal

        off_normal, off_extra = array_ops.split(
            off, [normal_size, hidden_size], 1)
        off_normal = control_flow_ops.cond(gen_math_ops.equal(normal_size, 0), lambda: do_nothing(
            off_normal), lambda: modify(off_normal, dist, normal_size))
        helper1, helper2 = array_ops.split(
            off_extra, [hidden_size - extra_size, extra_size], 1)
        off_extra = array_ops.concat([helper2, helper1], 1)
        off = array_ops.concat([off_normal, off_extra], 1)

        layer_output = diag + off
        i += 1

        return layer_output, i

    if fft:
        layer_function = layer_fft
    else:
        layer_function = layer_tunable
    output, _ = control_flow_ops.while_loop(
        lambda state, i: gen_math_ops.less(i, capacity), layer_function, [state, i])

    if not diag is None:
        output = math_ops.multiply(output, diag)

    return output


class EUNNCell(RNNCell):
    """Efficient Unitary Network Cell
    The implementation is based on: http://arxiv.org/abs/1612.05231.

    """

    def __init__(self, hidden_size, capacity=2, fft=False, comp=False, activation=modrelu, name=None):
        super(EUNNCell, self).__init__()
        self._hidden_size = hidden_size
        self._activation = activation
        self._capacity = capacity
        self._fft = fft
        self._comp = comp
        self._name = name

        self.diag_vec, self.off_vec, self.diag, self._capacity = _eunn_param(
            hidden_size, capacity, fft, comp, name)

    @property
    def state_size(self):
        return self._hidden_size

    @property
    def output_size(self):
        return self._hidden_size

    @property
    def capacity(self):
        return self._capacity

    def __call__(self, inputs, state, scope=None):
        with vs.variable_scope(scope or "eunn_cell"):

            state = _eunn_loop(state, self._capacity, self.diag_vec,
                               self.off_vec, self.diag, self._fft)

            input_matrix_init = init_ops.random_uniform_initializer(
                -0.01, 0.01)
            if self._comp:
                input_matrix_re = vs.get_variable("U_re", [inputs.get_shape(
                )[-1], self._hidden_size], initializer=input_matrix_init)
                input_matrix_im = vs.get_variable("U_im", [inputs.get_shape(
                )[-1], self._hidden_size], initializer=input_matrix_init)
                inputs_re = math_ops.matmul(inputs, input_matrix_re)
                inputs_im = math_ops.matmul(inputs, input_matrix_im)
                inputs = math_ops.complex(inputs_re, inputs_im)
            else:
                input_matrix = vs.get_variable(
                    "U", [inputs.get_shape()[-1], self._hidden_size], initializer=input_matrix_init)
                inputs = math_ops.matmul(inputs, input_matrix)

            bias = vs.get_variable(
                "modReLUBias", [self._hidden_size], initializer=init_ops.constant_initializer())
            output = self._activation((inputs + state), bias, self._comp)

        return output, output