org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator Java Examples

The following examples show how to use org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: SparkAMDSI.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue,
                boolean useWorkspace, DataSetCallback callback, Integer deviceId) {
    this();

    if (queueSize < 2)
        queueSize = 2;

    this.callback = callback;
    this.buffer = queue;
    this.backedIterator = iterator;
    this.useWorkspaces = useWorkspace;
    this.prefetchSize = queueSize;
    this.workspaceId = "SAMDSI_ITER-" + java.util.UUID.randomUUID().toString();
    this.deviceId = deviceId;

    if (iterator.resetSupported())
        this.backedIterator.reset();

    this.thread = new SparkPrefetchThread(buffer, iterator, terminator, Nd4j.getAffinityManager().getDeviceForCurrentThread());

    context = TaskContext.get();

    thread.setDaemon(true);
    thread.start();
}
 
Example #2
Source File: BaseEarlyStoppingTrainer.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected BaseEarlyStoppingTrainer(EarlyStoppingConfiguration<T> earlyStoppingConfiguration, T model,
                                   DataSetIterator train, MultiDataSetIterator trainMulti, EarlyStoppingListener<T> listener) {
    if(train != null && train.asyncSupported()){
        train = new AsyncDataSetIterator(train);
    }
    if(trainMulti != null && trainMulti.asyncSupported()){
        trainMulti = new AsyncMultiDataSetIterator(trainMulti);
    }

    this.esConfig = earlyStoppingConfiguration;
    this.model = model;
    this.train = train;
    this.trainMulti = trainMulti;
    this.iterator = (train != null ? train : trainMulti);
    this.listener = listener;
}
 
Example #3
Source File: MultiDataSetSplitterTests.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testUnorderedSplitter_3() {
    val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});

    val splitter = new MultiDataSetIteratorSplitter(back, new int[]{10});

    List<MultiDataSetIterator> iteratorList = splitter.getIterators();
    Random random = new Random();
    int[] indexes = new int[iteratorList.size()];
    for (int i = 0; i < indexes.length; ++i) {
        indexes[i] = random.nextInt(iteratorList.size());
    }

    for (int partNumber : indexes) {
        int cnt = 0;
        while (iteratorList.get(partNumber).hasNext()) {
            val data = iteratorList.get(partNumber).next().getFeatures();
            for (int i = 0; i < data.length; ++i) {
                assertEquals("Train failed on iteration " + cnt, (float) (500 * partNumber + cnt),
                        data[i].getFloat(0), 1e-5);
            }
            cnt++;
        }
    }
}
 
Example #4
Source File: BaseNetScoreFunction.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected double score(Object model, Object testData){
    if (model instanceof MultiLayerNetwork) {
        if (testData instanceof DataSetIterator) {
            return score((MultiLayerNetwork) model, (DataSetIterator) testData);
        } else if(testData instanceof MultiDataSetIterator){
            return score((MultiLayerNetwork) model, (MultiDataSetIterator) testData);
        } else if(testData instanceof DataSetIteratorFactory){
            return score((MultiLayerNetwork)model, ((DataSetIteratorFactory)testData).create());
        } else {
            throw new RuntimeException("Unknown type of data: " + testData.getClass());
        }
    } else {
        if (testData instanceof DataSetIterator) {
            return score((ComputationGraph) model, (DataSetIterator) testData);
        } else if(testData instanceof DataSetIteratorFactory){
            return score((ComputationGraph) model, ((DataSetIteratorFactory)testData).create());
        } else if(testData instanceof MultiDataSetIterator) {
            return score((ComputationGraph) model, (MultiDataSetIterator) testData);
        } else {
            throw new RuntimeException("Unknown type of data: " + testData.getClass());
        }
    }
}
 
Example #5
Source File: AbstractMultiDataSetNormalizer.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * Fit an iterator
 *
 * @param iterator for the data to iterate over
 */
