Python allennlp.nn.util.batched_index_select() Examples

The following are 15 code examples of allennlp.nn.util.batched_index_select(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module allennlp.nn.util , or try the search function .
Example #1
Source File: endpoint_span_extractor_test.py    From allennlp with Apache License 2.0 6 votes vote down vote up
def test_correct_sequence_elements_are_embedded(self):
        sequence_tensor = torch.randn([2, 5, 7])
        # Concatentate start and end points together to form our representation.
        extractor = EndpointSpanExtractor(7, "x,y")

        indices = torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [3, 4]]])
        span_representations = extractor(sequence_tensor, indices)

        assert list(span_representations.size()) == [2, 2, 14]
        assert extractor.get_output_dim() == 14
        assert extractor.get_input_dim() == 7

        start_indices, end_indices = indices.split(1, -1)
        # We just concatenated the start and end embeddings together, so
        # we can check they match the original indices if we split them apart.
        start_embeddings, end_embeddings = span_representations.split(7, -1)

        correct_start_embeddings = batched_index_select(sequence_tensor, start_indices.squeeze())
        correct_end_embeddings = batched_index_select(sequence_tensor, end_indices.squeeze())
        numpy.testing.assert_array_equal(
            start_embeddings.data.numpy(), correct_start_embeddings.data.numpy()
        )
        numpy.testing.assert_array_equal(
            end_embeddings.data.numpy(), correct_end_embeddings.data.numpy()
        ) 
Example #2
Source File: util_test.py    From allennlp with Apache License 2.0 6 votes vote down vote up
def test_masked_topk_selects_top_scored_items_and_respects_masking(self):
        items = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
        items[0, :2, :] = 1
        items[1, 2:, :] = 1
        items[2, 2:, :] = 1

        scores = items.sum(-1)

        mask = torch.ones([3, 4]).bool()
        mask[1, 0] = 0
        mask[1, 3] = 0

        pruned_scores, pruned_mask, pruned_indices = util.masked_topk(scores, mask, 2)

        # Second element in the batch would have indices 2, 3, but
        # 3 and 0 are masked, so instead it has 1, 2.
        numpy.testing.assert_array_equal(
            pruned_indices.data.numpy(), numpy.array([[0, 1], [1, 2], [2, 3]])
        )
        numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.ones([3, 2]))

        # scores should be the result of index_selecting the pruned_indices.
        correct_scores = util.batched_index_select(scores.unsqueeze(-1), pruned_indices).squeeze(-1)
        self.assert_array_equal_with_mask(correct_scores, pruned_scores, pruned_mask) 
Example #3
Source File: util_test.py    From magnitude with MIT License 6 votes vote down vote up
def test_batched_index_select(self):
        indices = numpy.array([[[1, 2],
                                [3, 4]],
                               [[5, 6],
                                [7, 8]]])
        # Each element is a vector of it's index.
        targets = torch.ones([2, 10, 3]).cumsum(1) - 1
        # Make the second batch double it's index so they're different.
        targets[1, :, :] *= 2
        indices = torch.tensor(indices, dtype=torch.long)
        selected = util.batched_index_select(targets, indices)

        assert list(selected.size()) == [2, 2, 2, 3]
        ones = numpy.ones([3])
        numpy.testing.assert_array_equal(selected[0, 0, 0, :].data.numpy(), ones)
        numpy.testing.assert_array_equal(selected[0, 0, 1, :].data.numpy(), ones * 2)
        numpy.testing.assert_array_equal(selected[0, 1, 0, :].data.numpy(), ones * 3)
        numpy.testing.assert_array_equal(selected[0, 1, 1, :].data.numpy(), ones * 4)

        numpy.testing.assert_array_equal(selected[1, 0, 0, :].data.numpy(), ones * 10)
        numpy.testing.assert_array_equal(selected[1, 0, 1, :].data.numpy(), ones * 12)
        numpy.testing.assert_array_equal(selected[1, 1, 0, :].data.numpy(), ones * 14)
        numpy.testing.assert_array_equal(selected[1, 1, 1, :].data.numpy(), ones * 16) 
