Python torch.eq() Examples

The following are 30 code examples of torch.eq(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: test_context_conditioned_policy.py    From garage with MIT License 6 votes vote down vote up
def test_update_context(self):
        """Test update_context."""
        s = TimeStep(env_spec=self.env_spec,
                     observation=np.ones(self.obs_dim),
                     next_observation=np.ones(self.obs_dim),
                     action=np.ones(self.action_dim),
                     reward=1.0,
                     terminal=False,
                     env_info={},
                     agent_info={})
        updates = 10
        for _ in range(updates):
            self.module.update_context(s)
        assert torch.all(
            torch.eq(self.module.context,
                     torch.ones(updates, self.encoder_input_dim))) 
Example #2
Source File: test_wrappers.py    From torchbearer with MIT License 6 votes vote down vote up
def test_train(self):
        self._metric.train()
        calls = [[torch.FloatTensor([0.0]), torch.LongTensor([0])],
                 [torch.FloatTensor([0.0, 0.1, 0.2, 0.3]), torch.LongTensor([0, 1, 2, 3])]]
        for i in range(len(self._states)):
            self._metric.process(self._states[i])
        self.assertEqual(2, len(self._metric_function.call_args_list))
        for i in range(len(self._metric_function.call_args_list)):
            self.assertTrue(torch.eq(self._metric_function.call_args_list[i][0][0], calls[i][0]).all)
            self.assertTrue(torch.lt(torch.abs(torch.add(self._metric_function.call_args_list[i][0][1], -calls[i][1])), 1e-12).all)
        self._metric_function.reset_mock()
        self._metric.process_final({})

        self.assertEqual(self._metric_function.call_count, 1)
        self.assertTrue(torch.eq(self._metric_function.call_args_list[0][0][1], torch.LongTensor([0, 1, 2, 3, 4])).all)
        self.assertTrue(torch.lt(torch.abs(torch.add(self._metric_function.call_args_list[0][0][0], -torch.FloatTensor([0.0, 0.1, 0.2, 0.3, 0.4]))), 1e-12).all) 
Example #3
Source File: sequence_labeling.py    From GraphIE with GNU General Public License v3.0 6 votes vote down vote up
def decode(self, input_word_orig, input_word, input_char, adjs, target=None, mask=None, length=None, hx=None,
               leading_symbolic=0, graph_types=['coref']):
        # output from rnn [batch, length, tag_space]

        output, target, sent_mask, length, _ = self._get_gcn_output(input_word_orig, input_word, input_char, adjs,
                                                                    target,
                                                                    mask=mask, length=length, hx=hx,
                                                                    leading_symbolic=leading_symbolic,
                                                                    graph_types=graph_types)

        if target is None:
            return self.crf.decode(output, mask=sent_mask, leading_symbolic=leading_symbolic), None

        preds = self.crf.decode(output, mask=sent_mask,
                                leading_symbolic=leading_symbolic)
        if mask is None:
            return preds, torch.eq(preds, target).float().sum()
        else:
            return preds, (torch.eq(preds, target).float() * sent_mask).sum() 
Example #4
Source File: sequence_labeling.py    From GraphIE with GNU General Public License v3.0 6 votes vote down vote up
def decode(self, input_word_orig, input_word, input_char, _, target=None, mask=None, length=None, hx=None,
               leading_symbolic=0):
        if len(input_word.size()) == 3:
            # input_word is the packed sents [n_sent, sent_len]
            input_word, input_char, target, sent_mask, length, doc_n_sent = self._doc2sent(
                input_word, input_char, target)
        # output from rnn [batch, length, tag_space]
        output, _, mask, length = self._get_rnn_output(input_word_orig, input_word, input_char, mask=mask,
                                                       length=length, hx=hx)

        if target is None:
            return self.crf.decode(output, mask=mask, leading_symbolic=leading_symbolic), None

        if length is not None:
            max_len = length.max()
            target = target[:, :max_len]

        preds = self.crf.decode(output, mask=mask, leading_symbolic=leading_symbolic)
        if mask is None:
            return preds, torch.eq(preds, target).float().sum()
        else:
            return preds, (torch.eq(preds, target).float() * mask).sum() 
Example #5
Source File: som.py    From USIP with GNU General Public License v3.0 6 votes vote down vote up
def query(self, x):
        '''
        :param x: input data CxN tensor
        :return: mask: Nxnode_num
        '''
        # expand as CxNxnode_num
        node = self.node.unsqueeze(1).expand(x.size(0), x.size(1), self.rows * self.cols)
        x_expanded = x.unsqueeze(2).expand_as(node)

        # calcuate difference between x and each node
        diff = x_expanded - node  # CxNxnode_num
        diff_norm = (diff ** 2).sum(dim=0)  # Nxnode_num

        # find the nearest neighbor
        _, min_idx = torch.min(diff_norm, dim=1)  # N
        min_idx_expanded = min_idx.unsqueeze(1).expand(min_idx.size()[0], self.rows * self.cols).float()  # Nxnode_num

        node_idx_list = self.node_idx_list.unsqueeze(0).expand_as(min_idx_expanded)  # Nxnode_num
        mask = torch.eq(min_idx_expanded, node_idx_list).float()  # Nxnode_num
        mask_row_max, _ = torch.max(mask, dim=0)  # node_num, this indicates whether the node has nearby x

        return mask, mask_row_max 
Example #6
Source File: test_dataset.py    From kge with MIT License 6 votes vote down vote up
def test_data_pickle_correctness(self):
        # this will create new pickle files for train, valid, test
        dataset = Dataset.create(
            config=self.config, folder=self.dataset_folder, preload_data=True
        )

        # create new dataset which loads the triples from stored pckl files
        dataset_load_by_pickle = Dataset.create(
            config=self.config, folder=self.dataset_folder, preload_data=True
        )
        for split in dataset._triples.keys():
            self.assertTrue(
                torch.all(
                    torch.eq(dataset_load_by_pickle.split(split), dataset.split(split))
                )
            )
        self.assertEqual(dataset._meta, dataset_load_by_pickle._meta) 
Example #7
Source File: test_continuous_mlp_q_function.py    From garage with MIT License 6 votes vote down vote up
def test_forward(self, hidden_sizes):
        env_spec = GarageEnv(DummyBoxEnv())
        obs_dim = env_spec.observation_space.flat_dim
        act_dim = env_spec.action_space.flat_dim
        obs = torch.ones(obs_dim, dtype=torch.float32).unsqueeze(0)
        act = torch.ones(act_dim, dtype=torch.float32).unsqueeze(0)

        qf = ContinuousMLPQFunction(env_spec=env_spec,
                                    hidden_nonlinearity=None,
                                    hidden_sizes=hidden_sizes,
                                    hidden_w_init=nn.init.ones_,
                                    output_w_init=nn.init.ones_)

        output = qf(obs, act)
        expected_output = torch.full([1, 1],
                                     fill_value=(obs_dim + act_dim) *
                                     np.prod(hidden_sizes),
                                     dtype=torch.float32)
        assert torch.eq(output, expected_output)

    # yapf: disable 
Example #8
Source File: test_continuous_mlp_q_function.py    From garage with MIT License 6 votes vote down vote up
def test_is_pickleable(self, hidden_sizes):
        env_spec = GarageEnv(DummyBoxEnv())
        obs_dim = env_spec.observation_space.flat_dim
        act_dim = env_spec.action_space.flat_dim
        obs = torch.ones(obs_dim, dtype=torch.float32).unsqueeze(0)
        act = torch.ones(act_dim, dtype=torch.float32).unsqueeze(0)

        qf = ContinuousMLPQFunction(env_spec=env_spec,
                                    hidden_nonlinearity=None,
                                    hidden_sizes=hidden_sizes,
                                    hidden_w_init=nn.init.ones_,
                                    output_w_init=nn.init.ones_)

        output1 = qf(obs, act)

        p = pickle.dumps(qf)
        qf_pickled = pickle.loads(p)
        output2 = qf_pickled(obs, act)

        assert torch.eq(output1, output2) 
Example #9
Source File: precision.py    From UnsupervisedGeometryAwareRepresentationLearning with GNU General Public License v3.0 6 votes vote down vote up
def update(self, output):
        y_pred, y = output
        num_classes = y_pred.size(1)
        indices = torch.max(y_pred, 1)[1]
        correct = torch.eq(indices, y)
        pred_onehot = to_onehot(indices, num_classes)
        all_positives = pred_onehot.sum(dim=0)
        if correct.sum() == 0:
            true_positives = torch.zeros_like(all_positives)
        else:
            correct_onehot = to_onehot(indices[correct], num_classes)
            true_positives = correct_onehot.sum(dim=0)
        if self._all_positives is None:
            self._all_positives = all_positives
            self._true_positives = true_positives
        else:
            self._all_positives += all_positives
            self._true_positives += true_positives 
Example #10
Source File: recall.py    From UnsupervisedGeometryAwareRepresentationLearning with GNU General Public License v3.0 6 votes vote down vote up
def update(self, output):
        y_pred, y = output
        num_classes = y_pred.size(1)
        indices = torch.max(y_pred, 1)[1]
        correct = torch.eq(indices, y)
        actual_onehot = to_onehot(y, num_classes)
        actual = actual_onehot.sum(dim=0)
        if correct.sum() == 0:
            true_positives = torch.zeros_like(actual)
        else:
            correct_onehot = to_onehot(indices[correct], num_classes)
            true_positives = correct_onehot.sum(dim=0)
        if self._actual is None:
            self._actual = actual
            self._true_positives = true_positives
        else:
            self._actual += actual
            self._true_positives += true_positives 
Example #11
Source File: test_coding.py    From nn-compression with MIT License 6 votes vote down vote up
def test_encode_param():
    param = torch.rand(256, 128, 3, 3)
    prune_vanilla_elementwise(sparsity=0.7, param=param)
    quantize_linear_fix_zeros(param, k=16)
    huffman = EncodedParam(param=param, method='huffman',
                           encode_indices=True, bit_length_zero_run_length=4)
    stats = huffman.stats
    print(stats)
    assert torch.eq(param, huffman.data).all()
    state_dict = huffman.state_dict()
    huffman = EncodedParam()
    huffman.load_state_dict(state_dict)
    assert torch.eq(param, huffman.data).all()
    vanilla = EncodedParam(param=param, method='vanilla',
                           encode_indices=True, bit_length_zero_run_length=4)
    stats = vanilla.stats
    print(stats)
    assert torch.eq(param, vanilla.data).all()
    quantize_fixed_point(param=param, bit_length=4, bit_length_integer=0)
    fixed_point = EncodedParam(param=param, method='fixed_point',
                               bit_length=4, bit_length_integer=0,
                               encode_indices=True, bit_length_zero_run_length=4)
    stats = fixed_point.stats
    print(stats)
    assert torch.eq(param, fixed_point.data).all() 
Example #12
Source File: hinge.py    From dfw with MIT License 6 votes vote down vote up
def _compute_xi(self, s, aug, y):

        # find argmax of augmented scores
        _, y_star = torch.max(aug, 1)
        # xi_max: one-hot encoding of maximal indices
        xi_max = torch.eq(y_star[:, None], self._range).float()

        if MultiClassHingeLoss.smooth:
            # find smooth argmax of scores
            xi_smooth = nn.functional.softmax(s, dim=1)
            # compute for each sample whether it has a positive contribution to the loss
            losses = torch.sum(xi_smooth * aug, 1)
            mask_smooth = torch.ge(losses, 0).float()[:, None]
            # keep only smoothing for positive contributions
            xi = mask_smooth * xi_smooth + (1 - mask_smooth) * xi_max
        else:
            xi = xi_max

        return xi 
Example #13
Source File: model.py    From VSE-C with MIT License 6 votes vote down vote up
def forward(self, feed_dict):
        feed_dict = GView(feed_dict)
        feature_f = self._extract_sent_feature(feed_dict.sent_f, feed_dict.sent_f_length, self.gru_f)
        feature_b = self._extract_sent_feature(feed_dict.sent_b, feed_dict.sent_b_length, self.gru_b)
        feature_img = feed_dict.image
        
        feature = torch.cat([feature_f, feature_b, feature_img], dim=1)
        predict = self.predict(feature)

        if self.training:
            label = self.embedding(feed_dict.label)
            loss = cosine_loss(predict, label).mean()
            return loss, {}, {}
        else:
            output_dict = dict(pred=predict)
            if 'label' in feed_dict:
                dis = cosine_distance(predict, self.embedding.weight)
                _, topk = dis.topk(1000, dim=1, sorted=True)
                for k in [1, 10, 100, 1000]:
                    output_dict['top{}'.format(k)] = torch.eq(topk, feed_dict.label.unsqueeze(-1))[:, :k].float().sum(dim=1).mean()
            return output_dict 
Example #14
Source File: accuracy.py    From LaSO with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def update(self, output):

        y_pred, y = self._check_shape(output)
        self._check_type((y_pred, y))

        if self._type == "binary":
            correct = torch.eq(y_pred.type(y.type()), y).view(-1)
        elif self._type == "multiclass":
            indices = torch.max(y_pred, dim=1)[1]
            correct = torch.eq(indices, y).view(-1)
        elif self._type == "multilabel":
            # if y, y_pred shape is (N, C, ...) -> (N x ..., C)
            num_classes = y_pred.size(1)
            last_dim = y_pred.ndimension()
            y_pred = torch.transpose(y_pred, 1, last_dim - 1).reshape(-1, num_classes)
            y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes)
            correct = torch.all(y == y_pred.type_as(y), dim=-1)

        self._num_correct += torch.sum(correct).item()
        self._num_examples += correct.shape[0] 
