/*
 * Copyright 2015 The gRPC Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.grpc.internal;

import static com.google.common.truth.Truth.assertThat;
import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;

import io.grpc.Attributes;
import io.grpc.CallOptions;
import io.grpc.Codec;
import io.grpc.Deadline;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.StreamTracer;
import io.grpc.internal.AbstractClientStream.TransportState;
import io.grpc.internal.ClientStreamListener.RpcProgress;
import io.grpc.internal.MessageFramerTest.ByteWritableBuffer;
import io.grpc.internal.testing.TestClientStreamTracer;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Matchers;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

/**
 * Test for {@link AbstractClientStream}.  This class tries to test functionality in
 * AbstractClientStream, but not in any super classes.
 */
@RunWith(JUnit4.class)
public class AbstractClientStreamTest {

  @Rule public final ExpectedException thrown = ExpectedException.none();

  private final StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP;
  private final TransportTracer transportTracer = new TransportTracer();
  @Mock private ClientStreamListener mockListener;

  @Before
  public void setUp() {
    MockitoAnnotations.initMocks(this);

    doAnswer(new Answer<Void>() {
      @Override
      public Void answer(InvocationOnMock invocation) throws Throwable {
        StreamListener.MessageProducer producer =
            (StreamListener.MessageProducer) invocation.getArguments()[0];
        while (producer.next() != null) {}
        return null;
      }
    }).when(mockListener).messagesAvailable(Matchers.<StreamListener.MessageProducer>any());
  }

  private final WritableBufferAllocator allocator = new WritableBufferAllocator() {
    @Override
    public WritableBuffer allocate(int capacityHint) {
      return new ByteWritableBuffer(capacityHint);
    }
  };

