Python torch.equal() Examples

The following are 30 code examples of torch.equal(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: test_anchor.py    From mmdetection with Apache License 2.0 6 votes vote down vote up
def test_strides():
    from mmdet.core import AnchorGenerator
    # Square strides
    self = AnchorGenerator([10], [1.], [1.], [10])
    anchors = self.grid_anchors([(2, 2)], device='cpu')

    expected_anchors = torch.tensor([[-5., -5., 5., 5.], [5., -5., 15., 5.],
                                     [-5., 5., 5., 15.], [5., 5., 15., 15.]])

    assert torch.equal(anchors[0], expected_anchors)

    # Different strides in x and y direction
    self = AnchorGenerator([(10, 20)], [1.], [1.], [10])
    anchors = self.grid_anchors([(2, 2)], device='cpu')

    expected_anchors = torch.tensor([[-5., -5., 5., 5.], [5., -5., 15., 5.],
                                     [-5., 15., 5., 25.], [5., 15., 15., 25.]])

    assert torch.equal(anchors[0], expected_anchors) 
Example #2
Source File: polybeast_learn_function_test.py    From torchbeast with Apache License 2.0 6 votes vote down vote up
def test_gradients_update(self):
        """Check that gradients get updated after one iteration."""
        # Reset models.
        self.model.load_state_dict(self.initial_model_dict)
        self.actor_model.load_state_dict(self.initial_actor_model_dict)

        # There should be no calculated gradient yet.
        for p in self.model.parameters():
            self.assertIsNone(p.grad)
        for p in self.actor_model.parameters():
            self.assertIsNone(p.grad)

        polybeast.learn(*self.learn_args)

        # Check that every parameter for the learner model has a gradient, and that
        # there is at least some non-zero gradient for each set of paramaters.
        for p in self.model.parameters():
            self.assertIsNotNone(p.grad)
            self.assertFalse(torch.equal(p.grad, torch.zeros_like(p.grad)))

        # Check that the actor model has no gradients associated with it.
        for p in self.actor_model.parameters():
            self.assertIsNone(p.grad) 
Example #3
Source File: test_wrappers.py    From mmdetection with Apache License 2.0 6 votes vote down vote up
def test_max_pool_2d():
    test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
                              ('in_channel', [1, 3]), ('out_channel', [1, 3]),
                              ('kernel_size', [3, 5]), ('stride', [1, 2]),
                              ('padding', [0, 1]), ('dilation', [1, 2])])

    for in_h, in_w, in_cha, out_cha, k, s, p, d in product(
            *list(test_cases.values())):
        # wrapper op with 0-dim input
        x_empty = torch.randn(0, in_cha, in_h, in_w, requires_grad=True)
        wrapper = MaxPool2d(k, stride=s, padding=p, dilation=d)
        wrapper_out = wrapper(x_empty)

        # torch op with 3-dim input as shape reference
        x_normal = torch.randn(3, in_cha, in_h, in_w)
        ref = nn.MaxPool2d(k, stride=s, padding=p, dilation=d)
        ref_out = ref(x_normal)

        assert wrapper_out.shape[0] == 0
        assert wrapper_out.shape[1:] == ref_out.shape[1:]

        assert torch.equal(wrapper(x_normal), ref_out) 
Example #4
Source File: utils.py    From PyTorch-NLP with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def torch_equals_ignore_index(tensor, tensor_other, ignore_index=None):
    """
    Compute ``torch.equal`` with the optional mask parameter.

    Args:
        ignore_index (int, optional): Specifies a ``tensor`` index that is ignored.

    Returns:
        (bool) Returns ``True`` if target and prediction are equal.
    """
    if ignore_index is not None:
        assert tensor.size() == tensor_other.size()
        mask_arr = tensor.ne(ignore_index)
        tensor = tensor.masked_select(mask_arr)
        tensor_other = tensor_other.masked_select(mask_arr)

    return torch.equal(tensor, tensor_other) 
