Java Code Examples for org.nd4j.linalg.dataset.DataSet#numExamples()

The following examples show how to use org.nd4j.linalg.dataset.DataSet#numExamples() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: PathSparkDataSetIterator.java    From deeplearning4j with Apache License 2.0 7 votes vote down vote up
@Override
public DataSet next() {
    DataSet ds;
    if (preloadedDataSet != null) {
        ds = preloadedDataSet;
        preloadedDataSet = null;
    } else {
        ds = load(iter.next());
    }

    totalOutcomes = ds.getLabels() == null ? 0 : (int) ds.getLabels().size(1); //May be null for layerwise pretraining
    inputColumns = (int) ds.getFeatures().size(1);
    batch = ds.numExamples();

    if (preprocessor != null)
        preprocessor.preProcess(ds);
    return ds;
}
 
Example 2
Source File: KFoldIterator.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**Create an iterator given the dataset and a value of k (optional, defaults to 10)
 * If number of samples in the dataset is not a multiple of k, the last fold will have less samples with the rest having the same number of samples.
 *
 * @param k number of folds (optional, defaults to 10)
 * @param singleFold DataSet to split into k folds
 */

public KFoldIterator(int k, DataSet singleFold) {
    this.k = k;
    this.singleFold = singleFold.copy();
    if (k <= 1)
        throw new IllegalArgumentException();
    if (singleFold.numExamples() % k != 0) {
        if (k != 2) {
            this.batch = singleFold.numExamples() / (k - 1);
            this.lastBatch = singleFold.numExamples() % (k - 1);
        } else {
            this.lastBatch = singleFold.numExamples() / 2;
            this.batch = this.lastBatch + 1;
        }
    } else {
        this.batch = singleFold.numExamples() / k;
        this.lastBatch = singleFold.numExamples() / k;
    }
}
 
Example 3
Source File: ScoreUtil.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * Score the given test data
 * with the given multi layer network
 * @param model model to use
 * @param testData the test data to test with
 * @param average whether to average the score or not
 * @return the score for the given test data given the model
 */
public static double score(MultiLayerNetwork model, DataSetIterator testData, boolean average) {
    //TODO: do this properly taking into account division by N, L1/L2 etc
    double sumScore = 0.0;
    int totalExamples = 0;
    while (testData.hasNext()) {
        DataSet ds = testData.next();
        int numExamples = ds.numExamples();

        sumScore += numExamples * model.score(ds);
        totalExamples += numExamples;
    }

    if (!average)
        return sumScore;
    return sumScore / totalExamples;
}
 
Example 4
Source File: DrawMnist.java    From Canova with Apache License 2.0 6 votes vote down vote up
public static void drawMnist(DataSet mnist,INDArray reconstruct) throws InterruptedException {
	for(int j = 0; j < mnist.numExamples(); j++) {
		INDArray draw1 = mnist.get(j).getFeatureMatrix().mul(255);
		INDArray reconstructed2 = reconstruct.getRow(j);
		INDArray draw2 = Sampling.binomial(reconstructed2, 1, new MersenneTwister(123)).mul(255);

		DrawReconstruction d = new DrawReconstruction(draw1);
		d.title = "REAL";
		d.draw();
		DrawReconstruction d2 = new DrawReconstruction(draw2,1000,1000);
		d2.title = "TEST";
		
		d2.draw();
		Thread.sleep(1000);
		d.frame.dispose();
		d2.frame.dispose();

	}
}
 
Example 5
Source File: KFoldIterator.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * Create an iterator given the dataset with given k train-test splits
 * N number of samples are split into k batches. The first (N%k) batches contain (N/k)+1 samples, while the remaining batches contain (N/k) samples.
 * In case the number of samples (N) in the dataset is a multiple of k, all batches will contain (N/k) samples.
 * @param k number of folds (optional, defaults to 10)
 * @param allData DataSet to split into k folds
 */
public KFoldIterator(int k, DataSet allData) {
    if (k <= 1) {
        throw new IllegalArgumentException();
    }
    this.k = k;
    this.N = allData.numExamples();
    this.allData = allData;
    
    // generate index interval boundaries of test folds
    int baseBatchSize = N / k;
    int numIncrementedBatches = N % k;

    this.intervalBoundaries = new int[k+1];
    intervalBoundaries[0] = 0;
    for (int i = 1; i <= k; i++) {
    	if (i <= numIncrementedBatches) {
            intervalBoundaries[i] = intervalBoundaries[i-1] + (baseBatchSize + 1);
        } else {
        	intervalBoundaries[i] = intervalBoundaries[i-1] + baseBatchSize;
        }
    }
    
}
 
