package org.datavec.spark.transform; import lombok.AllArgsConstructor; import lombok.Getter; import lombok.val; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.FieldVector; import org.datavec.api.transform.TransformProcess; import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.writable.Writable; import org.datavec.arrow.ArrowConverter; import org.datavec.arrow.recordreader.ArrowWritableRecordBatch; import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch; import org.datavec.local.transforms.LocalTransformExecutor; import org.datavec.spark.transform.model.Base64NDArrayBody; import org.datavec.spark.transform.model.BatchCSVRecord; import org.datavec.spark.transform.model.SequenceBatchCSVRecord; import org.datavec.spark.transform.model.SingleCSVRecord; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.serde.base64.Nd4jBase64; import java.io.IOException; import java.util.Arrays; import java.util.List; import static org.datavec.arrow.ArrowConverter.*; import static org.datavec.local.transforms.LocalTransformExecutor.execute; import static org.datavec.local.transforms.LocalTransformExecutor.executeSequenceToSequence; import static org.datavec.local.transforms.LocalTransformExecutor.executeToSequence; /** * CSVSpark Transform runs * the actual {@link TransformProcess} * * @author Adan Gibson */ @AllArgsConstructor public class CSVSparkTransform { @Getter private TransformProcess transformProcess; private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); /** * Convert a raw record via * the {@link TransformProcess} * to a base 64ed ndarray * @param batch the record to convert * @return teh base 64ed ndarray * @throws IOException */ public Base64NDArrayBody toArray(BatchCSVRecord batch) throws IOException { List<List<Writable>> converted = execute(toArrowWritables(toArrowColumnsString( bufferAllocator,transformProcess.getInitialSchema(), batch.getRecordsAsString()), transformProcess.getInitialSchema()),transformProcess); ArrowWritableRecordBatch arrowRecordBatch = (ArrowWritableRecordBatch) converted; INDArray convert = ArrowConverter.toArray(arrowRecordBatch); return new Base64NDArrayBody(Nd4jBase64.base64String(convert)); } /** * Convert a raw record via * the {@link TransformProcess} * to a base 64ed ndarray * @param record the record to convert * @return the base 64ed ndarray * @throws IOException */ public Base64NDArrayBody toArray(SingleCSVRecord record) throws IOException { List<Writable> record2 = toArrowWritablesSingle( toArrowColumnsStringSingle(bufferAllocator, transformProcess.getInitialSchema(),record.getValues()), transformProcess.getInitialSchema()); List<Writable> finalRecord = execute(Arrays.asList(record2),transformProcess).get(0); INDArray convert = RecordConverter.toArray(finalRecord); return new Base64NDArrayBody(Nd4jBase64.base64String(convert)); } /** * Runs the transform process * @param batch the record to transform * @return the transformed record */ public BatchCSVRecord transform(BatchCSVRecord batch) { BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); List<List<Writable>> converted = execute(toArrowWritables(toArrowColumnsString( bufferAllocator,transformProcess.getInitialSchema(), batch.getRecordsAsString()), transformProcess.getInitialSchema()),transformProcess); int numCols = converted.get(0).size(); for (int row = 0; row < converted.size(); row++) { String[] values = new String[numCols]; for (int i = 0; i < values.length; i++) values[i] = converted.get(row).get(i).toString(); batchCSVRecord.add(new SingleCSVRecord(values)); } return batchCSVRecord; } /** * Runs the transform process * @param record the record to transform * @return the transformed record */ public SingleCSVRecord transform(SingleCSVRecord record) { List<Writable> record2 = toArrowWritablesSingle( toArrowColumnsStringSingle(bufferAllocator, transformProcess.getInitialSchema(),record.getValues()), transformProcess.getInitialSchema()); List<Writable> finalRecord = execute(Arrays.asList(record2),transformProcess).get(0); String[] values = new String[finalRecord.size()]; for (int i = 0; i < values.length; i++) values[i] = finalRecord.get(i).toString(); return new SingleCSVRecord(values); } /** * * @param transform * @return */ public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) { /** * Sequence schema? */ List<List<List<Writable>>> converted = executeToSequence( toArrowWritables(toArrowColumnsStringTimeSeries( bufferAllocator, transformProcess.getInitialSchema(), Arrays.asList(transform.getRecordsAsString())), transformProcess.getInitialSchema()), transformProcess); SequenceBatchCSVRecord batchCSVRecord = new SequenceBatchCSVRecord(); for (int i = 0; i < converted.size(); i++) { BatchCSVRecord batchCSVRecord1 = BatchCSVRecord.fromWritables(converted.get(i)); batchCSVRecord.add(Arrays.asList(batchCSVRecord1)); } return batchCSVRecord; } /** * * @param batchCSVRecordSequence * @return */ public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecordSequence) { List<List<List<String>>> recordsAsString = batchCSVRecordSequence.getRecordsAsString(); boolean allSameLength = true; Integer length = null; for(List<List<String>> record : recordsAsString) { if(length == null) { length = record.size(); } else if(record.size() != length) { allSameLength = false; } } if(allSameLength) { List<FieldVector> fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), recordsAsString); ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors, transformProcess.getInitialSchema(), recordsAsString.get(0).get(0).size()); val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess); return SequenceBatchCSVRecord.fromWritables(transformed); } else { val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecordSequence.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess); return SequenceBatchCSVRecord.fromWritables(transformed); } } /** * TODO: optimize * @param batchCSVRecordSequence * @return */ public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecordSequence) { List<List<List<String>>> strings = batchCSVRecordSequence.getRecordsAsString(); boolean allSameLength = true; Integer length = null; for(List<List<String>> record : strings) { if(length == null) { length = record.size(); } else if(record.size() != length) { allSameLength = false; } } if(allSameLength) { List<FieldVector> fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), strings); ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,transformProcess.getInitialSchema(),strings.get(0).get(0).size()); val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess); INDArray arr = RecordConverter.toTensor(transformed).reshape(strings.size(),strings.get(0).get(0).size(),strings.get(0).size()); try { return new Base64NDArrayBody(Nd4jBase64.base64String(arr)); } catch (IOException e) { throw new IllegalStateException(e); } } else { val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecordSequence.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess); INDArray arr = RecordConverter.toTensor(transformed).reshape(strings.size(),strings.get(0).get(0).size(),strings.get(0).size()); try { return new Base64NDArrayBody(Nd4jBase64.base64String(arr)); } catch (IOException e) { throw new IllegalStateException(e); } } } /** * * @param singleCsvRecord * @return */ public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) { List<List<List<Writable>>> converted = executeToSequence(toArrowWritables(toArrowColumnsString( bufferAllocator,transformProcess.getInitialSchema(), singleCsvRecord.getRecordsAsString()), transformProcess.getInitialSchema()),transformProcess); ArrowWritableRecordTimeSeriesBatch arrowWritableRecordBatch = (ArrowWritableRecordTimeSeriesBatch) converted; INDArray arr = RecordConverter.toTensor(arrowWritableRecordBatch); try { return new Base64NDArrayBody(Nd4jBase64.base64String(arr)); } catch (IOException e) { e.printStackTrace(); } return null; } public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) { List<List<List<String>>> strings = batchCSVRecord.getRecordsAsString(); boolean allSameLength = true; Integer length = null; for(List<List<String>> record : strings) { if(length == null) { length = record.size(); } else if(record.size() != length) { allSameLength = false; } } if(allSameLength) { List<FieldVector> fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), strings); ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,transformProcess.getInitialSchema(),strings.get(0).get(0).size()); val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess); return SequenceBatchCSVRecord.fromWritables(transformed); } else { val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecord.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess); return SequenceBatchCSVRecord.fromWritables(transformed); } } }