package datasources;

import datasources.utils.DBClientWrapper;
import datasources.utils.DBTableReader;
import edb.common.Split;
import edb.common.UnknownTableException;
import org.apache.log4j.Logger;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.DataSourceV2;
import org.apache.spark.sql.sources.v2.ReadSupport;
import org.apache.spark.sql.sources.v2.reader.DataReader;
import org.apache.spark.sql.sources.v2.reader.DataReaderFactory;
import org.apache.spark.sql.sources.v2.reader.DataSourceReader;
import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning;
import org.apache.spark.sql.sources.v2.reader.partitioning.ClusteredDistribution;
import org.apache.spark.sql.sources.v2.reader.partitioning.Distribution;
import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

 * This DataSource also supports parallel reads (i.e.: on multiple executors)
 * from the ExampleDB.
 * The interesting feature of this example is that it supports informing the
 * Spark SQL optimizer whether the table is partitioned in the right way to avoid shuffles
 * in certain queries. One example is grouping queries, where shuffles can be avoided if the
 * table is clustered in such a way that each group (cluster) is fully contained in a
 * single partition. Since ExampleDB only supports clustered indexes on single columns,
 * in practice a shuffle can be avoided if the table is clustered on one of the grouping
 * columns. (In ExampleDB clustered tables, splits always respect clustering.)
 * It gets a table name from its configuration and infers a schema from
 * that table. If a number of partitions is specified in properties, it is used. Otherwise,
 * the table's default partition count (always 4 in ExampleDB) is used.
public class PartitioningRowDataSource implements DataSourceV2, ReadSupport {

    static Logger log = Logger.getLogger(PartitioningRowDataSource.class.getName());

     * Spark calls this to create the reader. Notice how it pulls the host and port
     * on which ExampleDB is listening, as well as a table name, from the supplied options.
     * @param options
     * @return
    public DataSourceReader createReader(DataSourceOptions options) {
        String host = options.get("host").orElse("localhost");
        int port = options.getInt("port", -1);
        String table = options.get("table").orElse("unknownTable"); // TODO: throw
        int partitions = Integer.parseInt(options.get("partitions").orElse("0"));
        return new Reader(host, port, table, partitions);

     * This is how Spark discovers the source table's schema by requesting it from ExmapleDB,
     * and how it obtains the reader factories to be used by the executors to create readers.
     * Notice that one factory is created for each partition.
    static class Reader implements SupportsReportPartitioning {

        static Logger log = Logger.getLogger(Reader.class.getName());

        public Reader(String host, int port, String table, int partitions) {
            _host = host;
            _port = port;
            _table = table;
            _requestedPartitions = partitions;

        private String _host;
        private int _port;
        private String _table;
        private int _requestedPartitions;

        // dynamic properties inferred from database

        private boolean _initialized = false;
        private StructType _schema;
        private String _clusteredColumn;
        private List<Split> _splits;

        private void initialize() {
            if (!_initialized) {
                DBClientWrapper db = new DBClientWrapper(_host, _port);
                try {
                    _schema = db.getSparkSchema(_table);
                    _clusteredColumn = db.getClusteredIndexColumn(_table);
                    if (_requestedPartitions == 0)
                        _splits = db.getSplits(_table);
                        _splits = db.getSplits(_table, _requestedPartitions);
                } catch (UnknownTableException ute) {
                    throw new RuntimeException(ute);
                } finally {
                _initialized = true;

        public StructType readSchema() {
  "schema requested for table [" + _table + "]");
            return _schema;

        public List<DataReaderFactory<Row>> createDataReaderFactories() {
  "reader factories requested for table [" + _table + "]");
            List<DataReaderFactory<Row>> factories = new ArrayList<>();
            for (Split split : _splits) {
                DataReaderFactory<Row> factory =
                        new SplitDataReaderFactory(_host, _port, _table, readSchema(), split);
            return factories;

        public Partitioning outputPartitioning() {
  "output partitioning requested for table [" + _table + "]");
            return new SingleClusteredColumnPartitioning(
                    _clusteredColumn, _splits.size());

    static class SingleClusteredColumnPartitioning implements Partitioning {

        static Logger log = Logger.getLogger(SingleClusteredColumnPartitioning.class.getName());

        public SingleClusteredColumnPartitioning(String columnName, int partitions) {
            _columnName = columnName;
            _partitions = partitions;

        public int numPartitions() {
  "asked for numPartitions");
            return _partitions;

        public boolean satisfy(Distribution distribution) {
            // Since Spark may add other Distribution policies in the future, we can't assume
            // it's always a ClusteredDistribution

            if (distribution instanceof ClusteredDistribution) {

                String[] clusteredCols = ((ClusteredDistribution) distribution).clusteredColumns;
                StringBuilder logEntryBuilder = new StringBuilder();
                logEntryBuilder.append("asked to satisfy ClusteredDistribution on columns ");
                if (clusteredCols.length > 0) {
                    for (String col : clusteredCols) {
                        logEntryBuilder.append("] ");
                if (_columnName == null) {
          "no cluster column so does not satisfy");
                    return false;
                } else {
                    boolean satisfies = Arrays.asList(clusteredCols).contains(_columnName);
          "based on cluster column: " + satisfies);
                    return satisfies;
  "asked to satisfy unknown distribution of type [" +
                    distribution.getClass().getCanonicalName() + "]");
            return false;

        private String _columnName;
        private int _partitions;

     * This is used by each executor to read from ExampleDB. It uses the Split to know
     * which data to read.
     * Also note that when DBClientWrapper's getTableReader() method is called
     * it reads ALL the data in its own Split eagerly.
    static class TaskDataReader implements DataReader<Row> {

        static Logger log = Logger.getLogger(TaskDataReader.class.getName());

        public TaskDataReader(String host, int port, String table,
                              StructType schema, Split split)
                throws UnknownTableException {
  "Task reading from [" + host + ":" + port + "]" );
            _db = new DBClientWrapper(host, port);
            _reader = _db.getTableReader(table, schema.fieldNames(), split);

        private DBClientWrapper _db;

        private DBTableReader _reader;

        public boolean next() {

        public Row get() {
            return _reader.get();

        public void close() throws IOException {

     * Note that this has to be serializable. Each instance is sent to an executor,
     * which uses it to create a reader for its own use.
    static class SplitDataReaderFactory implements DataReaderFactory<Row> {

        static Logger log = Logger.getLogger(SplitDataReaderFactory.class.getName());

        public SplitDataReaderFactory(String host, int port,
                                       String table, StructType schema,
                                       Split split) {
            _host = host;
            _port = port;
            _table = table;
            _schema = schema;
            _split = split;

        private String _host;
        private int _port;
        private String _table;
        private StructType _schema;
        private Split _split;

        public DataReader<Row> createDataReader() {
  "Factory creating reader for [" + _host + ":" + _port + "]" );
            try {
                return new TaskDataReader(_host, _port, _table, _schema, _split);
            } catch (UnknownTableException ute) {
                throw new RuntimeException(ute);

