import tensorflow as tf
TINY = 1e-5
class BaseGAttN:
    def loss(logits, labels, nb_classes, class_weights):
        sample_wts = tf.reduce_sum(tf.multiply(tf.one_hot(labels, nb_classes), class_weights), axis=-1)
        xentropy = tf.multiply(tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=labels, logits=logits), sample_wts)
        return tf.reduce_mean(xentropy, name='xentropy_mean')

    def training(loss, lr, l2_coef):
        # weight decay
        vars = tf.trainable_variables()
        lossL2 = tf.add_n([tf.nn.l2_loss(v) for v in vars if v.name not
                           in ['bias', 'gamma', 'b', 'g', 'beta']]) * l2_coef

        # optimizer
        opt = tf.train.AdamOptimizer(learning_rate=lr)

        # training op
        train_op = opt.minimize(loss+lossL2)
        
        return train_op

    def preshape(logits, labels, nb_classes):
        new_sh_lab = [-1]
        new_sh_log = [-1, nb_classes]
        log_resh = tf.reshape(logits, new_sh_log)
        lab_resh = tf.reshape(labels, new_sh_lab)
        return log_resh, lab_resh

    def confmat(logits, labels):
        preds = tf.argmax(logits, axis=1)
        return tf.confusion_matrix(labels, preds)

##########################
# Adapted from tkipf/gcn #
##########################

    def masked_softmax_cross_entropy(logits, labels, mask):
        """Softmax cross-entropy loss with masking."""
        loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels)
        mask = tf.cast(mask, dtype=tf.float32)
        mask /= tf.reduce_mean(mask)
        loss *= mask
        return tf.reduce_mean(loss)

    def masked_sigmoid_cross_entropy(logits, labels, mask):
        """Softmax cross-entropy loss with masking."""
        labels = tf.cast(labels, dtype=tf.float32)
        loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)
        loss=tf.reduce_mean(loss,axis=1)
        mask = tf.cast(mask, dtype=tf.float32)
        mask /= tf.reduce_mean(mask)
        loss *= mask
        return tf.reduce_mean(loss)

    def masked_accuracy(logits, labels, mask):
        """Accuracy with masking."""
        correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
        accuracy_all = tf.cast(correct_prediction, tf.float32)
        mask = tf.cast(mask, dtype=tf.float32)
        mask /= tf.reduce_mean(mask)
        accuracy_all *= mask
        return tf.reduce_mean(accuracy_all)

    def micro_f1(logits, labels, mask):
        """Accuracy with masking."""
        predicted = tf.round(tf.nn.sigmoid(logits))

        # Use integers to avoid any nasty FP behaviour
        predicted = tf.cast(predicted, dtype=tf.int32)
        labels = tf.cast(labels, dtype=tf.int32)
        mask = tf.cast(mask, dtype=tf.int32)

        # expand the mask so that broadcasting works ([nb_nodes, 1])
        mask = tf.expand_dims(mask, -1)
        
        # Count true positives, true negatives, false positives and false negatives.
        tp = tf.cast(tf.count_nonzero(predicted * labels * mask), tf.float32)
        tn = tf.cast(tf.count_nonzero((predicted - 1) * (labels - 1) * mask), tf.float32)
        fp = tf.cast(tf.count_nonzero(predicted * (labels - 1) * mask), tf.float32)
        fn = tf.cast(tf.count_nonzero((predicted - 1) * labels * mask), tf.float32)
        
        # Calculate accuracy, precision, recall and F1 score.
        precision = tp / (tp + fp + TINY)
        recall = tp / (tp + fn + TINY)
        fmeasure = (2 * precision * recall) / (precision + recall + TINY)
        return fmeasure


    def micro_f1_onelabel(logits, labels, mask):
        predicted = tf.argmax(tf.nn.softmax(logits), axis=1)
        pre = tf.one_hot(predicted, depth=tf.shape(labels)[-1], dtype=tf.int32)
        labels = tf.cast(labels, dtype=tf.int32)
        # expand the mask so that broadcasting works ([nb_nodes, 1])
        mask = tf.expand_dims(tf.cast(mask, dtype=tf.int32), -1)
        # Count true positives, true negatives, false positives and false negatives.
        tp = tf.cast(tf.count_nonzero(pre * labels * mask), tf.float32)
        tn = tf.cast(tf.count_nonzero((pre - 1) * (labels - 1) * mask), tf.float32)
        fp = tf.cast(tf.count_nonzero(pre * (labels - 1) * mask), tf.float32)
        fn = tf.cast(tf.count_nonzero((pre - 1) * labels * mask), tf.float32)
        
        # Calculate accuracy, precision, recall and F1 score.
        precision = tp / (tp + fp + TINY)
        recall = tp / (tp + fn + TINY)
        fmeasure = (2 * precision * recall) / (precision + recall + TINY)
        return fmeasure


    def macro_f1(logits, labels, mask):
        """Accuracy with masking."""
        predicted = tf.round(tf.nn.sigmoid(logits))

        # Use integers to avoid any nasty FP behaviour
        predicted = tf.cast(predicted, dtype=tf.int32)
        labels = tf.cast(labels, dtype=tf.int32)
        mask = tf.cast(mask, dtype=tf.int32)

        # expand the mask so that broadcasting works ([nb_nodes, 1])
        mask = tf.expand_dims(mask, -1)
        
        # Count true positives, true negatives, false positives and false negatives.
        tp = tf.cast(tf.count_nonzero(predicted * labels * mask, axis=0), tf.float32)
        tn = tf.cast(tf.count_nonzero((predicted - 1) * (labels - 1) * mask, axis=0), tf.float32)
        fp = tf.cast(tf.count_nonzero(predicted * (labels - 1) * mask, axis=0), tf.float32)
        fn = tf.cast(tf.count_nonzero((predicted - 1) * labels * mask, axis=0), tf.float32)
        
        # Calculate accuracy, precision, recall and F1 score.
        precision = tf.reduce_mean(tf.divide(tp, tp + fp + TINY))
        recall = tf.reduce_mean(tf.divide(tp, tp + fn + TINY))
        fmeasure = (2 * precision * recall) / (precision + recall + TINY)
        return fmeasure


    def macro_f1_onelabel(logits, labels, mask):
        predicted = tf.argmax(tf.nn.softmax(logits), axis=1)
        pre = tf.one_hot(predicted, depth=tf.shape(labels)[-1], dtype=tf.int32)
        labels = tf.cast(labels, dtype=tf.int32)
        # expand the mask so that broadcasting works ([nb_nodes, 1])
        mask = tf.expand_dims(tf.cast(mask, dtype=tf.int32), -1)
        # Count true positives, true negatives, false positives and false negatives.
        tp = tf.cast(tf.count_nonzero(pre * labels * mask, axis=0), tf.float32)
        tn = tf.cast(tf.count_nonzero((pre - 1) * (labels - 1) * mask, axis=0), tf.float32)
        fp = tf.cast(tf.count_nonzero(pre * (labels - 1) * mask, axis=0), tf.float32)
        fn = tf.cast(tf.count_nonzero((pre - 1) * labels * mask, axis=0), tf.float32)
        
        # Calculate accuracy, precision, recall and F1 score.
        precision = tf.reduce_mean(tf.divide(tp, tp + fp + TINY))
        recall = tf.reduce_mean(tf.divide(tp, tp + fn + TINY))
        fmeasure = (2 * precision * recall) / (precision + recall + TINY)
        return fmeasure