Python torch.sqrt() Examples

The following are 30 code examples of torch.sqrt(). You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File:    From AerialDetection with Apache License 2.0
def map_roi_levels(self, rois, num_levels):
        """Map rrois to corresponding feature levels by scales.

        - scale < finest_scale: level 0
        - finest_scale <= scale < finest_scale * 2: level 1
        - finest_scale * 2 <= scale < finest_scale * 4: level 2
        - scale >= finest_scale * 4: level 3

            rois (Tensor): Input RRoIs, shape (k, 6). (index, x, y, w, h, angle)
            num_levels (int): Total level number.

            Tensor: Level index (0-based) of each RoI, shape (k, )
        scale = torch.sqrt(rois[:, 3] * rois[:, 4])
        target_lvls = torch.floor(torch.log2(scale / self.finest_scale + 1e-6))
        target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
        return target_lvls 
Example #2
Source File:    From cascade-rcnn_Pytorch with MIT License
def forward(self, input1):
        self.batchgrid3d = torch.zeros(torch.Size([input1.size(0)]) + self.grid3d.size())

        for i in range(input1.size(0)):
            self.batchgrid3d[i] = self.grid3d

        self.batchgrid3d = Variable(self.batchgrid3d)

        x = torch.sum(torch.mul(self.batchgrid3d, input1[:,:,:,0:4]), 3)
        y = torch.sum(torch.mul(self.batchgrid3d, input1[:,:,:,4:8]), 3)
        z = torch.sum(torch.mul(self.batchgrid3d, input1[:,:,:,8:]), 3)
        r = torch.sqrt(x**2 + y**2 + z**2) + 1e-5

        theta = torch.acos(z/r)/(np.pi/2)  - 1
        #phi = torch.atan(y/x)
        phi = torch.atan(y/(x + 1e-5))  + np.pi * * ( -
        phi = phi/np.pi

        output =[theta,phi], 3)

        return output 
Example #3
Source File:    From pruning_yolov3 with GNU General Public License v3.0
def fuse_conv_and_bn(conv, bn):
    with torch.no_grad():
        # init
        fusedconv = torch.nn.Conv2d(conv.in_channels,

        # prepare filters
        w_conv = conv.weight.clone().view(conv.out_channels, -1)
        w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
        fusedconv.weight.copy_(, w_conv).view(fusedconv.weight.size()))

        # prepare spatial bias
        if conv.bias is not None:
            b_conv = conv.bias
            b_conv = torch.zeros(conv.weight.size(0))
        b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
        fusedconv.bias.copy_(b_conv + b_bn)

        return fusedconv 
Example #4
Source File:    From JEM with Apache License 2.0
def forward(self, x, y):
        means = torch.mean(x, dim=(2, 3))
        m = torch.mean(means, dim=-1, keepdim=True)
        v = torch.var(means, dim=-1, keepdim=True)
        means = (means - m) / (torch.sqrt(v + 1e-5))
        h = self.instance_norm(x)

        if self.bias:
            gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
            h = h + means[..., None, None] * alpha[..., None, None]
            out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
            gamma, alpha = self.embed(y).chunk(2, dim=-1)
            h = h + means[..., None, None] * alpha[..., None, None]
            out = gamma.view(-1, self.num_features, 1, 1) * h
        return out 
Example #5
Source File:    From torch-toolbox with BSD 3-Clause "New" or "Revised" License
def evo_norm(x, prefix, running_var, v, weight, bias,
             training, momentum, eps=0.1, groups=32):
    if prefix == 'b0':
        if training:
            var = torch.var(x, dim=(0, 2, 3), keepdim=True)
            running_var.add_((1 - momentum) * var)
            var = running_var
        if v is not None:
            den = torch.max((var + eps).sqrt(), v * x + instance_std(x, eps))
            x = x / den * weight + bias
            x = x * weight + bias
        if v is not None:
            x = x * torch.sigmoid(v * x) / group_std(x,
                                                     groups, eps) * weight + bias
            x = x * weight + bias

    return x 
Example #6
Source File:    From deep-learning-note with MIT License
def batch_norm(is_training, X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # 训练模式和预测模式逻辑不同
    if not is_training:
        # 预测模式下,直接使用传入的移动平均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # 使用全连接层,二维数组,计算特征维上的均值和方差
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
            # 使用卷积层,三维数组
            mean = X.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
            var = ((X - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
        # 训练模式下用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) *  var
    Y = gamma * X_hat + beta # 拉伸和偏移
    return Y, moving_mean, moving_var 
Example #7
Source File:    From JEM with Apache License 2.0
def cond_samples(f, replay_buffer, args, device, fresh=False):
    sqrt = lambda x: int(t.sqrt(t.Tensor([x])))
    plot = lambda p, x: tv.utils.save_image(t.clamp(x, -1, 1), p, normalize=True, nrow=sqrt(x.size(0)))

    if fresh:
        replay_buffer = uncond_samples(f, args, device, save=False)
    n_it = replay_buffer.size(0) // 100
    all_y = []
    for i in range(n_it):
        x = replay_buffer[i * 100: (i + 1) * 100].to(device)
        y = f.classify(x).max(1)[1]

    all_y =, 0)
    each_class = [replay_buffer[all_y == l] for l in range(10)]
    print([len(c) for c in each_class])
    for i in range(100):
        this_im = []
        for l in range(10):
            this_l = each_class[l][i * 10: (i + 1) * 10]
        this_im =, 0)
        if this_im.size(0) > 0:
            plot('{}/samples_{}.png'.format(args.save_dir, i), this_im)
Example #8
Source File:    From H3DNet with MIT License
def forward(ctx, unknown, known):
        # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
            Find the three nearest neighbors of unknown in known
        unknown : torch.Tensor
            (B, n, 3) tensor of known features
        known : torch.Tensor
            (B, m, 3) tensor of unknown features

        dist : torch.Tensor
            (B, n, 3) l2 distance to the three nearest neighbors
        idx : torch.Tensor
            (B, n, 3) index of 3 nearest neighbors
        dist2, idx = _ext.three_nn(unknown, known)

        return torch.sqrt(dist2), idx 
Example #9
Source File:    From DIB-R with MIT License
def sample(verts, faces, num=10000, ret_choice = False):
    dist_uni = torch.distributions.Uniform(torch.tensor([0.0]).cuda(), torch.tensor([1.0]).cuda())
    x1,x2,x3 = torch.split(torch.index_select(verts, 0, faces[:,0]) - torch.index_select(verts, 0, faces[:,1]), 1, dim = 1)
    y1,y2,y3 = torch.split(torch.index_select(verts, 0, faces[:,1]) - torch.index_select(verts, 0, faces[:,2]), 1, dim = 1)
    a = (x2*y3 - x3*y2)**2
    b = (x3*y1 - x1*y3)**2
    c = (x1*y2 - x2*y1)**2
    Areas = torch.sqrt(a+b+c)/2
    Areas = Areas / torch.sum(Areas)
    cat_dist = torch.distributions.Categorical(Areas.view(-1))
    choices = cat_dist.sample_n(num)
    select_faces = faces[choices]
    xs = torch.index_select(verts, 0,select_faces[:,0])
    ys = torch.index_select(verts, 0,select_faces[:,1])
    zs = torch.index_select(verts, 0,select_faces[:,2])
    u = torch.sqrt(dist_uni.sample_n(num))
    v = dist_uni.sample_n(num)
    points = (1- u)*xs + (u*(1-v))*ys + u*v*zs
    if ret_choice:
        return points, choices
        return points 
Example #10
Source File:    From prunnable-layers-pytorch with GNU General Public License v3.0
def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        if self.convert_to_onnx:
            x = self.classifier[0](x)

            # manually perform 1d batchnorm, caffe2 currently requires a resize,
            # which is hard to squeeze into the exported network
            bn_1d = self.classifier[1]
            numerator = (x - Variable(bn_1d.running_mean))
            denominator = Variable(torch.sqrt(bn_1d.running_var + bn_1d.eps))
            x = numerator/denominator*Variable( + Variable(

            x = self.classifier[2](x)
            x = self.classifier[3](x)
            x = self.classifier[4](x)
            return x
            x = self.classifier(x)
            return x 
Example #11
Source File:    From connecting_the_dots with MIT License
def tforward(self, disp, edge=None):

    if edge is not None:
      grad = self.sobel(disp)
      grad = torch.sqrt(grad[:,0:1,...]**2 + grad[:,1:2,...]**2 + 1e-8)
      pdf = (1-edge)/self.b0 * torch.exp(-torch.abs(grad)/self.b0) + \
            edge/self.b1 * torch.exp(-torch.abs(grad)/self.b1)
      val = torch.mean(-torch.log(pdf.clamp(min=1e-4)))
      # on qifeng's data we don't have ambient info
      # therefore we supress edge everywhere
      grad = self.sobel(disp)
      grad = torch.sqrt(grad[:,0:1,...]**2 + grad[:,1:2,...]**2 + 1e-8)
      grad= torch.clamp(grad, 0, 1.0)
      val = torch.mean(grad)

    return val 
Example #12
Source File:    From mmdetection with Apache License 2.0
def map_roi_levels(self, rois, num_levels):
        """Map rois to corresponding feature levels by scales.

        - scale < finest_scale * 2: level 0
        - finest_scale * 2 <= scale < finest_scale * 4: level 1
        - finest_scale * 4 <= scale < finest_scale * 8: level 2
        - scale >= finest_scale * 8: level 3

            rois (Tensor): Input RoIs, shape (k, 5).
            num_levels (int): Total level number.

            Tensor: Level index (0-based) of each RoI, shape (k, )
        scale = torch.sqrt(
            (rois[:, 3] - rois[:, 1]) * (rois[:, 4] - rois[:, 2]))
        target_lvls = torch.floor(torch.log2(scale / self.finest_scale + 1e-6))
        target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
        return target_lvls 
