Java Code Examples for org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator

The following examples show how to use org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator. These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source Project: FederatedAndroidTrainer   Source File: IrisFileDataSource.java    License: MIT License 6 votes vote down vote up
private void createDataSource() throws IOException, InterruptedException {
    //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int numLinesToSkip = 0;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    recordReader.initialize(new InputStreamInputSplit(dataFile));

    //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
    int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
    int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
    DataSet allData = iterator.next();
    allData.shuffle();

    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80);  //Use 80% of data for training

    trainingData = testAndTrain.getTrain();
    testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    normalizer.transform(trainingData);     //Apply normalization to the training data
    normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
 
Example 2
Source Project: FederatedAndroidTrainer   Source File: DiabetesFileDataSource.java    License: MIT License 6 votes vote down vote up
private void createDataSource() throws IOException, InterruptedException {
    //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int numLinesToSkip = 0;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    recordReader.initialize(new InputStreamInputSplit(dataFile));

    //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
    int labelIndex = 11;

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true);
    DataSet allData = iterator.next();

    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80);  //Use 80% of data for training

    trainingData = testAndTrain.getTrain();
    testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    normalizer.transform(trainingData);     //Apply normalization to the training data
    normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
 
Example 3
/**
 * This method returns the iterator. Scales all intensity values: it divides them by 255.
 *
 * @param data the dataset to use
 * @param seed the seed for the random number generator
 * @param batchSize the batch size to use
 * @return the iterator
 */
@Override
public DataSetIterator getDataSetIterator(Instances data, int seed, int batchSize)
    throws Exception {

  batchSize = Math.min(data.numInstances(), batchSize);
  validate(data);
  ImageRecordReader reader = getImageRecordReader(data);

  // Required for supporting channels-last models (currently only EfficientNet)
  if (getChannelsLast())
    reader.setNchw_channels_first(false);

  final int labelIndex = 1; // Use explicit label index position
  final int numPossibleLabels = data.numClasses();
  DataSetIterator tmpIter =
      new RecordReaderDataSetIterator(reader, batchSize, labelIndex, numPossibleLabels);
  DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
  scaler.fit(tmpIter);
  tmpIter.setPreProcessor(scaler);
  return tmpIter;
}
 
Example 4
/**
 * This method returns the iterator. Scales all intensity values: it divides them by 255.
 *
 * @param data the dataset to use
 * @param seed the seed for the random number generator
 * @param batchSize the batch size to use
 * @return the iterator
 */
@Override
public DataSetIterator getDataSetIterator(Instances data, int seed, int batchSize)
    throws Exception {

  batchSize = Math.min(data.numInstances(), batchSize);
  validate(data);
  ImageRecordReader reader = getImageRecordReader(data);

  // Required for supporting channels-last models (currently only EfficientNet)
  if (getChannelsLast())
    reader.setNchw_channels_first(false);

  final int labelIndex = 1; // Use explicit label index position
  final int numPossibleLabels = data.numClasses();
  DataSetIterator tmpIter =
      new RecordReaderDataSetIterator(reader, batchSize, labelIndex, numPossibleLabels);
  DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
  scaler.fit(tmpIter);
  tmpIter.setPreProcessor(scaler);
  return tmpIter;
}
 
Example 5
Source Project: deeplearning4j   Source File: TestRecordReaders.java    License: Apache License 2.0 6 votes vote down vote up
@Test
public void testClassIndexOutsideOfRangeRRDSI() {
    Collection<Collection<Writable>> c = new ArrayList<>();
    c.add(Arrays.<Writable>asList(new DoubleWritable(0.5), new IntWritable(0)));
    c.add(Arrays.<Writable>asList(new DoubleWritable(1.0), new IntWritable(2)));

    CollectionRecordReader crr = new CollectionRecordReader(c);

    RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(crr, 2, 1, 2);

    try {
        DataSet ds = iter.next();
        fail("Expected exception");
    } catch (Exception e) {
        assertTrue(e.getMessage(), e.getMessage().contains("to one-hot"));
    }
}
 
Example 6
Source Project: deeplearning4j   Source File: ConvolutionLayerSetupTest.java    License: Apache License 2.0 6 votes vote down vote up
@Test
public void testLRN() throws Exception {
    List<String> labels = new ArrayList<>(Arrays.asList("Zico", "Ziwang_Xu"));
    File dir = testDir.newFolder();
    new ClassPathResource("lfwtest/").copyDirectory(dir);
    String rootDir = dir.getAbsolutePath();

    RecordReader reader = new ImageRecordReader(28, 28, 3);
    reader.initialize(new FileSplit(new File(rootDir)));
    DataSetIterator recordReader = new RecordReaderDataSetIterator(reader, 10, 1, labels.size());
    labels.remove("lfwtest");
    NeuralNetConfiguration.ListBuilder builder = (NeuralNetConfiguration.ListBuilder) incompleteLRN();
    builder.setInputType(InputType.convolutional(28, 28, 3));

    MultiLayerConfiguration conf = builder.build();

    ConvolutionLayer layer2 = (ConvolutionLayer) conf.getConf(3).getLayer();
    assertEquals(6, layer2.getNIn());

}
 
