package datasources.utils;

import edb.client.DBClient;
import edb.common.*;

import org.apache.log4j.Logger;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.List;

public class DBClientWrapper {

    static Logger log = Logger.getLogger(DBClientWrapper.class.getName());

    public DBClientWrapper(String host, int port) {
        _host = host;
        _port = port;
    }


    public void bulkInsertFromTables(String destination,
                                     boolean truncateDestination,
                                     List<String> sources) {
        try {
            _client.bulkInsertFromTables(destination, truncateDestination, sources);
        } catch (UnknownTableException ute) {
            throw new RuntimeException(ute);
        }
    }

    public static org.apache.spark.sql.Row dbToSparkRow(Row dbRow) {
        int fieldCount = dbRow.getFieldCount();
        Object[] values = new Object[fieldCount];
        for (int i = 0; i < fieldCount; i++) {
            Row.Field f = dbRow.getField(i);
            if (f instanceof Row.Int64Field) {
                Row.Int64Field i64f = (Row.Int64Field) f;
                values[i] = i64f.getValue();
            } else if (f instanceof Row.DoubleField) {
                Row.DoubleField df = (Row.DoubleField) f;
                values[i] = df.getValue();
            } else if (f instanceof Row.StringField) {
                Row.StringField df = (Row.StringField) f;
                values[i] = df.getValue();
            }
        }
        GenericRow row = new GenericRow(values);
        return row;
    }

    public String getHost() { return _host; }

    public int getPort() { return _port; }

    public void connect() {
        log.info("Connecting to DB [" + _host + ":" + _port + "]");
        _client = new DBClient(_host, _port);
    }

    public void createTable(String name, Schema schema) {
        try {
            _client.createTable(name, schema);
        } catch (ExistingTableException ete) {
            throw new RuntimeException(ete);
        }
    }

    public boolean tableExists(String name) {
        try {
            _client.getTableSchema(name);
            return true;
        } catch (UnknownTableException ute) {
            return false;
        }
    }

    public boolean tableHasCompatibleSchema(String name, Schema required) {
        try {
            Schema actual =_client.getTableSchema(name);
            return actual.isCompatible(required);
        } catch (UnknownTableException ute) {
            throw new RuntimeException(ute);
        }
    }

    public String saveToTempTable(List<edb.common.Row> rows, Schema schema) {
        String tableName = _client.createTemporaryTable(schema);

        try {
            _client.bulkInsert(tableName, rows);
        } catch (UnknownTableException ute) {
            // can't happen since we just created it as a temp table
            throw new RuntimeException(ute);
        }
        return tableName;

    }

    public DBTableReader getTableReader(String tableName, String[] columnNames)
    throws UnknownTableException
    {
        return new DBTableReader(tableName, _client, columnNames);
    }

    public DBTableReader getTableReader(String tableName, String[] columnNames, Split split)
            throws UnknownTableException
    {
        return new DBTableReader(tableName, _client, columnNames, split);
    }

    public void disconnect() {

    }

    public static Schema sparkToDbSchema(StructType st) {
        Schema schema = new Schema();
        for (StructField sf: st.fields()) {
            if (sf.dataType() == DataTypes.StringType) {
                schema.addColumn(sf.name(), Schema.ColumnType.STRING);
            } else if (sf.dataType() == DataTypes.DoubleType) {
                schema.addColumn(sf.name(), Schema.ColumnType.DOUBLE);
            } else if (sf.dataType() == DataTypes.LongType) {
                schema.addColumn(sf.name(), Schema.ColumnType.INT64);
            } else {
                // TODO: type leakage
            }
        }
        return schema;
    }