Example #15
Source File: metrics.py    From LaSO with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def update(self, output):
        y_pred, y = output
        y_pred = torch.sigmoid(y_pred)
        y_pred = (y_pred > 0.5).float()
        correct = torch.eq(y_pred, y).view(-1)
        self._num_correct += torch.sum(correct).item()
        self._num_examples += correct.shape[0] 
Example #16
Source File: categorical_accuracy.py    From argus with MIT License 5 votes vote down vote up
def update(self, step_output: dict):
        pred = step_output['prediction']
        trg = step_output['target']
        indices = torch.max(pred, dim=1)[1]
        correct = torch.eq(indices, trg).view(-1)
        self.correct += torch.sum(correct).item()
        self.count += correct.shape[0] 
Example #17
Source File: loader.py    From AGGCN with MIT License 5 votes vote down vote up
def __getitem__(self, key):
        """ Get a batch with index. """
        if not isinstance(key, int):
            raise TypeError
        if key < 0 or key >= len(self.data):
            raise IndexError
        batch = self.data[key]
        batch_size = len(batch)
        batch = list(zip(*batch))

        # for nary dataset
        assert len(batch) == 9

        # sort all fields by lens for easy RNN operations
        lens = [len(x) for x in batch[0]]
        batch, orig_idx = sort_all(batch, lens)

        # word dropout
        if not self.eval:
            words = [word_dropout(sent, self.opt['word_dropout']) for sent in batch[0]]
        else:
            words = batch[0]

        # convert to tensors
        words = get_long_tensor(words, batch_size)
        masks = torch.eq(words, 0)
        pos = get_long_tensor(batch[1], batch_size)
        deprel = get_long_tensor(batch[2], batch_size)
        head = get_long_tensor(batch[3], batch_size)
        first_positions = get_long_tensor(batch[4], batch_size)
        second_positions = get_long_tensor(batch[5], batch_size)
        third_positions = get_long_tensor(batch[6], batch_size)
        cross = batch[7]
        rels = torch.LongTensor(batch[8])

        return (words, masks, pos, deprel, head, first_positions, second_positions, third_positions, cross, rels, orig_idx) 