Example #5
Source File: model.py    From lightNLP with Apache License 2.0 6 votes vote down vote up
def forward(self, left, right):
        left_vec = self.embedding(left.to(DEVICE)).to(DEVICE)
        #         left_vec = pack_padded_sequence(left_vec, left_sent_lengths)
        right_vec = self.embedding(right.to(DEVICE)).to(DEVICE)
        #         right_vec = pack_padded_sequence(right_vec, right_sent_lengths)

        self.hidden = self.init_hidden(batch_size=left.size(1))

        left_lstm_out, (left_lstm_hidden, _) = self.lstm(left_vec, self.hidden)
        #         left_lstm_out, left_batch_size = pad_packed_sequence(left_lstm_out)
        #         assert torch.equal(left_sent_lengths, left_batch_size.to(DEVICE))

        right_lstm_out, (right_lstm_hidden, _) = self.lstm(right_vec, self.hidden)
        #         right_lstm_out, right_batch_size = pad_packed_sequence(right_lstm_out)
        #         assert torch.equal(right_sent_lengths, right_batch_size.to(DEVICE))

        return self.manhattan_distance(left_lstm_hidden[0], right_lstm_hidden[0]) 
Example #6
Source File: test_model_scorers.py    From translate with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_r2l_scorer_prepare_inputs(self):
        eos = self.task.tgt_dict.eos()
        src_tokens = torch.tensor([[6, 7, 8], [1, 2, 3]], dtype=torch.int)
        hypos = [
            {"tokens": torch.tensor([12, 13, 14, eos], dtype=torch.int)},
            {"tokens": torch.tensor([22, 23, eos], dtype=torch.int)},
            {"tokens": torch.tensor([12, 13, 14, eos], dtype=torch.int)},
            {"tokens": torch.tensor([22, 23, eos], dtype=torch.int)},
        ]

        with patch(
            "pytorch_translate.utils.load_diverse_ensemble_for_inference",
            return_value=([self.model], self.args, self.task),
        ):
            scorer = R2LModelScorer(self.args, "/tmp/model_path.txt", None, self.task)
            (encoder_inputs, tgt_tokens) = scorer.prepare_inputs(src_tokens, hypos)
            # Test encoder inputs
            assert torch.equal(
                encoder_inputs[0],
                torch.tensor(
                    [[6, 7, 8], [6, 7, 8], [1, 2, 3], [1, 2, 3]], dtype=torch.int
                ),
            ), "Encoder inputs are not as expected" 
Example #7
Source File: text_encoder.py    From PyTorch-NLP with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def decode(self, encoded):
        """ Decodes an object.

        Args:
            object_ (object): Encoded object.

        Returns:
            object: Object decoded.
        """
        if self.enforce_reversible:
            self.enforce_reversible = False
            decoded_encoded = self.encode(self.decode(encoded))
            self.enforce_reversible = True
            if not torch.equal(decoded_encoded, encoded):
                raise ValueError('Decoding is not reversible for "%s"' % encoded)

        return encoded 
Example #8
Source File: test_optimizer.py    From mmcv with Apache License 2.0 6 votes vote down vote up
def check_default_optimizer(optimizer, model, prefix=''):
    assert isinstance(optimizer, torch.optim.SGD)
    assert optimizer.defaults['lr'] == base_lr
    assert optimizer.defaults['momentum'] == momentum
    assert optimizer.defaults['weight_decay'] == base_wd
    param_groups = optimizer.param_groups[0]
    param_names = [
        'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias', 'bn.weight',
        'bn.bias', 'sub.param1', 'sub.conv1.weight', 'sub.conv1.bias',
        'sub.gn.weight', 'sub.gn.bias'
    ]
    param_dict = dict(model.named_parameters())
    assert len(param_groups['params']) == len(param_names)
    for i in range(len(param_groups['params'])):
        assert torch.equal(param_groups['params'][i],
                           param_dict[prefix + param_names[i]]) 
Example #9
Source File: evaluation.py    From dgl with Apache License 2.0 6 votes vote down vote up
def recommend(self, full_graph, K, h_user, h_item):
        """
        Return a (n_user, K) matrix of recommended items for each user
        """
        graph_slice = full_graph.edge_type_subgraph([self.user_to_item_etype])
        n_users = full_graph.number_of_nodes(self.user_ntype)
        latest_interactions = dgl.sampling.select_topk(graph_slice, 1, self.timestamp, edge_dir='out')
        user, latest_items = latest_interactions.all_edges(form='uv', order='srcdst')
        # each user should have at least one "latest" interaction
        assert torch.equal(user, torch.arange(n_users))

        recommended_batches = []
        user_batches = torch.arange(n_users).split(self.batch_size)
        for user_batch in user_batches:
            latest_item_batch = latest_items[user_batch].to(device=h_item.device)
            dist = h_item[latest_item_batch] @ h_item.t()
            # exclude items that are already interacted
            for i, u in enumerate(user_batch.tolist()):
                interacted_items = full_graph.successors(u, etype=self.user_to_item_etype)
                dist[i, interacted_items] = -np.inf
            recommended_batches.append(dist.topk(K, 1)[1])

        recommendations = torch.cat(recommended_batches, 0)
        return recommendations 
