org.datavec.api.split.InputStreamInputSplit Java Examples

The following examples show how to use org.datavec.api.split.InputStreamInputSplit. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: IrisFileDataSource.java    From FederatedAndroidTrainer with MIT License 7 votes vote down vote up
private void createDataSource() throws IOException, InterruptedException {
    //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int numLinesToSkip = 0;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    recordReader.initialize(new InputStreamInputSplit(dataFile));

    //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
    int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
    int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
    DataSet allData = iterator.next();
    allData.shuffle();

    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80);  //Use 80% of data for training

    trainingData = testAndTrain.getTrain();
    testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    normalizer.transform(trainingData);     //Apply normalization to the training data
    normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
 
Example #2
Source File: FeatureRecordReader.java    From FancyBing with GNU General Public License v3.0 6 votes vote down vote up
protected Iterator<List<Writable>> getIterator(int location) {
	Iterator<List<Writable>> iterator = null;
	
	if (inputSplit instanceof InputStreamInputSplit) {
        InputStream is = ((InputStreamInputSplit) inputSplit).getIs();
        if (is != null) {
            iterator = lineIterator(new InputStreamReader(is));
        }
    } else {
     this.locations = inputSplit.locations();
     if (locations != null && locations.length > 0) {
         InputStream inputStream;
         try {
             inputStream = locations[location].toURL().openStream();
             onLocationOpen(locations[location]);
         } catch (IOException e) {
             throw new RuntimeException(e);
         }
         iterator = lineIterator(new InputStreamReader(inputStream));
     }
    }
    if (iterator == null)
        throw new UnsupportedOperationException("Unknown input split: " + inputSplit);
    return iterator;
}
 
Example #3
Source File: DiabetesFileDataSource.java    From FederatedAndroidTrainer with MIT License 6 votes vote down vote up
private void createDataSource() throws IOException, InterruptedException {
    //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int numLinesToSkip = 0;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    recordReader.initialize(new InputStreamInputSplit(dataFile));

    //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
    int labelIndex = 11;

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true);
    DataSet allData = iterator.next();

    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80);  //Use 80% of data for training

    trainingData = testAndTrain.getTrain();
    testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    normalizer.transform(trainingData);     //Apply normalization to the training data
    normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
 
Example #4
Source File: IOTiming.java    From DataVec with Apache License 2.0 6 votes vote down vote up
/**
 *
 * @param reader
 * @param inputStream
 * @param function
 * @return
 * @throws Exception
 */
public static TimingStatistics timeNDArrayCreation(RecordReader reader,
                                                   InputStream inputStream,
                                                   INDArrayCreationFunction function) throws Exception {


    reader.initialize(new InputStreamInputSplit(inputStream));
    long longNanos = System.nanoTime();
    List<Writable> next = reader.next();
    long endNanos = System.nanoTime();
    long etlDiff = endNanos - longNanos;
    long startArrCreation = System.nanoTime();
    INDArray arr = function.createFromRecord(next);
    long endArrCreation = System.nanoTime();
    long endCreationDiff = endArrCreation - startArrCreation;
    Map<Integer, Map<MemcpyDirection, Long>> currentBandwidth = PerformanceTracker.getInstance().getCurrentBandwidth();
    val bw = currentBandwidth.get(0).get(MemcpyDirection.HOST_TO_DEVICE);
    val deviceToHost = currentBandwidth.get(0).get(MemcpyDirection.HOST_TO_DEVICE);

    return TimingStatistics.builder()
            .diskReadingTimeNanos(etlDiff)
            .bandwidthNanosHostToDevice(bw)
            .bandwidthDeviceToHost(deviceToHost)
            .ndarrayCreationTimeNanos(endCreationDiff)
            .build();
}
 
