/*
 * Copyright 2018 The gRPC Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.grpc.alts.internal;

import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;

import io.grpc.Attributes;
import io.grpc.Channel;
import io.grpc.Grpc;
import io.grpc.InternalChannelz;
import io.grpc.ManagedChannel;
import io.grpc.SecurityLevel;
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
import io.grpc.alts.internal.AltsProtocolNegotiator.ServerAltsProtocolNegotiator;
import io.grpc.alts.internal.TsiFrameProtector.Consumer;
import io.grpc.alts.internal.TsiPeer.Property;
import io.grpc.internal.FixedObjectPool;
import io.grpc.internal.GrpcAttributes;
import io.grpc.internal.ObjectPool;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.InternalProtocolNegotiationEvent;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http2.DefaultHttp2Connection;
import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder;
import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder;
import io.netty.handler.codec.http2.DefaultHttp2FrameReader;
import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
import io.netty.handler.codec.http2.Http2Connection;
import io.netty.handler.codec.http2.Http2ConnectionDecoder;
import io.netty.handler.codec.http2.Http2ConnectionEncoder;
import io.netty.handler.codec.http2.Http2FrameReader;
import io.netty.handler.codec.http2.Http2FrameWriter;
import io.netty.handler.codec.http2.Http2Settings;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Tests for {@link AltsProtocolNegotiator}. */
@RunWith(JUnit4.class)
@SuppressWarnings("FutureReturnValueIgnored")
public class AltsProtocolNegotiatorTest {

  private final CapturingGrpcHttp2ConnectionHandler grpcHandler = capturingGrpcHandler();

  private final List<ReferenceCounted> references = new ArrayList<>();
  private final LinkedBlockingQueue<InterceptingProtector> protectors = new LinkedBlockingQueue<>();

  private EmbeddedChannel channel;
  private Throwable caughtException;

  private TsiPeer mockedTsiPeer = new TsiPeer(Collections.<Property<?>>emptyList());
  private AltsAuthContext mockedAltsContext =
      new AltsAuthContext(
          HandshakerResult.newBuilder()
              .setPeerRpcVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
              .build());
  private final TsiHandshaker mockHandshaker =
      new DelegatingTsiHandshaker(FakeTsiHandshaker.newFakeHandshakerServer()) {
        @Override
        public TsiPeer extractPeer() {
          return mockedTsiPeer;
        }

        @Override
        public Object extractPeerObject() {
          return mockedAltsContext;
        }
      };
  private final NettyTsiHandshaker serverHandshaker = new NettyTsiHandshaker(mockHandshaker);

  @Before
  public void setup() throws Exception {
    ChannelHandler uncaughtExceptionHandler =
        new ChannelDuplexHandler() {
          @Override
          public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
            caughtException = cause;
            super.exceptionCaught(ctx, cause);
            ctx.close();
          }
        };

