org.datavec.image.recordreader.ImageRecordReader Java Examples

The following examples show how to use org.datavec.image.recordreader.ImageRecordReader. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: ImageInstanceIterator.java    From wekaDeeplearning4j with GNU General Public License v3.0 7 votes vote down vote up
/**
 * This method returns the iterator. Scales all intensity values: it divides them by 255.
 *
 * @param data the dataset to use
 * @param seed the seed for the random number generator
 * @param batchSize the batch size to use
 * @return the iterator
 */
@Override
public DataSetIterator getDataSetIterator(Instances data, int seed, int batchSize)
    throws Exception {

  batchSize = Math.min(data.numInstances(), batchSize);
  validate(data);
  ImageRecordReader reader = getImageRecordReader(data);

  // Required for supporting channels-last models (currently only EfficientNet)
  if (getChannelsLast())
    reader.setNchw_channels_first(false);

  final int labelIndex = 1; // Use explicit label index position
  final int numPossibleLabels = data.numClasses();
  DataSetIterator tmpIter =
      new RecordReaderDataSetIterator(reader, batchSize, labelIndex, numPossibleLabels);
  DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
  scaler.fit(tmpIter);
  tmpIter.setPreProcessor(scaler);
  return tmpIter;
}
 
Example #2
Source File: ConvolutionLayerSetupTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testLRN() throws Exception {
    List<String> labels = new ArrayList<>(Arrays.asList("Zico", "Ziwang_Xu"));
    File dir = testDir.newFolder();
    new ClassPathResource("lfwtest/").copyDirectory(dir);
    String rootDir = dir.getAbsolutePath();

    RecordReader reader = new ImageRecordReader(28, 28, 3);
    reader.initialize(new FileSplit(new File(rootDir)));
    DataSetIterator recordReader = new RecordReaderDataSetIterator(reader, 10, 1, labels.size());
    labels.remove("lfwtest");
    NeuralNetConfiguration.ListBuilder builder = (NeuralNetConfiguration.ListBuilder) incompleteLRN();
    builder.setInputType(InputType.convolutional(28, 28, 3));

    MultiLayerConfiguration conf = builder.build();

    ConvolutionLayer layer2 = (ConvolutionLayer) conf.getConf(3).getLayer();
    assertEquals(6, layer2.getNIn());

}
 
Example #3
Source File: ResizeImageInstanceIterator.java    From wekaDeeplearning4j with GNU General Public License v3.0 6 votes vote down vote up
@Override
protected ImageRecordReader getImageRecordReader(Instances data) throws Exception {
  Environment env = Environment.getSystemWide();
  String resolved = getImagesLocation().toString();
  try {
    resolved = env.substitute(resolved);
  } catch (Exception ex) {
    // ignore
  }
  ArffMetaDataLabelGenerator labelGenerator =
      new ArffMetaDataLabelGenerator(data, resolved);
  ResizeImageTransform rit = new ResizeImageTransform(getWidth(), getHeight());
  ImageRecordReader reader =
      new ImageRecordReader(getHeight(), getWidth(), getNumChannels(), labelGenerator, rit);
  CollectionInputSplit cis = new CollectionInputSplit(labelGenerator.getPathURIs());
  reader.initialize(cis);
  return reader;
}
 
Example #4
Source File: ImageInstanceIterator.java    From wekaDeeplearning4j with GNU General Public License v3.0 6 votes vote down vote up
/**
 * Returns the image recorder.
 *
 * @param data the dataset to use
 * @return the image recorder
 */
protected ImageRecordReader getImageRecordReader(Instances data) throws Exception {
  Environment env = Environment.getSystemWide();
  String resolved = getImagesLocation().toString();
  try {
    resolved = env.substitute(getImagesLocation().toString());
  } catch (Exception ex) {
    // ignore
  }

  ArffMetaDataLabelGenerator labelGenerator =
      new ArffMetaDataLabelGenerator(data, resolved);
  ImageRecordReader reader =
      new ImageRecordReader(getHeight(), getWidth(), getNumChannels(), labelGenerator);
  CollectionInputSplit cis = new CollectionInputSplit(labelGenerator.getPathURIs());
  reader.initialize(cis);

  return reader;
}
 