Example #18
Source File: loader.py    From AGGCN with MIT License 5 votes vote down vote up
def __getitem__(self, key):
        """ Get a batch with index. """
        if not isinstance(key, int):
            raise TypeError
        if key < 0 or key >= len(self.data):
            raise IndexError
        batch = self.data[key]
        batch_size = len(batch)
        batch = list(zip(*batch))
        if dataset == 'dataset/tacred':
            assert len(batch) == 10
        else:
            assert len(batch) == 7

        # sort all fields by lens for easy RNN operations
        lens = [len(x) for x in batch[0]]
        batch, orig_idx = sort_all(batch, lens)

        # word dropout
        if not self.eval:
            words = [word_dropout(sent, self.opt['word_dropout']) for sent in batch[0]]
        else:
            words = batch[0]

        # convert to tensors
        words = get_long_tensor(words, batch_size)
        masks = torch.eq(words, 0)
        pos = get_long_tensor(batch[1], batch_size)
        deprel = get_long_tensor(batch[2], batch_size)
        head = get_long_tensor(batch[3], batch_size)
        subj_positions = get_long_tensor(batch[4], batch_size)
        obj_positions = get_long_tensor(batch[5], batch_size)
        rels = torch.LongTensor(batch[6])
        return (words, masks, pos, deprel, head, subj_positions, obj_positions, rels, orig_idx) 
