import pytest
import numpy as np
from numpy.testing import assert_allclose

import keras
from keras.utils.test_utils import layer_test
from keras.utils.test_utils import keras_test
from keras.layers import recurrent
from keras.layers import embeddings
from keras.models import Sequential
from keras.models import Model
from keras.engine.topology import Input
from keras.layers.core import Masking
from keras import regularizers
from keras import backend as K

num_samples, timesteps, embedding_dim, units = 2, 5, 4, 3
embedding_num = 12


@keras_test
def rnn_test(f):
    """
    All the recurrent layers share the same interface,
    so we can run through them with a single function.
    """
    f = keras_test(f)
    return pytest.mark.parametrize('layer_class', [
        recurrent.SimpleRNN,
        recurrent.GRU,
        recurrent.LSTM
    ])(f)


@keras_test
def rnn_cell_test(f):
    f = keras_test(f)
    return pytest.mark.parametrize('cell_class', [
        recurrent.SimpleRNNCell,
        recurrent.GRUCell,
        recurrent.LSTMCell
    ])(f)


@rnn_test
def test_return_sequences(layer_class):
    layer_test(layer_class,
               kwargs={'units': units,
                       'return_sequences': True},
               input_shape=(num_samples, timesteps, embedding_dim))


@rnn_test
def test_dynamic_behavior(layer_class):
    layer = layer_class(units, input_shape=(None, embedding_dim))
    model = Sequential()
    model.add(layer)
    model.compile('sgd', 'mse')
    x = np.random.random((num_samples, timesteps, embedding_dim))
    y = np.random.random((num_samples, units))
    model.train_on_batch(x, y)


@rnn_test
def test_stateful_invalid_use(layer_class):
    layer = layer_class(units,
                        stateful=True,
                        batch_input_shape=(num_samples,
                                           timesteps,
                                           embedding_dim))
    model = Sequential()
    model.add(layer)
    model.compile('sgd', 'mse')
    x = np.random.random((num_samples * 2, timesteps, embedding_dim))
    y = np.random.random((num_samples * 2, units))
    with pytest.raises(ValueError):
        model.fit(x, y)
    with pytest.raises(ValueError):
        model.predict(x, batch_size=num_samples + 1)


@rnn_test
@pytest.mark.skipif((K.backend() in ['theano']),
                    reason='Not supported.')
def test_dropout(layer_class):
    for unroll in [True, False]:
        layer_test(layer_class,
                   kwargs={'units': units,
                           'dropout': 0.1,
                           'recurrent_dropout': 0.1,
                           'unroll': unroll},
                   input_shape=(num_samples, timesteps, embedding_dim))

        # Test that dropout is applied during training
        x = K.ones((num_samples, timesteps, embedding_dim))
        layer = layer_class(units, dropout=0.5, recurrent_dropout=0.5,
                            input_shape=(timesteps, embedding_dim))
        y = layer(x)
        assert y._uses_learning_phase

        y = layer(x, training=True)
        assert not getattr(y, '_uses_learning_phase')

        # Test that dropout is not applied during testing
        x = np.random.random((num_samples, timesteps, embedding_dim))
        layer = layer_class(units, dropout=0.5, recurrent_dropout=0.5,
                            unroll=unroll,
                            input_shape=(timesteps, embedding_dim))
        model = Sequential([layer])
        assert model.uses_learning_phase
        y1 = model.predict(x)
        y2 = model.predict(x)
        assert_allclose(y1, y2)


@rnn_test
def test_statefulness(layer_class):
    model = Sequential()
    model.add(embeddings.Embedding(embedding_num, embedding_dim,
                                   mask_zero=True,
                                   input_length=timesteps,
                                   batch_input_shape=(num_samples, timesteps)))
    layer = layer_class(units, return_sequences=False,
                        stateful=True,
                        weights=None)
    model.add(layer)
    model.compile(optimizer='sgd', loss='mse')
    out1 = model.predict(np.ones((num_samples, timesteps)))
    assert(out1.shape == (num_samples, units))

    # train once so that the states change
    model.train_on_batch(np.ones((num_samples, timesteps)),
                         np.ones((num_samples, units)))
    out2 = model.predict(np.ones((num_samples, timesteps)))

    # if the state is not reset, output should be different
    assert(out1.max() != out2.max())

    # check that output changes after states are reset
    # (even though the model itself didn't change)
    layer.reset_states()
    out3 = model.predict(np.ones((num_samples, timesteps)))
    assert(out2.max() != out3.max())

    # check that container-level reset_states() works
    model.reset_states()
    out4 = model.predict(np.ones((num_samples, timesteps)))
    assert_allclose(out3, out4, atol=1e-5)

    # check that the call to `predict` updated the states
    out5 = model.predict(np.ones((num_samples, timesteps)))
    assert(out4.max() != out5.max())


