import sys
import random
import os.path as osp
import shutil

import torch
from torch_geometric.nn import RENet
from torch_geometric.datasets.icews import EventDataset
from torch_geometric.data import DataLoader


class MyTestEventDataset(EventDataset):
    def __init__(self, root, seq_len):
        super(MyTestEventDataset, self).__init__(
            root, pre_transform=RENet.pre_transform(seq_len))
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def num_nodes(self):
        return 16

    @property
    def num_rels(self):
        return 8

    @property
    def processed_file_names(self):
        return 'data.pt'

    def _download(self):
        pass

    def process_events(self):
        sub = torch.randint(self.num_nodes, (64, ), dtype=torch.long)
        rel = torch.randint(self.num_rels, (64, ), dtype=torch.long)
        obj = torch.randint(self.num_nodes, (64, ), dtype=torch.long)
        t = torch.arange(8, dtype=torch.long).view(-1, 1).repeat(1, 8).view(-1)
        return torch.stack([sub, rel, obj, t], dim=1)

    def process(self):
        data_list = super(MyTestEventDataset, self).process()
        torch.save(self.collate(data_list), self.processed_paths[0])


def test_re_net():
    root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    dataset = MyTestEventDataset(root, seq_len=4)
    loader = DataLoader(dataset, 2, follow_batch=['h_sub', 'h_obj'])

    model = RENet(
        dataset.num_nodes, dataset.num_rels, hidden_channels=16, seq_len=4)

    logits = torch.randn(6, 6)
    y = torch.tensor([0, 1, 2, 3, 4, 5])

    mrr, hits1, hits3, hits10 = model.test(logits, y)
    assert 0.15 < mrr <= 1
    assert hits1 <= hits3 and hits3 <= hits10 and hits10 == 1

    for data in loader:
        log_prob_obj, log_prob_sub = model(data)
        model.test(log_prob_obj, data.obj)
        model.test(log_prob_sub, data.sub)

    shutil.rmtree(root)