Example #19
Source File: loader.py    From AGGCN with MIT License 5 votes vote down vote up
def __getitem__(self, key):
        """ Get a batch with index. """
        if not isinstance(key, int):
            raise TypeError
        if key < 0 or key >= len(self.data):
            raise IndexError
        batch = self.data[key]
        batch_size = len(batch)
        batch = list(zip(*batch))

        # for nary dataset
        assert len(batch) == 8

        # sort all fields by lens for easy RNN operations
        lens = [len(x) for x in batch[0]]
        batch, orig_idx = sort_all(batch, lens)

        # word dropout
        if not self.eval:
            words = [word_dropout(sent, self.opt['word_dropout']) for sent in batch[0]]
        else:
            words = batch[0]

        # convert to tensors
        words = get_long_tensor(words, batch_size)
        masks = torch.eq(words, 0)
        pos = get_long_tensor(batch[1], batch_size)
        deprel = get_long_tensor(batch[2], batch_size)
        head = get_long_tensor(batch[3], batch_size)
        first_positions = get_long_tensor(batch[4], batch_size)
        second_positions = get_long_tensor(batch[5], batch_size)
        cross = batch[6]
        rels = torch.LongTensor(batch[7])

        return (words, masks, pos, deprel, head, first_positions, second_positions, cross, rels, orig_idx) 
