#Copyright 2018 Google LLC
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.


"""Inference step for joint node classification and link prediction models."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


from models.base_models import NodeEdgeModel
from models.edge_models import Gae
from models.node_models import Gat
from models.node_models import Gcn
import tensorflow as tf
from utils.model_utils import compute_adj
from utils.model_utils import gat_module
from utils.model_utils import gcn_module
from utils.model_utils import get_sp_topk
from utils.model_utils import mask_edges


class GaeGat(NodeEdgeModel):
  """GAE for link prediction and GAT for node classification."""

  def __init__(self, config):
    """Initializes EGCNGAT model."""
    super(GaeGat, self).__init__(config)
    self.edge_model = Gae(config)
    self.node_model = Gat(config)

  def compute_inference(self, node_features_in, sp_adj_matrix, is_training):
    adj_matrix_pred = self.edge_model.compute_inference(
        node_features_in, sp_adj_matrix, is_training)
    self.adj_matrix_pred = adj_matrix_pred
    adj_mask = get_sp_topk(adj_matrix_pred, sp_adj_matrix, self.nb_nodes,
                           self.topk)
    self.adj_mask = adj_mask
    # masked_adj_matrix_pred = tf.multiply(adj_mask,
    #                                      tf.nn.sigmoid(adj_matrix_pred))
    masked_adj_matrix_pred = mask_edges(tf.nn.sigmoid(adj_matrix_pred),
                                        adj_mask)
    sp_adj_pred = tf.contrib.layers.dense_to_sparse(masked_adj_matrix_pred)
    logits = self.node_model.compute_inference(node_features_in, sp_adj_pred,
                                               is_training)
    return logits, adj_matrix_pred


class GaeGcn(NodeEdgeModel):
  """GAE for link prediction and GCN for node classification."""

  def __init__(self, config):
    """Initializes EGCNGCN model."""
    super(GaeGcn, self).__init__(config)
    self.edge_model = Gae(config)
    self.node_model = Gcn(config)

  def compute_inference(self, node_features_in, sp_adj_matrix, is_training):
    adj_matrix_pred = self.edge_model.compute_inference(
        node_features_in, sp_adj_matrix, is_training)
    self.adj_matrix_pred = adj_matrix_pred
    adj_mask = get_sp_topk(adj_matrix_pred, sp_adj_matrix, self.nb_nodes,
                           self.topk)
    sp_adj_pred = tf.contrib.layers.dense_to_sparse(
        tf.multiply(adj_mask, tf.nn.leaky_relu(adj_matrix_pred)))
    sp_adj_pred = tf.sparse_softmax(sp_adj_pred)
    logits = self.node_model.compute_inference(node_features_in, sp_adj_pred,
                                               is_training)
    return logits, adj_matrix_pred


############################ EXPERIMENTAL MODELS #############################


class GatGraphite(NodeEdgeModel):
  """Gae for link prediction and GCN for node classification."""

  def compute_inference(self, node_features_in, sp_adj_matrix, is_training):
    with tf.variable_scope('edge-model'):
      z_latent = gat_module(
          node_features_in,
          sp_adj_matrix,
          self.n_hidden_edge,
          self.n_att_edge,
          self.p_drop_edge,
          is_training,
          self.input_dim,
          self.sparse_features,
          average_last=False)
      adj_matrix_pred = compute_adj(z_latent, self.att_mechanism,
                                    self.p_drop_edge, is_training)
      self.adj_matrix_pred = adj_matrix_pred
    with tf.variable_scope('node-model'):
      concat = True
      if concat:
        z_latent = tf.sparse_concat(
            axis=1,
            sp_inputs=[
                tf.contrib.layers.dense_to_sparse(z_latent), node_features_in
            ],
        )
        sparse_features = True
        input_dim = self.n_hidden_edge[-1] * self.n_att_edge[
            -1] + self.input_dim
      else:
        sparse_features = False
        input_dim = self.n_hidden_edge[-1] * self.n_att_edge[-1]
      logits = gat_module(
          z_latent,
          sp_adj_matrix,
          self.n_hidden_node,
          self.n_att_node,
          self.p_drop_node,
          is_training,
          input_dim,
          sparse_features=sparse_features,
          average_last=False)

    return logits, adj_matrix_pred


class GaeGatConcat(NodeEdgeModel):
  """EGCN for link prediction and GCN for node classification."""

  def __init__(self, config):
    """Initializes EGCN_GAT model."""
    super(GaeGatConcat, self).__init__(config)
    self.edge_model = Gae(config)
    self.node_model = Gat(config)

  def compute_inference(self, node_features_in, sp_adj_matrix, is_training):
    with tf.variable_scope('edge-model'):
      z_latent = gcn_module(node_features_in, sp_adj_matrix, self.n_hidden_edge,
                            self.p_drop_edge, is_training, self.input_dim,
                            self.sparse_features)
      adj_matrix_pred = compute_adj(z_latent, self.att_mechanism,
                                    self.p_drop_edge, is_training)
      self.adj_matrix_pred = adj_matrix_pred
    with tf.variable_scope('node-model'):
      z_latent = tf.sparse_concat(
          axis=1,
          sp_inputs=[
              tf.contrib.layers.dense_to_sparse(z_latent), node_features_in
          ])
      sparse_features = True
      input_dim = self.n_hidden_edge[-1] + self.input_dim
      sp_adj_train = tf.SparseTensor(
          indices=sp_adj_matrix.indices,
          values=tf.ones_like(sp_adj_matrix.values),
          dense_shape=sp_adj_matrix.dense_shape)
      logits = gat_module(
          z_latent,
          sp_adj_train,
          self.n_hidden_node,
          self.n_att_node,
          self.p_drop_node,
          is_training,
          input_dim,
          sparse_features=sparse_features,
          average_last=True)
    return logits, adj_matrix_pred


class GaeGcnConcat(NodeEdgeModel):
  """EGCN for link prediction and GCN for node classification."""

  def compute_inference(self, node_features_in, sp_adj_matrix, is_training):
    with tf.variable_scope('edge-model'):
      z_latent = gcn_module(node_features_in, sp_adj_matrix, self.n_hidden_edge,
                            self.p_drop_edge, is_training, self.input_dim,
                            self.sparse_features)
      adj_matrix_pred = compute_adj(z_latent, self.att_mechanism,
                                    self.p_drop_edge, is_training)
      self.adj_matrix_pred = adj_matrix_pred
    with tf.variable_scope('node-model'):
      z_latent = tf.sparse_concat(
          axis=1,
          sp_inputs=[
              tf.contrib.layers.dense_to_sparse(z_latent), node_features_in
          ])
      sparse_features = True
      input_dim = self.n_hidden_edge[-1] + self.input_dim
      logits = gcn_module(
          z_latent,
          sp_adj_matrix,
          self.n_hidden_node,
          self.p_drop_node,
          is_training,
          input_dim,
          sparse_features=sparse_features)
    return logits, adj_matrix_pred


class Gcat(NodeEdgeModel):
  """1 iteration Graph Convolution Attention Model."""

  def __init__(self, config):
    """Initializes GCAT model."""
    super(Gcat, self).__init__(config)
    self.edge_model = Gae(config)
    self.node_model = Gcn(config)

  def compute_inference(self, node_features_in, sp_adj_matrix, is_training):
    """Forward pass for GAT model."""
    adj_matrix_pred = self.edge_model.compute_inference(
        node_features_in, sp_adj_matrix, is_training)
    sp_adj_mask = tf.SparseTensor(
        indices=sp_adj_matrix.indices,
        values=tf.ones_like(sp_adj_matrix.values),
        dense_shape=sp_adj_matrix.dense_shape)
    sp_adj_att = sp_adj_mask * adj_matrix_pred
    sp_adj_att = tf.SparseTensor(
        indices=sp_adj_att.indices,
        values=tf.nn.leaky_relu(sp_adj_att.values),
        dense_shape=sp_adj_att.dense_shape)
    sp_adj_att = tf.sparse_softmax(sp_adj_att)
    logits = self.node_model.compute_inference(node_features_in, sp_adj_att,
                                               is_training)
    return logits, adj_matrix_pred