package org.datavec.spark.transform;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.nd4j.linalg.primitives.Pair;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*;
import org.datavec.spark.transform.sparkfunction.SequenceToRows;
import org.datavec.spark.transform.sparkfunction.ToRecord;
import org.datavec.spark.transform.sparkfunction.ToRow;
import org.datavec.spark.transform.sparkfunction.sequence.DataFrameToSequenceCreateCombiner;
import org.datavec.spark.transform.sparkfunction.sequence.DataFrameToSequenceMergeCombiner;
import org.datavec.spark.transform.sparkfunction.sequence.DataFrameToSequenceMergeValue;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.util.ArrayList;
import java.util.List;

import static org.apache.spark.sql.functions.avg;
import static org.apache.spark.sql.functions.col;
import static org.datavec.spark.transform.DataRowsFacade.dataRows;


/**
 * Namespace for datavec
 * dataframe interop
 *
 * @author Adam Gibson
 */
public class DataFrames {

    public static final String SEQUENCE_UUID_COLUMN = "__SEQ_UUID";
    public static final String SEQUENCE_INDEX_COLUMN = "__SEQ_IDX";

    private DataFrames() {}

    /**
     * Standard deviation for a column
     *
     * @param dataFrame  the dataframe to
     *                   get the column from
     * @param columnName the name of the column to get the standard
     *                   deviation for
     * @return the column that represents the standard deviation
     */
    public static Column std(DataRowsFacade dataFrame, String columnName) {
        return functions.sqrt(var(dataFrame, columnName));
    }


    /**
     * Standard deviation for a column
     *
     * @param dataFrame  the dataframe to
     *                   get the column from
     * @param columnName the name of the column to get the standard
     *                   deviation for
     * @return the column that represents the standard deviation
     */
    public static Column var(DataRowsFacade dataFrame, String columnName) {
        return dataFrame.get().groupBy(columnName).agg(functions.variance(columnName)).col(columnName);
    }

    /**
     * MIn for a column
     *
     * @param dataFrame  the dataframe to
     *                   get the column from
     * @param columnName the name of the column to get the min for
     * @return the column that represents the min
     */
    public static Column min(DataRowsFacade dataFrame, String columnName) {
        return dataFrame.get().groupBy(columnName).agg(functions.min(columnName)).col(columnName);
    }

    /**
     * Max for a column
     *
     * @param dataFrame  the dataframe to
     *                   get the column from
     * @param columnName the name of the column
     *                   to get the max for
     * @return the column that represents the max
     */
    public static Column max(DataRowsFacade dataFrame, String columnName) {
        return dataFrame.get().groupBy(columnName).agg(functions.max(columnName)).col(columnName);
    }

    /**
     * Mean for a column
     *
     * @param dataFrame  the dataframe to
     *                   get the column fron
     * @param columnName the name of the column to get the mean for
     * @return the column that represents the mean
     */
    public static Column mean(DataRowsFacade dataFrame, String columnName) {
        return dataFrame.get().groupBy(columnName).agg(avg(columnName)).col(columnName);
    }

    /**
     * Convert a datavec schema to a
     * struct type in spark
     *
     * @param schema the schema to convert
     * @return the datavec struct type
     */
    public static StructType fromSchema(Schema schema) {
        StructField[] structFields = new StructField[schema.numColumns()];
        for (int i = 0; i < structFields.length; i++) {
            switch (schema.getColumnTypes().get(i)) {
                case Double:
                    structFields[i] = new StructField(schema.getName(i), DataTypes.DoubleType, false, Metadata.empty());
                    break;
                case Integer:
                    structFields[i] =
                                    new StructField(schema.getName(i), DataTypes.IntegerType, false, Metadata.empty());
                    break;
                case Long:
                    structFields[i] = new StructField(schema.getName(i), DataTypes.LongType, false, Metadata.empty());
                    break;
                case Float:
                    structFields[i] = new StructField(schema.getName(i), DataTypes.FloatType, false, Metadata.empty());
                    break;
                default:
                    throw new IllegalStateException(
                                    "This api should not be used with strings , binary data or ndarrays. This is only for columnar data");
            }
        }
        return new StructType(structFields);
    }

