Java Code Examples for org.datavec.api.io.labels.ParentPathLabelGenerator

The following examples show how to use org.datavec.api.io.labels.ParentPathLabelGenerator. These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source Project: DataVec   Source File: TestImageRecordReader.java    License: Apache License 2.0 5 votes vote down vote up
@Test
public void testMetaData() throws IOException {

    ClassPathResource cpr = new ClassPathResource("/testimages/class0/0.jpg");
    File parentDir = cpr.getFile().getParentFile().getParentFile();
    //        System.out.println(f.getAbsolutePath());
    //        System.out.println(f.getParentFile().getParentFile().getAbsolutePath());
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
    ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
    rr.initialize(new FileSplit(parentDir));

    List<List<Writable>> out = new ArrayList<>();
    while (rr.hasNext()) {
        List<Writable> l = rr.next();
        out.add(l);
        assertEquals(2, l.size());
    }

    assertEquals(6, out.size());

    rr.reset();
    List<List<Writable>> out2 = new ArrayList<>();
    List<Record> out3 = new ArrayList<>();
    List<RecordMetaData> meta = new ArrayList<>();

    while (rr.hasNext()) {
        Record r = rr.nextRecord();
        out2.add(r.getRecord());
        out3.add(r);
        meta.add(r.getMetaData());
        //            System.out.println(r.getMetaData() + "\t" + r.getRecord().get(1));
    }

    assertEquals(out, out2);

    List<Record> fromMeta = rr.loadFromMetaData(meta);
    assertEquals(out3, fromMeta);
}
 
Example 2
Source Project: DataVec   Source File: TestImageRecordReader.java    License: Apache License 2.0 5 votes vote down vote up
@Test
public void testImageRecordReaderLabelsOrder() throws Exception {
    //Labels order should be consistent, regardless of file iteration order

    //Idea: labels order should be consistent regardless of input file order
    File f0 = new ClassPathResource("/testimages/class0/0.jpg").getFile();
    File f1 = new ClassPathResource("/testimages/class1/A.jpg").getFile();

    List<URI> order0 = Arrays.asList(f0.toURI(), f1.toURI());
    List<URI> order1 = Arrays.asList(f1.toURI(), f0.toURI());

    ParentPathLabelGenerator labelMaker0 = new ParentPathLabelGenerator();
    ImageRecordReader rr0 = new ImageRecordReader(32, 32, 3, labelMaker0);
    rr0.initialize(new CollectionInputSplit(order0));

    ParentPathLabelGenerator labelMaker1 = new ParentPathLabelGenerator();
    ImageRecordReader rr1 = new ImageRecordReader(32, 32, 3, labelMaker1);
    rr1.initialize(new CollectionInputSplit(order1));

    List<String> labels0 = rr0.getLabels();
    List<String> labels1 = rr1.getLabels();

    //        System.out.println(labels0);
    //        System.out.println(labels1);

    assertEquals(labels0, labels1);
}
 
Example 3
Source Project: DataVec   Source File: TestImageRecordReader.java    License: Apache License 2.0 5 votes vote down vote up
@Test
public void testListenerInvocationBatch() throws IOException {
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
    ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
    File parent = new ClassPathResource("/testimages/class0").getFile();
    int numFiles = parent.list().length;
    rr.initialize(new FileSplit(parent));
    CountingListener counting = new CountingListener(new LogRecordListener());
    rr.setListeners(counting);
    rr.next(numFiles + 1);
    assertEquals(numFiles, counting.getCount());
}
 
Example 4
Source Project: DataVec   Source File: TestImageRecordReader.java    License: Apache License 2.0 5 votes vote down vote up
@Test
public void testListenerInvocationSingle() throws IOException {
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
    ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
    File parent = new ClassPathResource("/testimages/class0").getFile();
    int numFiles = parent.list().length;
    rr.initialize(new FileSplit(parent));
    CountingListener counting = new CountingListener(new LogRecordListener());
    rr.setListeners(counting);
    while(rr.hasNext()) {
        rr.next();
    }
    assertEquals(numFiles, counting.getCount());
}
 