Example #4
Source File: endpoint_span_extractor_test.py    From allennlp with Apache License 2.0 5 votes vote down vote up
def test_masked_indices_are_handled_correctly(self):
        sequence_tensor = torch.randn([2, 5, 7])
        # concatentate start and end points together to form our representation.
        extractor = EndpointSpanExtractor(7, "x,y")

        indices = torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [3, 4]]])
        span_representations = extractor(sequence_tensor, indices)

        # Make a mask with the second batch element completely masked.
        indices_mask = torch.tensor([[True, True], [False, False]])

        span_representations = extractor(sequence_tensor, indices, span_indices_mask=indices_mask)
        start_embeddings, end_embeddings = span_representations.split(7, -1)
        start_indices, end_indices = indices.split(1, -1)

        correct_start_embeddings = batched_index_select(
            sequence_tensor, start_indices.squeeze()
        ).data
        # Completely masked second batch element, so it should all be zero.
        correct_start_embeddings[1, :, :].fill_(0)
        correct_end_embeddings = batched_index_select(sequence_tensor, end_indices.squeeze()).data
        correct_end_embeddings[1, :, :].fill_(0)
        numpy.testing.assert_array_equal(
            start_embeddings.data.numpy(), correct_start_embeddings.numpy()
        )
        numpy.testing.assert_array_equal(
            end_embeddings.data.numpy(), correct_end_embeddings.numpy()
        ) 
Example #5
Source File: endpoint_span_extractor_test.py    From allennlp with Apache License 2.0 5 votes vote down vote up
def test_masked_indices_are_handled_correctly_with_exclusive_indices(self):
        sequence_tensor = torch.randn([2, 5, 8])
        # concatentate start and end points together to form our representation
        # for both the forward and backward directions.
        extractor = EndpointSpanExtractor(8, "x,y", use_exclusive_start_indices=True)
        indices = torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [0, 1]]])
        sequence_mask = torch.tensor(
            [[True, True, True, True, True], [True, True, True, False, False]]
        )

        span_representations = extractor(sequence_tensor, indices, sequence_mask=sequence_mask)

        # We just concatenated the start and end embeddings together, so
        # we can check they match the original indices if we split them apart.
        start_embeddings, end_embeddings = span_representations.split(8, -1)

        correct_start_indices = torch.LongTensor([[0, 1], [-1, -1]])
        # These indices should be -1, so they'll be replaced with a sentinel. Here,
        # we'll set them to a value other than -1 so we can index select the indices and
        # replace them later.
        correct_start_indices[1, 0] = 1
        correct_start_indices[1, 1] = 1

        correct_end_indices = torch.LongTensor([[3, 4], [2, 1]])

        correct_start_embeddings = batched_index_select(
            sequence_tensor.contiguous(), correct_start_indices
        )
        # This element had sequence_tensor index of 0, so it's exclusive index is the start sentinel.
        correct_start_embeddings[1, 0] = extractor._start_sentinel.data
        correct_start_embeddings[1, 1] = extractor._start_sentinel.data
        numpy.testing.assert_array_equal(
            start_embeddings.data.numpy(), correct_start_embeddings.data.numpy()
        )

        correct_end_embeddings = batched_index_select(
            sequence_tensor.contiguous(), correct_end_indices
        )
        numpy.testing.assert_array_equal(
            end_embeddings.data.numpy(), correct_end_embeddings.data.numpy()
        ) 
Example #6
Source File: util_test.py    From allennlp with Apache License 2.0 5 votes vote down vote up
def test_batched_index_select(self):
        indices = numpy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
        # Each element is a vector of its index.
        targets = torch.ones([2, 10, 3]).cumsum(1) - 1
        # Make the second batch double its index so they're different.
        targets[1, :, :] *= 2
        indices = torch.tensor(indices, dtype=torch.long)
        selected = util.batched_index_select(targets, indices)

        assert list(selected.size()) == [2, 2, 2, 3]
        ones = numpy.ones([3])
        numpy.testing.assert_array_equal(selected[0, 0, 0, :].data.numpy(), ones)
        numpy.testing.assert_array_equal(selected[0, 0, 1, :].data.numpy(), ones * 2)
        numpy.testing.assert_array_equal(selected[0, 1, 0, :].data.numpy(), ones * 3)
        numpy.testing.assert_array_equal(selected[0, 1, 1, :].data.numpy(), ones * 4)

        numpy.testing.assert_array_equal(selected[1, 0, 0, :].data.numpy(), ones * 10)
        numpy.testing.assert_array_equal(selected[1, 0, 1, :].data.numpy(), ones * 12)
        numpy.testing.assert_array_equal(selected[1, 1, 0, :].data.numpy(), ones * 14)
        numpy.testing.assert_array_equal(selected[1, 1, 1, :].data.numpy(), ones * 16)

        indices = numpy.array([[[1, 11], [3, 4]], [[5, 6], [7, 8]]])
        indices = torch.tensor(indices, dtype=torch.long)
        with pytest.raises(ConfigurationError):
            util.batched_index_select(targets, indices)

        indices = numpy.array([[[1, -1], [3, 4]], [[5, 6], [7, 8]]])
        indices = torch.tensor(indices, dtype=torch.long)
        with pytest.raises(ConfigurationError):
            util.batched_index_select(targets, indices) 
