Java Code Examples for org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator#hasNext()

The following examples show how to use org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator#hasNext() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TrainUtil.java    From FancyBing with GNU General Public License v3.0 6 votes vote down vote up
public static double evaluate(Model model, int outputNum, MultiDataSetIterator testData, int topN, int batchSize) {
	log.info("Evaluate model....");
    Evaluation clsEval = new Evaluation(createLabels(outputNum), topN);
    RegressionEvaluation valueRegEval1 = new RegressionEvaluation(1);
    int count = 0;
    
    long begin = 0;
    long consume = 0;
    while(testData.hasNext()){
    	MultiDataSet ds = testData.next();
    	begin = System.nanoTime();
    	INDArray[] output = ((ComputationGraph) model).output(false, ds.getFeatures());
    	consume += System.nanoTime() - begin;
        clsEval.eval(ds.getLabels(0), output[0]);
        valueRegEval1.eval(ds.getLabels(1), output[1]);
        count++;
    }
    String stats = clsEval.stats();
    int pos = stats.indexOf("===");
    stats = "\n" + stats.substring(pos);
    log.info(stats);
    log.info(valueRegEval1.stats());
    testData.reset();
    log.info("Evaluate time: " + consume + " count: " + (count * batchSize) + " average: " + ((float) consume/(count*batchSize)/1000));
	return clsEval.accuracy();
}
 
Example 2
Source File: AbstractMultiDataSetNormalizer.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * Fit an iterator
 *
 * @param iterator for the data to iterate over
 */
public void fit(@NonNull MultiDataSetIterator iterator) {
    List<S.Builder> featureNormBuilders = new ArrayList<>();
    List<S.Builder> labelNormBuilders = new ArrayList<>();

    iterator.reset();
    while (iterator.hasNext()) {
        MultiDataSet next = iterator.next();
        fitPartial(next, featureNormBuilders, labelNormBuilders);
    }

    featureStats = buildList(featureNormBuilders);
    if (isFitLabel()) {
        labelStats = buildList(labelNormBuilders);
    }
}
 
Example 3
Source File: AsyncMultiDataSetIterator.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public AsyncMultiDataSetIterator(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue,
                                 boolean useWorkspace, DataSetCallback callback, Integer deviceId) {

    if (queueSize < 2)
        queueSize = 2;

    this.callback = callback;
    this.buffer = queue;
    this.backedIterator = iterator;
    this.useWorkspaces = useWorkspace;
    this.prefetchSize = queueSize;
    this.workspaceId = "AMDSI_ITER-" + java.util.UUID.randomUUID().toString();
    this.deviceId = deviceId;

    if (iterator.resetSupported() && !iterator.hasNext())
        this.backedIterator.reset();

    this.thread = new AsyncPrefetchThread(buffer, iterator, terminator, deviceId);

    thread.setDaemon(true);
    thread.start();
}
 
Example 4
Source File: AbstractMultiDataSetNormalizer.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * Fit an iterator
 *
 * @param iterator for the data to iterate over
 */
public void fit(@NonNull MultiDataSetIterator iterator) {
    List<S.Builder> featureNormBuilders = new ArrayList<>();
    List<S.Builder> labelNormBuilders = new ArrayList<>();

    iterator.reset();
    while (iterator.hasNext()) {
        MultiDataSet next = iterator.next();
        fitPartial(next, featureNormBuilders, labelNormBuilders);
    }

    featureStats = buildList(featureNormBuilders);
    if (isFitLabel()) {
        labelStats = buildList(labelNormBuilders);
    }
}
 
Example 5
Source File: ScoreUtil.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * Score based on the loss function
 * @param model the model to score with
 * @param testData the test data to score
 * @param average whether to average the score
 *                for the whole batch or not
 * @return the score for the given test set
 */
public static double score(ComputationGraph model, MultiDataSetIterator testData, boolean average) {
    //TODO: do this properly taking into account division by N, L1/L2 etc
    double sumScore = 0.0;
    int totalExamples = 0;
    while (testData.hasNext()) {
        MultiDataSet ds = testData.next();
        long numExamples = ds.getFeatures(0).size(0);
        sumScore += numExamples * model.score(ds);
        totalExamples += numExamples;
    }

    if (!average)
        return sumScore;
    return sumScore / totalExamples;
}
 
