Java Code Examples for org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator#reset()

The following examples show how to use org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator#reset() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TrainUtil.java    From FancyBing with GNU General Public License v3.0 6 votes vote down vote up
public static double evaluate(Model model, int outputNum, MultiDataSetIterator testData, int topN, int batchSize) {
	log.info("Evaluate model....");
    Evaluation clsEval = new Evaluation(createLabels(outputNum), topN);
    RegressionEvaluation valueRegEval1 = new RegressionEvaluation(1);
    int count = 0;
    
    long begin = 0;
    long consume = 0;
    while(testData.hasNext()){
    	MultiDataSet ds = testData.next();
    	begin = System.nanoTime();
    	INDArray[] output = ((ComputationGraph) model).output(false, ds.getFeatures());
    	consume += System.nanoTime() - begin;
        clsEval.eval(ds.getLabels(0), output[0]);
        valueRegEval1.eval(ds.getLabels(1), output[1]);
        count++;
    }
    String stats = clsEval.stats();
    int pos = stats.indexOf("===");
    stats = "\n" + stats.substring(pos);
    log.info(stats);
    log.info(valueRegEval1.stats());
    testData.reset();
    log.info("Evaluate time: " + consume + " count: " + (count * batchSize) + " average: " + ((float) consume/(count*batchSize)/1000));
	return clsEval.accuracy();
}
 
Example 2
Source File: AbstractMultiDataSetNormalizer.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * Fit an iterator
 *
 * @param iterator for the data to iterate over
 */
public void fit(@NonNull MultiDataSetIterator iterator) {
    List<S.Builder> featureNormBuilders = new ArrayList<>();
    List<S.Builder> labelNormBuilders = new ArrayList<>();

    iterator.reset();
    while (iterator.hasNext()) {
        MultiDataSet next = iterator.next();
        fitPartial(next, featureNormBuilders, labelNormBuilders);
    }

    featureStats = buildList(featureNormBuilders);
    if (isFitLabel()) {
        labelStats = buildList(labelNormBuilders);
    }
}
 
Example 3
Source File: AbstractMultiDataSetNormalizer.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * Fit an iterator
 *
 * @param iterator for the data to iterate over
 */
public void fit(@NonNull MultiDataSetIterator iterator) {
    List<S.Builder> featureNormBuilders = new ArrayList<>();
    List<S.Builder> labelNormBuilders = new ArrayList<>();

    iterator.reset();
    while (iterator.hasNext()) {
        MultiDataSet next = iterator.next();
        fitPartial(next, featureNormBuilders, labelNormBuilders);
    }

    featureStats = buildList(featureNormBuilders);
    if (isFitLabel()) {
        labelStats = buildList(labelNormBuilders);
    }
}
 
Example 4
Source File: Main.java    From twse-captcha-solver-dl4j with MIT License 5 votes vote down vote up
public static void modelPredict(ComputationGraph model, MultiDataSetIterator iterator) {
  int sumCount = 0;
  int correctCount = 0;

  List<String> labelList =
      Arrays.asList(
          "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G",
          "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X",
          "Y", "Z");

  while (iterator.hasNext()) {
    MultiDataSet mds = iterator.next();
    INDArray[] output = model.output(mds.getFeatures());
    INDArray[] labels = mds.getLabels();
    int dataNum = batchSize > output[0].rows() ? output[0].rows() : batchSize;
    for (int dataIndex = 0; dataIndex < dataNum; dataIndex++) {
      String reLabel = "";
      String peLabel = "";
      INDArray preOutput = null;
      INDArray realLabel = null;
      for (int digit = 0; digit < 5; digit++) {
        preOutput = output[digit].getRow(dataIndex);
        peLabel += labelList.get(Nd4j.argMax(preOutput, 1).getInt(0));

        realLabel = labels[digit].getRow(dataIndex);
 reLabel += labelList.get(Nd4j.argMax(realLabel, 1).getInt(0));
      }
      if (peLabel.equals(reLabel)) {
        correctCount++;
      }
      sumCount++;
      logger.info(
          "real image {}  prediction {} status {}", reLabel, peLabel, peLabel.equals(reLabel));
    }
  }
  iterator.reset();
  System.out.println(
      "validate result : sum count =" + sumCount + " correct count=" + correctCount);
}
 
Example 5
Source File: MultiNormalizerHybrid.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * Iterates over a dataset
 * accumulating statistics for normalization
 *
 * @param iterator the iterator to use for collecting statistics
 */
@Override
public void fit(@NonNull MultiDataSetIterator iterator) {
    Map<Integer, NormalizerStats.Builder> inputStatsBuilders = new HashMap<>();
    Map<Integer, NormalizerStats.Builder> outputStatsBuilders = new HashMap<>();

    iterator.reset();
    while (iterator.hasNext()) {
        fitPartial(iterator.next(), inputStatsBuilders, outputStatsBuilders);
    }

    inputStats = buildAllStats(inputStatsBuilders);
    outputStats = buildAllStats(outputStatsBuilders);
}
 
Example 6
Source File: MultiNormalizerHybrid.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Iterates over a dataset
 * accumulating statistics for normalization
 *
 * @param iterator the iterator to use for collecting statistics
 */
@Override
public void fit(@NonNull MultiDataSetIterator iterator) {
    Map<Integer, NormalizerStats.Builder> inputStatsBuilders = new HashMap<>();
    Map<Integer, NormalizerStats.Builder> outputStatsBuilders = new HashMap<>();

    iterator.reset();
    while (iterator.hasNext()) {
        fitPartial(iterator.next(), inputStatsBuilders, outputStatsBuilders);
    }

    inputStats = buildAllStats(inputStatsBuilders);
    outputStats = buildAllStats(outputStatsBuilders);
}
 
Example 7
Source File: EarlyTerminationMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testCallstoNextNotAllowed() throws IOException {
    int terminateAfter = 1;

    MultiDataSetIterator iter =
                    new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
    EarlyTerminationMultiDataSetIterator earlyEndIter =
                    new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);

    earlyEndIter.next(10);
    iter.reset();
    exception.expect(RuntimeException.class);
    earlyEndIter.next(10);
}
 
Example 8
Source File: LoaderIteratorTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testMDSLoaderIter(){

    for(boolean r : new boolean[]{false, true}) {
        List<String> l = Arrays.asList("3", "0", "1");
        Random rng = r ? new Random(12345) : null;
        MultiDataSetIterator iter = new MultiDataSetLoaderIterator(l, null, new Loader<MultiDataSet>() {
            @Override
            public MultiDataSet load(Source source) throws IOException {
                INDArray i = Nd4j.scalar(Integer.valueOf(source.getPath()));
                return new org.nd4j.linalg.dataset.MultiDataSet(i, i);
            }
        }, new LocalFileSourceFactory());

        int count = 0;
        int[] exp = {3, 0, 1};
        while (iter.hasNext()) {
            MultiDataSet ds = iter.next();
            if(!r) {
                assertEquals(exp[count], ds.getFeatures()[0].getInt(0));
            }
            count++;
        }
        assertEquals(3, count);

        iter.reset();
        assertTrue(iter.hasNext());
    }
}