    /**
     * Convert the DataVec sequence schema to a StructType for Spark, for example for use in
     * {@link #toDataFrameSequence(Schema, JavaRDD)}}
     * <b>Note</b>: as per {@link #toDataFrameSequence(Schema, JavaRDD)}}, the StructType has two additional columns added to it:<br>
     * - Column 0: Sequence UUID (name: {@link #SEQUENCE_UUID_COLUMN}) - a UUID for the original sequence<br>
     * - Column 1: Sequence index (name: {@link #SEQUENCE_INDEX_COLUMN} - an index (integer, starting at 0) for the position
     * of this record in the original time series.<br>
     * These two columns are required if the data is to be converted back into a sequence at a later point, for example
     * using {@link #toRecordsSequence(DataRowsFacade)}
     *
     * @param schema Schema to convert
     * @return StructType for the schema
     */
    public static StructType fromSchemaSequence(Schema schema) {
        StructField[] structFields = new StructField[schema.numColumns() + 2];

        structFields[0] = new StructField(SEQUENCE_UUID_COLUMN, DataTypes.StringType, false, Metadata.empty());
        structFields[1] = new StructField(SEQUENCE_INDEX_COLUMN, DataTypes.IntegerType, false, Metadata.empty());

        for (int i = 0; i < schema.numColumns(); i++) {
            switch (schema.getColumnTypes().get(i)) {
                case Double:
                    structFields[i + 2] =
                                    new StructField(schema.getName(i), DataTypes.DoubleType, false, Metadata.empty());
                    break;
                case Integer:
                    structFields[i + 2] =
                                    new StructField(schema.getName(i), DataTypes.IntegerType, false, Metadata.empty());
                    break;
                case Long:
                    structFields[i + 2] =
                                    new StructField(schema.getName(i), DataTypes.LongType, false, Metadata.empty());
                    break;
                case Float:
                    structFields[i + 2] =
                                    new StructField(schema.getName(i), DataTypes.FloatType, false, Metadata.empty());
                    break;
                default:
                    throw new IllegalStateException(
                                    "This api should not be used with strings , binary data or ndarrays. This is only for columnar data");
            }
        }
        return new StructType(structFields);
    }


    /**
     * Create a datavec schema
     * from a struct type
     *
     * @param structType the struct type to create the schema from
     * @return the created schema
     */
    public static Schema fromStructType(StructType structType) {
        Schema.Builder builder = new Schema.Builder();
        StructField[] fields = structType.fields();
        String[] fieldNames = structType.fieldNames();
        for (int i = 0; i < fields.length; i++) {
            String name = fields[i].dataType().typeName().toLowerCase();
            switch (name) {
                case "double":
                    builder.addColumnDouble(fieldNames[i]);
                    break;
                case "float":
                    builder.addColumnFloat(fieldNames[i]);
                    break;
                case "long":
                    builder.addColumnLong(fieldNames[i]);
                    break;
                case "int":
                case "integer":
                    builder.addColumnInteger(fieldNames[i]);
                    break;
                case "string":
                    builder.addColumnString(fieldNames[i]);
                    break;
                default:
                    throw new RuntimeException("Unknown type: " + name);
            }
        }

        return builder.build();
    }


    /**
     * Create a compatible schema
     * and rdd for datavec
     *
     * @param dataFrame the dataframe to convert
     * @return the converted schema and rdd of writables
     */
    public static Pair<Schema, JavaRDD<List<Writable>>> toRecords(DataRowsFacade dataFrame) {
        Schema schema = fromStructType(dataFrame.get().schema());
        return new Pair<>(schema, dataFrame.get().javaRDD().map(new ToRecord(schema)));
    }

    /**
     * Convert the given DataFrame to a sequence<br>
     * <b>Note</b>: It is assumed here that the DataFrame has been created by {@link #toDataFrameSequence(Schema, JavaRDD)}.
     * In particular:<br>
     * - the first column is a UUID for the original sequence the row is from<br>
     * - the second column is a time step index: where the row appeared in the original sequence<br>
     * <p>
     * Typical use: Normalization via the {@link Normalization} static methods
     *
     * @param dataFrame Data frame to convert
     * @return Data in sequence (i.e., {@code List<List<Writable>>} form
     */
    public static Pair<Schema, JavaRDD<List<List<Writable>>>> toRecordsSequence(DataRowsFacade dataFrame) {

        //Need to convert from flattened to sequence data...
        //First: Group by the Sequence UUID (first column)
        JavaPairRDD<String, Iterable<Row>> grouped = dataFrame.get().javaRDD().groupBy(new Function<Row, String>() {
            @Override
            public String call(Row row) throws Exception {
                return row.getString(0);
            }
        });


        Schema schema = fromStructType(dataFrame.get().schema());

        //Group by sequence UUID, and sort each row within the sequences using the time step index
        Function<Iterable<Row>, List<List<Writable>>> createCombiner = new DataFrameToSequenceCreateCombiner(schema); //Function to create the initial combiner
        Function2<List<List<Writable>>, Iterable<Row>, List<List<Writable>>> mergeValue =
                        new DataFrameToSequenceMergeValue(schema); //Function to add a row
        Function2<List<List<Writable>>, List<List<Writable>>, List<List<Writable>>> mergeCombiners =
                        new DataFrameToSequenceMergeCombiner(); //Function to merge existing sequence writables

        JavaRDD<List<List<Writable>>> sequences =
                        grouped.combineByKey(createCombiner, mergeValue, mergeCombiners).values();

        //We no longer want/need the sequence UUID and sequence time step columns - extract those out
        JavaRDD<List<List<Writable>>> out = sequences.map(new Function<List<List<Writable>>, List<List<Writable>>>() {
            @Override
            public List<List<Writable>> call(List<List<Writable>> v1) throws Exception {
                List<List<Writable>> out = new ArrayList<>(v1.size());
                for (List<Writable> l : v1) {
                    List<Writable> subset = new ArrayList<>();
                    for (int i = 2; i < l.size(); i++) {
                        subset.add(l.get(i));
                    }
                    out.add(subset);
                }
                return out;
            }
        });

        return new Pair<>(schema, out);
    }

