import os import shutil import sys from io import StringIO from contextlib import contextmanager import torch from torch import Tensor import numpy as np import unittest from neural_pipeline.utils import FileStructManager, CheckpointsManager, dict_recursive_bypass from neural_pipeline.utils.fsm import FolderRegistrable from tests.common import UseFileStructure __all__ = ['UtilsTest', 'FileStructManagerTest', 'CheckpointsManagerTests'] @contextmanager def captured_output(): new_out, new_err = StringIO(), StringIO() old_out, old_err = sys.stdout, sys.stderr try: sys.stdout, sys.stderr = new_out, new_err yield sys.stdout, sys.stderr finally: sys.stdout, sys.stderr = old_out, old_err class UtilsTest(unittest.TestCase): def test_dict_recursive_bypass(self): d = {'data': np.array([1]), 'target': {'a': np.array([1]), 'b': np.array([1])}} d = dict_recursive_bypass(d, lambda v: torch.from_numpy(v)) self.assertTrue(isinstance(d['data'], Tensor)) self.assertTrue(isinstance(d['target']['a'], Tensor)) self.assertTrue(isinstance(d['target']['b'], Tensor)) class FileStructManagerTest(UseFileStructure): class TestObj(FolderRegistrable): def __init__(self, m: 'FileStructManager', dir: str, name: str): super().__init__(m) self.dir = dir self.name = name def _get_gir(self) -> str: return self.dir def _get_name(self) -> str: return self.name def test_creation(self): if os.path.exists(self.base_dir): shutil.rmtree(self.checkpoints_dir, ignore_errors=True) try: FileStructManager(base_dir=self.base_dir, is_continue=False) except FileStructManager.FSMException as err: self.fail("Raise error when base directory exists: [{}]".format(err)) self.assertFalse(os.path.exists(self.base_dir)) try: FileStructManager(base_dir=self.base_dir, is_continue=False) except FileStructManager.FSMException as err: self.fail("Raise error when base directory exists but empty: [{}]".format(err)) os.makedirs(os.path.join(self.base_dir, 'new_dir')) try: FileStructManager(base_dir=self.base_dir, is_continue=False) except: self.fail("Error initialize when exists non-registered folders in base directory") shutil.rmtree(self.base_dir, ignore_errors=True) def test_object_registration(self): fsm = FileStructManager(base_dir=self.base_dir, is_continue=False) fsm_exist_ok = FileStructManager(base_dir=self.base_dir, is_continue=False, exists_ok=True) o = self.TestObj(fsm, 'test_dir', 'test_name') fsm.register_dir(o) expected_path = os.path.join(self.base_dir, 'test_dir') self.assertFalse(os.path.exists(expected_path)) self.assertEqual(fsm.get_path(o), expected_path) with self.assertRaises(FileStructManager.FSMException): fsm.register_dir(self.TestObj(fsm, 'test_dir', 'another_name')) try: fsm.register_dir(self.TestObj(fsm, 'test_dir', 'another_name'), check_dir_registered=False) fsm_exist_ok.register_dir(self.TestObj(fsm, 'test_dir', 'another_name')) fsm_exist_ok.register_dir(self.TestObj(fsm, 'test_dir', 'another_name2'), check_dir_registered=False) except: self.fail("Folder registrable test fail when it's disabled") with self.assertRaises(FileStructManager.FSMException): fsm.register_dir(self.TestObj(fsm, 'another_dir', 'test_name')) fsm.register_dir(self.TestObj(fsm, 'another_dir', 'another_name')) with self.assertRaises(FileStructManager.FSMException): fsm_exist_ok.register_dir(self.TestObj(fsm, 'another_dir', 'test_name')) fsm_exist_ok.register_dir(self.TestObj(fsm, 'another_dir', 'another_name')) try: fsm.register_dir(self.TestObj(fsm, 'another_dir2', 'test_name'), check_name_registered=False) fsm_exist_ok.register_dir(self.TestObj(fsm, 'another_dir2', 'test_name'), check_name_registered=False) except: self.fail("Folder registrable test fail when it's disabled") os.makedirs(os.path.join(self.base_dir, 'dir_dir', 'dir')) with self.assertRaises(FileStructManager.FSMException): fsm.register_dir(self.TestObj(fsm, 'dir_dir', 'name_name')) try: fsm_exist_ok.register_dir(self.TestObj(fsm, 'dir_dir', 'name_name')) except: self.fail("Folder registrable test fail when exists_ok=True") class CheckpointsManagerTests(UseFileStructure): def test_initialisation(self): fsm = FileStructManager(base_dir=self.base_dir, is_continue=False) try: cm = CheckpointsManager(fsm) except Exception as err: self.fail("Fail init CheckpointsManager; err: ['{}']".format(err)) with self.assertRaises(FileStructManager.FSMException): CheckpointsManager(fsm) os.mkdir(os.path.join(fsm.get_path(cm), 'test_dir')) with self.assertRaises(FileStructManager.FSMException): CheckpointsManager(fsm) def test_pack(self): fsm = FileStructManager(base_dir=self.base_dir, is_continue=False) cm = CheckpointsManager(fsm) with self.assertRaises(CheckpointsManager.SMException): cm.pack() os.mkdir(cm.weights_file()) os.mkdir(cm.optimizer_state_file()) with self.assertRaises(CheckpointsManager.SMException): cm.pack() shutil.rmtree(cm.weights_file()) shutil.rmtree(cm.optimizer_state_file()) f = open(cm.weights_file(), 'w') f.close() f = open(cm.optimizer_state_file(), 'w') f.close() f = open(cm.trainer_file(), 'w') f.close() try: cm.pack() except Exception as err: self.fail('Exception on packing files: [{}]'.format(err)) for f in [cm.weights_file(), cm.optimizer_state_file()]: if os.path.exists(f) and os.path.isfile(f): self.fail("File '{}' doesn't remove after pack".format(f)) result = os.path.join(fsm.get_path(cm, check=False, create_if_non_exists=False), 'last_checkpoint.zip') self.assertTrue(os.path.exists(result) and os.path.isfile(result)) f = open(cm.weights_file(), 'w') f.close() f = open(cm.optimizer_state_file(), 'w') f.close() f = open(cm.trainer_file(), 'w') f.close() try: cm.pack() result = os.path.join(fsm.get_path(cm, check=False, create_if_non_exists=False), 'last_checkpoint.zip.old') self.assertTrue(os.path.exists(result) and os.path.isfile(result)) except Exception as err: self.fail('Fail to pack with existing previous state file') def test_unpack(self): fsm = FileStructManager(base_dir=self.base_dir, is_continue=False) cm = CheckpointsManager(fsm) f = open(cm.weights_file(), 'w') f.write('1') f.close() f = open(cm.optimizer_state_file(), 'w') f.write('2') f = open(cm.trainer_file(), 'w') f.write('3') f.close() cm.pack() try: cm.unpack() except Exception as err: self.fail('Exception on unpacking') for i, f in enumerate([cm.weights_file(), cm.optimizer_state_file(), cm.trainer_file()]): if not (os.path.exists(f) and os.path.isfile(f)): self.fail("File '{}' doesn't remove after pack".format(f)) with open(f, 'r') as file: if file.read() != str(i + 1): self.fail("File content corrupted") def test_clear_files(self): fsm = FileStructManager(base_dir=self.base_dir, is_continue=False) cm = CheckpointsManager(fsm) f = open(cm.weights_file(), 'w') f.close() f = open(cm.optimizer_state_file(), 'w') f.close() cm.clear_files() for f in [cm.weights_file(), cm.optimizer_state_file()]: if os.path.exists(f) and os.path.isfile(f): self.fail("File '{}' doesn't remove after pack".format(f))