org.datavec.api.split.NumberedFileInputSplit Java Examples

The following examples show how to use org.datavec.api.split.NumberedFileInputSplit. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: SameDiffRNNTestCases.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected MultiDataSetIterator getTrainingDataUnnormalized() throws Exception {
    int miniBatchSize = 10;
    int numLabelClasses = 6;

    File featuresDirTrain = Files.createTempDir();
    File labelsDirTrain = Files.createTempDir();
    Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/features/", featuresDirTrain);
    Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/labels/", labelsDirTrain);

    SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
    trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
    SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
    trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));

    DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
            false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(trainData);

    return iter;
}
 
Example #2
Source File: JacksonRecordReaderTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testReadingJson() throws Exception {
    //Load 3 values from 3 JSON files
    //stricture: a:value, b:value, c:x:value, c:y:value
    //And we want to load only a:value, b:value and c:x:value
    //For first JSON file: all values are present
    //For second JSON file: b:value is missing
    //For third JSON file: c:x:value is missing

    ClassPathResource cpr = new ClassPathResource("json/json_test_0.txt");
    String path = cpr.getFile().getAbsolutePath().replace("0", "%d");

    InputSplit is = new NumberedFileInputSplit(path, 0, 2);

    RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()));
    rr.initialize(is);

    testJacksonRecordReader(rr);
}
 
Example #3
Source File: RNNTestCases.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected MultiDataSetIterator getTrainingDataUnnormalized() throws Exception {
    int miniBatchSize = 10;
    int numLabelClasses = 6;

    File featuresDirTrain = Files.createTempDir();
    File labelsDirTrain = Files.createTempDir();
    new ClassPathResource("dl4j-integration-tests/data/uci_seq/train/features/").copyDirectory(featuresDirTrain);
    new ClassPathResource("dl4j-integration-tests/data/uci_seq/train/labels/").copyDirectory(labelsDirTrain);

    SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
    trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
    SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
    trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));

    DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
            false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(trainData);
    return iter;
}
 
Example #4
Source File: CSVSequenceRecordReaderTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testCsvSeqAndNumberedFileSplit() throws Exception {
    File baseDir = tempDir.newFolder();
    //Simple sanity check unit test
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir);
    }

    //Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
    ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
    String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath();

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));

    while(featureReader.hasNext()){
        featureReader.nextSequence();
    }

}
 
Example #5
Source File: CSVSequenceRecordReaderTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testCsvSeqAndNumberedFileSplit() throws Exception {
    File baseDir = tempDir.newFolder();
    //Simple sanity check unit test
    for (int i = 0; i < 3; i++) {
        new org.nd4j.linalg.io.ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir);
    }

    //Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
    org.nd4j.linalg.io.ClassPathResource resource = new org.nd4j.linalg.io.ClassPathResource("csvsequence_0.txt");
    String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath();

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));

    while(featureReader.hasNext()){
        featureReader.nextSequence();
    }

}
 
Example #6
Source File: JacksonRecordReaderTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testReadingJson() throws Exception {
    //Load 3 values from 3 JSON files
    //stricture: a:value, b:value, c:x:value, c:y:value
    //And we want to load only a:value, b:value and c:x:value
    //For first JSON file: all values are present
    //For second JSON file: b:value is missing
    //For third JSON file: c:x:value is missing

    ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
    File f = testDir.newFolder();
    cpr.copyDirectory(f);
    String path = new File(f, "json_test_%d.txt").getAbsolutePath();

    InputSplit is = new NumberedFileInputSplit(path, 0, 2);

    RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()));
    rr.initialize(is);

    testJacksonRecordReader(rr);
}
 
Example #7
Source File: JacksonRecordReaderTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testReadingYaml() throws Exception {
    //Exact same information as JSON format, but in YAML format

    ClassPathResource cpr = new ClassPathResource("datavec-api/yaml/");
    File f = testDir.newFolder();
    cpr.copyDirectory(f);
    String path = new File(f, "yaml_test_%d.txt").getAbsolutePath();


    InputSplit is = new NumberedFileInputSplit(path, 0, 2);

    RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new YAMLFactory()));
    rr.initialize(is);

    testJacksonRecordReader(rr);
}
 
Example #8
Source File: JacksonRecordReaderTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testReadingXml() throws Exception {
    //Exact same information as JSON format, but in XML format

    ClassPathResource cpr = new ClassPathResource("datavec-api/xml/");
    File f = testDir.newFolder();
    cpr.copyDirectory(f);
    String path = new File(f, "xml_test_%d.txt").getAbsolutePath();

    InputSplit is = new NumberedFileInputSplit(path, 0, 2);

    RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new XmlFactory()));
    rr.initialize(is);

    testJacksonRecordReader(rr);
}
 