Example #20
Source File: loader.py    From AGGCN with MIT License 5 votes vote down vote up
def __getitem__(self, key):
        """ Get a batch with index. """
        if not isinstance(key, int):
            raise TypeError
        if key < 0 or key >= len(self.data):
            raise IndexError
        batch = self.data[key]
        batch_size = len(batch)
        batch = list(zip(*batch))
        assert len(batch) == 10

        # sort all fields by lens for easy RNN operations
        lens = [len(x) for x in batch[0]]
        batch, orig_idx = sort_all(batch, lens)

        # word dropout
        if not self.eval:
            words = [word_dropout(sent, self.opt['word_dropout']) for sent in batch[0]]
        else:
            words = batch[0]

        # convert to tensors
        words = get_long_tensor(words, batch_size)
        masks = torch.eq(words, 0)
        pos = get_long_tensor(batch[1], batch_size)
        ner = get_long_tensor(batch[2], batch_size)
        deprel = get_long_tensor(batch[3], batch_size)
        head = get_long_tensor(batch[4], batch_size)
        subj_positions = get_long_tensor(batch[5], batch_size)
        obj_positions = get_long_tensor(batch[6], batch_size)
        subj_type = get_long_tensor(batch[7], batch_size)
        obj_type = get_long_tensor(batch[8], batch_size)

        rels = torch.LongTensor(batch[9])

        return (words, masks, pos, ner, deprel, head, subj_positions, obj_positions, subj_type, obj_type, rels, orig_idx) 