Example 6
Source File: Main.java    From twse-captcha-solver-dl4j with MIT License 5 votes vote down vote up
public static void modelPredict(ComputationGraph model, MultiDataSetIterator iterator) {
  int sumCount = 0;
  int correctCount = 0;

  List<String> labelList =
      Arrays.asList(
          "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G",
          "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X",
          "Y", "Z");

  while (iterator.hasNext()) {
    MultiDataSet mds = iterator.next();
    INDArray[] output = model.output(mds.getFeatures());
    INDArray[] labels = mds.getLabels();
    int dataNum = batchSize > output[0].rows() ? output[0].rows() : batchSize;
    for (int dataIndex = 0; dataIndex < dataNum; dataIndex++) {
      String reLabel = "";
      String peLabel = "";
      INDArray preOutput = null;
      INDArray realLabel = null;
      for (int digit = 0; digit < 5; digit++) {
        preOutput = output[digit].getRow(dataIndex);
        peLabel += labelList.get(Nd4j.argMax(preOutput, 1).getInt(0));

        realLabel = labels[digit].getRow(dataIndex);
 reLabel += labelList.get(Nd4j.argMax(realLabel, 1).getInt(0));
      }
      if (peLabel.equals(reLabel)) {
        correctCount++;
      }
      sumCount++;
      logger.info(
          "real image {}  prediction {} status {}", reLabel, peLabel, peLabel.equals(reLabel));
    }
  }
  iterator.reset();
  System.out.println(
      "validate result : sum count =" + sumCount + " correct count=" + correctCount);
}
 
Example 7
Source File: MultiNormalizerHybrid.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * Iterates over a dataset
 * accumulating statistics for normalization
 *
 * @param iterator the iterator to use for collecting statistics
 */
@Override
public void fit(@NonNull MultiDataSetIterator iterator) {
    Map<Integer, NormalizerStats.Builder> inputStatsBuilders = new HashMap<>();
    Map<Integer, NormalizerStats.Builder> outputStatsBuilders = new HashMap<>();

    iterator.reset();
    while (iterator.hasNext()) {
        fitPartial(iterator.next(), inputStatsBuilders, outputStatsBuilders);
    }

    inputStats = buildAllStats(inputStatsBuilders);
    outputStats = buildAllStats(outputStatsBuilders);
}
 
Example 8
Source File: MultiNormalizerHybrid.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Iterates over a dataset
 * accumulating statistics for normalization
 *
 * @param iterator the iterator to use for collecting statistics
 */
@Override
public void fit(@NonNull MultiDataSetIterator iterator) {
    Map<Integer, NormalizerStats.Builder> inputStatsBuilders = new HashMap<>();
    Map<Integer, NormalizerStats.Builder> outputStatsBuilders = new HashMap<>();

    iterator.reset();
    while (iterator.hasNext()) {
        fitPartial(iterator.next(), inputStatsBuilders, outputStatsBuilders);
    }

    inputStats = buildAllStats(inputStatsBuilders);
    outputStats = buildAllStats(outputStatsBuilders);
}
 
Example 9
Source File: TestComputationGraphNetwork.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test(timeout = 300000)
public void testIrisFitMultiDataSetIterator() throws Exception {

    RecordReader rr = new CSVRecordReader(0, ',');
    rr.initialize(new FileSplit(Resources.asFile("iris.txt")));

    MultiDataSetIterator iter = new RecordReaderMultiDataSetIterator.Builder(10).addReader("iris", rr)
            .addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build();

    ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
            .updater(new Sgd(0.1))
            .graphBuilder().addInputs("in")
            .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out",
                    new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(2).nOut(3)
                            .build(),
                    "dense")
            .setOutputs("out").build();

    ComputationGraph cg = new ComputationGraph(config);
    cg.init();

    cg.fit(iter);


    rr.reset();
    iter = new RecordReaderMultiDataSetIterator.Builder(10).addReader("iris", rr).addInput("iris", 0, 3)
            .addOutputOneHot("iris", 4, 3).build();
    while (iter.hasNext()) {
        cg.fit(iter.next());
    }
}
 
