import shutil import string import sys import tempfile from unittest import TestCase, mock from lineflow import download from lineflow.datasets import text_classification from lineflow.datasets.text_classification import ( get_text_classification_dataset, urls) class TextClassificationTestCaseBase(TestCase): names = list(urls.keys()) sizes = [(120_000, 7_600), (450_000, 60_000), (560_000, 70_000), (560_000, 38_000), (650_000, 50_000), (1_400_000, 60_000), (3_600_000, 400_000), (3_000_000, 650_000)] @classmethod def setUpClass(cls): cls.default_cache_root = download.get_cache_root() cls.temp_dir = tempfile.mkdtemp() download.set_cache_root(cls.temp_dir) cls.patcher = mock.patch('lineflow.datasets.text_classification.sys.maxsize', int(sys.float_info.max)) cls.patcher.start() @classmethod def tearDownClass(cls): download.set_cache_root(cls.default_cache_root) shutil.rmtree(cls.temp_dir) cls.patcher.stop() def name2class(self, name): return getattr(text_classification, string.capwords(name, '_').replace('_', '')) def get_text_classification_dataset(self, name, train_size, test_size): raw = get_text_classification_dataset(name) # train self.assertIn('train', raw) self.assertEqual(len(raw['train']), train_size) # test self.assertIn('test', raw) self.assertEqual(len(raw['test']), test_size) def get_text_classification_dataset_twice(self, name): get_text_classification_dataset(name) with mock.patch('lineflow.datasets.text_classification.pickle', autospec=True) as \ mock_pickle: get_text_classification_dataset(name) mock_pickle.dump.assert_not_called() self.assertEqual(mock_pickle.load.call_count, 1) def loads_each_split(self, name, train_size, test_size): train = self.name2class(name)(split='train') self.assertEqual(len(train), train_size) test = self.name2class(name)(split='test') self.assertEqual(len(test), test_size) def test_raises_key_error_with_invalid_name(self): with self.assertRaises(KeyError): get_text_classification_dataset('invalid_name') def raises_value_error_with_invalid_split(self, name): with self.assertRaises(ValueError): self.name2class(name)(split='invalid_split') class AgNewsTestCase(TextClassificationTestCaseBase): def setUp(self): super(AgNewsTestCase, self).setUp() self.name = self.names[0] self.size = self.sizes[0] def test_get_text_classification_dataset(self): self.get_text_classification_dataset(self.name, *self.size) def test_get_text_classification_dataset_twice(self): self.get_text_classification_dataset_twice(self.name) def test_loads_each_split(self): self.loads_each_split(self.name, *self.size) def test_raises_value_error_with_invalid_split(self): self.raises_value_error_with_invalid_split(self.name) class SogouNewsTestCase(AgNewsTestCase): def setUp(self): super(SogouNewsTestCase, self).setUp() self.name = self.names[1] self.size = self.sizes[1] class DbpediaTestCase(AgNewsTestCase): def setUp(self): super(DbpediaTestCase, self).setUp() self.name = self.names[2] self.size = self.sizes[2] class YelpReviewPolarityTestCase(AgNewsTestCase): def setUp(self): super(YelpReviewPolarityTestCase, self).setUp() self.name = self.names[3] self.size = self.sizes[3] class YelpReviewFullTestCase(AgNewsTestCase): def setUp(self): super(YelpReviewFullTestCase, self).setUp() self.name = self.names[4] self.size = self.sizes[4] class YahooAnswersTestCase(AgNewsTestCase): def setUp(self): super(YahooAnswersTestCase, self).setUp() self.name = self.names[5] self.size = self.sizes[5] class AmazonReviewPolarityTestCase(AgNewsTestCase): def setUp(self): super(AmazonReviewPolarityTestCase, self).setUp() self.name = self.names[6] self.size = self.sizes[6] class AmazonReviewFullTestCase(AgNewsTestCase): def setUp(self): super(AmazonReviewFullTestCase, self).setUp() self.name = self.names[7] self.size = self.sizes[7]