Example 5
Source Project: DataVec   Source File: LabelGeneratorTest.java    License: Apache License 2.0 5 votes vote down vote up
@Test
public void testParentPathLabelGenerator() throws Exception {
    //https://github.com/deeplearning4j/DataVec/issues/273
    File orig = new ClassPathResource("testimages/class0/0.jpg").getFile();

    for(String dirPrefix : new String[]{"m.", "m"}) {
        File f = testDir.newFolder();

        int numDirs = 3;
        int filesPerDir = 4;

        for (int i = 0; i < numDirs; i++) {
            File currentLabelDir = new File(f, dirPrefix + i);
            currentLabelDir.mkdirs();
            for (int j = 0; j < filesPerDir; j++) {
                File f3 = new File(currentLabelDir, "myImg_" + j + ".jpg");
                FileUtils.copyFile(orig, f3);
                assertTrue(f3.exists());
            }
        }

        ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
        rr.initialize(new FileSplit(f));

        List<String> labelsAct = rr.getLabels();
        List<String> labelsExp = Arrays.asList(dirPrefix + "0", dirPrefix + "1", dirPrefix + "2");
        assertEquals(labelsExp, labelsAct);

        int expCount = numDirs * filesPerDir;
        int actCount = 0;
        while (rr.hasNext()) {
            rr.next();
            actCount++;
        }
        assertEquals(expCount, actCount);
    }
}
 
Example 6
Source Project: deeplearning4j   Source File: TestImageRecordReader.java    License: Apache License 2.0 5 votes vote down vote up
@Test
public void testMetaData() throws IOException {

    File parentDir = testDir.newFolder();
    new ClassPathResource("datavec-data-image/testimages/").copyDirectory(parentDir);
    //        System.out.println(f.getAbsolutePath());
    //        System.out.println(f.getParentFile().getParentFile().getAbsolutePath());
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
    ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
    rr.initialize(new FileSplit(parentDir));

    List<List<Writable>> out = new ArrayList<>();
    while (rr.hasNext()) {
        List<Writable> l = rr.next();
        out.add(l);
        assertEquals(2, l.size());
    }

    assertEquals(6, out.size());

    rr.reset();
    List<List<Writable>> out2 = new ArrayList<>();
    List<Record> out3 = new ArrayList<>();
    List<RecordMetaData> meta = new ArrayList<>();

    while (rr.hasNext()) {
        Record r = rr.nextRecord();
        out2.add(r.getRecord());
        out3.add(r);
        meta.add(r.getMetaData());
        //            System.out.println(r.getMetaData() + "\t" + r.getRecord().get(1));
    }

    assertEquals(out, out2);

    List<Record> fromMeta = rr.loadFromMetaData(meta);
    assertEquals(out3, fromMeta);
}
 
Example 7
Source Project: deeplearning4j   Source File: TestImageRecordReader.java    License: Apache License 2.0 5 votes vote down vote up
@Test
public void testImageRecordReaderLabelsOrder() throws Exception {
    //Labels order should be consistent, regardless of file iteration order

    //Idea: labels order should be consistent regardless of input file order
    File f = testDir.newFolder();
    new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f);
    File f0 = new File(f, "/class0/0.jpg");
    File f1 = new File(f, "/class1/A.jpg");

    List<URI> order0 = Arrays.asList(f0.toURI(), f1.toURI());
    List<URI> order1 = Arrays.asList(f1.toURI(), f0.toURI());

    ParentPathLabelGenerator labelMaker0 = new ParentPathLabelGenerator();
    ImageRecordReader rr0 = new ImageRecordReader(32, 32, 3, labelMaker0);
    rr0.initialize(new CollectionInputSplit(order0));

    ParentPathLabelGenerator labelMaker1 = new ParentPathLabelGenerator();
    ImageRecordReader rr1 = new ImageRecordReader(32, 32, 3, labelMaker1);
    rr1.initialize(new CollectionInputSplit(order1));

    List<String> labels0 = rr0.getLabels();
    List<String> labels1 = rr1.getLabels();

    //        System.out.println(labels0);
    //        System.out.println(labels1);

    assertEquals(labels0, labels1);
}
 
