#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import unittest import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torchdp import PerSampleGradientClipper from torchvision import transforms from torchvision.datasets import FakeData from torchdp.utils import ConstantFlatClipper, ConstantPerLayerClipper class SampleConvNet(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 16, 8, 3) self.conv2 = nn.Conv1d(16, 32, 3, 1) self.convf = nn.Conv1d(32, 32, 1, 1) for p in self.convf.parameters(): p.requires_grad = False self.fc1 = nn.Linear(23, 17) self.fc2 = nn.Linear(32 * 17, 10) def forward(self, x): # x of shape [B, 1, 28, 28] x = F.relu(self.conv1(x)) # -> [B, 16, 10, 10] x = F.max_pool2d(x, 2, 2) # -> [B, 16, 5, 5] x = x.view(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]) # -> [B, 16, 25] x = F.relu(self.conv2(x)) # -> [B, 32, 23] x = self.convf(x) # -> [B, 32, 23] x = self.fc1(x) # -> [B, 32, 17] x = x.view(-1, x.shape[-2] * x.shape[-1]) # -> [B, 32 * 17] x = self.fc2(x) # -> [B, 10] return x def name(self): return "SampleConvNet" class PerSampleGradientClipper_test(unittest.TestCase): def setUp(self): self.DATA_SIZE = 64 self.criterion = nn.CrossEntropyLoss() self.setUp_data() self.setUp_original_model() self.setUp_clipped_model(clip_value=0.003, run_clipper_step=True) def setUp_data(self): self.ds = FakeData( size=self.DATA_SIZE, image_size=(1, 35, 35), num_classes=10, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ), ) self.dl = DataLoader(self.ds, batch_size=self.DATA_SIZE) def setUp_original_model(self): self.original_model = SampleConvNet() for x, y in self.dl: logits = self.original_model(x) loss = self.criterion(logits, y) loss.backward() # puts grad in self.original_model.parameters() self.original_grads_norms = torch.stack( [ p.grad.norm() for p in self.original_model.parameters() if p.requires_grad ], dim=-1, ) def setUp_clipped_model(self, clip_value=0.003, run_clipper_step=True): # Deep copy self.clipped_model = SampleConvNet() # create the structure self.clipped_model.load_state_dict(self.original_model.state_dict()) # fill it # Intentionally clipping to a very small value norm_clipper = ( ConstantFlatClipper(clip_value) if not isinstance(clip_value, list) else ConstantPerLayerClipper(clip_value) ) self.clipper = PerSampleGradientClipper( self.clipped_model, norm_clipper, ) for x, y in self.dl: logits = self.clipped_model(x) loss = self.criterion(logits, y) loss.backward() # puts grad in self.clipped_model.parameters() if run_clipper_step: self.clipper.clip_and_accumulate() self.clipper.pre_step() self.clipped_grads_norms = torch.stack( [p.grad.norm() for p in self.clipped_model.parameters() if p.requires_grad], dim=-1, ) def test_clipped_grad_norm_is_smaller(self): """ Test that grad are clipped and their value changes """ for original_layer_norm, clipped_layer_norm in zip( self.original_grads_norms, self.clipped_grads_norms ): self.assertLess(float(clipped_layer_norm), float(original_layer_norm)) def test_clipped_grad_norm_is_smaller_perlayer(self): """ Test that grad are clipped and their value changes """ # there are 8 parameter sets [bias, weight] * [conv1, conv2, fc1, fc2] self.setUp_clipped_model(clip_value=[0.001] * 8) for original_layer_norm, clipped_layer_norm in zip( self.original_grads_norms, self.clipped_grads_norms ): self.assertLess(float(clipped_layer_norm), float(original_layer_norm)) def test_clipped_grad_norms_not_zero(self): """ Test that grads aren't killed by clipping """ allzeros = torch.zeros_like(self.clipped_grads_norms) self.assertFalse(torch.allclose(self.clipped_grads_norms, allzeros)) def test_clipped_grad_norms_not_zero_per_layer(self): """ Test that grads aren't killed by clipping """ # there are 8 parameter sets [bias, weight] * [conv1, conv2, fc1, fc2] self.setUp_clipped_model(clip_value=[0.001] * 8) allzeros = torch.zeros_like(self.clipped_grads_norms) self.assertFalse(torch.allclose(self.clipped_grads_norms, allzeros)) def test_clipping_to_high_value_does_nothing(self): self.setUp_clipped_model( clip_value=9999, run_clipper_step=True ) # should be a no-op self.assertTrue( torch.allclose(self.original_grads_norms, self.clipped_grads_norms) ) def test_clipping_to_high_value_does_nothing_per_layer(self): self.setUp_clipped_model( clip_value=[9999] * 8, run_clipper_step=True ) # should be a no-op self.assertTrue( torch.allclose(self.original_grads_norms, self.clipped_grads_norms) ) def test_grad_norms_untouched_without_clip_step(self): """ Test that grad are not clipped until clipper.step() is called """ self.setUp_clipped_model(clip_value=0.003, run_clipper_step=False) self.assertTrue( torch.allclose(self.original_grads_norms, self.clipped_grads_norms) )