Java Code Examples for org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator#next()

The following examples show how to use org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator#next() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TrainUtil.java    From FancyBing with GNU General Public License v3.0 6 votes vote down vote up
public static double evaluate(Model model, int outputNum, MultiDataSetIterator testData, int topN, int batchSize) {
	log.info("Evaluate model....");
    Evaluation clsEval = new Evaluation(createLabels(outputNum), topN);
    RegressionEvaluation valueRegEval1 = new RegressionEvaluation(1);
    int count = 0;
    
    long begin = 0;
    long consume = 0;
    while(testData.hasNext()){
    	MultiDataSet ds = testData.next();
    	begin = System.nanoTime();
    	INDArray[] output = ((ComputationGraph) model).output(false, ds.getFeatures());
    	consume += System.nanoTime() - begin;
        clsEval.eval(ds.getLabels(0), output[0]);
        valueRegEval1.eval(ds.getLabels(1), output[1]);
        count++;
    }
    String stats = clsEval.stats();
    int pos = stats.indexOf("===");
    stats = "\n" + stats.substring(pos);
    log.info(stats);
    log.info(valueRegEval1.stats());
    testData.reset();
    log.info("Evaluate time: " + consume + " count: " + (count * batchSize) + " average: " + ((float) consume/(count*batchSize)/1000));
	return clsEval.accuracy();
}
 
Example 2
Source File: AbstractMultiDataSetNormalizer.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * Fit an iterator
 *
 * @param iterator for the data to iterate over
 */
public void fit(@NonNull MultiDataSetIterator iterator) {
    List<S.Builder> featureNormBuilders = new ArrayList<>();
    List<S.Builder> labelNormBuilders = new ArrayList<>();

    iterator.reset();
    while (iterator.hasNext()) {
        MultiDataSet next = iterator.next();
        fitPartial(next, featureNormBuilders, labelNormBuilders);
    }

    featureStats = buildList(featureNormBuilders);
    if (isFitLabel()) {
        labelStats = buildList(labelNormBuilders);
    }
}
 
Example 3
Source File: AbstractMultiDataSetNormalizer.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * Fit an iterator
 *
 * @param iterator for the data to iterate over
 */
public void fit(@NonNull MultiDataSetIterator iterator) {
    List<S.Builder> featureNormBuilders = new ArrayList<>();
    List<S.Builder> labelNormBuilders = new ArrayList<>();

    iterator.reset();
    while (iterator.hasNext()) {
        MultiDataSet next = iterator.next();
        fitPartial(next, featureNormBuilders, labelNormBuilders);
    }

    featureStats = buildList(featureNormBuilders);
    if (isFitLabel()) {
        labelStats = buildList(labelNormBuilders);
    }
}
 
Example 4
Source File: ScoreUtil.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * Score based on the loss function
 * @param model the model to score with
 * @param testData the test data to score
 * @param average whether to average the score
 *                for the whole batch or not
 * @return the score for the given test set
 */
public static double score(ComputationGraph model, MultiDataSetIterator testData, boolean average) {
    //TODO: do this properly taking into account division by N, L1/L2 etc
    double sumScore = 0.0;
    int totalExamples = 0;
    while (testData.hasNext()) {
        MultiDataSet ds = testData.next();
        long numExamples = ds.getFeatures(0).size(0);
        sumScore += numExamples * model.score(ds);
        totalExamples += numExamples;
    }

    if (!average)
        return sumScore;
    return sumScore / totalExamples;
}
 
