import numpy as np
import keras.backend as K
from keras.layers import Layer, merge, RepeatVector


class InteractionRNN(Layer):
    """ output response of two input tensors """
    def __init__(self, RNN, num_steps, DNN=None, **kwargs):
        self.RNN = RNN
        self.num_steps = num_steps
        self.DNN = DNN
        super(InteractionRNN, self).__init__(**kwargs)

    def get_output_shape_for(self, input_shape):
        return (None, 1)

    def call(self, x, mask=None):
        U, V = x[0], x[1]
        x = merge([U, V], mode='concat', dot_axes=[1, 1])
        x = RepeatVector(self.num_steps)(x)
        response = self.RNN(x)
        if self.DNN is not None:
            response = self.DNN(response)
        return response