Example #9
Source File: JacksonRecordReaderTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testAppendingLabelsMetaData() throws Exception {
    ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
    File f = testDir.newFolder();
    cpr.copyDirectory(f);
    String path = new File(f, "json_test_%d.txt").getAbsolutePath();

    InputSplit is = new NumberedFileInputSplit(path, 0, 2);

    //Insert at the end:
    RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
                    new LabelGen());
    rr.initialize(is);

    List<List<Writable>> out = new ArrayList<>();
    while (rr.hasNext()) {
        out.add(rr.next());
    }
    assertEquals(3, out.size());

    rr.reset();

    List<List<Writable>> out2 = new ArrayList<>();
    List<Record> outRecord = new ArrayList<>();
    List<RecordMetaData> meta = new ArrayList<>();
    while (rr.hasNext()) {
        Record r = rr.nextRecord();
        out2.add(r.getRecord());
        outRecord.add(r);
        meta.add(r.getMetaData());
    }

    assertEquals(out, out2);

    List<Record> fromMeta = rr.loadFromMetaData(meta);
    assertEquals(outRecord, fromMeta);
}
 
Example #10
Source File: SameDiffRNNTestCases.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
        public MultiDataSetIterator getEvaluationTestData() throws Exception {
            int miniBatchSize = 10;
            int numLabelClasses = 6;

//            File featuresDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/features/").getFile();
//            File labelsDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/labels/").getFile();
            File featuresDirTest = Files.createTempDir();
            File labelsDirTest = Files.createTempDir();
            Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/features/", featuresDirTest);
            Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/labels/", labelsDirTest);

            SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
            trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
            SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
            trainLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149));

            DataSetIterator testData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
                    false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

            MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(testData);

            MultiDataSetPreProcessor pp = multiDataSet -> {
                INDArray l = multiDataSet.getLabels(0);
                l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1));
                multiDataSet.setLabels(0, l);
                multiDataSet.setLabelsMaskArray(0, null);
            };


            iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp));

            return iter;
        }
 
Example #11
Source File: RNNTestCases.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
        public MultiDataSetIterator getEvaluationTestData() throws Exception {
            int miniBatchSize = 10;
            int numLabelClasses = 6;

//            File featuresDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/features/").getFile();
//            File labelsDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/labels/").getFile();
            File featuresDirTest = Files.createTempDir();
            File labelsDirTest = Files.createTempDir();
            new ClassPathResource("dl4j-integration-tests/data/uci_seq/test/features/").copyDirectory(featuresDirTest);
            new ClassPathResource("dl4j-integration-tests/data/uci_seq/test/labels/").copyDirectory(labelsDirTest);

            SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
            trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
            SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
            trainLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149));

            DataSetIterator testData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
                    false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

            MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(testData);

            MultiDataSetPreProcessor pp = multiDataSet -> {
                INDArray l = multiDataSet.getLabels(0);
                l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2)-1));
                multiDataSet.setLabels(0, l);
                multiDataSet.setLabelsMaskArray(0, null);
            };


            iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(),pp));

            return iter;
        }
 
Example #12
Source File: RecordReaderMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testSplittingCSVSequenceMeta() throws Exception {
    //Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt"
    // as standard one-hot output
    //need to manually extract
    File rootDir = temporaryFolder.newFolder();
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
    }

    String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
    String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
    featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    RecordReaderMultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1)
                    .addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2)
                    .addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();

    srrmdsi.setCollectMetaData(true);

    int count = 0;
    while (srrmdsi.hasNext()) {
        MultiDataSet mds = srrmdsi.next();
        MultiDataSet fromMeta = srrmdsi.loadFromMetaData(mds.getExampleMetaData(RecordMetaData.class));
        assertEquals(mds, fromMeta);
        count++;
    }
    assertEquals(3, count);
}
 
Example #13
Source File: RecordReaderDataSetiteratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testSequenceRecordReaderReset() throws Exception {
    File rootDir = temporaryFolder.newFolder();
    //need to manually extract
    for (int i = 0; i < 3; i++) {
        FileUtils.copyFile(Resources.asFile(String.format("csvsequence_%d.txt", i)), new File(rootDir, String.format("csvsequence_%d.txt", i)));
        FileUtils.copyFile(Resources.asFile(String.format("csvsequencelabels_%d.txt", i)), new File(rootDir, String.format("csvsequencelabels_%d.txt", i)));
    }
    String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
    String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReaderDataSetIterator iter =
                    new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);

    assertEquals(3, iter.inputColumns());
    assertEquals(4, iter.totalOutcomes());

    int nResets = 5;
    for (int i = 0; i < nResets; i++) {
        iter.reset();
        int count = 0;
        while (iter.hasNext()) {
            DataSet ds = iter.next();
            INDArray features = ds.getFeatures();
            INDArray labels = ds.getLabels();
            assertArrayEquals(new long[] {1, 3, 4}, features.shape());
            assertArrayEquals(new long[] {1, 4, 4}, labels.shape());
            count++;
        }
        assertEquals(3, count);
    }
}
 
