package org.datavec.arrow.recordreader; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.FieldVector; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.datavec.arrow.ArrowConverter; import org.junit.Ignore; import org.junit.Test; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; public class ArrowWritableRecordTimeSeriesBatchTests { private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); @Test public void testBasicIndexing() { Schema.Builder schema = new Schema.Builder(); for(int i = 0; i < 3; i++) { schema.addColumnInteger(String.valueOf(i)); } List<List<Writable>> timeStep = Arrays.asList( Arrays.<Writable>asList(new IntWritable(0),new IntWritable(1),new IntWritable(2)), Arrays.<Writable>asList(new IntWritable(1),new IntWritable(2),new IntWritable(3)), Arrays.<Writable>asList(new IntWritable(4),new IntWritable(5),new IntWritable(6)) ); int numTimeSteps = 5; List<List<List<Writable>>> timeSteps = new ArrayList<>(numTimeSteps); for(int i = 0; i < numTimeSteps; i++) { timeSteps.add(timeStep); } List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsTimeSeries(bufferAllocator, schema.build(), timeSteps); assertEquals(3,fieldVectors.size()); for(FieldVector fieldVector : fieldVectors) { for(int i = 0; i < fieldVector.getValueCount(); i++) { assertFalse("Index " + i + " was null for field vector " + fieldVector, fieldVector.isNull(i)); } } ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,schema.build(),timeStep.size() * timeStep.get(0).size()); assertEquals(timeSteps,arrowWritableRecordTimeSeriesBatch.toArrayList()); } @Test //not worried about this till after next release @Ignore public void testVariableLengthTS() { Schema.Builder schema = new Schema.Builder() .addColumnString("str") .addColumnInteger("int") .addColumnDouble("dbl"); List<List<Writable>> firstSeq = Arrays.asList( Arrays.<Writable>asList(new Text("00"),new IntWritable(0),new DoubleWritable(2.0)), Arrays.<Writable>asList(new Text("01"),new IntWritable(1),new DoubleWritable(2.1)), Arrays.<Writable>asList(new Text("02"),new IntWritable(2),new DoubleWritable(2.2))); List<List<Writable>> secondSeq = Arrays.asList( Arrays.<Writable>asList(new Text("10"),new IntWritable(10),new DoubleWritable(12.0)), Arrays.<Writable>asList(new Text("11"),new IntWritable(11),new DoubleWritable(12.1))); List<List<List<Writable>>> sequences = Arrays.asList(firstSeq, secondSeq); List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsTimeSeries(bufferAllocator, schema.build(), sequences); assertEquals(3,fieldVectors.size()); int timeSeriesStride = -1; //Can't sequences of different length... ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,schema.build(),timeSeriesStride); List<List<List<Writable>>> asList = arrowWritableRecordTimeSeriesBatch.toArrayList(); assertEquals(sequences, asList); } }