/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.codahale.grpcproxy;

import com.codahale.grpcproxy.helloworld.GreeterGrpc;
import com.codahale.grpcproxy.helloworld.HelloRequest;
import com.codahale.grpcproxy.stats.Recorder;
import com.codahale.grpcproxy.stats.Snapshot;
import com.codahale.grpcproxy.util.Netty;
import com.codahale.grpcproxy.util.TlsContext;
import io.airlift.airline.Command;
import io.airlift.airline.Option;
import io.grpc.ManagedChannel;
import io.grpc.StatusRuntimeException;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.channel.EventLoopGroup;
import java.time.Duration;
import java.time.Instant;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLException;
import net.logstash.logback.marker.Markers;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** A gRPC client. This could be in any language. */
class HelloWorldClient {

  private static final Logger LOGGER = LoggerFactory.getLogger(HelloWorldClient.class);
  private final EventLoopGroup eventLoopGroup;
  private final ManagedChannel channel;
  private final GreeterGrpc.GreeterBlockingStub blockingStub;

  private HelloWorldClient(String host, int port, TlsContext tls) throws SSLException {
    this.eventLoopGroup = Netty.newWorkerEventLoopGroup();
    this.channel =
        NettyChannelBuilder.forAddress(host, port)
            .eventLoopGroup(eventLoopGroup)
            .channelType(Netty.clientChannelType())
            .sslContext(tls.toClientContext())
            .build();
    this.blockingStub = GreeterGrpc.newBlockingStub(channel);
  }

  private void shutdown() throws InterruptedException {
    channel.shutdown().awaitTermination(5, TimeUnit.SECONDS);
    eventLoopGroup.shutdownGracefully(0, 0, TimeUnit.SECONDS);
  }

  private String greet(int i) {
    try {
      final HelloRequest request = HelloRequest.newBuilder().setName("world " + i).build();
      return blockingStub.sayHello(request).getMessage();
    } catch (StatusRuntimeException e) {
      LOGGER.warn("RPC failed: {}", e.getStatus());
      return null;
    }
  }

  @Command(name = "client", description = "Runs a bunch of HelloWorld client calls.")
  public static class Cmd implements Runnable {

    @Option(
      name = {"-h", "--hostname"},
      description = "the hostname of the gRPC server"
    )
    private String hostname = "localhost";

    @Option(
      name = {"-p", "--port"},
      description = "the port of the gRPC server"
    )
    private int port = 50051;

    @Option(
      name = {"-n", "--requests"},
      description = "the number of requests to make"
    )
    private int requests = 1_000_000;

    @Option(
      name = {"-c", "--threads"},
      description = "the number of threads to use"
    )
    private int threads = 10;

    @Option(name = "--ca-certs")
    private String trustedCertsPath = "cert.crt";

    @Option(name = "--cert")
    private String certPath = "cert.crt";

    @Option(name = "--key")
    private String keyPath = "cert.key";

    @Override
    public void run() {
      try {
        final TlsContext tls = new TlsContext(trustedCertsPath, certPath, keyPath);
        final HelloWorldClient client = new HelloWorldClient(hostname, port, tls);
        try {
          final Recorder recorder =
              new Recorder(
                  500,
                  TimeUnit.MINUTES.toMicros(1),
                  TimeUnit.MILLISECONDS.toMicros(10),
                  TimeUnit.MICROSECONDS);
          LOGGER.info("Initial request: {}", client.greet(requests));
          LOGGER.info("Sending {} requests from {} threads", requests, threads);

          final ExecutorService threadPool = Executors.newFixedThreadPool(threads);
          final Instant start = Instant.now();
          for (int i = 0; i < threads; i++) {
            threadPool.execute(
                () -> {
                  for (int j = 0; j < requests / threads; j++) {
                    final long t = System.nanoTime();
                    client.greet(j);
                    recorder.record(t);
                  }
                });
          }
          threadPool.shutdown();
          threadPool.awaitTermination(20, TimeUnit.MINUTES);

          final Snapshot stats = recorder.interval();
          final Duration duration = Duration.between(start, Instant.now());
          LOGGER.info(
              Markers.append("stats", stats).and(Markers.append("duration", duration.toString())),
              "{} requests in {} ({} req/sec)",
              stats.count(),
              duration,
              stats.throughput());
        } finally {
          client.shutdown();
        }
      } catch (SSLException | InterruptedException e) {
        LOGGER.error("Error running command", e);
      }
    }
  }
}