import torch import numpy import math import torch.nn as nn from options.train_options import TrainOptions opt = TrainOptions().parse() bitsW = opt.bits_w bitsI = opt.bits_i bitsG = opt.bits_g # Scale def S(bits): return 2.0 ** (bits - 1) # Clip function def C(x, bits=32): if bits > 15 or bits == 1 or bits == 2: delta = 0. else: delta = 1. / S(bits) MAX = +1 - delta MIN = -1 + delta x = torch.clamp(x, MIN, MAX) return x # Quantization function Q(x, k) def Q(x, bits): if bits > 15: return x elif bits == 1: return torch.sign(x) elif bits == 2: return torch.round(x) else: SCALE = S(bits) return torch.round(x * SCALE) / SCALE ### QuanInput2d ### class QuanInput(torch.autograd.Function): ''' Quantize the input activations and calculate the mean across channel dimension. ''' # @staticmethod def forward(self, x): self.save_for_backward(x) x = Q(C(x, bitsI), bitsI) return x # @staticmethod def backward(self, grad_output): x, = self.saved_tensors grad_input = grad_output.clone() grad_input[x.ge(1)] = 0 grad_input[x.le(-1)] = 0 return grad_input class QuanInput2d(nn.Module): def __init__(self): super(QuanInput2d, self).__init__() self.layer_type = 'QuanInput2d' def forward(self, x): x = QuanInput()(x) return x ### QuanOp() ### class QuanOp(): def __init__(self, model): # count the number of Conv2d count_Conv2d = 0 for m in model.modules(): if isinstance(m, nn.Conv2d): count_Conv2d = count_Conv2d + 1 start_range = 1 end_range = count_Conv2d - 2 # leave out the first and the last conv2d self.bin_range = numpy.linspace(start_range, end_range, end_range-start_range+1)\ .astype('int').tolist() self.num_of_params = len(self.bin_range) self.saved_params = [] self.target_params = [] self.target_modules = [] index = -1 for m in model.modules(): if isinstance(m, nn.Conv2d): index = index + 1 if index in self.bin_range: tmp = m.weight.data.clone() self.saved_params.append(tmp) self.target_modules.append(m.weight) def quantization(self): self.meancenterConvParams() self.clampConvParams() self.save_params() self.quantizeConvParams() def meancenterConvParams(self): for index in range(self.num_of_params): s = self.target_modules[index].data.size() negMean = self.target_modules[index].data.mean(1, True).\ mul(-1).expand_as(self.target_modules[index].data) self.target_modules[index].data = self.target_modules[index].data.add(negMean) def clampConvParams(self): for index in range(self.num_of_params): self.target_modules[index].data = C(self.target_modules[index].data, bitsG) def save_params(self): for index in range(self.num_of_params): self.saved_params[index].copy_(Q(self.target_modules[index].data, bitsG)) def quantizeConvParams(self): for index in range(self.num_of_params): if bitsW == 1: n = self.target_modules[index].data[0].nelement() s = self.target_modules[index].data.size() m = self.target_modules[index].data.norm(1, 3, True)\ .sum(2, True).sum(1, True).div(n).expand(s) m = Q(m, bitsG) self.target_modules[index].data = self.target_modules[index].data.sign()\ .mul(m) if bitsW == 2: w = self.target_modules[index].data n = self.target_modules[index].data[0].nelement() s = self.target_modules[index].data.size() d = self.target_modules[index].data.norm(1, 3, True)\ .sum(2, True).sum(1, True).div(n).mul(0.7) wt = w for col in range(s[0]): d_col = d[col,0,0,0] wt_neg = w[col,:,:,:].lt(-1.0 * d_col).float().mul(-1) wt_pos = w[col,:,:,:].gt(1.0 * d_col).float() wt[col,:,:,:] = wt_pos.add(wt_neg) self.target_modules[index].data = wt.mul(1) else: self.target_modules[index].data = Q(C(self.target_modules[index].data, bitsW), bitsW) def restore(self): for index in range(self.num_of_params): self.target_modules[index].data.copy_(self.saved_params[index]) # for gradient def updateQuanGradWeight(self): for index in range(self.num_of_params): if bitsW == 1: weight = self.target_modules[index].data n = weight[0].nelement() s = weight.size() m = weight.norm(1, 3, True)\ .sum(2, True).sum(1, True).div(n).expand(s) m[weight.lt(-1.0)] = 0 m[weight.gt(1.0)] = 0 m = Q(m, bitsG) m = m.mul(self.target_modules[index].grad.data) m_add = weight.sign().mul(self.target_modules[index].grad.data) m_add = m_add.sum(3, True)\ .sum(2, True).sum(1, True).div(n).expand(s) m_add = m_add.mul(weight.sign()) self.target_modules[index].grad.data = m.add(m_add).mul(1.0-1.0/s[1]).mul(n) self.target_modules[index].grad.data = Q(C(self.target_modules[index].grad.data, bitsG), bitsG) else: self.target_modules[index].grad.data = Q(C(self.target_modules[index].grad.data, bitsG), bitsG)