Example #14
Source File: RecordReaderDataSetiteratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testSequenceRecordReaderMeta() throws Exception {
    File rootDir = temporaryFolder.newFolder();
    //need to manually extract
    for (int i = 0; i < 3; i++) {
        FileUtils.copyFile(Resources.asFile(String.format("csvsequence_%d.txt", i)), new File(rootDir, String.format("csvsequence_%d.txt", i)));
        FileUtils.copyFile(Resources.asFile(String.format("csvsequencelabels_%d.txt", i)), new File(rootDir, String.format("csvsequencelabels_%d.txt", i)));
    }
    String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
    String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReaderDataSetIterator iter =
                    new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);

    iter.setCollectMetaData(true);

    assertEquals(3, iter.inputColumns());
    assertEquals(4, iter.totalOutcomes());

    while (iter.hasNext()) {
        DataSet ds = iter.next();
        List<RecordMetaData> meta = ds.getExampleMetaData(RecordMetaData.class);
        DataSet fromMeta = iter.loadFromMetaData(meta);

        assertEquals(ds, fromMeta);
    }
}
 
Example #15
Source File: RegexRecordReaderTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testRegexSequenceRecordReaderMeta() throws Exception {
    String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";

    ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/");
    File f = testDir.newFolder();
    cpr.copyDirectory(f);
    String path = new File(f, "logtestfile%d.txt").getAbsolutePath();

    InputSplit is = new NumberedFileInputSplit(path, 0, 1);

    SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
    rr.initialize(is);

    List<List<List<Writable>>> out = new ArrayList<>();
    while (rr.hasNext()) {
        out.add(rr.sequenceRecord());
    }

    assertEquals(2, out.size());
    List<List<List<Writable>>> out2 = new ArrayList<>();
    List<SequenceRecord> out3 = new ArrayList<>();
    List<RecordMetaData> meta = new ArrayList<>();
    rr.reset();
    while (rr.hasNext()) {
        SequenceRecord seqr = rr.nextSequence();
        out2.add(seqr.getSequenceRecord());
        out3.add(seqr);
        meta.add(seqr.getMetaData());
    }

    List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(meta);

    assertEquals(out, out2);
    assertEquals(out3, fromMeta);
}
 
Example #16
Source File: RegexRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testRegexSequenceRecordReaderMeta() throws Exception {
    String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";

    String path = new ClassPathResource("/logtestdata/logtestfile0.txt").getFile().toURI().toString();
    path = path.replace("0", "%d");
    InputSplit is = new NumberedFileInputSplit(path, 0, 1);

    SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
    rr.initialize(is);

    List<List<List<Writable>>> out = new ArrayList<>();
    while (rr.hasNext()) {
        out.add(rr.sequenceRecord());
    }

    assertEquals(2, out.size());
    List<List<List<Writable>>> out2 = new ArrayList<>();
    List<SequenceRecord> out3 = new ArrayList<>();
    List<RecordMetaData> meta = new ArrayList<>();
    rr.reset();
    while (rr.hasNext()) {
        SequenceRecord seqr = rr.nextSequence();
        out2.add(seqr.getSequenceRecord());
        out3.add(seqr);
        meta.add(seqr.getMetaData());
    }

    List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(meta);

    assertEquals(out, out2);
    assertEquals(out3, fromMeta);
}
 
Example #17
Source File: RegexRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testRegexSequenceRecordReader() throws Exception {
    String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";

    String path = new ClassPathResource("/logtestdata/logtestfile0.txt").getFile().toURI().toString();
    path = path.replace("0", "%d");

    InputSplit is = new NumberedFileInputSplit(path, 0, 1);

    SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
    rr.initialize(is);

    List<List<Writable>> exp0 = new ArrayList<>();
    exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"),
                    new Text("First entry message!")));
    exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"),
                    new Text("Second entry message!")));
    exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"),
                    new Text("Third entry message!")));


    List<List<Writable>> exp1 = new ArrayList<>();
    exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.011"), new Text("11"), new Text("DEBUG"),
                    new Text("First entry message!")));
    exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.012"), new Text("12"), new Text("INFO"),
                    new Text("Second entry message!")));
    exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.013"), new Text("13"), new Text("WARN"),
                    new Text("Third entry message!")));

    assertEquals(exp0, rr.sequenceRecord());
    assertEquals(exp1, rr.sequenceRecord());
    assertFalse(rr.hasNext());

    //Test resetting:
    rr.reset();
    assertEquals(exp0, rr.sequenceRecord());
    assertEquals(exp1, rr.sequenceRecord());
    assertFalse(rr.hasNext());
}
 
