#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import unittest import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torchdp import PrivacyEngine from torchvision import transforms from torchvision.datasets import FakeData class SampleConvNet(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 16, 8, 3) self.conv2 = nn.Conv1d(16, 32, 3, 1) self.convf = nn.Conv1d(32, 32, 1, 1) for p in self.convf.parameters(): p.requires_grad = False self.fc1 = nn.Linear(23, 17) self.fc2 = nn.Linear(32 * 17, 10) def forward(self, x): # x of shape [B, 1, 28, 28] x = F.relu(self.conv1(x)) # -> [B, 16, 10, 10] x = F.max_pool2d(x, 2, 2) # -> [B, 16, 5, 5] x = x.view(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]) # -> [B, 16, 25] x = F.relu(self.conv2(x)) # -> [B, 32, 23] x = self.convf(x) # -> [B, 32, 23] x = self.fc1(x) # -> [B, 32, 17] x = x.view(-1, x.shape[-2] * x.shape[-1]) # -> [B, 32 * 17] x = self.fc2(x) # -> [B, 10] return x def name(self): return "SampleConvNet" class GradientAccumulation_test(unittest.TestCase): def setUp(self): self.DATA_SIZE = 64 self.BATCH_SIZE = 16 self.LR = 0 # we want to call optimizer.step() without modifying the model self.ALPHAS = [1 + x / 10.0 for x in range(1, 100, 10)] self.criterion = nn.CrossEntropyLoss() self.setUp_data() self.setUp_model_and_optimizer() def setUp_data(self): self.ds = FakeData( size=self.DATA_SIZE, image_size=(1, 35, 35), num_classes=10, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ), ) self.dl = DataLoader(self.ds, batch_size=self.BATCH_SIZE) def setUp_model_and_optimizer(self): self.model = SampleConvNet() self.optimizer = torch.optim.SGD( self.model.parameters(), lr=self.LR, momentum=0 ) self.optimizer.zero_grad() # accumulate .grad over the entire dataset for x, y in self.dl: logits = self.model(x) loss = self.criterion(logits, y) loss.backward() self.effective_batch_grad = torch.cat( [p.grad.reshape(-1) for p in self.model.parameters() if p.requires_grad] ) * (self.BATCH_SIZE / self.DATA_SIZE) self.optimizer.zero_grad() def setUp_privacy_engine(self, batch_size): self.privacy_engine = PrivacyEngine( self.model, batch_size=batch_size, sample_size=self.DATA_SIZE, alphas=self.ALPHAS, noise_multiplier=0, max_grad_norm=999, ) self.privacy_engine.attach(self.optimizer) def test_grad_sample_accumulation(self): """ Calling loss.backward() multiple times should sum up the gradients in .grad and accumulate all the individual gradients in .grad-sample """ self.setUp_privacy_engine(self.DATA_SIZE) for x, y in self.dl: logits = self.model(x) loss = self.criterion(logits, y) loss.backward() # should accumulate grads in .grad and .grad_sample # the accumulated per-sample gradients per_sample_grads = torch.cat( [ p.grad_sample.reshape(self.DATA_SIZE, -1) for p in self.model.parameters() if p.requires_grad ], dim=-1, ) # average up all the per-sample gradients accumulated_grad = torch.mean(per_sample_grads, dim=0) # the full data gradient accumulated in .grad grad = torch.cat( [p.grad.reshape(-1) for p in self.model.parameters() if p.requires_grad] ) * (self.BATCH_SIZE / self.DATA_SIZE) self.optimizer.step() # the accumulated gradients in .grad without any hooks orig_grad = self.effective_batch_grad self.assertTrue( torch.allclose(accumulated_grad, orig_grad, atol=10e-5, rtol=10e-3) ) self.assertTrue(torch.allclose(grad, orig_grad, atol=10e-5, rtol=10e-3)) def test_clipper_accumulation(self): """ Calling optimizer.virtual_step() should accumulate clipped gradients to form one large batch. """ self.setUp_privacy_engine(self.DATA_SIZE) for x, y in self.dl: logits = self.model(x) loss = self.criterion(logits, y) loss.backward() self.optimizer.virtual_step() # the full data gradient accumulated in .grad grad = torch.cat( [p.grad.reshape(-1) for p in self.model.parameters() if p.requires_grad] ) * (self.BATCH_SIZE / self.DATA_SIZE) self.optimizer.step() # .grad should contain the average gradient over the entire dataset accumulated_grad = torch.cat( [p.grad.reshape(-1) for p in self.model.parameters() if p.requires_grad] ) # the accumulated gradients in .grad without any hooks orig_grad = self.effective_batch_grad self.assertTrue( torch.allclose(accumulated_grad, orig_grad, atol=10e-5, rtol=10e-3) ) self.assertTrue(torch.allclose(grad, orig_grad, atol=10e-5, rtol=10e-3)) def test_mixed_accumulation(self): """ Calling loss.backward() multiple times aggregates all per-sample gradients in .grad-sample. Then, calling optimizer.virtual_step() should clip all gradients and aggregate them into one large batch. """ self.setUp_privacy_engine(self.DATA_SIZE) for idx, (x, y) in enumerate(self.dl): logits = self.model(x) loss = self.criterion(logits, y) loss.backward() # accumulate per-sample grads for two mini batches # and then aggregate them in the clipper if (idx + 1) % 2 == 0: self.optimizer.virtual_step() # the full data gradient accumulated in .grad grad = torch.cat( [p.grad.reshape(-1) for p in self.model.parameters() if p.requires_grad] ) * (self.BATCH_SIZE / self.DATA_SIZE) self.optimizer.step() # .grad should contain the average gradient over the entire dataset accumulated_grad = torch.cat( [p.grad.reshape(-1) for p in self.model.parameters() if p.requires_grad] ) # the accumulated gradients in .grad without any hooks orig_grad = self.effective_batch_grad self.assertTrue( torch.allclose(accumulated_grad, orig_grad, atol=10e-5, rtol=10e-3) ) self.assertTrue(torch.allclose(grad, orig_grad, atol=10e-5, rtol=10e-3)) def test_grad_sample_erased(self): """ Calling optimizer.step() should erase any accumulated per-sample gradients. """ self.setUp_privacy_engine(2 * self.BATCH_SIZE) for idx, (x, y) in enumerate(self.dl): logits = self.model(x) loss = self.criterion(logits, y) loss.backward() # accumulate per-sample gradients for two mini-batches to form an # effective batch of size `2*BATCH_SIZE`. Once an effective batch # has been accumulated, we call `optimizer.step()` to clip and # average the per-sample gradients. This should erase the # `grad_sample` fields for each parameter if (idx + 1) % 2 == 0: # take a step. The clipper will compute the mean gradient # for the entire effective batch and populate each parameter's # `.grad` field. self.optimizer.step() for param_name, param in self.model.named_parameters(): if param.requires_grad: self.assertFalse( hasattr(param, "grad_sample"), f"Per-sample gradients haven't been erased " f"for {param_name}", ) else: for param_name, param in self.model.named_parameters(): if param.requires_grad: self.assertTrue( hasattr(param, "grad_sample"), f"Per-sample gradients aren't accumulated " f"for {param_name}", ) def test_summed_grad_erased(self): """ Calling optimizer.step() should erase any accumulated clipped gradients. """ self.setUp_privacy_engine(2 * self.BATCH_SIZE) for idx, (x, y) in enumerate(self.dl): logits = self.model(x) loss = self.criterion(logits, y) loss.backward() # perform a virtual step for each mini-batch # this will accumulate clipped gradients in each parameter's # `summed_grads` field. self.optimizer.virtual_step() # accumulate gradients for two mini-batches to form an # effective batch of size `2*BATCH_SIZE`. Once an effective batch # has been accumulated, we call `optimizer.step()` to compute the # average gradient for the entire batch. This should erase the # `summed_grads` fields for each parameter. if (idx + 1) % 2 == 0: # take a step. The clipper will compute the mean gradient # for the entire effective batch and populate each parameter's # `.grad` field. self.optimizer.step() for param_name, param in self.model.named_parameters(): if param.requires_grad: self.assertFalse( hasattr(param, "summed_grad"), f"Accumulated clipped gradients haven't been erased " f"¨for {param_name}", ) else: for param_name, param in self.model.named_parameters(): if param.requires_grad: self.assertTrue( hasattr(param, "summed_grad"), f"Clipped gradients aren't accumulated " f"for {param_name}", ) def test_throws_wrong_batch_size(self): """ If we accumulate the wrong number of gradients and feed this batch to the privacy engine, we expect a failure. """ self.setUp_privacy_engine(2 * self.BATCH_SIZE) x, y = next(iter(self.dl)) logits = self.model(x) loss = self.criterion(logits, y) loss.backward() self.optimizer.virtual_step() # consuming a batch that is smaller than expected should work with self.assertWarns(Warning): self.optimizer.step() self.optimizer.zero_grad() for _ in range(3): x, y = next(iter(self.dl)) logits = self.model(x) loss = self.criterion(logits, y) loss.backward() self.optimizer.virtual_step() # consuming a larger batch than expected should fail with self.assertRaises(ValueError): self.optimizer.step()