Example #10
Source File: models_resnext3d_test.py    From ClassyVision with MIT License 6 votes vote down vote up
def test_set_classy_state_weight_inflation(self):
        # Get model state from a 2D ResNet model, inflate the 2D conv weights,
        # and use them to initialize 3D conv weights. This is an advanced use of
        # `set_classy_state` method.
        model_2d_config, model_3d_config = self._get_model_config_weight_inflation()
        model_2d = build_model(model_2d_config)
        model_2d_state = model_2d.get_classy_state()

        model_3d = build_model(model_3d_config)
        model_3d.set_classy_state(model_2d_state)
        model_3d_state = model_3d.get_classy_state()

        for name, weight_2d in model_2d_state["model"]["trunk"].items():
            weight_3d = model_3d_state["model"]["trunk"][name]
            if weight_2d.dim() == 5:
                # inflation only applies to conv weights
                self.assertEqual(weight_3d.dim(), 5)
                if weight_2d.shape[2] == 1 and weight_3d.shape[2] > 1:
                    weight_2d_inflated = (
                        weight_2d.repeat(1, 1, weight_3d.shape[2], 1, 1)
                        / weight_3d.shape[2]
                    )
                    self.assertTrue(torch.equal(weight_3d, weight_2d_inflated)) 
Example #11
Source File: test_patch.py    From higher with Apache License 2.0 6 votes vote down vote up
def testSubModuleDirectCall(self):
        """Check that patched submodules can be called directly."""
        class Module(nn.Module):
            def __init__(self):
                super().__init__()
                self.submodule = nn.Linear(3, 4)

            def forward(self, inputs):
                return self.submodule(inputs)

        module = _NestedEnc(nn.Linear(3, 4))
        fmodule = higher.monkeypatch(module)

        xs = torch.randn(2, 3)
        fsubmodule = fmodule.f

        self.assertTrue(torch.equal(fmodule(xs), fsubmodule(xs))) 
Example #12
Source File: test_higher.py    From higher with Apache License 2.0 6 votes vote down vote up
def testRandomForwards(self):
        """Test reference and patched net forward equivalence.

        Test if, given rand fast weights, patched net and reference forwards
        match up given random inputs.
        """
        with higher.innerloop_ctx(self.target_net, self.opt) as (fnet, _):
            for i in range(10):
                fast_named_weights = OrderedDict(
                    (name, torch.rand(p.shape, requires_grad=True))
                    for name, p in self.reference_net.named_parameters()
                )
                fast_weights = [p for _, p in fast_named_weights.items()]
                inputs = torch.rand(
                    self.batch_size, self.num_in_channels, self.in_h, self.in_w
                )
                self.assertTrue(
                    torch.equal(
                        self.reference_net(inputs, params=fast_named_weights),
                        fnet(inputs, params=fast_weights)
                    )
                ) 
Example #13
Source File: test_network_utils.py    From cortex with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_apply_nonlinearity(simple_tensor):

    """

    Args:
        simple_tensor(@pytest.fixture): torch.Tensor

    Asserts: True if the right Pytorch function is called.

    """

    nonlinearity_args = {}
    nonlinear = 'tanh'

    expected_output = torch.nn.functional.tanh(simple_tensor)
    applied_nonlinearity = apply_nonlinearity(simple_tensor, nonlinear,
                                              **nonlinearity_args)

    assert torch.equal(expected_output, applied_nonlinearity) 
Example #14
Source File: test_utils.py    From mmcv with Apache License 2.0 6 votes vote down vote up
def test_set_random_seed():
    set_random_seed(0)
    a_random = random.randint(0, 10)
    a_np_random = np.random.rand(2, 2)
    a_torch_random = torch.rand(2, 2)
    assert torch.backends.cudnn.deterministic is False
    assert torch.backends.cudnn.benchmark is False
    assert os.environ['PYTHONHASHSEED'] == str(0)

    set_random_seed(0, True)
    b_random = random.randint(0, 10)
    b_np_random = np.random.rand(2, 2)
    b_torch_random = torch.rand(2, 2)
    assert torch.backends.cudnn.deterministic is True
    assert torch.backends.cudnn.benchmark is False

    assert a_random == b_random
    assert np.equal(a_np_random, b_np_random).all()
    assert torch.equal(a_torch_random, b_torch_random) 