    public static StructType dbToSparkSchema(Schema schema) {
        List<StructField> fields = new ArrayList<>();
        for (int i = 0; i < schema.getColumnCount(); i++) {
            String name = schema.getColumnName(i);
            switch (schema.getColumnType(i)) {
                case INT64:
                    fields.add(DataTypes.createStructField(name,
                            DataTypes.LongType, true));
                    break;
                case DOUBLE:
                    fields.add(DataTypes.createStructField(name,
                            DataTypes.DoubleType, true));
                    break;
                case STRING:
                    fields.add(DataTypes.createStructField(name,
                            DataTypes.StringType, true));
                    break;
                default:
            }
        }
        return DataTypes.createStructType(fields);
    }

    public static org.apache.spark.sql.Row dbToSparkRow(edb.common.Row dbRow, String[] colNames) {
        int fieldCount = dbRow.getFieldCount();
        Object[] values = new Object[fieldCount];
        for (int i = 0; i < colNames.length; i++) {
            edb.common.Row.Field f = dbRow.getField(colNames[i]);
            if (f instanceof edb.common.Row.Int64Field) {
                edb.common.Row.Int64Field i64f = (edb.common.Row.Int64Field) f;
                values[i] = i64f.getValue();
            } else if (f instanceof edb.common.Row.DoubleField) {
                edb.common.Row.DoubleField df = (edb.common.Row.DoubleField) f;
                values[i] = df.getValue();
            } else if (f instanceof edb.common.Row.StringField) {
                edb.common.Row.StringField df = (edb.common.Row.StringField) f;
                values[i] = df.getValue();
            }
        }
        GenericRow row = new GenericRow(values);
        return row;
    }

    public static edb.common.Row sparkToDBRow(org.apache.spark.sql.Row row, StructType type) {
        edb.common.Row dbRow = new edb.common.Row();
        StructField[] fields = type.fields();
        for (int i = 0; i < type.size(); i++) {
            StructField sf = fields[i];
            if (sf.dataType() == DataTypes.StringType) {
                dbRow.addField(new edb.common.Row.StringField(sf.name(), row.getString(i)));
            } else if (sf.dataType() == DataTypes.DoubleType) {
                dbRow.addField(new edb.common.Row.DoubleField(sf.name(), row.getDouble(i)));
            } else if (sf.dataType() == DataTypes.LongType) {
                dbRow.addField(new edb.common.Row.Int64Field(sf.name(), row.getLong(i)));
            } else {
                // TODO: type leakage
            }
        }

        return dbRow;
    }

    public static edb.common.Row sparkToDBRow(org.apache.spark.sql.Row row, Schema schema) {
        edb.common.Row dbRow = new edb.common.Row();
        for (int i = 0; i < schema.getColumnCount(); i++) {
            if (schema.getColumnType(i) ==  Schema.ColumnType.STRING) {
                dbRow.addField(new edb.common.Row.StringField(schema.getColumnName(i),
                        row.getString(i)));
            } else if (schema.getColumnType(i) ==  Schema.ColumnType.DOUBLE) {
                dbRow.addField(new edb.common.Row.DoubleField(schema.getColumnName(i),
                        row.getDouble(i)));
            } else if (schema.getColumnType(i) ==  Schema.ColumnType.INT64) {
                dbRow.addField(new edb.common.Row.Int64Field(schema.getColumnName(i),
                        row.getLong(i)));
            } else {
                // TODO: type leakage
            }
        }

        return dbRow;
    }

    public StructType getSparkSchema(String table) throws UnknownTableException
    {

        Schema schema =_client.getTableSchema(table);
        return dbToSparkSchema(schema);
    }

    public Schema getDBSchema(String table) throws UnknownTableException
    {
        Schema schema =_client.getTableSchema(table);
        return schema;
    }

    public String getClusteredIndexColumn(String table) throws UnknownTableException {
        return _client.getTableClusteredIndexColumn(table);
    }

    public List<Split> getSplits(String table, int count) throws UnknownTableException {
        return _client.getSplits(table, count);
    }

    public List<Split> getSplits(String table) throws UnknownTableException {
        return _client.getSplits(table);
    }

    private String _host;
    private int _port;
    private DBClient _client;
    private String _tableName;

}