import sys import random import os.path as osp import shutil import torch from torch_geometric.nn import RENet from torch_geometric.datasets.icews import EventDataset from torch_geometric.data import DataLoader class MyTestEventDataset(EventDataset): def __init__(self, root, seq_len): super(MyTestEventDataset, self).__init__( root, pre_transform=RENet.pre_transform(seq_len)) self.data, self.slices = torch.load(self.processed_paths[0]) @property def num_nodes(self): return 16 @property def num_rels(self): return 8 @property def processed_file_names(self): return 'data.pt' def _download(self): pass def process_events(self): sub = torch.randint(self.num_nodes, (64, ), dtype=torch.long) rel = torch.randint(self.num_rels, (64, ), dtype=torch.long) obj = torch.randint(self.num_nodes, (64, ), dtype=torch.long) t = torch.arange(8, dtype=torch.long).view(-1, 1).repeat(1, 8).view(-1) return torch.stack([sub, rel, obj, t], dim=1) def process(self): data_list = super(MyTestEventDataset, self).process() torch.save(self.collate(data_list), self.processed_paths[0]) def test_re_net(): root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize))) dataset = MyTestEventDataset(root, seq_len=4) loader = DataLoader(dataset, 2, follow_batch=['h_sub', 'h_obj']) model = RENet( dataset.num_nodes, dataset.num_rels, hidden_channels=16, seq_len=4) logits = torch.randn(6, 6) y = torch.tensor([0, 1, 2, 3, 4, 5]) mrr, hits1, hits3, hits10 = model.test(logits, y) assert 0.15 < mrr <= 1 assert hits1 <= hits3 and hits3 <= hits10 and hits10 == 1 for data in loader: log_prob_obj, log_prob_sub = model(data) model.test(log_prob_obj, data.obj) model.test(log_prob_sub, data.sub) shutil.rmtree(root)