Example #15
Source File: test_gaussian_mlp_module.py    From garage with MIT License 6 votes vote down vote up
def test_std_share_network_output_values(input_dim, output_dim, hidden_sizes):
    module = GaussianMLPTwoHeadedModule(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_sizes=hidden_sizes,
        hidden_nonlinearity=None,
        std_parameterization='exp',
        hidden_w_init=nn.init.ones_,
        output_w_init=nn.init.ones_)

    dist = module(torch.ones(input_dim))

    exp_mean = torch.full(
        (output_dim, ), input_dim * (torch.Tensor(hidden_sizes).prod().item()))
    exp_variance = (
        input_dim * torch.Tensor(hidden_sizes).prod()).exp().pow(2).item()

    assert dist.mean.equal(exp_mean)
    assert dist.variance.equal(torch.full((output_dim, ), exp_variance))
    assert dist.rsample().shape == (output_dim, ) 
Example #16
Source File: test_gaussian_mlp_module.py    From garage with MIT License 6 votes vote down vote up
def test_std_share_network_output_values_with_batch(input_dim, output_dim,
                                                    hidden_sizes):
    module = GaussianMLPTwoHeadedModule(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_sizes=hidden_sizes,
        hidden_nonlinearity=None,
        std_parameterization='exp',
        hidden_w_init=nn.init.ones_,
        output_w_init=nn.init.ones_)

    batch_size = 5
    dist = module(torch.ones([batch_size, input_dim]))

    exp_mean = torch.full(
        (batch_size, output_dim),
        input_dim * (torch.Tensor(hidden_sizes).prod().item()))
    exp_variance = (
        input_dim * torch.Tensor(hidden_sizes).prod()).exp().pow(2).item()

    assert dist.mean.equal(exp_mean)
    assert dist.variance.equal(
        torch.full((batch_size, output_dim), exp_variance))
    assert dist.rsample().shape == (batch_size, output_dim) 
Example #17
Source File: test_gaussian_mlp_module.py    From garage with MIT License 6 votes vote down vote up
def test_std_network_output_values(input_dim, output_dim, hidden_sizes):
    init_std = 2.

    module = GaussianMLPModule(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_sizes=hidden_sizes,
        init_std=init_std,
        hidden_nonlinearity=None,
        std_parameterization='exp',
        hidden_w_init=nn.init.ones_,
        output_w_init=nn.init.ones_)

    dist = module(torch.ones(input_dim))

    exp_mean = torch.full(
        (output_dim, ), input_dim * (torch.Tensor(hidden_sizes).prod().item()))
    exp_variance = init_std**2

    assert dist.mean.equal(exp_mean)
    assert dist.variance.equal(torch.full((output_dim, ), exp_variance))
    assert dist.rsample().shape == (output_dim, ) 
Example #18
Source File: test_gaussian_mlp_module.py    From garage with MIT License 6 votes vote down vote up
def test_std_adaptive_network_output_values(input_dim, output_dim,
                                            hidden_sizes, std_hidden_sizes):
    module = GaussianMLPIndependentStdModule(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_sizes=hidden_sizes,
        std_hidden_sizes=std_hidden_sizes,
        hidden_nonlinearity=None,
        hidden_w_init=nn.init.ones_,
        output_w_init=nn.init.ones_,
        std_hidden_nonlinearity=None,
        std_hidden_w_init=nn.init.ones_,
        std_output_w_init=nn.init.ones_)

    dist = module(torch.ones(input_dim))

    exp_mean = torch.full(
        (output_dim, ), input_dim * (torch.Tensor(hidden_sizes).prod().item()))
    exp_variance = (
        input_dim * torch.Tensor(hidden_sizes).prod()).exp().pow(2).item()

    assert dist.mean.equal(exp_mean)
    assert dist.variance.equal(torch.full((output_dim, ), exp_variance))
    assert dist.rsample().shape == (output_dim, ) 
