org.datavec.api.records.reader.impl.csv.CSVRecordReader Java Examples

The following examples show how to use org.datavec.api.records.reader.impl.csv.CSVRecordReader. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: IrisFileDataSource.java    From FederatedAndroidTrainer with MIT License 7 votes vote down vote up
private void createDataSource() throws IOException, InterruptedException {
    //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int numLinesToSkip = 0;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    recordReader.initialize(new InputStreamInputSplit(dataFile));

    //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
    int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
    int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
    DataSet allData = iterator.next();
    allData.shuffle();

    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80);  //Use 80% of data for training

    trainingData = testAndTrain.getTrain();
    testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    normalizer.transform(trainingData);     //Apply normalization to the training data
    normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
 
Example #2
Source File: RecordReaderDataSetiteratorTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testRRDSIwithAsync() throws Exception {
    RecordReader csv = new CSVRecordReader();
    csv.initialize(new FileSplit(Resources.asFile("iris.txt")));

    int batchSize = 10;
    int labelIdx = 4;
    int numClasses = 3;

    RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(csv, batchSize, labelIdx, numClasses);
    AsyncDataSetIterator adsi = new AsyncDataSetIterator(rrdsi, 8, true);
    while (adsi.hasNext()) {
        DataSet ds = adsi.next();

    }

}
 
Example #3
Source File: DiabetesFileDataSource.java    From FederatedAndroidTrainer with MIT License 6 votes vote down vote up
private void createDataSource() throws IOException, InterruptedException {
    //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int numLinesToSkip = 0;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    recordReader.initialize(new InputStreamInputSplit(dataFile));

    //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
    int labelIndex = 11;

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true);
    DataSet allData = iterator.next();

    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80);  //Use 80% of data for training

    trainingData = testAndTrain.getTrain();
    testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    normalizer.transform(trainingData);     //Apply normalization to the training data
    normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
 
Example #4
Source File: BatchInputParserMultiRecordTest.java    From konduit-serving with Apache License 2.0 6 votes vote down vote up
@Test(timeout = 60000)
public void runAdd(TestContext testContext) throws Exception {
    BatchInputArrowParserVerticle verticleRef = (BatchInputArrowParserVerticle) verticle;
    Schema irisInputSchema = TrainUtils.getIrisInputSchema();
    ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(irisInputSchema);
    CSVRecordReader reader = new CSVRecordReader();
    reader.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
    List<List<Writable>> writables = reader.next(150);

    File tmpFile = new File(temporary.getRoot(), "tmp.arrow");
    FileSplit fileSplit = new FileSplit(tmpFile);
    arrowRecordWriter.initialize(fileSplit, new NumberOfRecordsPartitioner());
    arrowRecordWriter.writeBatch(writables);

    given().port(port)
            .multiPart("input1", tmpFile)
            .when().post("/")
            .then().statusCode(200);

    testContext.assertNotNull(verticleRef.getBatch(), "Inputs were null. This means parsing failed.");
    testContext.assertTrue(verticleRef.getBatch().length >= 1);
    testContext.assertNotNull(verticleRef.getBatch());
    testContext.assertEquals(150, verticleRef.getBatch().length);
}
 
Example #5
Source File: HyperParameterTuningArbiterUiExample.java    From Java-Deep-Learning-Cookbook with MIT License 6 votes vote down vote up
public RecordReader dataPreprocess() throws IOException, InterruptedException {
    //Schema Definitions
    Schema schema = new Schema.Builder()
            .addColumnsString("RowNumber")
            .addColumnInteger("CustomerId")
            .addColumnString("Surname")
            .addColumnInteger("CreditScore")
            .addColumnCategorical("Geography", Arrays.asList("France","Spain","Germany"))
            .addColumnCategorical("Gender",Arrays.asList("Male","Female"))
            .addColumnsInteger("Age","Tenure","Balance","NumOfProducts","HasCrCard","IsActiveMember","EstimatedSalary","Exited").build();

    //Schema Transformation
    TransformProcess transformProcess = new TransformProcess.Builder(schema)
            .removeColumns("RowNumber","Surname","CustomerId")
            .categoricalToInteger("Gender")
            .categoricalToOneHot("Geography")
            .removeColumns("Geography[France]")
            .build();

    //CSVReader - Reading from file and applying transformation
    RecordReader reader = new CSVRecordReader(1,',');
    reader.initialize(new FileSplit(new ClassPathResource("Churn_Modelling.csv").getFile()));
    RecordReader transformProcessRecordReader = new TransformProcessRecordReader(reader,transformProcess);
    return transformProcessRecordReader;
}
 
