package software.amazon.kinesis.retrieval.fanout;

import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.netty.handler.timeout.ReadTimeoutException;
import io.reactivex.Flowable;
import io.reactivex.Scheduler;
import io.reactivex.schedulers.Schedulers;
import io.reactivex.subscribers.SafeSubscriber;
import lombok.Data;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.hamcrest.Description;
import org.hamcrest.Matcher;
import org.hamcrest.TypeSafeDiagnosingMatcher;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.runners.MockitoJUnitRunner;
import org.mockito.stubbing.Answer;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException;
import software.amazon.awssdk.services.kinesis.model.ShardIteratorType;
import software.amazon.awssdk.services.kinesis.model.StartingPosition;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEventStream;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest;
import software.amazon.kinesis.common.InitialPositionInStream;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.lifecycle.ShardConsumerNotifyingSubscriber;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
import software.amazon.kinesis.retrieval.BatchUniqueIdentifier;
import software.amazon.kinesis.retrieval.KinesisClientRecord;
import software.amazon.kinesis.retrieval.RecordsRetrieved;
import software.amazon.kinesis.retrieval.RetryableRetrievalException;
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
import software.amazon.kinesis.utils.SubscribeToShardRequestMatcher;

import java.nio.ByteBuffer;
import java.time.Instant;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.argThat;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

@RunWith(MockitoJUnitRunner.class)
@Slf4j
public class FanOutRecordsPublisherTest {

    private static final String SHARD_ID = "Shard-001";
    private static final String CONSUMER_ARN = "arn:consumer";

    @Mock
    private KinesisAsyncClient kinesisClient;
    @Mock
    private SdkPublisher<SubscribeToShardEventStream> publisher;
    @Mock
    private Subscription subscription;
    @Mock
    private Subscriber<RecordsRetrieved> subscriber;

    private SubscribeToShardEvent batchEvent;