Example 7
Source Project: deeplearning4j   Source File: MultipleEpochsIteratorTest.java    License: Apache License 2.0 6 votes vote down vote up
@Test
public void testNextAndReset() throws Exception {
    int epochs = 3;

    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
    MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, iter);

    assertTrue(multiIter.hasNext());
    while (multiIter.hasNext()) {
        DataSet path = multiIter.next();
        assertFalse(path == null);
    }
    assertEquals(epochs, multiIter.epochs);
}
 
Example 8
Source Project: deeplearning4j   Source File: MultipleEpochsIteratorTest.java    License: Apache License 2.0 6 votes vote down vote up
@Test
public void testLoadFullDataSet() throws Exception {
    int epochs = 3;

    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
    DataSet ds = iter.next(50);

    assertEquals(50, ds.getFeatures().size(0));

    MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);

    assertTrue(multiIter.hasNext());
    int count = 0;
    while (multiIter.hasNext()) {
        DataSet path = multiIter.next();
        assertNotNull(path);
        assertEquals(50, path.numExamples(), 0);
        count++;
    }
    assertEquals(epochs, count);
    assertEquals(epochs, multiIter.epochs);
}
 
Example 9
Source Project: deeplearning4j   Source File: MultipleEpochsIteratorTest.java    License: Apache License 2.0 6 votes vote down vote up
@Test
public void testLoadBatchDataSet() throws Exception {
    int epochs = 2;

    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150, 4, 3);
    DataSet ds = iter.next(20);
    assertEquals(20, ds.getFeatures().size(0));
    MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);

    while (multiIter.hasNext()) {
        DataSet path = multiIter.next(10);
        assertNotNull(path);
        assertEquals(10, path.numExamples(), 0.0);
    }

    assertEquals(epochs, multiIter.epochs);
}
 
Example 10
Source Project: deeplearning4j   Source File: DataSetIteratorTest.java    License: Apache License 2.0 6 votes vote down vote up
@Test
public void testMnist() throws Exception {
    ClassPathResource cpr = new ClassPathResource("mnist_first_200.txt");
    CSVRecordReader rr = new CSVRecordReader(0, ',');
    rr.initialize(new FileSplit(cpr.getTempFileFromArchive()));
    RecordReaderDataSetIterator dsi = new RecordReaderDataSetIterator(rr, 10, 0, 10);

    MnistDataSetIterator iter = new MnistDataSetIterator(10, 200, false, true, false, 0);

    while (dsi.hasNext()) {
        DataSet dsExp = dsi.next();
        DataSet dsAct = iter.next();

        INDArray fExp = dsExp.getFeatures();
        fExp.divi(255);
        INDArray lExp = dsExp.getLabels();

        INDArray fAct = dsAct.getFeatures();
        INDArray lAct = dsAct.getLabels();

        assertEquals(fExp, fAct.castTo(fExp.dataType()));
        assertEquals(lExp, lAct.castTo(lExp.dataType()));
    }
    assertFalse(iter.hasNext());
}
 
Example 11
private void processBatchIfRequired(List<List<Writable>> list, boolean finalRecord) throws Exception {
    if (list.isEmpty())
        return;
    if (list.size() < batchSize && !finalRecord)
        return;

    RecordReader rr = new CollectionRecordReader(list);
    RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(rr, null, batchSize, labelIndex, labelIndex, numPossibleLabels, -1, regression);

    DataSet ds = iter.next();

    String filename = "dataset_" + uid + "_" + (outputCount++) + ".bin";

    URI uri = new URI(outputDir.getPath() + "/" + filename);
    Configuration c = conf == null ? DefaultHadoopConfig.get() : conf.getValue().getConfiguration();
    FileSystem file = FileSystem.get(uri, c);
    try (FSDataOutputStream out = file.create(new Path(uri))) {
        ds.save(out);
    }

    list.clear();
}
 
Example 12
Source Project: Java-Deep-Learning-Cookbook   Source File: ImageClassifierAPI.java    License: MIT License 5 votes vote down vote up
public static INDArray generateOutput(File inputFile, String modelFileLocation) throws IOException, InterruptedException {
    //retrieve the saved model
    final File modelFile = new File(modelFileLocation);
    final MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelFile);
    final RecordReader imageRecordReader = generateReader(inputFile);
    final ImagePreProcessingScaler normalizerStandardize = ModelSerializer.restoreNormalizerFromFile(modelFile);
    final DataSetIterator dataSetIterator = new RecordReaderDataSetIterator.Builder(imageRecordReader,1).build();
    normalizerStandardize.fit(dataSetIterator);
    dataSetIterator.setPreProcessor(normalizerStandardize);
    return model.output(dataSetIterator);
}
 
