# Copyright (c) Facebook, Inc. and its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unit tests for higher top level functions.""" import unittest import copy from collections import OrderedDict import numpy as np import torch from torch import nn, optim from torch.nn import functional as F import higher class _ReferenceNet(nn.Module): def __init__(self, features, fc): super().__init__() self.features = features self.add_module('fc', fc) def batch_norm( self, inputs, weight=None, bias=None, running_mean=None, running_var=None, training=True, eps=1e-5, momentum=0.1 ): running_mean = torch.zeros(np.prod(np.array(inputs.data.size()[1]))) running_var = torch.ones(np.prod(np.array(inputs.data.size()[1]))) return F.batch_norm( inputs, running_mean, running_var, weight, bias, training, momentum, eps ) def maxpool(self, input, kernel_size, stride=None): return F.max_pool2d(input, kernel_size, stride) def forward(self, x, params=None): if params is None: x = self.features(x).view(x.size(0), 64) x = self.fc(x) else: x = F.conv2d( x, params['features.conv1.weight'], params['features.conv1.bias'] ) x = self.batch_norm( x, weight=params['features.bn1.weight'], bias=params['features.bn1.bias'], momentum=1 ) x = F.relu(x) x = self.maxpool(x, kernel_size=2, stride=2) x = F.conv2d( x, params['features.conv2.weight'], params['features.conv2.bias'] ) x = self.batch_norm( x, weight=params['features.bn2.weight'], bias=params['features.bn2.bias'], momentum=1 ) x = F.relu(x) x = self.maxpool(x, kernel_size=2, stride=2) x = F.conv2d( x, params['features.conv3.weight'], params['features.conv3.bias'] ) x = self.batch_norm( x, weight=params['features.bn3.weight'], bias=params['features.bn3.bias'], momentum=1 ) x = F.relu(x) x = self.maxpool(x, kernel_size=2, stride=2) x = x.view(x.size(0), 64) x = F.linear(x, params['fc.weight'], params['fc.bias']) return x def get_fast_weights(self): fast_weights = OrderedDict( (name, param) for (name, param) in self.named_parameters() ) return fast_weights class _TargetNet(nn.Module): def __init__(self, features, fc): super().__init__() self.features = features self.add_module('fc', fc) def forward(self, x): x = self.features(x).view(x.size(0), 64) x = self.fc(x) return x class TestCorrectness(unittest.TestCase): """Test case for package-level functions for correctness.""" def setUp(self): self.num_in_channels = num_in_channels = 3 self.num_classes = num_classes = 5 self.batch_size = 7 self.in_h = self.in_w = 28 features = nn.Sequential( OrderedDict( [ ('conv1', nn.Conv2d(num_in_channels, 64, 3)), ('bn1', nn.BatchNorm2d(64, momentum=1, affine=True)), ('relu1', nn.ReLU(inplace=True)), ('pool1', nn.MaxPool2d(2, 2)), ('conv2', nn.Conv2d(64, 64, 3)), ('bn2', nn.BatchNorm2d(64, momentum=1, affine=True)), ('relu2', nn.ReLU(inplace=True)), ('pool2', nn.MaxPool2d(2, 2)), ('conv3', nn.Conv2d(64, 64, 3)), ('bn3', nn.BatchNorm2d(64, momentum=1, affine=True)), ('relu3', nn.ReLU(inplace=True)), ('pool3', nn.MaxPool2d(2, 2)) ] ) ) fc = nn.Linear(64, num_classes) self.target_net = _TargetNet(features, fc) self.reference_net = _ReferenceNet( copy.deepcopy(features), copy.deepcopy(fc) ) self.lr = .01 self.opt = optim.SGD(self.target_net.parameters(), lr=self.lr) def testSameInitialWeightsPrePatch(self): """Check that reference and unpatched target net have equal weights. This is mostly a sanity check for the purpose of the other unit tests. """ ref_params = list(self.reference_net.named_parameters()) target_params = list(self.target_net.named_parameters()) self.assertEqual( len(ref_params), len(target_params), msg=( "Length mismatched between reference net parameter count " "({}) and target ({}).".format( len(ref_params), len(target_params) ) ) ) for ref, target in zip(ref_params, target_params): ref_name, ref_p = ref target_name, target_p = target self.assertEqual( ref_name, target_name, msg="Name mismatch or parameter misalignment ('{}' vs '{}')". format(ref_name, target_name) ) self.assertTrue( torch.equal(ref_p, target_p), msg="Parameter value inequality for {}".format(ref_name) ) def testSameInitialWeightsPostPatch(self): """Verify fast weight alignment/equality after monkey patching.""" ref_named_params = list(self.reference_net.get_fast_weights().items()) ref_params = [p for (_, p) in ref_named_params] with higher.innerloop_ctx(self.target_net, self.opt) as (fnet, _): target_named_params = list(fnet.named_parameters()) target_params = fnet.parameters() self.assertEqual( len(ref_named_params), len(target_named_params), msg=( "Length mismatched between reference net parameter count " "({}) and target ({}).".format( len(ref_named_params), len(target_named_params) ) ) ) for ref, target in zip(ref_named_params, target_named_params): ref_name, ref_p = ref target_name, target_p = target self.assertEqual( ref_name, target_name, msg="Name mismatch or parameter misalignment ('{}' vs '{}')" .format(ref_name, target_name) ) self.assertTrue( torch.equal(ref_p, target_p), msg="Parameter value inequality for {}".format(ref_name) ) zipped = zip(ref_params, target_params) for i, (ref_p, target_p) in enumerate(zipped): self.assertTrue( torch.equal(ref_p, target_p), msg="Parameter misalignment in position {}.".format(i) ) def testRandomForwards(self): """Test reference and patched net forward equivalence. Test if, given rand fast weights, patched net and reference forwards match up given random inputs. """ with higher.innerloop_ctx(self.target_net, self.opt) as (fnet, _): for i in range(10): fast_named_weights = OrderedDict( (name, torch.rand(p.shape, requires_grad=True)) for name, p in self.reference_net.named_parameters() ) fast_weights = [p for _, p in fast_named_weights.items()] inputs = torch.rand( self.batch_size, self.num_in_channels, self.in_h, self.in_w ) self.assertTrue( torch.equal( self.reference_net(inputs, params=fast_named_weights), fnet(inputs, params=fast_weights) ) ) def testUnrollEqualityForward(self): """Check if unrolled patched and reference nets produce same meta loss. """ for test_it in range(5): with higher.innerloop_ctx(self.target_net, self.opt) as (fnet, diffopt): ref_out, target_out = self._joint_inner_loop( fnet, diffopt=diffopt, num_steps=10 ) ref_meta_loss = ref_out[0] ref_fast_weights = ref_out[1] ref_train_losses = ref_out[2] ref_train_grads = ref_out[3] target_meta_loss = target_out[0] target_fast_weights = target_out[1] target_train_losses = target_out[2] target_train_grads = target_out[3] # Check final losses match self.assertTrue( torch.allclose(ref_meta_loss, target_meta_loss), msg=( "Ref ({}) and target ({}) metaloss differed on test_it " "{} (mse {})" ).format( ref_meta_loss.item(), target_meta_loss.item(), test_it, (ref_meta_loss - target_meta_loss).pow(2).item() ) ) # Check that training losses align for rl, tl in zip(ref_train_losses, target_train_losses): torch.testing.assert_allclose(rl, tl) # Check that fast weights align for rw, tw in zip(ref_fast_weights, target_fast_weights): torch.testing.assert_allclose(rw, tw) # Check that grads align for rgs, tgs in zip(ref_train_grads, target_train_grads): for rg, tg in zip(rgs, tgs): torch.testing.assert_allclose(rg, tg) def testUnrollEqualityBackward(self): """Check if metagrads match for target/ref net.""" for test_it in range(5): with higher.innerloop_ctx(self.target_net, self.opt) as (fnet, diffopt): ref_out, target_out = self._joint_inner_loop( fnet, diffopt=diffopt, num_steps=10 ) ref_meta_loss = ref_out[0] target_meta_loss = target_out[0] ref_metagrads = torch.autograd.grad( ref_meta_loss, self.reference_net.parameters() ) target_metagrads = torch.autograd.grad( target_meta_loss, fnet.parameters(time=0) ) # Check that metagrads align for rg, tg in zip(ref_metagrads, target_metagrads): torch.testing.assert_allclose(rg, tg) def testUnrollEqualityBackwardManualUnroll(self): """Check if metagrads match for target/ref net. A differentiable optimizer is not used (manual inner loop SGD). """ for test_it in range(5): with higher.innerloop_ctx(self.target_net, self.opt) as (fnet, diffopt): ref_out, target_out = self._joint_inner_loop( fnet, diffopt=None, num_steps=10 ) ref_meta_loss = ref_out[0] target_meta_loss = target_out[0] ref_metagrads = torch.autograd.grad( ref_meta_loss, self.reference_net.parameters() ) target_metagrads = torch.autograd.grad( target_meta_loss, fnet.parameters(time=0) ) # Check that metagrads align for rg, tg in zip(ref_metagrads, target_metagrads): torch.testing.assert_allclose(rg, tg) def _joint_inner_loop(self, fnet, diffopt=None, num_steps=1): ref_fast_weights = self.reference_net.get_fast_weights() target_fast_weights = None if diffopt is None: # If diffopt not provided we manually will update these target_fast_weights = list(fnet.parameters()) # Things we want to track ref_train_losses = [] ref_train_grads = [] target_train_losses = [] target_train_grads = [] for _ in range(num_steps): inputs = torch.rand( self.batch_size, self.num_in_channels, self.in_h, self.in_w ) labels = torch.rand(self.batch_size, self.num_classes) # Do inner loop step for reference net ref_preds = self.reference_net(inputs, params=ref_fast_weights) ref_loss = F.mse_loss(ref_preds, labels) ref_grads = torch.autograd.grad( ref_loss, ref_fast_weights.values(), create_graph=True ) ref_train_losses.append(ref_loss) ref_train_grads.append(ref_grads) ref_fast_weights = OrderedDict( (name, param - self.lr * grad) for ((name, param), grad) in zip(ref_fast_weights.items(), ref_grads) ) # Do inner loop step for target net if diffopt is None: target_preds = fnet(inputs, params=target_fast_weights) else: target_preds = fnet(inputs) target_loss = F.mse_loss(target_preds, labels) if diffopt is None: target_grads = torch.autograd.grad( target_loss, target_fast_weights, create_graph=True ) target_fast_weights = [ w - (self.lr * g) for w, g in zip(target_fast_weights, target_grads) ] else: target_grads = torch.autograd.grad( target_loss, list(fnet.parameters()), create_graph=True ) diffopt.step(target_loss) target_train_losses.append(target_loss) target_train_grads.append(target_grads) # metaval inputs = torch.rand( self.batch_size, self.num_in_channels, self.in_h, self.in_w ) labels = torch.rand(self.batch_size, self.num_classes) ref_preds = self.reference_net(inputs, params=ref_fast_weights) ref_meta_loss = F.mse_loss(ref_preds, labels) if diffopt is None: target_preds = fnet(inputs, params=target_fast_weights) else: target_preds = fnet(inputs) target_meta_loss = F.mse_loss(target_preds, labels) ref_fast_weights = ref_fast_weights.values() target_fast_weights = fnet.parameters() if diffopt is None: target_fast_weights = fnet.parameters() packed = ( ( ref_meta_loss, ref_fast_weights, ref_train_losses, ref_train_grads ), ( target_meta_loss, target_fast_weights, target_train_losses, target_train_grads ) ) return packed