Example #6
Source File: HyperParameterTuningArbiterUiExample.java    From Java-Deep-Learning-Cookbook with MIT License 6 votes vote down vote up
public RecordReader dataPreprocess() throws IOException, InterruptedException {
    //Schema Definitions
    Schema schema = new Schema.Builder()
            .addColumnsString("RowNumber")
            .addColumnInteger("CustomerId")
            .addColumnString("Surname")
            .addColumnInteger("CreditScore")
            .addColumnCategorical("Geography", Arrays.asList("France","Spain","Germany"))
            .addColumnCategorical("Gender",Arrays.asList("Male","Female"))
            .addColumnsInteger("Age","Tenure","Balance","NumOfProducts","HasCrCard","IsActiveMember","EstimatedSalary","Exited").build();

    //Schema Transformation
    TransformProcess transformProcess = new TransformProcess.Builder(schema)
            .removeColumns("RowNumber","Surname","CustomerId")
            .categoricalToInteger("Gender")
            .categoricalToOneHot("Geography")
            .removeColumns("Geography[France]")
            .build();

    //CSVReader - Reading from file and applying transformation
    RecordReader reader = new CSVRecordReader(1,',');
    reader.initialize(new FileSplit(new ClassPathResource("Churn_Modelling.csv").getFile()));
    RecordReader transformProcessRecordReader = new TransformProcessRecordReader(reader,transformProcess);
    return transformProcessRecordReader;
}
 
Example #7
Source File: HyperParameterTuning.java    From Java-Deep-Learning-Cookbook with MIT License 6 votes vote down vote up
public RecordReader dataPreprocess() throws IOException, InterruptedException {
    //Schema Definitions
    Schema schema = new Schema.Builder()
            .addColumnsString("RowNumber")
            .addColumnInteger("CustomerId")
            .addColumnString("Surname")
            .addColumnInteger("CreditScore")
            .addColumnCategorical("Geography",Arrays.asList("France","Spain","Germany"))
            .addColumnCategorical("Gender",Arrays.asList("Male","Female"))
            .addColumnsInteger("Age","Tenure","Balance","NumOfProducts","HasCrCard","IsActiveMember","EstimatedSalary","Exited").build();

    //Schema Transformation
    TransformProcess transformProcess = new TransformProcess.Builder(schema)
            .removeColumns("RowNumber","Surname","CustomerId")
            .categoricalToInteger("Gender")
            .categoricalToOneHot("Geography")
            .removeColumns("Geography[France]")
            .build();

    //CSVReader - Reading from file and applying transformation
    RecordReader reader = new CSVRecordReader(1,',');
    reader.initialize(new FileSplit(new ClassPathResource("Churn_Modelling.csv").getFile()));
    RecordReader transformProcessRecordReader = new TransformProcessRecordReader(reader,transformProcess);
    return transformProcessRecordReader;
}
 