Example #7
Source File: util_test.py    From allennlp with Apache License 2.0 5 votes vote down vote up
def test_masked_topk_selects_top_scored_items_and_respects_masking_different_num_items(self):
        items = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
        items[0, 0, :] = 1.5
        items[0, 1, :] = 2
        items[0, 3, :] = 1
        items[1, 1:3, :] = 1
        items[2, 0, :] = 1
        items[2, 1, :] = 2
        items[2, 2, :] = 1.5

        scores = items.sum(-1)

        mask = torch.ones([3, 4]).bool()
        mask[1, 3] = 0

        k = torch.tensor([3, 2, 1], dtype=torch.long)

        pruned_scores, pruned_mask, pruned_indices = util.masked_topk(scores, mask, k)

        # Second element in the batch would have indices 2, 3, but
        # 3 and 0 are masked, so instead it has 1, 2.
        numpy.testing.assert_array_equal(
            pruned_indices.data.numpy(), numpy.array([[0, 1, 3], [1, 2, 2], [1, 2, 2]])
        )
        numpy.testing.assert_array_equal(
            pruned_mask.data.numpy(), numpy.array([[1, 1, 1], [1, 1, 0], [1, 0, 0]])
        )

        # scores should be the result of index_selecting the pruned_indices.
        correct_scores = util.batched_index_select(scores.unsqueeze(-1), pruned_indices).squeeze(-1)
        self.assert_array_equal_with_mask(correct_scores, pruned_scores, pruned_mask) 
Example #8
Source File: util_test.py    From allennlp with Apache License 2.0 5 votes vote down vote up
def test_masked_topk_works_for_row_with_no_items_requested(self):
        # Case where `num_items_to_keep` is a tensor rather than an int. Make sure it does the right
        # thing when no items are requested for one of the rows.

        items = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
        items[0, :3, :] = 1
        items[1, 2:, :] = 1
        items[2, 2:, :] = 1

        scores = items.sum(-1)

        mask = torch.ones([3, 4]).bool()
        mask[1, 0] = 0
        mask[1, 3] = 0

        k = torch.tensor([3, 2, 0], dtype=torch.long)

        pruned_scores, pruned_mask, pruned_indices = util.masked_topk(scores, mask, k)

        # First element just picks top three entries. Second would pick entries 2 and 3, but 0 and 3
        # are masked, so it takes 1 and 2 (repeating the second index). The third element is
        # entirely masked and just repeats the largest index with a top-3 score.
        numpy.testing.assert_array_equal(
            pruned_indices.data.numpy(), numpy.array([[0, 1, 2], [1, 2, 2], [3, 3, 3]])
        )
        numpy.testing.assert_array_equal(
            pruned_mask.data.numpy(), numpy.array([[1, 1, 1], [1, 1, 0], [0, 0, 0]])
        )

        # scores should be the result of index_selecting the pruned_indices.
        correct_scores = util.batched_index_select(scores.unsqueeze(-1), pruned_indices).squeeze(-1)
        self.assert_array_equal_with_mask(correct_scores, pruned_scores, pruned_mask) 
Example #9
Source File: span_pruner_test.py    From magnitude with MIT License 5 votes vote down vote up
def test_span_pruner_selects_top_scored_spans_and_respects_masking(self):
        # Really simple scorer - sum up the embedding_dim.
        scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1)
        pruner = SpanPruner(scorer=scorer)

        spans = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
        spans[0, :2, :] = 1
        spans[1, 2:, :] = 1
        spans[2, 2:, :] = 1

        mask = torch.ones([3, 4])
        mask[1, 0] = 0
        mask[1, 3] = 0
        pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner(spans, mask, 2)

        # Second element in the batch would have indices 2, 3, but
        # 3 and 0 are masked, so instead it has 1, 2.
        numpy.testing.assert_array_equal(pruned_indices.data.numpy(), numpy.array([[0, 1],
                                                                                   [1, 2],
                                                                                   [2, 3]]))
        numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.ones([3, 2]))

        # embeddings should be the result of index_selecting the pruned_indices.
        correct_embeddings = batched_index_select(spans, pruned_indices)
        numpy.testing.assert_array_equal(correct_embeddings.data.numpy(),
                                         pruned_embeddings.data.numpy())
        # scores should be the sum of the correct embedding elements.
        numpy.testing.assert_array_equal(correct_embeddings.sum(-1).unsqueeze(-1).data.numpy(),
                                         pruned_scores.data.numpy()) 