    @Test
    public void simpleTest() throws Exception {
        FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);

        ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordSubscription.class);
        ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordFlow.class);

        doNothing().when(publisher).subscribe(captor.capture());

        source.start(ExtendedSequenceNumber.LATEST,
                InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST));

        List<ProcessRecordsInput> receivedInput = new ArrayList<>();

        source.subscribe(new ShardConsumerNotifyingSubscriber(new Subscriber<RecordsRetrieved>() {
            Subscription subscription;

            @Override public void onSubscribe(Subscription s) {
                subscription = s;
                subscription.request(1);
            }

            @Override public void onNext(RecordsRetrieved input) {
                receivedInput.add(input.processRecordsInput());
                subscription.request(1);
            }

            @Override public void onError(Throwable t) {
                log.error("Caught throwable in subscriber", t);
                fail("Caught throwable in subscriber");
            }

            @Override public void onComplete() {
                fail("OnComplete called when not expected");
            }
        }, source));

        verify(kinesisClient).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture());
        flowCaptor.getValue().onEventStream(publisher);
        captor.getValue().onSubscribe(subscription);

        List<Record> records = Stream.of(1, 2, 3).map(this::makeRecord).collect(Collectors.toList());
        List<KinesisClientRecordMatcher> matchers = records.stream().map(KinesisClientRecordMatcher::new)
                .collect(Collectors.toList());

        batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records).build();

        captor.getValue().onNext(batchEvent);
        captor.getValue().onNext(batchEvent);
        captor.getValue().onNext(batchEvent);

        verify(subscription, times(4)).request(1);
        assertThat(receivedInput.size(), equalTo(3));

        receivedInput.stream().map(ProcessRecordsInput::records).forEach(clientRecordsList -> {
            assertThat(clientRecordsList.size(), equalTo(matchers.size()));
            for (int i = 0; i < clientRecordsList.size(); ++i) {
                assertThat(clientRecordsList.get(i), matchers.get(i));
            }
        });

    }

    @Test
    public void testIfAllEventsReceivedWhenNoTasksRejectedByExecutor() throws Exception {
        FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);

        ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordSubscription.class);
        ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordFlow.class);

        doNothing().when(publisher).subscribe(captor.capture());

        source.start(ExtendedSequenceNumber.LATEST,
                InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST));

        List<ProcessRecordsInput> receivedInput = new ArrayList<>();

        Subscriber<RecordsRetrieved> shardConsumerSubscriber = new ShardConsumerNotifyingSubscriber(
                new Subscriber<RecordsRetrieved>() {
                    Subscription subscription;

                    @Override public void onSubscribe(Subscription s) {
                        subscription = s;
                        subscription.request(1);
                    }

                    @Override public void onNext(RecordsRetrieved input) {
                        receivedInput.add(input.processRecordsInput());
                        subscription.request(1);
                    }

                    @Override public void onError(Throwable t) {
                        log.error("Caught throwable in subscriber", t);
                        fail("Caught throwable in subscriber");
                    }

                    @Override public void onComplete() {
                        fail("OnComplete called when not expected");
                    }
                }, source);

        Scheduler testScheduler = getScheduler(getBlockingExecutor(getSpiedExecutor(getTestExecutor())));
        int bufferSize = 8;

        Flowable.fromPublisher(source).subscribeOn(testScheduler).observeOn(testScheduler, true, bufferSize)
                .subscribe(shardConsumerSubscriber);

        verify(kinesisClient).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture());
        flowCaptor.getValue().onEventStream(publisher);
        captor.getValue().onSubscribe(subscription);

        List<Record> records = Stream.of(1, 2, 3).map(this::makeRecord).collect(Collectors.toList());
        List<KinesisClientRecordMatcher> matchers = records.stream().map(KinesisClientRecordMatcher::new)
                .collect(Collectors.toList());

        Stream.of("1000", "2000", "3000")
                .map(contSeqNum ->
                        SubscribeToShardEvent.builder()
                                .millisBehindLatest(100L)
                                .continuationSequenceNumber(contSeqNum)
                                .records(records).build())
                .forEach(batchEvent -> captor.getValue().onNext(batchEvent));

        verify(subscription, times(4)).request(1);
        assertThat(receivedInput.size(), equalTo(3));

        receivedInput.stream().map(ProcessRecordsInput::records).forEach(clientRecordsList -> {
            assertThat(clientRecordsList.size(), equalTo(matchers.size()));
            for (int i = 0; i < clientRecordsList.size(); ++i) {
                assertThat(clientRecordsList.get(i), matchers.get(i));
            }
        });

        assertThat(source.getCurrentSequenceNumber(), equalTo("3000"));

    }

    @Test
    public void testIfEventsAreNotDeliveredToShardConsumerWhenPreviousEventDeliveryTaskGetsRejected() throws Exception {
        FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);

        ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordSubscription.class);
        ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordFlow.class);

        doNothing().when(publisher).subscribe(captor.capture());

        source.start(ExtendedSequenceNumber.LATEST,
                InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST));

        List<ProcessRecordsInput> receivedInput = new ArrayList<>();

        Subscriber<RecordsRetrieved> shardConsumerSubscriber = new ShardConsumerNotifyingSubscriber(
                new Subscriber<RecordsRetrieved>() {
                    Subscription subscription;

                    @Override public void onSubscribe(Subscription s) {
                        subscription = s;
                        subscription.request(1);
                    }

                    @Override public void onNext(RecordsRetrieved input) {
                        receivedInput.add(input.processRecordsInput());
                        subscription.request(1);
                    }

                    @Override public void onError(Throwable t) {
                        log.error("Caught throwable in subscriber", t);
                        fail("Caught throwable in subscriber");
                    }

                    @Override public void onComplete() {
                        fail("OnComplete called when not expected");
                    }
                }, source);

        Scheduler testScheduler = getScheduler(getOverwhelmedBlockingExecutor(getSpiedExecutor(getTestExecutor())));
        int bufferSize = 8;

        Flowable.fromPublisher(source).subscribeOn(testScheduler).observeOn(testScheduler, true, bufferSize)
                .subscribe(new SafeSubscriber<>(shardConsumerSubscriber));

        verify(kinesisClient).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture());
        flowCaptor.getValue().onEventStream(publisher);
        captor.getValue().onSubscribe(subscription);

        List<Record> records = Stream.of(1, 2, 3).map(this::makeRecord).collect(Collectors.toList());
        List<KinesisClientRecordMatcher> matchers = records.stream().map(KinesisClientRecordMatcher::new)
                .collect(Collectors.toList());

        Stream.of("1000", "2000", "3000")
                .map(contSeqNum ->
                        SubscribeToShardEvent.builder()
                                .millisBehindLatest(100L)
                                .continuationSequenceNumber(contSeqNum)
                                .records(records).build())
                .forEach(batchEvent -> captor.getValue().onNext(batchEvent));

        verify(subscription, times(2)).request(1);
        assertThat(receivedInput.size(), equalTo(1));

        receivedInput.stream().map(ProcessRecordsInput::records).forEach(clientRecordsList -> {
            assertThat(clientRecordsList.size(), equalTo(matchers.size()));
            for (int i = 0; i < clientRecordsList.size(); ++i) {
                assertThat(clientRecordsList.get(i), matchers.get(i));
            }
        });

        assertThat(source.getCurrentSequenceNumber(), equalTo("1000"));

    }

    @Test
    public void testIfStreamOfEventsAreDeliveredInOrderWithBackpressureAdheringServicePublisher() throws Exception {
        FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);

        ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordSubscription.class);
        ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordFlow.class);

        List<Record> records = Stream.of(1, 2, 3).map(this::makeRecord).collect(Collectors.toList());

        Consumer<Integer> servicePublisherAction = contSeqNum -> captor.getValue().onNext(
                SubscribeToShardEvent.builder()
                        .millisBehindLatest(100L)
                        .continuationSequenceNumber(contSeqNum + "")
                        .records(records)
                        .build());

        CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2);
        int totalServicePublisherEvents = 1000;
        int initialDemand = 0;
        BackpressureAdheringServicePublisher servicePublisher =
                new BackpressureAdheringServicePublisher(servicePublisherAction, totalServicePublisherEvents, servicePublisherTaskCompletionLatch, initialDemand);

        doNothing().when(publisher).subscribe(captor.capture());

        source.start(ExtendedSequenceNumber.LATEST,
                InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST));

        List<ProcessRecordsInput> receivedInput = new ArrayList<>();

        Subscriber<RecordsRetrieved> shardConsumerSubscriber = new ShardConsumerNotifyingSubscriber(
                new Subscriber<RecordsRetrieved>() {
                    private Subscription subscription;
                    private int lastSeenSeqNum = 0;

                    @Override public void onSubscribe(Subscription s) {
                        subscription = s;
                        subscription.request(1);
                        servicePublisher.request(1);
                    }

                    @Override public void onNext(RecordsRetrieved input) {
                        receivedInput.add(input.processRecordsInput());
                        assertEquals("" + ++lastSeenSeqNum, ((FanOutRecordsPublisher.FanoutRecordsRetrieved)input).continuationSequenceNumber());
                        subscription.request(1);
                        servicePublisher.request(1);
                        if(receivedInput.size() == totalServicePublisherEvents) {
                            servicePublisherTaskCompletionLatch.countDown();
                        }
                    }

                    @Override public void onError(Throwable t) {
                        log.error("Caught throwable in subscriber", t);
                        fail("Caught throwable in subscriber");
                    }

                    @Override public void onComplete() {
                        fail("OnComplete called when not expected");
                    }
                }, source);

        ExecutorService executorService = getTestExecutor();
        Scheduler testScheduler = getScheduler(getInitiallyBlockingExecutor(getSpiedExecutor(executorService)));
        int bufferSize = 8;

        Flowable.fromPublisher(source).subscribeOn(testScheduler).observeOn(testScheduler, true, bufferSize)
                .subscribe(shardConsumerSubscriber);

        verify(kinesisClient).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture());
        flowCaptor.getValue().onEventStream(publisher);
        captor.getValue().onSubscribe(subscription);

        List<KinesisClientRecordMatcher> matchers = records.stream().map(KinesisClientRecordMatcher::new)
                .collect(Collectors.toList());

        executorService.submit(servicePublisher);
        servicePublisherTaskCompletionLatch.await(5000, TimeUnit.MILLISECONDS);

        assertThat(receivedInput.size(), equalTo(totalServicePublisherEvents));

        receivedInput.stream().map(ProcessRecordsInput::records).forEach(clientRecordsList -> {
            assertThat(clientRecordsList.size(), equalTo(matchers.size()));
            for (int i = 0; i < clientRecordsList.size(); ++i) {
                assertThat(clientRecordsList.get(i), matchers.get(i));
            }
        });

        assertThat(source.getCurrentSequenceNumber(), equalTo(totalServicePublisherEvents + ""));

    }

    @Test
    public void testIfStreamOfEventsAndOnCompleteAreDeliveredInOrderWithBackpressureAdheringServicePublisher() throws Exception {

        CountDownLatch onS2SCallLatch = new CountDownLatch(2);

        doAnswer(new Answer() {
            @Override public Object answer(InvocationOnMock invocation) throws Throwable {
                onS2SCallLatch.countDown();
                return null;
            }
        }).when(kinesisClient).subscribeToShard(any(SubscribeToShardRequest.class), any());

        FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);

        ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordSubscription.class);
        ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordFlow.class);

        List<Record> records = Stream.of(1, 2, 3).map(this::makeRecord).collect(Collectors.toList());

        Consumer<Integer> servicePublisherAction = contSeqNum -> captor.getValue().onNext(
                SubscribeToShardEvent.builder()
                        .millisBehindLatest(100L)
                        .continuationSequenceNumber(contSeqNum + "")
                        .records(records)
                        .build());

        CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2);
        int totalServicePublisherEvents = 1000;
        int initialDemand = 9;
        int triggerCompleteAtNthEvent = 200;
        BackpressureAdheringServicePublisher servicePublisher = new BackpressureAdheringServicePublisher(
                servicePublisherAction, totalServicePublisherEvents, servicePublisherTaskCompletionLatch,
                initialDemand);
        servicePublisher.setCompleteTrigger(triggerCompleteAtNthEvent, () -> flowCaptor.getValue().complete());

        doNothing().when(publisher).subscribe(captor.capture());

        source.start(ExtendedSequenceNumber.LATEST,
                InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST));

        List<ProcessRecordsInput> receivedInput = new ArrayList<>();

        Subscriber<RecordsRetrieved> shardConsumerSubscriber = new ShardConsumerNotifyingSubscriber(
                new Subscriber<RecordsRetrieved>() {
                    private Subscription subscription;
                    private int lastSeenSeqNum = 0;

                    @Override public void onSubscribe(Subscription s) {
                        subscription = s;
                        subscription.request(1);
                        servicePublisher.request(1);
                    }

                    @Override public void onNext(RecordsRetrieved input) {
                        receivedInput.add(input.processRecordsInput());
                        assertEquals("" + ++lastSeenSeqNum, ((FanOutRecordsPublisher.FanoutRecordsRetrieved)input).continuationSequenceNumber());
                        subscription.request(1);
                        servicePublisher.request(1);
                        if(receivedInput.size() == triggerCompleteAtNthEvent) {
                            servicePublisherTaskCompletionLatch.countDown();
                        }
                    }

                    @Override public void onError(Throwable t) {
                        log.error("Caught throwable in subscriber", t);
                        fail("Caught throwable in subscriber");
                    }

                    @Override public void onComplete() {
                        fail("OnComplete called when not expected");
                    }
                }, source);

        ExecutorService executorService = getTestExecutor();
        Scheduler testScheduler = getScheduler(getInitiallyBlockingExecutor(getSpiedExecutor(executorService)));
        int bufferSize = 8;

        Flowable.fromPublisher(source).subscribeOn(testScheduler).observeOn(testScheduler, true, bufferSize)
                .subscribe(shardConsumerSubscriber);

        verify(kinesisClient, times(1)).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture());

        flowCaptor.getValue().onEventStream(publisher);
        captor.getValue().onSubscribe(subscription);

        List<KinesisClientRecordMatcher> matchers = records.stream().map(KinesisClientRecordMatcher::new)
                .collect(Collectors.toList());

        executorService.submit(servicePublisher);
        servicePublisherTaskCompletionLatch.await(5000, TimeUnit.MILLISECONDS);

        assertThat(receivedInput.size(), equalTo(triggerCompleteAtNthEvent));

        receivedInput.stream().map(ProcessRecordsInput::records).forEach(clientRecordsList -> {
            assertThat(clientRecordsList.size(), equalTo(matchers.size()));
            for (int i = 0; i < clientRecordsList.size(); ++i) {
                assertThat(clientRecordsList.get(i), matchers.get(i));
            }
        });

        assertThat(source.getCurrentSequenceNumber(), equalTo(triggerCompleteAtNthEvent + ""));
        // In non-shard end cases, upon successful completion, the publisher would re-subscribe to service.
        // Let's wait for sometime to allow the publisher to re-subscribe
        onS2SCallLatch.await(5000, TimeUnit.MILLISECONDS);
        verify(kinesisClient, times(2)).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture());

    }

    @Test
    public void testIfShardEndEventAndOnCompleteAreDeliveredInOrderWithBackpressureAdheringServicePublisher() throws Exception {
        FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);

        ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordSubscription.class);
        ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordFlow.class);

        List<Record> records = Stream.of(1, 2, 3).map(this::makeRecord).collect(Collectors.toList());

        Consumer<Integer> servicePublisherAction = contSeqNum -> captor.getValue().onNext(
                SubscribeToShardEvent.builder()
                        .millisBehindLatest(100L)
                        .continuationSequenceNumber(contSeqNum + "")
                        .records(records)
                        .build());

        Consumer<Integer> servicePublisherShardEndAction = contSeqNum -> captor.getValue().onNext(
                SubscribeToShardEvent.builder()
                        .millisBehindLatest(100L)
                        .continuationSequenceNumber(null)
                        .records(records)
                        .build());

        CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2);
        CountDownLatch onCompleteLatch = new CountDownLatch(1);
        int totalServicePublisherEvents = 1000;
        int initialDemand = 9;
        int triggerCompleteAtNthEvent = 200;
        BackpressureAdheringServicePublisher servicePublisher = new BackpressureAdheringServicePublisher(
                servicePublisherAction, totalServicePublisherEvents, servicePublisherTaskCompletionLatch,
                initialDemand);

        servicePublisher
                .setShardEndAndCompleteTrigger(triggerCompleteAtNthEvent, () -> flowCaptor.getValue().complete(),
                        servicePublisherShardEndAction);

        doNothing().when(publisher).subscribe(captor.capture());

        source.start(ExtendedSequenceNumber.LATEST,
                InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST));

        List<ProcessRecordsInput> receivedInput = new ArrayList<>();

        final boolean[] isOnCompleteTriggered = { false };

        Subscriber<RecordsRetrieved> shardConsumerSubscriber = new ShardConsumerNotifyingSubscriber(
                new Subscriber<RecordsRetrieved>() {
                    private Subscription subscription;
                    private int lastSeenSeqNum = 0;

                    @Override public void onSubscribe(Subscription s) {
                        subscription = s;
                        subscription.request(1);
                        servicePublisher.request(1);
                    }

                    @Override public void onNext(RecordsRetrieved input) {
                        receivedInput.add(input.processRecordsInput());
                        subscription.request(1);
                        servicePublisher.request(1);
                        if(receivedInput.size() == triggerCompleteAtNthEvent) {
                            servicePublisherTaskCompletionLatch.countDown();
                        }
                    }

                    @Override public void onError(Throwable t) {
                        log.error("Caught throwable in subscriber", t);
                        fail("Caught throwable in subscriber");
                    }

                    @Override public void onComplete() {
                        isOnCompleteTriggered[0] = true;
                        onCompleteLatch.countDown();
                    }
                }, source);

        ExecutorService executorService = getTestExecutor();
        Scheduler testScheduler = getScheduler(getInitiallyBlockingExecutor(getSpiedExecutor(executorService)));
        int bufferSize = 8;

        Flowable.fromPublisher(source).subscribeOn(testScheduler).observeOn(testScheduler, true, bufferSize)
                .subscribe(shardConsumerSubscriber);

        verify(kinesisClient, times(1)).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture());

        flowCaptor.getValue().onEventStream(publisher);
        captor.getValue().onSubscribe(subscription);

        List<KinesisClientRecordMatcher> matchers = records.stream().map(KinesisClientRecordMatcher::new)
                .collect(Collectors.toList());

        executorService.submit(servicePublisher);
        servicePublisherTaskCompletionLatch.await(5000, TimeUnit.MILLISECONDS);

        assertThat(receivedInput.size(), equalTo(triggerCompleteAtNthEvent));

        receivedInput.stream().map(ProcessRecordsInput::records).forEach(clientRecordsList -> {
            assertThat(clientRecordsList.size(), equalTo(matchers.size()));
            for (int i = 0; i < clientRecordsList.size(); ++i) {
                assertThat(clientRecordsList.get(i), matchers.get(i));
            }
        });

        assertNull(source.getCurrentSequenceNumber());
        // With shard end event, onComplete must be propagated to the subscriber.
        onCompleteLatch.await(5000, TimeUnit.MILLISECONDS);
        assertTrue("OnComplete should be triggered", isOnCompleteTriggered[0]);

    }

    @Test
    public void testIfStreamOfEventsAndOnErrorAreDeliveredInOrderWithBackpressureAdheringServicePublisher() throws Exception {
        FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);

        ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordSubscription.class);
        ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordFlow.class);

        List<Record> records = Stream.of(1, 2, 3).map(this::makeRecord).collect(Collectors.toList());

        Consumer<Integer> servicePublisherAction = contSeqNum -> captor.getValue().onNext(
                SubscribeToShardEvent.builder()
                        .millisBehindLatest(100L)
                        .continuationSequenceNumber(contSeqNum + "")
                        .records(records)
                        .build());

        CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2);
        CountDownLatch onErrorReceiveLatch = new CountDownLatch(1);
        int totalServicePublisherEvents = 1000;
        int initialDemand = 9;
        int triggerErrorAtNthEvent = 241;
        BackpressureAdheringServicePublisher servicePublisher = new BackpressureAdheringServicePublisher(
                servicePublisherAction, totalServicePublisherEvents, servicePublisherTaskCompletionLatch,
                initialDemand);
        servicePublisher.setErrorTrigger(triggerErrorAtNthEvent,
                () -> flowCaptor.getValue().exceptionOccurred(new RuntimeException("Service Exception")));

        doNothing().when(publisher).subscribe(captor.capture());

        source.start(ExtendedSequenceNumber.LATEST,
                InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST));

        final boolean[] isOnErrorThrown = { false };

        List<ProcessRecordsInput> receivedInput = new ArrayList<>();

        Subscriber<RecordsRetrieved> shardConsumerSubscriber = new ShardConsumerNotifyingSubscriber(
                new Subscriber<RecordsRetrieved>() {
                    private Subscription subscription;
                    private int lastSeenSeqNum = 0;

                    @Override public void onSubscribe(Subscription s) {
                        subscription = s;
                        subscription.request(1);
                        servicePublisher.request(1);
                    }

                    @Override public void onNext(RecordsRetrieved input) {
                        receivedInput.add(input.processRecordsInput());
                        assertEquals("" + ++lastSeenSeqNum, ((FanOutRecordsPublisher.FanoutRecordsRetrieved)input).continuationSequenceNumber());
                        subscription.request(1);
                        servicePublisher.request(1);
                        if(receivedInput.size() == triggerErrorAtNthEvent) {
                            servicePublisherTaskCompletionLatch.countDown();
                        }
                    }

                    @Override public void onError(Throwable t) {
                        log.error("Caught throwable in subscriber", t);
                        isOnErrorThrown[0] = true;
                        onErrorReceiveLatch.countDown();
                    }

                    @Override public void onComplete() {
                        fail("OnComplete called when not expected");
                    }
                }, source);

        ExecutorService executorService = getTestExecutor();
        Scheduler testScheduler = getScheduler(getInitiallyBlockingExecutor(getSpiedExecutor(executorService)));
        int bufferSize = 8;

        Flowable.fromPublisher(source).subscribeOn(testScheduler).observeOn(testScheduler, true, bufferSize)
                .subscribe(shardConsumerSubscriber);

        verify(kinesisClient).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture());
        flowCaptor.getValue().onEventStream(publisher);
        captor.getValue().onSubscribe(subscription);

        List<KinesisClientRecordMatcher> matchers = records.stream().map(KinesisClientRecordMatcher::new)
                .collect(Collectors.toList());

        executorService.submit(servicePublisher);
        servicePublisherTaskCompletionLatch.await(5000, TimeUnit.MILLISECONDS);

        assertThat(receivedInput.size(), equalTo(triggerErrorAtNthEvent));

        receivedInput.stream().map(ProcessRecordsInput::records).forEach(clientRecordsList -> {
            assertThat(clientRecordsList.size(), equalTo(matchers.size()));
            for (int i = 0; i < clientRecordsList.size(); ++i) {
                assertThat(clientRecordsList.get(i), matchers.get(i));
            }
        });

        assertThat(source.getCurrentSequenceNumber(), equalTo(triggerErrorAtNthEvent + ""));
        onErrorReceiveLatch.await(5000, TimeUnit.MILLISECONDS);
        assertTrue("OnError should have been thrown", isOnErrorThrown[0]);

    }

    @Test
    public void testIfStreamOfEventsAreDeliveredInOrderWithBackpressureAdheringServicePublisherHavingInitialBurstWithinLimit() throws Exception {
        FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);

        ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordSubscription.class);
        ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordFlow.class);

        List<Record> records = Stream.of(1, 2, 3).map(this::makeRecord).collect(Collectors.toList());

        Consumer<Integer> servicePublisherAction = contSeqNum -> captor.getValue().onNext(
                SubscribeToShardEvent.builder()
                        .millisBehindLatest(100L)
                        .continuationSequenceNumber(contSeqNum + "")
                        .records(records)
                        .build());

        CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2);
        int totalServicePublisherEvents = 1000;
        int initialDemand = 9;
        BackpressureAdheringServicePublisher servicePublisher =
                new BackpressureAdheringServicePublisher(servicePublisherAction, totalServicePublisherEvents, servicePublisherTaskCompletionLatch, initialDemand);

        doNothing().when(publisher).subscribe(captor.capture());

        source.start(ExtendedSequenceNumber.LATEST,
                InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST));

        List<ProcessRecordsInput> receivedInput = new ArrayList<>();

        Subscriber<RecordsRetrieved> shardConsumerSubscriber = new ShardConsumerNotifyingSubscriber(
                new Subscriber<RecordsRetrieved>() {
                    private Subscription subscription;
                    private int lastSeenSeqNum = 0;

                    @Override public void onSubscribe(Subscription s) {
                        subscription = s;
                        subscription.request(1);
                        servicePublisher.request(1);
                    }

                    @Override public void onNext(RecordsRetrieved input) {
                        receivedInput.add(input.processRecordsInput());
                        assertEquals("" + ++lastSeenSeqNum, ((FanOutRecordsPublisher.FanoutRecordsRetrieved)input).continuationSequenceNumber());
                        subscription.request(1);
                        servicePublisher.request(1);
                        if(receivedInput.size() == totalServicePublisherEvents) {
                            servicePublisherTaskCompletionLatch.countDown();
                        }
                    }

                    @Override public void onError(Throwable t) {
                        log.error("Caught throwable in subscriber", t);
                        fail("Caught throwable in subscriber");
                    }

                    @Override public void onComplete() {
                        fail("OnComplete called when not expected");
                    }
                }, source);

        ExecutorService executorService = getTestExecutor();
        Scheduler testScheduler = getScheduler(getInitiallyBlockingExecutor(getSpiedExecutor(executorService)));
        int bufferSize = 8;

        Flowable.fromPublisher(source).subscribeOn(testScheduler).observeOn(testScheduler, true, bufferSize)
                .subscribe(shardConsumerSubscriber);

        verify(kinesisClient).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture());
        flowCaptor.getValue().onEventStream(publisher);
        captor.getValue().onSubscribe(subscription);

        List<KinesisClientRecordMatcher> matchers = records.stream().map(KinesisClientRecordMatcher::new)
                .collect(Collectors.toList());

        executorService.submit(servicePublisher);
        servicePublisherTaskCompletionLatch.await(5000, TimeUnit.MILLISECONDS);

        assertThat(receivedInput.size(), equalTo(totalServicePublisherEvents));

        receivedInput.stream().map(ProcessRecordsInput::records).forEach(clientRecordsList -> {
            assertThat(clientRecordsList.size(), equalTo(matchers.size()));
            for (int i = 0; i < clientRecordsList.size(); ++i) {
                assertThat(clientRecordsList.get(i), matchers.get(i));
            }
        });

        assertThat(source.getCurrentSequenceNumber(), equalTo(totalServicePublisherEvents + ""));

    }

    @Test
    public void testIfStreamOfEventsAreDeliveredInOrderWithBackpressureAdheringServicePublisherHavingInitialBurstOverLimit() throws Exception {
        FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);

        ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordSubscription.class);
        ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordFlow.class);

        List<Record> records = Stream.of(1, 2, 3).map(this::makeRecord).collect(Collectors.toList());

        Consumer<Integer> servicePublisherAction = contSeqNum -> captor.getValue().onNext(
                SubscribeToShardEvent.builder()
                        .millisBehindLatest(100L)
                        .continuationSequenceNumber(contSeqNum + "")
                        .records(records)
                        .build());

        CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(1);
        int totalServicePublisherEvents = 1000;
        int initialDemand = 11;
        BackpressureAdheringServicePublisher servicePublisher =
                new BackpressureAdheringServicePublisher(servicePublisherAction, totalServicePublisherEvents, servicePublisherTaskCompletionLatch, initialDemand);

        doNothing().when(publisher).subscribe(captor.capture());

        source.start(ExtendedSequenceNumber.LATEST,
                InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST));

        List<ProcessRecordsInput> receivedInput = new ArrayList<>();
        AtomicBoolean onErrorSet = new AtomicBoolean(false);

        Subscriber<RecordsRetrieved> shardConsumerSubscriber = new ShardConsumerNotifyingSubscriber(
                new Subscriber<RecordsRetrieved>() {
                    private Subscription subscription;
                    private int lastSeenSeqNum = 0;

                    @Override public void onSubscribe(Subscription s) {
                        subscription = s;
                        subscription.request(1);
                        servicePublisher.request(1);
                    }

                    @Override public void onNext(RecordsRetrieved input) {
                        receivedInput.add(input.processRecordsInput());
                        assertEquals("" + ++lastSeenSeqNum, ((FanOutRecordsPublisher.FanoutRecordsRetrieved)input).continuationSequenceNumber());
                        subscription.request(1);
                        servicePublisher.request(1);
                    }

                    @Override public void onError(Throwable t) {
                        log.error("Caught throwable in subscriber", t);
                        onErrorSet.set(true);
                        servicePublisherTaskCompletionLatch.countDown();
                    }

                    @Override public void onComplete() {
                        fail("OnComplete called when not expected");
                    }
                }, source);

        ExecutorService executorService = getTestExecutor();
        Scheduler testScheduler = getScheduler(getInitiallyBlockingExecutor(getSpiedExecutor(executorService)));
        int bufferSize = 8;

        Flowable.fromPublisher(source).subscribeOn(testScheduler).observeOn(testScheduler, true, bufferSize)
                .subscribe(shardConsumerSubscriber);

        verify(kinesisClient).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture());
        flowCaptor.getValue().onEventStream(publisher);
        captor.getValue().onSubscribe(subscription);

        List<KinesisClientRecordMatcher> matchers = records.stream().map(KinesisClientRecordMatcher::new)
                .collect(Collectors.toList());

        executorService.submit(servicePublisher);
        servicePublisherTaskCompletionLatch.await(5000, TimeUnit.MILLISECONDS);

        assertTrue("onError should have triggered", onErrorSet.get());

        receivedInput.stream().map(ProcessRecordsInput::records).forEach(clientRecordsList -> {
            assertThat(clientRecordsList.size(), equalTo(matchers.size()));
            for (int i = 0; i < clientRecordsList.size(); ++i) {
                assertThat(clientRecordsList.get(i), matchers.get(i));
            }
        });
    }

    private Scheduler getScheduler(ExecutorService executorService) {
        return Schedulers.from(executorService);
    }

    private ExecutorService getTestExecutor() {
        return Executors.newFixedThreadPool(8,
                new ThreadFactoryBuilder().setNameFormat("test-fanout-record-publisher-%04d").setDaemon(true).build());
    }

    private ExecutorService getSpiedExecutor(ExecutorService executorService) {
        return spy(executorService);
    }

    private ExecutorService getBlockingExecutor(ExecutorService executorService) {
        doAnswer(invocation -> directlyExecuteRunnable(invocation)).when(executorService).execute(any());
        return executorService;
    }

    private ExecutorService getInitiallyBlockingExecutor(ExecutorService executorService) {
        doAnswer(invocation -> directlyExecuteRunnable(invocation))
        .doAnswer(invocation -> directlyExecuteRunnable(invocation))
        .doCallRealMethod()
        .when(executorService).execute(any());
        return executorService;
    }

    private ExecutorService getOverwhelmedBlockingExecutor(ExecutorService executorService) {
        doAnswer(invocation -> directlyExecuteRunnable(invocation))
        .doAnswer(invocation -> directlyExecuteRunnable(invocation))
        .doAnswer(invocation -> directlyExecuteRunnable(invocation))
        .doThrow(new RejectedExecutionException())
        .doAnswer(invocation -> directlyExecuteRunnable(invocation))
                .when(executorService).execute(any());
        return executorService;
    }

    private Object directlyExecuteRunnable(InvocationOnMock invocation) {
        Object[] args = invocation.getArguments();
        Runnable runnable = (Runnable) args[0];
        runnable.run();
        return null;
    }

    @Test
    public void largeRequestTest() throws Exception {
        FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);

        ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordSubscription.class);
        ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordFlow.class);

        doNothing().when(publisher).subscribe(captor.capture());

        source.start(ExtendedSequenceNumber.LATEST,
                InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST));

        List<ProcessRecordsInput> receivedInput = new ArrayList<>();

        source.subscribe(new ShardConsumerNotifyingSubscriber(new Subscriber<RecordsRetrieved>() {
            Subscription subscription;

            @Override public void onSubscribe(Subscription s) {
                subscription = s;
                subscription.request(3);
            }

            @Override public void onNext(RecordsRetrieved input) {
                receivedInput.add(input.processRecordsInput());
                subscription.request(1);
            }

            @Override public void onError(Throwable t) {
                log.error("Caught throwable in subscriber", t);
                fail("Caught throwable in subscriber");
            }

            @Override public void onComplete() {
                fail("OnComplete called when not expected");
            }
        }, source));

        verify(kinesisClient).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture());
        flowCaptor.getValue().onEventStream(publisher);
        captor.getValue().onSubscribe(subscription);

        List<Record> records = Stream.of(1, 2, 3).map(this::makeRecord).collect(Collectors.toList());
        List<KinesisClientRecordMatcher> matchers = records.stream().map(KinesisClientRecordMatcher::new)
                .collect(Collectors.toList());

        batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records).build();

        captor.getValue().onNext(batchEvent);
        captor.getValue().onNext(batchEvent);
        captor.getValue().onNext(batchEvent);

        verify(subscription, times(4)).request(1);
        assertThat(receivedInput.size(), equalTo(3));

        receivedInput.stream().map(ProcessRecordsInput::records).forEach(clientRecordsList -> {
            assertThat(clientRecordsList.size(), equalTo(matchers.size()));
            for (int i = 0; i < clientRecordsList.size(); ++i) {
                assertThat(clientRecordsList.get(i), matchers.get(i));
            }
        });

    }

    @Test
    public void testResourceNotFoundForShard() {
        FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);

        ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordFlow.class);
        ArgumentCaptor<RecordsRetrieved> inputCaptor = ArgumentCaptor.forClass(RecordsRetrieved.class);

        source.subscribe(subscriber);

        verify(kinesisClient).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture());
        FanOutRecordsPublisher.RecordFlow recordFlow = flowCaptor.getValue();
        recordFlow.exceptionOccurred(new RuntimeException(ResourceNotFoundException.builder().build()));

        verify(subscriber).onSubscribe(any());
        verify(subscriber, never()).onError(any());
        verify(subscriber).onNext(inputCaptor.capture());
        verify(subscriber).onComplete();

        ProcessRecordsInput input = inputCaptor.getValue().processRecordsInput();
        assertThat(input.isAtShardEnd(), equalTo(true));
        assertThat(input.records().isEmpty(), equalTo(true));
    }

    @Test
    public void testReadTimeoutExceptionForShard() {
        FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);

        ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordFlow.class);

        source.subscribe(subscriber);

        verify(kinesisClient).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture());
        FanOutRecordsPublisher.RecordFlow recordFlow = flowCaptor.getValue();
        recordFlow.exceptionOccurred(new RuntimeException(ReadTimeoutException.INSTANCE));

        verify(subscriber).onSubscribe(any());
        verify(subscriber).onError(any(RetryableRetrievalException.class));
        verify(subscriber, never()).onNext(any());
        verify(subscriber, never()).onComplete();
    }

    @Test
    public void testContinuesAfterSequence() {
        FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);

        ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordSubscription.class);
        ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordFlow.class);

        doNothing().when(publisher).subscribe(captor.capture());

        source.start(new ExtendedSequenceNumber("0"),
                InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST));

        NonFailingSubscriber nonFailingSubscriber = new NonFailingSubscriber();

        source.subscribe(new ShardConsumerNotifyingSubscriber(nonFailingSubscriber, source));

        SubscribeToShardRequest expected = SubscribeToShardRequest.builder().consumerARN(CONSUMER_ARN).shardId(SHARD_ID)
                .startingPosition(StartingPosition.builder().sequenceNumber("0")
                        .type(ShardIteratorType.AT_SEQUENCE_NUMBER).build())
                .build();

        verify(kinesisClient).subscribeToShard(argThat(new SubscribeToShardRequestMatcher(expected)), flowCaptor.capture());

        flowCaptor.getValue().onEventStream(publisher);
        captor.getValue().onSubscribe(subscription);

        List<Record> records = Stream.of(1, 2, 3).map(this::makeRecord).collect(Collectors.toList());
        List<KinesisClientRecordMatcher> matchers = records.stream().map(KinesisClientRecordMatcher::new)
                .collect(Collectors.toList());

        batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records)
                .continuationSequenceNumber("3").build();

        captor.getValue().onNext(batchEvent);
        captor.getValue().onComplete();
        flowCaptor.getValue().complete();

        ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> nextSubscribeCaptor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordSubscription.class);
        ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> nextFlowCaptor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordFlow.class);

        SubscribeToShardRequest nextExpected = SubscribeToShardRequest.builder().consumerARN(CONSUMER_ARN)
                .shardId(SHARD_ID).startingPosition(StartingPosition.builder().sequenceNumber("3")
                        .type(ShardIteratorType.AFTER_SEQUENCE_NUMBER).build())
                .build();

        verify(kinesisClient).subscribeToShard(argThat(new SubscribeToShardRequestMatcher(nextExpected)), nextFlowCaptor.capture());
        reset(publisher);
        doNothing().when(publisher).subscribe(nextSubscribeCaptor.capture());

        nextFlowCaptor.getValue().onEventStream(publisher);
        nextSubscribeCaptor.getValue().onSubscribe(subscription);

        List<Record> nextRecords = Stream.of(4, 5, 6).map(this::makeRecord).collect(Collectors.toList());
        List<KinesisClientRecordMatcher> nextMatchers = nextRecords.stream().map(KinesisClientRecordMatcher::new)
                .collect(Collectors.toList());

        batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(nextRecords)
                .continuationSequenceNumber("6").build();
        nextSubscribeCaptor.getValue().onNext(batchEvent);

        verify(subscription, times(4)).request(1);

        assertThat(nonFailingSubscriber.received.size(), equalTo(2));

        verifyRecords(nonFailingSubscriber.received.get(0).records(), matchers);
        verifyRecords(nonFailingSubscriber.received.get(1).records(), nextMatchers);

    }

    @Test
    public void testIfBufferingRecordsWithinCapacityPublishesOneEvent() {
        FanOutRecordsPublisher fanOutRecordsPublisher = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);
        RecordsRetrieved recordsRetrieved = ProcessRecordsInput.builder()::build;
        FanOutRecordsPublisher.RecordFlow recordFlow =
                new FanOutRecordsPublisher.RecordFlow(fanOutRecordsPublisher, Instant.now(), "shard-001-001");
        final int[] totalRecordsRetrieved = { 0 };
        fanOutRecordsPublisher.subscribe(new Subscriber<RecordsRetrieved>() {
            @Override public void onSubscribe(Subscription subscription) {}
            @Override public void onNext(RecordsRetrieved recordsRetrieved) {
                totalRecordsRetrieved[0]++;
            }
            @Override public void onError(Throwable throwable) {}
            @Override public void onComplete() {}
        });
        IntStream.rangeClosed(1, 10).forEach(i -> fanOutRecordsPublisher.bufferCurrentEventAndScheduleIfRequired(recordsRetrieved, recordFlow));
        assertEquals(1, totalRecordsRetrieved[0]);
    }

    @Test
    public void testIfBufferingRecordsOverCapacityPublishesOneEventAndThrows() {
        FanOutRecordsPublisher fanOutRecordsPublisher = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);
        RecordsRetrieved recordsRetrieved = ProcessRecordsInput.builder()::build;
        FanOutRecordsPublisher.RecordFlow recordFlow =
                new FanOutRecordsPublisher.RecordFlow(fanOutRecordsPublisher, Instant.now(), "shard-001");
        final int[] totalRecordsRetrieved = { 0 };
        fanOutRecordsPublisher.subscribe(new Subscriber<RecordsRetrieved>() {
            @Override public void onSubscribe(Subscription subscription) {}
            @Override public void onNext(RecordsRetrieved recordsRetrieved) {
                totalRecordsRetrieved[0]++;
            }
            @Override public void onError(Throwable throwable) {}
            @Override public void onComplete() {}
        });
        try {
            IntStream.rangeClosed(1, 12).forEach(
                    i -> fanOutRecordsPublisher.bufferCurrentEventAndScheduleIfRequired(recordsRetrieved, recordFlow));
            fail("Should throw Queue full exception");
        } catch (IllegalStateException e) {
            assertEquals("Queue full", e.getMessage());
        }
        assertEquals(1, totalRecordsRetrieved[0]);
    }

    @Test
    public void testIfPublisherAlwaysPublishesWhenQueueIsEmpty() {
        FanOutRecordsPublisher fanOutRecordsPublisher = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);
        FanOutRecordsPublisher.RecordFlow recordFlow =
                new FanOutRecordsPublisher.RecordFlow(fanOutRecordsPublisher, Instant.now(), "shard-001");
        final int[] totalRecordsRetrieved = { 0 };
        fanOutRecordsPublisher.subscribe(new Subscriber<RecordsRetrieved>() {
            @Override public void onSubscribe(Subscription subscription) {}
            @Override public void onNext(RecordsRetrieved recordsRetrieved) {
                totalRecordsRetrieved[0]++;
                // This makes sure the queue is immediately made empty, so that the next event enqueued will
                // be the only element in the queue.
                fanOutRecordsPublisher
                        .evictAckedEventAndScheduleNextEvent(() -> recordsRetrieved.batchUniqueIdentifier());
            }
            @Override public void onError(Throwable throwable) {}
            @Override public void onComplete() {}
        });
        IntStream.rangeClosed(1, 137).forEach(i -> fanOutRecordsPublisher.bufferCurrentEventAndScheduleIfRequired(
                new FanOutRecordsPublisher.FanoutRecordsRetrieved(ProcessRecordsInput.builder().build(), i + "", recordFlow.getSubscribeToShardId()),
                recordFlow));
        assertEquals(137, totalRecordsRetrieved[0]);
    }

    @Test
    public void testIfPublisherIgnoresStaleEventsAndContinuesWithNextFlow() {
        FanOutRecordsPublisher fanOutRecordsPublisher = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);
        FanOutRecordsPublisher.RecordFlow recordFlow =
                new FanOutRecordsPublisher.RecordFlow(fanOutRecordsPublisher, Instant.now(), "shard-001");
        final int[] totalRecordsRetrieved = { 0 };
        fanOutRecordsPublisher.subscribe(new Subscriber<RecordsRetrieved>() {
            @Override public void onSubscribe(Subscription subscription) {}
            @Override public void onNext(RecordsRetrieved recordsRetrieved) {
                totalRecordsRetrieved[0]++;
                // This makes sure the queue is immediately made empty, so that the next event enqueued will
                // be the only element in the queue.
                fanOutRecordsPublisher
                        .evictAckedEventAndScheduleNextEvent(() -> recordsRetrieved.batchUniqueIdentifier());
                // Send stale event periodically
                if(totalRecordsRetrieved[0] % 10 == 0) {
                    fanOutRecordsPublisher.evictAckedEventAndScheduleNextEvent(
                            () -> new BatchUniqueIdentifier("some_uuid_str", "some_old_flow"));
                }
            }
            @Override public void onError(Throwable throwable) {}
            @Override public void onComplete() {}
        });
        IntStream.rangeClosed(1, 100).forEach(i -> fanOutRecordsPublisher.bufferCurrentEventAndScheduleIfRequired(
                new FanOutRecordsPublisher.FanoutRecordsRetrieved(ProcessRecordsInput.builder().build(), i + "", recordFlow.getSubscribeToShardId()),
                recordFlow));
        assertEquals(100, totalRecordsRetrieved[0]);
    }

    @Test
    public void testIfPublisherIgnoresStaleEventsAndContinuesWithNextFlowWhenDeliveryQueueIsNotEmpty()
            throws InterruptedException {
        FanOutRecordsPublisher fanOutRecordsPublisher = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);
        FanOutRecordsPublisher.RecordFlow recordFlow =
                new FanOutRecordsPublisher.RecordFlow(fanOutRecordsPublisher, Instant.now(), "shard-001");
        final int[] totalRecordsRetrieved = { 0 };
        BlockingQueue<BatchUniqueIdentifier> ackQueue = new LinkedBlockingQueue<>();
        fanOutRecordsPublisher.subscribe(new Subscriber<RecordsRetrieved>() {
            @Override public void onSubscribe(Subscription subscription) {}
            @Override public void onNext(RecordsRetrieved recordsRetrieved) {
                totalRecordsRetrieved[0]++;
                // Enqueue the ack for bursty delivery
                ackQueue.add(recordsRetrieved.batchUniqueIdentifier());
                // Send stale event periodically
            }
            @Override public void onError(Throwable throwable) {}
            @Override public void onComplete() {}
        });
        IntStream.rangeClosed(1, 10).forEach(i -> fanOutRecordsPublisher.bufferCurrentEventAndScheduleIfRequired(
                new FanOutRecordsPublisher.FanoutRecordsRetrieved(ProcessRecordsInput.builder().build(), i + "", recordFlow.getSubscribeToShardId()),
                recordFlow));
        BatchUniqueIdentifier batchUniqueIdentifierQueued;
        int count = 0;
        // Now that we allowed upto 10 elements queued up, send a pair of good and stale ack to verify records
        // delivered as expected.
        while(count++ < 10 && (batchUniqueIdentifierQueued = ackQueue.take()) != null) {
            final BatchUniqueIdentifier batchUniqueIdentifierFinal = batchUniqueIdentifierQueued;
            fanOutRecordsPublisher
                    .evictAckedEventAndScheduleNextEvent(() -> batchUniqueIdentifierFinal);
            fanOutRecordsPublisher.evictAckedEventAndScheduleNextEvent(
                    () -> new BatchUniqueIdentifier("some_uuid_str", "some_old_flow"));
        }
        assertEquals(10, totalRecordsRetrieved[0]);
    }

    @Test(expected = IllegalStateException.class)
    public void testIfPublisherThrowsWhenMismatchAckforActiveFlowSeen() throws InterruptedException {
        FanOutRecordsPublisher fanOutRecordsPublisher = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);
        FanOutRecordsPublisher.RecordFlow recordFlow =
                new FanOutRecordsPublisher.RecordFlow(fanOutRecordsPublisher, Instant.now(), "Shard-001-1");
        final int[] totalRecordsRetrieved = { 0 };
        BlockingQueue<BatchUniqueIdentifier> ackQueue = new LinkedBlockingQueue<>();
        fanOutRecordsPublisher.subscribe(new Subscriber<RecordsRetrieved>() {
            @Override public void onSubscribe(Subscription subscription) {}
            @Override public void onNext(RecordsRetrieved recordsRetrieved) {
                totalRecordsRetrieved[0]++;
                // Enqueue the ack for bursty delivery
                ackQueue.add(recordsRetrieved.batchUniqueIdentifier());
                // Send stale event periodically
            }
            @Override public void onError(Throwable throwable) {}
            @Override public void onComplete() {}
        });
        IntStream.rangeClosed(1, 10).forEach(i -> fanOutRecordsPublisher.bufferCurrentEventAndScheduleIfRequired(
                new FanOutRecordsPublisher.FanoutRecordsRetrieved(ProcessRecordsInput.builder().build(), i + "", recordFlow.getSubscribeToShardId()),
                recordFlow));
        BatchUniqueIdentifier batchUniqueIdentifierQueued;
        int count = 0;
        // Now that we allowed upto 10 elements queued up, send a pair of good and stale ack to verify records
        // delivered as expected.
        while(count++ < 2 && (batchUniqueIdentifierQueued = ackQueue.poll(1000, TimeUnit.MILLISECONDS)) != null) {
            final BatchUniqueIdentifier batchUniqueIdentifierFinal = batchUniqueIdentifierQueued;
            fanOutRecordsPublisher.evictAckedEventAndScheduleNextEvent(
                    () -> new BatchUniqueIdentifier("some_uuid_str", batchUniqueIdentifierFinal.getFlowIdentifier()));
        }
    }

    @Test
    public void acquireTimeoutTriggersLogMethodForActiveFlow() {
        AtomicBoolean acquireTimeoutLogged = new AtomicBoolean(false);
        FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN) {
            @Override
            protected void logAcquireTimeoutMessage(Throwable t) {
                super.logAcquireTimeoutMessage(t);
                acquireTimeoutLogged.set(true);
            }
        };

        ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordSubscription.class);
        ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
                .forClass(FanOutRecordsPublisher.RecordFlow.class);

        doNothing().when(publisher).subscribe(captor.capture());

        source.start(ExtendedSequenceNumber.LATEST,
                InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST));
        RecordingSubscriber subscriber = new RecordingSubscriber();
        source.subscribe(subscriber);

        verify(kinesisClient).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture());

        Throwable exception = new CompletionException(
                "software.amazon.awssdk.core.exception.SdkClientException",
                SdkClientException.create(null, new Throwable(
                        "Acquire operation took longer than the configured maximum time. This indicates that a " +
                                "request cannot get a connection from the pool within the specified maximum time. " +
                                "This can be due to high request rate.\n" +
                                "Consider taking any of the following actions to mitigate the issue: increase max " +
                                "connections, increase acquire timeout, or slowing the request rate.\n" +
                                "Increasing the max connections can increase client throughput (unless the network " +
                                "interface is already fully utilized), but can eventually start to hit operation " +
                                "system limitations on the number of file descriptors used by the process. " +
                                "If you already are fully utilizing your network interface or cannot further " +
                                "increase your connection count, increasing the acquire timeout gives extra time " +
                                "for requests to acquire a connection before timing out. " +
                                "If the connections doesn't free up, the subsequent requests will still timeout.\n" +
                                "If the above mechanisms are not able to fix the issue, try smoothing out your " +
                                "requests so that large traffic bursts cannot overload the client, being more " +
                                "efficient with the number of times you need to call AWS, or by increasing the " +
                                "number of hosts sending requests.")));

        flowCaptor.getValue().exceptionOccurred(exception);

        Optional<OnErrorEvent> onErrorEvent = subscriber.events.stream().filter(e -> e instanceof OnErrorEvent).map(e -> (OnErrorEvent)e).findFirst();

        assertThat(onErrorEvent, equalTo(Optional.of(new OnErrorEvent(exception))));
        assertThat(acquireTimeoutLogged.get(), equalTo(true));

    }

    private void verifyRecords(List<KinesisClientRecord> clientRecordsList, List<KinesisClientRecordMatcher> matchers) {
        assertThat(clientRecordsList.size(), equalTo(matchers.size()));
        for (int i = 0; i < clientRecordsList.size(); ++i) {
            assertThat(clientRecordsList.get(i), matchers.get(i));
        }
    }

    private interface SubscriberEvent {

    }

    @Data
    private static class SubscribeEvent implements SubscriberEvent {
        final Subscription subscription;
    }

    @Data
    private static class OnNextEvent implements SubscriberEvent {
        final RecordsRetrieved recordsRetrieved;
    }

    @Data
    private static class OnErrorEvent implements SubscriberEvent {
        final Throwable throwable;
    }

    @Data
    private static class OnCompleteEvent implements SubscriberEvent {

    }

    @Data
    private static class RequestEvent implements SubscriberEvent {
        final long requested;
    }

    private static class RecordingSubscriber implements Subscriber<RecordsRetrieved> {

        final List<SubscriberEvent> events = new LinkedList<>();

        Subscription subscription;

        @Override
        public void onSubscribe(Subscription s) {
            events.add(new SubscribeEvent(s));
            subscription = s;
            subscription.request(1);
            events.add(new RequestEvent(1));
        }

        @Override
        public void onNext(RecordsRetrieved recordsRetrieved) {
            events.add(new OnNextEvent(recordsRetrieved));
            subscription.request(1);
            events.add(new RequestEvent(1));
        }

        @Override
        public void onError(Throwable t) {
            events.add(new OnErrorEvent(t));
        }

        @Override
        public void onComplete() {
            events.add(new OnCompleteEvent());
        }
    }

    private static class NonFailingSubscriber implements Subscriber<RecordsRetrieved> {
        final List<ProcessRecordsInput> received = new ArrayList<>();
        Subscription subscription;

        @Override
        public void onSubscribe(Subscription s) {
            subscription = s;
            subscription.request(1);
        }

        @Override
        public void onNext(RecordsRetrieved input) {
            received.add(input.processRecordsInput());
            subscription.request(1);
        }

        @Override
        public void onError(Throwable t) {
            log.error("Caught throwable in subscriber", t);
            fail("Caught throwable in subscriber");
        }

        @Override
        public void onComplete() {
            fail("OnComplete called when not expected");
        }
    }

    @RequiredArgsConstructor
    private static class BackpressureAdheringServicePublisher implements Runnable {

        private final Consumer<Integer> action;
        private final Integer numOfTimes;
        private final CountDownLatch taskCompletionLatch;
        private final Semaphore demandNotifier;
        private Integer sendCompletionAt;
        private Runnable completeAction;
        private Integer sendErrorAt;
        private Runnable errorAction;
        private Consumer<Integer> shardEndAction;

        BackpressureAdheringServicePublisher(Consumer<Integer> action, Integer numOfTimes,
                CountDownLatch taskCompletionLatch, Integer initialDemand) {
            this(action, numOfTimes, taskCompletionLatch, new Semaphore(initialDemand));
            sendCompletionAt = Integer.MAX_VALUE;
            sendErrorAt = Integer.MAX_VALUE;
        }

        public void request(int n) {
            demandNotifier.release(n);
        }

        public void run() {
            for (int i = 1; i <= numOfTimes; ) {
                demandNotifier.acquireUninterruptibly();
                if(i == sendCompletionAt) {
                    if(shardEndAction != null) {
                        shardEndAction.accept(i++);
                    } else {
                        action.accept(i++);
                    }
                    completeAction.run();
                    break;
                }
                if(i == sendErrorAt) {
                    action.accept(i++);
                    errorAction.run();
                    break;
                }
                action.accept(i++);
            }
            taskCompletionLatch.countDown();
        }

        public void setCompleteTrigger(Integer sendCompletionAt, Runnable completeAction) {
            this.sendCompletionAt = sendCompletionAt;
            this.completeAction = completeAction;
        }

        public void setShardEndAndCompleteTrigger(Integer sendCompletionAt, Runnable completeAction, Consumer<Integer> shardEndAction) {
            setCompleteTrigger(sendCompletionAt, completeAction);
            this.shardEndAction = shardEndAction;
        }

        public void setErrorTrigger(Integer sendErrorAt, Runnable errorAction) {
            this.sendErrorAt = sendErrorAt;
            this.errorAction = errorAction;
        }
    }

    private Record makeRecord(String sequenceNumber) {
        return makeRecord(Integer.parseInt(sequenceNumber));
    }

    private Record makeRecord(int sequenceNumber) {
        SdkBytes buffer = SdkBytes.fromByteArray(new byte[] { 1, 2, 3 });
        return Record.builder().data(buffer).approximateArrivalTimestamp(Instant.now())
                .sequenceNumber(Integer.toString(sequenceNumber)).partitionKey("A").build();
    }

    private static class KinesisClientRecordMatcher extends TypeSafeDiagnosingMatcher<KinesisClientRecord> {

        private final KinesisClientRecord expected;
        private final Matcher<String> partitionKeyMatcher;
        private final Matcher<String> sequenceNumberMatcher;
        private final Matcher<Instant> approximateArrivalMatcher;
        private final Matcher<ByteBuffer> dataMatcher;

        public KinesisClientRecordMatcher(Record record) {
            expected = KinesisClientRecord.fromRecord(record);
            partitionKeyMatcher = equalTo(expected.partitionKey());
            sequenceNumberMatcher = equalTo(expected.sequenceNumber());
            approximateArrivalMatcher = equalTo(expected.approximateArrivalTimestamp());
            dataMatcher = equalTo(expected.data());

        }

        @Override
        protected boolean matchesSafely(KinesisClientRecord item, Description mismatchDescription) {
            boolean matches = matchAndDescribe(partitionKeyMatcher, item.partitionKey(), "partitionKey",
                    mismatchDescription);
            matches &= matchAndDescribe(sequenceNumberMatcher, item.sequenceNumber(), "sequenceNumber",
                    mismatchDescription);
            matches &= matchAndDescribe(approximateArrivalMatcher, item.approximateArrivalTimestamp(),
                    "approximateArrivalTimestamp", mismatchDescription);
            matches &= matchAndDescribe(dataMatcher, item.data(), "data", mismatchDescription);
            return matches;
        }

        private <T> boolean matchAndDescribe(Matcher<T> matcher, T value, String field,
                Description mismatchDescription) {
            if (!matcher.matches(value)) {
                mismatchDescription.appendText(field).appendText(": ");
                matcher.describeMismatch(value, mismatchDescription);
                return false;
            }
            return true;
        }

        @Override
        public void describeTo(Description description) {
            description.appendText("A kinesis client record with: ").appendText("PartitionKey: ")
                    .appendDescriptionOf(partitionKeyMatcher).appendText(" SequenceNumber: ")
                    .appendDescriptionOf(sequenceNumberMatcher).appendText(" Approximate Arrival Time: ")
                    .appendDescriptionOf(approximateArrivalMatcher).appendText(" Data: ")
                    .appendDescriptionOf(dataMatcher);
        }

    }

}