Example #5
Source File: LineRecordReader.java    From DataVec with Apache License 2.0 6 votes vote down vote up
protected Iterator<String> getIterator(int location) {
    Iterator<String> iterator = null;
    if (inputSplit instanceof StringSplit) {
        StringSplit stringSplit = (StringSplit) inputSplit;
        iterator = Collections.singletonList(stringSplit.getData()).listIterator();
    } else if (inputSplit instanceof InputStreamInputSplit) {
        InputStream is = ((InputStreamInputSplit) inputSplit).getIs();
        if (is != null) {
            iterator = IOUtils.lineIterator(new InputStreamReader(is));
        }
    } else {
        this.locations = inputSplit.locations();
        if (locations != null && locations.length > 0) {
            InputStream inputStream;
            try {
                inputStream = locations[location].toURL().openStream();
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
            iterator = IOUtils.lineIterator(new InputStreamReader(inputStream));
        }
    }
    if (iterator == null)
        throw new UnsupportedOperationException("Unknown input split: " + inputSplit);
    return iterator;
}
 
Example #6
Source File: CSVRecordReaderTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testStreamReset() throws Exception {
    CSVRecordReader rr = new CSVRecordReader(0, ',');
    rr.initialize(new InputStreamInputSplit(new ClassPathResource("iris.dat").getInputStream()));

    int count = 0;
    while(rr.hasNext()){
        assertNotNull(rr.next());
        count++;
    }
    assertEquals(150, count);

    assertFalse(rr.resetSupported());

    try{
        rr.reset();
        fail("Expected exception");
    } catch (Exception e){
        e.printStackTrace();
    }
}
 
Example #7
Source File: CSVRecordReaderTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
    public void testStreamReset() throws Exception {
        CSVRecordReader rr = new CSVRecordReader(0, ',');
        rr.initialize(new InputStreamInputSplit(new ClassPathResource("datavec-api/iris.dat").getInputStream()));

        int count = 0;
        while(rr.hasNext()){
            assertNotNull(rr.next());
            count++;
        }
        assertEquals(150, count);

        assertFalse(rr.resetSupported());

        try{
            rr.reset();
            fail("Expected exception");
        } catch (Exception e){
            String msg = e.getMessage();
            String msg2 = e.getCause().getMessage();
            assertTrue(msg, msg.contains("Error during LineRecordReader reset"));
            assertTrue(msg2, msg2.contains("Reset not supported from streams"));
//            e.printStackTrace();
        }
    }
 
Example #8
Source File: LineReaderTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testLineReaderWithInputStreamInputSplit() throws Exception {
    File tmpdir = testDir.newFolder();

    File tmp1 = new File(tmpdir, "tmp1.txt.gz");

    OutputStream os = new GZIPOutputStream(new FileOutputStream(tmp1, false));
    IOUtils.writeLines(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9"), null, os);
    os.flush();
    os.close();

    InputSplit split = new InputStreamInputSplit(new GZIPInputStream(new FileInputStream(tmp1)));

    RecordReader reader = new LineRecordReader();
    reader.initialize(split);

    int count = 0;
    while (reader.hasNext()) {
        assertEquals(1, reader.next().size());
        count++;
    }

    assertEquals(9, count);
}
 
Example #9
Source File: MLLibUtil.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * Convert a traditional sc.binaryFiles
 * in to something usable for machine learning
 * @param binaryFiles the binary files to convert
 * @param reader the reader to use
 * @return the labeled points based on the given rdd
 */
public static JavaRDD<LabeledPoint> fromBinary(JavaPairRDD<String, PortableDataStream> binaryFiles,
                final RecordReader reader) {
    JavaRDD<Collection<Writable>> records =
                    binaryFiles.map(new Function<Tuple2<String, PortableDataStream>, Collection<Writable>>() {
                        @Override
                        public Collection<Writable> call(
                                        Tuple2<String, PortableDataStream> stringPortableDataStreamTuple2)
                                        throws Exception {
                            reader.initialize(new InputStreamInputSplit(stringPortableDataStreamTuple2._2().open(),
                                            stringPortableDataStreamTuple2._1()));
                            return reader.next();
                        }
                    });

    JavaRDD<LabeledPoint> ret = records.map(new Function<Collection<Writable>, LabeledPoint>() {
        @Override
        public LabeledPoint call(Collection<Writable> writables) throws Exception {
            return pointOf(writables);
        }
    });
    return ret;
}
 
Example #10
Source File: ArrowBinaryInputAdapter.java    From konduit-serving with Apache License 2.0 5 votes vote down vote up
@Override
public ArrowWritableRecordBatch convert(Buffer input, ConverterArgs parameters, Map<String, Object> contextData) {
    ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
    arrowRecordReader.initialize(new InputStreamInputSplit(new ByteArrayInputStream(input.getBytes())));
    arrowRecordReader.next();
    return arrowRecordReader.getCurrentBatch();
}
 
Example #11
Source File: BaseImageRecordReader.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Override
public boolean hasNext() {
    if(inputSplit instanceof InputStreamInputSplit) {
        return finishedInputStreamSplit;
    }

    if (iter != null) {
        return iter.hasNext();
    } else if (record != null) {
        return !hitImage;
    }
    throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist");
}
 
Example #12
Source File: LineReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testLineReaderWithInputStreamInputSplit() throws Exception {
    String tempDir = System.getProperty("java.io.tmpdir");
    File tmpdir = new File(tempDir, "tmpdir");
    tmpdir.mkdir();

    File tmp1 = new File(tmpdir, "tmp1.txt.gz");

    OutputStream os = new GZIPOutputStream(new FileOutputStream(tmp1, false));
    IOUtils.writeLines(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9"), null, os);
    os.flush();
    os.close();

    InputSplit split = new InputStreamInputSplit(new GZIPInputStream(new FileInputStream(tmp1)));

    RecordReader reader = new LineRecordReader();
    reader.initialize(split);

    int count = 0;
    while (reader.hasNext()) {
        assertEquals(1, reader.next().size());
        count++;
    }

    assertEquals(9, count);

    try {
        FileUtils.deleteDirectory(tmpdir);
    } catch (Exception e) {
        e.printStackTrace();
    }
}
 
Example #13
Source File: BaseImageRecordReader.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public boolean hasNext() {
    if(inputSplit instanceof InputStreamInputSplit) {
        return finishedInputStreamSplit;
    }

    if (iter != null) {
        return iter.hasNext();
    } else if (record != null) {
        return !hitImage;
    }
    throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist");
}
 
Example #14
Source File: LineRecordReader.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initialize(InputSplit split) throws IOException, InterruptedException {
    super.initialize(split);
    if(!(inputSplit instanceof StringSplit || inputSplit instanceof InputStreamInputSplit)){
        final ArrayList<URI> uris = new ArrayList<>();
        final Iterator<URI> uriIterator = inputSplit.locationsIterator();
        while(uriIterator.hasNext()) uris.add(uriIterator.next());

        this.locations = uris.toArray(new URI[0]);
    }
    this.iter = getIterator(0);
    this.initialized = true;
}
 
Example #15
Source File: RecordReaderDataSetiteratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testReadingFromStream() throws Exception {

    for(boolean b : new boolean[]{false, true}) {
        int batchSize = 1;
        int labelIndex = 4;
        int numClasses = 3;
        InputStream dataFile = Resources.asStream("iris.txt");
        RecordReader recordReader = new CSVRecordReader(0, ',');
        recordReader.initialize(new InputStreamInputSplit(dataFile));

        assertTrue(recordReader.hasNext());
        assertFalse(recordReader.resetSupported());

        DataSetIterator iterator;
        if(b){
            iterator = new RecordReaderDataSetIterator.Builder(recordReader, batchSize)
                    .classification(labelIndex, numClasses)
                    .build();
        } else {
            iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
        }
        assertFalse(iterator.resetSupported());

        int count = 0;
        while (iterator.hasNext()) {
            assertNotNull(iterator.next());
            count++;
        }

        assertEquals(150, count);

        try {
            iterator.reset();
            fail("Expected exception");
        } catch (Exception e) {
            //expected
        }
    }
}