# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # import unittest import numpy as np import tensorflow as tf from unittest.mock import Mock from unittest.mock import patch from kglib.kgcn.models.embedding import embed_type, embed_attribute from kglib.utils.test.utils import get_call_args class TestTypeEmbedding(unittest.TestCase): def setUp(self): tf.enable_eager_execution() def test_embedding_output_shape_as_expected(self): features = np.array([[1, 0, 0.7], [1, 2, 0.7], [0, 1, 0.5]], dtype=np.float32) type_embedding_dim = 5 output = embed_type(features, 3, type_embedding_dim) np.testing.assert_array_equal(np.array([3, 6]), output.shape) class TestAttributeEmbedding(unittest.TestCase): def setUp(self): tf.enable_eager_execution() def test_embedding_is_typewise(self): features = np.array([[1, 0, 0.7], [1, 2, 0.7], [0, 1, 0.5]]) mock_instance = Mock(return_value=tf.convert_to_tensor(np.array([[1, 0.7], [1, 0.7], [0, 0.5]]))) mock = Mock(return_value=mock_instance) patcher = patch('kglib.kgcn.models.embedding.TypewiseEncoder', spec=True, new=mock) mock_class = patcher.start() attr_encoders = Mock() attr_embedding_dim = Mock() embed_attribute(features, attr_encoders, attr_embedding_dim) # Function under test mock_class.assert_called_once_with(attr_encoders, attr_embedding_dim) call_args = get_call_args(mock_instance) np.testing.assert_array_equal([[np.array([[0, 0.7], [2, 0.7], [1, 0.5]])]], call_args) patcher.stop() if __name__ == "__main__": unittest.main()