Example #19
Source File: test_gaussian_mlp_module.py    From garage with MIT License 6 votes vote down vote up
def test_exp_min_std(input_dim, output_dim, hidden_sizes):
    min_value = 10.

    module = GaussianMLPModule(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_sizes=hidden_sizes,
        init_std=1.,
        min_std=min_value,
        hidden_nonlinearity=None,
        std_parameterization='exp',
        hidden_w_init=nn.init.zeros_,
        output_w_init=nn.init.zeros_)

    dist = module(torch.ones(input_dim))

    exp_variance = min_value**2

    assert dist.variance.equal(torch.full((output_dim, ), exp_variance)) 
Example #20
Source File: test_gaussian_mlp_module.py    From garage with MIT License 6 votes vote down vote up
def test_exp_max_std(input_dim, output_dim, hidden_sizes):
    max_value = 1.

    module = GaussianMLPModule(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_sizes=hidden_sizes,
        init_std=10.,
        max_std=max_value,
        hidden_nonlinearity=None,
        std_parameterization='exp',
        hidden_w_init=nn.init.zeros_,
        output_w_init=nn.init.zeros_)

    dist = module(torch.ones(input_dim))

    exp_variance = max_value**2

    assert dist.variance.equal(torch.full((output_dim, ), exp_variance)) 
Example #21
Source File: test_gaussian_mlp_module.py    From garage with MIT License 6 votes vote down vote up
def test_softplus_min_std(input_dim, output_dim, hidden_sizes):
    min_value = 2.

    module = GaussianMLPModule(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_sizes=hidden_sizes,
        init_std=1.,
        min_std=min_value,
        hidden_nonlinearity=None,
        std_parameterization='softplus',
        hidden_w_init=nn.init.zeros_,
        output_w_init=nn.init.zeros_)

    dist = module(torch.ones(input_dim))

    exp_variance = torch.Tensor([min_value]).exp().add(1.).log()**2

    assert dist.variance.equal(torch.full((output_dim, ), exp_variance[0])) 
Example #22
Source File: test_gaussian_mlp_module.py    From garage with MIT License 6 votes vote down vote up
def test_softplus_max_std(input_dim, output_dim, hidden_sizes):
    max_value = 1.

    module = GaussianMLPModule(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_sizes=hidden_sizes,
        init_std=10,
        max_std=max_value,
        hidden_nonlinearity=None,
        std_parameterization='softplus',
        hidden_w_init=nn.init.ones_,
        output_w_init=nn.init.ones_)

    dist = module(torch.ones(input_dim))

    exp_variance = torch.Tensor([max_value]).exp().add(1.).log()**2

    assert torch.equal(dist.variance,
                       torch.full((output_dim, ), exp_variance[0])) 
Example #23
Source File: test_convnets.py    From cortex with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_simple_conv_encoder_forward(simple_conv_encoder_image_classification,
                                     simple_tensor_conv2d):
    """

    Args:
        simple_conv_encoder_image_classification (@pytest.fixture): SimpleConvEncoder
        simple_tensor_conv2d (@pytest.fixture): torch.Tensor

    Asserts: True if the output's dimension is equal to the input's one
             and that element-wise, the values have changed.

    """
    input_dim = simple_tensor_conv2d.dim()
    output = simple_conv_encoder_image_classification.forward(
        simple_tensor_conv2d)
    output_dim = output.dim()
    equivalent = torch.equal(simple_tensor_conv2d, output)
    assert input_dim == 4
    assert output_dim == 2
    assert not equivalent 
Example #24
Source File: test_label_encoder.py    From PyTorch-NLP with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_label_encoder_batch_encoding(label_encoder):
    encoded = label_encoder.batch_encode(label_encoder.vocab)
    assert torch.equal(encoded, torch.arange(label_encoder.vocab_size).view(-1)) 
Example #25
Source File: test_hook_args.py    From PySyft with Apache License 2.0 5 votes vote down vote up
def test_list_as_index(workers):
    tensor = torch.tensor([10, 20, 30, -2, 3]).send(workers["bob"])
    target = torch.tensor([10, 20, 30, 3])

    slice = tensor[[0, 1, 2, 4]].get()
    slice2 = tensor[2].get()

    assert torch.equal(target, slice)
    assert torch.equal(torch.tensor(30), slice2) 