Example #13
Source File:    From mmdetection with Apache License 2.0
def centerness_target(self, anchors, bbox_targets):
        # only calculate pos centerness targets, otherwise there may be nan
        gts = self.bbox_coder.decode(anchors, bbox_targets)
        anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2
        anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2
        l_ = anchors_cx - gts[:, 0]
        t_ = anchors_cy - gts[:, 1]
        r_ = gts[:, 2] - anchors_cx
        b_ = gts[:, 3] - anchors_cy

        left_right = torch.stack([l_, r_], dim=1)
        top_bottom = torch.stack([t_, b_], dim=1)
        centerness = torch.sqrt(
            (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) *
            (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]))
        assert not torch.isnan(centerness).any()
        return centerness 
Example #14
Source File:    From mmdetection with Apache License 2.0
def centerness_target(self, pos_bbox_targets):
        """Compute centerness targets.

            pos_bbox_targets (Tensor): BBox targets of positive bboxes in shape
                (num_pos, 4)

            Tensor: Centerness target.
        # only calculate pos centerness targets, otherwise there may be nan
        left_right = pos_bbox_targets[:, [0, 2]]
        top_bottom = pos_bbox_targets[:, [1, 3]]
        centerness_targets = (
            left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
                top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
        return torch.sqrt(centerness_targets) 
Example #15
Source File:    From comet-commonsense with Apache License 2.0
def _attn(self, q, k, v, sequence_mask):
        w = torch.matmul(q, k)
        if self.scale:
            w = w / math.sqrt(v.size(-1))

        b_subset = self.b[:, :, :w.size(-2), :w.size(-1)]

        if sequence_mask is not None:
            b_subset = b_subset * sequence_mask.view(
                sequence_mask.size(0), 1, -1)
            b_subset = b_subset.permute(1, 0, 2, 3)

        w = w * b_subset + -1e9 * (1 - b_subset)
        w = nn.Softmax(dim=-1)(w)
        w = self.attn_dropout(w)
        return torch.matmul(w, v) 
Example #16
Source File:    From AerialDetection with Apache License 2.0
def map_roi_levels(self, rois, num_levels):
        """Map rois to corresponding feature levels by scales.

        - scale < finest_scale: level 0
        - finest_scale <= scale < finest_scale * 2: level 1
        - finest_scale * 2 <= scale < finest_scale * 4: level 2
        - scale >= finest_scale * 4: level 3

            rois (Tensor): Input RoIs, shape (k, 5).
            num_levels (int): Total level number.

            Tensor: Level index (0-based) of each RoI, shape (k, )
        scale = torch.sqrt(
            (rois[:, 3] - rois[:, 1] + 1) * (rois[:, 4] - rois[:, 2] + 1))
        target_lvls = torch.floor(torch.log2(scale / self.finest_scale + 1e-6))
        target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
        return target_lvls 