Example #5
Source File: ImageInstanceIteratorTest.java    From wekaDeeplearning4j with GNU General Public License v3.0 6 votes vote down vote up
/**
 * Test
 */
@Test
public void testGetImageRecordReader() throws Exception {
  final Instances metaData = DatasetLoader.loadMiniMnistMeta();
  Method method =
      ImageInstanceIterator.class.getDeclaredMethod("getImageRecordReader", Instances.class);
  method.setAccessible(true);
  this.idi.setTrainBatchSize(1);
  final ImageRecordReader irr = (ImageRecordReader) method.invoke(this.idi, metaData);

  Set<String> labels = new HashSet<>();
  for (Instance inst : metaData) {
    String label = inst.stringValue(1);
    String itLabel = irr.next().get(1).toString();
    Assert.assertEquals(label, itLabel);
    labels.add(label);
  }
  Assert.assertEquals(10, labels.size());
  Assert.assertTrue(labels.containsAll(irr.getLabels()));
  Assert.assertTrue(irr.getLabels().containsAll(labels));
}
 
Example #6
Source File: ImageInstanceIteratorTest.java    From wekaDeeplearning4j with GNU General Public License v3.0 6 votes vote down vote up
/**
 * Test
 */
@Test
public void testGetImageRecordReader() throws Exception {
  final Instances metaData = DatasetLoader.loadMiniMnistMeta();
  Method method =
      ImageInstanceIterator.class.getDeclaredMethod("getImageRecordReader", Instances.class);
  method.setAccessible(true);
  this.idi.setTrainBatchSize(1);
  final ImageRecordReader irr = (ImageRecordReader) method.invoke(this.idi, metaData);

  Set<String> labels = new HashSet<>();
  for (Instance inst : metaData) {
    String label = inst.stringValue(1);
    String itLabel = irr.next().get(1).toString();
    Assert.assertEquals(label, itLabel);
    labels.add(label);
  }
  Assert.assertEquals(10, labels.size());
  Assert.assertTrue(labels.containsAll(irr.getLabels()));
  Assert.assertTrue(irr.getLabels().containsAll(labels));
}
 
Example #7
Source File: ImageInstanceIterator.java    From wekaDeeplearning4j with GNU General Public License v3.0 6 votes vote down vote up
/**
 * This method returns the iterator. Scales all intensity values: it divides them by 255.
 *
 * @param data the dataset to use
 * @param seed the seed for the random number generator
 * @param batchSize the batch size to use
 * @return the iterator
 */
@Override
public DataSetIterator getDataSetIterator(Instances data, int seed, int batchSize)
    throws Exception {

  batchSize = Math.min(data.numInstances(), batchSize);
  validate(data);
  ImageRecordReader reader = getImageRecordReader(data);

  // Required for supporting channels-last models (currently only EfficientNet)
  if (getChannelsLast())
    reader.setNchw_channels_first(false);

  final int labelIndex = 1; // Use explicit label index position
  final int numPossibleLabels = data.numClasses();
  DataSetIterator tmpIter =
      new RecordReaderDataSetIterator(reader, batchSize, labelIndex, numPossibleLabels);
  DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
  scaler.fit(tmpIter);
  tmpIter.setPreProcessor(scaler);
  return tmpIter;
}
 
Example #8
Source File: ImageInstanceIterator.java    From wekaDeeplearning4j with GNU General Public License v3.0 6 votes vote down vote up
/**
 * Returns the image recorder.
 *
 * @param data the dataset to use
 * @return the image recorder
 */