Example 8
Source Project: deeplearning4j   Source File: TestImageRecordReader.java    License: Apache License 2.0 5 votes vote down vote up
@Test
public void testListenerInvocationBatch() throws IOException {
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
    ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
    File f = testDir.newFolder();
    new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f);

    File parent = f;
    int numFiles = 6;
    rr.initialize(new FileSplit(parent));
    CountingListener counting = new CountingListener(new LogRecordListener());
    rr.setListeners(counting);
    rr.next(numFiles + 1);
    assertEquals(numFiles, counting.getCount());
}
 
Example 9
Source Project: deeplearning4j   Source File: TestImageRecordReader.java    License: Apache License 2.0 5 votes vote down vote up
@Test
public void testListenerInvocationSingle() throws IOException {
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
    ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
    File parent = testDir.newFolder();
    new ClassPathResource("datavec-data-image/testimages/class0/").copyDirectory(parent);
    int numFiles = parent.list().length;
    rr.initialize(new FileSplit(parent));
    CountingListener counting = new CountingListener(new LogRecordListener());
    rr.setListeners(counting);
    while(rr.hasNext()) {
        rr.next();
    }
    assertEquals(numFiles, counting.getCount());
}
 
Example 10
Source Project: deeplearning4j   Source File: LabelGeneratorTest.java    License: Apache License 2.0 5 votes vote down vote up
@Test
public void testParentPathLabelGenerator() throws Exception {
    //https://github.com/deeplearning4j/DataVec/issues/273
    File orig = new ClassPathResource("datavec-data-image/testimages/class0/0.jpg").getFile();

    for(String dirPrefix : new String[]{"m.", "m"}) {
        File f = testDir.newFolder();

        int numDirs = 3;
        int filesPerDir = 4;

        for (int i = 0; i < numDirs; i++) {
            File currentLabelDir = new File(f, dirPrefix + i);
            currentLabelDir.mkdirs();
            for (int j = 0; j < filesPerDir; j++) {
                File f3 = new File(currentLabelDir, "myImg_" + j + ".jpg");
                FileUtils.copyFile(orig, f3);
                assertTrue(f3.exists());
            }
        }

        ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
        rr.initialize(new FileSplit(f));

        List<String> labelsAct = rr.getLabels();
        List<String> labelsExp = Arrays.asList(dirPrefix + "0", dirPrefix + "1", dirPrefix + "2");
        assertEquals(labelsExp, labelsAct);

        int expCount = numDirs * filesPerDir;
        int actCount = 0;
        while (rr.hasNext()) {
            rr.next();
            actCount++;
        }
        assertEquals(expCount, actCount);
    }
}
 
Example 11
Source Project: DataVec   Source File: TestImageRecordReader.java    License: Apache License 2.0 4 votes vote down vote up
@Test
public void testImageRecordReaderRandomization() throws Exception {
    //Order of FileSplit+ImageRecordReader should be different after reset

    //Idea: labels order should be consistent regardless of input file order
    File f0 = new ClassPathResource("/testimages/").getFile();

    FileSplit fs = new FileSplit(f0, new Random(12345));

    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
    ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
    rr.initialize(fs);

    List<List<Writable>> out1 = new ArrayList<>();
    List<File> order1 = new ArrayList<>();
    while (rr.hasNext()) {
        out1.add(rr.next());
        order1.add(rr.getCurrentFile());
    }
    assertEquals(6, out1.size());
    assertEquals(6, order1.size());

    rr.reset();
    List<List<Writable>> out2 = new ArrayList<>();
    List<File> order2 = new ArrayList<>();
    while (rr.hasNext()) {
        out2.add(rr.next());
        order2.add(rr.getCurrentFile());
    }
    assertEquals(6, out2.size());
    assertEquals(6, order2.size());

    assertNotEquals(out1, out2);
    assertNotEquals(order1, order2);

    //Check that different seed gives different order for the initial iteration
    FileSplit fs2 = new FileSplit(f0, new Random(999999999));

    ParentPathLabelGenerator labelMaker2 = new ParentPathLabelGenerator();
    ImageRecordReader rr2 = new ImageRecordReader(32, 32, 3, labelMaker2);
    rr2.initialize(fs2);

    List<File> order3 = new ArrayList<>();
    while (rr2.hasNext()) {
        rr2.next();
        order3.add(rr2.getCurrentFile());
    }
    assertEquals(6, order3.size());

    assertNotEquals(order1, order3);
}
 