Example 5
Source File: Main.java    From twse-captcha-solver-dl4j with MIT License 5 votes vote down vote up
public static void modelPredict(ComputationGraph model, MultiDataSetIterator iterator) {
  int sumCount = 0;
  int correctCount = 0;

  List<String> labelList =
      Arrays.asList(
          "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G",
          "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X",
          "Y", "Z");

  while (iterator.hasNext()) {
    MultiDataSet mds = iterator.next();
    INDArray[] output = model.output(mds.getFeatures());
    INDArray[] labels = mds.getLabels();
    int dataNum = batchSize > output[0].rows() ? output[0].rows() : batchSize;
    for (int dataIndex = 0; dataIndex < dataNum; dataIndex++) {
      String reLabel = "";
      String peLabel = "";
      INDArray preOutput = null;
      INDArray realLabel = null;
      for (int digit = 0; digit < 5; digit++) {
        preOutput = output[digit].getRow(dataIndex);
        peLabel += labelList.get(Nd4j.argMax(preOutput, 1).getInt(0));

        realLabel = labels[digit].getRow(dataIndex);
 reLabel += labelList.get(Nd4j.argMax(realLabel, 1).getInt(0));
      }
      if (peLabel.equals(reLabel)) {
        correctCount++;
      }
      sumCount++;
      logger.info(
          "real image {}  prediction {} status {}", reLabel, peLabel, peLabel.equals(reLabel));
    }
  }
  iterator.reset();
  System.out.println(
      "validate result : sum count =" + sumCount + " correct count=" + correctCount);
}
 
Example 6
Source File: RandomDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testMDSI(){
    Nd4j.getRandom().setSeed(12345);
    MultiDataSetIterator iter = new RandomMultiDataSetIterator.Builder(5)
            .addFeatures(new long[]{3,4}, RandomMultiDataSetIterator.Values.INTEGER_0_100)
            .addFeatures(new long[]{3,5}, RandomMultiDataSetIterator.Values.BINARY)
            .addLabels(new long[]{3,6}, RandomMultiDataSetIterator.Values.ZEROS)
        .build();

    int count = 0;
    while(iter.hasNext()){
        count++;
        MultiDataSet mds = iter.next();

        assertEquals(2, mds.numFeatureArrays());
        assertEquals(1, mds.numLabelsArrays());
        assertArrayEquals(new long[]{3,4}, mds.getFeatures(0).shape());
        assertArrayEquals(new long[]{3,5}, mds.getFeatures(1).shape());
        assertArrayEquals(new long[]{3,6}, mds.getLabels(0).shape());

        assertTrue(mds.getFeatures(0).minNumber().doubleValue() >= 0 && mds.getFeatures(0).maxNumber().doubleValue() <= 100.0
                && mds.getFeatures(0).maxNumber().doubleValue() > 2.0);
        assertTrue(mds.getFeatures(1).minNumber().doubleValue() == 0.0 && mds.getFeatures(1).maxNumber().doubleValue() == 1.0);
        assertEquals(0.0, mds.getLabels(0).sumNumber().doubleValue(), 0.0);
    }
    assertEquals(5, count);
}
 
Example 7
Source File: LoaderIteratorTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testMDSLoaderIter(){

    for(boolean r : new boolean[]{false, true}) {
        List<String> l = Arrays.asList("3", "0", "1");
        Random rng = r ? new Random(12345) : null;
        MultiDataSetIterator iter = new MultiDataSetLoaderIterator(l, null, new Loader<MultiDataSet>() {
            @Override
            public MultiDataSet load(Source source) throws IOException {
                INDArray i = Nd4j.scalar(Integer.valueOf(source.getPath()));
                return new org.nd4j.linalg.dataset.MultiDataSet(i, i);
            }
        }, new LocalFileSourceFactory());

        int count = 0;
        int[] exp = {3, 0, 1};
        while (iter.hasNext()) {
            MultiDataSet ds = iter.next();
            if(!r) {
                assertEquals(exp[count], ds.getFeatures()[0].getInt(0));
            }
            count++;
        }
        assertEquals(3, count);

        iter.reset();
        assertTrue(iter.hasNext());
    }
}
 