public void fit(@NonNull MultiDataSetIterator iterator) {
    List<S.Builder> featureNormBuilders = new ArrayList<>();
    List<S.Builder> labelNormBuilders = new ArrayList<>();

    iterator.reset();
    while (iterator.hasNext()) {
        MultiDataSet next = iterator.next();
        fitPartial(next, featureNormBuilders, labelNormBuilders);
    }

    featureStats = buildList(featureNormBuilders);
    if (isFitLabel()) {
        labelStats = buildList(labelNormBuilders);
    }
}
 
Example #6
Source File: SameDiffRNNTestCases.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected MultiDataSetIterator getTrainingDataUnnormalized() throws Exception {
    int miniBatchSize = 10;
    int numLabelClasses = 6;

    File featuresDirTrain = Files.createTempDir();
    File labelsDirTrain = Files.createTempDir();
    Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/features/", featuresDirTrain);
    Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/labels/", labelsDirTrain);

    SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
    trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
    SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
    trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));

    DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
            false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(trainData);

    return iter;
}
 
Example #7
Source File: AbstractMultiDataSetNormalizer.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * Fit an iterator
 *
 * @param iterator for the data to iterate over
 */
public void fit(@NonNull MultiDataSetIterator iterator) {
    List<S.Builder> featureNormBuilders = new ArrayList<>();
    List<S.Builder> labelNormBuilders = new ArrayList<>();

    iterator.reset();
    while (iterator.hasNext()) {
        MultiDataSet next = iterator.next();
        fitPartial(next, featureNormBuilders, labelNormBuilders);
    }

    featureStats = buildList(featureNormBuilders);
    if (isFitLabel()) {
        labelStats = buildList(labelNormBuilders);
    }
}
 
Example #8
Source File: MultiDataSetIteratorSplitter.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 *
 * @param baseIterator
 * @param totalBatches - total number of batches in underlying iterator. this value will be used to determine number of test/train batches
 * @param ratio - this value will be used as splitter. should be between in range of 0.0 > X < 1.0. I.e. if value 0.7 is provided, then 70% of total examples will be used for training, and 30% of total examples will be used for testing
 */
public MultiDataSetIteratorSplitter(@NonNull MultiDataSetIterator baseIterator, long totalBatches, double ratio) {
    if (!(ratio > 0.0 && ratio < 1.0))
        throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");

    if (totalBatches < 0)
        throw new ND4JIllegalStateException("totalExamples number should be positive value");

    if (!baseIterator.resetSupported())
        throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");


    this.backedIterator = baseIterator;
    this.totalExamples = totalBatches;
    this.ratio = ratio;
    this.numTrain = (long) (totalExamples * ratio);
    this.numTest = totalExamples - numTrain;
    this.ratios = null;
    this.numArbitrarySets = 0;
    this.splits = null;

    log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
}
 
Example #9
Source File: IEvaluateMDSPathsFlatMapFunction.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public Iterator<IEvaluation[]> call(Iterator<String> paths) throws Exception {
    if (!paths.hasNext()) {
        return Collections.emptyIterator();
    }

    MultiDataSetIterator iter;
    if(dsLoader != null){
        DataSetIterator dsIter = new DataSetLoaderIterator(paths, dsLoader, new RemoteFileSourceFactory(conf));
        iter = new MultiDataSetIteratorAdapter(dsIter);
    } else {
        iter = new MultiDataSetLoaderIterator(paths, mdsLoader, new RemoteFileSourceFactory(conf));
    }

    Future<IEvaluation[]> f = EvaluationRunner.getInstance().execute(evaluations, evalNumWorkers, evalBatchSize, null, iter, true, json, params);
    IEvaluation[] result = f.get();
    if(result == null){
        return Collections.emptyIterator();
    } else {
        return Collections.singletonList(result).iterator();
    }
}
 