protected ImageRecordReader getImageRecordReader(Instances data) throws Exception {
  Environment env = Environment.getSystemWide();
  String resolved = getImagesLocation().toString();
  try {
    resolved = env.substitute(getImagesLocation().toString());
  } catch (Exception ex) {
    // ignore
  }

  ArffMetaDataLabelGenerator labelGenerator =
      new ArffMetaDataLabelGenerator(data, resolved);
  ImageRecordReader reader =
      new ImageRecordReader(getHeight(), getWidth(), getNumChannels(), labelGenerator);
  CollectionInputSplit cis = new CollectionInputSplit(labelGenerator.getPathURIs());
  reader.initialize(cis);

  return reader;
}
 
Example #9
Source File: ResizeImageInstanceIterator.java    From wekaDeeplearning4j with GNU General Public License v3.0 6 votes vote down vote up
@Override
protected ImageRecordReader getImageRecordReader(Instances data) throws Exception {
  Environment env = Environment.getSystemWide();
  String resolved = getImagesLocation().toString();
  try {
    resolved = env.substitute(resolved);
  } catch (Exception ex) {
    // ignore
  }
  ArffMetaDataLabelGenerator labelGenerator =
      new ArffMetaDataLabelGenerator(data, resolved);
  ResizeImageTransform rit = new ResizeImageTransform(getWidth(), getHeight());
  ImageRecordReader reader =
      new ImageRecordReader(getHeight(), getWidth(), getNumChannels(), labelGenerator, rit);
  CollectionInputSplit cis = new CollectionInputSplit(labelGenerator.getPathURIs());
  reader.initialize(cis);
  return reader;
}
 
Example #10
Source File: LoaderTests.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testLfwReader() throws Exception {
    String subDir = "lfw-a/lfw";
    File path = new File(FilenameUtils.concat(System.getProperty("user.home"), subDir));
    FileSplit fileSplit = new FileSplit(path, LFWLoader.ALLOWED_FORMATS, new Random(42));
    BalancedPathFilter pathFilter = new BalancedPathFilter(new Random(42), LFWLoader.LABEL_PATTERN, 1, 1, 1);
    InputSplit[] inputSplit = fileSplit.sample(pathFilter, 1);
    RecordReader rr = new ImageRecordReader(250, 250, 3, LFWLoader.LABEL_PATTERN);
    rr.initialize(inputSplit[0]);
    List<String> exptedLabel = rr.getLabels();

    RecordReader rr2 = new LFWLoader(new int[] {250, 250, 3}, true).getRecordReader(1, 1, 1, new Random(42));

    assertEquals(exptedLabel.get(0), rr2.getLabels().get(0));
}
 
Example #11
Source File: LabelGeneratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testParentPathLabelGenerator() throws Exception {
    //https://github.com/deeplearning4j/DataVec/issues/273
    File orig = new ClassPathResource("datavec-data-image/testimages/class0/0.jpg").getFile();

    for(String dirPrefix : new String[]{"m.", "m"}) {
        File f = testDir.newFolder();

        int numDirs = 3;
        int filesPerDir = 4;

        for (int i = 0; i < numDirs; i++) {
            File currentLabelDir = new File(f, dirPrefix + i);
            currentLabelDir.mkdirs();
            for (int j = 0; j < filesPerDir; j++) {
                File f3 = new File(currentLabelDir, "myImg_" + j + ".jpg");
                FileUtils.copyFile(orig, f3);
                assertTrue(f3.exists());
            }
        }

        ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
        rr.initialize(new FileSplit(f));

        List<String> labelsAct = rr.getLabels();
        List<String> labelsExp = Arrays.asList(dirPrefix + "0", dirPrefix + "1", dirPrefix + "2");
        assertEquals(labelsExp, labelsAct);

        int expCount = numDirs * filesPerDir;
        int actCount = 0;
        while (rr.hasNext()) {
            rr.next();
            actCount++;
        }
        assertEquals(expCount, actCount);
    }
}
 
Example #12
Source File: LFWLoader.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public RecordReader getRecordReader(long batchSize, long numExamples, long[] imgDim, long numLabels,
                PathLabelGenerator labelGenerator, boolean train, double splitTrainTest, Random rng) {
    load(batchSize, numExamples, numLabels, labelGenerator, splitTrainTest, rng);
    RecordReader recordReader =
                    new ImageRecordReader(imgDim[0], imgDim[1], imgDim[2], labelGenerator, imageTransform);

    try {
        InputSplit data = train ? inputSplit[0] : inputSplit[1];
        recordReader.initialize(data);
    } catch (IOException | InterruptedException e) {
        log.error("",e);
    }
    return recordReader;
}
 
