package datasources;

import datasources.utils.DBClientWrapper;
import datasources.utils.DBTableReader;
import edb.common.Split;
import edb.common.UnknownTableException;
import org.apache.log4j.Logger;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.DataSourceV2;
import org.apache.spark.sql.sources.v2.ReadSupport;
import org.apache.spark.sql.sources.v2.reader.DataReader;
import org.apache.spark.sql.sources.v2.reader.DataReaderFactory;
import org.apache.spark.sql.sources.v2.reader.DataSourceReader;
import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning;
import org.apache.spark.sql.sources.v2.reader.partitioning.ClusteredDistribution;
import org.apache.spark.sql.sources.v2.reader.partitioning.Distribution;
import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning;
import org.apache.spark.sql.types.StructType;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * This DataSource also supports parallel reads (i.e.: on multiple executors)
 * from the ExampleDB.
 *
 * The interesting feature of this example is that it supports informing the
 * Spark SQL optimizer whether the table is partitioned in the right way to avoid shuffles
 * in certain queries. One example is grouping queries, where shuffles can be avoided if the
 * table is clustered in such a way that each group (cluster) is fully contained in a
 * single partition. Since ExampleDB only supports clustered indexes on single columns,
 * in practice a shuffle can be avoided if the table is clustered on one of the grouping
 * columns. (In ExampleDB clustered tables, splits always respect clustering.)
 *
 * It gets a table name from its configuration and infers a schema from
 * that table. If a number of partitions is specified in properties, it is used. Otherwise,
 * the table's default partition count (always 4 in ExampleDB) is used.
 */
public class PartitioningRowDataSource implements DataSourceV2, ReadSupport {

    static Logger log = Logger.getLogger(PartitioningRowDataSource.class.getName());

    /**
     * Spark calls this to create the reader. Notice how it pulls the host and port
     * on which ExampleDB is listening, as well as a table name, from the supplied options.
     * @param options
     * @return
     */
    @Override
    public DataSourceReader createReader(DataSourceOptions options) {
        String host = options.get("host").orElse("localhost");
        int port = options.getInt("port", -1);
        String table = options.get("table").orElse("unknownTable"); // TODO: throw
        int partitions = Integer.parseInt(options.get("partitions").orElse("0"));
        return new Reader(host, port, table, partitions);
    }

    /**
     * This is how Spark discovers the source table's schema by requesting it from ExmapleDB,
     * and how it obtains the reader factories to be used by the executors to create readers.
     * Notice that one factory is created for each partition.
     */
    static class Reader implements SupportsReportPartitioning {

        static Logger log = Logger.getLogger(Reader.class.getName());

        public Reader(String host, int port, String table, int partitions) {
            _host = host;
            _port = port;
            _table = table;
            _requestedPartitions = partitions;
        }

        private String _host;
        private int _port;
        private String _table;
        private int _requestedPartitions;

        //
        // dynamic properties inferred from database
        //

        private boolean _initialized = false;
        private StructType _schema;
        private String _clusteredColumn;
        private List<Split> _splits;


        private void initialize() {
            if (!_initialized) {
                log.info("initializing");
                DBClientWrapper db = new DBClientWrapper(_host, _port);
                db.connect();
                try {
                    _schema = db.getSparkSchema(_table);
                    _clusteredColumn = db.getClusteredIndexColumn(_table);
                    if (_requestedPartitions == 0)
                        _splits = db.getSplits(_table);
                    else
                        _splits = db.getSplits(_table, _requestedPartitions);
                } catch (UnknownTableException ute) {
                    throw new RuntimeException(ute);
                } finally {
                    db.disconnect();
                }
                _initialized = true;
                log.info("initialized");
            }
        }

        @Override
        public StructType readSchema() {
            log.info("schema requested for table [" + _table + "]");
            initialize();
            return _schema;
        }

