Python torch.zeros() Examples
The following are 30
code examples of torch.zeros().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: 33_gru_raw.py From deep-learning-note with MIT License | 6 votes |
def get_params(): def _one(shape): ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32) return torch.nn.Parameter(ts, requires_grad=True) def _three(): return (_one((num_inputs, num_hiddens)), _one((num_hiddens, num_hiddens)), torch.nn.Parameter(torch.zeros(num_hiddens, device=device, dtype=torch.float32), requires_grad=True)) W_xz, W_hz, b_z = _three() # 更新门参数 W_xr, W_hr, b_r = _three() # 重置门参数 W_xh, W_hh, b_h = _three() # 候选隐藏层参数 # 输出层参数 W_hq = _one((num_hiddens, num_outputs)) b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, dtype=torch.float32), requires_grad=True) return nn.ParameterList([W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q])
Example #2
Source File: roi_pool.py From Collaborative-Learning-for-Weakly-Supervised-Object-Detection with MIT License | 6 votes |
def forward(self, features, rois): batch_size, num_channels, data_height, data_width = features.size() num_rois = rois.size()[0] output = torch.zeros(num_rois, num_channels, self.pooled_height, self.pooled_width) argmax = torch.IntTensor(num_rois, num_channels, self.pooled_height, self.pooled_width).zero_() if not features.is_cuda: _features = features.permute(0, 2, 3, 1) roi_pooling.roi_pooling_forward(self.pooled_height, self.pooled_width, self.spatial_scale, _features, rois, output) # output = output.cuda() else: output = output.cuda() argmax = argmax.cuda() roi_pooling.roi_pooling_forward_cuda(self.pooled_height, self.pooled_width, self.spatial_scale, features, rois, output, argmax) self.output = output self.argmax = argmax self.rois = rois self.feature_size = features.size() return output
Example #3
Source File: visualize.py From Random-Erasing with Apache License 2.0 | 6 votes |
def colorize(x): ''' Converts a one-channel grayscale image to a color heatmap image ''' if x.dim() == 2: torch.unsqueeze(x, 0, out=x) if x.dim() == 3: cl = torch.zeros([3, x.size(1), x.size(2)]) cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) cl[1] = gauss(x,1,.5,.3) cl[2] = gauss(x,1,.2,.3) cl[cl.gt(1)] = 1 elif x.dim() == 4: cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) cl[:,1,:,:] = gauss(x,1,.5,.3) cl[:,2,:,:] = gauss(x,1,.2,.3) return cl
Example #4
Source File: CRF.py From pytorch_NER_BiLSTM_CNN_CRF with Apache License 2.0 | 6 votes |
def __init__(self, **kwargs): """ kwargs: target_size: int, target size device: str, device """ super(CRF, self).__init__() for k in kwargs: self.__setattr__(k, kwargs[k]) device = self.device # init transitions self.START_TAG, self.STOP_TAG = -2, -1 init_transitions = torch.zeros(self.target_size + 2, self.target_size + 2, device=device) init_transitions[:, self.START_TAG] = -10000.0 init_transitions[self.STOP_TAG, :] = -10000.0 self.transitions = nn.Parameter(init_transitions)
Example #5
Source File: MessageFunction.py From nmp_qc with MIT License | 6 votes |
def m_ggnn(self, h_v, h_w, e_vw, opt={}): m = Variable(torch.zeros(h_w.size(0), h_w.size(1), self.args['out']).type_as(h_w.data)) for w in range(h_w.size(1)): if torch.nonzero(e_vw[:, w, :].data).size(): for i, el in enumerate(self.args['e_label']): ind = (el == e_vw[:,w,:]).type_as(self.learn_args[0][i]) parameter_mat = self.learn_args[0][i][None, ...].expand(h_w.size(0), self.learn_args[0][i].size(0), self.learn_args[0][i].size(1)) m_w = torch.transpose(torch.bmm(torch.transpose(parameter_mat, 1, 2), torch.transpose(torch.unsqueeze(h_w[:, w, :], 1), 1, 2)), 1, 2) m_w = torch.squeeze(m_w) m[:,w,:] = ind.expand_as(m_w)*m_w return m
Example #6
Source File: functions.py From comet-commonsense with Apache License 2.0 | 6 votes |
def set_conceptnet_inputs(input_event, relation, text_encoder, max_e1, max_r, force): abort = False e1_tokens, rel_tokens, _ = data.conceptnet_data.do_example(text_encoder, input_event, relation, None) if len(e1_tokens) > max_e1: if force: XMB = torch.zeros(1, len(e1_tokens) + max_r).long().to(cfg.device) else: XMB = torch.zeros(1, max_e1 + max_r).long().to(cfg.device) return {}, True else: XMB = torch.zeros(1, max_e1 + max_r).long().to(cfg.device) XMB[:, :len(e1_tokens)] = torch.LongTensor(e1_tokens) XMB[:, max_e1:max_e1 + len(rel_tokens)] = torch.LongTensor(rel_tokens) batch = {} batch["sequences"] = XMB batch["attention_mask"] = data.conceptnet_data.make_attention_mask(XMB) return batch, abort
Example #7
Source File: conv_ws.py From mmdetection with Apache License 2.0 | 6 votes |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.register_buffer('weight_gamma', torch.ones(self.out_channels, 1, 1, 1)) self.register_buffer('weight_beta', torch.zeros(self.out_channels, 1, 1, 1))
Example #8
Source File: vocab.py From hgraph2graph with MIT License | 6 votes |
def __init__(self, smiles_pairs, cuda=True): cls = list(zip(*smiles_pairs))[0] self.hvocab = sorted( list(set(cls)) ) self.hmap = {x:i for i,x in enumerate(self.hvocab)} self.vocab = [tuple(x) for x in smiles_pairs] #copy self.inter_size = [count_inters(x[1]) for x in self.vocab] self.vmap = {x:i for i,x in enumerate(self.vocab)} self.mask = torch.zeros(len(self.hvocab), len(self.vocab)) for h,s in smiles_pairs: hid = self.hmap[h] idx = self.vmap[(h,s)] self.mask[hid, idx] = 1000.0 if cuda: self.mask = self.mask.cuda() self.mask = self.mask - 1000.0
Example #9
Source File: encoder.py From hgraph2graph with MIT License | 6 votes |
def embed_sub_tree(self, tree_tensors, hinput, subtree, is_inter_layer): subnode, submess = subtree num_nodes = tree_tensors[0].size(0) fnode, fmess, agraph, bgraph, cgraph, _ = self.get_sub_tensor(tree_tensors, subtree) if is_inter_layer: finput = self.E_i(fnode[:, 1]) hinput = index_select_ND(hinput, 0, cgraph).sum(dim=1) hnode = self.W_i( torch.cat([finput, hinput], dim=-1) ) else: finput = self.E_c(fnode[:, 0]) hinput = hinput.index_select(0, subnode) hnode = self.W_c( torch.cat([finput, hinput], dim=-1) ) if len(submess) == 0: hmess = fmess else: node_buf = torch.zeros(num_nodes, self.hidden_size, device=fmess.device) node_buf = index_scatter(hnode, node_buf, subnode) hmess = node_buf.index_select(index=fmess[:, 0], dim=0) pos_vecs = self.E_pos.index_select(0, fmess[:, 2]) hmess = torch.cat( [hmess, pos_vecs], dim=-1 ) return hnode, hmess, agraph, bgraph
Example #10
Source File: encoder.py From hgraph2graph with MIT License | 6 votes |
def embed_sub_tree(self, tree_tensors, hinput, subtree, is_inter_layer): subnode, submess = subtree num_nodes = tree_tensors[0].size(0) fnode, fmess, agraph, bgraph, cgraph, _ = self.get_sub_tensor(tree_tensors, subtree) if is_inter_layer: finput = self.E_i(fnode[:, 1]) hinput = index_select_ND(hinput, 0, cgraph).sum(dim=1) hnode = self.W_i( torch.cat([finput, hinput], dim=-1) ) else: finput = self.E_c(fnode[:, 0]) hinput = hinput.index_select(0, subnode) hnode = self.W_c( torch.cat([finput, hinput], dim=-1) ) if len(submess) == 0: hmess = fmess else: node_buf = torch.zeros(num_nodes, self.hidden_size, device=fmess.device) node_buf = index_scatter(hnode, node_buf, subnode) hmess = node_buf.index_select(index=fmess[:, 0], dim=0) pos_vecs = self.E_pos.index_select(0, fmess[:, 2]) hmess = torch.cat( [hmess, pos_vecs], dim=-1 ) return hnode, hmess, agraph, bgraph
Example #11
Source File: mol_graph.py From hgraph2graph with MIT License | 6 votes |
def tensorize(mol_batch, vocab, avocab): mol_batch = [MolGraph(x) for x in mol_batch] tree_tensors, tree_batchG = MolGraph.tensorize_graph([x.mol_tree for x in mol_batch], vocab) graph_tensors, graph_batchG = MolGraph.tensorize_graph([x.mol_graph for x in mol_batch], avocab) tree_scope = tree_tensors[-1] graph_scope = graph_tensors[-1] max_cls_size = max( [len(c) for x in mol_batch for c in x.clusters] ) cgraph = torch.zeros(len(tree_batchG) + 1, max_cls_size).int() for v,attr in tree_batchG.nodes(data=True): bid = attr['batch_id'] offset = graph_scope[bid][0] tree_batchG.nodes[v]['inter_label'] = inter_label = [(x + offset, y) for x,y in attr['inter_label']] tree_batchG.nodes[v]['cluster'] = cls = [x + offset for x in attr['cluster']] tree_batchG.nodes[v]['assm_cands'] = [add(x, offset) for x in attr['assm_cands']] cgraph[v, :len(cls)] = torch.IntTensor(cls) all_orders = [] for i,hmol in enumerate(mol_batch): offset = tree_scope[i][0] order = [(x + offset, y + offset, z) for x,y,z in hmol.order[:-1]] + [(hmol.order[-1][0] + offset, None, 0)] all_orders.append(order) tree_tensors = tree_tensors[:4] + (cgraph, tree_scope) return (tree_batchG, graph_batchG), (tree_tensors, graph_tensors), all_orders
Example #12
Source File: vocab.py From hgraph2graph with MIT License | 6 votes |
def __init__(self, smiles_pairs, cuda=True): cls = list(zip(*smiles_pairs))[0] self.hvocab = sorted( list(set(cls)) ) self.hmap = {x:i for i,x in enumerate(self.hvocab)} self.vocab = [tuple(x) for x in smiles_pairs] #copy self.inter_size = [count_inters(x[1]) for x in self.vocab] self.vmap = {x:i for i,x in enumerate(self.vocab)} self.mask = torch.zeros(len(self.hvocab), len(self.vocab)) for h,s in smiles_pairs: hid = self.hmap[h] idx = self.vmap[(h,s)] self.mask[hid, idx] = 1000.0 if cuda: self.mask = self.mask.cuda() self.mask = self.mask - 1000.0
Example #13
Source File: 25_batch_normalization_raw.py From deep-learning-note with MIT License | 5 votes |
def __init__(self, num_features, num_dims): super(BatchNorm, self).__init__() if num_dims == 2: # 全连接 shape = (1, num_features) else: # 卷积 shape = (1, num_features, 1, 1) self.gamma = nn.Parameter(torch.ones(shape)) self.beta = nn.Parameter(torch.zeros(shape)) self.moving_mean = torch.zeros(shape) self.moving_var = torch.zeros(shape)
Example #14
Source File: data.py From VSE-C with MIT License | 5 votes |
def collate_fn(data): """Build mini-batch tensors from a list of (image, caption) tuples. Args: data: list of (image, caption) tuple. - image: torch tensor of shape (3, 256, 256). - caption: torch tensor of shape (?); variable length. Returns: images: torch tensor of shape (batch_size, 3, 256, 256). targets: torch tensor of shape (batch_size, padded_length). lengths: list; valid length for each padded caption. """ # Sort a data list by caption length data.sort(key=lambda x: len(x[1]), reverse=True) images, captions, ids, img_ids = list(zip(*data)) # Merge images (convert tuple of 3D tensor to 4D tensor) images = torch.stack(images, 0) # Merge captions (convert tuple of 1D tensor to 2D tensor) lengths = [len(cap) for cap in captions] targets = torch.zeros(len(captions), max(lengths)).long() for i, cap in enumerate(captions): end = lengths[i] targets[i, :end] = cap[:end] return images, targets, lengths, ids
Example #15
Source File: data.py From VSE-C with MIT License | 5 votes |
def __getitem__(self, index): # handle the image redundancy img_id = index//self.im_div image = torch.Tensor(self.images[img_id]) target = self.convert_to_tensor(self.captions[index]) extended_captions = open(pjoin(self.extended_path, str(index) + '.txt'), 'rb').readlines() if self.data_split == 'train': random.shuffle(extended_captions) extended_captions = [bytes.decode(x).strip() if type(x) == bytes else x.strip() for x in extended_captions] extended_captions = [self.convert_to_tensor(st) for st in extended_captions[:self.num_negative]] if len(extended_captions) < self.num_negative: extended_captions.extend([torch.zeros(target.size()).long()] * (self.num_negative - len(extended_captions))) return image, target, index, img_id, extended_captions
Example #16
Source File: model.py From VSE-C with MIT License | 5 votes |
def forward(self, x, lengths): x_glove = self.glove.index_select(0, x.view(-1)).view(x.size(0), x.size(1), -1) x_semantic = self.embed(x) x = torch.cat((x_semantic, x_glove), dim=2) x = torch.transpose(x, 1, 2).contiguous() x = torch.cat((x, Variable(torch.zeros(x.size(0), x.size(1), 60 - x.size(2)))), dim=2) conv1 = functional.relu(self.conv1(x)) conv2 = functional.relu(self.conv2(conv1)) conv3 = functional.relu(self.conv3(conv2)) rep = self.conv4(conv3).view(x.size(0), -1) # l2-norm rep = l2norm(rep) if self.use_abs: rep = torch.abs(rep) return rep
Example #17
Source File: anchor_generator.py From mmdetection with Apache License 2.0 | 5 votes |
def single_level_valid_flags(self, featmap_size, valid_size, num_base_anchors, device='cuda'): """Generate the valid flags of anchor in a single feature map. Args: featmap_size (tuple[int]): The size of feature maps. valid_size (tuple[int]): The valid size of the feature maps. num_base_anchors (int): The number of base anchors. device (str, optional): Device where the flags will be put on. Defaults to 'cuda'. Returns: torch.Tensor: The valid flags of each anchor in a single level feature map. """ feat_h, feat_w = featmap_size valid_h, valid_w = valid_size assert valid_h <= feat_h and valid_w <= feat_w valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device) valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device) valid_x[:valid_w] = 1 valid_y[:valid_h] = 1 valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) valid = valid_xx & valid_yy valid = valid[:, None].expand(valid.size(0), num_base_anchors).contiguous().view(-1) return valid
Example #18
Source File: utils.py From deep-learning-note with MIT License | 5 votes |
def train_opt(optimizer_fn, states, hyperparams, features, labels, batch_size=10, num_epochs=2): # 初始化模型 net, loss = linreg, squared_loss w = torch.nn.Parameter(torch.tensor(np.random.normal(0, 0.01, size=(features.shape[1], 1)), dtype=torch.float32), requires_grad=True) b = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=True) def eval_loss(): return loss(net(features, w, b), labels).mean().item() ls = [eval_loss()] data_iter = torch.utils.data.DataLoader( torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True) for _ in range(num_epochs): start = time.time() for batch_i, (X, y) in enumerate(data_iter): l = loss(net(X, w, b), y).mean() # 使用平均损失 # 梯度清零 if w.grad is not None: w.grad.data.zero_() b.grad.data.zero_() l.backward() optimizer_fn([w, b], states, hyperparams) # 迭代模型参数 if (batch_i + 1) * batch_size % 100 == 0: ls.append(eval_loss()) # 每100个样本记录下当前训练误差 # 打印结果和作图 print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start)) plt.plot(np.linspace(0, num_epochs, len(ls)), ls) plt.xlabel('epoch') plt.ylabel('loss') plt.show() # 本函数与原书不同的是这里第一个参数优化器函数而不是优化器的名字 # 例如: optimizer_fn=torch.optim.SGD, optimizer_hyperparams={"lr": 0.05}
Example #19
Source File: ghm_loss.py From mmdetection with Apache License 2.0 | 5 votes |
def __init__(self, mu=0.02, bins=10, momentum=0, loss_weight=1.0): super(GHMR, self).__init__() self.mu = mu self.bins = bins edges = torch.arange(bins + 1).float() / bins self.register_buffer('edges', edges) self.edges[-1] = 1e3 self.momentum = momentum if momentum > 0: acc_sum = torch.zeros(bins) self.register_buffer('acc_sum', acc_sum) self.loss_weight = loss_weight # TODO: support reduction parameter
Example #20
Source File: ghm_loss.py From mmdetection with Apache License 2.0 | 5 votes |
def __init__(self, bins=10, momentum=0, use_sigmoid=True, loss_weight=1.0): super(GHMC, self).__init__() self.bins = bins self.momentum = momentum edges = torch.arange(bins + 1).float() / bins self.register_buffer('edges', edges) self.edges[-1] += 1e-6 if momentum > 0: acc_sum = torch.zeros(bins) self.register_buffer('acc_sum', acc_sum) self.use_sigmoid = use_sigmoid if not self.use_sigmoid: raise NotImplementedError self.loss_weight = loss_weight
Example #21
Source File: standard_roi_head.py From mmdetection with Apache License 2.0 | 5 votes |
def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks, img_metas): """Run forward function and calculate loss for mask head in training.""" if not self.share_roi_extractor: pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results]) if pos_rois.shape[0] == 0: return dict(loss_mask=None) mask_results = self._mask_forward(x, pos_rois) else: pos_inds = [] device = bbox_feats.device for res in sampling_results: pos_inds.append( torch.ones( res.pos_bboxes.shape[0], device=device, dtype=torch.uint8)) pos_inds.append( torch.zeros( res.neg_bboxes.shape[0], device=device, dtype=torch.uint8)) pos_inds = torch.cat(pos_inds) if pos_inds.shape[0] == 0: return dict(loss_mask=None) mask_results = self._mask_forward( x, pos_inds=pos_inds, bbox_feats=bbox_feats) mask_targets = self.mask_head.get_targets(sampling_results, gt_masks, self.train_cfg) pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) loss_mask = self.mask_head.loss(mask_results['mask_pred'], mask_targets, pos_labels) mask_results.update(loss_mask=loss_mask, mask_targets=mask_targets) return mask_results
Example #22
Source File: test.py From mmdetection with Apache License 2.0 | 5 votes |
def collect_results_gpu(result_part, size): rank, world_size = get_dist_info() # dump result part to tensor with pickle part_tensor = torch.tensor( bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') # gather all result part tensor shape shape_tensor = torch.tensor(part_tensor.shape, device='cuda') shape_list = [shape_tensor.clone() for _ in range(world_size)] dist.all_gather(shape_list, shape_tensor) # padding result part tensor to max length shape_max = torch.tensor(shape_list).max() part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') part_send[:shape_tensor[0]] = part_tensor part_recv_list = [ part_tensor.new_zeros(shape_max) for _ in range(world_size) ] # gather all result part dist.all_gather(part_recv_list, part_send) if rank == 0: part_list = [] for recv, shape in zip(part_recv_list, shape_list): part_list.append( pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())) # sort the results ordered_results = [] for res in zip(*part_list): ordered_results.extend(list(res)) # the dataloader may pad some samples ordered_results = ordered_results[:size] return ordered_results
Example #23
Source File: utils.py From deep-learning-note with MIT License | 5 votes |
def load_pretrained_embedding(words, pretrained_vocab): """从预训练好的vocab中提取出words对应的词向量""" embed = torch.zeros(len(words), pretrained_vocab.vectors[0].shape[0]) # 初始化为0 oov_count = 0 # out of vocabulary for i, word in enumerate(words): try: idx = pretrained_vocab.stoi[word] embed[i, :] = pretrained_vocab.vectors[idx] except KeyError: oov_count += 0 if oov_count > 0: print("There are %d oov words.") return embed
Example #24
Source File: encoder.py From hgraph2graph with MIT License | 5 votes |
def forward(self, tensors, h, num_nodes, subset): fnode, fmess, agraph, bgraph = tensors subnode, submess = subset if len(submess) > 0: h = self.rnn.sparse_forward(h, fmess, submess, bgraph) nei_message = index_select_ND(self.rnn.get_hidden_state(h), 0, agraph) nei_message = nei_message.sum(dim=1) node_hiddens = torch.cat([fnode, nei_message], dim=1) node_hiddens = self.W_o(node_hiddens) node_buf = torch.zeros(num_nodes, self.hidden_size, device=fmess.device) node_hiddens = index_scatter(node_hiddens, node_buf, subnode) return node_hiddens, h
Example #25
Source File: inc_graph.py From hgraph2graph with MIT License | 5 votes |
def get_mess_feature(self, atom, bond_type, nth_child): f1 = torch.zeros(self.avocab.size()) f2 = torch.zeros(len(MolGraph.BOND_LIST)) f3 = torch.zeros(MolGraph.MAX_POS) symbol, charge = atom.GetSymbol(), atom.GetFormalCharge() f1[ self.avocab[(symbol,charge)] ] = 1 f2[ MolGraph.BOND_LIST.index(bond_type) ] = 1 f3[ nth_child ] = 1 return torch.cat( [f1,f2,f3], dim=-1 ).cuda()
Example #26
Source File: 44_adadelta.py From deep-learning-note with MIT License | 5 votes |
def init_adadelta_states(): s_w, s_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32) delta_w, delta_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32) return ((s_w, delta_w), (s_b, delta_b))
Example #27
Source File: inc_graph.py From hgraph2graph with MIT License | 5 votes |
def get_atom_feature(self, atom): f = torch.zeros(self.avocab.size()) symbol, charge = atom.GetSymbol(), atom.GetFormalCharge() f[ self.avocab[(symbol,charge)] ] = 1 return f.cuda()
Example #28
Source File: 45_adam.py From deep-learning-note with MIT License | 5 votes |
def init_adam_states(): v_w, v_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32) s_w, s_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32) return ((v_w, s_w), (v_b, s_b))
Example #29
Source File: 18_conv2d.py From deep-learning-note with MIT License | 5 votes |
def corr2d(X, K): h, w = K.shape Y = torch.zeros((X.shape[0] -h+1, X.shape[1] - w + 1)) for i in range(Y.shape[0]): for j in range(Y.shape[1]): Y[i, j] = (X[i:i+h, j:j+w] * K).sum() return Y
Example #30
Source File: detectron2pytorch.py From mmdetection with Apache License 2.0 | 5 votes |
def convert_bn(blobs, state_dict, caffe_name, torch_name, converted_names): # detectron replace bn with affine channel layer state_dict[torch_name + '.bias'] = torch.from_numpy(blobs[caffe_name + '_b']) state_dict[torch_name + '.weight'] = torch.from_numpy(blobs[caffe_name + '_s']) bn_size = state_dict[torch_name + '.weight'].size() state_dict[torch_name + '.running_mean'] = torch.zeros(bn_size) state_dict[torch_name + '.running_var'] = torch.ones(bn_size) converted_names.add(caffe_name + '_b') converted_names.add(caffe_name + '_s')