Example 10
Source File: RandomDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testMDSI(){
    Nd4j.getRandom().setSeed(12345);
    MultiDataSetIterator iter = new RandomMultiDataSetIterator.Builder(5)
            .addFeatures(new long[]{3,4}, RandomMultiDataSetIterator.Values.INTEGER_0_100)
            .addFeatures(new long[]{3,5}, RandomMultiDataSetIterator.Values.BINARY)
            .addLabels(new long[]{3,6}, RandomMultiDataSetIterator.Values.ZEROS)
        .build();

    int count = 0;
    while(iter.hasNext()){
        count++;
        MultiDataSet mds = iter.next();

        assertEquals(2, mds.numFeatureArrays());
        assertEquals(1, mds.numLabelsArrays());
        assertArrayEquals(new long[]{3,4}, mds.getFeatures(0).shape());
        assertArrayEquals(new long[]{3,5}, mds.getFeatures(1).shape());
        assertArrayEquals(new long[]{3,6}, mds.getLabels(0).shape());

        assertTrue(mds.getFeatures(0).minNumber().doubleValue() >= 0 && mds.getFeatures(0).maxNumber().doubleValue() <= 100.0
                && mds.getFeatures(0).maxNumber().doubleValue() > 2.0);
        assertTrue(mds.getFeatures(1).minNumber().doubleValue() == 0.0 && mds.getFeatures(1).maxNumber().doubleValue() == 1.0);
        assertEquals(0.0, mds.getLabels(0).sumNumber().doubleValue(), 0.0);
    }
    assertEquals(5, count);
}
 
Example 11
Source File: LoaderIteratorTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testMDSLoaderIter(){

    for(boolean r : new boolean[]{false, true}) {
        List<String> l = Arrays.asList("3", "0", "1");
        Random rng = r ? new Random(12345) : null;
        MultiDataSetIterator iter = new MultiDataSetLoaderIterator(l, null, new Loader<MultiDataSet>() {
            @Override
            public MultiDataSet load(Source source) throws IOException {
                INDArray i = Nd4j.scalar(Integer.valueOf(source.getPath()));
                return new org.nd4j.linalg.dataset.MultiDataSet(i, i);
            }
        }, new LocalFileSourceFactory());

        int count = 0;
        int[] exp = {3, 0, 1};
        while (iter.hasNext()) {
            MultiDataSet ds = iter.next();
            if(!r) {
                assertEquals(exp[count], ds.getFeatures()[0].getInt(0));
            }
            count++;
        }
        assertEquals(3, count);

        iter.reset();
        assertTrue(iter.hasNext());
    }
}
 
Example 12
Source File: ScoreFlatMapFunctionCGMultiDataSet.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public Iterator<Tuple2<Long, Double>> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
    if (!dataSetIterator.hasNext()) {
        return Collections.singletonList(new Tuple2<>(0L, 0.0)).iterator();
    }

    MultiDataSetIterator iter = new IteratorMultiDataSetIterator(dataSetIterator, minibatchSize); //Does batching where appropriate


    ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(json));
    network.init();
    INDArray val = params.value().unsafeDuplication(); //.value() is shared by all executors on single machine -> OK, as params are not changed in score function
    if (val.length() != network.numParams(false))
        throw new IllegalStateException(
                        "Network did not have same number of parameters as the broadcast set parameters");
    network.setParams(val);

    List<Tuple2<Long, Double>> out = new ArrayList<>();
    while (iter.hasNext()) {
        MultiDataSet ds = iter.next();
        double score = network.score(ds, false);

        long numExamples = ds.getFeatures(0).size(0);
        out.add(new Tuple2<>(numExamples, score * numExamples));
    }

    Nd4j.getExecutioner().commit();

    return out.iterator();
}
 
Example 13
Source File: TestSparkComputationGraph.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testBasic() throws Exception {

    JavaSparkContext sc = this.sc;

    RecordReader rr = new CSVRecordReader(0, ',');
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
    MultiDataSetIterator iter = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr)
                    .addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build();

    List<MultiDataSet> list = new ArrayList<>(150);
    while (iter.hasNext())
        list.add(iter.next());

    ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
                    .updater(new Sgd(0.1))
                    .graphBuilder().addInputs("in")
                    .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out",
                                    new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3)
                                                    .build(),
                                    "dense")
                    .setOutputs("out").build();

    ComputationGraph cg = new ComputationGraph(config);
    cg.init();

    TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);

    SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm);
    scg.setListeners(Collections.singleton((TrainingListener) new ScoreIterationListener(5)));

    JavaRDD<MultiDataSet> rdd = sc.parallelize(list);
    scg.fitMultiDataSet(rdd);

    //Try: fitting using DataSet
    DataSetIterator iris = new IrisDataSetIterator(1, 150);
    List<DataSet> list2 = new ArrayList<>();
    while (iris.hasNext())
        list2.add(iris.next());
    JavaRDD<DataSet> rddDS = sc.parallelize(list2);

    scg.fit(rddDS);
}