from keras import backend as K
from keras.losses import binary_crossentropy

import tensorflow as tf


def _reset_invalid_joints(y_true, y_pred):
    """Reset (set to zero) invalid joints, according to y_true, and compute the
    number of valid joints.
    """
    idx = K.cast(K.greater(y_true, 0.), 'float32')
    y_true = idx * y_true
    y_pred = idx * y_pred
    num_joints = K.clip(K.sum(idx, axis=(-1, -2)), 1, None)
    return y_true, y_pred, num_joints


def elasticnet_loss_on_valid_joints(y_true, y_pred):
    y_true, y_pred, num_joints = _reset_invalid_joints(y_true, y_pred)
    l1 = K.sum(K.abs(y_pred - y_true), axis=(-1, -2)) / num_joints
    l2 = K.sum(K.square(y_pred - y_true), axis=(-1, -2)) / num_joints
    return l1 + l2


def elasticnet_bincross_loss_on_valid_joints(y_true, y_pred):
    idx = K.cast(K.greater(y_true, 0.), 'float32')
    num_joints = K.clip(K.sum(idx, axis=(-1, -2)), 1, None)

    l1 = K.abs(y_pred - y_true)
    l2 = K.square(y_pred - y_true)
    bc = 0.01*K.binary_crossentropy(y_true, y_pred)
    dummy = 0. * y_pred

    return K.sum(tf.where(K.cast(idx, 'bool'), l1 + l2 + bc, dummy),
            axis=(-1, -2)) / num_joints


def l1_loss_on_valid_joints(y_true, y_pred):
    y_true, y_pred, num_joints = _reset_invalid_joints(y_true, y_pred)
    return K.sum(K.abs(y_pred - y_true), axis=(-1, -2)) / num_joints


def l2_loss_on_valid_joints(y_true, y_pred):
    y_true, y_pred, num_joints = _reset_invalid_joints(y_true, y_pred)
    return K.sum(K.square(y_pred - y_true), axis=(-1, -2)) / num_joints


def pose_regression_loss(pose_loss, visibility_weight):

    def _pose_regression_loss(y_true, y_pred):
        video_clip = K.ndim(y_true) == 4
        if video_clip:
            """The model was time-distributed, so there is one additional
            dimension.
            """
            p_true = y_true[:, :, :, 0:-1]
            p_pred = y_pred[:, :, :, 0:-1]
            v_true = y_true[:, :, :, -1]
            v_pred = y_pred[:, :, :, -1]
        else:
            p_true = y_true[:, :, 0:-1]
            p_pred = y_pred[:, :, 0:-1]
            v_true = y_true[:, :, -1]
            v_pred = y_pred[:, :, -1]

        if pose_loss == 'l1l2':
            ploss = elasticnet_loss_on_valid_joints(p_true, p_pred)
        elif pose_loss == 'l1':
            ploss = l1_loss_on_valid_joints(p_true, p_pred)
        elif pose_loss == 'l2':
            ploss = l2_loss_on_valid_joints(p_true, p_pred)
        elif pose_loss == 'l1l2bincross':
            ploss = elasticnet_bincross_loss_on_valid_joints(p_true, p_pred)
        else:
            raise Exception('Invalid pose_loss option ({})'.format(pose_loss))

        vloss = binary_crossentropy(v_true, v_pred)

        if video_clip:
            """If time-distributed, average the error on video frames."""
            vloss = K.mean(vloss, axis=-1)
            ploss = K.mean(ploss, axis=-1)

        return ploss + visibility_weight*vloss

    return _pose_regression_loss