Example #10
Source File: TrainUtil.java    From FancyBing with GNU General Public License v3.0 6 votes vote down vote up
public static double evaluate(Model model, int outputNum, MultiDataSetIterator testData, int topN, int batchSize) {
	log.info("Evaluate model....");
    Evaluation clsEval = new Evaluation(createLabels(outputNum), topN);
    RegressionEvaluation valueRegEval1 = new RegressionEvaluation(1);
    int count = 0;
    
    long begin = 0;
    long consume = 0;
    while(testData.hasNext()){
    	MultiDataSet ds = testData.next();
    	begin = System.nanoTime();
    	INDArray[] output = ((ComputationGraph) model).output(false, ds.getFeatures());
    	consume += System.nanoTime() - begin;
        clsEval.eval(ds.getLabels(0), output[0]);
        valueRegEval1.eval(ds.getLabels(1), output[1]);
        count++;
    }
    String stats = clsEval.stats();
    int pos = stats.indexOf("===");
    stats = "\n" + stats.substring(pos);
    log.info(stats);
    log.info(valueRegEval1.stats());
    testData.reset();
    log.info("Evaluate time: " + consume + " count: " + (count * batchSize) + " average: " + ((float) consume/(count*batchSize)/1000));
	return clsEval.accuracy();
}
 
Example #11
Source File: ScrollableMultiDataSetIterator.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public ScrollableMultiDataSetIterator(int num, MultiDataSetIterator backedIterator, AtomicLong counter,
                                 MultiDataSet firstTrain,  int[] itemsPerPart) {
    this.thisPart = num;
    this.bottom = itemsPerPart[0];
    this.top = bottom + itemsPerPart[1];
    this.itemsPerPart = top;

    this.counter = counter;
    //this.resetPending = resetPending;
    this.firstTrain = null;
    this.firstMultiTrain = firstTrain;
    //this.totalExamples = totalExamples;
    this.current = 0;
    this.backedIterator = backedIterator;
    this.resetPending = resetPending;
}
 
Example #12
Source File: ScoreUtil.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Get a {@link DataSetIterator}
 * from the given object whether it's a {@link DataSetIterator}
 * or {@link DataSetIteratorFactory}, any other type will throw
 * an {@link IllegalArgumentException}
 * @param o the object to get the iterator from
 * @return the datasetiterator from the given objects
 */
public static MultiDataSetIterator getMultiIterator(Object o) {
    if (o instanceof MultiDataSetIterator) {
        return (MultiDataSetIterator) o;
    } else if (o instanceof MultiDataSetIteratorFactory) {
        MultiDataSetIteratorFactory factory = (MultiDataSetIteratorFactory) o;
        return factory.create();
    } else if( o instanceof DataSetIterator ){
        return new MultiDataSetIteratorAdapter((DataSetIterator)o);
    } else if( o instanceof DataSetIteratorFactory ){
        return new MultiDataSetIteratorAdapter(((DataSetIteratorFactory)o).create());
    }

    throw new IllegalArgumentException("Type must either be DataSetIterator or DataSetIteratorFactory");
}
 
Example #13
Source File: MultiNormalizerStandardizeTest.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testFullyMaskedData() {
    MultiDataSetIterator iter = new TestMultiDataSetIterator(1,
                    new MultiDataSet(new INDArray[] {Nd4j.create(new float[] {1}).reshape(1, 1, 1)},
                                    new INDArray[] {Nd4j.create(new float[] {2}).reshape(1, 1, 1)}),
                    new MultiDataSet(new INDArray[] {Nd4j.create(new float[] {2}).reshape(1, 1, 1)},
                                    new INDArray[] {Nd4j.create(new float[] {4}).reshape(1, 1, 1)}, null,
                                    new INDArray[] {Nd4j.create(new float[] {0}).reshape(1, 1)}));

    SUT.fit(iter);

    // The label mean should be 2, as the second row with 4 is masked.
    assertEquals(2f, SUT.getLabelMean(0).getFloat(0), 1e-6);
}
 
