/*
 * Copyright 2015 The gRPC Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.grpc.internal;

import static com.google.common.base.Charsets.UTF_8;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.isA;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.common.io.CharStreams;
import io.grpc.CompressorRegistry;
import io.grpc.Context;
import io.grpc.DecompressorRegistry;
import io.grpc.InternalChannelz.ServerStats;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.Marshaller;
import io.grpc.MethodDescriptor.MethodType;
import io.grpc.ServerCall;
import io.grpc.Status;
import io.grpc.internal.ServerCallImpl.ServerStreamListenerImpl;
import io.grpc.internal.testing.SingleMessageProducer;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.io.InputStreamReader;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

@RunWith(JUnit4.class)
public class ServerCallImplTest {
  @Rule public final ExpectedException thrown = ExpectedException.none();
  @Mock private ServerStream stream;
  @Mock private ServerCall.Listener<Long> callListener;

  private final CallTracer serverCallTracer = CallTracer.getDefaultFactory().create();
  private ServerCallImpl<Long, Long> call;
  private Context.CancellableContext context;

  private static final MethodDescriptor<Long, Long> UNARY_METHOD =
      MethodDescriptor.<Long, Long>newBuilder()
          .setType(MethodType.UNARY)
          .setFullMethodName("service/method")
          .setRequestMarshaller(new LongMarshaller())
          .setResponseMarshaller(new LongMarshaller())
          .build();

  private static final MethodDescriptor<Long, Long> CLIENT_STREAMING_METHOD =
      MethodDescriptor.<Long, Long>newBuilder()
          .setType(MethodType.UNARY)
          .setFullMethodName("service/method")
          .setRequestMarshaller(new LongMarshaller())
          .setResponseMarshaller(new LongMarshaller())
          .build();

  private final Metadata requestHeaders = new Metadata();

  @Before
  public void setUp() {
    MockitoAnnotations.initMocks(this);
    context = Context.ROOT.withCancellation();
    call = new ServerCallImpl<Long, Long>(stream, UNARY_METHOD, requestHeaders, context,
        DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance(),
        serverCallTracer);
  }

  @Test
  public void callTracer_success() {
    callTracer0(Status.OK);
  }

  @Test
  public void callTracer_failure() {
    callTracer0(Status.UNKNOWN);
  }

  private void callTracer0(Status status) {
    CallTracer tracer = CallTracer.getDefaultFactory().create();
    ServerStats.Builder beforeBuilder = new ServerStats.Builder();
    tracer.updateBuilder(beforeBuilder);
    ServerStats before = beforeBuilder.build();
    assertEquals(0, before.callsStarted);
    assertEquals(0, before.lastCallStartedNanos);

    call = new ServerCallImpl<Long, Long>(stream, UNARY_METHOD, requestHeaders, context,
        DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance(),
        tracer);

    // required boilerplate
    call.sendHeaders(new Metadata());
    call.sendMessage(123L);
    // end: required boilerplate

    call.close(status, new Metadata());
    ServerStats.Builder afterBuilder = new ServerStats.Builder();
    tracer.updateBuilder(afterBuilder);
    ServerStats after = afterBuilder.build();
    assertEquals(1, after.callsStarted);
    if (status.isOk()) {
      assertEquals(1, after.callsSucceeded);
    } else {
      assertEquals(1, after.callsFailed);
    }
  }

  @Test
  public void request() {
    call.request(10);

    verify(stream).request(10);
  }

  @Test
  public void sendHeader_firstCall() {
    Metadata headers = new Metadata();

    call.sendHeaders(headers);

    verify(stream).writeHeaders(headers);
  }

  @Test
  public void sendHeader_failsOnSecondCall() {
    call.sendHeaders(new Metadata());
    thrown.expect(IllegalStateException.class);
    thrown.expectMessage("sendHeaders has already been called");

    call.sendHeaders(new Metadata());
  }

  @Test
  public void sendHeader_failsOnClosed() {
    call.close(Status.CANCELLED, new Metadata());

    thrown.expect(IllegalStateException.class);
    thrown.expectMessage("call is closed");

    call.sendHeaders(new Metadata());
  }

  @Test
  public void sendMessage() {
    call.sendHeaders(new Metadata());
    call.sendMessage(1234L);

    verify(stream).writeMessage(isA(InputStream.class));
    verify(stream).flush();
  }

  @Test
  public void sendMessage_failsOnClosed() {
    call.sendHeaders(new Metadata());
    call.close(Status.CANCELLED, new Metadata());

    thrown.expect(IllegalStateException.class);
    thrown.expectMessage("call is closed");

    call.sendMessage(1234L);
  }

  @Test
  public void sendMessage_failsIfheadersUnsent() {
    thrown.expect(IllegalStateException.class);
    thrown.expectMessage("sendHeaders has not been called");

    call.sendMessage(1234L);
  }

  @Test
  public void sendMessage_closesOnFailure() {
    call.sendHeaders(new Metadata());
    doThrow(new RuntimeException("bad")).when(stream).writeMessage(isA(InputStream.class));

    call.sendMessage(1234L);

    verify(stream).close(isA(Status.class), isA(Metadata.class));
  }

  @Test
  public void sendMessage_serverSendsOne_closeOnSecondCall_unary() {
    sendMessage_serverSendsOne_closeOnSecondCall(UNARY_METHOD);
  }

  @Test
  public void sendMessage_serverSendsOne_closeOnSecondCall_clientStreaming() {
    sendMessage_serverSendsOne_closeOnSecondCall(CLIENT_STREAMING_METHOD);
  }

  private void sendMessage_serverSendsOne_closeOnSecondCall(
      MethodDescriptor<Long, Long> method) {
    ServerCallImpl<Long, Long> serverCall = new ServerCallImpl<Long, Long>(
        stream,
        method,
        requestHeaders,
        context,
        DecompressorRegistry.getDefaultInstance(),
        CompressorRegistry.getDefaultInstance(),
        serverCallTracer);
    serverCall.sendHeaders(new Metadata());
    serverCall.sendMessage(1L);
    verify(stream, times(1)).writeMessage(any(InputStream.class));
    verify(stream, never()).close(any(Status.class), any(Metadata.class));

    // trying to send a second message causes gRPC to close the underlying stream
    serverCall.sendMessage(1L);
    verify(stream, times(1)).writeMessage(any(InputStream.class));
    ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
    verify(stream, times(1)).cancel(statusCaptor.capture());
    assertEquals(Status.Code.INTERNAL, statusCaptor.getValue().getCode());
    assertEquals(ServerCallImpl.TOO_MANY_RESPONSES, statusCaptor.getValue().getDescription());
  }

  @Test
  public void sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion_unary() {
    sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion(UNARY_METHOD);
  }

  @Test
  public void sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion_clientStreaming() {
    sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion(CLIENT_STREAMING_METHOD);
  }

  private void sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion(
      MethodDescriptor<Long, Long> method) {
    ServerCallImpl<Long, Long> serverCall = new ServerCallImpl<Long, Long>(
        stream,
        method,
        requestHeaders,
        context,
        DecompressorRegistry.getDefaultInstance(),
        CompressorRegistry.getDefaultInstance(),
        serverCallTracer);
    serverCall.sendHeaders(new Metadata());
    serverCall.sendMessage(1L);
    serverCall.sendMessage(1L);
    verify(stream, times(1)).writeMessage(any(InputStream.class));
    verify(stream, times(1)).cancel(any(Status.class));

    // App runs to completion but everything is ignored
    serverCall.sendMessage(1L);
    serverCall.close(Status.OK, new Metadata());
    try {
      serverCall.close(Status.OK, new Metadata());
      fail("calling a second time should still cause an error");
    } catch (IllegalStateException expected) {
      // noop
    }
  }

  @Test
  public void serverSendsOne_okFailsOnMissingResponse_unary() {
    serverSendsOne_okFailsOnMissingResponse(UNARY_METHOD);
  }

  @Test
  public void serverSendsOne_okFailsOnMissingResponse_clientStreaming() {
    serverSendsOne_okFailsOnMissingResponse(CLIENT_STREAMING_METHOD);
  }

  private void serverSendsOne_okFailsOnMissingResponse(
      MethodDescriptor<Long, Long> method) {
    ServerCallImpl<Long, Long> serverCall = new ServerCallImpl<Long, Long>(
        stream,
        method,
        requestHeaders,
        context,
        DecompressorRegistry.getDefaultInstance(),
        CompressorRegistry.getDefaultInstance(),
        serverCallTracer);
    serverCall.close(Status.OK, new Metadata());
    ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
    verify(stream, times(1)).cancel(statusCaptor.capture());
    assertEquals(Status.Code.INTERNAL, statusCaptor.getValue().getCode());
    assertEquals(ServerCallImpl.MISSING_RESPONSE, statusCaptor.getValue().getDescription());
  }

  @Test
  public void serverSendsOne_canErrorWithoutResponse() {
    final String description = "test description";
    final Status status = Status.RESOURCE_EXHAUSTED.withDescription(description);
    final Metadata metadata = new Metadata();
    call.close(status, metadata);
    verify(stream, times(1)).close(same(status), same(metadata));
  }

  @Test
  public void isReady() {
    when(stream.isReady()).thenReturn(true);

    assertTrue(call.isReady());
  }

  @Test
  public void getAuthority() {
    when(stream.getAuthority()).thenReturn("fooapi.googleapis.com");
    assertEquals("fooapi.googleapis.com", call.getAuthority());
    verify(stream).getAuthority();
  }

  @Test
  public void getNullAuthority() {
    when(stream.getAuthority()).thenReturn(null);
    assertNull(call.getAuthority());
    verify(stream).getAuthority();
  }

  @Test
  public void setMessageCompression() {
    call.setMessageCompression(true);

    verify(stream).setMessageCompression(true);
  }

  @Test
  public void streamListener_halfClosed() {
    ServerStreamListenerImpl<Long> streamListener =
        new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context);

    streamListener.halfClosed();

    verify(callListener).onHalfClose();
  }

  @Test
  public void streamListener_halfClosed_onlyOnce() {
    ServerStreamListenerImpl<Long> streamListener =
        new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context);
    streamListener.halfClosed();
    // canceling the call should short circuit future halfClosed() calls.
    streamListener.closed(Status.CANCELLED);

    streamListener.halfClosed();

    verify(callListener).onHalfClose();
  }

  @Test
  public void streamListener_closedOk() {
    ServerStreamListenerImpl<Long> streamListener =
        new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context);

    streamListener.closed(Status.OK);

    verify(callListener).onComplete();
    assertTrue(context.isCancelled());
    assertNull(context.cancellationCause());
  }

  @Test
  public void streamListener_closedCancelled() {
    ServerStreamListenerImpl<Long> streamListener =
        new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context);

    streamListener.closed(Status.CANCELLED);

    verify(callListener).onCancel();
    assertTrue(context.isCancelled());
    assertNull(context.cancellationCause());
  }

  @Test
  public void streamListener_onReady() {
    ServerStreamListenerImpl<Long> streamListener =
        new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context);

    streamListener.onReady();

    verify(callListener).onReady();
  }

  @Test
  public void streamListener_onReady_onlyOnce() {
    ServerStreamListenerImpl<Long> streamListener =
        new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context);
    streamListener.onReady();
    // canceling the call should short circuit future halfClosed() calls.
    streamListener.closed(Status.CANCELLED);

    streamListener.onReady();

    verify(callListener).onReady();
  }

  @Test
  public void streamListener_messageRead() {
    ServerStreamListenerImpl<Long> streamListener =
        new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context);
    streamListener.messagesAvailable(new SingleMessageProducer(UNARY_METHOD.streamRequest(1234L)));

    verify(callListener).onMessage(1234L);
  }

  @Test
  public void streamListener_messageRead_onlyOnce() {
    ServerStreamListenerImpl<Long> streamListener =
        new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context);
    streamListener.messagesAvailable(new SingleMessageProducer(UNARY_METHOD.streamRequest(1234L)));
    // canceling the call should short circuit future halfClosed() calls.
    streamListener.closed(Status.CANCELLED);

    streamListener.messagesAvailable(new SingleMessageProducer(UNARY_METHOD.streamRequest(1234L)));

    verify(callListener).onMessage(1234L);
  }

  @Test
  public void streamListener_unexpectedRuntimeException() {
    ServerStreamListenerImpl<Long> streamListener =
        new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context);
    doThrow(new RuntimeException("unexpected exception"))
        .when(callListener)
        .onMessage(any(Long.class));

    InputStream inputStream = UNARY_METHOD.streamRequest(1234L);

    thrown.expect(RuntimeException.class);
    thrown.expectMessage("unexpected exception");
    streamListener.messagesAvailable(new SingleMessageProducer(inputStream));
  }

  private static class LongMarshaller implements Marshaller<Long> {
    @Override
    public InputStream stream(Long value) {
      return new ByteArrayInputStream(value.toString().getBytes(UTF_8));
    }

    @Override
    public Long parse(InputStream stream) {
      try {
        return Long.parseLong(CharStreams.toString(new InputStreamReader(stream, UTF_8)));
      } catch (Exception e) {
        throw new RuntimeException(e);
      }
    }
  }
}