Example 8
Source File: ScoreFlatMapFunctionCGMultiDataSet.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public Iterator<Tuple2<Long, Double>> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
    if (!dataSetIterator.hasNext()) {
        return Collections.singletonList(new Tuple2<>(0L, 0.0)).iterator();
    }

    MultiDataSetIterator iter = new IteratorMultiDataSetIterator(dataSetIterator, minibatchSize); //Does batching where appropriate


    ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(json));
    network.init();
    INDArray val = params.value().unsafeDuplication(); //.value() is shared by all executors on single machine -> OK, as params are not changed in score function
    if (val.length() != network.numParams(false))
        throw new IllegalStateException(
                        "Network did not have same number of parameters as the broadcast set parameters");
    network.setParams(val);

    List<Tuple2<Long, Double>> out = new ArrayList<>();
    while (iter.hasNext()) {
        MultiDataSet ds = iter.next();
        double score = network.score(ds, false);

        long numExamples = ds.getFeatures(0).size(0);
        out.add(new Tuple2<>(numExamples, score * numExamples));
    }

    Nd4j.getExecutioner().commit();

    return out.iterator();
}
 
Example 9
Source File: RecordReaderMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSplittingCSV() throws Exception {
    //Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
    //Inputs: columns 0 and 1-2
    //Outputs: columns 3, and 4->OneHot
    //need to manually extract
    RecordReader rr = new CSVRecordReader(0, ',');
    rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
    RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 10, 4, 3);

    RecordReader rr2 = new CSVRecordReader(0, ',');
    rr2.initialize(new FileSplit(Resources.asFile("iris.txt")));

    MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2)
                    .addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3)
                    .addOutputOneHot("reader", 4, 3).build();

    while (rrdsi.hasNext()) {
        DataSet ds = rrdsi.next();
        INDArray fds = ds.getFeatures();
        INDArray lds = ds.getLabels();

        MultiDataSet mds = rrmdsi.next();
        assertEquals(2, mds.getFeatures().length);
        assertEquals(2, mds.getLabels().length);
        assertNull(mds.getFeaturesMaskArrays());
        assertNull(mds.getLabelsMaskArrays());
        INDArray[] fmds = mds.getFeatures();
        INDArray[] lmds = mds.getLabels();

        assertNotNull(fmds);
        assertNotNull(lmds);
        for (int i = 0; i < fmds.length; i++)
            assertNotNull(fmds[i]);
        for (int i = 0; i < lmds.length; i++)
            assertNotNull(lmds[i]);

        //Get the subsets of the original iris data
        INDArray expIn1 = fds.get(all(), interval(0,0,true));
        INDArray expIn2 = fds.get(all(), interval(1, 2, true));
        INDArray expOut1 = fds.get(all(), interval(3,3,true));
        INDArray expOut2 = lds;

        assertEquals(expIn1, fmds[0]);
        assertEquals(expIn2, fmds[1]);
        assertEquals(expOut1, lmds[0]);
        assertEquals(expOut2, lmds[1]);
    }
    assertFalse(rrmdsi.hasNext());
}
 
Example 10
Source File: RecordReaderMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSplittingCSVSequence() throws Exception {
    //Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt"
    // as standard one-hot output
    //need to manually extract
    File rootDir = temporaryFolder.newFolder();
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
    }

    String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
    String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReaderDataSetIterator iter =
                    new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);

    SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
    featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1)
                    .addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2)
                    .addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();

    while (iter.hasNext()) {
        DataSet ds = iter.next();
        INDArray fds = ds.getFeatures();
        INDArray lds = ds.getLabels();

        MultiDataSet mds = srrmdsi.next();
        assertEquals(2, mds.getFeatures().length);
        assertEquals(1, mds.getLabels().length);
        assertNull(mds.getFeaturesMaskArrays());
        assertNull(mds.getLabelsMaskArrays());
        INDArray[] fmds = mds.getFeatures();
        INDArray[] lmds = mds.getLabels();

        assertNotNull(fmds);
        assertNotNull(lmds);
        for (int i = 0; i < fmds.length; i++)
            assertNotNull(fmds[i]);
        for (int i = 0; i < lmds.length; i++)
            assertNotNull(lmds[i]);

        INDArray expIn1 = fds.get(all(), NDArrayIndex.interval(0, 1, true), all());
        INDArray expIn2 = fds.get(all(), NDArrayIndex.interval(2, 2, true), all());

        assertEquals(expIn1, fmds[0]);
        assertEquals(expIn2, fmds[1]);
        assertEquals(lds, lmds[0]);
    }
    assertFalse(srrmdsi.hasNext());
}
 