Example #21
Source File: som.py    From USIP with GNU General Public License v3.0 5 votes vote down vote up
def query_topk(node, x, M, k):
    '''
    :param node: SOM node of BxCxM tensor
    :param x: input data BxCxN tensor
    :param M: number of SOM nodes
    :param k: topk
    :return: mask: Nxnode_num
    '''
    # ensure x, and other stored tensors are in the same device
    device = x.device
    node = node.to(x.device)
    node_idx_list = torch.from_numpy(np.arange(M).astype(np.int64)).to(device)  # node_num LongTensor

    # expand as BxCxNxnode_num
    node = node.unsqueeze(2).expand(x.size(0), x.size(1), x.size(2), M)
    x_expanded = x.unsqueeze(3).expand_as(node)

    # calcuate difference between x and each node
    diff = x_expanded - node  # BxCxNxnode_num
    diff_norm = (diff ** 2).sum(dim=1)  # BxNxnode_num

    # find the nearest neighbor
    _, min_idx = torch.topk(diff_norm, k=k, dim=2, largest=False, sorted=False)  # BxNxk
    min_idx_expanded = min_idx.unsqueeze(2).expand(min_idx.size()[0], min_idx.size()[1], M, k)  # BxNxnode_numxk

    node_idx_list = node_idx_list.unsqueeze(0).unsqueeze(0).unsqueeze(3).expand_as(
        min_idx_expanded).long()  # BxNxnode_numxk
    mask = torch.eq(min_idx_expanded, node_idx_list).int()  # BxNxnode_numxk
    # mask = torch.sum(mask, dim=3)  # BxNxnode_num

    # debug
    B, N, M = mask.size()[0], mask.size()[1], mask.size()[2]
    mask = mask.permute(0, 2, 3, 1).contiguous().view(B, M, k*N).permute(0, 2, 1).contiguous()  # BxMxkxN -> BxMxkN -> BxkNxM
    min_idx = min_idx.permute(0, 2, 1).contiguous().view(B, k*N)

    mask_row_max, _ = torch.max(mask, dim=1)  # Bxnode_num, this indicates whether the node has nearby x

    return mask, mask_row_max, min_idx 
Example #22
Source File: transforms.py    From PyTorch-ENet with MIT License 5 votes vote down vote up
def __call__(self, tensor):
        """Performs the conversion from ``torch.LongTensor`` to a ``PIL image``

        Keyword arguments:
        - tensor (``torch.LongTensor``): the tensor to convert

        Returns:
        A ``PIL.Image``.

        """
        # Check if label_tensor is a LongTensor
        if not isinstance(tensor, torch.LongTensor):
            raise TypeError("label_tensor should be torch.LongTensor. Got {}"
                            .format(type(tensor)))
        # Check if encoding is a ordered dictionary
        if not isinstance(self.rgb_encoding, OrderedDict):
            raise TypeError("encoding should be an OrderedDict. Got {}".format(
                type(self.rgb_encoding)))

        # label_tensor might be an image without a channel dimension, in this
        # case unsqueeze it
        if len(tensor.size()) == 2:
            tensor.unsqueeze_(0)

        color_tensor = torch.ByteTensor(3, tensor.size(1), tensor.size(2))

        for index, (class_name, color) in enumerate(self.rgb_encoding.items()):
            # Get a mask of elements equal to index
            mask = torch.eq(tensor, index).squeeze_()
            # Fill color_tensor with corresponding colors
            for channel, color_value in enumerate(color):
                color_tensor[channel].masked_fill_(mask, color_value)

        return ToPILImage()(color_tensor) 
Example #23
Source File: BayesianConvs.py    From UCB with MIT License 5 votes vote down vote up
def prune_module(self, mask):
        self.mask_flag = True 
        self.pruned_weight_mu=self.weight_mu.data.mul_(mask)
        # self.pruned_weight_rho=self.weight_rho.data.mul_(mask)
        # pruning_mask = torch.eq(mask, torch.zeros_like(mask)) 