Example #8
Source File: ArrowBinaryInputAdapterTest.java    From konduit-serving with Apache License 2.0 6 votes vote down vote up
@Test(timeout = 60000)

    public void testArrowBinary() throws Exception {
        Schema irisInputSchema = TrainUtils.getIrisInputSchema();
        ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(irisInputSchema);
        CSVRecordReader reader = new CSVRecordReader();
        reader.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
        List<List<Writable>> writables = reader.next(150);

        File tmpFile = new File(temporary.getRoot(), "tmp.arrow");
        FileSplit fileSplit = new FileSplit(tmpFile);
        arrowRecordWriter.initialize(fileSplit, new NumberOfRecordsPartitioner());
        arrowRecordWriter.writeBatch(writables);
        byte[] arrowBytes = FileUtils.readFileToByteArray(tmpFile);

        Buffer buffer = Buffer.buffer(arrowBytes);
        ArrowBinaryInputAdapter arrowBinaryInputAdapter = new ArrowBinaryInputAdapter();
        ArrowWritableRecordBatch convert = arrowBinaryInputAdapter.convert(buffer, ConverterArgs.builder().schema(irisInputSchema).build(), null);
        assertEquals(writables.size(), convert.size());
    }
 
Example #9
Source File: DataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testMnist() throws Exception {
    ClassPathResource cpr = new ClassPathResource("mnist_first_200.txt");
    CSVRecordReader rr = new CSVRecordReader(0, ',');
    rr.initialize(new FileSplit(cpr.getTempFileFromArchive()));
    RecordReaderDataSetIterator dsi = new RecordReaderDataSetIterator(rr, 10, 0, 10);

    MnistDataSetIterator iter = new MnistDataSetIterator(10, 200, false, true, false, 0);

    while (dsi.hasNext()) {
        DataSet dsExp = dsi.next();
        DataSet dsAct = iter.next();

        INDArray fExp = dsExp.getFeatures();
        fExp.divi(255);
        INDArray lExp = dsExp.getLabels();

        INDArray fAct = dsAct.getFeatures();
        INDArray lAct = dsAct.getLabels();

        assertEquals(fExp, fAct.castTo(fExp.dataType()));
        assertEquals(lExp, lAct.castTo(lExp.dataType()));
    }
    assertFalse(iter.hasNext());
}
 
Example #10
Source File: TransformProcessRecordReaderTests.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void simpleTransformTest() throws Exception {
    Schema schema = new Schema.Builder()
            .addColumnsDouble("%d", 0, 4)
            .build();
    TransformProcess transformProcess = new TransformProcess.Builder(schema).removeColumns("0").build();
    CSVRecordReader csvRecordReader = new CSVRecordReader();
    csvRecordReader.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));
    TransformProcessRecordReader rr =
                    new TransformProcessRecordReader(csvRecordReader, transformProcess);
    int count = 0;
    List<List<Writable>> all = new ArrayList<>();
    while(rr.hasNext()){
        List<Writable> next = rr.next();
        assertEquals(4, next.size());
        count++;
        all.add(next);
    }
    assertEquals(150, count);

    //Test batch:
    assertTrue(rr.resetSupported());
    rr.reset();
    List<List<Writable>> batch = rr.next(150);
    assertEquals(all, batch);
}
 
Example #11
Source File: TestConcatenatingRecordReader.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void test() throws Exception {

    CSVRecordReader rr = new CSVRecordReader(0, ',');
    rr.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));

    CSVRecordReader rr2 = new CSVRecordReader(0, ',');
    rr2.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));

    RecordReader rrC = new ConcatenatingRecordReader(rr, rr2);

    int count = 0;
    while(rrC.hasNext()){
        rrC.next();
        count++;
    }

    assertEquals(300, count);
}
 
Example #12
Source File: CSVRecordReaderTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testReset() throws Exception {
    CSVRecordReader rr = new CSVRecordReader(0, ',');
    rr.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));

    int nResets = 5;
    for (int i = 0; i < nResets; i++) {

        int lineCount = 0;
        while (rr.hasNext()) {
            List<Writable> line = rr.next();
            assertEquals(5, line.size());
            lineCount++;
        }
        assertFalse(rr.hasNext());
        assertEquals(150, lineCount);
        rr.reset();
    }
}
 
