import torch from torchvision import models import math import unittest from onecyclelr import OneCycleLR class TestOneCycleLR(unittest.TestCase): def setUp(self): self.model = models.resnet18() self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1, momentum=0.1) self.scheduler = OneCycleLR( self.optimizer, num_steps=1000, lr_range=(0.1, 1.), momentum_range=(0.85, 0.95), annihilation_frac=0.1, reduce_factor=0.01, last_step=-1 ) def test_internals(self): assert self.scheduler.num_cycle_steps == 900 assert math.isclose(self.scheduler.final_lr, 0.1 * 0.01) assert math.isclose(self.scheduler.get_lr(), 0.1) assert math.isclose(self.scheduler.get_momentum(), 0.95) def test_step(self): # Scale up for i in range(450): self.scheduler.step() assert self.scheduler.last_step == 450 assert math.isclose(self.scheduler.get_lr(), 1.) assert math.isclose(self.scheduler.get_momentum(), 0.85) # Scale down for i in range(450): self.scheduler.step() assert self.scheduler.last_step == 900 assert math.isclose(self.scheduler.get_lr(), 0.1) assert math.isclose(self.scheduler.get_momentum(), 0.95) for i in range(100): self.scheduler.step() assert self.scheduler.last_step == 1000 assert math.isclose(self.scheduler.get_lr(), 0.001) assert math.isclose(self.scheduler.get_momentum(), 0.95) # Go beyond the given num of steps: check if it works okay for i in range(50): self.scheduler.step() assert math.isclose(self.scheduler.get_lr(), 0.001) assert math.isclose(self.scheduler.get_momentum(), 0.95)