Example #13
Source File: LFWLoader.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public RecordReader getRecordReader(long batchSize, long numExamples, int[] imgDim, long numLabels,
                PathLabelGenerator labelGenerator, boolean train, double splitTrainTest, Random rng) {
    load(batchSize, numExamples, numLabels, labelGenerator, splitTrainTest, rng);
    RecordReader recordReader =
                    new ImageRecordReader(imgDim[0], imgDim[1], imgDim[2], labelGenerator, imageTransform);

    try {
        InputSplit data = train ? inputSplit[0] : inputSplit[1];
        recordReader.initialize(data);
    } catch (IOException | InterruptedException e) {
        log.error("",e);
    }
    return recordReader;
}
 
Example #14
Source File: LabelGeneratorTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testParentPathLabelGenerator() throws Exception {
    //https://github.com/deeplearning4j/DataVec/issues/273
    File orig = new ClassPathResource("testimages/class0/0.jpg").getFile();

    for(String dirPrefix : new String[]{"m.", "m"}) {
        File f = testDir.newFolder();

        int numDirs = 3;
        int filesPerDir = 4;

        for (int i = 0; i < numDirs; i++) {
            File currentLabelDir = new File(f, dirPrefix + i);
            currentLabelDir.mkdirs();
            for (int j = 0; j < filesPerDir; j++) {
                File f3 = new File(currentLabelDir, "myImg_" + j + ".jpg");
                FileUtils.copyFile(orig, f3);
                assertTrue(f3.exists());
            }
        }

        ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
        rr.initialize(new FileSplit(f));

        List<String> labelsAct = rr.getLabels();
        List<String> labelsExp = Arrays.asList(dirPrefix + "0", dirPrefix + "1", dirPrefix + "2");
        assertEquals(labelsExp, labelsAct);

        int expCount = numDirs * filesPerDir;
        int actCount = 0;
        while (rr.hasNext()) {
            rr.next();
            actCount++;
        }
        assertEquals(expCount, actCount);
    }
}
 
Example #15
Source File: LFWLoader.java    From DataVec with Apache License 2.0 5 votes vote down vote up
public RecordReader getRecordReader(int batchSize, int numExamples, int[] imgDim, int numLabels,
                PathLabelGenerator labelGenerator, boolean train, double splitTrainTest, Random rng) {
    load(batchSize, numExamples, numLabels, labelGenerator, splitTrainTest, rng);
    RecordReader recordReader =
                    new ImageRecordReader(imgDim[0], imgDim[1], imgDim[2], labelGenerator, imageTransform);

    try {
        InputSplit data = train ? inputSplit[0] : inputSplit[1];
        recordReader.initialize(data);
    } catch (IOException | InterruptedException e) {
        e.printStackTrace();
    }
    return recordReader;
}
 
Example #16
Source File: DataStorage.java    From Java-Machine-Learning-for-Computer-Vision with MIT License 5 votes vote down vote up
default DataSetIterator getDataSetIterator(InputSplit sample) throws IOException {
    ImageRecordReader imageRecordReader = new ImageRecordReader(HEIGHT, WIDTH, CHANNELS, LABEL_GENERATOR_MAKER);
    imageRecordReader.initialize(sample);

    DataSetIterator iterator = new RecordReaderDataSetIterator(imageRecordReader, BATCH_SIZE, 1, NUM_POSSIBLE_LABELS);
    iterator.setPreProcessor(new VGG16ImagePreProcessor());
    return iterator;
}
 
