/*
 * Copyright (C) 2017-2019 Dremio Corporation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.dremio.flight;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.StringWriter;
import java.net.InetAddress;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.security.cert.Certificate;
import java.util.Enumeration;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.auth.BasicClientAuthHandler;
import org.apache.arrow.memory.BufferAllocator;
import org.bouncycastle.openssl.jcajce.JcaPEMWriter;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.dremio.BaseTestQuery;
import com.dremio.config.DremioConfig;
import com.dremio.exec.rpc.ssl.SSLConfigurator;
import com.dremio.service.users.SystemUser;
import com.dremio.ssl.SSLConfig;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;

import io.protostuff.LinkedBuffer;

/**
 * Basic flight endpoint test
 */
public class TestSslFlightEndpoint extends BaseTestQuery {

  private static FlightInitializer fi;
  private static final LinkedBuffer buffer = LinkedBuffer.allocate();
  private static final ExecutorService tp = Executors.newFixedThreadPool(4);
  private static final Logger logger = LoggerFactory.getLogger(TestSslFlightEndpoint.class);
  private static DremioConfig dremioConfig;

  @ClassRule
  public static final TemporaryFolder tempFolder = new TemporaryFolder();

  @BeforeClass
  public static void init() throws Exception {
    System.setProperty("dremio.flight.use-ssl", "true");
    System.setProperty("dremio.flight.enabled", "true");
    dremioConfig = DremioConfig.create()
      .withValue(
        DremioConfig.WEB_SSL_PREFIX + DremioConfig.SSL_ENABLED,
        true)
      .withValue(
        DremioConfig.WEB_SSL_PREFIX + DremioConfig.SSL_AUTO_GENERATED_CERTIFICATE,
        true)
      .withValue(
        DremioConfig.LOCAL_WRITE_PATH_STRING, tempFolder.getRoot().getAbsolutePath());
    getBindingCreator().bind(DremioConfig.class, dremioConfig);
    fi = new FlightInitializer();
    fi.initialize(getBindingProvider());
  }

  @AfterClass
  public static void shutdown() throws Exception {
    fi.close();
  }

  private static InputStream certs() throws GeneralSecurityException, IOException {
    InetAddress ip = InetAddress.getLocalHost();
    final SSLConfigurator configurator = new SSLConfigurator(dremioConfig, DremioConfig.WEB_SSL_PREFIX, "web");
    final Optional<SSLConfig> sslConfigOption = configurator.getSSLConfig(true, ip.getHostName());
    Preconditions.checkState(sslConfigOption.isPresent()); // caller's responsibility
    final SSLConfig sslConfig = sslConfigOption.get();
    KeyStore trustStore = null;
    //noinspection StringEquality
    if (sslConfig.getTrustStorePath() != SSLConfig.UNSPECIFIED) {
      trustStore = KeyStore.getInstance(sslConfig.getTrustStoreType());
      try (InputStream stream = Files.newInputStream(Paths.get(sslConfig.getTrustStorePath()))) {
        trustStore.load(stream, sslConfig.getTrustStorePassword().toCharArray());
      }
    }
    Enumeration<String> es = trustStore.aliases();
    String alias = "";
    List<Certificate> certs = Lists.newArrayList();
    while (es.hasMoreElements()) {
      alias = (String) es.nextElement();
      // if alias refers to a private key break at that point
      // as we want to use that certificate
      certs.add(trustStore.getCertificate(alias));
    }
    toFile(certs);
    return certsToStream(certs);
  }

  private static void toFile(List<Certificate> certificates) throws IOException {
    File root = tempFolder.getRoot();
    Path certFile = Paths.get(root.getAbsolutePath(), "certs.pem");
    JcaPEMWriter writer = new JcaPEMWriter(new FileWriter(certFile.toFile()));
    for (final Certificate c : certificates) {
      writer.writeObject(c);
    }
    writer.close();
  }

  private static InputStream certsToStream(List<Certificate> certs) throws IOException {

    final StringWriter writer = new StringWriter();
    final JcaPEMWriter pemWriter = new JcaPEMWriter(writer);
    for (Certificate cert : certs) {
      pemWriter.writeObject(cert);
    }
    pemWriter.flush();
    pemWriter.close();
    String pemString = writer.toString();
    return new ByteArrayInputStream(pemString.getBytes());
  }