Example 6
Source File: Upsampling1DTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public INDArray getData() throws Exception {
    DataSetIterator data = new MnistDataSetIterator(5, 5);
    DataSet mnist = data.next();
    nExamples = mnist.numExamples();
    INDArray features = mnist.getFeatures().reshape(nExamples, nChannelsIn, inputLength, inputLength);
    return features.slice(0, 3);
}
 
Example 7
Source File: LoadAndDraw.java    From Canova with Apache License 2.0 5 votes vote down vote up
/**
 * @param args
 */
public static void main(String[] args) throws Exception {
	MnistDataSetIterator iter = new MnistDataSetIterator(60,60000);
	@SuppressWarnings("unchecked")
	ObjectInputStream ois = new ObjectInputStream(new FileInputStream(args[0]));
	
	BasePretrainNetwork network = (BasePretrainNetwork) ois.readObject();
	
	
	DataSet test = null;
	while(iter.hasNext()) {
		INDArray reconstructed = network.transform(test.getFeatureMatrix());
		for(int i = 0; i < test.numExamples(); i++) {
			INDArray draw1 = test.get(i).getFeatureMatrix().mul(255);
			INDArray reconstructed2 = reconstructed.getRow(i);
			INDArray draw2 = Sampling.binomial(reconstructed2, 1, new MersenneTwister(123)).mul(255);

			DrawReconstruction d = new DrawReconstruction(draw1);
			d.title = "REAL";
			d.draw();
			DrawReconstruction d2 = new DrawReconstruction(draw2,100,100);
			d2.title = "TEST";
			d2.draw();
			Thread.sleep(10000);
			d.frame.dispose();
			d2.frame.dispose();
		}
	}
	
	
}
 
Example 8
Source File: BatchCSVRecord.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Return a batch record based on a dataset
 * @param dataSet the dataset to get the batch record for
 * @return the batch record
 */
public static BatchCSVRecord fromDataSet(DataSet dataSet) {
    BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
    for (int i = 0; i < dataSet.numExamples(); i++) {
        batchCSVRecord.add(SingleCSVRecord.fromRow(dataSet.get(i)));
    }

    return batchCSVRecord;
}
 
Example 9
Source File: ConvolutionLayerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public INDArray getMnistData() throws Exception {
    int inputWidth = 28;
    int inputHeight = 28;
    int nChannelsIn = 1;
    int nExamples = 5;

    DataSetIterator data = new MnistDataSetIterator(nExamples, nExamples);
    DataSet mnist = data.next();
    nExamples = mnist.numExamples();
    return mnist.getFeatures().reshape(nExamples, nChannelsIn, inputHeight, inputWidth);
}
 
Example 10
Source File: DataSetOrdering.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public boolean lteq(DataSet dataSet, DataSet t1) {
    return dataSet.numExamples() >= t1.numExamples();
}
 
Example 11
Source File: FileDataSetIterator.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
protected long sizeOf(DataSet of) {
    return of.numExamples();
}
 
Example 12
Source File: SubsamplingLayerTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public INDArray getData() throws Exception {
    DataSetIterator data = new MnistDataSetIterator(5, 5);
    DataSet mnist = data.next();
    nExamples = mnist.numExamples();
    return mnist.getFeatures().reshape(nExamples, nChannelsIn, inputWidth, inputHeight);
}
 
Example 13
Source File: Upsampling2DTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public INDArray getData() throws Exception {
    DataSetIterator data = new MnistDataSetIterator(5, 5);
    DataSet mnist = data.next();
    nExamples = mnist.numExamples();
    return mnist.getFeatures().reshape(nExamples, nChannelsIn, inputWidth, inputHeight);
}
 
Example 14
Source File: Convolution3DTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public INDArray getData() throws Exception {
    DataSetIterator data = new MnistDataSetIterator(5, 5);
    DataSet mnist = data.next();
    nExamples = mnist.numExamples();
    return mnist.getFeatures().reshape(nExamples, nChannelsIn, inputDepth, inputHeight, inputWidth);
}
 
