package eu.amidst.sparklink.core.data; import eu.amidst.core.datastream.DataInstance; import eu.amidst.core.datastream.DataOnMemory; import eu.amidst.core.datastream.Attribute; import eu.amidst.core.datastream.Attributes; import eu.amidst.core.datastream.filereaders.DataInstanceFromDataRow; import eu.amidst.core.variables.StateSpaceType; import eu.amidst.core.variables.stateSpaceTypes.FiniteStateSpace; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.types.DataTypes; import java.util.ArrayList; import java.util.Iterator; import static eu.amidst.core.variables.StateSpaceTypeEnum.REAL; /** * Created by jarias on 21/06/16. */ class DataFrameOps { static JavaRDD<DataInstance> toDataInstanceRDD(DataFrame data, Attributes attributes) { JavaRDD<double[]> rawRDD = data.rdd() .toJavaRDD() .map( row -> transformRow2DataInstance(row, attributes) ); return rawRDD.map(v -> new DataInstanceFromDataRow( new DataRowSpark(v, attributes) ) ); } static JavaRDD<DataOnMemory<DataInstance>> toBatchedRDD(JavaRDD<DataInstance> instanceRDD, Attributes attributes, int batchSize) { return instanceRDD.mapPartitions( partition -> partition2Batches(partition, attributes, batchSize) ); } private static double[] transformRow2DataInstance(Row row, Attributes attributes) throws Exception { double[] instance = new double[row.length()]; for (int i = 0; i < row.length(); i++) { Attribute att = attributes.getFullListOfAttributes().get(i); StateSpaceType space = att.getStateSpaceType(); switch (space.getStateSpaceTypeEnum()) { case REAL: instance[i] = row.getDouble(i); break; case FINITE_SET: String state = row.getString(i); double index = ((FiniteStateSpace) space).getIndexOfState(state); instance[i] = index; break; default: // This should never execute throw new Exception("Unrecognized Error"); } } return instance; } private static Iterable<DataOnMemory<DataInstance>> partition2Batches(Iterator<DataInstance> partition, Attributes attributes, int batchSize) { ArrayList<DataOnMemory<DataInstance>> batches = new ArrayList<>(); int currentSize = 0; ArrayList<DataInstance> batch = new ArrayList<>(); while(partition.hasNext()) { batch.add(partition.next()); currentSize++; if (currentSize >= batchSize) { currentSize = 0; batches.add(new DataOnMemoryListContainerSerializable<DataInstance>(attributes, batch)); batch = new ArrayList<>(); } } // Add the last batch if there are any remaining instances: if (currentSize > 0) batches.add(new DataOnMemoryListContainerSerializable<>(attributes, batch)); return batches; } static JavaRDD<Row> toRowRDD(JavaRDD<DataInstance> rawRDD, Attributes atts) { // FIXME: Categorical values should be inserted with their corresponding state name return rawRDD.map( v -> transformArray2RowAttributes(v, atts)); } private static Row transformArray2RowAttributes(DataInstance inst, Attributes atts) { double[] values = inst.toArray(); Object[] rowValues = new Object[values.length]; for (int a = 0; a < atts.getNumberOfAttributes(); a++) { Attribute attribute = atts.getFullListOfAttributes().get(a); StateSpaceType domain = attribute.getStateSpaceType(); if (domain.getStateSpaceTypeEnum() == REAL) rowValues[a] = new Double(values[a]); else rowValues[a] = domain.stringValue(values[a]); } return RowFactory.create(rowValues); } }