  private static FlightClient flightClient(BufferAllocator allocator, Location location) {
    try {
      InputStream certStream = certs();
      return FlightClient.builder()
        .allocator(allocator)
        .location(location)
        .useTls()
        .trustedCertificates(certStream)
        .build();
    } catch (GeneralSecurityException | IOException e) {
      throw new RuntimeException(e);
    }
  }

  @Test
  public void connect() throws Exception {
    certs();
    InetAddress ip = InetAddress.getLocalHost();
    Location location = Location.forGrpcTls(ip.getHostName(), 47470);
    try (FlightClient c = flightClient(getAllocator(), location)) {
      c.authenticate(new BasicClientAuthHandler(SystemUser.SYSTEM_USERNAME, null));
      String sql = "select * from sys.options";
      FlightInfo info = c.getInfo(FlightDescriptor.command(sql.getBytes()));
      long total = info.getEndpoints().stream()
        .map(this::submit)
        .map(TestSslFlightEndpoint::get)
        .mapToLong(Long::longValue)
        .sum();

      Assert.assertTrue(total > 1);
      System.out.println(total);
    }
  }

  private static AtomicInteger endpointsSubmitted = new AtomicInteger();
  private static AtomicInteger endpointsWaitingOn = new AtomicInteger();
  private static AtomicInteger endpointsReceived = new AtomicInteger();

  private Future<Long> submit(FlightEndpoint e) {
    int thisEndpoint = endpointsSubmitted.incrementAndGet();
    logger.debug("submitting flight endpoint {} with ticket {} to {}", thisEndpoint, new String(e.getTicket().getBytes()), e.getLocations().get(0).getUri());
    RunnableReader reader = new RunnableReader(allocator, e);
    Future<Long> f = tp.submit(reader);
    logger.debug("submitted flight endpoint {} with ticket {} to {}", thisEndpoint, new String(e.getTicket().getBytes()), e.getLocations().get(0).getUri());
    return f;
  }

  private static Long get(Future<Long> r) {
    try {
      logger.debug("starting wait on future {} of {}", endpointsWaitingOn.incrementAndGet(), endpointsSubmitted.get());
      Long f = r.get();
      logger.debug("returned future {} of {} with value {}", endpointsReceived.incrementAndGet(), endpointsSubmitted.get(), f);
      return f;
    } catch (Throwable t) {
      throw new RuntimeException(t);
    }
  }


  private static final class RunnableReader implements Callable<Long> {
    private final BufferAllocator allocator;
    private FlightEndpoint endpoint;

    private RunnableReader(BufferAllocator allocator, FlightEndpoint endpoint) {
      this.allocator = allocator;
      this.endpoint = endpoint;
    }

    @Override
    public Long call() {
      long count = 0;
      int readIndex = 0;
      logger.debug("starting work on flight endpoint with ticket {} to {}", new String(endpoint.getTicket().getBytes()), endpoint.getLocations().get(0).getUri());
      try (FlightClient c = flightClient(allocator, endpoint.getLocations().get(0))) {
        c.authenticate(new BasicClientAuthHandler(SystemUser.SYSTEM_USERNAME, null));
        logger.debug("trying to get stream for flight endpoint with ticket {} to {}", new String(endpoint.getTicket().getBytes()), endpoint.getLocations().get(0).getUri());
        FlightStream fs = c.getStream(endpoint.getTicket());
        logger.debug("got stream for flight endpoint with ticket {} to {}. Will now try and read", new String(endpoint.getTicket().getBytes()), endpoint.getLocations().get(0).getUri());
        while (fs.next()) {
          long thisCount = fs.getRoot().getRowCount();
          count += thisCount;
          logger.debug("got results from stream for flight endpoint with ticket {} to {}. This is read {} and we got {} rows back for a total of {}", new String(endpoint.getTicket().getBytes()), endpoint.getLocations().get(0).getUri(), ++readIndex, thisCount, count);
          fs.getRoot().clear();
        }
      } catch (InterruptedException e) {

      } catch (Throwable t) {
        logger.error("Error in stream fetch", t);
      }
      logger.debug("got all results from stream for flight endpoint with ticket {} to {}. We read {} batches and we got {} rows back", new String(endpoint.getTicket().getBytes()), endpoint.getLocations().get(0).getUri(), ++readIndex, count);
      return count;
    }
  }
}