Example #17
Source File: ImageUtils.java    From Java-Machine-Learning-for-Computer-Vision with MIT License 5 votes vote down vote up
public static DataSetIterator createDataSetIterator(File sample,int numLabels,int batchSize) throws IOException {
    ImageRecordReader imageRecordReader = new ImageRecordReader(HEIGHT, WIDTH, CHANNELS, LABEL_GENERATOR_MAKER);
    imageRecordReader.initialize(new FileSplit(sample));
    DataSetIterator iterator = new RecordReaderDataSetIterator(imageRecordReader, batchSize,
            1, numLabels);
    iterator.setPreProcessor(new CifarImagePreProcessor());
    return iterator;
}
 
Example #18
Source File: RecordReaderDataSetiteratorTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testImagesRRDSI() throws Exception {
    File parentDir = temporaryFolder.newFolder();
    parentDir.deleteOnExit();
    String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
    String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");

    File f2 = new File(str2);
    File f1 = new File(str1);
    f1.mkdirs();
    f2.mkdirs();

    TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")),
            new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
    TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")),
            new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());


    Random r = new Random(12345);
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();

    ImageRecordReader rr1 = new ImageRecordReader(28, 28, 3, labelMaker);
    rr1.initialize(new FileSplit(parentDir));


    RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr1,2);
    DataSet ds = rrdsi.next();
    assertArrayEquals(new long[]{2, 3, 28, 28}, ds.getFeatures().shape());
    assertArrayEquals(new long[]{2, 2}, ds.getLabels().shape());


    //Check the same thing via the builder:
    rr1.reset();
    rrdsi = new RecordReaderDataSetIterator.Builder(rr1, 2)
            .classification(1,2)
            .build();


    ds = rrdsi.next();
    assertArrayEquals(new long[]{2, 3, 28, 28}, ds.getFeatures().shape());
    assertArrayEquals(new long[]{2, 2}, ds.getLabels().shape());
}
 
Example #19
Source File: ModelUtils.java    From gluon-samples with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
public void evaluateModel(MultiLayerNetwork model, boolean invertColors) throws IOException {
        LOGGER.info("******EVALUATE MODEL******");

        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
        ImageRecordReader recordReader = new ImageRecordReader(height,width,channels,labelMaker);
//        recordReader.setListeners(new LogRecordListener());

        // Initialize the record reader
        // add a listener, to extract the name

        File testData = new File(DATA_PATH + "/mnist_png/testing");
        FileSplit test = new FileSplit(testData,NativeImageLoader.ALLOWED_FORMATS,randNumGen);

        // The model trained on the training dataset split
        // now that it has trained we evaluate against the
        // test data of images the network has not seen

        recordReader.initialize(test);
        DataNormalization scaler = new ImagePreProcessingScaler(invertColors ? 1 : 0, invertColors ? 0 : 1);
        DataSetIterator testIter = new RecordReaderDataSetIterator(recordReader,batchSize,1,outputNum);
        scaler.fit(testIter);
        testIter.setPreProcessor(scaler);

        /*
        log the order of the labels for later use
        In previous versions the label order was consistent, but random
        In current verions label order is lexicographic
        preserving the RecordReader Labels order is no
        longer needed left in for demonstration
        purposes
        */
        LOGGER.info(recordReader.getLabels().toString());

        // Create Eval object with 10 possible classes
        Evaluation eval = new Evaluation(outputNum);


        // Evaluate the network
        while (testIter.hasNext()) {
            DataSet next = testIter.next();
            INDArray output = model.output(next.getFeatureMatrix());
            // Compare the Feature Matrix from the model
            // with the labels from the RecordReader
            eval.eval(next.getLabels(), output);

        }

        LOGGER.info(eval.stats());
    }
 