Example #18
Source File: JacksonRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testAppendingLabelsMetaData() throws Exception {
    ClassPathResource cpr = new ClassPathResource("json/json_test_0.txt");
    String path = cpr.getFile().getAbsolutePath().replace("0", "%d");

    InputSplit is = new NumberedFileInputSplit(path, 0, 2);

    //Insert at the end:
    RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
                    new LabelGen());
    rr.initialize(is);

    List<List<Writable>> out = new ArrayList<>();
    while (rr.hasNext()) {
        out.add(rr.next());
    }
    assertEquals(3, out.size());

    rr.reset();

    List<List<Writable>> out2 = new ArrayList<>();
    List<Record> outRecord = new ArrayList<>();
    List<RecordMetaData> meta = new ArrayList<>();
    while (rr.hasNext()) {
        Record r = rr.nextRecord();
        out2.add(r.getRecord());
        outRecord.add(r);
        meta.add(r.getMetaData());
    }

    assertEquals(out, out2);

    List<Record> fromMeta = rr.loadFromMetaData(meta);
    assertEquals(outRecord, fromMeta);
}
 
Example #19
Source File: JacksonRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testReadingYaml() throws Exception {
    //Exact same information as JSON format, but in YAML format

    ClassPathResource cpr = new ClassPathResource("yaml/yaml_test_0.txt");
    String path = cpr.getFile().getAbsolutePath().replace("0", "%d");

    InputSplit is = new NumberedFileInputSplit(path, 0, 2);

    RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new YAMLFactory()));
    rr.initialize(is);

    testJacksonRecordReader(rr);
}
 
Example #20
Source File: JacksonRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testReadingXml() throws Exception {
    //Exact same information as JSON format, but in XML format

    ClassPathResource cpr = new ClassPathResource("xml/xml_test_0.txt");
    String path = cpr.getFile().getAbsolutePath().replace("0", "%d");

    InputSplit is = new NumberedFileInputSplit(path, 0, 2);

    RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new XmlFactory()));
    rr.initialize(is);

    testJacksonRecordReader(rr);
}
 
Example #21
Source File: RecordReaderMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSplittingCSVSequence() throws Exception {
    //Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt"
    // as standard one-hot output
    //need to manually extract
    File rootDir = temporaryFolder.newFolder();
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
    }

    String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
    String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReaderDataSetIterator iter =
                    new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);

    SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
    featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1)
                    .addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2)
                    .addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();

    while (iter.hasNext()) {
        DataSet ds = iter.next();
        INDArray fds = ds.getFeatures();
        INDArray lds = ds.getLabels();

        MultiDataSet mds = srrmdsi.next();
        assertEquals(2, mds.getFeatures().length);
        assertEquals(1, mds.getLabels().length);
        assertNull(mds.getFeaturesMaskArrays());
        assertNull(mds.getLabelsMaskArrays());
        INDArray[] fmds = mds.getFeatures();
        INDArray[] lmds = mds.getLabels();

        assertNotNull(fmds);
        assertNotNull(lmds);
        for (int i = 0; i < fmds.length; i++)
            assertNotNull(fmds[i]);
        for (int i = 0; i < lmds.length; i++)
            assertNotNull(lmds[i]);

        INDArray expIn1 = fds.get(all(), NDArrayIndex.interval(0, 1, true), all());
        INDArray expIn2 = fds.get(all(), NDArrayIndex.interval(2, 2, true), all());

        assertEquals(expIn1, fmds[0]);
        assertEquals(expIn2, fmds[1]);
        assertEquals(lds, lmds[0]);
    }
    assertFalse(srrmdsi.hasNext());
}
 