Example #17
Source File:    From Semantic-Aware-Scene-Recognition with MIT License
def step(self, closure):

        obj = float(closure())

        for group in self.param_groups:
            wd = group['weight_decay']
            if wd:
                for p in group['params']:
                    obj += 0.5 * wd * ** 2
           += wd *

        grad_sqrd_norm = 0
        for group in self.param_groups:
            for p in group['params']:
                grad_sqrd_norm += ** 2

        step_size = float(obj / (torch.sqrt(grad_sqrd_norm) + self.eps))

        for group in self.param_groups:
            L = group['L']
            mu = group['momentum']
            for p in group['params']:
                v = self.state[p]['v']
                v *= mu
                v -= step_size / L *
       += v

        self.gamma = step_size 
Example #18
Source File:    From FormulaNet with BSD 3-Clause "New" or "Revised" License
def l2norm2d(inputs, k):
    # k dimension to normalize
    norm = torch.sqrt(torch.sum(inputs * inputs, k)) + 1e-12
    return inputs / norm.expand_as(inputs) 
Example #19
Source File:    From connecting_the_dots with MIT License
def tforward(self,x):
    x = F.pad(x, (2,2,2,2), "replicate")
    gx = self.conv_x(x)
    gy = self.conv_y(x)
    if self.norm:
      return torch.sqrt(gx**2 + gy**2 + 1e-8)
      return, gy), dim=1) 