  @Test
  public void cancel_doNotAcceptOk() {
    for (Code code : Code.values()) {
      ClientStreamListener listener = new NoopClientStreamListener();
      AbstractClientStream stream =
          new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer);
      stream.start(listener);
      if (code != Code.OK) {
        stream.cancel(Status.fromCodeValue(code.value()));
      } else {
        try {
          stream.cancel(Status.fromCodeValue(code.value()));
          fail();
        } catch (IllegalArgumentException e) {
          // ignore
        }
      }
    }
  }

  @Test
  public void cancel_failsOnNull() {
    ClientStreamListener listener = new NoopClientStreamListener();
    AbstractClientStream stream =
        new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer);
    stream.start(listener);
    thrown.expect(NullPointerException.class);

    stream.cancel(null);
  }

  @Test
  public void cancel_notifiesOnlyOnce() {
    final BaseTransportState state = new BaseTransportState(statsTraceCtx, transportTracer);
    AbstractClientStream stream = new BaseAbstractClientStream(allocator, state, new BaseSink() {
      @Override
      public void cancel(Status errorStatus) {
        // Cancel should eventually result in a transportReportStatus on the transport thread
        state.transportReportStatus(errorStatus, true/*stop delivery*/, new Metadata());
      }
    }, statsTraceCtx, transportTracer);
    stream.start(mockListener);

    stream.cancel(Status.DEADLINE_EXCEEDED);
    stream.cancel(Status.DEADLINE_EXCEEDED);

    verify(mockListener).closed(any(Status.class), same(PROCESSED), any(Metadata.class));
  }

  @Test
  public void startFailsOnNullListener() {
    AbstractClientStream stream =
        new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer);

    thrown.expect(NullPointerException.class);

    stream.start(null);
  }

  @Test
  public void cantCallStartTwice() {
    AbstractClientStream stream =
        new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer);
    stream.start(mockListener);
    thrown.expect(IllegalStateException.class);

    stream.start(mockListener);
  }

  @Test
  public void inboundDataReceived_failsOnNullFrame() {
    ClientStreamListener listener = new NoopClientStreamListener();
    AbstractClientStream stream =
        new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer);
    stream.start(listener);

    TransportState state = stream.transportState();

    thrown.expect(NullPointerException.class);
    state.inboundDataReceived(null);
  }

  @Test
  public void inboundHeadersReceived_notifiesListener() {
    AbstractClientStream stream =
        new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer);
    stream.start(mockListener);
    Metadata headers = new Metadata();

    stream.transportState().inboundHeadersReceived(headers);
    verify(mockListener).headersRead(headers);
  }

  @Test
  public void inboundHeadersReceived_failsIfStatusReported() {
    AbstractClientStream stream =
        new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer);
    stream.start(mockListener);
    stream.transportState().transportReportStatus(Status.CANCELLED, false, new Metadata());

    TransportState state = stream.transportState();

    thrown.expect(IllegalStateException.class);
    state.inboundHeadersReceived(new Metadata());
  }

  @Test
  public void inboundHeadersReceived_acceptsGzipContentEncoding() {
    AbstractClientStream stream =
        new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer);
    stream.start(mockListener);
    Metadata headers = new Metadata();
    headers.put(GrpcUtil.CONTENT_ENCODING_KEY, "gzip");

    stream.setFullStreamDecompression(true);
    stream.transportState().inboundHeadersReceived(headers);

    verify(mockListener).headersRead(headers);
  }

  @Test
  // https://tools.ietf.org/html/rfc7231#section-3.1.2.1
  public void inboundHeadersReceived_contentEncodingIsCaseInsensitive() {
    AbstractClientStream stream =
        new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer);
    stream.start(mockListener);
    Metadata headers = new Metadata();
    headers.put(GrpcUtil.CONTENT_ENCODING_KEY, "gZIp");

    stream.setFullStreamDecompression(true);
    stream.transportState().inboundHeadersReceived(headers);

    verify(mockListener).headersRead(headers);
  }

  @Test
  public void inboundHeadersReceived_failsOnUnrecognizedContentEncoding() {
    AbstractClientStream stream =
        new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer);
    stream.start(mockListener);
    Metadata headers = new Metadata();
    headers.put(GrpcUtil.CONTENT_ENCODING_KEY, "not-a-real-compression-method");

    stream.setFullStreamDecompression(true);
    stream.transportState().inboundHeadersReceived(headers);

    verifyNoMoreInteractions(mockListener);
    Throwable t = ((BaseTransportState) stream.transportState()).getDeframeFailedCause();
    assertEquals(Status.INTERNAL.getCode(), Status.fromThrowable(t).getCode());
    assertTrue(
        "unexpected deframe failed description",
        Status.fromThrowable(t)
            .getDescription()
            .startsWith("Can't find full stream decompressor for"));
  }

  @Test
  public void inboundHeadersReceived_disallowsContentAndMessageEncoding() {
    AbstractClientStream stream =
        new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer);
    stream.start(mockListener);
    Metadata headers = new Metadata();
    headers.put(GrpcUtil.CONTENT_ENCODING_KEY, "gzip");
    headers.put(GrpcUtil.MESSAGE_ENCODING_KEY, new Codec.Gzip().getMessageEncoding());

    stream.setFullStreamDecompression(true);
    stream.transportState().inboundHeadersReceived(headers);

    verifyNoMoreInteractions(mockListener);
    Throwable t = ((BaseTransportState) stream.transportState()).getDeframeFailedCause();
    assertEquals(Status.INTERNAL.getCode(), Status.fromThrowable(t).getCode());
    assertTrue(
        "unexpected deframe failed description",
        Status.fromThrowable(t)
            .getDescription()
            .equals("Full stream and gRPC message encoding cannot both be set"));
  }

  @Test
  public void inboundHeadersReceived_acceptsGzipMessageEncoding() {
    AbstractClientStream stream =
        new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer);
    stream.start(mockListener);
    Metadata headers = new Metadata();
    headers.put(GrpcUtil.MESSAGE_ENCODING_KEY, new Codec.Gzip().getMessageEncoding());

    stream.transportState().inboundHeadersReceived(headers);
    verify(mockListener).headersRead(headers);
  }

  @Test
  public void inboundHeadersReceived_acceptsIdentityMessageEncoding() {
    AbstractClientStream stream =
        new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer);
    stream.start(mockListener);
    Metadata headers = new Metadata();
    headers.put(GrpcUtil.MESSAGE_ENCODING_KEY, Codec.Identity.NONE.getMessageEncoding());

    stream.transportState().inboundHeadersReceived(headers);
    verify(mockListener).headersRead(headers);
  }

  @Test
  public void inboundHeadersReceived_failsOnUnrecognizedMessageEncoding() {
    AbstractClientStream stream =
        new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer);
    stream.start(mockListener);
    Metadata headers = new Metadata();
    headers.put(GrpcUtil.MESSAGE_ENCODING_KEY, "not-a-real-compression-method");

    stream.transportState().inboundHeadersReceived(headers);

    verifyNoMoreInteractions(mockListener);
    Throwable t = ((BaseTransportState) stream.transportState()).getDeframeFailedCause();
    assertEquals(Status.INTERNAL.getCode(), Status.fromThrowable(t).getCode());
    assertTrue(
        "unexpected deframe failed description",
        Status.fromThrowable(t).getDescription().startsWith("Can't find decompressor for"));
  }

  @Test
  public void rstStreamClosesStream() {
    AbstractClientStream stream =
        new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer);
    stream.start(mockListener);
    // The application will call request when waiting for a message, which will in turn call this
    // on the transport thread.
    stream.transportState().requestMessagesFromDeframer(1);
    // Send first byte of 2 byte message
    stream.transportState().deframe(ReadableBuffers.wrap(new byte[] {0, 0, 0, 0, 2, 1}));
    Status status = Status.INTERNAL.withDescription("rst___stream");
    // Simulate getting a reset
    stream.transportState().transportReportStatus(status, false /*stop delivery*/, new Metadata());

    ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
    verify(mockListener)
        .closed(statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class));
    assertSame(Status.Code.INTERNAL, statusCaptor.getValue().getCode());
    assertEquals("rst___stream", statusCaptor.getValue().getDescription());
  }

  @Test
  public void trailerOkWithTruncatedMessage() {
    AbstractClientStream stream =
        new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer);
    stream.start(mockListener);

    stream.transportState().requestMessagesFromDeframer(1);
    stream.transportState().deframe(ReadableBuffers.wrap(new byte[] {0, 0, 0, 0, 2, 1}));
    stream.transportState().inboundTrailersReceived(new Metadata(), Status.OK);

    ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
    verify(mockListener)
        .closed(statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class));
    assertSame(Status.Code.INTERNAL, statusCaptor.getValue().getCode());
    assertEquals("Encountered end-of-stream mid-frame", statusCaptor.getValue().getDescription());
  }

  @Test
  public void trailerNotOkWithTruncatedMessage() {
    AbstractClientStream stream =
        new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer);
    stream.start(mockListener);

    stream.transportState().requestMessagesFromDeframer(1);
    stream.transportState().deframe(ReadableBuffers.wrap(new byte[] {0, 0, 0, 0, 2, 1}));
    stream.transportState().inboundTrailersReceived(
        new Metadata(), Status.DATA_LOSS.withDescription("data___loss"));

    ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
    verify(mockListener)
        .closed(statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class));
    assertSame(Status.Code.DATA_LOSS, statusCaptor.getValue().getCode());
    assertEquals("data___loss", statusCaptor.getValue().getDescription());
  }
  
  @Test
  public void getRequest() {
    AbstractClientStream.Sink sink = mock(AbstractClientStream.Sink.class);
    final TestClientStreamTracer tracer = new TestClientStreamTracer();
    StatsTraceContext statsTraceCtx = new StatsTraceContext(new StreamTracer[]{tracer});
    AbstractClientStream stream = new BaseAbstractClientStream(
        allocator,
        new BaseTransportState(statsTraceCtx, transportTracer),
        sink,
        statsTraceCtx,
        transportTracer,
        true);
    stream.start(mockListener);
    stream.writeMessage(new ByteArrayInputStream(new byte[1]));
    // writeHeaders will be delayed since we're sending a GET request.
    verify(sink, never()).writeHeaders(any(Metadata.class), any(byte[].class));
    // halfClose will trigger writeHeaders.
    stream.halfClose();
    ArgumentCaptor<byte[]> payloadCaptor = ArgumentCaptor.forClass(byte[].class);
    verify(sink).writeHeaders(any(Metadata.class), payloadCaptor.capture());
    assertTrue(payloadCaptor.getValue() != null);
    // GET requests don't have BODY.
    verify(sink, never())
        .writeFrame(
            any(WritableBuffer.class), any(Boolean.class), any(Boolean.class), any(Integer.class));
    assertThat(tracer.nextOutboundEvent()).isEqualTo("outboundMessage(0)");
    assertThat(tracer.nextOutboundEvent()).matches("outboundMessageSent\\(0, [0-9]+, [0-9]+\\)");
    assertNull(tracer.nextOutboundEvent());
    assertNull(tracer.nextInboundEvent());
    assertEquals(1, tracer.getOutboundWireSize());
    assertEquals(1, tracer.getOutboundUncompressedSize());
  }

  @Test
  public void writeMessage_closesStream() throws IOException {
    AbstractClientStream.Sink sink = mock(AbstractClientStream.Sink.class);
    final TestClientStreamTracer tracer = new TestClientStreamTracer();
    StatsTraceContext statsTraceCtx = new StatsTraceContext(new StreamTracer[] {tracer});
    AbstractClientStream stream = new BaseAbstractClientStream(
        allocator,
        new BaseTransportState(statsTraceCtx, transportTracer),
        sink,
        statsTraceCtx,
        transportTracer,
        true);
    stream.start(mockListener);
    InputStream input = mock(InputStream.class, delegatesTo(new ByteArrayInputStream(new byte[1])));
    stream.writeMessage(input);
    verify(input).close();
  }

  @Test
  public void deadlineTimeoutPopulatedToHeaders() {
    AbstractClientStream.Sink sink = mock(AbstractClientStream.Sink.class);
    ClientStream stream = new BaseAbstractClientStream(
        allocator, new BaseTransportState(statsTraceCtx, transportTracer), sink, statsTraceCtx,
        transportTracer);

    stream.setDeadline(Deadline.after(1, TimeUnit.SECONDS));
    stream.start(mockListener);

    ArgumentCaptor<Metadata> headersCaptor = ArgumentCaptor.forClass(Metadata.class);
    verify(sink).writeHeaders(headersCaptor.capture(), any(byte[].class));

    Metadata headers = headersCaptor.getValue();
    assertTrue(headers.containsKey(GrpcUtil.TIMEOUT_KEY));
    assertThat(headers.get(GrpcUtil.TIMEOUT_KEY).longValue())
        .isLessThan(TimeUnit.SECONDS.toNanos(1));
    assertThat(headers.get(GrpcUtil.TIMEOUT_KEY).longValue())
        .isGreaterThan(TimeUnit.MILLISECONDS.toNanos(600));
  }

  /**
   * No-op base class for testing.
   */
  private static class BaseAbstractClientStream extends AbstractClientStream {
    private final TransportState state;
    private final Sink sink;

    public BaseAbstractClientStream(
        WritableBufferAllocator allocator,
        StatsTraceContext statsTraceCtx,
        TransportTracer transportTracer) {
      this(
          allocator,
          new BaseTransportState(statsTraceCtx, transportTracer),
          new BaseSink(),
          statsTraceCtx,
          transportTracer);
    }

    public BaseAbstractClientStream(
        WritableBufferAllocator allocator,
        TransportState state,
        Sink sink,
        StatsTraceContext statsTraceCtx,
        TransportTracer transportTracer) {
      this(allocator, state, sink, statsTraceCtx, transportTracer, false);
    }

    public BaseAbstractClientStream(
        WritableBufferAllocator allocator,
        TransportState state,
        Sink sink,
        StatsTraceContext statsTraceCtx,
        TransportTracer transportTracer,
        boolean useGet) {
      super(allocator, statsTraceCtx, transportTracer, new Metadata(), CallOptions.DEFAULT, useGet);
      this.state = state;
      this.sink = sink;
    }

    @Override
    protected Sink abstractClientStreamSink() {
      return sink;
    }

    @Override
    public TransportState transportState() {
      return state;
    }

    @Override
    public void setAuthority(String authority) {}

    @Override
    public void setMaxInboundMessageSize(int maxSize) {}

    @Override
    public void setMaxOutboundMessageSize(int maxSize) {}

    @Override
    public Attributes getAttributes() {
      return Attributes.EMPTY;
    }
  }

  private static class BaseSink implements AbstractClientStream.Sink {
    @Override
    public void writeHeaders(Metadata headers, byte[] payload) {}

    @Override
    public void request(int numMessages) {}

    @Override
    public void writeFrame(
        WritableBuffer frame, boolean endOfStream, boolean flush, int numMessages) {}

    @Override
    public void cancel(Status reason) {}
  }

  private static class BaseTransportState extends AbstractClientStream.TransportState {
    private Throwable deframeFailedCause;

    private Throwable getDeframeFailedCause() {
      return deframeFailedCause;
    }

    public BaseTransportState(StatsTraceContext statsTraceCtx, TransportTracer transportTracer) {
      super(DEFAULT_MAX_MESSAGE_SIZE, statsTraceCtx, transportTracer);
    }

    @Override
    public void deframeFailed(Throwable cause) {
      assertNull("deframeFailed already called", deframeFailedCause);
      deframeFailedCause = cause;
    }

    @Override
    public void bytesRead(int processedBytes) {}

    @Override
    public void runOnTransportThread(Runnable r) {
      r.run();
    }
  }
}