Example 13
private static DataSetIteratorSplitter createDataSetSplitter() throws IOException, InterruptedException {
    final RecordReader recordReader = DataSetIteratorHelper.generateReader(new ClassPathResource("Churn_Modelling.csv").getFile());
    final DataSetIterator dataSetIterator = new RecordReaderDataSetIterator.Builder(recordReader,batchSize)
            .classification(labelIndex,numClasses)
            .build();
    final DataNormalization dataNormalization = new NormalizerStandardize();
    dataNormalization.fit(dataSetIterator);
    dataSetIterator.setPreProcessor(dataNormalization);
    final DataSetIteratorSplitter dataSetIteratorSplitter = new DataSetIteratorSplitter(dataSetIterator,1250,0.8);
    return dataSetIteratorSplitter;
}
 
Example 14
public static INDArray generateOutput(File inputFile, String modelFilePath) throws IOException, InterruptedException {
    final File modelFile = new File(modelFilePath);
    final MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(modelFile);
    final RecordReader recordReader = generateReader(inputFile);
    //final INDArray array = RecordConverter.toArray(recordReader.next());
    final NormalizerStandardize normalizerStandardize = ModelSerializer.restoreNormalizerFromFile(modelFile);
    //normalizerStandardize.transform(array);
    final DataSetIterator dataSetIterator = new RecordReaderDataSetIterator.Builder(recordReader,1).build();
    normalizerStandardize.fit(dataSetIterator);
    dataSetIterator.setPreProcessor(normalizerStandardize);
    return network.output(dataSetIterator);

}
 
Example 15
@Override
public Object trainData() {
    try{
        DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses);
        return dataSplit(iterator).getTestIterator();
    }
    catch(Exception e){
        throw new RuntimeException();
    }
}
 
Example 16
@Override
public Object testData() {
    try{
        DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses);
        return dataSplit(iterator).getTestIterator();
    }
    catch(Exception e){
        throw new RuntimeException();
    }
}
 
Example 17
Source Project: Java-Deep-Learning-Cookbook   Source File: HyperParameterTuning.java    License: MIT License 5 votes vote down vote up
@Override
public Object trainData() {
    try{
        DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses);
        return dataSplit(iterator).getTestIterator();
    }
    catch(Exception e){
        throw new RuntimeException();
    }
}
 
Example 18
Source Project: Java-Deep-Learning-Cookbook   Source File: HyperParameterTuning.java    License: MIT License 5 votes vote down vote up
@Override
public Object testData() {
    try{
        DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses);
        return dataSplit(iterator).getTestIterator();
    }
    catch(Exception e){
        throw new RuntimeException();
    }
}
 
Example 19
Source Project: Java-Deep-Learning-Cookbook   Source File: ImageClassifierAPI.java    License: MIT License 5 votes vote down vote up
public static INDArray generateOutput(File inputFile, String modelFileLocation) throws IOException, InterruptedException {
    //retrieve the saved model
    final File modelFile = new File(modelFileLocation);
    final MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelFile);
    final RecordReader imageRecordReader = generateReader(inputFile);
    final ImagePreProcessingScaler normalizerStandardize = ModelSerializer.restoreNormalizerFromFile(modelFile);
    final DataSetIterator dataSetIterator = new RecordReaderDataSetIterator.Builder(imageRecordReader,1).build();
    normalizerStandardize.fit(dataSetIterator);
    dataSetIterator.setPreProcessor(normalizerStandardize);
    return model.output(dataSetIterator);
}
 
Example 20
private static DataSetIteratorSplitter createDataSetSplitter() throws IOException, InterruptedException {
    final RecordReader recordReader = DataSetIteratorHelper.generateReader(new ClassPathResource("Churn_Modelling.csv").getFile());
    final DataSetIterator dataSetIterator = new RecordReaderDataSetIterator.Builder(recordReader,batchSize)
            .classification(labelIndex,numClasses)
            .build();
    final DataNormalization dataNormalization = new NormalizerStandardize();
    dataNormalization.fit(dataSetIterator);
    dataSetIterator.setPreProcessor(dataNormalization);
    final DataSetIteratorSplitter dataSetIteratorSplitter = new DataSetIteratorSplitter(dataSetIterator,1250,0.8);
    return dataSetIteratorSplitter;
}
 