Example 11
Source File: RecordReaderMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testImagesRRDMSI() throws Exception {
    File parentDir = temporaryFolder.newFolder();
    parentDir.deleteOnExit();
    String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
    String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");

    File f1 = new File(str1);
    File f2 = new File(str2);
    f1.mkdirs();
    f2.mkdirs();

    TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")),
                    new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
    TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")),
                    new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());


    int outputNum = 2;
    Random r = new Random(12345);
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();

    ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);

    rr1.initialize(new FileSplit(parentDir));
    rr1s.initialize(new FileSplit(parentDir));


    MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(1).addReader("rr1", rr1)
                    .addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0)
                    .addOutputOneHot("rr1s", 1, outputNum).build();

    //Now, do the same thing with ImageRecordReader, and check we get the same results:
    ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1_b.initialize(new FileSplit(parentDir));
    rr1s_b.initialize(new FileSplit(parentDir));

    DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 1, 1, 2);
    DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 1, 1, 2);

    for (int i = 0; i < 2; i++) {
        MultiDataSet mds = trainDataIterator.next();

        DataSet d1 = dsi1.next();
        DataSet d2 = dsi2.next();

        assertEquals(d1.getFeatures(), mds.getFeatures(0));
        assertEquals(d2.getFeatures(), mds.getFeatures(1));
        assertEquals(d1.getLabels(), mds.getLabels(0));
    }
}
 
Example 12
Source File: RecordReaderMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testImagesRRDMSI_Batched() throws Exception {
    File parentDir = temporaryFolder.newFolder();
    parentDir.deleteOnExit();
    String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
    String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");

    File f1 = new File(str1);
    File f2 = new File(str2);
    f1.mkdirs();
    f2.mkdirs();

    TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")),
                    new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
    TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")),
                    new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());

    int outputNum = 2;
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();

    ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);

    URI[] uris = new FileSplit(parentDir).locations();

    rr1.initialize(new CollectionInputSplit(uris));
    rr1s.initialize(new CollectionInputSplit(uris));

    MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(2).addReader("rr1", rr1)
                    .addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0)
                    .addOutputOneHot("rr1s", 1, outputNum).build();

    //Now, do the same thing with ImageRecordReader, and check we get the same results:
    ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1_b.initialize(new FileSplit(parentDir));
    rr1s_b.initialize(new FileSplit(parentDir));

    DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 2, 1, 2);
    DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 2, 1, 2);

    MultiDataSet mds = trainDataIterator.next();

    DataSet d1 = dsi1.next();
    DataSet d2 = dsi2.next();

    assertEquals(d1.getFeatures(), mds.getFeatures(0));
    assertEquals(d2.getFeatures(), mds.getFeatures(1));
    assertEquals(d1.getLabels(), mds.getLabels(0));

    //Check label assignment:

    File currentFile = rr1_b.getCurrentFile();
    INDArray expLabels;
    if(currentFile.getAbsolutePath().contains("Zico")){
        expLabels = Nd4j.create(new double[][] {{0, 1}, {1, 0}});
    } else {
        expLabels = Nd4j.create(new double[][] {{1, 0}, {0, 1}});
    }

    assertEquals(expLabels, d1.getLabels());
    assertEquals(expLabels, d2.getLabels());
}