#!/usr/bin/python # -*- coding: utf-8 -*- """ graph_reader.py: Reads graph datasets. Usage: """ import numpy as np import networkx as nx import random import argparse from rdkit import Chem from rdkit.Chem import ChemicalFeatures from rdkit import RDConfig import os from os import listdir from os.path import isfile, join import xml.etree.ElementTree as ET __author__ = "Pau Riba, Anjan Dutta" __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" random.seed(2) np.random.seed(2) def load_dataset(directory, dataset, subdir = '01_Keypoint' ): if dataset == 'enzymes': file_path = join(directory, dataset) files = [f for f in listdir(file_path) if isfile(join(file_path, f))] classes = [] graphs = [] for i in range(len(files)): g, c = create_graph_enzymes(join(directory, dataset, files[i])) graphs += [g] classes += [c] train_graphs, train_classes, valid_graphs, valid_classes, test_graphs, test_classes = divide_datasets(graphs, classes) elif dataset == 'mutag': file_path = join(directory, dataset) files = [f for f in listdir(file_path) if isfile(join(file_path, f))] classes = [] graphs = [] for i in range(len(files)): g, c = create_graph_mutag(join(directory, dataset, files[i])) graphs += [g] classes += [c] train_graphs, train_classes, valid_graphs, valid_classes, test_graphs, test_classes = divide_datasets(graphs, classes) elif dataset == 'MUTAG' or dataset == 'ENZYMES' or dataset == 'NCI1' or \ dataset == 'NCI109' or dataset == 'DD': label_file = dataset + '.label' list_file = dataset + '.list' label_file_path = join(directory, dataset, label_file) list_file_path = join(directory, dataset, list_file) with open(label_file_path, 'r') as f: l = f.read() classes = [int(s) for s in l.split() if s.isdigit()] with open(list_file_path, 'r') as f: files = f.read().splitlines() graphs = load_graphml(join(directory, dataset), files) train_graphs, train_classes, valid_graphs, valid_classes, test_graphs, test_classes = divide_datasets(graphs, classes) elif dataset == 'gwhist': train_classes, train_files = read_2cols_set_files(join(directory,'Set/Train.txt')) test_classes, test_files = read_2cols_set_files(join(directory,'Set/Test.txt')) valid_classes, valid_files = read_2cols_set_files(join(directory,'Set/Valid.txt')) train_classes, valid_classes, test_classes = \ create_numeric_classes(train_classes, valid_classes, test_classes) data_dir = join(directory, 'Data/Word_Graphs/01_Skew', subdir) train_graphs = load_gwhist(data_dir, train_files) valid_graphs = load_gwhist(data_dir, valid_files) test_graphs = load_gwhist(data_dir, test_files) elif dataset == 'qm9': file_path = join(directory, dataset, subdir) files = [f for f in listdir(file_path) if isfile(join(file_path, f))] data_dir = join(directory, dataset, subdir) graphs , labels = load_qm9(data_dir, files) # TODO: Split into train, valid and test sets and class information idx = np.random.permutation(len(labels)) valid_graphs = [graphs[i] for i in idx[0:10000]] valid_classes = [labels[i] for i in idx[0:10000]] test_graphs = [graphs[i] for i in idx[10000:20000]] test_classes = [labels[i] for i in idx[10000:20000]] train_graphs = [graphs[i] for i in idx[20000:]] train_classes = [labels[i] for i in idx[20000:]] return train_graphs, train_classes, valid_graphs, valid_classes, test_graphs, test_classes def create_numeric_classes(train_classes, valid_classes, test_classes): classes = train_classes + valid_classes + test_classes uniq_classes = sorted(list(set(classes))) train_classes_ = [0] * len(train_classes) valid_classes_ = [0] * len(valid_classes) test_classes_ = [0] * len(test_classes) for ix in range(len(uniq_classes)): idx = [i for i, c in enumerate(train_classes) if c == uniq_classes[ix]] for i in idx: train_classes_[i] = ix idx = [i for i, c in enumerate(valid_classes) if c == uniq_classes[ix]] for i in idx: valid_classes_[i] = ix idx = [i for i, c in enumerate(test_classes) if c == uniq_classes[ix]] for i in idx: test_classes_[i] = ix return train_classes_, valid_classes_, test_classes_ def load_gwhist(data_dir, files): graphs = [] for i in range(len(files)): g = create_graph_gwhist(join(data_dir, files[i])) graphs += [g] return graphs def load_graphml(data_dir, files): graphs = [] for i in range(len(files)): g = nx.read_graphml(join(data_dir,files[i])) graphs += [g] return graphs def load_qm9(data_dir, files): graphs = [] labels = [] for i in range(len(files)): g , l = xyz_graph_reader(join(data_dir, files[i])) graphs += [g] labels.append(l) return graphs, labels def read_2cols_set_files(file): f = open(file, 'r') lines = f.read().splitlines() f.close() classes = [] files = [] for line in lines: c, f = line.split(' ')[:2] classes += [c] files += [f + '.gxl'] return classes, files def read_cxl(file): files = [] classes = [] tree_cxl = ET.parse(file) root_cxl = tree_cxl.getroot() for f in root_cxl.iter('print'): files += [f.get('file')] classes += [f.get('class')] return classes, files def divide_datasets(graphs, classes): uc = list(set(classes)) tr_idx = [] va_idx = [] te_idx = [] for c in uc: idx = [i for i, x in enumerate(classes) if x == c] tr_idx += sorted(np.random.choice(idx, int(0.8*len(idx)), replace=False)) va_idx += sorted(np.random.choice([x for x in idx if x not in tr_idx], int(0.1*len(idx)), replace=False)) te_idx += sorted(np.random.choice([x for x in idx if x not in tr_idx and x not in va_idx], int(0.1*len(idx)), replace=False)) train_graphs = [graphs[i] for i in tr_idx] valid_graphs = [graphs[i] for i in va_idx] test_graphs = [graphs[i] for i in te_idx] train_classes = [classes[i] for i in tr_idx] valid_classes = [classes[i] for i in va_idx] test_classes = [classes[i] for i in te_idx] return train_graphs, train_classes, valid_graphs, valid_classes, test_graphs, test_classes def create_graph_enzymes(file): f = open(file, 'r') lines = f.read().splitlines() f.close() # get the indices of the vertext, adj list and class idx_vertex = lines.index("#v - vertex labels") idx_adj_list = lines.index("#a - adjacency list") idx_clss = lines.index("#c - Class") # node label vl = [int(ivl) for ivl in lines[idx_vertex+1:idx_adj_list]] adj_list = lines[idx_adj_list+1:idx_clss] sources = list(range(1,len(adj_list)+1)) for i in range(len(adj_list)): if not adj_list[i]: adj_list[i] = str(sources[i]) else: adj_list[i] = str(sources[i])+","+adj_list[i] g = nx.parse_adjlist(adj_list, nodetype=int, delimiter=",") for i in range(1, g.number_of_nodes()+1): g.node[i]['labels'] = np.array(vl[i-1]) c = int(lines[idx_clss+1]) return g, c def create_graph_mutag(file): f = open(file, 'r') lines = f.read().splitlines() f.close() # get the indices of the vertext, adj list and class idx_vertex = lines.index("#v - vertex labels") idx_edge = lines.index("#e - edge labels") idx_clss = lines.index("#c - Class") # node label vl = [int(ivl) for ivl in lines[idx_vertex+1:idx_edge]] edge_list = lines[idx_edge+1:idx_clss] g = nx.parse_edgelist(edge_list, nodetype=int, data=(('weight', float),), delimiter=",") for i in range(1, g.number_of_nodes()+1): g.node[i]['labels'] = np.array(vl[i-1]) c = int(lines[idx_clss+1]) return g, c def create_graph_gwhist(file): tree_gxl = ET.parse(file) root_gxl = tree_gxl.getroot() vl = [] for node in root_gxl.iter('node'): for attr in node.iter('attr'): if(attr.get('name') == 'x'): x = attr.find('float').text elif(attr.get('name') == 'y'): y = attr.find('float').text vl += [[x, y]] g = nx.Graph() for edge in root_gxl.iter('edge'): s = edge.get('from') s = int(s.split('_')[1]) t = edge.get('to') t = int(t.split('_')[1]) g.add_edge(s, t) for i in range(g.number_of_nodes()): if i not in g.node: g.add_node(i) g.node[i]['labels'] = np.array(vl[i]) return g def isfloat(value): try: float(value) return True except ValueError: return False def create_graph_grec(file): tree_gxl = ET.parse(file) root_gxl = tree_gxl.getroot() vl = [] switch_node = {'circle': 0, 'corner': 1, 'endpoint': 2, 'intersection': 3} switch_edge = {'arc': 0, 'arcarc': 1, 'line': 2, 'linearc': 3} for node in root_gxl.iter('node'): for attr in node.iter('attr'): if (attr.get('name') == 'x'): x = int(attr.find('Integer').text) elif (attr.get('name') == 'y'): y = int(attr.find('Integer').text) elif (attr.get('name') == 'type'): t = switch_node.get(attr.find('String').text, 4) vl += [[x, y, t]] g = nx.Graph() for edge in root_gxl.iter('edge'): s = int(edge.get('from')) t = int(edge.get('to')) for attr in edge.iter('attr'): if(attr.get('name') == 'frequency'): f = attr.find('Integer').text elif(attr.get('name') == 'type0'): ta = switch_edge.get(attr.find('String').text) elif (attr.get('name') == 'angle0'): a = attr.find('String').text if isfloat(a): a = float(a) else: a = 0.0 # TODO: The erroneous string is replaced with 0.0 g.add_edge(s, t, frequency=f, type=ta, angle=a) for i in range(len(vl)): if i not in g.node: g.add_node(i) g.node[i]['labels'] = np.array(vl[i][:3]) return g def create_graph_letter(file): tree_gxl = ET.parse(file) root_gxl = tree_gxl.getroot() vl = [] for node in root_gxl.iter('node'): for attr in node.iter('attr'): if (attr.get('name') == 'x'): x = float(attr.find('float').text) elif (attr.get('name') == 'y'): y = float(attr.find('float').text) vl += [[x, y]] g = nx.Graph() for edge in root_gxl.iter('edge'): s = int(edge.get('from').split('_')[1]) t = int(edge.get('to').split('_')[1]) g.add_edge(s, t) for i in range(len(vl)): if i not in g.node: g.add_node(i) g.node[i]['labels'] = np.array(vl[i][:2]) return g # Initialization of graph for QM9 def init_graph(prop): prop = prop.split() g_tag = prop[0] g_index = int(prop[1]) g_A = float(prop[2]) g_B = float(prop[3]) g_C = float(prop[4]) g_mu = float(prop[5]) g_alpha = float(prop[6]) g_homo = float(prop[7]) g_lumo = float(prop[8]) g_gap = float(prop[9]) g_r2 = float(prop[10]) g_zpve = float(prop[11]) g_U0 = float(prop[12]) g_U = float(prop[13]) g_H = float(prop[14]) g_G = float(prop[15]) g_Cv = float(prop[16]) labels = [g_mu, g_alpha, g_homo, g_lumo, g_gap, g_r2, g_zpve, g_U0, g_U, g_H, g_G, g_Cv] return nx.Graph(tag=g_tag, index=g_index, A=g_A, B=g_B, C=g_C, mu=g_mu, alpha=g_alpha, homo=g_homo, lumo=g_lumo, gap=g_gap, r2=g_r2, zpve=g_zpve, U0=g_U0, U=g_U, H=g_H, G=g_G, Cv=g_Cv), labels # XYZ file reader for QM9 dataset def xyz_graph_reader(graph_file): with open(graph_file,'r') as f: # Number of atoms na = int(f.readline()) # Graph properties properties = f.readline() g, l = init_graph(properties) atom_properties = [] # Atoms properties for i in range(na): a_properties = f.readline() a_properties = a_properties.replace('.*^', 'e') a_properties = a_properties.replace('*^', 'e') a_properties = a_properties.split() atom_properties.append(a_properties) # Frequencies f.readline() # SMILES smiles = f.readline() smiles = smiles.split() smiles = smiles[0] m = Chem.MolFromSmiles(smiles) m = Chem.AddHs(m) fdef_name = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef') factory = ChemicalFeatures.BuildFeatureFactory(fdef_name) feats = factory.GetFeaturesForMol(m) # Create nodes for i in range(0, m.GetNumAtoms()): atom_i = m.GetAtomWithIdx(i) g.add_node(i, a_type=atom_i.GetSymbol(), a_num=atom_i.GetAtomicNum(), acceptor=0, donor=0, aromatic=atom_i.GetIsAromatic(), hybridization=atom_i.GetHybridization(), num_h=atom_i.GetTotalNumHs(), coord=np.array(atom_properties[i][1:4]).astype(np.float), pc=float(atom_properties[i][4])) for i in range(0, len(feats)): if feats[i].GetFamily() == 'Donor': node_list = feats[i].GetAtomIds() for i in node_list: g.node[i]['donor'] = 1 elif feats[i].GetFamily() == 'Acceptor': node_list = feats[i].GetAtomIds() for i in node_list: g.node[i]['acceptor'] = 1 # Read Edges for i in range(0, m.GetNumAtoms()): for j in range(0, m.GetNumAtoms()): e_ij = m.GetBondBetweenAtoms(i, j) if e_ij is not None: g.add_edge(i, j, b_type=e_ij.GetBondType(), distance=np.linalg.norm(g.node[i]['coord']-g.node[j]['coord'])) else: # Unbonded g.add_edge(i, j, b_type=None, distance=np.linalg.norm(g.node[i]['coord'] - g.node[j]['coord'])) return g , l if __name__ == '__main__': g1 = create_graph_grec('/home/adutta/Workspace/Datasets/Graphs/GREC/data/image1_1.gxl') g2 = create_graph_letter('/home/adutta/Workspace/Datasets/STDGraphs/Letter/LOW/AP1_0000.gxl') # Parse optios for downloading parser = argparse.ArgumentParser(description='Read the specified directory, dataset and subdirectory.') # Positional arguments parser.add_argument('--dataset', default='GREC', nargs=1, help='Specify a dataset.') # Optional argument parser.add_argument('--dir', nargs=1, help='Specify the data directory.', default=['../data/']) parser.add_argument('--subdir', nargs=1, help='Specify a subdirectory.') args = parser.parse_args() directory = args.dir[0] dataset = args.dataset[0] if dataset == 'gwhist' or dataset == 'qm9': if args.subdir is None: print('Error: No subdirectory mentioned for the dataset') quit() else: subdir = args.subdir[0] else: subdir = [] print(dataset) train_graphs, train_classes, valid_graphs, valid_classes, test_graphs, test_classes = load_dataset(directory, dataset, subdir) print(len(train_graphs), len(valid_graphs), len(test_graphs))