Python torch.var() Examples
The following are 30
code examples of torch.var().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
![](https://www.programcreek.com/common/static/images/search.png)
Example #1
Source File: test_crypten.py From CrypTen with MIT License | 6 votes |
def test_rand(self): """Tests uniform random variable generation on [0, 1)""" for size in [(10,), (10, 10), (10, 10, 10)]: randvec = crypten.rand(*size) self.assertTrue(randvec.size() == size, "Incorrect size") tensor = randvec.get_plain_text() self.assertTrue( (tensor >= 0).all() and (tensor < 1).all(), "Invalid values" ) randvec = crypten.rand(int(1e6)).get_plain_text() mean = torch.mean(randvec) var = torch.var(randvec) self.assertTrue(torch.isclose(mean, torch.Tensor([0.5]), rtol=1e-3, atol=1e-3)) self.assertTrue( torch.isclose(var, torch.Tensor([1.0 / 12]), rtol=1e-3, atol=1e-3) )
Example #2
Source File: functional.py From torch-toolbox with BSD 3-Clause "New" or "Revised" License | 6 votes |
def evo_norm(x, prefix, running_var, v, weight, bias, training, momentum, eps=0.1, groups=32): if prefix == 'b0': if training: var = torch.var(x, dim=(0, 2, 3), keepdim=True) running_var.mul_(momentum) running_var.add_((1 - momentum) * var) else: var = running_var if v is not None: den = torch.max((var + eps).sqrt(), v * x + instance_std(x, eps)) x = x / den * weight + bias else: x = x * weight + bias else: if v is not None: x = x * torch.sigmoid(v * x) / group_std(x, groups, eps) * weight + bias else: x = x * weight + bias return x
Example #3
Source File: test_distance_weighted_miner.py From pytorch-metric-learning with MIT License | 6 votes |
def test_distance_weighted_miner(self): embedding_angles = torch.arange(0, 180) embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=torch.float) #2D embeddings labels = torch.randint(low=0, high=2, size=(180,)) a,_,n = lmu.get_all_triplets_indices(labels) all_an_dist = torch.nn.functional.pairwise_distance(embeddings[a], embeddings[n], 2) min_an_dist = torch.min(all_an_dist) for non_zero_cutoff_int in range(5, 15): non_zero_cutoff = (float(non_zero_cutoff_int) / 10.) - 0.01 miner = DistanceWeightedMiner(0, non_zero_cutoff) a, p, n = miner(embeddings, labels) anchors, positives, negatives = embeddings[a], embeddings[p], embeddings[n] an_dist = torch.nn.functional.pairwise_distance(anchors, negatives, 2) self.assertTrue(torch.max(an_dist)<=non_zero_cutoff) an_dist_var = torch.var(an_dist) an_dist_mean = torch.mean(an_dist) target_var = ((non_zero_cutoff - min_an_dist)**2) / 12 # variance formula for uniform distribution target_mean = (non_zero_cutoff - min_an_dist) / 2 self.assertTrue(torch.abs(an_dist_var-target_var)/target_var < 0.1) self.assertTrue(torch.abs(an_dist_mean-target_mean)/target_mean < 0.1)
Example #4
Source File: norms.py From JEM with Apache License 2.0 | 6 votes |
def forward(self, x, y): means = torch.mean(x, dim=(2, 3)) m = torch.mean(means, dim=-1, keepdim=True) v = torch.var(means, dim=-1, keepdim=True) means = (means - m) / (torch.sqrt(v + 1e-5)) h = self.instance_norm(x) if self.bias: gamma, alpha, beta = self.embed(y).chunk(3, dim=-1) h = h + means[..., None, None] * alpha[..., None, None] out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) else: gamma, alpha = self.embed(y).chunk(2, dim=-1) h = h + means[..., None, None] * alpha[..., None, None] out = gamma.view(-1, self.num_features, 1, 1) * h return out
Example #5
Source File: cond_refinenet_dilated.py From ncsn with GNU General Public License v3.0 | 6 votes |
def forward(self, x, y): means = torch.mean(x, dim=(2, 3)) m = torch.mean(means, dim=-1, keepdim=True) v = torch.var(means, dim=-1, keepdim=True) means = (means - m) / (torch.sqrt(v + 1e-5)) h = self.instance_norm(x) if self.bias: gamma, alpha, beta = self.embed(y).chunk(3, dim=-1) h = h + means[..., None, None] * alpha[..., None, None] out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) else: gamma, alpha = self.embed(y).chunk(2, dim=-1) h = h + means[..., None, None] * alpha[..., None, None] out = gamma.view(-1, self.num_features, 1, 1) * h return out
Example #6
Source File: normalization.py From nsf with MIT License | 6 votes |
def forward(self, inputs, context=None): if inputs.dim() != 2: raise ValueError('Expected 2-dim inputs, got inputs of shape: {}'.format(inputs.shape)) if self.training: mean, var = inputs.mean(0), inputs.var(0) self.running_mean.mul_(1 - self.momentum).add_(mean * self.momentum) self.running_var.mul_(1 - self.momentum).add_(var * self.momentum) else: mean, var = self.running_mean, self.running_var outputs = self.weight * ((inputs - mean) / torch.sqrt((var + self.eps))) + self.bias logabsdet_ = torch.log(self.weight) - 0.5 * torch.log(var + self.eps) logabsdet = torch.sum(logabsdet_) * torch.ones(inputs.shape[0]) return outputs, logabsdet
Example #7
Source File: linear.py From biva-pytorch with MIT License | 6 votes |
def init_parameters(self, x, init_scale=0.05, eps=1e-8): if self.weightnorm: # initial values self.linear._parameters['weight_v'].data.normal_(mean=0, std=init_scale) self.linear._parameters['weight_g'].data.fill_(1.) self.linear._parameters['bias'].data.fill_(0.) init_scale = .01 # data dependent init x = self.linear(x) m_init, v_init = torch.mean(x, 0), torch.var(x, 0) scale_init = init_scale / torch.sqrt(v_init + eps) self.linear._parameters['weight_g'].data = self.linear._parameters['weight_g'].data * scale_init.view( self.linear._parameters['weight_g'].data.size()) self.linear._parameters['bias'].data = self.linear._parameters['bias'].data - m_init * scale_init self.initialized = True + self.initialized return scale_init[None, :] * (x - m_init[None, :])
Example #8
Source File: linear.py From biva-pytorch with MIT License | 6 votes |
def init_parameters(self, x, init_scale=0.05, eps=1e-8): if self.weightnorm: # initial values self.linear._parameters['weight_v'].data.normal_(mean=0, std=init_scale) self.linear._parameters['weight_g'].data.fill_(1.) self.linear._parameters['bias'].data.fill_(0.) init_scale = .01 # data dependent init x = self.linear(x) m_init, v_init = torch.mean(x, 0), torch.var(x, 0) scale_init = init_scale / torch.sqrt(v_init + eps) self.linear._parameters['weight_g'].data = self.linear._parameters['weight_g'].data * scale_init.view( self.linear._parameters['weight_g'].data.size()) self.linear._parameters['bias'].data = self.linear._parameters['bias'].data - m_init * scale_init self.initialized = True + self.initialized return scale_init[None, :] * (x - m_init[None, :])
Example #9
Source File: ppo.py From surreal with MIT License | 6 votes |
def _value_loss(self, obs, returns): """ Computes the loss with current data. also returns a dictionary of statistics which includes value loss and explained variance return: surreal.utils.pytorch.GPUVariable, dict Args: obs: batch of observations in form of (batch_size, obs_dim) returns: batch of N-step return estimate (batch_size,) Returns: loss: Variable for loss stats: dictionary of recorded statistics """ values = self.model.forward_critic(obs, self.cells) if len(values.size()) == 3: values = values.squeeze(2) explained_var = 1 - torch.var(returns - values) / torch.var(returns) loss = (values - returns).pow(2).mean() stats = { '_val_loss': loss.item(), '_val_explained_var': explained_var.item() } return loss, stats
Example #10
Source File: virtualbatchnorm.py From torchgan with MIT License | 6 votes |
def _normalize(self, x, mu, var): r"""Normalizes the tensor ``x`` using the statistics ``mu`` and ``var``. Args: x (torch.Tensor): The Tensor to be normalized. mu (torch.Tensor): Mean using which the Tensor is to be normalized. var (torch.Tensor): Variance used in the normalization of ``x``. Returns: Normalized Tensor ``x``. """ std = torch.sqrt(self.eps + var) x = (x - mu) / std sizes = list(x.size()) for dim, i in enumerate(x.size()): if dim != 1: sizes[dim] = 1 scale = self.scale.view(*sizes) bias = self.bias.view(*sizes) return x * scale + bias
Example #11
Source File: WKPooling.py From sentence-transformers with Apache License 2.0 | 6 votes |
def unify_sentence(self, sentence_feature, one_sentence_embedding): """ Unify Sentence By Token Importance """ sent_len = one_sentence_embedding.size()[0] var_token = torch.zeros(sent_len, device=one_sentence_embedding.device) for token_index in range(sent_len): token_feature = sentence_feature[:, token_index, :] sim_map = self.cosine_similarity_torch(token_feature) var_token[token_index] = torch.var(sim_map.diagonal(-1)) var_token = var_token / torch.sum(var_token) sentence_embedding = torch.mv(one_sentence_embedding.t(), var_token) return sentence_embedding
Example #12
Source File: test_ops.py From tntorch with GNU Lesser General Public License v3.0 | 5 votes |
def test_stats(): def check(): x = t.torch() assert tn.relative_error(tn.mean(t), torch.mean(x)) <= 1e-3 assert tn.relative_error(tn.var(t), torch.var(x)) <= 1e-3 assert tn.relative_error(tn.norm(t), torch.norm(x)) <= 1e-3 shape = [8]*4 for i in range(100): t = random_format(shape) check()
Example #13
Source File: test.py From pytorch-mono-depth with MIT License | 5 votes |
def __init__(self, mean, var): self.mean = mean self.variance = var
Example #14
Source File: test_operators.py From onnx-fb-universe with MIT License | 5 votes |
def test_symbolic_override(self): """Lifted from fast-neural-style: custom implementation of instance norm to be mapped to ONNX operator""" class CustomInstanceNorm(torch.nn.Module): def __init__(self, dim, eps=1e-9): super(CustomInstanceNorm, self).__init__() self.scale = nn.Parameter(torch.FloatTensor(dim).uniform_()) self.shift = nn.Parameter(torch.FloatTensor(dim).zero_()) self.eps = eps def forward(self, x): return self._run_forward(x, self.scale, self.shift, eps=self.eps) @staticmethod @torch.onnx.symbolic_override( lambda g, x, scale, shift, eps: g.op( 'InstanceNormalization', x, scale, shift, epsilon_f=eps) ) def _run_forward(x, scale, shift, eps): # since we hand-roll instance norm it doesn't perform well all in fp16 n = x.size(2) * x.size(3) t = x.view(x.size(0), x.size(1), n) mean = torch.mean(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x) # Calculate the biased var. torch.var returns unbiased var var = torch.var(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x) * ((float(n) - 1) / float(n)) scale_broadcast = scale.unsqueeze(1).unsqueeze(1).unsqueeze(0) scale_broadcast = scale_broadcast.expand_as(x) shift_broadcast = shift.unsqueeze(1).unsqueeze(1).unsqueeze(0) shift_broadcast = shift_broadcast.expand_as(x) out = (x - mean) / torch.sqrt(var + eps) out = out * scale_broadcast + shift_broadcast return out instnorm = CustomInstanceNorm(10) x = Variable(torch.randn(2, 10, 32, 32)) self.assertONNX(instnorm, x)
Example #15
Source File: test_operators.py From onnx-fb-universe with MIT License | 5 votes |
def assertONNX(self, f, args, params=tuple(), **kwargs): if isinstance(f, nn.Module): m = f else: m = FuncModule(f, params) onnx_model_pb = export_to_string(m, args, **kwargs) model_def = self.assertONNXExpected(onnx_model_pb) if _onnx_test: test_function = inspect.stack()[1][0].f_code.co_name test_name = test_function[0:4] + "_operator" + test_function[4:] output_dir = os.path.join(test_onnx_common.pytorch_operator_dir, test_name) # Assume: # 1) the old test should be delete before the test. # 2) only one assertONNX in each test, otherwise will override the data. assert not os.path.exists(output_dir), "{} should not exist!".format(output_dir) os.makedirs(output_dir) with open(os.path.join(output_dir, "model.onnx"), 'wb') as file: file.write(model_def.SerializeToString()) data_dir = os.path.join(output_dir, "test_data_set_0") os.makedirs(data_dir) if isinstance(args, Variable): args = (args,) for index, var in enumerate(flatten(args)): tensor = numpy_helper.from_array(var.data.numpy()) with open(os.path.join(data_dir, "input_{}.pb".format(index)), 'wb') as file: file.write(tensor.SerializeToString()) outputs = m(*args) if isinstance(outputs, Variable): outputs = (outputs,) for index, var in enumerate(flatten(outputs)): tensor = numpy_helper.from_array(var.data.numpy()) with open(os.path.join(data_dir, "output_{}.pb".format(index)), 'wb') as file: file.write(tensor.SerializeToString())
Example #16
Source File: convolution.py From biva-pytorch with MIT License | 5 votes |
def init_parameters(self, x, init_scale=0.05, eps=1e-8): self.initialized = True + self.initialized if self.weightnorm: # initial values self.conv._parameters['weight_v'].data.normal_(mean=0, std=init_scale) self.conv._parameters['weight_g'].data.fill_(1.) self.conv._parameters['bias'].data.fill_(0.) init_scale = .01 # data dependent init x = self.conv(x) t = x.view(x.size()[0], x.size()[1], -1) t = t.permute(0, 2, 1).contiguous() t = t.view(-1, t.size()[-1]) m_init, v_init = torch.mean(t, 0), torch.var(t, 0) scale_init = init_scale / torch.sqrt(v_init + eps) if self.conv.transposed: self.conv._parameters['weight_g'].data = self.conv._parameters['weight_g'].data * scale_init[None, :].view( self.conv._parameters['weight_g'].data.size()) self.conv._parameters['bias'].data = self.conv._parameters['bias'].data - m_init * scale_init else: self.conv._parameters['weight_g'].data = self.conv._parameters['weight_g'].data * scale_init[:, None].view( self.conv._parameters['weight_g'].data.size()) self.conv._parameters['bias'].data = self.conv._parameters['bias'].data - m_init * scale_init return scale_init[None, :, None, None] * (x - m_init[None, :, None, None]) if len( self._input_shp) > 3 else scale_init[None, :, None] * (x - m_init[None, :, None])
Example #17
Source File: normalization.py From ffjord with MIT License | 5 votes |
def stable_var(x, mean=None, dim=1): if mean is None: mean = x.mean(dim, keepdim=True) mean = mean.view(-1, 1) res = torch.pow(x - mean, 2) max_sqr = torch.max(res, dim, keepdim=True)[0] var = torch.mean(res / max_sqr, 1, keepdim=True) * max_sqr var = var.view(-1) # change nan to zero var[var != var] = 0 return var
Example #18
Source File: normalization.py From ffjord with MIT License | 5 votes |
def _forward(self, x, logpx=None): c = x.size(1) used_mean = self.running_mean.clone().detach() used_var = self.running_var.clone().detach() if self.training: # compute batch statistics x_t = x.transpose(0, 1).contiguous().view(c, -1) batch_mean = torch.mean(x_t, dim=1) batch_var = torch.var(x_t, dim=1) # moving average if self.bn_lag > 0: used_mean = batch_mean - (1 - self.bn_lag) * (batch_mean - used_mean.detach()) used_mean /= (1. - self.bn_lag**(self.step[0] + 1)) used_var = batch_var - (1 - self.bn_lag) * (batch_var - used_var.detach()) used_var /= (1. - self.bn_lag**(self.step[0] + 1)) # update running estimates self.running_mean -= self.decay * (self.running_mean - batch_mean.data) self.running_var -= self.decay * (self.running_var - batch_var.data) self.step += 1 # perform normalization used_mean = used_mean.view(*self.shape).expand_as(x) used_var = used_var.view(*self.shape).expand_as(x) y = (x - used_mean) * torch.exp(-0.5 * torch.log(used_var + self.eps)) if self.affine: weight = self.weight.view(*self.shape).expand_as(x) bias = self.bias.view(*self.shape).expand_as(x) y = y * torch.exp(weight) + bias if logpx is None: return y else: return y, logpx - self._logdetgrad(x, used_var).view(x.size(0), -1).sum(1, keepdim=True)
Example #19
Source File: mini_batch_stddev_module.py From pytorch_GAN_zoo with BSD 3-Clause "New" or "Revised" License | 5 votes |
def miniBatchStdDev(x, subGroupSize=4): r""" Add a minibatch standard deviation channel to the current layer. In other words: 1) Compute the standard deviation of the feature map over the minibatch 2) Get the mean, over all pixels and all channels of thsi ValueError 3) expand the layer and cocatenate it with the input Args: - x (tensor): previous layer - subGroupSize (int): size of the mini-batches on which the standard deviation should be computed """ size = x.size() subGroupSize = min(size[0], subGroupSize) if size[0] % subGroupSize != 0: subGroupSize = size[0] G = int(size[0] / subGroupSize) if subGroupSize > 1: y = x.view(-1, subGroupSize, size[1], size[2], size[3]) y = torch.var(y, 1) y = torch.sqrt(y + 1e-8) y = y.view(G, -1) y = torch.mean(y, 1).view(G, 1) y = y.expand(G, size[2]*size[3]).view((G, 1, 1, size[2], size[3])) y = y.expand(G, subGroupSize, -1, -1, -1) y = y.contiguous().view((-1, 1, size[2], size[3])) else: y = torch.zeros(x.size(0), 1, x.size(2), x.size(3), device=x.device) return torch.cat([x, y], dim=1)
Example #20
Source File: networks.py From graphx-conv with MIT License | 5 votes |
def transform(self, pc_feat, img_feat, fc): pc_feat = (pc_feat - T.mean(pc_feat, -1, keepdim=True)) / T.sqrt(T.var(pc_feat, -1, keepdim=True) + 1e-8) mean, var = T.mean(img_feat, (2, 3)), T.var(T.flatten(img_feat, 2), 2) output = (pc_feat + nnt.utils.dimshuffle(mean, (0, 'x', 1))) * T.sqrt( nnt.utils.dimshuffle(var, (0, 'x', 1)) + 1e-8) return fc(output)
Example #21
Source File: misc.py From pydlt with BSD 3-Clause Clear License | 5 votes |
def sub_var(x, width=5): """Calculates variance of a one dimensional Tensor or Array every `width` elements. Args: x (Tensor or Array): 1D Tensor or array. width (int, optional): Width of the kernel. """ if len(x) >= width: if is_array(x): return np.var(slide_window_(x, width, width), -1) else: return torch.var(slide_window_(x, width, width), -1) else: return x.var()
Example #22
Source File: signal_to_noise_ratio_losses.py From pytorch-metric-learning with MIT License | 5 votes |
def SNR_dist(x, y, dim): return torch.var(x-y, dim=dim) / torch.var(x, dim=dim)
Example #23
Source File: test_signal_to_noise_ratio_losses.py From pytorch-metric-learning with MIT License | 5 votes |
def test_snr_contrastive_loss(self): pos_margin, neg_margin, regularizer_weight = 0, 0.1, 0.1 loss_func = SignalToNoiseRatioContrastiveLoss(pos_margin=pos_margin, neg_margin=neg_margin, regularizer_weight=regularizer_weight) embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=torch.float) #2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) loss = loss_func(embeddings, labels) loss.backward() pos_pairs = [(0,1), (1,0), (2,3), (3,2)] neg_pairs = [(0,2), (0,3), (0,4), (1,2), (1,3), (1,4), (2,0), (2,1), (2,4), (3,0), (3,1), (3,4), (4,0), (4,1), (4,2), (4,3)] correct_pos_loss = 0 correct_neg_loss = 0 num_non_zero = 0 for a,p in pos_pairs: anchor, positive = embeddings[a], embeddings[p] curr_loss = torch.relu(torch.var(anchor-positive) / torch.var(anchor) - pos_margin) correct_pos_loss += curr_loss if curr_loss > 0: num_non_zero += 1 if num_non_zero > 0: correct_pos_loss /= num_non_zero num_non_zero = 0 for a,n in neg_pairs: anchor, negative = embeddings[a], embeddings[n] curr_loss = torch.relu(neg_margin - torch.var(anchor-negative) / torch.var(anchor)) correct_neg_loss += curr_loss if curr_loss > 0: num_non_zero += 1 if num_non_zero > 0: correct_neg_loss /= num_non_zero reg_loss = torch.mean(torch.abs(torch.sum(embeddings, dim=1))) correct_total = correct_pos_loss + correct_neg_loss + regularizer_weight*reg_loss self.assertTrue(torch.isclose(loss, correct_total))
Example #24
Source File: metrics.py From vel with MIT License | 5 votes |
def _value_function(self, batch_info): values = batch_info['values'] rewards = batch_info['rewards'] explained_variance = 1 - torch.var(rewards - values) / torch.var(rewards) return explained_variance.item()
Example #25
Source File: functions.py From vel with MIT License | 5 votes |
def explained_variance(returns, values): """ Calculate how much variance in returns do the values explain """ exp_var = 1 - torch.var(returns - values) / torch.var(returns) return exp_var.item()
Example #26
Source File: transformer_net.py From fast-neural-style with MIT License | 5 votes |
def forward(self, x): n = x.size(2) * x.size(3) t = x.view(x.size(0), x.size(1), n) mean = torch.mean(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x) # Calculate the biased var. torch.var returns unbiased var var = torch.var(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x) * ((n - 1) / float(n)) scale_broadcast = self.scale.unsqueeze(1).unsqueeze(1).unsqueeze(0) scale_broadcast = scale_broadcast.expand_as(x) shift_broadcast = self.shift.unsqueeze(1).unsqueeze(1).unsqueeze(0) shift_broadcast = shift_broadcast.expand_as(x) out = (x - mean) / torch.sqrt(var + self.eps) out = out * scale_broadcast + shift_broadcast return out
Example #27
Source File: norms.py From asteroid with MIT License | 5 votes |
def forward(self, x): """ Applies forward pass. Works for any input size > 2D. Args: x (:class:`torch.Tensor`): Shape `[batch, chan, *]` Returns: :class:`torch.Tensor`: gLN_x `[batch, chan, *]` """ dims = list(range(1, len(x.shape))) mean = x.mean(dim=dims, keepdim=True) var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True) return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
Example #28
Source File: norms.py From asteroid with MIT License | 5 votes |
def forward(self, x): """ Applies forward pass. Works for any input size > 2D. Args: x (:class:`torch.Tensor`): `[batch, chan, *]` Returns: :class:`torch.Tensor`: chanLN_x `[batch, chan, *]` """ mean = torch.mean(x, dim=1, keepdim=True) var = torch.var(x, dim=1, keepdim=True, unbiased=False) return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
Example #29
Source File: tasnet.py From demucs with MIT License | 5 votes |
def forward(self, y): """ Args: y: [M, N, K], M is batch size, N is channel size, K is length Returns: cLN_y: [M, N, K] """ mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K] var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K] cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta return cLN_y
Example #30
Source File: tasnet.py From demucs with MIT License | 5 votes |
def forward(self, y): """ Args: y: [M, N, K], M is batch size, N is channel size, K is length Returns: gLN_y: [M, N, K] """ # TODO: in torch 1.0, torch.mean() support dim list mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1] var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta return gLN_y