@rnn_test
def test_masking_correctness(layer_class):
    # Check masking: output with left padding and right padding
    # should be the same.
    model = Sequential()
    model.add(embeddings.Embedding(embedding_num, embedding_dim,
                                   mask_zero=True,
                                   input_length=timesteps,
                                   batch_input_shape=(num_samples, timesteps)))
    layer = layer_class(units, return_sequences=False)
    model.add(layer)
    model.compile(optimizer='sgd', loss='mse')

    left_padded_input = np.ones((num_samples, timesteps))
    left_padded_input[0, :1] = 0
    left_padded_input[1, :2] = 0
    out6 = model.predict(left_padded_input)

    right_padded_input = np.ones((num_samples, timesteps))
    right_padded_input[0, -1:] = 0
    right_padded_input[1, -2:] = 0
    out7 = model.predict(right_padded_input)

    assert_allclose(out7, out6, atol=1e-5)


@rnn_test
def test_implementation_mode(layer_class):
    for mode in [1, 2]:
        # Without dropout
        layer_test(layer_class,
                   kwargs={'units': units,
                           'implementation': mode},
                   input_shape=(num_samples, timesteps, embedding_dim))
        # With dropout
        layer_test(layer_class,
                   kwargs={'units': units,
                           'implementation': mode,
                           'dropout': 0.1,
                           'recurrent_dropout': 0.1},
                   input_shape=(num_samples, timesteps, embedding_dim))
        # Without bias
        layer_test(layer_class,
                   kwargs={'units': units,
                           'implementation': mode,
                           'use_bias': False},
                   input_shape=(num_samples, timesteps, embedding_dim))


@rnn_test
def test_regularizer(layer_class):
    layer = layer_class(units, return_sequences=False, weights=None,
                        input_shape=(timesteps, embedding_dim),
                        kernel_regularizer=regularizers.l1(0.01),
                        recurrent_regularizer=regularizers.l1(0.01),
                        bias_regularizer='l2')
    layer.build((None, None, embedding_dim))
    assert len(layer.losses) == 3
    assert len(layer.cell.losses) == 3

    layer = layer_class(units, return_sequences=False, weights=None,
                        input_shape=(timesteps, embedding_dim),
                        activity_regularizer='l2')
    assert layer.activity_regularizer
    x = K.variable(np.ones((num_samples, timesteps, embedding_dim)))
    layer(x)
    assert len(layer.cell.get_losses_for(x)) == 0
    assert len(layer.get_losses_for(x)) == 1


@rnn_test
def test_trainability(layer_class):
    layer = layer_class(units)
    layer.build((None, None, embedding_dim))
    assert len(layer.weights) == 3
    assert len(layer.trainable_weights) == 3
    assert len(layer.non_trainable_weights) == 0
    layer.trainable = False
    assert len(layer.weights) == 3
    assert len(layer.trainable_weights) == 0
    assert len(layer.non_trainable_weights) == 3
    layer.trainable = True
    assert len(layer.weights) == 3
    assert len(layer.trainable_weights) == 3
    assert len(layer.non_trainable_weights) == 0


@keras_test
def test_masking_layer():
    ''' This test based on a previously failing issue here:
    https://github.com/keras-team/keras/issues/1567
    '''
    inputs = np.random.random((6, 3, 4))
    targets = np.abs(np.random.random((6, 3, 5)))
    targets /= targets.sum(axis=-1, keepdims=True)

    model = Sequential()
    model.add(Masking(input_shape=(3, 4)))
    model.add(recurrent.SimpleRNN(units=5, return_sequences=True, unroll=False))
    model.compile(loss='categorical_crossentropy', optimizer='adam')
    model.fit(inputs, targets, epochs=1, batch_size=100, verbose=1)

    model = Sequential()
    model.add(Masking(input_shape=(3, 4)))
    model.add(recurrent.SimpleRNN(units=5, return_sequences=True, unroll=True))
    model.compile(loss='categorical_crossentropy', optimizer='adam')
    model.fit(inputs, targets, epochs=1, batch_size=100, verbose=1)