Example 12
public void evaluateModel(MultiLayerNetwork model, boolean invertColors) throws IOException {
        LOGGER.info("******EVALUATE MODEL******");

        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
        ImageRecordReader recordReader = new ImageRecordReader(height,width,channels,labelMaker);
//        recordReader.setListeners(new LogRecordListener());

        // Initialize the record reader
        // add a listener, to extract the name

        File testData = new File(DATA_PATH + "/mnist_png/testing");
        FileSplit test = new FileSplit(testData,NativeImageLoader.ALLOWED_FORMATS,randNumGen);

        // The model trained on the training dataset split
        // now that it has trained we evaluate against the
        // test data of images the network has not seen

        recordReader.initialize(test);
        DataNormalization scaler = new ImagePreProcessingScaler(invertColors ? 1 : 0, invertColors ? 0 : 1);
        DataSetIterator testIter = new RecordReaderDataSetIterator(recordReader,batchSize,1,outputNum);
        scaler.fit(testIter);
        testIter.setPreProcessor(scaler);

        /*
        log the order of the labels for later use
        In previous versions the label order was consistent, but random
        In current verions label order is lexicographic
        preserving the RecordReader Labels order is no
        longer needed left in for demonstration
        purposes
        */
        LOGGER.info(recordReader.getLabels().toString());

        // Create Eval object with 10 possible classes
        Evaluation eval = new Evaluation(outputNum);


        // Evaluate the network
        while (testIter.hasNext()) {
            DataSet next = testIter.next();
            INDArray output = model.output(next.getFeatureMatrix());
            // Compare the Feature Matrix from the model
            // with the labels from the RecordReader
            eval.eval(next.getLabels(), output);

        }

        LOGGER.info(eval.stats());
    }
 
Example 13
Source Project: deeplearning4j   Source File: TestImageRecordReader.java    License: Apache License 2.0 4 votes vote down vote up
@Test
public void testImageRecordReaderRandomization() throws Exception {
    //Order of FileSplit+ImageRecordReader should be different after reset

    //Idea: labels order should be consistent regardless of input file order
    File f0 = testDir.newFolder();
    new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);

    FileSplit fs = new FileSplit(f0, new Random(12345));

    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
    ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
    rr.initialize(fs);

    List<List<Writable>> out1 = new ArrayList<>();
    List<File> order1 = new ArrayList<>();
    while (rr.hasNext()) {
        out1.add(rr.next());
        order1.add(rr.getCurrentFile());
    }
    assertEquals(6, out1.size());
    assertEquals(6, order1.size());

    rr.reset();
    List<List<Writable>> out2 = new ArrayList<>();
    List<File> order2 = new ArrayList<>();
    while (rr.hasNext()) {
        out2.add(rr.next());
        order2.add(rr.getCurrentFile());
    }
    assertEquals(6, out2.size());
    assertEquals(6, order2.size());

    assertNotEquals(out1, out2);
    assertNotEquals(order1, order2);

    //Check that different seed gives different order for the initial iteration
    FileSplit fs2 = new FileSplit(f0, new Random(999999999));

    ParentPathLabelGenerator labelMaker2 = new ParentPathLabelGenerator();
    ImageRecordReader rr2 = new ImageRecordReader(32, 32, 3, labelMaker2);
    rr2.initialize(fs2);

    List<File> order3 = new ArrayList<>();
    while (rr2.hasNext()) {
        rr2.next();
        order3.add(rr2.getCurrentFile());
    }
    assertEquals(6, order3.size());

    assertNotEquals(order1, order3);
}
 