    TsiHandshakerFactory handshakerFactory =
        new DelegatingTsiHandshakerFactory(FakeTsiHandshaker.clientHandshakerFactory()) {
          @Override
          public TsiHandshaker newHandshaker(String authority) {
            return new DelegatingTsiHandshaker(super.newHandshaker(authority)) {
              @Override
              public TsiPeer extractPeer() throws GeneralSecurityException {
                return mockedTsiPeer;
              }

              @Override
              public Object extractPeerObject() throws GeneralSecurityException {
                return mockedAltsContext;
              }
            };
          }
        };
    ManagedChannel fakeChannel = NettyChannelBuilder.forTarget("localhost:8080").build();
    ObjectPool<Channel> fakeChannelPool = new FixedObjectPool<Channel>(fakeChannel);
    LazyChannel lazyFakeChannel = new LazyChannel(fakeChannelPool);
    ChannelHandler altsServerHandler = new ServerAltsProtocolNegotiator(
        handshakerFactory, lazyFakeChannel)
        .newHandler(grpcHandler);
    // On real server, WBAEH fires default ProtocolNegotiationEvent. KickNH provides this behavior.
    ChannelHandler handler = new KickNegotiationHandler(altsServerHandler);
    channel = new EmbeddedChannel(uncaughtExceptionHandler, handler);
  }

  @After
  public void teardown() throws Exception {
    if (channel != null) {
      @SuppressWarnings("unused") // go/futurereturn-lsc
      Future<?> possiblyIgnoredError = channel.close();
    }

    for (ReferenceCounted reference : references) {
      ReferenceCountUtil.safeRelease(reference);
    }
  }

  @Test
  public void handshakeShouldBeSuccessful() throws Exception {
    doHandshake();
  }

  @Test
  @SuppressWarnings("unchecked") // List cast
  public void protectShouldRoundtrip() throws Exception {
    doHandshake();

    // Write the message 1 character at a time. The message should be buffered
    // and not interfere with the handshake.
    final AtomicInteger writeCount = new AtomicInteger();
    String message = "hello";
    for (int ix = 0; ix < message.length(); ++ix) {
      ByteBuf in = Unpooled.copiedBuffer(message, ix, 1, UTF_8);
      @SuppressWarnings("unused") // go/futurereturn-lsc
      Future<?> possiblyIgnoredError =
          channel
              .write(in)
              .addListener(
                  new ChannelFutureListener() {
                    @Override
                    public void operationComplete(ChannelFuture future) throws Exception {
                      if (future.isSuccess()) {
                        writeCount.incrementAndGet();
                      }
                    }
                  });
    }
    channel.flush();

    // Capture the protected data written to the wire.
    assertEquals(1, channel.outboundMessages().size());
    ByteBuf protectedData = channel.readOutbound();
    assertEquals(message.length(), writeCount.get());

    // Read the protected message at the server and verify it matches the original message.
    TsiFrameProtector serverProtector = serverHandshaker.createFrameProtector(channel.alloc());
    List<ByteBuf> unprotected = new ArrayList<>();
    serverProtector.unprotect(protectedData, (List<Object>) (List<?>) unprotected, channel.alloc());
    // We try our best to remove the HTTP2 handler as soon as possible, but just by constructing it
    // a settings frame is written (and an HTTP2 preface).  This is hard coded into Netty, so we
    // have to remove it here.  See {@code Http2ConnectionHandler.PrefaceDecode.sendPreface}.
    int settingsFrameLength = 9;

    CompositeByteBuf unprotectedAll =
        new CompositeByteBuf(channel.alloc(), false, unprotected.size() + 1, unprotected);
    ByteBuf unprotectedData = unprotectedAll.slice(settingsFrameLength, message.length());
    assertEquals(message, unprotectedData.toString(UTF_8));

    // Protect the same message at the server.
    final AtomicReference<ByteBuf> newlyProtectedData = new AtomicReference<>();
    serverProtector.protectFlush(
        Collections.singletonList(unprotectedData),
        new Consumer<ByteBuf>() {
          @Override
          public void accept(ByteBuf buf) {
            newlyProtectedData.set(buf);
          }
        },
        channel.alloc());

    // Read the protected message at the client and verify that it matches the original message.
    channel.writeInbound(newlyProtectedData.get());
    assertEquals(1, channel.inboundMessages().size());
    assertEquals(message, channel.<ByteBuf>readInbound().toString(UTF_8));
  }

  @Test
  public void unprotectLargeIncomingFrame() throws Exception {

    // We use a server frameprotector with twice the standard frame size.
    int serverFrameSize = 4096 * 2;
    // This should fit into one frame.
    byte[] unprotectedBytes = new byte[serverFrameSize - 500];
    Arrays.fill(unprotectedBytes, (byte) 7);
    ByteBuf unprotectedData = Unpooled.wrappedBuffer(unprotectedBytes);
    unprotectedData.writerIndex(unprotectedBytes.length);

    // Perform handshake.
    doHandshake();

    // Protect the message on the server.
    TsiFrameProtector serverProtector =
        serverHandshaker.createFrameProtector(serverFrameSize, channel.alloc());
    serverProtector.protectFlush(
        Collections.singletonList(unprotectedData),
        new Consumer<ByteBuf>() {
          @Override
          public void accept(ByteBuf buf) {
            channel.writeInbound(buf);
          }
        },
        channel.alloc());
    channel.flushInbound();

    // Read the protected message at the client and verify that it matches the original message.
    assertEquals(1, channel.inboundMessages().size());

    ByteBuf receivedData1 = channel.readInbound();
    int receivedLen1 = receivedData1.readableBytes();
    byte[] receivedBytes = new byte[receivedLen1];
    receivedData1.readBytes(receivedBytes, 0, receivedLen1);

    assertThat(unprotectedBytes.length).isEqualTo(receivedBytes.length);
    assertThat(unprotectedBytes).isEqualTo(receivedBytes);
  }

  @Test
  public void flushShouldFailAllPromises() throws Exception {
    doHandshake();

    channel
        .pipeline()
        .addFirst(
            new ChannelDuplexHandler() {
              @Override
              public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
                  throws Exception {
                throw new Exception("Fake exception");
              }
            });

    // Write the message 1 character at a time.
    String message = "hello";
    final AtomicInteger failures = new AtomicInteger();
    for (int ix = 0; ix < message.length(); ++ix) {
      ByteBuf in = Unpooled.copiedBuffer(message, ix, 1, UTF_8);
      @SuppressWarnings("unused") // go/futurereturn-lsc
      Future<?> possiblyIgnoredError =
          channel
              .write(in)
              .addListener(
                  new ChannelFutureListener() {
                    @Override
                    public void operationComplete(ChannelFuture future) throws Exception {
                      if (!future.isSuccess()) {
                        failures.incrementAndGet();
                      }
                    }
                  });
    }
    channel.flush();

    // Verify that the promises fail.
    assertEquals(message.length(), failures.get());
  }

  @Test
  public void doNotFlushEmptyBuffer() throws Exception {
    doHandshake();
    assertEquals(1, protectors.size());
    InterceptingProtector protector = protectors.poll();

    String message = "hello";
    ByteBuf in = Unpooled.copiedBuffer(message, UTF_8);

    assertEquals(0, protector.flushes.get());
    Future<?> done = channel.write(in);
    channel.flush();
    done.get(5, TimeUnit.SECONDS);
    assertEquals(1, protector.flushes.get());

    done = channel.write(Unpooled.EMPTY_BUFFER);
    channel.flush();
    done.get(5, TimeUnit.SECONDS);
    assertEquals(1, protector.flushes.get());
  }

  @Test
  public void peerPropagated() throws Exception {
    doHandshake();

    assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.TSI_PEER_KEY)).isEqualTo(mockedTsiPeer);
    assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.AUTH_CONTEXT_KEY))
        .isEqualTo(mockedAltsContext);
    assertThat(grpcHandler.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR).toString())
        .isEqualTo("embedded");
    assertThat(grpcHandler.attrs.get(Grpc.TRANSPORT_ATTR_LOCAL_ADDR).toString())
        .isEqualTo("embedded");
    assertThat(grpcHandler.attrs.get(GrpcAttributes.ATTR_SECURITY_LEVEL))
        .isEqualTo(SecurityLevel.PRIVACY_AND_INTEGRITY);
  }

  private void doHandshake() throws Exception {
    // Capture the client frame and add to the server.
    assertEquals(1, channel.outboundMessages().size());
    ByteBuf clientFrame = channel.readOutbound();
    assertTrue(serverHandshaker.processBytesFromPeer(clientFrame));

    // Get the server response handshake frames.
    ByteBuf serverFrame = channel.alloc().buffer();
    serverHandshaker.getBytesToSendToPeer(serverFrame);
    channel.writeInbound(serverFrame);

    // Capture the next client frame and add to the server.
    assertEquals(1, channel.outboundMessages().size());
    clientFrame = channel.readOutbound();
    assertTrue(serverHandshaker.processBytesFromPeer(clientFrame));

    // Get the server response handshake frames.
    serverFrame = channel.alloc().buffer();
    serverHandshaker.getBytesToSendToPeer(serverFrame);
    channel.writeInbound(serverFrame);

    // Ensure that both sides have confirmed that the handshake has completed.
    assertFalse(serverHandshaker.isInProgress());

    if (caughtException != null) {
      throw new RuntimeException(caughtException);
    }
    assertNotNull(grpcHandler.attrs);
  }

  private CapturingGrpcHttp2ConnectionHandler capturingGrpcHandler() {
    // Netty Boilerplate.  We don't really need any of this, but there is a tight coupling
    // between an Http2ConnectionHandler and its dependencies.
    Http2Connection connection = new DefaultHttp2Connection(true);
    Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter();
    Http2FrameReader frameReader = new DefaultHttp2FrameReader(false);
    DefaultHttp2ConnectionEncoder encoder =
        new DefaultHttp2ConnectionEncoder(connection, frameWriter);
    DefaultHttp2ConnectionDecoder decoder =
        new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader);

    return new CapturingGrpcHttp2ConnectionHandler(decoder, encoder, new Http2Settings());
  }

  private final class CapturingGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler {

    private Attributes attrs;

    private CapturingGrpcHttp2ConnectionHandler(
        Http2ConnectionDecoder decoder,
        Http2ConnectionEncoder encoder,
        Http2Settings initialSettings) {
      super(null, decoder, encoder, initialSettings);
    }

    @Override
    public void handleProtocolNegotiationCompleted(
        Attributes attrs,
        @SuppressWarnings("UnusedVariable") InternalChannelz.Security securityInfo) {
      // If we are added to the pipeline, we need to remove ourselves.  The HTTP2 handler
      channel.pipeline().remove(this);
      this.attrs = attrs;
    }
  }

  private static class DelegatingTsiHandshakerFactory implements TsiHandshakerFactory {

    private TsiHandshakerFactory delegate;

    DelegatingTsiHandshakerFactory(TsiHandshakerFactory delegate) {
      this.delegate = delegate;
    }

    @Override
    public TsiHandshaker newHandshaker(String authority) {
      return delegate.newHandshaker(authority);
    }
  }

  private class DelegatingTsiHandshaker implements TsiHandshaker {

    private final TsiHandshaker delegate;

    DelegatingTsiHandshaker(TsiHandshaker delegate) {
      this.delegate = delegate;
    }

    @Override
    public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityException {
      delegate.getBytesToSendToPeer(bytes);
    }

    @Override
    public boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityException {
      return delegate.processBytesFromPeer(bytes);
    }

    @Override
    public boolean isInProgress() {
      return delegate.isInProgress();
    }

    @Override
    public TsiPeer extractPeer() throws GeneralSecurityException {
      return delegate.extractPeer();
    }

    @Override
    public Object extractPeerObject() throws GeneralSecurityException {
      return delegate.extractPeerObject();
    }

    @Override
    public TsiFrameProtector createFrameProtector(ByteBufAllocator alloc) {
      InterceptingProtector protector =
          new InterceptingProtector(delegate.createFrameProtector(alloc));
      protectors.add(protector);
      return protector;
    }

    @Override
    public TsiFrameProtector createFrameProtector(int maxFrameSize, ByteBufAllocator alloc) {
      InterceptingProtector protector =
          new InterceptingProtector(delegate.createFrameProtector(maxFrameSize, alloc));
      protectors.add(protector);
      return protector;
    }

    @Override
    public void close() {
      delegate.close();
    }
  }

  private static class InterceptingProtector implements TsiFrameProtector {

    private final TsiFrameProtector delegate;
    final AtomicInteger flushes = new AtomicInteger();

    InterceptingProtector(TsiFrameProtector delegate) {
      this.delegate = delegate;
    }

    @Override
    public void protectFlush(
        List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)
        throws GeneralSecurityException {
      flushes.incrementAndGet();
      delegate.protectFlush(unprotectedBufs, ctxWrite, alloc);
    }

    @Override
    public void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)
        throws GeneralSecurityException {
      delegate.unprotect(in, out, alloc);
    }

    @Override
    public void destroy() {
      delegate.destroy();
    }
  }

  /** Kicks off negotiation of the server. */
  private static final class KickNegotiationHandler extends ChannelInboundHandlerAdapter {

    private final ChannelHandler next;

    KickNegotiationHandler(ChannelHandler next) {
      this.next = checkNotNull(next, "next");
    }

    @Override
    public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
      super.handlerAdded(ctx);
      ctx.pipeline().replace(ctx.name(), /*newName= */ null, next);
      ctx.pipeline().fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
    }
  }
}