@rnn_test
def test_from_config(layer_class):
    stateful_flags = (False, True)
    for stateful in stateful_flags:
        l1 = layer_class(units=1, stateful=stateful)
        l2 = layer_class.from_config(l1.get_config())
        assert l1.get_config() == l2.get_config()


@rnn_test
def test_specify_initial_state_keras_tensor(layer_class):
    num_states = 2 if layer_class is recurrent.LSTM else 1

    # Test with Keras tensor
    inputs = Input((timesteps, embedding_dim))
    initial_state = [Input((units,)) for _ in range(num_states)]
    layer = layer_class(units)
    if len(initial_state) == 1:
        output = layer(inputs, initial_state=initial_state[0])
    else:
        output = layer(inputs, initial_state=initial_state)
    assert initial_state[0] in layer._inbound_nodes[0].input_tensors

    model = Model([inputs] + initial_state, output)
    model.compile(loss='categorical_crossentropy', optimizer='adam')

    inputs = np.random.random((num_samples, timesteps, embedding_dim))
    initial_state = [np.random.random((num_samples, units))
                     for _ in range(num_states)]
    targets = np.random.random((num_samples, units))
    model.fit([inputs] + initial_state, targets)


@rnn_test
def test_specify_initial_state_non_keras_tensor(layer_class):
    num_states = 2 if layer_class is recurrent.LSTM else 1

    # Test with non-Keras tensor
    inputs = Input((timesteps, embedding_dim))
    initial_state = [K.random_normal_variable((num_samples, units), 0, 1)
                     for _ in range(num_states)]
    layer = layer_class(units)
    output = layer(inputs, initial_state=initial_state)

    model = Model(inputs, output)
    model.compile(loss='categorical_crossentropy', optimizer='adam')

    inputs = np.random.random((num_samples, timesteps, embedding_dim))
    targets = np.random.random((num_samples, units))
    model.fit(inputs, targets)


@rnn_test
def test_reset_states_with_values(layer_class):
    num_states = 2 if layer_class is recurrent.LSTM else 1

    layer = layer_class(units, stateful=True)
    layer.build((num_samples, timesteps, embedding_dim))
    layer.reset_states()
    assert len(layer.states) == num_states
    assert layer.states[0] is not None
    np.testing.assert_allclose(K.eval(layer.states[0]),
                               np.zeros(K.int_shape(layer.states[0])),
                               atol=1e-4)
    state_shapes = [K.int_shape(state) for state in layer.states]
    values = [np.ones(shape) for shape in state_shapes]
    if len(values) == 1:
        values = values[0]
    layer.reset_states(values)
    np.testing.assert_allclose(K.eval(layer.states[0]),
                               np.ones(K.int_shape(layer.states[0])),
                               atol=1e-4)

    # Test fit with invalid data
    with pytest.raises(ValueError):
        layer.reset_states([1] * (len(layer.states) + 1))


@rnn_test
def test_initial_states_as_other_inputs(layer_class):
    num_states = 2 if layer_class is recurrent.LSTM else 1

    # Test with Keras tensor
    main_inputs = Input((timesteps, embedding_dim))
    initial_state = [Input((units,)) for _ in range(num_states)]
    inputs = [main_inputs] + initial_state

    layer = layer_class(units)
    output = layer(inputs)
    assert initial_state[0] in layer._inbound_nodes[0].input_tensors

    model = Model(inputs, output)
    model.compile(loss='categorical_crossentropy', optimizer='adam')

    main_inputs = np.random.random((num_samples, timesteps, embedding_dim))
    initial_state = [np.random.random((num_samples, units))
                     for _ in range(num_states)]
    targets = np.random.random((num_samples, units))
    model.train_on_batch([main_inputs] + initial_state, targets)


