/*
 * Copyright (C) 2019 Ryan Murray
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.flight.spark;

import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;

import org.apache.arrow.flight.Action;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightTestUtil;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.NoOpFlightProducer;
import org.apache.arrow.flight.Result;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.flight.auth.ServerAuthHandler;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.Text;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;

import com.google.common.collect.ImmutableList;

public class TestConnector {
  private static final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
  private static Location location;
  private static FlightServer server;
  private static SparkConf conf;
  private static JavaSparkContext sc;
  private static FlightSparkContext csc;

  @BeforeClass
  public static void setUp() throws Exception {
    server = FlightTestUtil.getStartedServer(location -> FlightServer.builder(allocator, location, new TestProducer()).authHandler(
      new ServerAuthHandler() {
        @Override
        public Optional<String> isValid(byte[] token) {
          return Optional.of("xxx");
        }

        @Override
        public boolean authenticate(ServerAuthSender outgoing, Iterator<byte[]> incoming) {
          incoming.next();
          outgoing.send(new byte[0]);
          return true;
        }
      }).build()
    );
    location = server.getLocation();
    conf = new SparkConf()
      .setAppName("flightTest")
      .setMaster("local[*]")
      .set("spark.driver.allowMultipleContexts", "true")
      .set("spark.flight.endpoint.host", location.getUri().getHost())
      .set("spark.flight.endpoint.port", Integer.toString(location.getUri().getPort()))
      .set("spark.flight.auth.username", "xxx")
      .set("spark.flight.auth.password", "yyy")
    ;
    sc = new JavaSparkContext(conf);
    csc = FlightSparkContext.flightContext(sc);
  }

  @AfterClass
  public static void tearDown() throws Exception {
    AutoCloseables.close(server, allocator, sc);
  }

  @Test
  public void testConnect() {
    csc.read("test.table");
  }

  @Test
  public void testRead() {
    long count = csc.read("test.table").count();
    Assert.assertEquals(20, count);
  }

  @Test
  public void testSql() {
    long count = csc.readSql("select * from test.table").count();
    Assert.assertEquals(20, count);
  }

  @Test
  public void testFilter() {
    Dataset<Row> df = csc.readSql("select * from test.table");
    long count = df.filter(df.col("symbol").equalTo("USDCAD")).count();
    long countOriginal = csc.readSql("select * from test.table").count();
    Assert.assertTrue(count < countOriginal);
  }

  private static class SizeConsumer implements Consumer<Row> {
    private int length = 0;
    private int width = 0;

    @Override
    public void accept(Row row) {
      length += 1;
      width = row.length();
    }
  }

  @Test
  public void testProject() {
    Dataset<Row> df = csc.readSql("select * from test.table");
    SizeConsumer c = new SizeConsumer();
    df.select("bid", "ask", "symbol").toLocalIterator().forEachRemaining(c);
    long count = c.width;
    long countOriginal = csc.readSql("select * from test.table").columns().length;
    Assert.assertTrue(count < countOriginal);
  }

  @Test
  public void testParallel() {
    String easySql = "select * from \"@dremio\".tpch_spark limit 100000";
    SizeConsumer c = new SizeConsumer();
    csc.readSql(easySql, true).toLocalIterator().forEachRemaining(c);
    long width = c.width;
    long length = c.length;
    Assert.assertEquals(5, width);
    Assert.assertEquals(40, length);
  }

  private static class TestProducer extends NoOpFlightProducer {
    private boolean parallel = false;

    @Override
    public void doAction(CallContext context, Action action, StreamListener<Result> listener) {
      parallel = true;
      listener.onNext(new Result("ok".getBytes()));
      listener.onCompleted();
    }

    @Override
    public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
      Schema schema;
      List<FlightEndpoint> endpoints;
      if (parallel) {
        endpoints = ImmutableList.of(new FlightEndpoint(new Ticket(descriptor.getCommand()), location),
          new FlightEndpoint(new Ticket(descriptor.getCommand()), location));
      } else {
        endpoints = ImmutableList.of(new FlightEndpoint(new Ticket(descriptor.getCommand()), location));
      }
      if (new String(descriptor.getCommand()).equals("select \"bid\", \"ask\", \"symbol\" from (select * from test.table))")) {
        schema = new Schema(ImmutableList.of(
          Field.nullable("bid", Types.MinorType.FLOAT8.getType()),
          Field.nullable("ask", Types.MinorType.FLOAT8.getType()),
          Field.nullable("symbol", Types.MinorType.VARCHAR.getType()))
        );

      } else {
        schema = new Schema(ImmutableList.of(
          Field.nullable("bid", Types.MinorType.FLOAT8.getType()),
          Field.nullable("ask", Types.MinorType.FLOAT8.getType()),
          Field.nullable("symbol", Types.MinorType.VARCHAR.getType()),
          Field.nullable("bidsize", Types.MinorType.BIGINT.getType()),
          Field.nullable("asksize", Types.MinorType.BIGINT.getType()))
        );
      }
      return new FlightInfo(schema, descriptor, endpoints, 1000000, 10);
    }

    @Override
    public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
      final int size = (new String(ticket.getBytes()).contains("USDCAD")) ? 5 : 10;

      if (new String(ticket.getBytes()).equals("select \"bid\", \"ask\", \"symbol\" from (select * from test.table))")) {
        Float8Vector b = new Float8Vector("bid", allocator);
        Float8Vector a = new Float8Vector("ask", allocator);
        VarCharVector s = new VarCharVector("symbol", allocator);

        VectorSchemaRoot root = VectorSchemaRoot.of(b, a, s);
        listener.start(root);

        //batch 1
        root.allocateNew();
        for (int i = 0; i < size; i++) {
          b.set(i, (double) i);
          a.set(i, (double) i);
          s.set(i, (i % 2 == 0) ? new Text("USDCAD") : new Text("EURUSD"));
        }
        b.setValueCount(size);
        a.setValueCount(size);
        s.setValueCount(size);
        root.setRowCount(size);
        listener.putNext();

        // batch 2

        root.allocateNew();
        for (int i = 0; i < size; i++) {
          b.set(i, (double) i);
          a.set(i, (double) i);
          s.set(i, (i % 2 == 0) ? new Text("USDCAD") : new Text("EURUSD"));
        }
        b.setValueCount(size);
        a.setValueCount(size);
        s.setValueCount(size);
        root.setRowCount(size);
        listener.putNext();
        root.clear();
        listener.completed();
      } else {
        BigIntVector bs = new BigIntVector("bidsize", allocator);
        BigIntVector as = new BigIntVector("asksize", allocator);
        Float8Vector b = new Float8Vector("bid", allocator);
        Float8Vector a = new Float8Vector("ask", allocator);
        VarCharVector s = new VarCharVector("symbol", allocator);

        VectorSchemaRoot root = VectorSchemaRoot.of(b, a, s, bs, as);
        listener.start(root);

        //batch 1
        root.allocateNew();
        for (int i = 0; i < size; i++) {
          bs.set(i, (long) i);
          as.set(i, (long) i);
          b.set(i, (double) i);
          a.set(i, (double) i);
          s.set(i, (i % 2 == 0) ? new Text("USDCAD") : new Text("EURUSD"));
        }
        bs.setValueCount(size);
        as.setValueCount(size);
        b.setValueCount(size);
        a.setValueCount(size);
        s.setValueCount(size);
        root.setRowCount(size);
        listener.putNext();

        // batch 2

        root.allocateNew();
        for (int i = 0; i < size; i++) {
          bs.set(i, (long) i);
          as.set(i, (long) i);
          b.set(i, (double) i);
          a.set(i, (double) i);
          s.set(i, (i % 2 == 0) ? new Text("USDCAD") : new Text("EURUSD"));
        }
        bs.setValueCount(size);
        as.setValueCount(size);
        b.setValueCount(size);
        a.setValueCount(size);
        s.setValueCount(size);
        root.setRowCount(size);
        listener.putNext();
        root.clear();
        listener.completed();
      }
    }


  }
}