import unittest import torch from torch.nn import Module, Linear import torch.nn.init as init import torchbearer from torchbearer import callbacks as c class Net(Module): def __init__(self, x): super(Net, self).__init__() self.pars = torch.nn.Parameter(x) def f(self): """ function to be minimised: f(x) = (x[0]-5)^2 + x[1]^2 + (x[2]-1)^2 Solution: x = [5,0,1] """ out = torch.zeros_like(self.pars) out[0] = self.pars[0]-5 out[1] = self.pars[1] out[2] = self.pars[2]-1 return torch.sum(out**2) def forward(self, _): return self.f() class NetWithState(Net): def forward(self, _, state=None): if state is None: raise ValueError return super(NetWithState, self).forward(_) def loss(y_pred, y_true): return y_pred class TestEndToEnd(unittest.TestCase): def test_basic_opt(self): p = torch.tensor([2.0, 1.0, 10.0]) training_steps = 1000 model = NetWithState(p) optim = torch.optim.SGD(model.parameters(), lr=0.01) trial = torchbearer.Trial(model, optim, loss).for_train_steps(training_steps).for_val_steps(1).for_test_steps(1) trial.run() trial.predict() trial.evaluate() self.assertAlmostEqual(model.pars[0].item(), 5.0, places=4) self.assertAlmostEqual(model.pars[1].item(), 0.0, places=4) self.assertAlmostEqual(model.pars[2].item(), 1.0, places=4) def test_callbacks(self): from torch.utils.data import TensorDataset traingen = TensorDataset(torch.rand(10, 1, 3), torch.rand(10, 1)) valgen = TensorDataset(torch.rand(10, 1, 3), torch.rand(10, 1)) testgen = TensorDataset(torch.rand(10, 1, 3), torch.rand(10, 1)) model = torch.nn.Linear(3, 1) optim = torch.optim.SGD(model.parameters(), lr=0.01) cbs = [] cbs.extend([c.EarlyStopping(), c.GradientClipping(10, model.parameters()), c.Best('test.pt'), c.MostRecent('test.pt'), c.ReduceLROnPlateau(), c.CosineAnnealingLR(0.1, 0.01), c.ExponentialLR(1), c.Interval('test.pt'), c.CSVLogger('test_csv.pt'), c.L1WeightDecay(), c.L2WeightDecay(), c.TerminateOnNaN(monitor='fail_metric')]) trial = torchbearer.Trial(model, optim, torch.nn.MSELoss(), metrics=['loss'], callbacks=cbs) trial = trial.with_generators(traingen, valgen, testgen) trial.run(2) trial.predict() trial.evaluate(data_key=torchbearer.TEST_DATA) trial.evaluate() import os os.remove('test.pt') os.remove('test_csv.pt') def test_zero_model(self): model = Linear(3, 1) init.constant_(model.weight, 0) init.constant_(model.bias, 0) optim = torch.optim.SGD(model.parameters(), lr=0.01) trial = torchbearer.Trial(model, optim, loss) trial.with_test_data(torch.rand(10, 3), batch_size=3) preds = trial.predict() for i in range(len(preds)): self.assertAlmostEqual(preds[i], 0) def test_basic_checkpoint(self): p = torch.tensor([2.0, 1.0, 10.0]) training_steps = 500 model = Net(p) optim = torch.optim.SGD(model.parameters(), lr=0.01) trial = torchbearer.Trial(model, optim, loss, callbacks=[torchbearer.callbacks.MostRecent(filepath='test.pt')]).for_train_steps(training_steps).for_val_steps(1) trial.run(2) # Simulate 2 'epochs' # Reload p = torch.tensor([2.0, 1.0, 10.0]) model = Net(p) optim = torch.optim.SGD(model.parameters(), lr=0.01) trial = torchbearer.Trial(model, optim, loss, callbacks=[torchbearer.callbacks.MostRecent(filepath='test.pt')]).for_train_steps(training_steps) trial.load_state_dict(torch.load('test.pt')) self.assertEqual(len(trial.state[torchbearer.HISTORY]), 2) self.assertAlmostEqual(model.pars[0].item(), 5.0, places=4) self.assertAlmostEqual(model.pars[1].item(), 0.0, places=4) self.assertAlmostEqual(model.pars[2].item(), 1.0, places=4) import os os.remove('test.pt') def test_with_loader(self): p = torch.tensor([2.0, 1.0, 10.0]) training_steps = 2 model = Net(p) optim = torch.optim.SGD(model.parameters(), lr=0.01) test_var = {'loaded': False} def custom_loader(state): state[torchbearer.X], state[torchbearer.Y_TRUE] = None, None test_var['loaded'] = True trial = torchbearer.Trial(model, optim, loss, callbacks=[torchbearer.callbacks.MostRecent(filepath='test.pt')]).for_train_steps(training_steps).for_val_steps(1) trial.with_loader(custom_loader) self.assertTrue(not test_var['loaded']) trial.run(1) self.assertTrue(test_var['loaded']) import os os.remove('test.pt') def test_only_model(self): p = torch.tensor([2.0, 1.0, 10.0]) model = Net(p) trial = torchbearer.Trial(model) self.assertListEqual(trial.run(), []) def test_no_model(self): trial = torchbearer.Trial(None) trial.run() self.assertTrue(torchbearer.trial.MockModel()(torch.rand(1)) is None) def test_no_train_steps(self): trial = torchbearer.Trial(None) trial.for_val_steps(10) trial.run()