Example #24
Source File: sequence_labeling.py    From GraphIE with GNU General Public License v3.0 5 votes vote down vote up
def loss(self, input_word_orig, input_word, input_char, target, mask=None, length=None, hx=None, leading_symbolic=0,
             show_net=False):
        # [batch, length, tag_space]
        output, mask, length = self.forward(input_word_orig, input_word, input_char, mask=mask, length=length, hx=hx)
        # [batch, length, num_labels]
        output = self.dense_softmax(output)
        # preds = [batch, length]
        _, preds = torch.max(output[:, :, leading_symbolic:], dim=2)
        preds += leading_symbolic

        output_size = output.size()
        # [batch * length, num_labels]
        output_size = (output_size[0] * output_size[1], output_size[2])
        output = output.view(output_size)

        if length is not None and target.size(1) != mask.size(1):
            max_len = length.max()
            target = target[:, :max_len].contiguous()

        if mask is not None:
            return (self.nll_loss(self.logsoftmax(output), target.view(-1)) * mask.contiguous().view(
                -1)).sum() / mask.sum(), \
                   (torch.eq(preds, target).type_as(mask) * mask).sum(), preds
        else:
            num = output_size[0] * output_size[1]
            return self.nll_loss(self.logsoftmax(output), target.view(-1)).sum() / num, \
                   (torch.eq(preds, target).type_as(output)).sum(), preds 
Example #25
Source File: test_dataset.py    From kge with MIT License 5 votes vote down vote up
def assertEqualTorch(self, first, second, msg=None):
        """Compares first and second using ==, except for PyTorch tensors,
        where `torch.eq` is used."""

        # TODO factor out to utility class
        self.assertEqual(type(first), type(second), msg=msg)
        if isinstance(first, dict):
            self.assertEqual(len(first), len(second), msg=msg)
            for key in first.keys():
                self.assertTrue(key in second, msg=msg)
                self.assertEqualTorch(first[key], second[key], msg=msg)
        elif isinstance(first, list):
            self.assertEqual(len(first), len(second), msg=msg)
            for i in range(len(first)):
                self.assertEqualTorch(first[i], second[i], msg=msg)
        elif isinstance(first, KvsAllIndex):
            first_attributes = [a for a in dir(first) if not a.startswith("__")]
            second_attributes = [a for a in dir(second) if not a.startswith("__")]
            for first_attribute, second_attribute in zip(
                first_attributes, second_attributes
            ):
                self.assertEqualTorch(first_attribute, second_attribute)
        else:
            if type(first) is torch.Tensor:
                self.assertTrue(torch.all(torch.eq(first, second)), msg=msg)
            else:
                self.assertEqual(first, second, msg=msg) 
Example #26
Source File: test_coding.py    From nn-compression with MIT License 5 votes vote down vote up
def test_codec():
    quantize_rule = [
        ('0.weight', 'k-means', 4, 'k-means++'),
        ('1.weight', 'fixed_point', 6, 1),
    ]
    model = torch.nn.Sequential(torch.nn.Conv2d(256, 128, 3, bias=True),
                                torch.nn.Conv2d(128, 512, 1, bias=False))
    mask_dict = {}
    for n, p in model.named_parameters():
        mask_dict[n] = prune_vanilla_elementwise(sparsity=0.6, param=p.data)
    quantizer = Quantizer(rule=quantize_rule, fix_zeros=True)
    quantizer.quantize(model, update_labels=False, verbose=True)
    rule = [
        ('0.weight', 'huffman', 0, 0, 4),
        ('1.weight', 'fixed_point', 6, 1, 4)
    ]
    codec = Codec(rule=rule)
    encoded_module = codec.encode(model)
    print(codec.stats)
    state_dict = encoded_module.state_dict()
    model_2 = torch.nn.Sequential(torch.nn.Conv2d(256, 128, 3, bias=True),
                                  torch.nn.Conv2d(128, 512, 1, bias=False))
    model_2 = Codec.decode(model_2, state_dict)
    for p1, p2 in zip(model.parameters(), model_2.parameters()):
        if p1.dim() > 1:
            assert torch.eq(p1, p2).all() 