Example #10
Source File: span_pruner_test.py    From magnitude with MIT License 5 votes vote down vote up
def test_span_scorer_works_for_completely_masked_rows(self):
        # Really simple scorer - sum up the embedding_dim.
        scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1)
        pruner = SpanPruner(scorer=scorer) # type: ignore

        spans = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
        spans[0, :2, :] = 1
        spans[1, 2:, :] = 1
        spans[2, 2:, :] = 1

        mask = torch.ones([3, 4])
        mask[1, 0] = 0
        mask[1, 3] = 0
        mask[2, :] = 0 # fully masked last batch element.

        pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner(spans, mask, 2)

        # We can't check the last row here, because it's completely masked.
        # Instead we'll check that the scores for these elements are -inf.
        numpy.testing.assert_array_equal(pruned_indices[:2].data.numpy(), numpy.array([[0, 1],
                                                                                       [1, 2]]))
        numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.array([[1, 1],
                                                                                [1, 1],
                                                                                [0, 0]]))
        # embeddings should be the result of index_selecting the pruned_indices.
        correct_embeddings = batched_index_select(spans, pruned_indices)
        numpy.testing.assert_array_equal(correct_embeddings.data.numpy(),
                                         pruned_embeddings.data.numpy())
        # scores should be the sum of the correct embedding elements, with
        # masked elements equal to -inf.
        correct_scores = correct_embeddings.sum(-1).unsqueeze(-1).data.numpy()
        correct_scores[2, :] = float(u"-inf")
        numpy.testing.assert_array_equal(correct_scores,
                                         pruned_scores.data.numpy()) 
Example #11
Source File: endpoint_span_extractor_test.py    From magnitude with MIT License 5 votes vote down vote up
def test_correct_sequence_elements_are_embedded(self):
        sequence_tensor = torch.randn([2, 5, 7])
        # Concatentate start and end points together to form our representation.
        extractor = EndpointSpanExtractor(7, u"x,y")

        indices = torch.LongTensor([[[1, 3],
                                     [2, 4]],
                                    [[0, 2],
                                     [3, 4]]])
        span_representations = extractor(sequence_tensor, indices)

        assert list(span_representations.size()) == [2, 2, 14]
        assert extractor.get_output_dim() == 14
        assert extractor.get_input_dim() == 7

        start_indices, end_indices = indices.split(1, -1)
        # We just concatenated the start and end embeddings together, so
        # we can check they match the original indices if we split them apart.
        start_embeddings, end_embeddings = span_representations.split(7, -1)

        correct_start_embeddings = batched_index_select(sequence_tensor, start_indices.squeeze())
        correct_end_embeddings = batched_index_select(sequence_tensor, end_indices.squeeze())
        numpy.testing.assert_array_equal(start_embeddings.data.numpy(),
                                         correct_start_embeddings.data.numpy())
        numpy.testing.assert_array_equal(end_embeddings.data.numpy(),
                                         correct_end_embeddings.data.numpy()) 
Example #12
Source File: endpoint_span_extractor_test.py    From magnitude with MIT License 5 votes vote down vote up
def test_masked_indices_are_handled_correctly(self):
        sequence_tensor = torch.randn([2, 5, 7])
        # concatentate start and end points together to form our representation.
        extractor = EndpointSpanExtractor(7, u"x,y")

        indices = torch.LongTensor([[[1, 3],
                                     [2, 4]],
                                    [[0, 2],
                                     [3, 4]]])
        span_representations = extractor(sequence_tensor, indices)

        # Make a mask with the second batch element completely masked.
        indices_mask = torch.LongTensor([[1, 1], [0, 0]])

        span_representations = extractor(sequence_tensor, indices, span_indices_mask=indices_mask)
        start_embeddings, end_embeddings = span_representations.split(7, -1)
        start_indices, end_indices = indices.split(1, -1)

        correct_start_embeddings = batched_index_select(sequence_tensor, start_indices.squeeze()).data
        # Completely masked second batch element, so it should all be zero.
        correct_start_embeddings[1, :, :].fill_(0)
        correct_end_embeddings = batched_index_select(sequence_tensor, end_indices.squeeze()).data
        correct_end_embeddings[1, :, :].fill_(0)
        numpy.testing.assert_array_equal(start_embeddings.data.numpy(),
                                         correct_start_embeddings.numpy())
        numpy.testing.assert_array_equal(end_embeddings.data.numpy(),
                                         correct_end_embeddings.numpy()) 