Example #20
Source File:    From connecting_the_dots with MIT License
def tforward(self, data):
    boxs = self.box_conv(data)

    avgs = boxs / (2*self.radius+1)**2
    boxs_n2 = boxs**2
    boxs_2n = self.box_conv(data**2)

    stds = torch.sqrt(boxs_2n / (2*self.radius+1)**2 - avgs**2 + 1e-6)
    stds = stds + self.epsilon

    return (data - avgs) / stds, stds 
Example #21
Source File:    From Res2Net-maskrcnn with MIT License
def __call__(self, boxlists):
            boxlists (list[BoxList])
        # Compute level ids
        s = torch.sqrt(cat([boxlist.area() for boxlist in boxlists]))

        # Eqn.(1) in FPN paper
        target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0 + self.eps))
        target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
        return - self.k_min 
Example #22
Source File:    From AdaptiveWingLoss with Apache License 2.0
def forward(self, input_tensor):
            input_tensor: shape(batch, channel, x_dim, y_dim)
        batch_size, _, x_dim, y_dim = input_tensor.size()

        xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
        yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)

        xx_channel = xx_channel / (x_dim - 1)
        yy_channel = yy_channel / (y_dim - 1)

        xx_channel = xx_channel * 2 - 1
        yy_channel = yy_channel * 2 - 1

        xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
        yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)

        if input_tensor.is_cuda:
            xx_channel = xx_channel.cuda()
            yy_channel = yy_channel.cuda()

        ret =[
            yy_channel.type_as(input_tensor)], dim=1)

        if self.with_r:
            rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2))
            if input_tensor.is_cuda:
                rr = rr.cuda()
            ret =[ret, rr], dim=1)

        return ret 
