import pytest import json import numpy as np from torch import nn as nn import fackel from questionanswering.construction import sentence from questionanswering import _utils from questionanswering.models.modules import ConvWordsEncoder from questionanswering.models.gnn import GNNModel from questionanswering.models import vectorization as V wordembeddings, word2idx = V.extend_embeddings_with_special_tokens( *_utils.load_word_embeddings(_utils.RESOURCES_FOLDER + "../../resources/embeddings/glove/glove.6B.100d.txt") ) with open(_utils.RESOURCES_FOLDER + "../data/generated/webqsp.examples.train.silvergraphs.02-12.el.unittests.json") as f: training_dataset = json.load(f, object_hook=sentence.sentence_object_hook) def test_encode_structure(): train_questions = V.encode_batch_graph_structure(training_dataset, word2idx) print(train_questions[0]) def test_load_parameters(): encoder = ConvWordsEncoder(*wordembeddings.shape) encoder.load_word_embeddings_from_numpy(wordembeddings) net = GNNModel(encoder, hp_dropout=0.2) criterion = nn.MultiMarginLoss(margin=0.5) container = fackel.TorchContainer( torch_model=net, batch_size=8, max_epochs=5, model_checkpoint=False, save_to_dir="../trainedmodels/", early_stopping=5, criterion=criterion, init_model_weights=True, lr_decay=2 ) container.save_model() container.reload_from_saved() assert container._model._gnn._prop_model._dropout.p == 0.2 def test_ggnn(): encoder = ConvWordsEncoder(*wordembeddings.shape) encoder.load_word_embeddings_from_numpy(wordembeddings) net = GNNModel(encoder) criterion = nn.MultiMarginLoss(margin=0.5) container = fackel.TorchContainer( torch_model=net, batch_size=8, max_epochs=5, model_checkpoint=False, early_stopping=5, criterion=criterion, init_model_weights=True, lr_decay=2 ) train_questions = V.encode_batch_questions(training_dataset, word2idx)[..., 0, :] train_graphs = V.encode_batch_graph_structure(training_dataset, word2idx) targets = np.zeros(len(training_dataset), dtype=np.int32) container.train(train=(train_questions, *train_graphs), train_targets=targets) def test_gnn(): encoder = ConvWordsEncoder(*wordembeddings.shape) encoder.load_word_embeddings_from_numpy(wordembeddings) net = GNNModel(encoder, hp_gated=False) criterion = nn.MultiMarginLoss(margin=0.5) container = fackel.TorchContainer( torch_model=net, batch_size=8, max_epochs=5, model_checkpoint=False, early_stopping=5, criterion=criterion, init_model_weights=True, lr_decay=2 ) train_questions = V.encode_batch_questions(training_dataset, word2idx)[..., 0, :] train_graphs = V.encode_batch_graph_structure(training_dataset, word2idx) targets = np.zeros(len(training_dataset), dtype=np.int32) container.train(train=(train_questions, *train_graphs), train_targets=targets) if __name__ == '__main__': test_gnn() # pytest.main(['-v', __file__])