@rnn_test
def test_specify_state_with_masking(layer_class):
    ''' This test based on a previously failing issue here:
    https://github.com/keras-team/keras/issues/1567
    '''
    num_states = 2 if layer_class is recurrent.LSTM else 1

    inputs = Input((timesteps, embedding_dim))
    _ = Masking()(inputs)
    initial_state = [Input((units,)) for _ in range(num_states)]
    output = layer_class(units)(inputs, initial_state=initial_state)

    model = Model([inputs] + initial_state, output)
    model.compile(loss='categorical_crossentropy', optimizer='adam')

    inputs = np.random.random((num_samples, timesteps, embedding_dim))
    initial_state = [np.random.random((num_samples, units))
                     for _ in range(num_states)]
    targets = np.random.random((num_samples, units))
    model.fit([inputs] + initial_state, targets)


@rnn_test
def test_return_state(layer_class):
    num_states = 2 if layer_class is recurrent.LSTM else 1

    inputs = Input(batch_shape=(num_samples, timesteps, embedding_dim))
    layer = layer_class(units, return_state=True, stateful=True)
    outputs = layer(inputs)
    output, state = outputs[0], outputs[1:]
    assert len(state) == num_states
    model = Model(inputs, state[0])

    inputs = np.random.random((num_samples, timesteps, embedding_dim))
    state = model.predict(inputs)
    np.testing.assert_allclose(K.eval(layer.states[0]), state, atol=1e-4)


@rnn_test
def test_state_reuse(layer_class):
    inputs = Input(batch_shape=(num_samples, timesteps, embedding_dim))
    layer = layer_class(units, return_state=True, return_sequences=True)
    outputs = layer(inputs)
    output, state = outputs[0], outputs[1:]
    output = layer_class(units)(output, initial_state=state)
    model = Model(inputs, output)

    inputs = np.random.random((num_samples, timesteps, embedding_dim))
    outputs = model.predict(inputs)


@rnn_test
@pytest.mark.skipif((K.backend() in ['theano']),
                    reason='Not supported.')
def test_state_reuse_with_dropout(layer_class):
    input1 = Input(batch_shape=(num_samples, timesteps, embedding_dim))
    layer = layer_class(units, return_state=True, return_sequences=True, dropout=0.2)
    state = layer(input1)[1:]

    input2 = Input(batch_shape=(num_samples, timesteps, embedding_dim))
    output = layer_class(units)(input2, initial_state=state)
    model = Model([input1, input2], output)

    inputs = [np.random.random((num_samples, timesteps, embedding_dim)),
              np.random.random((num_samples, timesteps, embedding_dim))]
    outputs = model.predict(inputs)


@keras_test
def test_minimal_rnn_cell_non_layer():

    class MinimalRNNCell(object):

        def __init__(self, units, input_dim):
            self.units = units
            self.state_size = units
            self.kernel = keras.backend.variable(
                np.random.random((input_dim, units)))

        def call(self, inputs, states):
            prev_output = states[0]
            output = keras.backend.dot(inputs, self.kernel) + prev_output
            return output, [output]

    # Basic test case.
    cell = MinimalRNNCell(32, 5)
    x = keras.Input((None, 5))
    layer = recurrent.RNN(cell)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

    # Test stacking.
    cells = [MinimalRNNCell(8, 5),
             MinimalRNNCell(32, 8),
             MinimalRNNCell(32, 32)]
    layer = recurrent.RNN(cells)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))


@keras_test
def test_minimal_rnn_cell_non_layer_multiple_states():

    class MinimalRNNCell(object):

        def __init__(self, units, input_dim):
            self.units = units
            self.state_size = (units, units)
            self.kernel = keras.backend.variable(
                np.random.random((input_dim, units)))

        def call(self, inputs, states):
            prev_output_1 = states[0]
            prev_output_2 = states[1]
            output = keras.backend.dot(inputs, self.kernel)
            output += prev_output_1
            output -= prev_output_2
            return output, [output * 2, output * 3]

    # Basic test case.
    cell = MinimalRNNCell(32, 5)
    x = keras.Input((None, 5))
    layer = recurrent.RNN(cell)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

    # Test stacking.
    cells = [MinimalRNNCell(8, 5),
             MinimalRNNCell(16, 8),
             MinimalRNNCell(32, 16)]
    layer = recurrent.RNN(cells)
    assert layer.cell.state_size == (32, 32, 16, 16, 8, 8)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))