Example #27
Source File: linear.py    From nn-compression with MIT License 5 votes vote down vote up
def quantize_linear_fix_zeros(param, k=16, **unused):
    """
    linearly quantize while fixing zeros
    :param param: torch.(cuda.)tensor
    :param k: int, the number of quantization level, default=16
    :param unused: unused options
    :return:
        dict, {'centers_': torch.tensor}, codebook of quantization
    """
    zero_mask = torch.eq(param, 0.0)  # get zero mask
    num_param = param.numel()
    kth = int(math.ceil(num_param * magic_percentile))
    param_flatten = param.view(num_param)
    param_min, _ = torch.topk(param_flatten, kth, dim=0, largest=False, sorted=False)
    param_min = param_min.max()
    param_max, _ = torch.topk(param_flatten, kth, dim=0, largest=True, sorted=False)
    param_max = param_max.min()
    step = (param_max - param_min) / (k - 2)
    param.clamp_(param_min, param_max).sub_(param_min).div_(step).round_().mul_(step).add_(param_min)
    param.masked_fill_(zero_mask, 0)  # recover zeros
    # codebook = {'centers_': torch.tensor(list(set(param_flatten.cpu().tolist())))}
    codebook = {'cluster_centers_': torch.zeros(k),
                'method': 'linear',
                }
    codebook['cluster_centers_'][1:] = torch.linspace(param_min, param_max, k - 1)
    return codebook 
Example #28
Source File: group.py    From pose-ae-train with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def nms(self, det):
        # suppose det is a tensor
        maxm = self.pool(det)
        maxm = torch.eq(maxm, det).float()
        det = det * maxm
        return det 
Example #29
Source File: few_shot.py    From cactus-protonets with MIT License 5 votes vote down vote up
def loss(self, sample):
        xs = Variable(sample['xs']) # support
        xq = Variable(sample['xq']) # query

        n_class = xs.size(0)
        assert xq.size(0) == n_class
        n_support = xs.size(1)
        n_query = xq.size(1)

        target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long()
        target_inds = Variable(target_inds, requires_grad=False)

        if xq.is_cuda:
            target_inds = target_inds.cuda()

        x = torch.cat([xs.view(n_class * n_support, *xs.size()[2:]),
                       xq.view(n_class * n_query, *xq.size()[2:])], 0)

        z = self.encoder.forward(x)
        z_dim = z.size(-1)

        z_proto = z[:n_class*n_support].view(n_class, n_support, z_dim).mean(1)
        zq = z[n_class*n_support:]

        dists = euclidean_dist(zq, z_proto)

        log_p_y = F.log_softmax(-dists, dim=1).view(n_class, n_query, -1)

        loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()

        _, y_hat = log_p_y.max(2)
        acc_val = torch.eq(y_hat, target_inds.squeeze()).float().mean()

        return loss_val, {
            'loss': loss_val.item(),
            'acc': acc_val.item()
        } 
Example #30
Source File: netmath.py    From ibeis with Apache License 2.0 5 votes vote down vote up
def _siamese_metrics(output, label, margin=1):

        l2_dist_tensor = torch.from_numpy(output.data.cpu().numpy())
        label_tensor = torch.from_numpy(label.data.cpu().numpy())

        # Distance
        is_pos = torch.ByteTensor()
        POS_LABEL = 1
        NEG_LABEL = 0
        torch.eq(label_tensor, POS_LABEL, out=is_pos)  # y==1
        pos_dist = 0 if len(l2_dist_tensor[is_pos]) == 0 else l2_dist_tensor[is_pos].mean()
        neg_dist = 0 if len(l2_dist_tensor[~is_pos]) == 0 else l2_dist_tensor[~is_pos].mean()
        # print('same dis : diff dis  {} : {}'.format(l2_dist_tensor[is_pos == 0].mean(), l2_dist_tensor[is_pos].mean()))

        # accuracy
        pred_pos_flags = torch.ByteTensor()
        torch.le(l2_dist_tensor, margin, out=pred_pos_flags)  # y==1's idx

        cur_score = torch.FloatTensor(label.size(0))
        cur_score.fill_(NEG_LABEL)
        cur_score[pred_pos_flags] = POS_LABEL

        label_tensor_ = label_tensor.type(torch.FloatTensor)
        accuracy = torch.eq(cur_score, label_tensor_).sum() / label_tensor.size(0)

        metrics = {
            'accuracy': accuracy,
            'pos_dist': pos_dist,
            'neg_dist': neg_dist,
        }
        return metrics