Example #20
Source File: RecordReaderMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testImagesRRDMSI() throws Exception {
    File parentDir = temporaryFolder.newFolder();
    parentDir.deleteOnExit();
    String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
    String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");

    File f1 = new File(str1);
    File f2 = new File(str2);
    f1.mkdirs();
    f2.mkdirs();

    TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")),
                    new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
    TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")),
                    new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());


    int outputNum = 2;
    Random r = new Random(12345);
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();

    ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);

    rr1.initialize(new FileSplit(parentDir));
    rr1s.initialize(new FileSplit(parentDir));


    MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(1).addReader("rr1", rr1)
                    .addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0)
                    .addOutputOneHot("rr1s", 1, outputNum).build();

    //Now, do the same thing with ImageRecordReader, and check we get the same results:
    ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1_b.initialize(new FileSplit(parentDir));
    rr1s_b.initialize(new FileSplit(parentDir));

    DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 1, 1, 2);
    DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 1, 1, 2);

    for (int i = 0; i < 2; i++) {
        MultiDataSet mds = trainDataIterator.next();

        DataSet d1 = dsi1.next();
        DataSet d2 = dsi2.next();

        assertEquals(d1.getFeatures(), mds.getFeatures(0));
        assertEquals(d2.getFeatures(), mds.getFeatures(1));
        assertEquals(d1.getLabels(), mds.getLabels(0));
    }
}
 
Example #21
Source File: RecordReaderMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testImagesRRDMSI_Batched() throws Exception {
    File parentDir = temporaryFolder.newFolder();
    parentDir.deleteOnExit();
    String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
    String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");

    File f1 = new File(str1);
    File f2 = new File(str2);
    f1.mkdirs();
    f2.mkdirs();

    TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")),
                    new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
    TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")),
                    new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());

    int outputNum = 2;
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();

    ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);

    URI[] uris = new FileSplit(parentDir).locations();

    rr1.initialize(new CollectionInputSplit(uris));
    rr1s.initialize(new CollectionInputSplit(uris));

    MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(2).addReader("rr1", rr1)
                    .addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0)
                    .addOutputOneHot("rr1s", 1, outputNum).build();

    //Now, do the same thing with ImageRecordReader, and check we get the same results:
    ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1_b.initialize(new FileSplit(parentDir));
    rr1s_b.initialize(new FileSplit(parentDir));

    DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 2, 1, 2);
    DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 2, 1, 2);

    MultiDataSet mds = trainDataIterator.next();

    DataSet d1 = dsi1.next();
    DataSet d2 = dsi2.next();

    assertEquals(d1.getFeatures(), mds.getFeatures(0));
    assertEquals(d2.getFeatures(), mds.getFeatures(1));
    assertEquals(d1.getLabels(), mds.getLabels(0));

    //Check label assignment:

    File currentFile = rr1_b.getCurrentFile();
    INDArray expLabels;
    if(currentFile.getAbsolutePath().contains("Zico")){
        expLabels = Nd4j.create(new double[][] {{0, 1}, {1, 0}});
    } else {
        expLabels = Nd4j.create(new double[][] {{1, 0}, {0, 1}});
    }

    assertEquals(expLabels, d1.getLabels());
    assertEquals(expLabels, d2.getLabels());
}
 
Example #22
Source File: ImageInputFormat.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException {
    RecordReader reader = new ImageRecordReader();
    reader.initialize(conf, split);
    return reader;
}
 
Example #23
Source File: ImageClassifierAPI.java    From Java-Deep-Learning-Cookbook with MIT License 4 votes vote down vote up
private static RecordReader generateReader(File file) throws IOException, InterruptedException {
    final RecordReader recordReader = new ImageRecordReader(30,30,3);
    final InputSplit inputSplit = new FileSplit(file);
    recordReader.initialize(inputSplit);
    return recordReader;
}
 
Example #24
Source File: ImageInputFormat.java    From DataVec with Apache License 2.0 4 votes vote down vote up
@Override
public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException {
    RecordReader reader = new ImageRecordReader();
    reader.initialize(conf, split);
    return reader;
}
 
Example #25
Source File: ImageClassifierAPI.java    From Java-Deep-Learning-Cookbook with MIT License 4 votes vote down vote up
private static RecordReader generateReader(File file) throws IOException, InterruptedException {
    final RecordReader recordReader = new ImageRecordReader(30,30,3);
    final InputSplit inputSplit = new FileSplit(file);
    recordReader.initialize(inputSplit);
    return recordReader;
}