/*
 * Copyright 2015-2018 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.rsocket.core;

import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE;
import static io.rsocket.frame.FrameHeaderCodec.frameType;
import static io.rsocket.frame.FrameType.*;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.verify;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.util.CharsetUtil;
import io.netty.util.IllegalReferenceCountException;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.TestScheduler;
import io.rsocket.exceptions.ApplicationErrorException;
import io.rsocket.exceptions.CustomRSocketException;
import io.rsocket.exceptions.RejectedSetupException;
import io.rsocket.frame.CancelFrameCodec;
import io.rsocket.frame.ErrorFrameCodec;
import io.rsocket.frame.FrameHeaderCodec;
import io.rsocket.frame.FrameLengthCodec;
import io.rsocket.frame.FrameType;
import io.rsocket.frame.PayloadFrameCodec;
import io.rsocket.frame.RequestChannelFrameCodec;
import io.rsocket.frame.RequestFireAndForgetFrameCodec;
import io.rsocket.frame.RequestNFrameCodec;
import io.rsocket.frame.RequestResponseFrameCodec;
import io.rsocket.frame.RequestStreamFrameCodec;
import io.rsocket.frame.decoder.PayloadDecoder;
import io.rsocket.internal.subscriber.AssertSubscriber;
import io.rsocket.lease.RequesterLeaseHandler;
import io.rsocket.test.util.TestSubscriber;
import io.rsocket.util.ByteBufPayload;
import io.rsocket.util.DefaultPayload;
import io.rsocket.util.EmptyPayload;
import java.nio.channels.ClosedChannelException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.Assumptions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.runners.model.Statement;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Hooks;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoProcessor;
import reactor.core.publisher.UnicastProcessor;
import reactor.core.scheduler.Schedulers;
import reactor.test.StepVerifier;
import reactor.test.publisher.TestPublisher;
import reactor.test.util.RaceTestUtils;

public class RSocketRequesterTest {

  ClientSocketRule rule;

  @BeforeEach
  public void setUp() throws Throwable {
    Hooks.onNextDropped(ReferenceCountUtil::safeRelease);
    Hooks.onErrorDropped((t) -> {});
    rule = new ClientSocketRule();
    rule.apply(
            new Statement() {
              @Override
              public void evaluate() {}
            },
            null)
        .evaluate();
  }

  @AfterEach
  public void tearDown() {
    Hooks.resetOnErrorDropped();
    Hooks.resetOnNextDropped();
  }

  @Test
  @Timeout(2_000)
  public void testInvalidFrameOnStream0ShouldNotTerminateRSocket() {
    rule.connection.addToReceivedBuffer(RequestNFrameCodec.encode(rule.alloc(), 0, 10));
    Assertions.assertThat(rule.socket.isDisposed()).isFalse();
    rule.assertHasNoLeaks();
  }

  @Test
  @Timeout(2_000)
  public void testStreamInitialN() {
    Flux<Payload> stream = rule.socket.requestStream(EmptyPayload.INSTANCE);

    BaseSubscriber<Payload> subscriber =
        new BaseSubscriber<Payload>() {
          @Override
          protected void hookOnSubscribe(Subscription subscription) {
            // don't request here
          }
        };
    stream.subscribe(subscriber);

    Assertions.assertThat(rule.connection.getSent()).isEmpty();

    subscriber.request(5);

    List<ByteBuf> sent = new ArrayList<>(rule.connection.getSent());

    assertThat("sent frame count", sent.size(), is(1));

    ByteBuf f = sent.get(0);

    assertThat("initial frame", frameType(f), is(REQUEST_STREAM));
    assertThat("initial request n", RequestStreamFrameCodec.initialRequestN(f), is(5L));
    assertThat("should be released", f.release(), is(true));
    rule.assertHasNoLeaks();
  }

  @Test
  @Timeout(2_000)
  public void testHandleSetupException() {
    rule.connection.addToReceivedBuffer(
        ErrorFrameCodec.encode(rule.alloc(), 0, new RejectedSetupException("boom")));
    Assertions.assertThatThrownBy(() -> rule.socket.onClose().block())
        .isInstanceOf(RejectedSetupException.class);
    rule.assertHasNoLeaks();
  }

  @Test
  @Timeout(2_000)
  public void testHandleApplicationException() {
    rule.connection.clearSendReceiveBuffers();
    Publisher<Payload> response = rule.socket.requestResponse(EmptyPayload.INSTANCE);
    Subscriber<Payload> responseSub = TestSubscriber.create();
    response.subscribe(responseSub);

    int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE);
    rule.connection.addToReceivedBuffer(
        ErrorFrameCodec.encode(rule.alloc(), streamId, new ApplicationErrorException("error")));

    verify(responseSub).onError(any(ApplicationErrorException.class));

    Assertions.assertThat(rule.connection.getSent())
        // requestResponseFrame
        .hasSize(1)
        .allMatch(ReferenceCounted::release);

    rule.assertHasNoLeaks();
  }

  @Test
  @Timeout(2_000)
  public void testHandleValidFrame() {
    Publisher<Payload> response = rule.socket.requestResponse(EmptyPayload.INSTANCE);
    Subscriber<Payload> sub = TestSubscriber.create();
    response.subscribe(sub);

    int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE);
    rule.connection.addToReceivedBuffer(
        PayloadFrameCodec.encodeNextReleasingPayload(
            rule.alloc(), streamId, EmptyPayload.INSTANCE));

    verify(sub).onComplete();
    Assertions.assertThat(rule.connection.getSent()).hasSize(1).allMatch(ReferenceCounted::release);
    rule.assertHasNoLeaks();
  }

  @Test
  @Timeout(2_000)
  public void testRequestReplyWithCancel() {
    Mono<Payload> response = rule.socket.requestResponse(EmptyPayload.INSTANCE);

    try {
      response.block(Duration.ofMillis(100));
    } catch (IllegalStateException ise) {
    }

    List<ByteBuf> sent = new ArrayList<>(rule.connection.getSent());

    assertThat(
        "Unexpected frame sent on the connection.", frameType(sent.get(0)), is(REQUEST_RESPONSE));
    assertThat("Unexpected frame sent on the connection.", frameType(sent.get(1)), is(CANCEL));
    Assertions.assertThat(sent).hasSize(2).allMatch(ReferenceCounted::release);
    rule.assertHasNoLeaks();
  }

  @Test
  @Disabled("invalid")
  @Timeout(2_000)
  public void testRequestReplyErrorOnSend() {
    rule.connection.setAvailability(0); // Fails send
    Mono<Payload> response = rule.socket.requestResponse(EmptyPayload.INSTANCE);
    Subscriber<Payload> responseSub = TestSubscriber.create(10);
    response.subscribe(responseSub);

    this.rule
        .socket
        .onClose()
        .as(StepVerifier::create)
        .expectComplete()
        .verify(Duration.ofMillis(100));

    verify(responseSub).onSubscribe(any(Subscription.class));

    rule.assertHasNoLeaks();
    // TODO this should get the error reported through the response subscription
    //    verify(responseSub).onError(any(RuntimeException.class));
  }

  @Test
  @Timeout(2_000)
  public void testChannelRequestCancellation() {
    MonoProcessor<Void> cancelled = MonoProcessor.create();
    Flux<Payload> request = Flux.<Payload>never().doOnCancel(cancelled::onComplete);
    rule.socket.requestChannel(request).subscribe().dispose();
    Flux.first(
            cancelled,
            Flux.error(new IllegalStateException("Channel request not cancelled"))
                .delaySubscription(Duration.ofSeconds(1)))
        .blockFirst();
    rule.assertHasNoLeaks();
  }

  @Test
  @Timeout(2_000)
  public void testChannelRequestCancellation2() {
    MonoProcessor<Void> cancelled = MonoProcessor.create();
    Flux<Payload> request =
        Flux.<Payload>just(EmptyPayload.INSTANCE).repeat(259).doOnCancel(cancelled::onComplete);
    rule.socket.requestChannel(request).subscribe().dispose();
    Flux.first(
            cancelled,
            Flux.error(new IllegalStateException("Channel request not cancelled"))
                .delaySubscription(Duration.ofSeconds(1)))
        .blockFirst();
    Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release);
    rule.assertHasNoLeaks();
  }

  @Test
  public void testChannelRequestServerSideCancellation() {
    MonoProcessor<Payload> cancelled = MonoProcessor.create();
    UnicastProcessor<Payload> request = UnicastProcessor.create();
    request.onNext(EmptyPayload.INSTANCE);
    rule.socket.requestChannel(request).subscribe(cancelled);
    int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL);
    rule.connection.addToReceivedBuffer(CancelFrameCodec.encode(rule.alloc(), streamId));
    rule.connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(rule.alloc(), streamId));
    Flux.first(
            cancelled,
            Flux.error(new IllegalStateException("Channel request not cancelled"))
                .delaySubscription(Duration.ofSeconds(1)))
        .blockFirst();

    Assertions.assertThat(request.isDisposed()).isTrue();
    Assertions.assertThat(rule.connection.getSent())
        .hasSize(1)
        .first()
        .matches(bb -> frameType(bb) == REQUEST_CHANNEL)
        .matches(ReferenceCounted::release);
    rule.assertHasNoLeaks();
  }

  @Test
  public void testCorrectFrameOrder() {
    MonoProcessor<Object> delayer = MonoProcessor.create();
    BaseSubscriber<Payload> subscriber =
        new BaseSubscriber<Payload>() {
          @Override
          protected void hookOnSubscribe(Subscription subscription) {}
        };
    rule.socket
        .requestChannel(
            Flux.concat(Flux.just(0).delayUntil(i -> delayer), Flux.range(1, 999))
                .map(i -> DefaultPayload.create(i + "")))
        .subscribe(subscriber);

    subscriber.request(1);
    subscriber.request(Long.MAX_VALUE);
    delayer.onComplete();

    Iterator<ByteBuf> iterator = rule.connection.getSent().iterator();

    ByteBuf initialFrame = iterator.next();

    Assertions.assertThat(FrameHeaderCodec.frameType(initialFrame)).isEqualTo(REQUEST_CHANNEL);
    Assertions.assertThat(RequestChannelFrameCodec.initialRequestN(initialFrame))
        .isEqualTo(Long.MAX_VALUE);
    Assertions.assertThat(RequestChannelFrameCodec.data(initialFrame).toString(CharsetUtil.UTF_8))
        .isEqualTo("0");
    Assertions.assertThat(initialFrame.release()).isTrue();

    Assertions.assertThat(iterator.hasNext()).isFalse();
    rule.assertHasNoLeaks();
  }

  @Test
  public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation() {
    prepareCalls()
        .forEach(
            generator -> {
              byte[] metadata = new byte[FrameLengthCodec.FRAME_LENGTH_MASK];
              byte[] data = new byte[FrameLengthCodec.FRAME_LENGTH_MASK];
              ThreadLocalRandom.current().nextBytes(metadata);
              ThreadLocalRandom.current().nextBytes(data);
              StepVerifier.create(
                      generator.apply(rule.socket, DefaultPayload.create(data, metadata)))
                  .expectSubscription()
                  .expectErrorSatisfies(
                      t ->
                          Assertions.assertThat(t)
                              .isInstanceOf(IllegalArgumentException.class)
                              .hasMessage(INVALID_PAYLOAD_ERROR_MESSAGE))
                  .verify();
              rule.assertHasNoLeaks();
            });
  }

  static Stream<BiFunction<RSocket, Payload, Publisher<?>>> prepareCalls() {
    return Stream.of(
        RSocket::fireAndForget,
        RSocket::requestResponse,
        RSocket::requestStream,
        (rSocket, payload) -> rSocket.requestChannel(Flux.just(payload)),
        RSocket::metadataPush);
  }

  @Test
  public void
      shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentationForRequestChannelCase() {
    byte[] metadata = new byte[FrameLengthCodec.FRAME_LENGTH_MASK];
    byte[] data = new byte[FrameLengthCodec.FRAME_LENGTH_MASK];
    ThreadLocalRandom.current().nextBytes(metadata);
    ThreadLocalRandom.current().nextBytes(data);
    StepVerifier.create(
            rule.socket.requestChannel(
                Flux.just(EmptyPayload.INSTANCE, DefaultPayload.create(data, metadata))))
        .expectSubscription()
        .then(
            () ->
                rule.connection.addToReceivedBuffer(
                    RequestNFrameCodec.encode(
                        rule.alloc(), rule.getStreamIdForRequestType(REQUEST_CHANNEL), 2)))
        .expectErrorSatisfies(
            t ->
                Assertions.assertThat(t)
                    .isInstanceOf(IllegalArgumentException.class)
                    .hasMessage(INVALID_PAYLOAD_ERROR_MESSAGE))
        .verify();
    Assertions.assertThat(rule.connection.getSent())
        // expect to be sent RequestChannelFrame
        // expect to be sent CancelFrame
        .hasSize(2)
        .allMatch(ReferenceCounted::release);
    rule.assertHasNoLeaks();
  }

  @ParameterizedTest
  @MethodSource("racingCases")
  public void checkNoLeaksOnRacing(
      Function<ClientSocketRule, Publisher<Payload>> initiator,
      BiConsumer<AssertSubscriber<Payload>, ClientSocketRule> runner) {
    for (int i = 0; i < 10000; i++) {
      ClientSocketRule clientSocketRule = new ClientSocketRule();
      try {
        clientSocketRule
            .apply(
                new Statement() {
                  @Override
                  public void evaluate() {}
                },
                null)
            .evaluate();
      } catch (Throwable throwable) {
        throwable.printStackTrace();
      }

      Publisher<Payload> payloadP = initiator.apply(clientSocketRule);
      AssertSubscriber<Payload> assertSubscriber = AssertSubscriber.create(0);

      if (payloadP instanceof Flux) {
        ((Flux<Payload>) payloadP).doOnNext(Payload::release).subscribe(assertSubscriber);
      } else {
        ((Mono<Payload>) payloadP).doOnNext(Payload::release).subscribe(assertSubscriber);
      }

      runner.accept(assertSubscriber, clientSocketRule);

      Assertions.assertThat(clientSocketRule.connection.getSent())
          .allMatch(ReferenceCounted::release);

      clientSocketRule.assertHasNoLeaks();
    }
  }

  private static Stream<Arguments> racingCases() {
    return Stream.of(
        Arguments.of(
            (Function<ClientSocketRule, Publisher<Payload>>)
                (rule) -> rule.socket.requestStream(EmptyPayload.INSTANCE),
            (BiConsumer<AssertSubscriber<Payload>, ClientSocketRule>)
                (as, rule) -> {
                  ByteBufAllocator allocator = rule.alloc();
                  ByteBuf metadata = allocator.buffer();
                  metadata.writeCharSequence("abc", CharsetUtil.UTF_8);
                  ByteBuf data = allocator.buffer();
                  data.writeCharSequence("def", CharsetUtil.UTF_8);
                  as.request(1);
                  int streamId = rule.getStreamIdForRequestType(REQUEST_STREAM);
                  ByteBuf frame =
                      PayloadFrameCodec.encode(
                          allocator, streamId, false, false, true, metadata, data);

                  RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame));
                }),
        Arguments.of(
            (Function<ClientSocketRule, Publisher<Payload>>)
                (rule) -> rule.socket.requestChannel(Flux.just(EmptyPayload.INSTANCE)),
            (BiConsumer<AssertSubscriber<Payload>, ClientSocketRule>)
                (as, rule) -> {
                  ByteBufAllocator allocator = rule.alloc();
                  ByteBuf metadata = allocator.buffer();
                  metadata.writeCharSequence("abc", CharsetUtil.UTF_8);
                  ByteBuf data = allocator.buffer();
                  data.writeCharSequence("def", CharsetUtil.UTF_8);
                  as.request(1);
                  int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL);
                  ByteBuf frame =
                      PayloadFrameCodec.encode(
                          allocator, streamId, false, false, true, metadata, data);

                  RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame));
                }),
        Arguments.of(
            (Function<ClientSocketRule, Publisher<Payload>>)
                (rule) -> {
                  ByteBufAllocator allocator = rule.alloc();
                  ByteBuf metadata = allocator.buffer();
                  metadata.writeCharSequence("metadata", CharsetUtil.UTF_8);
                  ByteBuf data = allocator.buffer();
                  data.writeCharSequence("data", CharsetUtil.UTF_8);
                  final Payload payload = ByteBufPayload.create(data, metadata);

                  return rule.socket.requestStream(payload);
                },
            (BiConsumer<AssertSubscriber<Payload>, ClientSocketRule>)
                (as, rule) -> {
                  RaceTestUtils.race(() -> as.request(1), as::cancel);
                  // ensures proper frames order
                  if (rule.connection.getSent().size() > 0) {
                    Assertions.assertThat(rule.connection.getSent()).hasSize(2);
                    Assertions.assertThat(rule.connection.getSent())
                        .element(0)
                        .matches(
                            bb -> frameType(bb) == REQUEST_STREAM,
                            "Expected first frame matches {"
                                + REQUEST_STREAM
                                + "} but was {"
                                + frameType(rule.connection.getSent().stream().findFirst().get())
                                + "}");
                    Assertions.assertThat(rule.connection.getSent())
                        .element(1)
                        .matches(
                            bb -> frameType(bb) == CANCEL,
                            "Expected first frame matches {"
                                + CANCEL
                                + "} but was {"
                                + frameType(
                                    rule.connection.getSent().stream().skip(1).findFirst().get())
                                + "}");
                  }
                }),
        Arguments.of(
            (Function<ClientSocketRule, Publisher<Payload>>)
                (rule) -> {
                  ByteBufAllocator allocator = rule.alloc();
                  return rule.socket.requestChannel(
                      Flux.generate(
                          () -> 1L,
                          (index, sink) -> {
                            ByteBuf metadata = allocator.buffer();
                            metadata.writeCharSequence("metadata", CharsetUtil.UTF_8);
                            ByteBuf data = allocator.buffer();
                            data.writeCharSequence("data", CharsetUtil.UTF_8);
                            final Payload payload = ByteBufPayload.create(data, metadata);
                            sink.next(payload);
                            sink.complete();
                            return ++index;
                          }));
                },
            (BiConsumer<AssertSubscriber<Payload>, ClientSocketRule>)
                (as, rule) -> {
                  RaceTestUtils.race(() -> as.request(1), as::cancel);
                  // ensures proper frames order
                  if (rule.connection.getSent().size() > 0) {
                    //
                    // Assertions.assertThat(rule.connection.getSent()).hasSize(2);
                    Assertions.assertThat(rule.connection.getSent())
                        .element(0)
                        .matches(
                            bb -> frameType(bb) == REQUEST_CHANNEL,
                            "Expected first frame matches {"
                                + REQUEST_CHANNEL
                                + "} but was {"
                                + frameType(rule.connection.getSent().stream().findFirst().get())
                                + "}");
                    Assertions.assertThat(rule.connection.getSent())
                        .element(1)
                        .matches(
                            bb -> frameType(bb) == CANCEL,
                            "Expected first frame matches {"
                                + CANCEL
                                + "} but was {"
                                + frameType(
                                    rule.connection.getSent().stream().skip(1).findFirst().get())
                                + "}");
                  }
                }),
        Arguments.of(
            (Function<ClientSocketRule, Publisher<Payload>>)
                (rule) ->
                    rule.socket.requestChannel(
                        Flux.generate(
                            () -> 1L,
                            (index, sink) -> {
                              ByteBuf data = rule.alloc().buffer();
                              data.writeCharSequence("d" + index, CharsetUtil.UTF_8);
                              ByteBuf metadata = rule.alloc().buffer();
                              metadata.writeCharSequence("m" + index, CharsetUtil.UTF_8);
                              final Payload payload = ByteBufPayload.create(data, metadata);
                              sink.next(payload);
                              return ++index;
                            })),
            (BiConsumer<AssertSubscriber<Payload>, ClientSocketRule>)
                (as, rule) -> {
                  ByteBufAllocator allocator = rule.alloc();
                  as.request(1);
                  int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL);
                  ByteBuf frame = CancelFrameCodec.encode(allocator, streamId);

                  RaceTestUtils.race(
                      () -> as.request(Long.MAX_VALUE),
                      () -> rule.connection.addToReceivedBuffer(frame));
                }),
        Arguments.of(
            (Function<ClientSocketRule, Publisher<Payload>>)
                (rule) ->
                    rule.socket.requestChannel(
                        Flux.generate(
                            () -> 1L,
                            (index, sink) -> {
                              ByteBuf data = rule.alloc().buffer();
                              data.writeCharSequence("d" + index, CharsetUtil.UTF_8);
                              ByteBuf metadata = rule.alloc().buffer();
                              metadata.writeCharSequence("m" + index, CharsetUtil.UTF_8);
                              final Payload payload = ByteBufPayload.create(data, metadata);
                              sink.next(payload);
                              return ++index;
                            })),
            (BiConsumer<AssertSubscriber<Payload>, ClientSocketRule>)
                (as, rule) -> {
                  ByteBufAllocator allocator = rule.alloc();
                  as.request(1);
                  int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL);
                  ByteBuf frame =
                      ErrorFrameCodec.encode(allocator, streamId, new RuntimeException("test"));

                  RaceTestUtils.race(
                      () -> as.request(Long.MAX_VALUE),
                      () -> rule.connection.addToReceivedBuffer(frame));
                }),
        Arguments.of(
            (Function<ClientSocketRule, Publisher<Payload>>)
                (rule) -> rule.socket.requestResponse(EmptyPayload.INSTANCE),
            (BiConsumer<AssertSubscriber<Payload>, ClientSocketRule>)
                (as, rule) -> {
                  ByteBufAllocator allocator = rule.alloc();
                  ByteBuf metadata = allocator.buffer();
                  metadata.writeCharSequence("abc", CharsetUtil.UTF_8);
                  ByteBuf data = allocator.buffer();
                  data.writeCharSequence("def", CharsetUtil.UTF_8);
                  as.request(Long.MAX_VALUE);
                  int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE);
                  ByteBuf frame =
                      PayloadFrameCodec.encode(
                          allocator, streamId, false, false, true, metadata, data);

                  RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame));
                }));
  }

  @Test
  public void simpleOnDiscardRequestChannelTest() {
    AssertSubscriber<Payload> assertSubscriber = AssertSubscriber.create(1);
    TestPublisher<Payload> testPublisher = TestPublisher.create();

    Flux<Payload> payloadFlux = rule.socket.requestChannel(testPublisher);

    payloadFlux.subscribe(assertSubscriber);

    testPublisher.next(
        ByteBufPayload.create("d", "m"),
        ByteBufPayload.create("d1", "m1"),
        ByteBufPayload.create("d2", "m2"));

    assertSubscriber.cancel();

    Assertions.assertThat(rule.connection.getSent()).allMatch(ByteBuf::release);

    rule.assertHasNoLeaks();
  }

  @Test
  public void simpleOnDiscardRequestChannelTest2() {
    ByteBufAllocator allocator = rule.alloc();
    AssertSubscriber<Payload> assertSubscriber = AssertSubscriber.create(1);
    TestPublisher<Payload> testPublisher = TestPublisher.create();

    Flux<Payload> payloadFlux = rule.socket.requestChannel(testPublisher);

    payloadFlux.subscribe(assertSubscriber);

    testPublisher.next(ByteBufPayload.create("d", "m"));

    int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL);
    testPublisher.next(ByteBufPayload.create("d1", "m1"), ByteBufPayload.create("d2", "m2"));

    rule.connection.addToReceivedBuffer(
        ErrorFrameCodec.encode(
            allocator, streamId, new CustomRSocketException(0x00000404, "test")));

    Assertions.assertThat(rule.connection.getSent()).allMatch(ByteBuf::release);

    rule.assertHasNoLeaks();
  }

  @ParameterizedTest
  @MethodSource("encodeDecodePayloadCases")
  public void verifiesThatFrameWithNoMetadataHasDecodedCorrectlyIntoPayload(
      FrameType frameType, int framesCnt, int responsesCnt) {
    ByteBufAllocator allocator = rule.alloc();
    AssertSubscriber<Payload> assertSubscriber = AssertSubscriber.create(responsesCnt);
    TestPublisher<Payload> testPublisher = TestPublisher.create();

    Publisher<Payload> response;

    switch (frameType) {
      case REQUEST_FNF:
        response =
            testPublisher.mono().flatMap(p -> rule.socket.fireAndForget(p).then(Mono.empty()));
        break;
      case REQUEST_RESPONSE:
        response = testPublisher.mono().flatMap(p -> rule.socket.requestResponse(p));
        break;
      case REQUEST_STREAM:
        response = testPublisher.mono().flatMapMany(p -> rule.socket.requestStream(p));
        break;
      case REQUEST_CHANNEL:
        response = rule.socket.requestChannel(testPublisher.flux());
        break;
      default:
        throw new UnsupportedOperationException("illegal case");
    }

    response.subscribe(assertSubscriber);
    testPublisher.next(ByteBufPayload.create("d"));

    int streamId = rule.getStreamIdForRequestType(frameType);

    if (responsesCnt > 0) {
      for (int i = 0; i < responsesCnt - 1; i++) {
        rule.connection.addToReceivedBuffer(
            PayloadFrameCodec.encode(
                allocator,
                streamId,
                false,
                false,
                true,
                null,
                Unpooled.wrappedBuffer(("rd" + (i + 1)).getBytes())));
      }

      rule.connection.addToReceivedBuffer(
          PayloadFrameCodec.encode(
              allocator,
              streamId,
              false,
              true,
              true,
              null,
              Unpooled.wrappedBuffer(("rd" + responsesCnt).getBytes())));
    }

    if (framesCnt > 1) {
      rule.connection.addToReceivedBuffer(
          RequestNFrameCodec.encode(allocator, streamId, framesCnt));
    }

    for (int i = 1; i < framesCnt; i++) {
      testPublisher.next(ByteBufPayload.create("d" + i));
    }

    Assertions.assertThat(rule.connection.getSent())
        .describedAs(
            "Interaction Type :[%s]. Expected to observe %s frames sent", frameType, framesCnt)
        .hasSize(framesCnt)
        .allMatch(bb -> !FrameHeaderCodec.hasMetadata(bb))
        .allMatch(ByteBuf::release);

    Assertions.assertThat(assertSubscriber.isTerminated())
        .describedAs("Interaction Type :[%s]. Expected to be terminated", frameType)
        .isTrue();

    Assertions.assertThat(assertSubscriber.values())
        .describedAs(
            "Interaction Type :[%s]. Expected to observe %s frames received",
            frameType, responsesCnt)
        .hasSize(responsesCnt)
        .allMatch(p -> !p.hasMetadata())
        .allMatch(p -> p.release());

    rule.assertHasNoLeaks();
    rule.connection.clearSendReceiveBuffers();
  }

  static Stream<Arguments> encodeDecodePayloadCases() {
    return Stream.of(
        Arguments.of(REQUEST_FNF, 1, 0),
        Arguments.of(REQUEST_RESPONSE, 1, 1),
        Arguments.of(REQUEST_STREAM, 1, 5),
        Arguments.of(REQUEST_CHANNEL, 5, 5));
  }

  @ParameterizedTest
  @MethodSource("refCntCases")
  public void ensureSendsErrorOnIllegalRefCntPayload(
      BiFunction<Payload, RSocket, Publisher<?>> sourceProducer) {
    Payload invalidPayload = ByteBufPayload.create("test", "test");
    invalidPayload.release();

    Publisher<?> source = sourceProducer.apply(invalidPayload, rule.socket);

    StepVerifier.create(source, 0)
        .expectError(IllegalReferenceCountException.class)
        .verify(Duration.ofMillis(100));
  }

  private static Stream<BiFunction<Payload, RSocket, Publisher<?>>> refCntCases() {
    return Stream.of(
        (p, r) -> r.fireAndForget(p),
        (p, r) -> r.requestResponse(p),
        (p, r) -> r.requestStream(p),
        (p, r) -> r.requestChannel(Mono.just(p)),
        (p, r) ->
            r.requestChannel(Flux.just(EmptyPayload.INSTANCE, p).doOnSubscribe(s -> s.request(1))));
  }

  @Test
  public void ensuresThatNoOpsMustHappenUntilSubscriptionInCaseOfFnfCall() {
    Payload payload1 = ByteBufPayload.create("abc1");
    Mono<Void> fnf1 = rule.socket.fireAndForget(payload1);

    Payload payload2 = ByteBufPayload.create("abc2");
    Mono<Void> fnf2 = rule.socket.fireAndForget(payload2);

    Assertions.assertThat(rule.connection.getSent()).isEmpty();

    // checks that fnf2 should have id 1 even though it was generated later than fnf1
    AssertSubscriber<Void> voidAssertSubscriber2 = fnf2.subscribeWith(AssertSubscriber.create(0));
    voidAssertSubscriber2.assertTerminated().assertNoError();
    Assertions.assertThat(rule.connection.getSent())
        .hasSize(1)
        .first()
        .matches(bb -> frameType(bb) == REQUEST_FNF)
        .matches(bb -> FrameHeaderCodec.streamId(bb) == 1)
        // ensures that this is fnf1 with abc2 data
        .matches(
            bb ->
                ByteBufUtil.equals(
                    RequestFireAndForgetFrameCodec.data(bb),
                    Unpooled.wrappedBuffer("abc2".getBytes())))
        .matches(ReferenceCounted::release);

    rule.connection.clearSendReceiveBuffers();

    // checks that fnf1 should have id 3 even though it was generated earlier
    AssertSubscriber<Void> voidAssertSubscriber1 = fnf1.subscribeWith(AssertSubscriber.create(0));
    voidAssertSubscriber1.assertTerminated().assertNoError();
    Assertions.assertThat(rule.connection.getSent())
        .hasSize(1)
        .first()
        .matches(bb -> frameType(bb) == REQUEST_FNF)
        .matches(bb -> FrameHeaderCodec.streamId(bb) == 3)
        // ensures that this is fnf1 with abc1 data
        .matches(
            bb ->
                ByteBufUtil.equals(
                    RequestFireAndForgetFrameCodec.data(bb),
                    Unpooled.wrappedBuffer("abc1".getBytes())))
        .matches(ReferenceCounted::release);
  }

  @ParameterizedTest
  @MethodSource("requestNInteractions")
  public void ensuresThatNoOpsMustHappenUntilFirstRequestN(
      FrameType frameType, BiFunction<ClientSocketRule, Payload, Publisher<Payload>> interaction) {
    Payload payload1 = ByteBufPayload.create("abc1");
    Publisher<Payload> interaction1 = interaction.apply(rule, payload1);

    Payload payload2 = ByteBufPayload.create("abc2");
    Publisher<Payload> interaction2 = interaction.apply(rule, payload2);

    Assertions.assertThat(rule.connection.getSent()).isEmpty();

    AssertSubscriber<Payload> assertSubscriber1 = AssertSubscriber.create(0);
    interaction1.subscribe(assertSubscriber1);
    AssertSubscriber<Payload> assertSubscriber2 = AssertSubscriber.create(0);
    interaction2.subscribe(assertSubscriber2);
    assertSubscriber1.assertNotTerminated().assertNoError();
    assertSubscriber2.assertNotTerminated().assertNoError();
    // even though we subscribed, nothing should happen until the first requestN
    Assertions.assertThat(rule.connection.getSent()).isEmpty();

    // first request on the second interaction to ensure that stream id issuing on the first request
    assertSubscriber2.request(1);

    Assertions.assertThat(rule.connection.getSent())
        .hasSize(1)
        .first()
        .matches(bb -> frameType(bb) == frameType)
        .matches(
            bb -> FrameHeaderCodec.streamId(bb) == 1,
            "Expected to have stream ID {1} but got {"
                + FrameHeaderCodec.streamId(rule.connection.getSent().iterator().next())
                + "}")
        .matches(
            bb -> {
              switch (frameType) {
                case REQUEST_RESPONSE:
                  return ByteBufUtil.equals(
                      RequestResponseFrameCodec.data(bb),
                      Unpooled.wrappedBuffer("abc2".getBytes()));
                case REQUEST_STREAM:
                  return ByteBufUtil.equals(
                      RequestStreamFrameCodec.data(bb), Unpooled.wrappedBuffer("abc2".getBytes()));
                case REQUEST_CHANNEL:
                  return ByteBufUtil.equals(
                      RequestChannelFrameCodec.data(bb), Unpooled.wrappedBuffer("abc2".getBytes()));
              }

              return false;
            })
        .matches(ReferenceCounted::release);

    rule.connection.clearSendReceiveBuffers();

    assertSubscriber1.request(1);
    Assertions.assertThat(rule.connection.getSent())
        .hasSize(1)
        .first()
        .matches(bb -> frameType(bb) == frameType)
        .matches(
            bb -> FrameHeaderCodec.streamId(bb) == 3,
            "Expected to have stream ID {1} but got {"
                + FrameHeaderCodec.streamId(rule.connection.getSent().iterator().next())
                + "}")
        .matches(
            bb -> {
              switch (frameType) {
                case REQUEST_RESPONSE:
                  return ByteBufUtil.equals(
                      RequestResponseFrameCodec.data(bb),
                      Unpooled.wrappedBuffer("abc1".getBytes()));
                case REQUEST_STREAM:
                  return ByteBufUtil.equals(
                      RequestStreamFrameCodec.data(bb), Unpooled.wrappedBuffer("abc1".getBytes()));
                case REQUEST_CHANNEL:
                  return ByteBufUtil.equals(
                      RequestChannelFrameCodec.data(bb), Unpooled.wrappedBuffer("abc1".getBytes()));
              }

              return false;
            })
        .matches(ReferenceCounted::release);
  }

  private static Stream<Arguments> requestNInteractions() {
    return Stream.of(
        Arguments.of(
            REQUEST_RESPONSE,
            (BiFunction<ClientSocketRule, Payload, Publisher<Payload>>)
                (rule, payload) -> rule.socket.requestResponse(payload)),
        Arguments.of(
            REQUEST_STREAM,
            (BiFunction<ClientSocketRule, Payload, Publisher<Payload>>)
                (rule, payload) -> rule.socket.requestStream(payload)),
        Arguments.of(
            REQUEST_CHANNEL,
            (BiFunction<ClientSocketRule, Payload, Publisher<Payload>>)
                (rule, payload) -> rule.socket.requestChannel(Flux.just(payload))));
  }

  @ParameterizedTest
  @MethodSource("streamRacingCases")
  public void ensuresCorrectOrderOfStreamIdIssuingInCaseOfRacing(
      BiFunction<ClientSocketRule, Payload, Publisher<?>> interaction1,
      BiFunction<ClientSocketRule, Payload, Publisher<?>> interaction2,
      FrameType interactionType1,
      FrameType interactionType2) {
    Assumptions.assumeThat(interactionType1).isNotEqualTo(METADATA_PUSH);
    Assumptions.assumeThat(interactionType2).isNotEqualTo(METADATA_PUSH);
    for (int i = 1; i < 10000; i += 4) {
      Payload payload = DefaultPayload.create("test", "test");
      Publisher<?> publisher1 = interaction1.apply(rule, payload);
      Publisher<?> publisher2 = interaction2.apply(rule, payload);
      RaceTestUtils.race(
          () -> publisher1.subscribe(AssertSubscriber.create()),
          () -> publisher2.subscribe(AssertSubscriber.create()));

      Assertions.assertThat(rule.connection.getSent())
          .extracting(FrameHeaderCodec::streamId)
          .containsExactly(i, i + 2);
      rule.connection.getSent().forEach(bb -> bb.release());
      rule.connection.getSent().clear();
    }
  }

  public static Stream<Arguments> streamRacingCases() {
    return Stream.of(
        Arguments.of(
            (BiFunction<ClientSocketRule, Payload, Publisher<?>>)
                (r, p) -> r.socket.fireAndForget(p),
            (BiFunction<ClientSocketRule, Payload, Publisher<?>>)
                (r, p) -> r.socket.requestResponse(p),
            REQUEST_FNF,
            REQUEST_RESPONSE),
        Arguments.of(
            (BiFunction<ClientSocketRule, Payload, Publisher<?>>)
                (r, p) -> r.socket.requestResponse(p),
            (BiFunction<ClientSocketRule, Payload, Publisher<?>>)
                (r, p) -> r.socket.requestStream(p),
            REQUEST_RESPONSE,
            REQUEST_STREAM),
        Arguments.of(
            (BiFunction<ClientSocketRule, Payload, Publisher<?>>)
                (r, p) -> r.socket.requestStream(p),
            (BiFunction<ClientSocketRule, Payload, Publisher<?>>)
                (r, p) -> {
                  AtomicBoolean subscribed = new AtomicBoolean();
                  Flux<Payload> just = Flux.just(p).doOnSubscribe((__) -> subscribed.set(true));
                  return r.socket
                      .requestChannel(just)
                      .doFinally(
                          __ -> {
                            if (!subscribed.get()) {
                              p.release();
                            }
                          });
                },
            REQUEST_STREAM,
            REQUEST_CHANNEL),
        Arguments.of(
            (BiFunction<ClientSocketRule, Payload, Publisher<?>>)
                (r, p) -> {
                  AtomicBoolean subscribed = new AtomicBoolean();
                  Flux<Payload> just = Flux.just(p).doOnSubscribe((__) -> subscribed.set(true));
                  return r.socket
                      .requestChannel(just)
                      .doFinally(
                          __ -> {
                            if (!subscribed.get()) {
                              p.release();
                            }
                          });
                },
            (BiFunction<ClientSocketRule, Payload, Publisher<?>>)
                (r, p) -> r.socket.fireAndForget(p),
            REQUEST_CHANNEL,
            REQUEST_FNF),
        Arguments.of(
            (BiFunction<ClientSocketRule, Payload, Publisher<?>>)
                (r, p) -> r.socket.metadataPush(p),
            (BiFunction<ClientSocketRule, Payload, Publisher<?>>)
                (r, p) -> r.socket.fireAndForget(p),
            METADATA_PUSH,
            REQUEST_FNF));
  }

  @ParameterizedTest
  @MethodSource("streamRacingCases")
  @SuppressWarnings({"rawtypes", "unchecked"})
  public void shouldTerminateAllStreamsIfThereRacingBetweenDisposeAndRequests(
      BiFunction<ClientSocketRule, Payload, Publisher<?>> interaction1,
      BiFunction<ClientSocketRule, Payload, Publisher<?>> interaction2,
      FrameType interactionType1,
      FrameType interactionType2) {
    for (int i = 1; i < 10000; i++) {
      Payload payload1 = ByteBufPayload.create("test", "test");
      Payload payload2 = ByteBufPayload.create("test", "test");
      AssertSubscriber assertSubscriber1 = AssertSubscriber.create();
      AssertSubscriber assertSubscriber2 = AssertSubscriber.create();
      Publisher<?> publisher1 = interaction1.apply(rule, payload1);
      Publisher<?> publisher2 = interaction2.apply(rule, payload2);
      RaceTestUtils.race(
          () -> rule.socket.dispose(),
          () ->
              RaceTestUtils.race(
                  () -> publisher1.subscribe(assertSubscriber1),
                  () -> publisher2.subscribe(assertSubscriber2),
                  Schedulers.parallel()),
          Schedulers.parallel());

      assertSubscriber1.await().assertTerminated();
      if (interactionType1 != REQUEST_FNF) {
        assertSubscriber1.assertError(ClosedChannelException.class);
      } else {
        try {
          assertSubscriber1.assertError(ClosedChannelException.class);
        } catch (Throwable t) {
          // fnf call may be completed
          assertSubscriber1.assertComplete();
        }
      }
      assertSubscriber2.await().assertTerminated();
      if (interactionType2 != REQUEST_FNF) {
        assertSubscriber2.assertError(ClosedChannelException.class);
      } else {
        try {
          assertSubscriber2.assertError(ClosedChannelException.class);
        } catch (Throwable t) {
          // fnf call may be completed
          assertSubscriber2.assertComplete();
        }
      }

      Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release);
      rule.connection.getSent().clear();

      Assertions.assertThat(payload1.refCnt()).isZero();
      Assertions.assertThat(payload2.refCnt()).isZero();
    }
  }

  @Test
  // see https://github.com/rsocket/rsocket-java/issues/858
  public void testWorkaround858() {
    ByteBuf buffer = rule.alloc().buffer();
    buffer.writeCharSequence("test", CharsetUtil.UTF_8);

    rule.socket.requestResponse(ByteBufPayload.create(buffer)).subscribe();

    rule.connection.addToReceivedBuffer(
        ErrorFrameCodec.encode(rule.alloc(), 1, new RuntimeException("test")));

    Assertions.assertThat(rule.connection.getSent())
        .hasSize(1)
        .first()
        .matches(bb -> FrameHeaderCodec.frameType(bb) == REQUEST_RESPONSE)
        .matches(ByteBuf::release);

    Assertions.assertThat(rule.socket.isDisposed()).isFalse();

    rule.assertHasNoLeaks();
  }

  public static class ClientSocketRule extends AbstractSocketRule<RSocketRequester> {
    @Override
    protected RSocketRequester newRSocket() {
      return new RSocketRequester(
          connection,
          PayloadDecoder.ZERO_COPY,
          StreamIdSupplier.clientSupplier(),
          0,
          Integer.MAX_VALUE,
          Integer.MAX_VALUE,
          null,
          RequesterLeaseHandler.None,
          TestScheduler.INSTANCE);
    }

    public int getStreamIdForRequestType(FrameType expectedFrameType) {
      assertThat("Unexpected frames sent.", connection.getSent(), hasSize(greaterThanOrEqualTo(1)));
      List<FrameType> framesFound = new ArrayList<>();
      for (ByteBuf frame : connection.getSent()) {
        FrameType frameType = frameType(frame);
        if (frameType == expectedFrameType) {
          return FrameHeaderCodec.streamId(frame);
        }
        framesFound.add(frameType);
      }
      throw new AssertionError(
          "No frames sent with frame type: "
              + expectedFrameType
              + ", frames found: "
              + framesFound);
    }
  }
}