Example #14
Source File: MultiNormalizerMinMaxScalerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testFullyMaskedData() {
    MultiDataSetIterator iter = new TestMultiDataSetIterator(1,
                    new MultiDataSet(new INDArray[] {Nd4j.create(new float[] {1}).reshape(1, 1, 1)},
                                    new INDArray[] {Nd4j.create(new float[] {2}).reshape(1, 1, 1)}),
                    new MultiDataSet(new INDArray[] {Nd4j.create(new float[] {2}).reshape(1, 1, 1)},
                                    new INDArray[] {Nd4j.create(new float[] {4}).reshape(1, 1, 1)}, null,
                                    new INDArray[] {Nd4j.create(new float[] {0}).reshape(1, 1)}));

    SUT.fit(iter);

    // The label min value should be 2, as the second row with 4 is masked.
    assertEquals(2f, SUT.getLabelMin(0).getFloat(0), 1e-6);
}
 
Example #15
Source File: SameDiffRNNTestCases.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public MultiDataSetIterator getTrainingData() throws Exception {
    MultiDataSetIterator iter = getTrainingDataUnnormalized();
    MultiDataSetPreProcessor pp = multiDataSet -> {
        INDArray l = multiDataSet.getLabels(0);
        l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1));
        multiDataSet.setLabels(0, l);
        multiDataSet.setLabelsMaskArray(0, null);
    };


    iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp));

    return iter;
}
 
Example #16
Source File: EarlyTerminationMultiDataSetIterator.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Constructor takes the iterator to wrap and the number of minibatches after which the call to hasNext()
 * will return false
 * @param underlyingIterator, iterator to wrap
 * @param terminationPoint, minibatches after which hasNext() will return false
 */
public EarlyTerminationMultiDataSetIterator(MultiDataSetIterator underlyingIterator, int terminationPoint) {
    if (terminationPoint <= 0)
        throw new IllegalArgumentException(
                        "Termination point (the number of calls to .next() or .next(num)) has to be > 0");
    this.underlyingIterator = underlyingIterator;
    this.terminationPoint = terminationPoint;
}
 
Example #17
Source File: MultiDataSetIteratorSplitter.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public MultiDataSetIteratorSplitter(@NonNull MultiDataSetIterator baseIterator, long totalBatches, double[] ratios) {
    for (double ratio : ratios) {
        if (!(ratio > 0.0 && ratio < 1.0))
            throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");
    }

    if (totalBatches < 0)
        throw new ND4JIllegalStateException("totalExamples number should be positive value");

    if (!baseIterator.resetSupported())
        throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");


    this.backedIterator = baseIterator;
    this.totalExamples = totalBatches;
    this.ratio = 0.0;
    this.numTrain = (long) (totalExamples * ratio);
    this.numTest = totalExamples - numTrain;
    this.ratios = null;
    this.numArbitrarySets = ratios.length;

    this.splits = new int[this.ratios.length];
    for (int i = 0; i < this.splits.length; ++i) {
        this.splits[i] = (int)(totalExamples * ratios[i]);
    }

    log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
}
 
Example #18
Source File: MultiDataSetIteratorSplitter.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public List<MultiDataSetIterator> getIterators() {
    List<MultiDataSetIterator> retVal = new ArrayList<>();
    int partN = 0;
    int bottom = 0;
    for (final int split : splits) {
        ScrollableMultiDataSetIterator partIterator =
                new ScrollableMultiDataSetIterator(partN++, backedIterator, counter, firstTrain,
                        new int[]{bottom,split});
        bottom += split;
        retVal.add(partIterator);
    }
    return retVal;
}
 