Example 14
Source Project: deeplearning4j   Source File: FileBatchRecordReaderTest.java    License: Apache License 2.0 4 votes vote down vote up
@Test
public void testCsv() throws Exception {
    File extractedSourceDir = testDir.newFolder();
    new ClassPathResource("datavec-data-image/testimages").copyDirectory(extractedSourceDir);
    File baseDir = testDir.newFolder();


    List<File> c = new ArrayList<>(FileUtils.listFiles(extractedSourceDir, null, true));
    assertEquals(6, c.size());

    Collections.sort(c, new Comparator<File>() {
        @Override
        public int compare(File o1, File o2) {
            return o1.getPath().compareTo(o2.getPath());
        }
    });


    FileBatch fb = FileBatch.forFiles(c);
    File saveFile = new File(baseDir, "saved.zip");
    fb.writeAsZip(saveFile);
    fb = FileBatch.readFromZip(saveFile);

    PathLabelGenerator labelMaker = new ParentPathLabelGenerator();
    ImageRecordReader rr = new ImageRecordReader(32, 32, 1, labelMaker);
    rr.setLabels(Arrays.asList("class0", "class1"));
    FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb);


    NativeImageLoader il = new NativeImageLoader(32, 32, 1);
    for( int test=0; test<3; test++) {
        for (int i = 0; i < 6; i++) {
            assertTrue(fbrr.hasNext());
            List<Writable> next = fbrr.next();
            assertEquals(2, next.size());

            INDArray exp;
            switch (i){
                case 0:
                    exp = il.asMatrix(new File(extractedSourceDir, "class0/0.jpg"));
                    break;
                case 1:
                    exp = il.asMatrix(new File(extractedSourceDir, "class0/1.png"));
                    break;
                case 2:
                    exp = il.asMatrix(new File(extractedSourceDir, "class0/2.jpg"));
                    break;
                case 3:
                    exp = il.asMatrix(new File(extractedSourceDir, "class1/A.jpg"));
                    break;
                case 4:
                    exp = il.asMatrix(new File(extractedSourceDir, "class1/B.png"));
                    break;
                case 5:
                    exp = il.asMatrix(new File(extractedSourceDir, "class1/C.jpg"));
                    break;
                default:
                    throw new RuntimeException();
            }
            Writable expLabel = (i < 3 ? new IntWritable(0) : new IntWritable(1));

            assertEquals(((NDArrayWritable)next.get(0)).get(), exp);
            assertEquals(expLabel, next.get(1));
        }
        assertFalse(fbrr.hasNext());
        assertTrue(fbrr.resetSupported());
        fbrr.reset();
    }
}
 
Example 15
@Test
public void testImagesRRDSI() throws Exception {
    File parentDir = temporaryFolder.newFolder();
    parentDir.deleteOnExit();
    String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
    String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");

    File f2 = new File(str2);
    File f1 = new File(str1);
    f1.mkdirs();
    f2.mkdirs();

    TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")),
            new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
    TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")),
            new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());


    Random r = new Random(12345);
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();

    ImageRecordReader rr1 = new ImageRecordReader(28, 28, 3, labelMaker);
    rr1.initialize(new FileSplit(parentDir));


    RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr1,2);
    DataSet ds = rrdsi.next();
    assertArrayEquals(new long[]{2, 3, 28, 28}, ds.getFeatures().shape());
    assertArrayEquals(new long[]{2, 2}, ds.getLabels().shape());


    //Check the same thing via the builder:
    rr1.reset();
    rrdsi = new RecordReaderDataSetIterator.Builder(rr1, 2)
            .classification(1,2)
            .build();


    ds = rrdsi.next();
    assertArrayEquals(new long[]{2, 3, 28, 28}, ds.getFeatures().shape());
    assertArrayEquals(new long[]{2, 2}, ds.getLabels().shape());
}
 
