'''Create a recurrent neural network to compute a control policy.
'''

from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.recurrent import SimpleRNN


def create_rnn():
    """Create a recurrent neural network to compute a control policy.

    Reference:
    Koutnik, Jan, Jurgen Schmidhuber, and Faustino Gomez. "Evolving deep
    unsupervised convolutional networks for vision-based reinforcement
    learning." Proceedings of the 2014 conference on Genetic and
    evolutionary computation. ACM, 2014.
    """
    model = Sequential()

    model.add(SimpleRNN(output_dim=3, stateful=True, batch_input_shape=(1, 1, 3)))
    model.add(Dense(input_dim=3, output_dim=3))

    model.compile(loss='mse', optimizer='rmsprop')

    return model


def calculate_rnn_output(model, input, multiple=False):
    """Calculates the output of the RNN. Use vector=False to indicate that a
    single input is being passed.
    """
    output = model.predict(input)
    if multiple:
        output = output.reshape(output.shape[0], output.shape[1])
    else:
        output = output.reshape(output.shape[1])

    return output