Example #23
Source File:    From pointnet-registration-framework with MIT License
def log(g):
    eps = 1.0e-7
    R = g.view(-1, 3, 3)
    tr = btrace(R)
    c = (tr - 1) / 2
    t = torch.acos(c)
    sc = sinc1(t)
    idx0 = (torch.abs(sc) <= eps)
    idx1 = (torch.abs(sc) > eps)
    sc = sc.view(-1, 1, 1)

    X = torch.zeros_like(R)
    if idx1.any():
        X[idx1] = (R[idx1] - R[idx1].transpose(1, 2)) / (2*sc[idx1])

    if idx0.any():
        # t[idx0] == math.pi
        t2 = t[idx0] ** 2
        A = (R[idx0] + torch.eye(3).type_as(R).unsqueeze(0)) * t2.view(-1, 1, 1) / 2
        aw1 = torch.sqrt(A[:, 0, 0])
        aw2 = torch.sqrt(A[:, 1, 1])
        aw3 = torch.sqrt(A[:, 2, 2])
        sgn_3 = torch.sign(A[:, 0, 2])
        sgn_3[sgn_3 == 0] = 1
        sgn_23 = torch.sign(A[:, 1, 2])
        sgn_23[sgn_23 == 0] = 1
        sgn_2 = sgn_23 * sgn_3
        w1 = aw1
        w2 = aw2 * sgn_2
        w3 = aw3 * sgn_3
        w = torch.stack((w1, w2, w3), dim=-1)
        W = mat(w)
        X[idx0] = W

    x = vec(X.view_as(g))
    return x 
Example #24
Source File:    From audio with BSD 2-Clause "Simplified" License
def create_dct(
        n_mfcc: int,
        n_mels: int,
        norm: Optional[str]
) -> Tensor:
    r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``),
    normalized depending on norm.

        n_mfcc (int): Number of mfc coefficients to retain
        n_mels (int): Number of mel filterbanks
        norm (str or None): Norm to use (either 'ortho' or None)

        Tensor: The transformation matrix, to be right-multiplied to
        row-wise data of size (``n_mels``, ``n_mfcc``).
    n = torch.arange(float(n_mels))
    k = torch.arange(float(n_mfcc)).unsqueeze(1)
    dct = torch.cos(math.pi / float(n_mels) * (n + 0.5) * k)  # size (n_mfcc, n_mels)
    if norm is None:
        dct *= 2.0
        assert norm == "ortho"
        dct[0] *= 1.0 / math.sqrt(2.0)
        dct *= math.sqrt(2.0 / float(n_mels))
    return dct.t() 
Example #25
Source File:    From JEM with Apache License 2.0
def forward(self, x, y):
        if self.init:
            scale, bias = self.embed(y).chunk(2, dim=-1)
            return x * scale[:, :, None, None] + bias[:, :, None, None]
            m, v = torch.mean(x, dim=(0, 2, 3)), torch.var(x, dim=(0, 2, 3))
            std = torch.sqrt(v + 1e-5)
            scale_init = 1. / std
            bias_init = -1. * m / std
  [:, :self.num_features] = scale_init[None].repeat(self.num_classes, 1)
  [:, self.num_features:] = bias_init[None].repeat(self.num_classes, 1)
            self.init = True
            return self(x, y) 
Example #26
Source File:    From cmrc2019 with Creative Commons Attribution Share Alike 4.0 International
def forward(self, x):
            u = x.mean(-1, keepdim=True)
            s = (x - u).pow(2).mean(-1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.variance_epsilon)
            return self.weight * x + self.bias 