Example 16
@Test
public void testImagesRRDMSI() throws Exception {
    File parentDir = temporaryFolder.newFolder();
    parentDir.deleteOnExit();
    String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
    String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");

    File f1 = new File(str1);
    File f2 = new File(str2);
    f1.mkdirs();
    f2.mkdirs();

    TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")),
                    new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
    TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")),
                    new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());


    int outputNum = 2;
    Random r = new Random(12345);
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();

    ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);

    rr1.initialize(new FileSplit(parentDir));
    rr1s.initialize(new FileSplit(parentDir));


    MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(1).addReader("rr1", rr1)
                    .addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0)
                    .addOutputOneHot("rr1s", 1, outputNum).build();

    //Now, do the same thing with ImageRecordReader, and check we get the same results:
    ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1_b.initialize(new FileSplit(parentDir));
    rr1s_b.initialize(new FileSplit(parentDir));

    DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 1, 1, 2);
    DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 1, 1, 2);

    for (int i = 0; i < 2; i++) {
        MultiDataSet mds = trainDataIterator.next();

        DataSet d1 = dsi1.next();
        DataSet d2 = dsi2.next();

        assertEquals(d1.getFeatures(), mds.getFeatures(0));
        assertEquals(d2.getFeatures(), mds.getFeatures(1));
        assertEquals(d1.getLabels(), mds.getLabels(0));
    }
}
 
Example 17
@Test
public void testImagesRRDMSI_Batched() throws Exception {
    File parentDir = temporaryFolder.newFolder();
    parentDir.deleteOnExit();
    String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
    String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");

    File f1 = new File(str1);
    File f2 = new File(str2);
    f1.mkdirs();
    f2.mkdirs();

    TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")),
                    new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
    TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")),
                    new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());

    int outputNum = 2;
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();

    ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);

    URI[] uris = new FileSplit(parentDir).locations();

    rr1.initialize(new CollectionInputSplit(uris));
    rr1s.initialize(new CollectionInputSplit(uris));

    MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(2).addReader("rr1", rr1)
                    .addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0)
                    .addOutputOneHot("rr1s", 1, outputNum).build();

    //Now, do the same thing with ImageRecordReader, and check we get the same results:
    ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1_b.initialize(new FileSplit(parentDir));
    rr1s_b.initialize(new FileSplit(parentDir));

    DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 2, 1, 2);
    DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 2, 1, 2);

    MultiDataSet mds = trainDataIterator.next();

    DataSet d1 = dsi1.next();
    DataSet d2 = dsi2.next();

    assertEquals(d1.getFeatures(), mds.getFeatures(0));
    assertEquals(d2.getFeatures(), mds.getFeatures(1));
    assertEquals(d1.getLabels(), mds.getLabels(0));

    //Check label assignment:

    File currentFile = rr1_b.getCurrentFile();
    INDArray expLabels;
    if(currentFile.getAbsolutePath().contains("Zico")){
        expLabels = Nd4j.create(new double[][] {{0, 1}, {1, 0}});
    } else {
        expLabels = Nd4j.create(new double[][] {{1, 0}, {0, 1}});
    }

    assertEquals(expLabels, d1.getLabels());
    assertEquals(expLabels, d2.getLabels());
}
 
Example 18
Source Project: deeplearning4j   Source File: LFWDataSetIterator.java    License: Apache License 2.0 4 votes vote down vote up
/** Loads subset of images with given imgDim returned by the generator. */
public LFWDataSetIterator(int[] imgDim) {
    this(LFWLoader.SUB_NUM_IMAGES, LFWLoader.SUB_NUM_IMAGES, imgDim, LFWLoader.SUB_NUM_LABELS, false,
                    new ParentPathLabelGenerator(), true, 1, null, new Random(System.currentTimeMillis()));
}