Example #22
Source File: LstmTimeSeriesExample.java    From Java-Deep-Learning-Cookbook with MIT License 4 votes vote down vote up
public static void main(String[] args) throws IOException, InterruptedException {
    if(FEATURE_DIR.equals("{PATH-TO-PHYSIONET-FEATURES}") || LABEL_DIR.equals("{PATH-TO-PHYSIONET-LABELS")){
        System.out.println("Please provide proper directory path in place of: PATH-TO-PHYSIONET-FEATURES && PATH-TO-PHYSIONET-LABELS");
        throw new FileNotFoundException();
    }
    SequenceRecordReader trainFeaturesReader = new CSVSequenceRecordReader(1, ",");
    trainFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",0,3199));
    SequenceRecordReader trainLabelsReader = new CSVSequenceRecordReader();
    trainLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",0,3199));
    DataSetIterator trainDataSetIterator = new SequenceRecordReaderDataSetIterator(trainFeaturesReader,trainLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    SequenceRecordReader testFeaturesReader = new CSVSequenceRecordReader(1, ",");
    testFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",3200,3999));
    SequenceRecordReader testLabelsReader = new CSVSequenceRecordReader();
    testLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",3200,3999));
    DataSetIterator testDataSetIterator = new SequenceRecordReaderDataSetIterator(testFeaturesReader,testLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    ComputationGraphConfiguration configuration = new NeuralNetConfiguration.Builder()
                                                    .seed(RANDOM_SEED)
                                                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                                                    .weightInit(WeightInit.XAVIER)
                                                    .updater(new Adam())
                                                    .dropOut(0.9)
                                                    .graphBuilder()
                                                    .addInputs("trainFeatures")
                                                    .setOutputs("predictMortality")
                                                    .addLayer("L1", new LSTM.Builder()
                                                                                   .nIn(86)
                                                                                    .nOut(200)
                                                                                    .forgetGateBiasInit(1)
                                                                                    .activation(Activation.TANH)
                                                                                    .build(),"trainFeatures")
                                                    .addLayer("predictMortality", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                                                                        .activation(Activation.SOFTMAX)
                                                                                        .nIn(200).nOut(2).build(),"L1")
                                                    .build();

    ComputationGraph model = new ComputationGraph(configuration);

    for(int i=0;i<1;i++){
       model.fit(trainDataSetIterator);
       trainDataSetIterator.reset();
    }
    ROC evaluation = new ROC(100);
    while (testDataSetIterator.hasNext()) {
        DataSet batch = testDataSetIterator.next();
        INDArray[] output = model.output(batch.getFeatures());
        evaluation.evalTimeSeries(batch.getLabels(), output[0]);
    }
    
    System.out.println(evaluation.calculateAUC());
    System.out.println(evaluation.stats());
}
 
Example #23
Source File: NumberedFileInputSplitExample.java    From Java-Deep-Learning-Cookbook with MIT License 4 votes vote down vote up
public static void main(String[] args) {
    NumberedFileInputSplit numberedFileInputSplit = new NumberedFileInputSplit("numberedfiles/file%d.txt",1,4);
    numberedFileInputSplit.locationsIterator().forEachRemaining(System.out::println);
}
 
Example #24
Source File: LstmTimeSeriesExample.java    From Java-Deep-Learning-Cookbook with MIT License 4 votes vote down vote up
public static void main(String[] args) throws IOException, InterruptedException {
    if(FEATURE_DIR.equals("{PATH-TO-PHYSIONET-FEATURES}") || LABEL_DIR.equals("{PATH-TO-PHYSIONET-LABELS")){
        System.out.println("Please provide proper directory path in place of: PATH-TO-PHYSIONET-FEATURES && PATH-TO-PHYSIONET-LABELS");
        throw new FileNotFoundException();
    }
    SequenceRecordReader trainFeaturesReader = new CSVSequenceRecordReader(1, ",");
    trainFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",0,3199));
    SequenceRecordReader trainLabelsReader = new CSVSequenceRecordReader();
    trainLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",0,3199));
    DataSetIterator trainDataSetIterator = new SequenceRecordReaderDataSetIterator(trainFeaturesReader,trainLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    SequenceRecordReader testFeaturesReader = new CSVSequenceRecordReader(1, ",");
    testFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",3200,3999));
    SequenceRecordReader testLabelsReader = new CSVSequenceRecordReader();
    testLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",3200,3999));
    DataSetIterator testDataSetIterator = new SequenceRecordReaderDataSetIterator(testFeaturesReader,testLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    ComputationGraphConfiguration configuration = new NeuralNetConfiguration.Builder()
                                                    .seed(RANDOM_SEED)
                                                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                                                    .weightInit(WeightInit.XAVIER)
                                                    .updater(new Adam())
                                                    .dropOut(0.9)
                                                    .graphBuilder()
                                                    .addInputs("trainFeatures")
                                                    .setOutputs("predictMortality")
                                                    .addLayer("L1", new LSTM.Builder()
                                                                                   .nIn(86)
                                                                                    .nOut(200)
                                                                                    .forgetGateBiasInit(1)
                                                                                    .activation(Activation.TANH)
                                                                                    .build(),"trainFeatures")
                                                    .addLayer("predictMortality", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                                                                        .activation(Activation.SOFTMAX)
                                                                                        .nIn(200).nOut(2).build(),"L1")
                                                    .build();

    ComputationGraph model = new ComputationGraph(configuration);

    for(int i=0;i<1;i++){
       model.fit(trainDataSetIterator);
       trainDataSetIterator.reset();
    }
    ROC evaluation = new ROC(100);
    while (testDataSetIterator.hasNext()) {
        DataSet batch = testDataSetIterator.next();
        INDArray[] output = model.output(batch.getFeatures());
        evaluation.evalTimeSeries(batch.getLabels(), output[0]);
    }
    
    System.out.println(evaluation.calculateAUC());
    System.out.println(evaluation.stats());
}
 
Example #25
Source File: TestDataVecDataSetFunctions.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testDataVecSequencePairDataSetFunction() throws Exception {
    JavaSparkContext sc = getContext();

    File f = testDir.newFolder();
    ClassPathResource cpr = new ClassPathResource("dl4j-spark/csvsequence/");
    cpr.copyDirectory(f);
    String path = f.getAbsolutePath() + "/*";

    PathToKeyConverter pathConverter = new PathToKeyConverterFilename();
    JavaPairRDD<Text, BytesPairWritable> toWrite =
                    DataVecSparkUtil.combineFilesForSequenceFile(sc, path, path, pathConverter);

    Path p = testDir.newFolder("dl4j_testSeqPairFn").toPath();
    p.toFile().deleteOnExit();
    String outPath = p.toString() + "/out";
    new File(outPath).deleteOnExit();
    toWrite.saveAsNewAPIHadoopFile(outPath, Text.class, BytesPairWritable.class, SequenceFileOutputFormat.class);

    //Load from sequence file:
    JavaPairRDD<Text, BytesPairWritable> fromSeq = sc.sequenceFile(outPath, Text.class, BytesPairWritable.class);

    SequenceRecordReader srr1 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader srr2 = new CSVSequenceRecordReader(1, ",");
    PairSequenceRecordReaderBytesFunction psrbf = new PairSequenceRecordReaderBytesFunction(srr1, srr2);
    JavaRDD<Tuple2<List<List<Writable>>, List<List<Writable>>>> writables = fromSeq.map(psrbf);

    //Map to DataSet:
    DataVecSequencePairDataSetFunction pairFn = new DataVecSequencePairDataSetFunction();
    JavaRDD<DataSet> data = writables.map(pairFn);
    List<DataSet> sparkData = data.collect();


    //Now: do the same thing locally (SequenceRecordReaderDataSetIterator) and compare
    String featuresPath = FilenameUtils.concat(f.getAbsolutePath(), "csvsequence_%d.txt");

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));

    SequenceRecordReaderDataSetIterator iter =
                    new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, -1, true);

    List<DataSet> localData = new ArrayList<>(3);
    while (iter.hasNext())
        localData.add(iter.next());

    assertEquals(3, sparkData.size());
    assertEquals(3, localData.size());

    for (int i = 0; i < 3; i++) {
        //Check shapes etc. data sets order may differ for spark vs. local
        DataSet dsSpark = sparkData.get(i);
        DataSet dsLocal = localData.get(i);

        assertNull(dsSpark.getFeaturesMaskArray());
        assertNull(dsSpark.getLabelsMaskArray());

        INDArray fSpark = dsSpark.getFeatures();
        INDArray fLocal = dsLocal.getFeatures();
        INDArray lSpark = dsSpark.getLabels();
        INDArray lLocal = dsLocal.getLabels();

        val s = new long[] {1, 3, 4}; //1 example, 3 values, 3 time steps
        assertArrayEquals(s, fSpark.shape());
        assertArrayEquals(s, fLocal.shape());
        assertArrayEquals(s, lSpark.shape());
        assertArrayEquals(s, lLocal.shape());
    }


    //Check that results are the same (order not withstanding)
    boolean[] found = new boolean[3];
    for (int i = 0; i < 3; i++) {
        int foundIndex = -1;
        DataSet ds = sparkData.get(i);
        for (int j = 0; j < 3; j++) {
            if (ds.equals(localData.get(j))) {
                if (foundIndex != -1)
                    fail(); //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                foundIndex = j;
                if (found[foundIndex])
                    fail(); //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                found[foundIndex] = true; //mark this one as seen before
            }
        }
    }
    int count = 0;
    for (boolean b : found)
        if (b)
            count++;
    assertEquals(3, count); //Expect all 3 and exactly 3 pairwise matches between spark and local versions
}
 
Example #26
Source File: RecordReaderMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testVariableLengthTSMeta() throws Exception {
    //need to manually extract
    File rootDir = temporaryFolder.newFolder();
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
    }
    //Set up SequenceRecordReaderDataSetIterators for comparison

    String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
    String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabelsShort_%d.txt");

    //Set up
    SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ",");
    featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ",");
    featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1)
                    .addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in")
                    .addOutputOneHot("out", 0, 4)
                    .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build();

    RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1)
                    .addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in")
                    .addOutputOneHot("out", 0, 4)
                    .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build();

    rrmdsiStart.setCollectMetaData(true);
    rrmdsiEnd.setCollectMetaData(true);

    int count = 0;
    while (rrmdsiStart.hasNext()) {
        MultiDataSet mdsStart = rrmdsiStart.next();
        MultiDataSet mdsEnd = rrmdsiEnd.next();

        MultiDataSet mdsStartFromMeta =
                        rrmdsiStart.loadFromMetaData(mdsStart.getExampleMetaData(RecordMetaData.class));
        MultiDataSet mdsEndFromMeta = rrmdsiEnd.loadFromMetaData(mdsEnd.getExampleMetaData(RecordMetaData.class));

        assertEquals(mdsStart, mdsStartFromMeta);
        assertEquals(mdsEnd, mdsEndFromMeta);

        count++;
    }
    assertFalse(rrmdsiStart.hasNext());
    assertFalse(rrmdsiEnd.hasNext());
    assertEquals(3, count);
}
 