        @Override
        public List<DataReaderFactory<Row>> createDataReaderFactories() {
            log.info("reader factories requested for table [" + _table + "]");
            initialize();
            List<DataReaderFactory<Row>> factories = new ArrayList<>();
            for (Split split : _splits) {
                DataReaderFactory<Row> factory =
                        new SplitDataReaderFactory(_host, _port, _table, readSchema(), split);
                factories.add(factory);
            }
            return factories;
        }

        @Override
        public Partitioning outputPartitioning() {
            log.info("output partitioning requested for table [" + _table + "]");
            return new SingleClusteredColumnPartitioning(
                    _clusteredColumn, _splits.size());
        }
    }

    static class SingleClusteredColumnPartitioning implements Partitioning {

        static Logger log = Logger.getLogger(SingleClusteredColumnPartitioning.class.getName());

        public SingleClusteredColumnPartitioning(String columnName, int partitions) {
            _columnName = columnName;
            _partitions = partitions;
        }

        @Override
        public int numPartitions() {
            log.info("asked for numPartitions");
            return _partitions;
        }

        @Override
        public boolean satisfy(Distribution distribution) {
            //
            // Since Spark may add other Distribution policies in the future, we can't assume
            // it's always a ClusteredDistribution
            //

            if (distribution instanceof ClusteredDistribution) {

                String[] clusteredCols = ((ClusteredDistribution) distribution).clusteredColumns;
                StringBuilder logEntryBuilder = new StringBuilder();
                logEntryBuilder.append("asked to satisfy ClusteredDistribution on columns ");
                if (clusteredCols.length > 0) {
                    for (String col : clusteredCols) {
                        logEntryBuilder.append("[");
                        logEntryBuilder.append(col);
                        logEntryBuilder.append("] ");
                    }
                }
                log.info(logEntryBuilder.toString());
                if (_columnName == null) {
                    log.info("no cluster column so does not satisfy");
                    return false;
                } else {
                    boolean satisfies = Arrays.asList(clusteredCols).contains(_columnName);
                    log.info("based on cluster column: " + satisfies);
                    return satisfies;
                }
            }
            log.info("asked to satisfy unknown distribution of type [" +
                    distribution.getClass().getCanonicalName() + "]");
            return false;
        }

        private String _columnName;
        private int _partitions;
    }

    /**
     * This is used by each executor to read from ExampleDB. It uses the Split to know
     * which data to read.
     * Also note that when DBClientWrapper's getTableReader() method is called
     * it reads ALL the data in its own Split eagerly.
     */
    static class TaskDataReader implements DataReader<Row> {

        static Logger log = Logger.getLogger(TaskDataReader.class.getName());

        public TaskDataReader(String host, int port, String table,
                              StructType schema, Split split)
                throws UnknownTableException {
            log.info("Task reading from [" + host + ":" + port + "]" );
            _db = new DBClientWrapper(host, port);
            _db.connect();
            _reader = _db.getTableReader(table, schema.fieldNames(), split);
        }

        private DBClientWrapper _db;

        private DBTableReader _reader;

        @Override
        public boolean next() {
            return _reader.next();
        }

        @Override
        public Row get() {
            return _reader.get();
        }

        @Override
        public void close() throws IOException {
            _db.disconnect();
        }
    }

    /**
     * Note that this has to be serializable. Each instance is sent to an executor,
     * which uses it to create a reader for its own use.
     */
    static class SplitDataReaderFactory implements DataReaderFactory<Row> {

        static Logger log = Logger.getLogger(SplitDataReaderFactory.class.getName());

        public SplitDataReaderFactory(String host, int port,
                                       String table, StructType schema,
                                       Split split) {
            _host = host;
            _port = port;
            _table = table;
            _schema = schema;
            _split = split;
        }

        private String _host;
        private int _port;
        private String _table;
        private StructType _schema;
        private Split _split;

        @Override
        public DataReader<Row> createDataReader() {
            log.info("Factory creating reader for [" + _host + ":" + _port + "]" );
            try {
                return new TaskDataReader(_host, _port, _table, _schema, _split);
            } catch (UnknownTableException ute) {
                throw new RuntimeException(ute);
            }
        }

    }


}