Example 15
Source File: LearnDigitsBackprop.java    From aifh with Apache License 2.0 4 votes vote down vote up
/**
 * The main method.
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 1e-2;
        int nEpochs = 50;
        int batchSize = 500;

        // Setup training data.
        System.out.println("Please wait, reading MNIST training data.");
        String dir = System.getProperty("user.dir");
        MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
        MNISTReader validationReader = MNIST.loadMNIST(dir, false);

        DataSet trainingSet = trainingReader.getData();
        DataSet validationSet = validationReader.getData();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationReader.getNumRows());

        System.out.println("Training set size: " + trainingReader.getNumImages());
        System.out.println("Validation set size: " + validationReader.getNumImages());

        System.out.println(trainingSet.get(0).getFeatures().size(1));
        System.out.println(validationSet.get(0).getFeatures().size(1));

        int numInputs = trainingReader.getNumCols()*trainingReader.getNumRows();
        int numOutputs = 10;
        int numHiddenNodes = 200;

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .regularization(true).dropOut(0.50)
                .list(2)
                .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                        .weightInit(WeightInit.XAVIER)
                        .activation("relu")
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .weightInit(WeightInit.XAVIER)
                        .activation("softmax")
                        .nIn(numHiddenNodes).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();


        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                //.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
                .evaluateEveryNEpochs(1)
                .scoreCalculator(new DataSetLossCalculator(validationSetIterator, true))     //Calculate test set score
                .modelSaver(saver)
                .build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch(Exception ex) {
        ex.printStackTrace();
    }

}
 
Example 16
Source File: DataSetOrdering.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public boolean equiv(DataSet dataSet, DataSet t1) {
    return dataSet.numExamples() == t1.numExamples();
}
 
Example 17
Source File: RnnTextEmbeddingDataSetIterator.java    From wekaDeeplearning4j with GNU General Public License v3.0 4 votes vote down vote up
@Override
public DataSet next(int num) {
  // Check if next() call is valid - throws appropriate exceptions
  checkIfNextIsValid();

  // Collect
  List<String> sentences = new ArrayList<>(num);
  List<Double> labelsRaw = new ArrayList<>(num);
  collectData(num, sentences, labelsRaw);
  final int numDocuments = sentences.size();

  // Tokenize sentences
  List<List<String>> tokenizedSentences = tokenizeSentences(sentences);

  // Get longest sentence length
  int maxSentenceLength = tokenizedSentences.stream().mapToInt(List::size).max().getAsInt();

  // Truncate maximum sentence length
  if (maxSentenceLength > truncateLength || maxSentenceLength == 0) {
    maxSentenceLength = truncateLength;
  }

  // Init feature/label arrays
  int[] featureShape = {numDocuments, wordVectorSize, maxSentenceLength};
  int[] labelShape = {numDocuments, data.numClasses(), maxSentenceLength};
  INDArray features = Nd4j.create(featureShape, 'f');
  INDArray labels = Nd4j.create(labelShape, 'f');
  INDArray featuresMask = Nd4j.zeros(numDocuments, maxSentenceLength);
  INDArray labelsMask = Nd4j.zeros(numDocuments, maxSentenceLength);

  for (int i = 0; i < numDocuments; i++) {
    List<String> tokens = tokenizedSentences.get(i);

    // Check for empty document
    if (tokens.isEmpty()) {
      continue;
    }

    // Get the last index of the current document (truncated)
    int lastIdx = Math.min(tokens.size(), maxSentenceLength);

    // Get all wordvectors in batch
    List<String> truncatedTokenList = tokens.subList(0, lastIdx);
    final INDArray vectors = wordVectors.getWordVectors(truncatedTokenList).transpose();

    /*
     * Put wordvectors into features array at the following indices:
     * 1) Document (i)
     * 2) All vector elements which is equal to NDArrayIndex.interval(0, vectorSize)
     * 3) All elements between 0 and the length of the current sequence
     */
    INDArrayIndex[] indices = {point(i), all(), interval(0, lastIdx)};
    features.put(indices, vectors);

    // Assign "1" to each position where a feature is present, that is, in the interval of
    // [0, lastIdx)
    featuresMask.get(point(i), interval(0, lastIdx)).assign(1);

    // Put the labels in the labels and labelsMask arrays
    // Differ between classification and regression task
    if (data.numClasses() == 1) { // Regression
      double val = labelsRaw.get(i);
      labels.putScalar(new int[]{i, 0, lastIdx - 1}, val);
    } else if (data.numClasses() > 1) { // Classification
      // One-Hot-Encoded class
      int idx = labelsRaw.get(i).intValue();
      // Set label
      labels.putScalar(new int[]{i, idx, lastIdx - 1}, 1.0);
    } else {
      throw new RuntimeException("Could not detect classification or regression task.");
    }

    // Set final timestep for this example to 1.0 to show that an output exists here
    int[] lastTimestepIndex = {i, lastIdx - 1};
    labelsMask.putScalar(lastTimestepIndex, 1.0);
  }

  // Cache the dataset
  final DataSet ds = new DataSet(features, labels, featuresMask, labelsMask);

  // Move cursor
  cursor += ds.numExamples();
  return ds;
}
 
