import shutil import tempfile from unittest import TestCase, mock from lineflow import download from lineflow.datasets.imdb import Imdb, _imdb_loader, get_imdb class ImdbTestCase(TestCase): @classmethod def setUpClass(cls): cls.default_cache_root = download.get_cache_root() cls.temp_dir = tempfile.mkdtemp() download.set_cache_root(cls.temp_dir) @classmethod def tearDownClass(cls): download.set_cache_root(cls.default_cache_root) shutil.rmtree(cls.temp_dir) def test_get_imdb(self): raw = get_imdb() # train self.assertIn('train', raw) self.assertEqual(len(raw['train']), 25_000) # test self.assertIn('test', raw) self.assertEqual(len(raw['test']), 25_000) def test_get_imdb_twice(self): get_imdb() with mock.patch('lineflow.datasets.imdb.pickle', autospec=True) as mock_pickle: get_imdb() mock_pickle.dump.assert_not_called() self.assertEqual(mock_pickle.load.call_count, 1) @mock.patch('lineflow.datasets.imdb.io.open', autospec=True) def test_imdb_loader(self, mock_open): for path in ('pos', 'neg'): with self.subTest(path=path): string, label = _imdb_loader(path) self.assertEqual(label, 0 if path == 'pos' else 1) def test_loads_each_split(self): train = Imdb(split='train') self.assertEqual(len(train), 25_000) test = Imdb(split='test') self.assertEqual(len(test), 25_000) def test_raises_value_error_with_invalid_split(self): with self.assertRaises(ValueError): Imdb(split='invalid_split')