#!/usr/bin/python # encoding: utf-8 import sys import unittest import torch from torch.autograd import Variable import collections origin_path = sys.path sys.path.append("..") import utils sys.path = origin_path def equal(a, b): if isinstance(a, torch.Tensor): return a.equal(b) elif isinstance(a, str): return a == b elif isinstance(a, collections.Iterable): res = True for (x, y) in zip(a, b): res = res & equal(x, y) return res else: return a == b class utilsTestCase(unittest.TestCase): def checkConverter(self): encoder = utils.strLabelConverter('abcdefghijklmnopqrstuvwxyz') # Encode # trivial mode result = encoder.encode('efa') target = (torch.IntTensor([5, 6, 1]), torch.IntTensor([3])) self.assertTrue(equal(result, target)) # batch mode result = encoder.encode(['efa', 'ab']) target = (torch.IntTensor([5, 6, 1, 1, 2]), torch.IntTensor([3, 2])) self.assertTrue(equal(result, target)) # Decode # trivial mode result = encoder.decode( torch.IntTensor([5, 6, 1]), torch.IntTensor([3])) target = 'efa' self.assertTrue(equal(result, target)) # replicate mode result = encoder.decode( torch.IntTensor([5, 5, 0, 1]), torch.IntTensor([4])) target = 'ea' self.assertTrue(equal(result, target)) # raise AssertionError def f(): result = encoder.decode( torch.IntTensor([5, 5, 0, 1]), torch.IntTensor([3])) self.assertRaises(AssertionError, f) # batch mode result = encoder.decode( torch.IntTensor([5, 6, 1, 1, 2]), torch.IntTensor([3, 2])) target = ['efa', 'ab'] self.assertTrue(equal(result, target)) def checkOneHot(self): v = torch.LongTensor([1, 2, 1, 2, 0]) v_length = torch.LongTensor([2, 3]) v_onehot = utils.oneHot(v, v_length, 4) target = torch.FloatTensor([[[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 0]], [[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0]]]) assert target.equal(v_onehot) def checkAverager(self): acc = utils.averager() acc.add(Variable(torch.Tensor([1, 2]))) acc.add(Variable(torch.Tensor([[5, 6]]))) assert acc.val() == 3.5 acc = utils.averager() acc.add(torch.Tensor([1, 2])) acc.add(torch.Tensor([[5, 6]])) assert acc.val() == 3.5 def checkAssureRatio(self): img = torch.Tensor([[1], [3]]).view(1, 1, 2, 1) img = Variable(img) img = utils.assureRatio(img) assert torch.Size([1, 1, 2, 2]) == img.size() def _suite(): suite = unittest.TestSuite() suite.addTest(utilsTestCase("checkConverter")) suite.addTest(utilsTestCase("checkOneHot")) suite.addTest(utilsTestCase("checkAverager")) suite.addTest(utilsTestCase("checkAssureRatio")) return suite if __name__ == "__main__": suite = _suite() runner = unittest.TextTestRunner() runner.run(suite)