Example #13
Source File: CSVRecordReaderTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testResetWithSkipLines() throws Exception {
    CSVRecordReader rr = new CSVRecordReader(10, ',');
    rr.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));
    int lineCount = 0;
    while (rr.hasNext()) {
        rr.next();
        ++lineCount;
    }
    assertEquals(140, lineCount);
    rr.reset();
    lineCount = 0;
    while (rr.hasNext()) {
        rr.next();
        ++lineCount;
    }
    assertEquals(140, lineCount);
}
 
Example #14
Source File: CSVRecordReaderTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test(expected = NoSuchElementException.class)
public void testCsvSkipAllLines() throws IOException, InterruptedException {
    final int numLines = 4;
    final List<Writable> lineList = Arrays.asList((Writable) new IntWritable(numLines - 1),
                    (Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three"));
    String header = ",one,two,three";
    List<String> lines = new ArrayList<>();
    for (int i = 0; i < numLines; i++)
        lines.add(Integer.toString(i) + header);
    File tempFile = File.createTempFile("csvSkipLines", ".csv");
    FileUtils.writeLines(tempFile, lines);

    CSVRecordReader rr = new CSVRecordReader(numLines, ',');
    rr.initialize(new FileSplit(tempFile));
    rr.reset();
    assertTrue(!rr.hasNext());
    rr.next();
}
 
Example #15
Source File: CSVRecordReaderTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testStreamReset() throws Exception {
    CSVRecordReader rr = new CSVRecordReader(0, ',');
    rr.initialize(new InputStreamInputSplit(new ClassPathResource("iris.dat").getInputStream()));

    int count = 0;
    while(rr.hasNext()){
        assertNotNull(rr.next());
        count++;
    }
    assertEquals(150, count);

    assertFalse(rr.resetSupported());

    try{
        rr.reset();
        fail("Expected exception");
    } catch (Exception e){
        e.printStackTrace();
    }
}
 
Example #16
Source File: TestKryoSerialization.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testCsvRecordReader() throws Exception {
    SerializerInstance si = sc.env().serializer().newInstance();
    assertTrue(si instanceof KryoSerializerInstance);

    RecordReader r1 = new CSVRecordReader(1,'\t');
    RecordReader r2 = serDe(r1, si);

    File f = new ClassPathResource("iris_tab_delim.txt").getFile();
    r1.initialize(new FileSplit(f));
    r2.initialize(new FileSplit(f));

    while(r1.hasNext()){
        assertEquals(r1.next(), r2.next());
    }
    assertFalse(r2.hasNext());
}
 
Example #17
Source File: TransformProcessRecordReaderTests.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void simpleTransformTest() throws Exception {
    Schema schema = new Schema.Builder()
            .addColumnsDouble("%d", 0, 4)
            .build();
    TransformProcess transformProcess = new TransformProcess.Builder(schema).removeColumns("0").build();
    CSVRecordReader csvRecordReader = new CSVRecordReader();
    csvRecordReader.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
    TransformProcessRecordReader rr =
                    new TransformProcessRecordReader(csvRecordReader, transformProcess);
    int count = 0;
    List<List<Writable>> all = new ArrayList<>();
    while(rr.hasNext()){
        List<Writable> next = rr.next();
        assertEquals(4, next.size());
        count++;
        all.add(next);
    }
    assertEquals(150, count);

    //Test batch:
    assertTrue(rr.resetSupported());
    rr.reset();
    List<List<Writable>> batch = rr.next(150);
    assertEquals(all, batch);
}
 
Example #18
Source File: TestConcatenatingRecordReader.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void test() throws Exception {

    CSVRecordReader rr = new CSVRecordReader(0, ',');
    rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));

    CSVRecordReader rr2 = new CSVRecordReader(0, ',');
    rr2.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));

    RecordReader rrC = new ConcatenatingRecordReader(rr, rr2);

    int count = 0;
    while(rrC.hasNext()){
        rrC.next();
        count++;
    }

    assertEquals(300, count);
}
 