Example #19
Source File: Main.java    From twse-captcha-solver-dl4j with MIT License 5 votes vote down vote up
public static void modelPredict(ComputationGraph model, MultiDataSetIterator iterator) {
  int sumCount = 0;
  int correctCount = 0;

  List<String> labelList =
      Arrays.asList(
          "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G",
          "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X",
          "Y", "Z");

  while (iterator.hasNext()) {
    MultiDataSet mds = iterator.next();
    INDArray[] output = model.output(mds.getFeatures());
    INDArray[] labels = mds.getLabels();
    int dataNum = batchSize > output[0].rows() ? output[0].rows() : batchSize;
    for (int dataIndex = 0; dataIndex < dataNum; dataIndex++) {
      String reLabel = "";
      String peLabel = "";
      INDArray preOutput = null;
      INDArray realLabel = null;
      for (int digit = 0; digit < 5; digit++) {
        preOutput = output[digit].getRow(dataIndex);
        peLabel += labelList.get(Nd4j.argMax(preOutput, 1).getInt(0));

        realLabel = labels[digit].getRow(dataIndex);
 reLabel += labelList.get(Nd4j.argMax(realLabel, 1).getInt(0));
      }
      if (peLabel.equals(reLabel)) {
        correctCount++;
      }
      sumCount++;
      logger.info(
          "real image {}  prediction {} status {}", reLabel, peLabel, peLabel.equals(reLabel));
    }
  }
  iterator.reset();
  System.out.println(
      "validate result : sum count =" + sumCount + " correct count=" + correctCount);
}
 
Example #20
Source File: RNNTestCases.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
        public MultiDataSetIterator getEvaluationTestData() throws Exception {
            int miniBatchSize = 10;
            int numLabelClasses = 6;

//            File featuresDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/features/").getFile();
//            File labelsDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/labels/").getFile();
            File featuresDirTest = Files.createTempDir();
            File labelsDirTest = Files.createTempDir();
            new ClassPathResource("dl4j-integration-tests/data/uci_seq/test/features/").copyDirectory(featuresDirTest);
            new ClassPathResource("dl4j-integration-tests/data/uci_seq/test/labels/").copyDirectory(labelsDirTest);

            SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
            trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
            SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
            trainLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149));

            DataSetIterator testData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
                    false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

            MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(testData);

            MultiDataSetPreProcessor pp = multiDataSet -> {
                INDArray l = multiDataSet.getLabels(0);
                l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2)-1));
                multiDataSet.setLabels(0, l);
                multiDataSet.setLabelsMaskArray(0, null);
            };


            iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(),pp));

            return iter;
        }
 
Example #21
Source File: LoaderIteratorTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testMDSLoaderIter(){

    for(boolean r : new boolean[]{false, true}) {
        List<String> l = Arrays.asList("3", "0", "1");
        Random rng = r ? new Random(12345) : null;
        MultiDataSetIterator iter = new MultiDataSetLoaderIterator(l, null, new Loader<MultiDataSet>() {
            @Override
            public MultiDataSet load(Source source) throws IOException {
                INDArray i = Nd4j.scalar(Integer.valueOf(source.getPath()));
                return new org.nd4j.linalg.dataset.MultiDataSet(i, i);
            }
        }, new LocalFileSourceFactory());

        int count = 0;
        int[] exp = {3, 0, 1};
        while (iter.hasNext()) {
            MultiDataSet ds = iter.next();
            if(!r) {
                assertEquals(exp[count], ds.getFeatures()[0].getInt(0));
            }
            count++;
        }
        assertEquals(3, count);

        iter.reset();
        assertTrue(iter.hasNext());
    }
}
 
Example #22
Source File: SameDiffRNNTestCases.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
        public MultiDataSetIterator getEvaluationTestData() throws Exception {
            int miniBatchSize = 10;
            int numLabelClasses = 6;

//            File featuresDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/features/").getFile();
//            File labelsDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/labels/").getFile();
            File featuresDirTest = Files.createTempDir();
            File labelsDirTest = Files.createTempDir();
            Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/features/", featuresDirTest);
            Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/labels/", labelsDirTest);

            SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
            trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
            SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
            trainLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149));

            DataSetIterator testData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
                    false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

            MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(testData);

            MultiDataSetPreProcessor pp = multiDataSet -> {
                INDArray l = multiDataSet.getLabels(0);
                l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1));
                multiDataSet.setLabels(0, l);
                multiDataSet.setLabelsMaskArray(0, null);
            };


            iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp));

            return iter;
        }
 