Example #27
Source File:    From audio with BSD 2-Clause "Simplified" License
def treble_biquad(
        waveform: Tensor,
        sample_rate: int,
        gain: float,
        central_freq: float = 3000,
        Q: float = 0.707
) -> Tensor:
    r"""Design a treble tone-control effect.  Similar to SoX implementation.

        waveform (Tensor): audio waveform of dimension of `(..., time)`
        sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
        gain (float): desired gain at the boost (or attenuation) in dB.
        central_freq (float, optional): central frequency (in Hz). (Default: ``3000``)
        Q (float, optional): (Default: ``0.707``).

        Tensor: Waveform of dimension of `(..., time)`

    w0 = 2 * math.pi * central_freq / sample_rate
    alpha = math.sin(w0) / 2 / Q
    A = math.exp(gain / 40 * math.log(10))

    temp1 = 2 * math.sqrt(A) * alpha
    temp2 = (A - 1) * math.cos(w0)
    temp3 = (A + 1) * math.cos(w0)

    b0 = A * ((A + 1) + temp2 + temp1)
    b1 = -2 * A * ((A - 1) + temp3)
    b2 = A * ((A + 1) + temp2 - temp1)
    a0 = (A + 1) - temp2 + temp1
    a1 = 2 * ((A - 1) - temp3)
    a2 = (A + 1) - temp2 - temp1

    return biquad(waveform, b0, b1, b2, a0, a1, a2) 
Example #28
Source File:    From audio with BSD 2-Clause "Simplified" License
def bass_biquad(
        waveform: Tensor,
        sample_rate: int,
        gain: float,
        central_freq: float = 100,
        Q: float = 0.707
) -> Tensor:
    r"""Design a bass tone-control effect.  Similar to SoX implementation.

        waveform (Tensor): audio waveform of dimension of `(..., time)`
        sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
        gain (float): desired gain at the boost (or attenuation) in dB.
        central_freq (float, optional): central frequency (in Hz). (Default: ``100``)
        Q (float, optional): (Default: ``0.707``).

        Tensor: Waveform of dimension of `(..., time)`

    w0 = 2 * math.pi * central_freq / sample_rate
    alpha = math.sin(w0) / 2 / Q
    A = math.exp(gain / 40 * math.log(10))

    temp1 = 2 * math.sqrt(A) * alpha
    temp2 = (A - 1) * math.cos(w0)
    temp3 = (A + 1) * math.cos(w0)

    b0 = A * ((A + 1) - temp2 + temp1)
    b1 = 2 * A * ((A - 1) - temp3)
    b2 = A * ((A + 1) - temp2 - temp1)
    a0 = (A + 1) + temp2 + temp1
    a1 = -2 * ((A - 1) + temp3)
    a2 = (A + 1) + temp2 - temp1

    return biquad(waveform, b0 / a0, b1 / a0, b2 / a0, a0 / a0, a1 / a0, a2 / a0) 
Example #29
Source File:    From JEM with Apache License 2.0
def uncond_samples(f, args, device, save=True):
    sqrt = lambda x: int(t.sqrt(t.Tensor([x])))
    plot = lambda p, x: tv.utils.save_image(t.clamp(x, -1, 1), p, normalize=True, nrow=sqrt(x.size(0)))

    replay_buffer = t.FloatTensor(args.buffer_size, 3, 32, 32).uniform_(-1, 1)
    for i in range(args.n_sample_steps):
        samples = sample_q(args, device, f, replay_buffer)
        if i % args.print_every == 0 and save:
            plot('{}/samples_{}.png'.format(args.save_dir, i), samples)
    return replay_buffer 
Example #30
Source File:    From prunnable-layers-pytorch with GNU General Public License v3.0
def __estimate_taylor_importance(self, _, grad_input, grad_output):
        # skip dim=1, its the dim for depth
        n_batch, _, n_x, n_y = self.__recent_activations.size()
        n_parameters = n_batch * n_x * n_y

        estimates = self.__recent_activations.mul_(grad_output[0]) \
            .sum(dim=3) \
            .sum(dim=2) \
            .sum(dim=0) \

        # normalization
        self.taylor_estimates = torch.abs(estimates) / torch.sqrt(torch.sum(estimates * estimates))
        del estimates, self.__recent_activations
        self.__recent_activations = None