Example #27
Source File: RecordReaderMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testVariableLengthTS() throws Exception {
    //need to manually extract
    File rootDir = temporaryFolder.newFolder();
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
    }

    String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
    String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabelsShort_%d.txt");

    //Set up SequenceRecordReaderDataSetIterators for comparison

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
    featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReaderDataSetIterator iterAlignStart = new SequenceRecordReaderDataSetIterator(featureReader,
                    labelReader, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_START);

    SequenceRecordReaderDataSetIterator iterAlignEnd = new SequenceRecordReaderDataSetIterator(featureReader2,
                    labelReader2, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);


    //Set up
    SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ",");
    featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ",");
    featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1)
                    .addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in")
                    .addOutputOneHot("out", 0, 4)
                    .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build();

    RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1)
                    .addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in")
                    .addOutputOneHot("out", 0, 4)
                    .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build();


    while (iterAlignStart.hasNext()) {
        DataSet dsStart = iterAlignStart.next();
        DataSet dsEnd = iterAlignEnd.next();

        MultiDataSet mdsStart = rrmdsiStart.next();
        MultiDataSet mdsEnd = rrmdsiEnd.next();

        assertEquals(1, mdsStart.getFeatures().length);
        assertEquals(1, mdsStart.getLabels().length);
        //assertEquals(1, mdsStart.getFeaturesMaskArrays().length); //Features data is always longer -> don't need mask arrays for it
        assertEquals(1, mdsStart.getLabelsMaskArrays().length);

        assertEquals(1, mdsEnd.getFeatures().length);
        assertEquals(1, mdsEnd.getLabels().length);
        //assertEquals(1, mdsEnd.getFeaturesMaskArrays().length);
        assertEquals(1, mdsEnd.getLabelsMaskArrays().length);


        assertEquals(dsStart.getFeatures(), mdsStart.getFeatures(0));
        assertEquals(dsStart.getLabels(), mdsStart.getLabels(0));
        assertEquals(dsStart.getLabelsMaskArray(), mdsStart.getLabelsMaskArray(0));

        assertEquals(dsEnd.getFeatures(), mdsEnd.getFeatures(0));
        assertEquals(dsEnd.getLabels(), mdsEnd.getLabels(0));
        assertEquals(dsEnd.getLabelsMaskArray(), mdsEnd.getLabelsMaskArray(0));
    }
    assertFalse(rrmdsiStart.hasNext());
    assertFalse(rrmdsiEnd.hasNext());
}
 