Example #19
Source File: RecordReaderDataSetiteratorTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testNormalizerPrefetchReset() throws Exception {
    //Check NPE fix for: https://github.com/deeplearning4j/deeplearning4j/issues/4214
    RecordReader csv = new CSVRecordReader();
    csv.initialize(new FileSplit(Resources.asFile("iris.txt")));

    int batchSize = 3;

    DataSetIterator iter = new RecordReaderDataSetIterator(csv, batchSize, 4, 4, true);

    DataNormalization normalizer = new NormalizerMinMaxScaler(0, 1);
    normalizer.fit(iter);
    iter.setPreProcessor(normalizer);

    iter.inputColumns();    //Prefetch
    iter.totalOutcomes();
    iter.hasNext();
    iter.reset();
    iter.next();
}
 
Example #20
Source File: CSVRecordReaderTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testReset() throws Exception {
    CSVRecordReader rr = new CSVRecordReader(0, ',');
    rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));

    int nResets = 5;
    for (int i = 0; i < nResets; i++) {

        int lineCount = 0;
        while (rr.hasNext()) {
            List<Writable> line = rr.next();
            assertEquals(5, line.size());
            lineCount++;
        }
        assertFalse(rr.hasNext());
        assertEquals(150, lineCount);
        rr.reset();
    }
}
 
Example #21
Source File: RecordReaderMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testsBasicMeta() throws Exception {
    //As per testBasic - but also loading metadata
    RecordReader rr2 = new CSVRecordReader(0, ',');
    rr2.initialize(new FileSplit(Resources.asFile("iris.txt")));

    RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10)
                    .addReader("reader", rr2).addInput("reader", 0, 3).addOutputOneHot("reader", 4, 3).build();

    rrmdsi.setCollectMetaData(true);

    int count = 0;
    while (rrmdsi.hasNext()) {
        MultiDataSet mds = rrmdsi.next();
        MultiDataSet fromMeta = rrmdsi.loadFromMetaData(mds.getExampleMetaData(RecordMetaData.class));
        assertEquals(mds, fromMeta);
        count++;
    }
    assertEquals(150 / 10, count);
}
 
Example #22
Source File: MultipleEpochsIteratorTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testLoadBatchDataSet() throws Exception {
    int epochs = 2;

    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150, 4, 3);
    DataSet ds = iter.next(20);
    assertEquals(20, ds.getFeatures().size(0));
    MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);

    while (multiIter.hasNext()) {
        DataSet path = multiIter.next(10);
        assertNotNull(path);
        assertEquals(10, path.numExamples(), 0.0);
    }

    assertEquals(epochs, multiIter.epochs);
}
 
