import tensorflow as tf
from tensorflow.contrib.learn import ModeKeys


def dot_semantic_nn(context, utterance, tng_mode):
    """
    Creates a two channel network that encodes two different inputs and calculates
    a dot product to measure how close they are.

    :param context: Embedding of summed word embeddings for the context (1 x emb_dim)
    :param utterance: Embedding of summed word embeddings for the utterance (1 x emb_dim)
    :param tng_mode: tf ModeKeys key
    :return:
    """

    # disable dropout during inference
    # only time it's used is during training
    keep_prob = 0.5
    if tng_mode == ModeKeys.TRAIN:
        keep_prob = 0.5

    # -----------------------------
    # CONTEXT CHANNEL
    # this channel encodes the context
    # this is a (1 x emb_size) vector
    context_channel = _network_channel(network_name='context_channel', net_input=context, keep_prob=keep_prob)

    # -----------------------------
    # UTTERANCE CHANNEL
    # this channel encodes the utterance
    # this is a (1 x emb_size) vector
    utterance_channel = _network_channel(network_name='utterance_channel', net_input=utterance, keep_prob=keep_prob)

    # -----------------------------
    # LOSS
    # negative log probability while using K-1 examples in the batch
    # as negative samples
    mean_loss = _negative_log_probability_loss(context_channel, utterance_channel)
    K = tf.matmul(context_channel, utterance_channel, transpose_b=True)

    # return the loss and the encoding from each channel
    return mean_loss, context_channel, utterance_channel, K


def _negative_log_probability_loss(context_channel, utterance_channel):
    """
    This implements the negative log probability using negative sampling
    where K-1 items in the batch are treated as negative samples where
    i != j.

    The overall loss formula is:
    $$L(x,y,\theta) = -\frac{1}{K}\sum_{i=1}^{K}{[ f(x_i, y_i) - log \sum_{j=1}^{K}{e^{f(x_i,y_i)}}]}$$

    :param context_channel:
    :param utterance_channel:
    :return:
    """
    # calculate dot product between each pair of inputs and responses
    # (bs x bs)
    K = tf.matmul(context_channel, utterance_channel, transpose_b=True)

    # get the diagonals which are the S(x_i, y_i)
    # this represents the similarity score between each input x_i and output y_i
    # out = (bs x 1)
    S = tf.diag_part(K)
    S = tf.reshape(S, [-1, 1])

    # calculate the log sum(e^x_i, y_j)
    # here every row has only the negative examples
    # in = (bs x bs).  out = (bs x 1)
    K = tf.reduce_logsumexp(K, axis=1, keep_dims=True)

    # compute the mean loss between each x,y pair
    # and the log sum of each other (K-1) x,y pair
    return -tf.reduce_mean(S - K)


def _network_channel(network_name, net_input, keep_prob):
    """
    Generates an n layer Dense network that encodes the inputs
    into a k dimensional space
    :param network_name:
    :param net_input:
    :param keep_prob:
    :return:
    """
    with tf.variable_scope(network_name) as scope:
        predict_opt_name = '{}_branch_predict'.format(network_name)

        # use 3 dense layers for this network branch
        with tf.variable_scope('dense_branch') as d_scope:
            dense_0 = tf.layers.dense(net_input, units=300, activation=tf.nn.tanh)
            dense_0 = tf.layers.batch_normalization(dense_0)
            dense_0 = tf.layers.dropout(inputs=dense_0, rate=keep_prob)

            dense_1 = tf.layers.dense(dense_0, units=300, activation=tf.nn.tanh)
            dense_1 = tf.layers.batch_normalization(dense_1)
            dense_1 = tf.layers.dropout(inputs=dense_1, rate=keep_prob)

            dense_2 = tf.layers.dense(dense_1, units=500, activation=tf.nn.tanh, name=predict_opt_name)
            tf.add_to_collection('{}_embed_opt'.format(network_name), dense_2)

        return dense_2