Example 21
public static INDArray generateOutput(File inputFile, String modelFilePath) throws IOException, InterruptedException {
    final File modelFile = new File(modelFilePath);
    final MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(modelFile);
    final RecordReader recordReader = generateReader(inputFile);
    //final INDArray array = RecordConverter.toArray(recordReader.next());
    final NormalizerStandardize normalizerStandardize = ModelSerializer.restoreNormalizerFromFile(modelFile);
    //normalizerStandardize.transform(array);
    final DataSetIterator dataSetIterator = new RecordReaderDataSetIterator.Builder(recordReader,1).build();
    normalizerStandardize.fit(dataSetIterator);
    dataSetIterator.setPreProcessor(normalizerStandardize);
    return network.output(dataSetIterator);

}
 
Example 22
@Override
public Object trainData() {
    try{
        DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses);
        return dataSplit(iterator).getTestIterator();
    }
    catch(Exception e){
        throw new RuntimeException();
    }
}
 
Example 23
@Override
public Object testData() {
    try{
        DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses);
        return dataSplit(iterator).getTestIterator();
    }
    catch(Exception e){
        throw new RuntimeException();
    }
}
 
Example 24
Source Project: Java-Deep-Learning-Cookbook   Source File: HyperParameterTuning.java    License: MIT License 5 votes vote down vote up
@Override
public Object trainData() {
    try{
        DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses);
        return dataSplit(iterator).getTestIterator();
    }
    catch(Exception e){
        throw new RuntimeException();
    }
}
 
Example 25
Source Project: Java-Deep-Learning-Cookbook   Source File: HyperParameterTuning.java    License: MIT License 5 votes vote down vote up
@Override
public Object testData() {
    try{
        DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses);
        return dataSplit(iterator).getTestIterator();
    }
    catch(Exception e){
        throw new RuntimeException();
    }
}
 
Example 26
public static DataSetIterator createDataSetIterator(File sample,int numLabels,int batchSize) throws IOException {
    ImageRecordReader imageRecordReader = new ImageRecordReader(HEIGHT, WIDTH, CHANNELS, LABEL_GENERATOR_MAKER);
    imageRecordReader.initialize(new FileSplit(sample));
    DataSetIterator iterator = new RecordReaderDataSetIterator(imageRecordReader, batchSize,
            1, numLabels);
    iterator.setPreProcessor(new CifarImagePreProcessor());
    return iterator;
}
 
Example 27
default DataSetIterator getDataSetIterator(InputSplit sample) throws IOException {
    ImageRecordReader imageRecordReader = new ImageRecordReader(HEIGHT, WIDTH, CHANNELS, LABEL_GENERATOR_MAKER);
    imageRecordReader.initialize(sample);

    DataSetIterator iterator = new RecordReaderDataSetIterator(imageRecordReader, BATCH_SIZE, 1, NUM_POSSIBLE_LABELS);
    iterator.setPreProcessor(new VGG16ImagePreProcessor());
    return iterator;
}
 
Example 28
Source Project: neo4j-ml-procedures   Source File: DL4JMLModel.java    License: Apache License 2.0 5 votes vote down vote up
@Override
    protected Object doPredict(List<String> line) {
        try {
            ListStringSplit input = new ListStringSplit(Collections.singletonList(line));
            ListStringRecordReader rr = new ListStringRecordReader();
            rr.initialize(input);
            DataSetIterator iterator = new RecordReaderDataSetIterator(rr, 1);

            DataSet ds = iterator.next();
            INDArray prediction = model.output(ds.getFeatures());

            DataType outputType = types.get(this.output);
            switch (outputType) {
                case _float : return prediction.getDouble(0);
                case _class: {
                    int numClasses = 2;
                    double max = 0;
                    int maxIndex = -1;
                    for (int i=0;i<numClasses;i++) {
                        if (prediction.getDouble(i) > max) {maxIndex = i; max = prediction.getDouble(i);}
                    }
                    return maxIndex;
//                    return prediction.getInt(0,1); // numberOfClasses
                }
                default: throw new IllegalArgumentException("Output type not yet supported "+outputType);
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
 
Example 29
private static DataSetIterator readCSVDataset(String csvFileClasspath, int BATCH_SIZE, int LABEL_INDEX, int numClasses)
        throws IOException, InterruptedException {

    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new File(csvFileClasspath)));
    DataSetIterator iterator = new RecordReaderDataSetIterator(rr, BATCH_SIZE, LABEL_INDEX, numClasses);

    return iterator;
}
 
Example 30
@Override
public DataSet load(Source source) throws IOException {
    FileBatch fb = FileBatch.readFromZip(source.getInputStream());

    //Wrap file batch in RecordReader
    //Create RecordReaderDataSetIterator
    //Return dataset
    RecordReader rr = new FileBatchRecordReader(recordReader, fb);
    RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(rr, null, batchSize, labelIndexFrom, labelIndexTo, numPossibleLabels, -1, regression);
    if (preProcessor != null) {
        iter.setPreProcessor(preProcessor);
    }
    DataSet ds = iter.next();
    return ds;
}