Example 18
Source File: LearnIrisBackprop.java    From aifh with Apache License 2.0 4 votes vote down vote up
/**
 * The main method.
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 0.1;
        int splitTrainNum = (int) (150 * .75);

        int numInputs = 4;
        int numOutputs = 3;
        int numHiddenNodes = 50;

        // Setup training data.
        final InputStream istream = LearnIrisBackprop.class.getResourceAsStream("/iris.csv");
        if( istream==null ) {
            System.out.println("Cannot access data set, make sure the resources are available.");
            System.exit(1);
        }
        final NormalizeDataSet ds = NormalizeDataSet.load(istream);
        final CategoryMap species = ds.encodeOneOfN(4); // species is column 4
        istream.close();

        DataSet next = ds.extractSupervised(0, 4, 4, 3);
        next.shuffle();

        // Training and validation data split
        SplitTestAndTrain testAndTrain = next.splitTestAndTrain(splitTrainNum, new Random(seed));
        DataSet trainSet = testAndTrain.getTrain();
        DataSet validationSet = testAndTrain.getTest();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainSet.asList(), trainSet.numExamples());

        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationSet.numExamples());

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .list(2)
                .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                        .weightInit(WeightInit.XAVIER)
                        .activation("relu")
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .weightInit(WeightInit.XAVIER)
                        .activation("softmax")
                        .nIn(numHiddenNodes).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();


        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                .epochTerminationConditions(new MaxEpochsTerminationCondition(500)) //Max of 50 epochs
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(25))
                .evaluateEveryNEpochs(1)
                .scoreCalculator(new DataSetLossCalculator(validationSetIterator, true))     //Calculate test set score
                .modelSaver(saver)
                .build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            System.out.println(features + ":Prediction("+findSpecies(labels,species)
                    +"):Actual("+findSpecies(predicted,species)+")" + predicted );
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch(Exception ex) {
        ex.printStackTrace();
    }
}
 
Example 19
Source File: LearnDigitsDropout.java    From aifh with Apache License 2.0 4 votes vote down vote up
/**
 * The main method.
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 1e-2;
        int nEpochs = 50;
        int batchSize = 500;

        // Setup training data.
        System.out.println("Please wait, reading MNIST training data.");
        String dir = System.getProperty("user.dir");
        MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
        MNISTReader validationReader = MNIST.loadMNIST(dir, false);

        DataSet trainingSet = trainingReader.getData();
        DataSet validationSet = validationReader.getData();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationReader.getNumRows());

        System.out.println("Training set size: " + trainingReader.getNumImages());
        System.out.println("Validation set size: " + validationReader.getNumImages());

        System.out.println(trainingSet.get(0).getFeatures().size(1));
        System.out.println(validationSet.get(0).getFeatures().size(1));

        int numInputs = trainingReader.getNumCols()*trainingReader.getNumRows();
        int numOutputs = 10;
        int numHiddenNodes = 100;

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .list(2)
                .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                        .weightInit(WeightInit.XAVIER)
                        .activation("relu")
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .weightInit(WeightInit.XAVIER)
                        .activation("softmax")
                        .nIn(numHiddenNodes).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();


        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                //.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
                .evaluateEveryNEpochs(1)
                .scoreCalculator(new DataSetLossCalculator(validationSetIterator, true))     //Calculate test set score
                .modelSaver(saver)
                .build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch(Exception ex) {
        ex.printStackTrace();
    }

}
 
Example 20
Source File: ScoreExamplesFunction.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public Iterator<Double> call(Iterator<DataSet> iterator) throws Exception {
    if (!iterator.hasNext()) {
        return Collections.emptyIterator();
    }

    MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue()));
    network.init();
    INDArray val = params.value().unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException(
                        "Network did not have same number of parameters as the broadcast set parameters");
    network.setParameters(val);

    List<Double> ret = new ArrayList<>();

    List<DataSet> collect = new ArrayList<>(batchSize);
    int totalCount = 0;
    while (iterator.hasNext()) {
        collect.clear();
        int nExamples = 0;
        while (iterator.hasNext() && nExamples < batchSize) {
            DataSet ds = iterator.next();
            int n = ds.numExamples();
            collect.add(ds);
            nExamples += n;
        }
        totalCount += nExamples;

        DataSet data = DataSet.merge(collect);


        INDArray scores = network.scoreExamples(data, addRegularization);
        double[] doubleScores = scores.data().asDouble();

        for (double doubleScore : doubleScores) {
            ret.add(doubleScore);
        }
    }

    Nd4j.getExecutioner().commit();

    if (log.isDebugEnabled()) {
        log.debug("Scored {} examples ", totalCount);
    }

    return ret.iterator();
}