Example #26
Source File: test_higher.py    From higher with Apache License 2.0 5 votes vote down vote up
def testSameInitialWeightsPrePatch(self):
        """Check that reference and unpatched target net have equal weights.

        This is mostly a sanity check for the purpose of the other unit tests.
        """
        ref_params = list(self.reference_net.named_parameters())
        target_params = list(self.target_net.named_parameters())
        self.assertEqual(
            len(ref_params),
            len(target_params),
            msg=(
                "Length mismatched between reference net parameter count "
                "({}) and target ({}).".format(
                    len(ref_params), len(target_params)
                )
            )
        )
        for ref, target in zip(ref_params, target_params):
            ref_name, ref_p = ref
            target_name, target_p = target
            self.assertEqual(
                ref_name,
                target_name,
                msg="Name mismatch or parameter misalignment ('{}' vs '{}')".
                format(ref_name, target_name)
            )
            self.assertTrue(
                torch.equal(ref_p, target_p),
                msg="Parameter value inequality for {}".format(ref_name)
            ) 
Example #27
Source File: test_gaussian_mlp_module.py    From garage with MIT License 5 votes vote down vote up
def test_std_network_output_values_with_batch(input_dim, output_dim,
                                              hidden_sizes):
    init_std = 2.

    module = GaussianMLPModule(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_sizes=hidden_sizes,
        init_std=init_std,
        hidden_nonlinearity=None,
        std_parameterization='exp',
        hidden_w_init=nn.init.ones_,
        output_w_init=nn.init.ones_)

    batch_size = 5
    dist = module(torch.ones([batch_size, input_dim]))

    exp_mean = torch.full(
        (batch_size, output_dim),
        input_dim * (torch.Tensor(hidden_sizes).prod().item()))
    exp_variance = init_std**2

    assert dist.mean.equal(exp_mean)
    assert dist.variance.equal(
        torch.full((batch_size, output_dim), exp_variance))
    assert dist.rsample().shape == (batch_size, output_dim) 
Example #28
Source File: model_builder.py    From Detectron.pytorch with MIT License 5 votes vote down vote up
def compare_state_dict(sa, sb):
    if sa.keys() != sb.keys():
        return False
    for k, va in sa.items():
        if not torch.equal(va, sb[k]):
            return False
    return True 
Example #29
Source File: model.py    From lightNLP with Apache License 2.0 5 votes vote down vote up
def lstm_forward(self, sentence, sent_lengths):
        x = self.embedding(sentence.to(DEVICE)).to(DEVICE)
        x = pack_padded_sequence(x, sent_lengths)
        self.hidden = self.init_hidden(batch_size=len(sent_lengths))
        lstm_out, self.hidden = self.lstm(x, self.hidden)
        lstm_out, new_batch_size = pad_packed_sequence(lstm_out)
        assert torch.equal(sent_lengths, new_batch_size.to(DEVICE))
        y = self.hidden2label(lstm_out.to(DEVICE))
        return y.to(DEVICE) 
Example #30
Source File: test_rescorer.py    From translate with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_batch_computation(self):
        test_args = test_utils.ModelParamsDict("transformer")
        test_args.enable_rescoring = True
        test_args.length_penalty = 1
        test_args.l2r_model_path = "/tmp/test_rescorer_model.pt"
        test_args.l2r_model_weight = 1.0
        test_args.r2l_model_weight = 0.0
        test_args.reverse_model_weight = 0.0
        test_args.cloze_transformer_weight = 1.0
        test_args.lm_model_weight = 0.0
        test_args.length_penalty = 1.0

        _, src_dict, tgt_dict = test_utils.prepare_inputs(test_args)
        task = tasks.PytorchTranslateTask(test_args, src_dict, tgt_dict)
        model = task.build_model(test_args)
        torch.save(model, test_args.l2r_model_path)
        with patch(
            "pytorch_translate.utils.load_diverse_ensemble_for_inference",
            return_value=([model], test_args, task),
        ):
            rescorer = Rescorer(test_args)
            src_tokens = torch.tensor([[1, 3, 3, 4, 2], [1, 3, 2, 0, 0]])
            hypos = [
                {"tokens": torch.tensor([1, 5, 2])},
                {"tokens": torch.tensor([6, 3, 5, 2])},
                {"tokens": torch.tensor([1, 2])},
                {"tokens": torch.tensor([1, 5, 6, 2])},
            ]
            scores = rescorer.score(src_tokens, hypos)

            src_tokens = torch.tensor([[1, 3, 3, 4, 2]])
            hypos = [
                {"tokens": torch.tensor([1, 5, 2])},
                {"tokens": torch.tensor([6, 3, 5, 2])},
            ]
            scores_single = rescorer.score(src_tokens, hypos)

            assert torch.equal(scores[0], scores_single[0])