/* * Copyright 2019 Amazon.com, Inc. or its affiliates. * Licensed under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package software.amazon.kinesis.lifecycle; import static org.hamcrest.CoreMatchers.allOf; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.notNullValue; import static org.hamcrest.CoreMatchers.nullValue; import static org.hamcrest.beans.HasPropertyWithValue.hasProperty; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.io.ByteArrayOutputStream; import java.math.BigInteger; import java.nio.ByteBuffer; import java.security.MessageDigest; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.Collections; import java.util.Date; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Random; import java.util.UUID; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import org.hamcrest.Description; import org.hamcrest.Matcher; import org.hamcrest.TypeSafeDiagnosingMatcher; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; import com.google.protobuf.ByteString; import lombok.Data; import lombok.Getter; import software.amazon.awssdk.services.kinesis.model.HashKeyRange; import software.amazon.awssdk.services.kinesis.model.Shard; import software.amazon.kinesis.checkpoint.ShardRecordProcessorCheckpointer; import software.amazon.kinesis.leases.ShardDetector; import software.amazon.kinesis.leases.ShardInfo; import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; import software.amazon.kinesis.metrics.NullMetricsFactory; import software.amazon.kinesis.processor.Checkpointer; import software.amazon.kinesis.processor.ShardRecordProcessor; import software.amazon.kinesis.retrieval.AggregatorUtil; import software.amazon.kinesis.retrieval.KinesisClientRecord; import software.amazon.kinesis.retrieval.ThrottlingReporter; import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; import software.amazon.kinesis.retrieval.kpl.Messages; import software.amazon.kinesis.retrieval.kpl.Messages.AggregatedRecord; @RunWith(MockitoJUnitRunner.class) public class ProcessTaskTest { private static final long IDLE_TIME_IN_MILLISECONDS = 100L; private boolean shouldCallProcessRecordsEvenForEmptyRecordList = true; private boolean skipShardSyncAtWorkerInitializationIfLeasesExist = true; private ShardInfo shardInfo; @Mock private ProcessRecordsInput processRecordsInput; @Mock private ShardDetector shardDetector; private static final byte[] TEST_DATA = new byte[] { 1, 2, 3, 4 }; private final String shardId = "shard-test"; private final long taskBackoffTimeMillis = 1L; @Mock private ShardRecordProcessor shardRecordProcessor; @Mock private ShardRecordProcessorCheckpointer checkpointer; @Mock private ThrottlingReporter throttlingReporter; private ProcessTask processTask; @Before public void setUpProcessTask() { when(checkpointer.checkpointer()).thenReturn(mock(Checkpointer.class)); shardInfo = new ShardInfo(shardId, null, null, null); } private ProcessTask makeProcessTask(ProcessRecordsInput processRecordsInput) { return makeProcessTask(processRecordsInput, new AggregatorUtil(), skipShardSyncAtWorkerInitializationIfLeasesExist); } private ProcessTask makeProcessTask(ProcessRecordsInput processRecordsInput, AggregatorUtil aggregatorUtil, boolean skipShardSync) { return new ProcessTask(shardInfo, shardRecordProcessor, checkpointer, taskBackoffTimeMillis, skipShardSync, shardDetector, throttlingReporter, processRecordsInput, shouldCallProcessRecordsEvenForEmptyRecordList, IDLE_TIME_IN_MILLISECONDS, aggregatorUtil, new NullMetricsFactory()); } @Test public void testProcessTaskWithShardEndReached() { processTask = makeProcessTask(processRecordsInput); when(processRecordsInput.isAtShardEnd()).thenReturn(true); TaskResult result = processTask.call(); assertThat(result, shardEndTaskResult(true)); } private KinesisClientRecord makeKinesisClientRecord(String partitionKey, String sequenceNumber, Instant arrival) { return KinesisClientRecord.builder().partitionKey(partitionKey).sequenceNumber(sequenceNumber) .approximateArrivalTimestamp(arrival).data(ByteBuffer.wrap(TEST_DATA)).build(); } @Test public void testNonAggregatedKinesisRecord() { final String sqn = new BigInteger(128, new Random()).toString(); final String pk = UUID.randomUUID().toString(); final Date ts = new Date(System.currentTimeMillis() - TimeUnit.MILLISECONDS.convert(4, TimeUnit.HOURS)); final KinesisClientRecord r = makeKinesisClientRecord(pk, sqn, ts.toInstant()); ShardRecordProcessorOutcome outcome = testWithRecord(r); assertEquals(1, outcome.getProcessRecordsCall().records().size()); KinesisClientRecord pr = outcome.getProcessRecordsCall().records().get(0); assertEquals(pk, pr.partitionKey()); assertEquals(ts.toInstant(), pr.approximateArrivalTimestamp()); byte[] b = pr.data().array(); assertThat(b, equalTo(TEST_DATA)); assertEquals(sqn, outcome.getCheckpointCall().sequenceNumber()); assertEquals(0, outcome.getCheckpointCall().subSequenceNumber()); } @Data static class ShardRecordProcessorOutcome { final ProcessRecordsInput processRecordsCall; final ExtendedSequenceNumber checkpointCall; } @Test public void testDeaggregatesRecord() { final String sqn = new BigInteger(128, new Random()).toString(); final String pk = UUID.randomUUID().toString(); final Instant ts = Instant.now().minus(4, ChronoUnit.HOURS); KinesisClientRecord record = KinesisClientRecord.builder().partitionKey("-").data(generateAggregatedRecord(pk)) .sequenceNumber(sqn).approximateArrivalTimestamp(ts).build(); processTask = makeProcessTask(processRecordsInput); ShardRecordProcessorOutcome outcome = testWithRecord(record); List<KinesisClientRecord> actualRecords = outcome.getProcessRecordsCall().records(); assertEquals(3, actualRecords.size()); for (KinesisClientRecord pr : actualRecords) { assertThat(pr, instanceOf(KinesisClientRecord.class)); assertEquals(pk, pr.partitionKey()); assertEquals(ts, pr.approximateArrivalTimestamp()); byte[] actualData = new byte[pr.data().limit()]; pr.data().get(actualData); assertThat(actualData, equalTo(TEST_DATA)); } assertEquals(sqn, outcome.getCheckpointCall().sequenceNumber()); assertEquals(actualRecords.size() - 1, outcome.getCheckpointCall().subSequenceNumber()); } @Test public void testDeaggregatesRecordWithNoArrivalTimestamp() { final String sqn = new BigInteger(128, new Random()).toString(); final String pk = UUID.randomUUID().toString(); KinesisClientRecord record = KinesisClientRecord.builder().partitionKey("-").data(generateAggregatedRecord(pk)) .sequenceNumber(sqn).build(); processTask = makeProcessTask(processRecordsInput); ShardRecordProcessorOutcome outcome = testWithRecord(record); List<KinesisClientRecord> actualRecords = outcome.getProcessRecordsCall().records(); assertEquals(3, actualRecords.size()); for (KinesisClientRecord actualRecord : actualRecords) { assertThat(actualRecord.partitionKey(), equalTo(pk)); assertThat(actualRecord.approximateArrivalTimestamp(), nullValue()); } } @Test public void testLargestPermittedCheckpointValue() { // Some sequence number value from previous processRecords call to mock. final BigInteger previousCheckpointSqn = new BigInteger(128, new Random()); // Values for this processRecords call. final int numberOfRecords = 104; // Start these batch of records's sequence number that is greater than previous checkpoint value. final BigInteger startingSqn = previousCheckpointSqn.add(BigInteger.valueOf(10)); final List<KinesisClientRecord> records = generateConsecutiveRecords(numberOfRecords, "-", ByteBuffer.wrap(TEST_DATA), new Date(), startingSqn); processTask = makeProcessTask(processRecordsInput); ShardRecordProcessorOutcome outcome = testWithRecords(records, new ExtendedSequenceNumber(previousCheckpointSqn.toString()), new ExtendedSequenceNumber(previousCheckpointSqn.toString())); final ExtendedSequenceNumber expectedLargestPermittedEsqn = new ExtendedSequenceNumber( startingSqn.add(BigInteger.valueOf(numberOfRecords - 1)).toString()); assertEquals(expectedLargestPermittedEsqn, outcome.getCheckpointCall()); } @Test public void testLargestPermittedCheckpointValueWithEmptyRecords() { // Some sequence number value from previous processRecords call. final BigInteger baseSqn = new BigInteger(128, new Random()); final ExtendedSequenceNumber lastCheckpointEspn = new ExtendedSequenceNumber(baseSqn.toString()); final ExtendedSequenceNumber largestPermittedEsqn = new ExtendedSequenceNumber( baseSqn.add(BigInteger.valueOf(100)).toString()); processTask = makeProcessTask(processRecordsInput); ShardRecordProcessorOutcome outcome = testWithRecords(Collections.emptyList(), lastCheckpointEspn, largestPermittedEsqn); // Make sure that even with empty records, largest permitted sequence number does not change. assertEquals(largestPermittedEsqn, outcome.getCheckpointCall()); } @Test public void testFilterBasedOnLastCheckpointValue() { // Explanation of setup: // * Assume in previous processRecord call, user got 3 sub-records that all belonged to one // Kinesis record. So sequence number was X, and sub-sequence numbers were 0, 1, 2. // * 2nd sub-record was checkpointed (extended sequnce number X.1). // * Worker crashed and restarted. So now DDB has checkpoint value of X.1. // Test: // * Now in the subsequent processRecords call, KCL should filter out X.0 and X.1. BigInteger previousCheckpointSqn = new BigInteger(128, new Random()); long previousCheckpointSsqn = 1; // Values for this processRecords call. String startingSqn = previousCheckpointSqn.toString(); String pk = UUID.randomUUID().toString(); KinesisClientRecord record = KinesisClientRecord.builder().partitionKey("-").data(generateAggregatedRecord(pk)) .sequenceNumber(startingSqn).build(); processTask = makeProcessTask(processRecordsInput); ShardRecordProcessorOutcome outcome = testWithRecords(Collections.singletonList(record), new ExtendedSequenceNumber(previousCheckpointSqn.toString(), previousCheckpointSsqn), new ExtendedSequenceNumber(previousCheckpointSqn.toString(), previousCheckpointSsqn)); List<KinesisClientRecord> actualRecords = outcome.getProcessRecordsCall().records(); // First two records should be dropped - and only 1 remaining records should be there. assertThat(actualRecords.size(), equalTo(1)); // Verify user record's extended sequence number and other fields. KinesisClientRecord actualRecord = actualRecords.get(0); assertThat(actualRecord.partitionKey(), equalTo(pk)); assertThat(actualRecord.sequenceNumber(), equalTo(startingSqn)); assertThat(actualRecord.subSequenceNumber(), equalTo(previousCheckpointSsqn + 1)); assertThat(actualRecord.approximateArrivalTimestamp(), nullValue()); // Expected largest permitted sequence number will be last sub-record sequence number. final ExtendedSequenceNumber expectedLargestPermittedEsqn = new ExtendedSequenceNumber( previousCheckpointSqn.toString(), 2L); assertEquals(expectedLargestPermittedEsqn, outcome.getCheckpointCall()); } @Test public void testDiscardReshardedKplData() throws Exception { BigInteger sequenceNumber = new BigInteger(120, ThreadLocalRandom.current()); String lowHashKey = BigInteger.ONE.shiftLeft(60).toString(); String highHashKey = BigInteger.ONE.shiftLeft(68).toString(); ControlledHashAggregatorUtil aggregatorUtil = new ControlledHashAggregatorUtil(lowHashKey, highHashKey); AggregatedRecord.Builder aggregatedRecord = AggregatedRecord.newBuilder(); Instant approximateArrivalTime = Instant.now(); int recordIndex = 0; sequenceNumber = sequenceNumber.add(BigInteger.ONE); for (int i = 0; i < 5; ++i) { KinesisClientRecord expectedRecord = createAndRegisterAggregatedRecord(sequenceNumber, aggregatedRecord, recordIndex, approximateArrivalTime); aggregatorUtil.addInRange(expectedRecord); recordIndex++; } sequenceNumber = sequenceNumber.add(BigInteger.ONE); for (int i = 0; i < 5; ++i) { KinesisClientRecord expectedRecord = createAndRegisterAggregatedRecord(sequenceNumber, aggregatedRecord, recordIndex, approximateArrivalTime); aggregatorUtil.addBelowRange(expectedRecord); recordIndex++; } sequenceNumber = sequenceNumber.add(BigInteger.ONE); for (int i = 0; i < 5; ++i) { KinesisClientRecord expectedRecord = createAndRegisterAggregatedRecord(sequenceNumber, aggregatedRecord, recordIndex, approximateArrivalTime); aggregatorUtil.addAboveRange(expectedRecord); recordIndex++; } byte[] payload = aggregatedRecord.build().toByteArray(); ByteArrayOutputStream bos = new ByteArrayOutputStream(); bos.write(new byte[] { -13, -119, -102, -62 }); bos.write(payload); bos.write(md5(payload)); ByteBuffer rawRecordData = ByteBuffer.wrap(bos.toByteArray()); KinesisClientRecord rawRecord = KinesisClientRecord.builder().data(rawRecordData) .approximateArrivalTimestamp(approximateArrivalTime).partitionKey("p-01") .sequenceNumber(sequenceNumber.toString()).build(); when(shardDetector.shard(any())).thenReturn(Shard.builder().shardId("Shard-01") .hashKeyRange(HashKeyRange.builder().startingHashKey(lowHashKey).endingHashKey(highHashKey).build()) .build()); when(processRecordsInput.records()).thenReturn(Collections.singletonList(rawRecord)); ProcessTask processTask = makeProcessTask(processRecordsInput, aggregatorUtil, false); ShardRecordProcessorOutcome outcome = testWithRecords(processTask, new ExtendedSequenceNumber(sequenceNumber.subtract(BigInteger.valueOf(100)).toString(), 0L), new ExtendedSequenceNumber(sequenceNumber.toString(), recordIndex + 1L)); assertThat(outcome.processRecordsCall.records().size(), equalTo(0)); } @Test public void testAllInShardKplData() throws Exception { BigInteger sequenceNumber = new BigInteger(120, ThreadLocalRandom.current()); String lowHashKey = BigInteger.ONE.shiftLeft(60).toString(); String highHashKey = BigInteger.ONE.shiftLeft(68).toString(); ControlledHashAggregatorUtil aggregatorUtil = new ControlledHashAggregatorUtil(lowHashKey, highHashKey); List<KinesisClientRecord> expectedRecords = new ArrayList<>(); List<KinesisClientRecord> rawRecords = new ArrayList<>(); for (int i = 0; i < 3; ++i) { AggregatedRecord.Builder aggregatedRecord = AggregatedRecord.newBuilder(); Instant approximateArrivalTime = Instant.now().minus(i + 4, ChronoUnit.SECONDS); sequenceNumber = sequenceNumber.add(BigInteger.ONE); for (int j = 0; j < 2; ++j) { KinesisClientRecord expectedRecord = createAndRegisterAggregatedRecord(sequenceNumber, aggregatedRecord, j, approximateArrivalTime); aggregatorUtil.addInRange(expectedRecord); expectedRecords.add(expectedRecord); } byte[] payload = aggregatedRecord.build().toByteArray(); ByteArrayOutputStream bos = new ByteArrayOutputStream(); bos.write(AggregatorUtil.AGGREGATED_RECORD_MAGIC); bos.write(payload); bos.write(md5(payload)); ByteBuffer rawRecordData = ByteBuffer.wrap(bos.toByteArray()); KinesisClientRecord rawRecord = KinesisClientRecord.builder().data(rawRecordData) .approximateArrivalTimestamp(approximateArrivalTime).partitionKey("pa-" + i) .sequenceNumber(sequenceNumber.toString()).build(); rawRecords.add(rawRecord); } when(shardDetector.shard(any())).thenReturn(Shard.builder().shardId("Shard-01") .hashKeyRange(HashKeyRange.builder().startingHashKey(lowHashKey).endingHashKey(highHashKey).build()) .build()); when(processRecordsInput.records()).thenReturn(rawRecords); ProcessTask processTask = makeProcessTask(processRecordsInput, aggregatorUtil, false); ShardRecordProcessorOutcome outcome = testWithRecords(processTask, new ExtendedSequenceNumber(sequenceNumber.subtract(BigInteger.valueOf(100)).toString(), 0L), new ExtendedSequenceNumber(sequenceNumber.toString(), 0L)); assertThat(outcome.processRecordsCall.records(), equalTo(expectedRecords)); } private KinesisClientRecord createAndRegisterAggregatedRecord(BigInteger sequenceNumber, AggregatedRecord.Builder aggregatedRecord, int i, Instant approximateArrivalTime) { byte[] dataArray = new byte[1024]; ThreadLocalRandom.current().nextBytes(dataArray); ByteBuffer data = ByteBuffer.wrap(dataArray); KinesisClientRecord expectedRecord = KinesisClientRecord.builder().partitionKey("p-" + i) .sequenceNumber(sequenceNumber.toString()).approximateArrivalTimestamp(approximateArrivalTime) .data(data).subSequenceNumber(i).aggregated(true).build(); Messages.Record kplRecord = Messages.Record.newBuilder().setData(ByteString.copyFrom(dataArray)) .setPartitionKeyIndex(i).build(); aggregatedRecord.addPartitionKeyTable(expectedRecord.partitionKey()).addRecords(kplRecord); return expectedRecord; } private enum RecordRangeState { BELOW_RANGE, IN_RANGE, ABOVE_RANGE } @Getter private static class ControlledHashAggregatorUtil extends AggregatorUtil { private final BigInteger lowHashKey; private final BigInteger highHashKey; private final long width; private final Map<String, RecordRangeState> recordRanges = new HashMap<>(); ControlledHashAggregatorUtil(String lowHashKey, String highHashKey) { this.lowHashKey = new BigInteger(lowHashKey); this.highHashKey = new BigInteger(highHashKey); this.width = this.highHashKey.subtract(this.lowHashKey).mod(BigInteger.valueOf(Long.MAX_VALUE)).longValue() - 1; } void add(KinesisClientRecord record, RecordRangeState recordRangeState) { recordRanges.put(record.partitionKey(), recordRangeState); } void addInRange(KinesisClientRecord record) { add(record, RecordRangeState.IN_RANGE); } void addBelowRange(KinesisClientRecord record) { add(record, RecordRangeState.BELOW_RANGE); } void addAboveRange(KinesisClientRecord record) { add(record, RecordRangeState.ABOVE_RANGE); } @Override protected BigInteger effectiveHashKey(String partitionKey, String explicitHashKey) { RecordRangeState rangeState = recordRanges.get(partitionKey); assertThat(rangeState, not(nullValue())); switch (rangeState) { case BELOW_RANGE: return lowHashKey.subtract(BigInteger.valueOf(ThreadLocalRandom.current().nextInt()).abs()); case IN_RANGE: return lowHashKey.add(BigInteger.valueOf(ThreadLocalRandom.current().nextLong(width))); case ABOVE_RANGE: return highHashKey.add(BigInteger.ONE) .add(BigInteger.valueOf(ThreadLocalRandom.current().nextInt()).abs()); default: throw new IllegalStateException("Unknown range state: " + rangeState); } } } private ShardRecordProcessorOutcome testWithRecord(KinesisClientRecord record) { return testWithRecords(Collections.singletonList(record), ExtendedSequenceNumber.TRIM_HORIZON, ExtendedSequenceNumber.TRIM_HORIZON); } private ShardRecordProcessorOutcome testWithRecords(List<KinesisClientRecord> records, ExtendedSequenceNumber lastCheckpointValue, ExtendedSequenceNumber largestPermittedCheckpointValue) { return testWithRecords(records, lastCheckpointValue, largestPermittedCheckpointValue, new AggregatorUtil()); } private ShardRecordProcessorOutcome testWithRecords(List<KinesisClientRecord> records, ExtendedSequenceNumber lastCheckpointValue, ExtendedSequenceNumber largestPermittedCheckpointValue, AggregatorUtil aggregatorUtil) { when(processRecordsInput.records()).thenReturn(records); return testWithRecords( makeProcessTask(processRecordsInput, aggregatorUtil, skipShardSyncAtWorkerInitializationIfLeasesExist), lastCheckpointValue, largestPermittedCheckpointValue); } private ShardRecordProcessorOutcome testWithRecords(ProcessTask processTask, ExtendedSequenceNumber lastCheckpointValue, ExtendedSequenceNumber largestPermittedCheckpointValue) { when(checkpointer.lastCheckpointValue()).thenReturn(lastCheckpointValue); when(checkpointer.largestPermittedCheckpointValue()).thenReturn(largestPermittedCheckpointValue); processTask.call(); verify(throttlingReporter).success(); verify(throttlingReporter, never()).throttled(); ArgumentCaptor<ProcessRecordsInput> recordsCaptor = ArgumentCaptor.forClass(ProcessRecordsInput.class); verify(shardRecordProcessor).processRecords(recordsCaptor.capture()); ArgumentCaptor<ExtendedSequenceNumber> esnCaptor = ArgumentCaptor.forClass(ExtendedSequenceNumber.class); verify(checkpointer).largestPermittedCheckpointValue(esnCaptor.capture()); return new ShardRecordProcessorOutcome(recordsCaptor.getValue(), esnCaptor.getValue()); } /** * See the KPL documentation on GitHub for more details about the binary format. * * @param pk * Partition key to use. All the records will have the same partition key. * @return ByteBuffer containing the serialized form of the aggregated record, along with the necessary header and * footer. */ private static ByteBuffer generateAggregatedRecord(String pk) { ByteBuffer bb = ByteBuffer.allocate(1024); bb.put(new byte[] { -13, -119, -102, -62 }); Messages.Record r = Messages.Record.newBuilder().setData(ByteString.copyFrom(TEST_DATA)).setPartitionKeyIndex(0) .build(); byte[] payload = AggregatedRecord.newBuilder().addPartitionKeyTable(pk).addRecords(r).addRecords(r) .addRecords(r).build().toByteArray(); bb.put(payload); bb.put(md5(payload)); bb.limit(bb.position()); bb.rewind(); return bb; } private static List<KinesisClientRecord> generateConsecutiveRecords(int numberOfRecords, String partitionKey, ByteBuffer data, Date arrivalTimestamp, BigInteger startSequenceNumber) { List<KinesisClientRecord> records = new ArrayList<>(); for (int i = 0; i < numberOfRecords; ++i) { String seqNum = startSequenceNumber.add(BigInteger.valueOf(i)).toString(); KinesisClientRecord record = KinesisClientRecord.builder().partitionKey(partitionKey).data(data) .sequenceNumber(seqNum).approximateArrivalTimestamp(arrivalTimestamp.toInstant()).build(); records.add(record); } return records; } private static byte[] md5(byte[] b) { try { MessageDigest md = MessageDigest.getInstance("MD5"); return md.digest(b); } catch (Exception e) { throw new RuntimeException(e); } } private static TaskResultMatcher shardEndTaskResult(boolean isAtShardEnd) { TaskResult expected = new TaskResult(null, isAtShardEnd); return taskResult(expected); } private static TaskResultMatcher exceptionTaskResult(Exception ex) { TaskResult expected = new TaskResult(ex, false); return taskResult(expected); } private static TaskResultMatcher taskResult(TaskResult expected) { return new TaskResultMatcher(expected); } private static class TaskResultMatcher extends TypeSafeDiagnosingMatcher<TaskResult> { Matcher<TaskResult> matchers; TaskResultMatcher(TaskResult expected) { if (expected == null) { matchers = nullValue(TaskResult.class); } else { matchers = allOf(notNullValue(TaskResult.class), hasProperty("shardEndReached", equalTo(expected.isShardEndReached())), hasProperty("exception", equalTo(expected.getException()))); } } @Override protected boolean matchesSafely(TaskResult item, Description mismatchDescription) { if (!matchers.matches(item)) { matchers.describeMismatch(item, mismatchDescription); return false; } return true; } @Override public void describeTo(Description description) { description.appendDescriptionOf(matchers); } } }