@keras_test
def test_minimal_rnn_cell_layer():

    class MinimalRNNCell(keras.layers.Layer):

        def __init__(self, units, **kwargs):
            self.units = units
            self.state_size = units
            super(MinimalRNNCell, self).__init__(**kwargs)

        def build(self, input_shape):
            # no time axis in the input shape passed to RNN cells
            assert len(input_shape) == 2

            self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                          initializer='uniform',
                                          name='kernel')
            self.recurrent_kernel = self.add_weight(
                shape=(self.units, self.units),
                initializer='uniform',
                name='recurrent_kernel')
            self.built = True

        def call(self, inputs, states):
            prev_output = states[0]
            h = keras.backend.dot(inputs, self.kernel)
            output = h + keras.backend.dot(prev_output, self.recurrent_kernel)
            return output, [output]

        def get_config(self):
            config = {'units': self.units}
            base_config = super(MinimalRNNCell, self).get_config()
            return dict(list(base_config.items()) + list(config.items()))

    # Test basic case.
    x = keras.Input((None, 5))
    cell = MinimalRNNCell(32)
    layer = recurrent.RNN(cell)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

    # Test basic case serialization.
    x_np = np.random.random((6, 5, 5))
    y_np = model.predict(x_np)
    weights = model.get_weights()
    config = layer.get_config()
    with keras.utils.CustomObjectScope({'MinimalRNNCell': MinimalRNNCell}):
        layer = recurrent.RNN.from_config(config)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.set_weights(weights)
    y_np_2 = model.predict(x_np)
    assert_allclose(y_np, y_np_2, atol=1e-4)

    # Test stacking.
    cells = [MinimalRNNCell(8),
             MinimalRNNCell(12),
             MinimalRNNCell(32)]
    layer = recurrent.RNN(cells)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

    # Test stacked RNN serialization.
    x_np = np.random.random((6, 5, 5))
    y_np = model.predict(x_np)
    weights = model.get_weights()
    config = layer.get_config()
    with keras.utils.CustomObjectScope({'MinimalRNNCell': MinimalRNNCell}):
        layer = recurrent.RNN.from_config(config)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.set_weights(weights)
    y_np_2 = model.predict(x_np)
    assert_allclose(y_np, y_np_2, atol=1e-4)


@rnn_cell_test
def test_builtin_rnn_cell_layer(cell_class):
    # Test basic case.
    x = keras.Input((None, 5))
    cell = cell_class(32)
    layer = recurrent.RNN(cell)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

    # Test basic case serialization.
    x_np = np.random.random((6, 5, 5))
    y_np = model.predict(x_np)
    weights = model.get_weights()
    config = layer.get_config()
    layer = recurrent.RNN.from_config(config)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.set_weights(weights)
    y_np_2 = model.predict(x_np)
    assert_allclose(y_np, y_np_2, atol=1e-4)

    # Test stacking.
    cells = [cell_class(8),
             cell_class(12),
             cell_class(32)]
    layer = recurrent.RNN(cells)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

    # Test stacked RNN serialization.
    x_np = np.random.random((6, 5, 5))
    y_np = model.predict(x_np)
    weights = model.get_weights()
    config = layer.get_config()
    layer = recurrent.RNN.from_config(config)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.set_weights(weights)
    y_np_2 = model.predict(x_np)
    assert_allclose(y_np, y_np_2, atol=1e-4)


@keras_test
@pytest.mark.skipif((K.backend() in ['cntk', 'theano']),
                    reason='Not supported.')
def test_stacked_rnn_dropout():
    cells = [recurrent.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1),
             recurrent.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1)]
    layer = recurrent.RNN(cells)

    x = keras.Input((None, 5))
    y = layer(x)
    model = keras.models.Model(x, y)
    model.compile('sgd', 'mse')
    x_np = np.random.random((6, 5, 5))
    y_np = np.random.random((6, 3))
    model.train_on_batch(x_np, y_np)


