import tensorflow as tf


class FSMN(object):
    def __init__(self, memory_size, input_size, output_size, dtype=tf.float32):
        self._memory_size = memory_size
        self._output_size = output_size
        self._input_size = input_size
        self._dtype = dtype
        self._build_graph()

    def _build_graph(self):
        self._W1 = tf.get_variable("fsmnn_w1", [self._input_size, self._output_size], initializer=tf.truncated_normal_initializer(stddev=5e-2, dtype=self._dtype))
        self._W2 = tf.get_variable("fsmnn_w2", [self._input_size, self._output_size], initializer=tf.truncated_normal_initializer(stddev=5e-2, dtype=self._dtype))
        self._bias = tf.get_variable("fsmnn_bias", [self._output_size], initializer=tf.constant_initializer(0.0, dtype=self._dtype))
        self._memory_weights = tf.get_variable("memory_weights", [self._memory_size], initializer=tf.constant_initializer(1.0, dtype=self._dtype))

    def __call__(self, input_data):
        batch_size = input_data.get_shape()[0].value
        num_steps = input_data.get_shape()[1].value

        memory_matrix = []
        for step in range(num_steps):
            left_num = tf.maximum(0, step + 1 - self._memory_size)
            right_num = num_steps - step - 1
            mem = self._memory_weights[tf.minimum(step, self._memory_size)::-1]
            d_batch = tf.pad(mem, [[left_num, right_num]])
            memory_matrix.append([d_batch])
        memory_matrix = tf.concat(0, memory_matrix)

        h_hatt = tf.batch_matmul([memory_matrix] * batch_size, input_data)
        h = tf.batch_matmul(input_data, [self._W1] * batch_size)
        h += tf.batch_matmul(h_hatt, [self._W2] * batch_size) + self._bias
        return h