Example #23
Source File: CSVRecordReaderTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test(expected = NoSuchElementException.class)
public void testCsvSkipAllLines() throws IOException, InterruptedException {
    final int numLines = 4;
    final List<Writable> lineList = Arrays.asList((Writable) new IntWritable(numLines - 1),
                    (Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three"));
    String header = ",one,two,three";
    List<String> lines = new ArrayList<>();
    for (int i = 0; i < numLines; i++)
        lines.add(Integer.toString(i) + header);
    File tempFile = File.createTempFile("csvSkipLines", ".csv");
    FileUtils.writeLines(tempFile, lines);

    CSVRecordReader rr = new CSVRecordReader(numLines, ',');
    rr.initialize(new FileSplit(tempFile));
    rr.reset();
    assertTrue(!rr.hasNext());
    rr.next();
}
 
Example #24
Source File: CSVRecordReaderTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testCsvSkipAllButOneLine() throws IOException, InterruptedException {
    final int numLines = 4;
    final List<Writable> lineList = Arrays.<Writable>asList(new Text(Integer.toString(numLines - 1)),
            new Text("one"), new Text("two"), new Text("three"));
    String header = ",one,two,three";
    List<String> lines = new ArrayList<>();
    for (int i = 0; i < numLines; i++)
        lines.add(Integer.toString(i) + header);
    File tempFile = File.createTempFile("csvSkipLines", ".csv");
    FileUtils.writeLines(tempFile, lines);

    CSVRecordReader rr = new CSVRecordReader(numLines - 1, ',');
    rr.initialize(new FileSplit(tempFile));
    rr.reset();
    assertTrue(rr.hasNext());
    assertEquals(rr.next(), lineList);
}
 
Example #25
Source File: CSVRecordReaderTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
    public void testStreamReset() throws Exception {
        CSVRecordReader rr = new CSVRecordReader(0, ',');
        rr.initialize(new InputStreamInputSplit(new ClassPathResource("datavec-api/iris.dat").getInputStream()));

        int count = 0;
        while(rr.hasNext()){
            assertNotNull(rr.next());
            count++;
        }
        assertEquals(150, count);

        assertFalse(rr.resetSupported());

        try{
            rr.reset();
            fail("Expected exception");
        } catch (Exception e){
            String msg = e.getMessage();
            String msg2 = e.getCause().getMessage();
            assertTrue(msg, msg.contains("Error during LineRecordReader reset"));
            assertTrue(msg2, msg2.contains("Reset not supported from streams"));
//            e.printStackTrace();
        }
    }
 
Example #26
Source File: MultipleEpochsIteratorTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testNextAndReset() throws Exception {
    int epochs = 3;

    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
    MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, iter);

    assertTrue(multiIter.hasNext());
    while (multiIter.hasNext()) {
        DataSet path = multiIter.next();
        assertFalse(path == null);
    }
    assertEquals(epochs, multiIter.epochs);
}
 
Example #27
Source File: TestKryoSerialization.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testCsvRecordReader() throws Exception {
    SerializerInstance si = sc.env().serializer().newInstance();
    assertTrue(si instanceof KryoSerializerInstance);

    RecordReader r1 = new CSVRecordReader(1,'\t');
    RecordReader r2 = serDe(r1, si);

    File f = new ClassPathResource("iris_tab_delim.txt").getFile();
    r1.initialize(new FileSplit(f));
    r2.initialize(new FileSplit(f));

    while(r1.hasNext()){
        assertEquals(r1.next(), r2.next());
    }
    assertFalse(r2.hasNext());
}
 
Example #28
Source File: CSVRecordReaderTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testCsvSkipAllButOneLine() throws IOException, InterruptedException {
    final int numLines = 4;
    final List<Writable> lineList = Arrays.<Writable>asList(new Text(Integer.toString(numLines - 1)),
            new Text("one"), new Text("two"), new Text("three"));
    String header = ",one,two,three";
    List<String> lines = new ArrayList<>();
    for (int i = 0; i < numLines; i++)
        lines.add(Integer.toString(i) + header);
    File tempFile = File.createTempFile("csvSkipLines", ".csv");
    FileUtils.writeLines(tempFile, lines);

    CSVRecordReader rr = new CSVRecordReader(numLines - 1, ',');
    rr.initialize(new FileSplit(tempFile));
    rr.reset();
    assertTrue(rr.hasNext());
    assertEquals(rr.next(), lineList);
}
 
Example #29
Source File: CSVInputFormat.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public RecordReader createReader(InputSplit split) throws IOException, InterruptedException {
    CSVRecordReader ret = new CSVRecordReader();
    ret.initialize(split);
    return ret;

}
 
Example #30
Source File: TestAnalyzeLocal.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testAnalysisBasic() throws Exception {

    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));

    Schema s = new Schema.Builder()
            .addColumnsDouble("0", "1", "2", "3")
            .addColumnInteger("label")
            .build();

    DataAnalysis da = AnalyzeLocal.analyze(s, rr);

    System.out.println(da);

    //Compare:
    List<List<Writable>> list = new ArrayList<>();
    rr.reset();
    while(rr.hasNext()){
        list.add(rr.next());
    }

    INDArray arr = RecordConverter.toMatrix(DataType.DOUBLE, list);
    INDArray mean = arr.mean(0);
    INDArray std = arr.std(0);

    for( int i=0; i<5; i++ ){
        double m = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getMean();
        double stddev = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getSampleStdev();
        assertEquals(mean.getDouble(i), m, 1e-3);
        assertEquals(std.getDouble(i), stddev, 1e-3);
    }

}