Example #28
Source File: JacksonRecordReaderTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testAppendingLabels() throws Exception {

    ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
    File f = testDir.newFolder();
    cpr.copyDirectory(f);
    String path = new File(f, "json_test_%d.txt").getAbsolutePath();

    InputSplit is = new NumberedFileInputSplit(path, 0, 2);

    //Insert at the end:
    RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
                    new LabelGen());
    rr.initialize(is);

    List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"),
                    new IntWritable(0));
    assertEquals(exp0, rr.next());

    List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"),
                    new IntWritable(1));
    assertEquals(exp1, rr.next());

    List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"),
                    new IntWritable(2));
    assertEquals(exp2, rr.next());

    //Insert at position 0:
    rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
                    new LabelGen(), 0);
    rr.initialize(is);

    exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"),
                    new Text("cxValue0"));
    assertEquals(exp0, rr.next());

    exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"),
                    new Text("cxValue1"));
    assertEquals(exp1, rr.next());

    exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"),
                    new Text("MISSING_CX"));
    assertEquals(exp2, rr.next());
}
 
Example #29
Source File: JacksonRecordReaderTest.java    From DataVec with Apache License 2.0 4 votes vote down vote up
@Test
public void testAppendingLabels() throws Exception {
    ClassPathResource cpr = new ClassPathResource("json/json_test_0.txt");
    String path = cpr.getFile().getAbsolutePath().replace("0", "%d");

    InputSplit is = new NumberedFileInputSplit(path, 0, 2);

    //Insert at the end:
    RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
                    new LabelGen());
    rr.initialize(is);

    List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"),
                    new IntWritable(0));
    assertEquals(exp0, rr.next());

    List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"),
                    new IntWritable(1));
    assertEquals(exp1, rr.next());

    List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"),
                    new IntWritable(2));
    assertEquals(exp2, rr.next());

    //Insert at position 0:
    rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
                    new LabelGen(), 0);
    rr.initialize(is);

    exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"),
                    new Text("cxValue0"));
    assertEquals(exp0, rr.next());

    exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"),
                    new Text("cxValue1"));
    assertEquals(exp1, rr.next());

    exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"),
                    new Text("MISSING_CX"));
    assertEquals(exp2, rr.next());
}
 