@keras_test
def test_stacked_rnn_attributes():
    cells = [recurrent.LSTMCell(3),
             recurrent.LSTMCell(3, kernel_regularizer='l2')]
    layer = recurrent.RNN(cells)
    layer.build((None, None, 5))

    # Test regularization losses
    assert len(layer.losses) == 1

    # Test weights
    assert len(layer.trainable_weights) == 6
    cells[0].trainable = False
    assert len(layer.trainable_weights) == 3
    assert len(layer.non_trainable_weights) == 3

    # Test `get_losses_for`
    x = keras.Input((None, 5))
    y = K.sum(x)
    cells[0].add_loss(y, inputs=x)
    assert layer.get_losses_for(x) == [y]


@keras_test
def test_stacked_rnn_compute_output_shape():
    cells = [recurrent.LSTMCell(3),
             recurrent.LSTMCell(6)]
    layer = recurrent.RNN(cells, return_state=True, return_sequences=True)
    output_shape = layer.compute_output_shape((None, timesteps, embedding_dim))
    expected_output_shape = [(None, timesteps, 6),
                             (None, 6),
                             (None, 6),
                             (None, 3),
                             (None, 3)]
    assert output_shape == expected_output_shape


@rnn_test
def test_batch_size_equal_one(layer_class):
    inputs = Input(batch_shape=(1, timesteps, embedding_dim))
    layer = layer_class(units)
    outputs = layer(inputs)
    model = Model(inputs, outputs)
    model.compile('sgd', 'mse')
    x = np.random.random((1, timesteps, embedding_dim))
    y = np.random.random((1, units))
    model.train_on_batch(x, y)


