Python torch.chunk() Examples
The following are 30
code examples of torch.chunk().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: nnet.py From pase with MIT License | 6 votes |
def forward(self, x): h = self.frontend(x) if not self.ft_fe: h = h.detach() if hasattr(self, 'z_bnorm'): h = self.z_bnorm(h) ht, state = self.rnn(h.transpose(1, 2)) if self.return_sequence: ht = ht.transpose(1, 2) else: if not self.uni: # pick last time-step for each dir # first chunk feat dim bsz, slen, feats = ht.size() ht = torch.chunk(ht.view(bsz, slen, 2, feats // 2), 2, dim=2) # now select fwd ht_fwd = ht[0][:, -1, 0, :].unsqueeze(2) ht_bwd = ht[1][:, 0, 0, :].unsqueeze(2) ht = torch.cat((ht_fwd, ht_bwd), dim=1) else: # just last time-step works ht = ht[:, -1, :].unsqueeze(2) y = self.model(ht) return y
Example #2
Source File: modules.py From pase with MIT License | 6 votes |
def format_frontend_chunk(batch, device='cpu'): if type(batch) == dict: if 'chunk_ctxt' and 'chunk_rand' in batch: keys = ['chunk', 'chunk_ctxt', 'chunk_rand', 'cchunk'] # cluster all 'chunk's, including possible 'cchunk' batches = [batch[k] for k in keys if k in batch] x = torch.cat(batches, dim=0).to(device) # store the number of batches condensed as format data_fmt = len(batches) else: x = batch['chunk'].to(device) data_fmt = 1 else: x = batch data_fmt = 0 return x, data_fmt
Example #3
Source File: revnet.py From imgclsmob with MIT License | 6 votes |
def forward(ctx, x, fm, gm, *params): with torch.no_grad(): x1, x2 = torch.chunk(x, chunks=2, dim=1) x1 = x1.contiguous() x2 = x2.contiguous() y1 = x1 + fm(x2) y2 = x2 + gm(y1) y = torch.cat((y1, y2), dim=1) x1.set_() x2.set_() y1.set_() y2.set_() del x1, x2, y1, y2 ctx.save_for_backward(x, y) ctx.fm = fm ctx.gm = gm return y
Example #4
Source File: shufflenetv2b.py From imgclsmob with MIT License | 6 votes |
def forward(self, x): if self.downsample: y1 = self.shortcut_dconv(x) y1 = self.shortcut_conv(y1) x2 = x else: y1, x2 = torch.chunk(x, chunks=2, dim=1) y2 = self.conv1(x2) y2 = self.dconv(y2) y2 = self.conv2(y2) if self.use_se: y2 = self.se(y2) if self.use_residual and not self.downsample: y2 = y2 + x2 x = torch.cat((y1, y2), dim=1) x = self.c_shuffle(x) return x
Example #5
Source File: esim.py From video_captioning_rl with MIT License | 6 votes |
def seq2seq_cross_entropy(logits, label, l, chuck=None, sos_truncate=True): """ :param logits: [exB, V] : exB = sum(l) :param label: [B] : a batch of Label :param l: [B] : a batch of LongTensor indicating the lengths of each inputs :param chuck: Number of chuck to process :return: A loss value """ packed_label = pack_sequence_for_linear(label, l) cross_entropy_loss = functools.partial(F.cross_entropy, size_average=False) total = sum(l) assert total == logits.size(0) or packed_label.size(0) == logits.size(0),\ "logits length mismatch with label length." if chuck: logits_losses = 0 for x, y in zip(torch.chunk(logits, chuck, dim=0), torch.chunk(packed_label, chuck, dim=0)): logits_losses += cross_entropy_loss(x, y) return logits_losses * (1 / total) else: return cross_entropy_loss(logits, packed_label) * (1 / total)
Example #6
Source File: perceptual_loss.py From PerceptualGAN with GNU General Public License v3.0 | 6 votes |
def forward(self, input): input = input.clone() input = self.preprocess_range(input) if self.preprocessing_type == 'caffe': r, g, b = torch.chunk(input, 3, dim=1) bgr = torch.cat([b, g, r], 1) out = bgr * 255 - self.vgg_mean elif self.preprocessing_type == 'pytorch': input = input - self.vgg_mean input = input / self.vgg_std output = input outputs = [] for block in self.blocks: output = block(output) outputs.append(output) return outputs
Example #7
Source File: esim.py From video_captioning_rl with MIT License | 6 votes |
def pack_sequence_for_linear(inputs, lengths, batch_first=True): """ :param inputs: [B, T, D] if batch_first :param lengths: [B] :param batch_first: :return: """ batch_list = [] if batch_first: for i, l in enumerate(lengths): # print(inputs[i, :l].size()) batch_list.append(inputs[i, :l]) packed_sequence = torch.cat(batch_list, 0) # if chuck: # return list(torch.chunk(packed_sequence, chuck, dim=0)) # else: return packed_sequence else: raise NotImplemented()
Example #8
Source File: grassdata.py From grass_pytorch with Apache License 2.0 | 6 votes |
def __init__(self, dir, transform=None): self.dir = dir box_data = torch.from_numpy(loadmat(self.dir+u'/box_data.mat')[u'boxes']).float() op_data = torch.from_numpy(loadmat(self.dir+u'/op_data.mat')[u'ops']).int() sym_data = torch.from_numpy(loadmat(self.dir+u'/sym_data.mat')[u'syms']).float() #weight_list = torch.from_numpy(loadmat(self.dir+'/weights.mat')['weights']).float() num_examples = op_data.size()[1] box_data = torch.chunk(box_data, num_examples, 1) op_data = torch.chunk(op_data, num_examples, 1) sym_data = torch.chunk(sym_data, num_examples, 1) #weight_list = torch.chunk(weight_list, num_examples, 1) self.transform = transform self.trees = [] for i in xrange(len(op_data)) : boxes = torch.t(box_data[i]) ops = torch.t(op_data[i]) syms = torch.t(sym_data[i]) tree = Tree(boxes, ops, syms) self.trees.append(tree)
Example #9
Source File: classifiers.py From pase with MIT License | 6 votes |
def forward(self, x): # input x with shape [B, F, T] # FORWARD THROUGH DRN # ---------------------------- if self.frontend is not None: x = self.frontend(x) if not self.ft_fe: x = x.detach() x = F.pad(x, (4, 5)) x = self.drn(x) # FORWARD THROUGH RNN # ---------------------------- x = x.transpose(1, 2) x, _ = self.rnn(x) xt = torch.chunk(x, x.shape[1], dim=1) x = xt[-1].transpose(1, 2) # FORWARD THROUGH DNn # ---------------------------- x = self.mlp(x) return x
Example #10
Source File: grassdata.py From grass_pytorch with Apache License 2.0 | 6 votes |
def __init__(self, dir, transform=None): self.dir = dir box_data = torch.from_numpy(loadmat(self.dir+'/box_data.mat')['boxes']).float() op_data = torch.from_numpy(loadmat(self.dir+'/op_data.mat')['ops']).int() sym_data = torch.from_numpy(loadmat(self.dir+'/sym_data.mat')['syms']).float() #weight_list = torch.from_numpy(loadmat(self.dir+'/weights.mat')['weights']).float() num_examples = op_data.size()[1] box_data = torch.chunk(box_data, num_examples, 1) op_data = torch.chunk(op_data, num_examples, 1) sym_data = torch.chunk(sym_data, num_examples, 1) #weight_list = torch.chunk(weight_list, num_examples, 1) self.transform = transform self.trees = [] for i in range(len(op_data)) : boxes = torch.t(box_data[i]) ops = torch.t(op_data[i]) syms = torch.t(sym_data[i]) tree = Tree(boxes, ops, syms) self.trees.append(tree)
Example #11
Source File: frontend.py From pase with MIT License | 6 votes |
def forward(self, batch, device=None, mode=None): # batch possible chunk and contexts, or just forward non-dict tensor x, data_fmt = format_frontend_chunk(batch, device) sinc_out = self.sinc(x).unsqueeze(1) # print(sinc_out.shape) conv_out = self.conv1(sinc_out) # print(conv_out.shape) res_out = self.resnet(conv_out) # print(res_out.shape) h =self.conv2(res_out).squeeze(2) # print(h.shape) return format_frontend_output(h, data_fmt, mode)
Example #12
Source File: treelstm.py From treenet with GNU General Public License v3.0 | 6 votes |
def forward(self, inputs, children, arities): i = self.wi_net(inputs) o = self.wo_net(inputs) u = self.wu_net(inputs) f_base = self.wf_net(inputs) fc_sum = inputs.new_zeros(self.memory_size) for k, child in enumerate(children): child_h, child_c = torch.chunk(child, 2, dim=1) i.add_(self.ui_nets[k](child_h)) o.add_(self.uo_nets[k](child_h)) u.add_(self.uu_nets[k](child_h)) f = f_base for l, other_child in enumerate(children): other_child_h, _ = torch.chunk(other_child, 2, dim=1) f = f.add(self.uf_nets[k][l](other_child_h)) fc_sum.add(torch.sigmoid(f) * child_c) c = torch.sigmoid(i) * torch.tanh(u) + fc_sum h = torch.sigmoid(o) * torch.tanh(c) return torch.cat([h, c], dim=1)
Example #13
Source File: train.py From dgl with Apache License 2.0 | 6 votes |
def step(i, j, g, lg, deg_g, deg_lg, pm_pd): """ One step of training. """ deg_g = deg_g.to(dev) deg_lg = deg_lg.to(dev) pm_pd = pm_pd.to(dev) t0 = time.time() z = model(g, lg, deg_g, deg_lg, pm_pd) t_forward = time.time() - t0 z_list = th.chunk(z, args.batch_size, 0) loss = sum(min(F.cross_entropy(z, y) for y in y_list) for z in z_list) / args.batch_size overlap = compute_overlap(z_list) optimizer.zero_grad() t0 = time.time() loss.backward() t_backward = time.time() - t0 optimizer.step() return loss, overlap, t_forward, t_backward
Example #14
Source File: blow.py From blow with Apache License 2.0 | 6 votes |
def forward(self,h,emb): sbatch,nsq,lchunk=h.size() h=h.contiguous() """ # Slower version ws=list(self.adapt_w(emb).view(sbatch,self.ncha,1,self.kw)) bs=list(self.adapt_b(emb)) hs=list(torch.chunk(h,sbatch,dim=0)) out=[] for hi,wi,bi in zip(hs,ws,bs): out.append(torch.nn.functional.conv1d(hi,wi,bias=bi,padding=self.kw//2,groups=nsq)) h=torch.cat(out,dim=0) """ # Faster version fully using group convolution w=self.adapt_w(emb).view(-1,1,self.kw) b=self.adapt_b(emb).view(-1) h=torch.nn.functional.conv1d(h.view(1,-1,lchunk),w,bias=b,padding=self.kw//2,groups=sbatch*nsq).view(sbatch,self.ncha,lchunk) #""" h=self.net.forward(h) s,m=torch.chunk(h,2,dim=1) s=torch.sigmoid(s+2)+1e-7 return s,m ######################################################################################################################## ########################################################################################################################
Example #15
Source File: simple.py From kge with MIT License | 6 votes |
def score_emb(self, s_emb, p_emb, o_emb, combine: str): n = p_emb.size(0) # split left/right s_emb_h, s_emb_t = torch.chunk(s_emb, 2, dim=1) p_emb_forward, p_emb_backward = torch.chunk(p_emb, 2, dim=1) o_emb_h, o_emb_t = torch.chunk(o_emb, 2, dim=1) if combine == "spo": out1 = (s_emb_h * p_emb_forward * o_emb_t).sum(dim=1) out2 = (s_emb_t * p_emb_backward * o_emb_h).sum(dim=1) elif combine == "sp_": out1 = (s_emb_h * p_emb_forward).mm(o_emb_t.transpose(0, 1)) out2 = (s_emb_t * p_emb_backward).mm(o_emb_h.transpose(0, 1)) elif combine == "_po": out1 = (o_emb_t * p_emb_forward).mm(s_emb_h.transpose(0, 1)) out2 = (o_emb_h * p_emb_backward).mm(s_emb_t.transpose(0, 1)) else: return super().score_emb(s_emb, p_emb, o_emb, combine) return (out1 + out2).view(n, -1) / 2.0
Example #16
Source File: transformer.py From naru with Apache License 2.0 | 6 votes |
def forward(self, x, query_input=None): """x: [bs, num cols, d_model]. Output has the same shape.""" assert x.dim() == 3, x.size() bs, ncols, _ = x.size() # [bs, num cols, d_state * 3 * num_heads] qkv = self.qkv_linear(x) # [bs, num heads, num cols, d_state] each qs, ks, vs = map(self._split_heads, torch.chunk(qkv, 3, dim=-1)) if query_input is not None: # TODO: obviously can avoid redundant calc. qkv = self.qkv_linear(query_input) qs, _, _ = map(self._split_heads, torch.chunk(qkv, 3, dim=-1)) # [bs, num heads, num cols, d_state] x = self._do_attention(qs, ks, vs, mask=self.attn_mask.to(x.device)) # [bs, num cols, num heads, d_state] x = x.transpose(1, 2) # Concat all heads' outputs: [bs, num cols, num heads * d_state] x = x.contiguous().view(bs, ncols, -1) # Then do a transform: [bs, num cols, d_model]. x = self.linear(x) return x
Example #17
Source File: modulated_deform_conv.py From mmcv with Apache License 2.0 | 5 votes |
def forward(self, x): out = self.conv_offset(x) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.deform_groups)
Example #18
Source File: segan.py From tfm-franroldan-wav2pix with GNU General Public License v3.0 | 5 votes |
def forward(self, x): #print("Input: {}".format(x.data.shape)) h = x # store intermediate activations int_act = {} for ii, layer in enumerate(self.disc): #print(ii) h, _ = layer(h) #print("After layer: {}".format(h.data.shape)) int_act['h_{}'.format(ii)] = h if self.pool_type == 'rnn': if hasattr(self, 'ln'): h = self.ln(h) int_act['ln_conv'] = h ht, state = self.rnn(h.transpose(1, 2)) h = state[0] # concat both states (fwd, bwd) hfwd, hbwd = torch.chunk(h, 2, 0) h = torch.cat((hfwd, hbwd), dim=2) h = h.squeeze(0) int_act['rnn_h'] = h elif self.pool_type == 'conv': h = self.pool_conv(h) h = h.view(h.size(0), -1) int_act['avg_conv_h'] = h elif self.pool_type == 'none': h = h.view(h.size(0), -1) #print("Final h: {}".format(h.data.shape)) #print(type(h.data)) y = self.fc(h) #print(type(y.data)) int_act['logit'] = y # return F.sigmoid(y), int_act return y, int_act
Example #19
Source File: deform_conv_module.py From DetNAS with MIT License | 5 votes |
def forward(self, input): out = self.conv_offset_mask(input) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) return modulated_deform_conv( input, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.deformable_groups)
Example #20
Source File: treelstm.py From treenet with GNU General Public License v3.0 | 5 votes |
def forward(self, *args, **kwargs): hc = super(TreeLSTM, self).forward(*args, **kwargs) h, _ = torch.chunk(hc, 2, dim=1) return h
Example #21
Source File: irevnet.py From imgclsmob with MIT License | 5 votes |
def forward(self, x, _): x1, x2 = torch.chunk(x, chunks=2, dim=1) return x1, x2
Example #22
Source File: toy.py From vae-lagging-encoder with MIT License | 5 votes |
def plot_multiple(model, plot_data, grid_z, iter_, args): plot_data, sents_len = plot_data plot_data_list = torch.chunk(plot_data, round(args.num_plot / args.batch_size)) infer_posterior_mean = [] report_loss_kl = report_mi = report_num_sample = 0 for data in plot_data_list: report_loss_kl += model.KL(data).sum().item() report_num_sample += data.size(0) report_mi += model.calc_mi_q(data) * data.size(0) # [batch, 1] posterior_mean = model.calc_model_posterior_mean(data, grid_z) infer_mean = model.calc_infer_mean(data) infer_posterior_mean.append(torch.cat([posterior_mean, infer_mean], 1)) # [*, 2] infer_posterior_mean = torch.cat(infer_posterior_mean, 0) save_path = os.path.join(args.plot_dir, 'aggr%d_iter%d_multiple.pickle' % (args.aggressive, iter_)) save_data = {'posterior': infer_posterior_mean[:,0].cpu().numpy(), 'inference': infer_posterior_mean[:,1].cpu().numpy(), 'kl': report_loss_kl / report_num_sample, 'mi': report_mi / report_num_sample } pickle.dump(save_data, open(save_path, 'wb'))
Example #23
Source File: revnet.py From imgclsmob with MIT License | 5 votes |
def backward(ctx, grad_y): fm = ctx.fm gm = ctx.gm x, y = ctx.saved_variables y1, y2 = torch.chunk(y, chunks=2, dim=1) y1 = y1.contiguous() y2 = y2.contiguous() with torch.no_grad(): y1_z = Variable(y1.data, requires_grad=True) x2 = y2 - gm(y1_z) x1 = y1 - fm(x2) with set_grad_enabled(True): x1_ = Variable(x1.data, requires_grad=True) x2_ = Variable(x2.data, requires_grad=True) y1_ = x1_ + fm.forward(x2_) y2_ = x2_ + gm(y1_) y = torch.cat((y1_, y2_), dim=1) dd = torch.autograd.grad(y, (x1_, x2_) + tuple(gm.parameters()) + tuple(fm.parameters()), grad_y) gm_params_len = len([p for p in gm.parameters()]) gm_params_grads = dd[2:2 + gm_params_len] fm_params_grads = dd[2 + gm_params_len:] grad_x = torch.cat((dd[0], dd[1]), dim=1) y1_.detach_() y2_.detach_() del y1_, y2_ x.data.set_(torch.cat((x1, x2), dim=1).data.contiguous()) return (grad_x, None, None) + fm_params_grads + gm_params_grads
Example #24
Source File: modulated_dcn.py From openseg.pytorch with MIT License | 5 votes |
def forward(self, input): out = self.conv_offset_mask(input) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) func = ModulatedDeformConvFunction(self.stride, self.padding, self.dilation, self.deformable_groups) return func(input, offset, mask, self.weight, self.bias)
Example #25
Source File: deform_conv.py From GCNet with Apache License 2.0 | 5 votes |
def forward(self, x): out = self.conv_offset_mask(x) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.deformable_groups)
Example #26
Source File: deform_conv_module.py From Clothing-Detection with GNU General Public License v3.0 | 5 votes |
def forward(self, input): out = self.conv_offset_mask(input) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) return modulated_deform_conv( input, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.deformable_groups)
Example #27
Source File: test_additive_shared.py From PySyft with Apache License 2.0 | 5 votes |
def test_chunk(workers): bob, alice, james = (workers["bob"], workers["alice"], workers["james"]) t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) x = t.share(bob, alice, crypto_provider=james) res0 = torch.chunk(x, 2, dim=0) res1 = torch.chunk(x, 2, dim=1) expected0 = [torch.tensor([[1, 2, 3, 4]]), torch.tensor([[5, 6, 7, 8]])] expected1 = [torch.tensor([[1, 2], [5, 6]]), torch.tensor([[3, 4], [7, 8]])] assert all(((res0[i].get() == expected0[i]).all() for i in range(2))) assert all(((res1[i].get() == expected1[i]).all() for i in range(2)))
Example #28
Source File: average_attn.py From ITDD with MIT License | 5 votes |
def forward(self, inputs, mask=None, layer_cache=None, step=None): """ Args: inputs (`FloatTensor`): `[batch_size x input_len x model_dim]` Returns: (`FloatTensor`, `FloatTensor`): * gating_outputs `[batch_size x 1 x model_dim]` * average_outputs average attention `[batch_size x 1 x model_dim]` """ batch_size = inputs.size(0) inputs_len = inputs.size(1) device = inputs.device average_outputs = self.cumulative_average( inputs, self.cumulative_average_mask(batch_size, inputs_len).to(device).float() if layer_cache is None else step, layer_cache=layer_cache) average_outputs = self.average_layer(average_outputs) gating_outputs = self.gating_layer(torch.cat((inputs, average_outputs), -1)) input_gate, forget_gate = torch.chunk(gating_outputs, 2, dim=2) gating_outputs = torch.sigmoid(input_gate) * inputs + \ torch.sigmoid(forget_gate) * average_outputs return gating_outputs, average_outputs
Example #29
Source File: dcn_v2.py From centerpose with MIT License | 5 votes |
def forward(self, input): out = self.conv_offset_mask(input) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.deformable_groups)
Example #30
Source File: score_fun.py From dgl with Apache License 2.0 | 5 votes |
def edge_func(self, edges): re_head, im_head = th.chunk(edges.src['emb'], 2, dim=-1) re_tail, im_tail = th.chunk(edges.dst['emb'], 2, dim=-1) phase_rel = edges.data['emb'] / (self.emb_init / np.pi) re_rel, im_rel = th.cos(phase_rel), th.sin(phase_rel) re_score = re_head * re_rel - im_head * im_rel im_score = re_head * im_rel + im_head * re_rel re_score = re_score - re_tail im_score = im_score - im_tail score = th.stack([re_score, im_score], dim=0) score = score.norm(dim=0) return {'score': self.gamma - score.sum(-1)}