from unittest import TestCase from mock import patch, Mock import warnings import torchbearer from torchbearer.callbacks import TorchScheduler, LambdaLR, StepLR, MultiStepLR, ExponentialLR, CosineAnnealingLR,\ ReduceLROnPlateau, CyclicLR class TestTorchScheduler(TestCase): def setUp(self): super(TestTorchScheduler, self).setUp() warnings.filterwarnings('always') def tearDown(self): super(TestTorchScheduler, self).tearDown() warnings.filterwarnings('default') def test_torch_scheduler_on_batch_with_monitor(self): state = {torchbearer.EPOCH: 1, torchbearer.METRICS: {'test': 101}, torchbearer.OPTIMIZER: 'optimizer'} mock_scheduler = Mock() mock_scheduler.return_value = mock_scheduler torch_scheduler = TorchScheduler(lambda opt: mock_scheduler(opt), monitor='test', step_on_batch=True) torch_scheduler.on_start(state) mock_scheduler.assert_called_once_with('optimizer') mock_scheduler.reset_mock() torch_scheduler.on_start_training(state) mock_scheduler.assert_not_called() mock_scheduler.reset_mock() torch_scheduler.on_sample(state) mock_scheduler.assert_not_called() mock_scheduler.reset_mock() torch_scheduler.on_step_training(state) mock_scheduler.step.assert_called_once_with(101) mock_scheduler.reset_mock() torch_scheduler.on_end_epoch(state) mock_scheduler.assert_not_called() mock_scheduler.reset_mock() def test_torch_scheduler_on_epoch_with_monitor(self): state = {torchbearer.EPOCH: 1, torchbearer.METRICS: {'test': 101}, torchbearer.OPTIMIZER: 'optimizer', torchbearer.DATA: None} mock_scheduler = Mock() mock_scheduler.return_value = mock_scheduler torch_scheduler = TorchScheduler(lambda opt: mock_scheduler(opt), monitor='test', step_on_batch=False) torch_scheduler.on_start(state) mock_scheduler.assert_called_once_with('optimizer') mock_scheduler.reset_mock() torch_scheduler.on_start_training(state) mock_scheduler.assert_not_called() mock_scheduler.reset_mock() torch_scheduler.on_sample(state) mock_scheduler.assert_not_called() mock_scheduler.reset_mock() torch_scheduler.on_step_training(state) mock_scheduler.assert_not_called() mock_scheduler.reset_mock() torch_scheduler.on_end_epoch(state) mock_scheduler.step.assert_called_once_with(101, epoch=1) mock_scheduler.reset_mock() def test_torch_scheduler_on_batch_no_monitor(self): state = {torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer'} mock_scheduler = Mock() mock_scheduler.return_value = mock_scheduler torch_scheduler = TorchScheduler(lambda opt: mock_scheduler(opt), monitor=None, step_on_batch=True) torch_scheduler.on_start(state) mock_scheduler.assert_called_once_with('optimizer') mock_scheduler.reset_mock() torch_scheduler.on_start_training(state) mock_scheduler.assert_not_called() mock_scheduler.reset_mock() torch_scheduler.on_sample(state) mock_scheduler.step.assert_called_once_with() mock_scheduler.reset_mock() torch_scheduler.on_step_training(state) mock_scheduler.assert_not_called() mock_scheduler.reset_mock() torch_scheduler.on_end_epoch(state) mock_scheduler.assert_not_called() mock_scheduler.reset_mock() def test_torch_scheduler_on_epoch_no_monitor(self): state = {torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer', torchbearer.METRICS: {}} mock_scheduler = Mock() mock_scheduler.return_value = mock_scheduler torch_scheduler = TorchScheduler(lambda opt: mock_scheduler(opt), monitor=None, step_on_batch=False) torch_scheduler.on_start(state) mock_scheduler.assert_called_once_with('optimizer') mock_scheduler.reset_mock() torch_scheduler.on_sample(state) mock_scheduler.assert_not_called() mock_scheduler.reset_mock() torch_scheduler.on_step_training(state) mock_scheduler.assert_not_called() mock_scheduler.reset_mock() torch_scheduler.on_end_epoch(state) mock_scheduler.assert_not_called() mock_scheduler.reset_mock() def test_monitor_not_found(self): state = {torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer', torchbearer.METRICS: {'not_test': 1.}} mock_scheduler = Mock() mock_scheduler.return_value = mock_scheduler torch_scheduler = TorchScheduler(lambda opt: mock_scheduler(opt), monitor='test', step_on_batch=False) torch_scheduler.on_start(state) with warnings.catch_warnings(record=True) as w: torch_scheduler.on_start_validation(state) self.assertTrue(len(w) == 0) with warnings.catch_warnings(record=True) as w: torch_scheduler.on_end_epoch(state) self.assertTrue('Failed to retrieve key `test`' in str(w[0].message)) def test_monitor_found(self): state = {torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer', torchbearer.METRICS: {'test': 1.}} mock_scheduler = Mock() mock_scheduler.return_value = mock_scheduler torch_scheduler = TorchScheduler(lambda opt: mock_scheduler(opt), monitor='test', step_on_batch=False) torch_scheduler.on_start(state) with warnings.catch_warnings(record=True) as w: torch_scheduler.on_start_training(state) self.assertTrue(len(w) == 0) with warnings.catch_warnings(record=True) as w: torch_scheduler.on_start_validation(state) self.assertTrue(len(w) == 0) with warnings.catch_warnings(record=True) as w: torch_scheduler.on_end_epoch(state) self.assertTrue(len(w) == 0) def test_batch_monitor_not_found(self): state = {torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer', torchbearer.METRICS: {'not_test': 1.}} mock_scheduler = Mock() mock_scheduler.return_value = mock_scheduler torch_scheduler = TorchScheduler(lambda opt: mock_scheduler(opt), monitor='test', step_on_batch=True) torch_scheduler.on_start(state) with warnings.catch_warnings(record=True) as w: torch_scheduler.on_step_training(state) self.assertTrue('Failed to retrieve key `test`' in str(w[0].message)) def test_batch_monitor_found(self): state = {torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer', torchbearer.METRICS: {'test': 1.}} mock_scheduler = Mock() mock_scheduler.return_value = mock_scheduler torch_scheduler = TorchScheduler(lambda opt: mock_scheduler(opt), monitor='test', step_on_batch=True) torch_scheduler.on_start(state) with warnings.catch_warnings(record=True) as w: torch_scheduler.on_step_training(state) self.assertTrue(len(w) == 0) class TestLambdaLR(TestCase): @patch('torch.optim.lr_scheduler.LambdaLR') def test_lambda_lr(self, lr_mock): state = {torchbearer.OPTIMIZER: 'optimizer'} scheduler = LambdaLR(0.1, last_epoch=-4, step_on_batch='batch') scheduler.on_start(state) lr_mock.assert_called_once_with('optimizer', 0.1, last_epoch=-4) self.assertTrue(scheduler._step_on_batch == 'batch') class TestStepLR(TestCase): @patch('torch.optim.lr_scheduler.StepLR') def test_lambda_lr(self, lr_mock): state = {torchbearer.OPTIMIZER: 'optimizer'} scheduler = StepLR(10, gamma=0.4, last_epoch=-4, step_on_batch='batch') scheduler.on_start(state) lr_mock.assert_called_once_with('optimizer', 10, gamma=0.4, last_epoch=-4) self.assertTrue(scheduler._step_on_batch == 'batch') class TestMultiStepLR(TestCase): @patch('torch.optim.lr_scheduler.MultiStepLR') def test_lambda_lr(self, lr_mock): state = {torchbearer.OPTIMIZER: 'optimizer'} scheduler = MultiStepLR(10, gamma=0.4, last_epoch=-4, step_on_batch='batch') scheduler.on_start(state) lr_mock.assert_called_once_with('optimizer', 10, gamma=0.4, last_epoch=-4) self.assertTrue(scheduler._step_on_batch == 'batch') class TestExponentialLR(TestCase): @patch('torch.optim.lr_scheduler.ExponentialLR') def test_lambda_lr(self, lr_mock): state = {torchbearer.OPTIMIZER: 'optimizer'} scheduler = ExponentialLR(0.4, last_epoch=-4, step_on_batch='batch') scheduler.on_start(state) lr_mock.assert_called_once_with('optimizer', 0.4, last_epoch=-4) self.assertTrue(scheduler._step_on_batch == 'batch') class TestCosineAnnealingLR(TestCase): @patch('torch.optim.lr_scheduler.CosineAnnealingLR') def test_lambda_lr(self, lr_mock): state = {torchbearer.OPTIMIZER: 'optimizer'} scheduler = CosineAnnealingLR(4, eta_min=10, last_epoch=-4, step_on_batch='batch') scheduler.on_start(state) lr_mock.assert_called_once_with('optimizer', 4, eta_min=10, last_epoch=-4) self.assertTrue(scheduler._step_on_batch == 'batch') class TestReduceLROnPlateau(TestCase): @patch('torch.optim.lr_scheduler.ReduceLROnPlateau') def test_lambda_lr(self, lr_mock): state = {torchbearer.OPTIMIZER: 'optimizer'} scheduler = ReduceLROnPlateau(monitor='test', mode='max', factor=0.2, patience=100, verbose=True, threshold=10, threshold_mode='thresh', cooldown=5, min_lr=0.1, eps=1e-4, step_on_batch='batch') scheduler.on_start(state) lr_mock.assert_called_once_with('optimizer', mode='max', factor=0.2, patience=100, verbose=True, threshold=10, threshold_mode='thresh', cooldown=5, min_lr=0.1, eps=1e-4) self.assertTrue(scheduler._step_on_batch == 'batch') self.assertTrue(scheduler._monitor == 'test') class TestCyclicLR(TestCase): def test_lambda_lr(self): from distutils.version import LooseVersion import torch version = torch.__version__ if str(torch.__version__) is torch.__version__ else "0.4.0" if LooseVersion(version) > LooseVersion("1.0.0"): # CyclicLR is implemented with patch('torch.optim.lr_scheduler.CyclicLR') as lr_mock: state = {torchbearer.OPTIMIZER: 'optimizer'} scheduler = CyclicLR(0.01, 0.1, monitor='test', step_size_up=200, step_size_down=None, mode='triangular', gamma=2., scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.7, max_momentum=0.9, last_epoch=-1, step_on_batch='batch') scheduler.on_start(state) lr_mock.assert_called_once_with('optimizer', 0.01, 0.1, step_size_up=200, step_size_down=None, mode='triangular', gamma=2., scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.7, max_momentum=0.9, last_epoch=-1) self.assertTrue(scheduler._step_on_batch == 'batch') self.assertTrue(scheduler._monitor == 'test') else: self.assertRaises(NotImplementedError, lambda: CyclicLR(0.01, 0.1, monitor='test', step_size_up=200, step_size_down=None, mode='triangular', gamma=2., scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.7, max_momentum=0.9, last_epoch=-1, step_on_batch='batch'))