    /**
     * Creates a data frame from a collection of writables
     * rdd given a schema
     *
     * @param schema the schema to use
     * @param data   the data to convert
     * @return the dataframe object
     */
    public static DataRowsFacade toDataFrame(Schema schema, JavaRDD<List<Writable>> data) {
        JavaSparkContext sc = new JavaSparkContext(data.context());
        SQLContext sqlContext = new SQLContext(sc);
        JavaRDD<Row> rows = data.map(new ToRow(schema));
        return dataRows(sqlContext.createDataFrame(rows, fromSchema(schema)));
    }


    /**
     * Convert the given sequence data set to a DataFrame.<br>
     * <b>Note</b>: The resulting DataFrame has two additional columns added to it:<br>
     * - Column 0: Sequence UUID (name: {@link #SEQUENCE_UUID_COLUMN}) - a UUID for the original sequence<br>
     * - Column 1: Sequence index (name: {@link #SEQUENCE_INDEX_COLUMN} - an index (integer, starting at 0) for the position
     * of this record in the original time series.<br>
     * These two columns are required if the data is to be converted back into a sequence at a later point, for example
     * using {@link #toRecordsSequence(DataRowsFacade)}
     *
     * @param schema Schema for the data
     * @param data   Sequence data to convert to a DataFrame
     * @return The dataframe object
     */
    public static DataRowsFacade toDataFrameSequence(Schema schema, JavaRDD<List<List<Writable>>> data) {
        JavaSparkContext sc = new JavaSparkContext(data.context());

        SQLContext sqlContext = new SQLContext(sc);
        JavaRDD<Row> rows = data.flatMap(new SequenceToRows(schema));
        return dataRows(sqlContext.createDataFrame(rows, fromSchemaSequence(schema)));
    }

    /**
     * Convert a given Row to a list of writables, given the specified Schema
     *
     * @param schema Schema for the data
     * @param row    Row of data
     */
    public static List<Writable> rowToWritables(Schema schema, Row row) {
        List<Writable> ret = new ArrayList<>();
        for (int i = 0; i < row.size(); i++) {
            switch (schema.getType(i)) {
                case Double:
                    ret.add(new DoubleWritable(row.getDouble(i)));
                    break;
                case Float:
                    ret.add(new FloatWritable(row.getFloat(i)));
                    break;
                case Integer:
                    ret.add(new IntWritable(row.getInt(i)));
                    break;
                case Long:
                    ret.add(new LongWritable(row.getLong(i)));
                    break;
                case String:
                    ret.add(new Text(row.getString(i)));
                    break;
                default:
                    throw new IllegalStateException("Illegal type");
            }
        }
        return ret;
    }

    /**
     * Convert a string array into a list
     * @param input the input to create the list from
     * @return the created array
     */
    public static List<String> toList(String[] input) {
        List<String> ret = new ArrayList<>();
        for (int i = 0; i < input.length; i++)
            ret.add(input[i]);
        return ret;
    }


    /**
     * Convert a string list into a array
     * @param list the input to create the array from
     * @return the created list
     */
    public static String[] toArray(List<String> list) {
        String[] ret = new String[list.size()];
        for (int i = 0; i < ret.length; i++)
            ret[i] = list.get(i);
        return ret;
    }

    /**
     * Convert a list of rows to a matrix
     * @param rows the list of rows to convert
     * @return the converted matrix
     */
    public static INDArray toMatrix(List<Row> rows) {
        INDArray ret = Nd4j.create(rows.size(), rows.get(0).size());
        for (int i = 0; i < ret.rows(); i++) {
            for (int j = 0; j < ret.columns(); j++) {
                ret.putScalar(i, j, rows.get(i).getDouble(j));
            }
        }
        return ret;
    }


    /**
     * Convert a list of string names
     * to columns
     * @param columns the columns to convert
     * @return the resulting column list
     */
    public static List<Column> toColumn(List<String> columns) {
        List<Column> ret = new ArrayList<>();
        for (String s : columns)
            ret.add(col(s));
        return ret;
    }

    /**
     * Convert an array of strings
     * to column names
     * @param columns the columns to convert
     * @return the converted columns
     */
    public static Column[] toColumns(String... columns) {
        Column[] ret = new Column[columns.length];
        for (int i = 0; i < columns.length; i++)
            ret[i] = col(columns[i]);
        return ret;
    }

}