@keras_test
def test_rnn_cell_with_constants_layer():

    class RNNCellWithConstants(keras.layers.Layer):

        def __init__(self, units, **kwargs):
            self.units = units
            self.state_size = units
            super(RNNCellWithConstants, self).__init__(**kwargs)

        def build(self, input_shape):
            if not isinstance(input_shape, list):
                raise TypeError('expects constants shape')
            [input_shape, constant_shape] = input_shape
            # will (and should) raise if more than one constant passed

            self.input_kernel = self.add_weight(
                shape=(input_shape[-1], self.units),
                initializer='uniform',
                name='kernel')
            self.recurrent_kernel = self.add_weight(
                shape=(self.units, self.units),
                initializer='uniform',
                name='recurrent_kernel')
            self.constant_kernel = self.add_weight(
                shape=(constant_shape[-1], self.units),
                initializer='uniform',
                name='constant_kernel')
            self.built = True

        def call(self, inputs, states, constants):
            [prev_output] = states
            [constant] = constants
            h_input = keras.backend.dot(inputs, self.input_kernel)
            h_state = keras.backend.dot(prev_output, self.recurrent_kernel)
            h_const = keras.backend.dot(constant, self.constant_kernel)
            output = h_input + h_state + h_const
            return output, [output]

        def get_config(self):
            config = {'units': self.units}
            base_config = super(RNNCellWithConstants, self).get_config()
            return dict(list(base_config.items()) + list(config.items()))

    # Test basic case.
    x = keras.Input((None, 5))
    c = keras.Input((3,))
    cell = RNNCellWithConstants(32)
    layer = recurrent.RNN(cell)
    y = layer(x, constants=c)
    model = keras.models.Model([x, c], y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(
        [np.zeros((6, 5, 5)), np.zeros((6, 3))],
        np.zeros((6, 32))
    )

    # Test basic case serialization.
    x_np = np.random.random((6, 5, 5))
    c_np = np.random.random((6, 3))
    y_np = model.predict([x_np, c_np])
    weights = model.get_weights()
    config = layer.get_config()
    custom_objects = {'RNNCellWithConstants': RNNCellWithConstants}
    with keras.utils.CustomObjectScope(custom_objects):
        layer = recurrent.RNN.from_config(config.copy())
    y = layer(x, constants=c)
    model = keras.models.Model([x, c], y)
    model.set_weights(weights)
    y_np_2 = model.predict([x_np, c_np])
    assert_allclose(y_np, y_np_2, atol=1e-4)

    # test flat list inputs
    with keras.utils.CustomObjectScope(custom_objects):
        layer = recurrent.RNN.from_config(config.copy())
    y = layer([x, c])
    model = keras.models.Model([x, c], y)
    model.set_weights(weights)
    y_np_3 = model.predict([x_np, c_np])
    assert_allclose(y_np, y_np_3, atol=1e-4)

    # Test stacking.
    cells = [recurrent.GRUCell(8),
             RNNCellWithConstants(12),
             RNNCellWithConstants(32)]
    layer = recurrent.RNN(cells)
    y = layer(x, constants=c)
    model = keras.models.Model([x, c], y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(
        [np.zeros((6, 5, 5)), np.zeros((6, 3))],
        np.zeros((6, 32))
    )

    # Test stacked RNN serialization.
    x_np = np.random.random((6, 5, 5))
    c_np = np.random.random((6, 3))
    y_np = model.predict([x_np, c_np])
    weights = model.get_weights()
    config = layer.get_config()
    with keras.utils.CustomObjectScope(custom_objects):
        layer = recurrent.RNN.from_config(config.copy())
    y = layer(x, constants=c)
    model = keras.models.Model([x, c], y)
    model.set_weights(weights)
    y_np_2 = model.predict([x_np, c_np])
    assert_allclose(y_np, y_np_2, atol=1e-4)


@keras_test
def test_rnn_cell_with_constants_layer_passing_initial_state():

    class RNNCellWithConstants(keras.layers.Layer):

        def __init__(self, units, **kwargs):
            self.units = units
            self.state_size = units
            super(RNNCellWithConstants, self).__init__(**kwargs)

        def build(self, input_shape):
            if not isinstance(input_shape, list):
                raise TypeError('expects constants shape')
            [input_shape, constant_shape] = input_shape
            # will (and should) raise if more than one constant passed

            self.input_kernel = self.add_weight(
                shape=(input_shape[-1], self.units),
                initializer='uniform',
                name='kernel')
            self.recurrent_kernel = self.add_weight(
                shape=(self.units, self.units),
                initializer='uniform',
                name='recurrent_kernel')
            self.constant_kernel = self.add_weight(
                shape=(constant_shape[-1], self.units),
                initializer='uniform',
                name='constant_kernel')
            self.built = True

        def call(self, inputs, states, constants):
            [prev_output] = states
            [constant] = constants
            h_input = keras.backend.dot(inputs, self.input_kernel)
            h_state = keras.backend.dot(prev_output, self.recurrent_kernel)
            h_const = keras.backend.dot(constant, self.constant_kernel)
            output = h_input + h_state + h_const
            return output, [output]

        def get_config(self):
            config = {'units': self.units}
            base_config = super(RNNCellWithConstants, self).get_config()
            return dict(list(base_config.items()) + list(config.items()))

    # Test basic case.
    x = keras.Input((None, 5))
    c = keras.Input((3,))
    s = keras.Input((32,))
    cell = RNNCellWithConstants(32)
    layer = recurrent.RNN(cell)
    y = layer(x, initial_state=s, constants=c)
    model = keras.models.Model([x, s, c], y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(
        [np.zeros((6, 5, 5)), np.zeros((6, 32)), np.zeros((6, 3))],
        np.zeros((6, 32))
    )

    # Test basic case serialization.
    x_np = np.random.random((6, 5, 5))
    s_np = np.random.random((6, 32))
    c_np = np.random.random((6, 3))
    y_np = model.predict([x_np, s_np, c_np])
    weights = model.get_weights()
    config = layer.get_config()
    custom_objects = {'RNNCellWithConstants': RNNCellWithConstants}
    with keras.utils.CustomObjectScope(custom_objects):
        layer = recurrent.RNN.from_config(config.copy())
    y = layer(x, initial_state=s, constants=c)
    model = keras.models.Model([x, s, c], y)
    model.set_weights(weights)
    y_np_2 = model.predict([x_np, s_np, c_np])
    assert_allclose(y_np, y_np_2, atol=1e-4)

    # verify that state is used
    y_np_2_different_s = model.predict([x_np, s_np + 10., c_np])
    with pytest.raises(AssertionError):
        assert_allclose(y_np, y_np_2_different_s, atol=1e-4)

    # test flat list inputs
    with keras.utils.CustomObjectScope(custom_objects):
        layer = recurrent.RNN.from_config(config.copy())
    y = layer([x, s, c])
    model = keras.models.Model([x, s, c], y)
    model.set_weights(weights)
    y_np_3 = model.predict([x_np, s_np, c_np])
    assert_allclose(y_np, y_np_3, atol=1e-4)


if __name__ == '__main__':
    pytest.main([__file__])