# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # -*- coding: utf-8 -*- import itertools import unittest from typing import List, Tuple import numpy as np import torch from fvcore.nn import update_bn_stats from torch import nn class TestPreciseBN(unittest.TestCase): def setUp(self) -> None: torch.set_rng_state(torch.manual_seed(42).get_state()) @staticmethod def compute_bn_stats( tensors: List[torch.Tensor], dims: List[int] ) -> Tuple[np.ndarray, np.ndarray]: """ Given a list of random initialized tensors, compute the mean and variance. Args: tensors (list): list of randomly initialized tensors. dims (list): list of dimensions to compute the mean and variance. """ mean = ( torch.stack([tensor.mean(dim=dims) for tensor in tensors]) .mean(dim=0) .numpy() ) var = ( # pyre-ignore torch.stack([tensor.var(dim=dims) for tensor in tensors]) .mean(dim=0) .numpy() ) return mean, var def test_precise_bn(self) -> None: # Number of batches to test. NB = 8 _bn_types = [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d] _stats_dims = [[0, 2], [0, 2, 3], [0, 2, 3, 4]] _input_dims = [(16, 8, 24), (16, 8, 24, 8), (16, 8, 4, 12, 6)] assert len({len(_bn_types), len(_stats_dims), len(_input_dims)}) == 1 for bn, stats_dim, input_dim in zip(_bn_types, _stats_dims, _input_dims): model = bn(input_dim[1]) model.train() tensors = [torch.randn(input_dim) for _ in range(NB)] mean, var = TestPreciseBN.compute_bn_stats(tensors, stats_dim) old_weight = model.weight.detach().numpy() update_bn_stats(model, itertools.cycle(tensors), NB * 100) self.assertTrue(np.allclose(model.running_mean.numpy(), mean)) self.assertTrue(np.allclose(model.running_var.numpy(), var)) self.assertTrue(np.allclose(model.weight.detach().numpy(), old_weight)) def test_precise_bn_insufficient_data(self) -> None: input_dim = (16, 32, 24, 24) model = nn.BatchNorm2d(input_dim[1]) model.train() tensor = torch.randn(input_dim) with self.assertRaises(AssertionError): update_bn_stats(model, itertools.repeat(tensor, 10), 20)