Example #30
Source File: RecordReaderDataSetiteratorTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSequenceRecordReaderMultiRegression() throws Exception {
    File rootDir = temporaryFolder.newFolder();
    //need to manually extract
    for (int i = 0; i < 3; i++) {
        FileUtils.copyFile(Resources.asFile(String.format("csvsequence_%d.txt", i)), new File(rootDir, String.format("csvsequence_%d.txt", i)));
    }
    String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");

    SequenceRecordReader reader = new CSVSequenceRecordReader(1, ",");
    reader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));

    SequenceRecordReaderDataSetIterator iter =
            new SequenceRecordReaderDataSetIterator(reader, 1, 2, 1, true);

    assertEquals(1, iter.inputColumns());
    assertEquals(2, iter.totalOutcomes());

    List<DataSet> dsList = new ArrayList<>();
    while (iter.hasNext()) {
        dsList.add(iter.next());
    }

    assertEquals(3, dsList.size()); //3 files
    for (int i = 0; i < 3; i++) {
        DataSet ds = dsList.get(i);
        INDArray features = ds.getFeatures();
        INDArray labels = ds.getLabels();
        assertArrayEquals(new long[] {1, 1, 4}, features.shape()); //1 examples, 1 values, 4 time steps
        assertArrayEquals(new long[] {1, 2, 4}, labels.shape());

        INDArray f2d = features.get(point(0), all(), all()).transpose();
        INDArray l2d = labels.get(point(0), all(), all()).transpose();

        switch (i){
            case 0:
                assertEquals(Nd4j.create(new double[]{0,10,20,30}, new int[]{4,1}).castTo(DataType.FLOAT), f2d);
                assertEquals(Nd4j.create(new double[][]{{1,2}, {11,12}, {21,22}, {31,32}}).castTo(DataType.FLOAT), l2d);
                break;
            case 1:
                assertEquals(Nd4j.create(new double[]{100,110,120,130}, new int[]{4,1}).castTo(DataType.FLOAT), f2d);
                assertEquals(Nd4j.create(new double[][]{{101,102}, {111,112}, {121,122}, {131,132}}).castTo(DataType.FLOAT), l2d);
                break;
            case 2:
                assertEquals(Nd4j.create(new double[]{200,210,220,230}, new int[]{4,1}).castTo(DataType.FLOAT), f2d);
                assertEquals(Nd4j.create(new double[][]{{201,202}, {211,212}, {221,222}, {231,232}}).castTo(DataType.FLOAT), l2d);
                break;
            default:
                throw new RuntimeException();
        }
    }


    iter.reset();
    int count = 0;
    while (iter.hasNext()) {
        iter.next();
        count++;
    }
    assertEquals(3, count);
}