Example #13
Source File: endpoint_span_extractor.py    From magnitude with MIT License 4 votes vote down vote up
def forward(self,
                sequence_tensor                   ,
                span_indices                  ,
                sequence_mask                   = None,
                span_indices_mask                   = None)        :
        # shape (batch_size, num_spans)
        span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)]

        if span_indices_mask is not None:
            # It's not strictly necessary to multiply the span indices by the mask here,
            # but it's possible that the span representation was padded with something other
            # than 0 (such as -1, which would be an invalid index), so we do so anyway to
            # be safe.
            span_starts = span_starts * span_indices_mask
            span_ends = span_ends * span_indices_mask

        if not self._use_exclusive_start_indices:
            start_embeddings = util.batched_index_select(sequence_tensor, span_starts)
            end_embeddings = util.batched_index_select(sequence_tensor, span_ends)

        else:
            # We want `exclusive` span starts, so we remove 1 from the forward span starts
            # as the AllenNLP ``SpanField`` is inclusive.
            # shape (batch_size, num_spans)
            exclusive_span_starts = span_starts - 1
            # shape (batch_size, num_spans, 1)
            start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1)
            exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1))

            # We'll check the indices here at runtime, because it's difficult to debug
            # if this goes wrong and it's tricky to get right.
            if (exclusive_span_starts < 0).any():
                raise ValueError("Adjusted span indices must lie inside the the sequence tensor, "
                                 "but found: exclusive_span_starts: {exclusive_span_starts}.")

            start_embeddings = util.batched_index_select(sequence_tensor, exclusive_span_starts)
            end_embeddings = util.batched_index_select(sequence_tensor, span_ends)

            # We're using sentinels, so we need to replace all the elements which were
            # outside the dimensions of the sequence_tensor with the start sentinel.
            float_start_sentinel_mask = start_sentinel_mask.float()
            start_embeddings = start_embeddings * (1 - float_start_sentinel_mask)\
                                        + float_start_sentinel_mask * self._start_sentinel

        combined_tensors = util.combine_tensors(self._combination, [start_embeddings, end_embeddings])
        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = util.bucket_values(span_ends - span_starts,
                                                 num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            return torch.cat([combined_tensors, span_width_embeddings], -1)

        if span_indices_mask is not None:
            return combined_tensors * span_indices_mask.unsqueeze(-1).float()
        return combined_tensors 
Example #14
Source File: sum_span_extractor.py    From AntNRE with Apache License 2.0 4 votes vote down vote up
def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(max_batch_span_width,
                                                       util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices)
        text_embeddings = span_embeddings * span_mask.unsqueeze(-1)
        sum_text_embeddings = text_embeddings.sum(dim=2)
        return sum_text_embeddings 
Example #15
Source File: mean_span_extractor.py    From AntNRE with Apache License 2.0 4 votes vote down vote up
def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(max_batch_span_width,
                                                       util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices)
        text_embeddings = span_embeddings * span_mask.unsqueeze(-1)
        sum_text_embeddings = text_embeddings.sum(dim=2)

        span_num = span_mask.unsqueeze(-1).sum(dim=2)
        mean_text_embeddings = sum_text_embeddings / span_num 
        return mean_text_embeddings


#  sequence_tensor = torch.randn(2, 5, 5)
#  span_indices = torch.LongTensor([[[0, 1]], [[1, 3]]])
#  extractor = MeanSpanExtractor(5)
#  print(extractor(sequence_tensor, span_indices))
#  print("====")
#  print((sequence_tensor[0][0] + sequence_tensor[0][1]) / 2)

#  print((sequence_tensor[1][1] + sequence_tensor[1][2] + sequence_tensor[1][3])/3 )