Example #23
Source File: MultiNormalizerHybrid.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Iterates over a dataset
 * accumulating statistics for normalization
 *
 * @param iterator the iterator to use for collecting statistics
 */
@Override
public void fit(@NonNull MultiDataSetIterator iterator) {
    Map<Integer, NormalizerStats.Builder> inputStatsBuilders = new HashMap<>();
    Map<Integer, NormalizerStats.Builder> outputStatsBuilders = new HashMap<>();

    iterator.reset();
    while (iterator.hasNext()) {
        fitPartial(iterator.next(), inputStatsBuilders, outputStatsBuilders);
    }

    inputStats = buildAllStats(inputStatsBuilders);
    outputStats = buildAllStats(outputStatsBuilders);
}
 
Example #24
Source File: EvaluativeListener.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Evaluation will be launched after each *frequency* iteration
 *
 * @param iterator    Iterator to provide data for evaluation
 * @param frequency   Frequency (in number of iterations/epochs according to the invocation type) to perform evaluation
 * @param type        Type of value for 'frequency' - iteration end, epoch end, etc
 * @param evaluations Type of evalutions to perform
 */
public EvaluativeListener(@NonNull MultiDataSetIterator iterator, int frequency, @NonNull InvocationType type,
                IEvaluation... evaluations) {
    this.mdsIterator = iterator;
    this.frequency = frequency;
    this.evaluations = evaluations;

    this.invocationType = type;
}
 
Example #25
Source File: AsyncMultiDataSetIterator.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
protected AsyncPrefetchThread(@NonNull BlockingQueue<MultiDataSet> queue,
                              @NonNull MultiDataSetIterator iterator, @NonNull MultiDataSet terminator, int deviceId) {
    this.queue = queue;
    this.iterator = iterator;
    this.terminator = terminator;
    this.deviceId = deviceId;

    this.setDaemon(true);
    this.setName("AMDSI prefetch thread");
}
 
Example #26
Source File: FitConfig.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Set the validation data
 */
public FitConfig validate(DataSetIterator validationData) {
    if (validationData == null) {
        return validate((MultiDataSetIterator) null);
    } else {
        return validate(new MultiDataSetIteratorAdapter(validationData));
    }
}
 
Example #27
Source File: TestComputationGraphNetwork.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test(timeout = 300000)
public void testIrisFitMultiDataSetIterator() throws Exception {

    RecordReader rr = new CSVRecordReader(0, ',');
    rr.initialize(new FileSplit(Resources.asFile("iris.txt")));

    MultiDataSetIterator iter = new RecordReaderMultiDataSetIterator.Builder(10).addReader("iris", rr)
            .addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build();

    ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
            .updater(new Sgd(0.1))
            .graphBuilder().addInputs("in")
            .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out",
                    new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(2).nOut(3)
                            .build(),
                    "dense")
            .setOutputs("out").build();

    ComputationGraph cg = new ComputationGraph(config);
    cg.init();

    cg.fit(iter);


    rr.reset();
    iter = new RecordReaderMultiDataSetIterator.Builder(10).addReader("iris", rr).addInput("iris", 0, 3)
            .addOutputOneHot("iris", 4, 3).build();
    while (iter.hasNext()) {
        cg.fit(iter.next());
    }
}
 
Example #28
Source File: ClassificationEvaluator.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public List<Class<?>> getSupportedDataTypes() {
    return Arrays.<Class<?>>asList(DataSetIterator.class, MultiDataSetIterator.class);
}
 
Example #29
Source File: ROCScoreCalculator.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public ROCScoreCalculator(ROCType type, Metric metric, MultiDataSetIterator iterator){
    super(iterator);
    this.type = type;
    this.metric = metric;
}
 
Example #30
Source File: ROCScoreCalculator.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public ROCScoreCalculator(ROCType type, MultiDataSetIterator iterator){
    this(type, Metric.AUC, iterator);
}