from keras.layers import K, RepeatVector, Lambda


def repeat_vector(inputs):
    """
    Temporary solution:
    Use this function within a Lambda layer to get a repeated layer with a variable 1-st dimension (seq_len).
    May be useful to further feed it to a Concatenate layer.

    inputs == (layer_for_repeat, layer_for_getting_rep_num):
        layer_for_repeat:           shape == (batch_size, vector_dim)
        layer_for_getting_rep_num:  shape == (batch_size, seq_len, ...)
    :return:
        repeated layer_for_repeat, shape == (batch_size, seq_len, vector_dim)
    """
    layer_for_repeat, layer_for_getting_rep_num = inputs
    repeated_vector = RepeatVector(
        n=K.shape(layer_for_getting_rep_num)[1], name='custom_repeat_vector')(layer_for_repeat)
    # shape == (batch_size, seq_len, vector_dim)
    return repeated_vector


def softmax_with_temperature(logits, temperature):
    """
    :param logits:      shape == (batch_size, seq_len, vocab_size), float32
    :param temperature: shape == (batch_size, 1), float32
    :return:
        transformed tokens probs, shape == (batch_size, seq_len, vocab_size), float32
    """

    def softmax_with_temp(args):
        logits, temperature = args
        repeat_num = K.shape(logits)[1]
        temperature_repeated = RepeatVector(repeat_num)(temperature)
        # shape == (batch_size, seq_len)
        scaled_logits = logits / temperature_repeated
        # shape == (batch_size, seq_len, vocab_size)

        # for numerical stability (e.g. for low temperatures):
        scaled_logits = scaled_logits - K.max(scaled_logits, axis=2, keepdims=True)
        # shape == (batch_size, seq_len, vocab_size)
        transformed_probs = K.softmax(scaled_logits)
        # shape == (batch_size, seq_len, vocab_size)
        return transformed_probs

    # wrap transformation in Lambda to turn the result to Keras layer
    transformed_probs = Lambda(
        function=softmax_with_temp,
        mask=lambda inputs, inputs_masks: inputs_masks[0],  # function to get mask of the first input
        name='softmax_with_temperature')([logits, temperature])
    # output shape == (batch_size, seq_len, vocab_size)

    return transformed_probs