"""Graph Attention Mechanism.""" import glob import json import torch import random import numpy as np import pandas as pd import networkx as nx from tqdm import tqdm, trange from utils import calculate_reward, calculate_predictive_loss from utils import read_node_labels, create_logs, create_features, create_batches class StepNetworkLayer(torch.nn.Module): """ Step Network Layer Class for selecting next node to move. """ def __init__(self, args, identifiers): """ Initializing the layer. :param args: Arguments object. :param identifiers: Node type -- id hash map. """ super(StepNetworkLayer, self).__init__() self.identifiers = identifiers self.args = args self.setup_attention() self.create_parameters() def setup_attention(self): """ Initial attention generation with uniform attention scores. """ self.attention = torch.ones((len(self.identifiers)))/len(self.identifiers) def create_parameters(self): """ Creating trainable weights and initlaizing them. """ self.theta_step_1 = torch.nn.Parameter(torch.Tensor(len(self.identifiers), self.args.step_dimensions)) self.theta_step_2 = torch.nn.Parameter(torch.Tensor(len(self.identifiers), self.args.step_dimensions)) self.theta_step_3 = torch.nn.Parameter(torch.Tensor(2*self.args.step_dimensions, self.args.combined_dimensions)) torch.nn.init.uniform_(self.theta_step_1, -1, 1) torch.nn.init.uniform_(self.theta_step_2, -1, 1) torch.nn.init.uniform_(self.theta_step_3, -1, 1) def sample_node_label(self, orig_neighbors, graph, features): """ Sampling a label from the neighbourhood. :param original_neighbors: Neighbours of the source node. :param graph: NetworkX graph. :param features: Node feature matrix. :return label: Label sampled from the neighbourhood with attention. """ neighbor_vector = torch.tensor([1.0 if n in orig_neighbors else 0.0 for n in graph.nodes()]) neighbor_features = torch.mm(neighbor_vector.view(1, -1), features) attention_spread = self.attention * neighbor_features normalized_attention_spread = attention_spread / attention_spread.sum() normalized_attention_spread = normalized_attention_spread.detach().numpy().reshape(-1) label = np.random.choice(np.arange(len(self.identifiers)), p=normalized_attention_spread) return label def make_step(self, node, graph, features, labels, inverse_labels): """ :param node: Source node for step. :param graph: NetworkX graph. :param features: Feature matrix. :param labels: Node labels hash table. :param inverse_labels: Inverse node label hash table. """ orig_neighbors = set(nx.neighbors(graph, node)) label = self.sample_node_label(orig_neighbors, graph, features) labels = list(set(orig_neighbors).intersection(set(inverse_labels[str(label)]))) new_node = random.choice(labels) new_node_attributes = torch.zeros((len(self.identifiers), 1)) new_node_attributes[label, 0] = 1.0 attention_score = self.attention[label] return new_node_attributes, new_node, attention_score def forward(self, data, graph, features, node): """ Making a forward propagation step. :param data: Data hash table. :param graph: NetworkX graph object. :param features: Feature matrix of the graph. :param node: Base node where the step is taken from. :return state: State vector. :return node: New node to move to. :return attention_score: Attention score of chosen node. """ feature_row, node, attention_score = self.make_step(node, graph, features, data["labels"], data["inverse_labels"]) hidden_attention = torch.mm(self.attention.view(1, -1), self.theta_step_1) hidden_node = torch.mm(torch.t(feature_row), self.theta_step_2) combined_hidden_representation = torch.cat((hidden_attention, hidden_node), dim=1) state = torch.mm(combined_hidden_representation, self.theta_step_3) state = state.view(1, 1, self.args.combined_dimensions) return state, node, attention_score class DownStreamNetworkLayer(torch.nn.Module): """ Neural network layer for attention update and node label assignment. """ def __init__(self, args, target_number, identifiers): """ :param args: :param target_number: :param identifiers: """ super(DownStreamNetworkLayer, self).__init__() self.args = args self.target_number = target_number self.identifiers = identifiers self.create_parameters() def create_parameters(self): """ Defining and initializing the classification and attention update weights. """ self.theta_classification = torch.nn.Parameter(torch.Tensor(self.args.combined_dimensions, self.target_number)) self.theta_rank = torch.nn.Parameter(torch.Tensor(self.args.combined_dimensions, len(self.identifiers))) torch.nn.init.xavier_normal_(self.theta_classification) torch.nn.init.xavier_normal_(self.theta_rank) def forward(self, hidden_state): """ Making a forward propagation pass with the input from the LSTM layer. :param hidden_state: LSTM state used for labeling and attention update. """ predictions = torch.mm(hidden_state.view(1, -1), self.theta_classification) attention = torch.mm(hidden_state.view(1, -1), self.theta_rank) attention = torch.nn.functional.softmax(attention, dim=1) return predictions, attention class GAM(torch.nn.Module): """ Graph Attention Machine class. """ def __init__(self, args): """ Initializing the machine. :param args: Arguments object. """ super(GAM, self).__init__() self.args = args self.identifiers, self.class_number = read_node_labels(self.args) self.step_block = StepNetworkLayer(self.args, self.identifiers) self.recurrent_block = torch.nn.LSTM(self.args.combined_dimensions, self.args.combined_dimensions, 1) self.down_block = DownStreamNetworkLayer(self.args, self.class_number, self.identifiers) self.reset_attention() def reset_attention(self): """ Resetting the attention and hidden states. """ self.step_block.attention = torch.ones((len(self.identifiers)))/len(self.identifiers) self.lstm_h_0 = torch.randn(1, 1, self.args.combined_dimensions) self.lstm_c_0 = torch.randn(1, 1, self.args.combined_dimensions) def forward(self, data, graph, features, node): """ Doing a forward pass on a graph from a given node. :param data: Data dictionary. :param graph: NetworkX graph. :param features: Feature tensor. :param node: Source node identifier. :return label_predictions: Label prediction. :return node: New node to move to. :return attention_score: Attention score on selected node. """ self.state, node, attention_score = self.step_block(data, graph, features, node) lstm_output, (self.h0, self.c0) = self.recurrent_block(self.state, (self.lstm_h_0, self.lstm_c_0)) label_predictions, attention = self.down_block(lstm_output) self.step_block.attention = attention.view(-1) label_predictions = torch.nn.functional.log_softmax(label_predictions, dim=1) return label_predictions, node, attention_score class GAMTrainer(object): """ Object to train a GAM model. """ def __init__(self, args): self.args = args self.model = GAM(args) self.setup_graphs() self.logs = create_logs(self.args) def setup_graphs(self): """ Listing the training and testing graphs in the source folders. """ self.training_graphs = glob.glob(self.args.train_graph_folder + "*.json") self.test_graphs = glob.glob(self.args.test_graph_folder + "*.json") def process_graph(self, graph_path, batch_loss): """ Reading a graph and doing a forward pass on a graph with a time budget. :param graph_path: Location of the graph to process. :param batch_loss: Loss on the graphs processed so far in the batch. :return batch_loss: Incremented loss on the current batch being processed. """ data = json.load(open(graph_path)) graph, features = create_features(data, self.model.identifiers) node = random.choice(list(graph.nodes())) attention_loss = 0 for t in range(self.args.time): predictions, node, attention_score = self.model(data, graph, features, node) target, prediction_loss = calculate_predictive_loss(data, predictions) batch_loss = batch_loss + prediction_loss if t < self.args.time-2: attention_loss += (self.args.gamma**(self.args.time-t))*torch.log(attention_score) reward = calculate_reward(target, predictions) batch_loss = batch_loss-reward*attention_loss self.model.reset_attention() return batch_loss def process_batch(self, batch): """ Forward and backward propagation on a batch of graphs. :param batch: Batch if graphs. :return loss_value: Value of loss on batch. """ self.optimizer.zero_grad() batch_loss = 0 for graph_path in batch: batch_loss = self.process_graph(graph_path, batch_loss) batch_loss.backward(retain_graph=True) self.optimizer.step() loss_value = batch_loss.item() self.optimizer.zero_grad() return loss_value def update_log(self): """ Adding the end of epoch loss to the log. """ average_loss = self.epoch_loss/self.nodes_processed self.logs["losses"].append(average_loss) def fit(self): """ Fitting a model on the training dataset. """ print("\nTraining started.\n") self.model.train() self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay) self.optimizer.zero_grad() epoch_range = trange(self.args.epochs, desc="Epoch: ", leave=True) for _ in epoch_range: random.shuffle(self.training_graphs) batches = create_batches(self.training_graphs, self.args.batch_size) self.epoch_loss = 0 self.nodes_processed = 0 batch_range = trange(len(batches)) for batch in batch_range: self.epoch_loss = self.epoch_loss + self.process_batch(batches[batch]) self.nodes_processed = self.nodes_processed + len(batches[batch]) loss_score = round(self.epoch_loss/self.nodes_processed, 4) batch_range.set_description("(Loss=%g)" % loss_score) self.update_log() def score_graph(self, data, prediction): """ Scoring the prediction on the graph. :param data: Data hash table of graph. :param prediction: Label prediction. """ target = data["target"] is_it_right = (target == prediction) self.predictions.append(is_it_right) def score(self): """ Scoring the test set graphs. """ print("\n") print("\nScoring the test set.\n") self.model.eval() self.predictions = [] for data in tqdm(self.test_graphs): data = json.load(open(data)) graph, features = create_features(data, self.model.identifiers) node_predictions = [] for _ in range(self.args.repetitions): node = random.choice(list(graph.nodes())) for _ in range(self.args.time): prediction, node, _ = self.model(data, graph, features, node) node_predictions.append(np.argmax(prediction.detach())) self.model.reset_attention() prediction = max(set(node_predictions), key=node_predictions.count) self.score_graph(data, prediction) self.accuracy = float(np.mean(self.predictions)) print("\nThe test set accuracy is: "+str(round(self.accuracy, 4))+".\n") def save_predictions_and_logs(self): """ Saving the predictions as a csv file and logs as a JSON. """ self.logs["test_accuracy"] = self.accuracy with open(self.args.log_path, "w") as f: json.dump(self.logs, f) cols = ["graph_id", "predicted_label"] predictions = [[self.test_graphs[i], self.predictions[i].item()] for i in range(len(self.test_graphs))] self.output_data = pd.DataFrame(predictions, columns=cols) self.output_data.to_csv(self.args.prediction_path, index=None)