Java Code Examples for org.nd4j.linalg.dataset.DataSet

The following examples show how to use org.nd4j.linalg.dataset.DataSet. These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source Project: deeplearning4j   Source File: PathSparkDataSetIterator.java    License: Apache License 2.0 7 votes vote down vote up
@Override
public DataSet next() {
    DataSet ds;
    if (preloadedDataSet != null) {
        ds = preloadedDataSet;
        preloadedDataSet = null;
    } else {
        ds = load(iter.next());
    }

    totalOutcomes = ds.getLabels() == null ? 0 : (int) ds.getLabels().size(1); //May be null for layerwise pretraining
    inputColumns = (int) ds.getFeatures().size(1);
    batch = ds.numExamples();

    if (preprocessor != null)
        preprocessor.preProcess(ds);
    return ds;
}
 
Example 2
Source Project: deeplearning4j   Source File: DataSetExportFunction.java    License: Apache License 2.0 6 votes vote down vote up
@Override
public void call(Iterator<DataSet> iter) throws Exception {
    String jvmuid = UIDProvider.getJVMUID();
    uid = Thread.currentThread().getId() + jvmuid.substring(0, Math.min(8, jvmuid.length()));

    Configuration c = conf == null ? DefaultHadoopConfig.get() : conf.getValue().getConfiguration();

    while (iter.hasNext()) {
        DataSet next = iter.next();

        String filename = "dataset_" + uid + "_" + (outputCount++) + ".bin";

        String path = outputDir.getPath();
        URI uri = new URI(path + (path.endsWith("/") || path.endsWith("\\") ? "" : "/") + filename);
        FileSystem file = FileSystem.get(uri, c);
        try (FSDataOutputStream out = file.create(new Path(uri))) {
            next.save(out);
        }
    }
}
 
Example 3
/**
 * Test getDataSetIterator
 */
@Test
public void testGetIteratorNominalClass() throws Exception {
  final Instances data = DatasetLoader.loadReutersMinimal();
  final int batchSize = 1;
  final DataSetIterator it = this.cteii.getDataSetIterator(data, SEED, batchSize);

  Set<Integer> labels = new HashSet<>();
  for (Instance inst : data) {
    int label = Integer.parseInt(inst.stringValue(data.classIndex()));
    final DataSet next = Utils.getNext(it);
    int itLabel = next.getLabels().argMax().getInt(0);
    Assert.assertEquals(label, itLabel);
    labels.add(label);
  }
  final Set<Integer> collect =
      it.getLabels().stream().map(s -> Double.valueOf(s).intValue()).collect(Collectors.toSet());
  Assert.assertEquals(2, labels.size());
  Assert.assertTrue(labels.containsAll(collect));
  Assert.assertTrue(collect.containsAll(labels));
}
 
Example 4
Source Project: deeplearning4j   Source File: RandomDataSetIteratorTest.java    License: Apache License 2.0 6 votes vote down vote up
@Test
public void testDSI(){
    DataSetIterator iter = new RandomDataSetIterator(5, new long[]{3,4}, new long[]{3,5}, RandomDataSetIterator.Values.RANDOM_UNIFORM,
            RandomDataSetIterator.Values.ONE_HOT);

    int count = 0;
    while(iter.hasNext()){
        count++;
        DataSet ds = iter.next();

        assertArrayEquals(new long[]{3,4}, ds.getFeatures().shape());
        assertArrayEquals(new long[]{3,5}, ds.getLabels().shape());

        assertTrue(ds.getFeatures().minNumber().doubleValue() >= 0.0 && ds.getFeatures().maxNumber().doubleValue() <= 1.0);
        assertEquals(Nd4j.ones(3), ds.getLabels().sum(1));
    }
    assertEquals(5, count);
}
 
Example 5
Source Project: deeplearning4j   Source File: TransferLearningHelper.java    License: Apache License 2.0 6 votes vote down vote up
/**
 * During training frozen vertices/layers can be treated as "featurizing" the input
 * The forward pass through these frozen layer/vertices can be done in advance and the dataset saved to disk to iterate
 * quickly on the smaller unfrozen part of the model
 * Currently does not support datasets with feature masks
 *
 * @param input multidataset to feed into the computation graph with frozen layer vertices
 * @return a multidataset with input features that are the outputs of the frozen layer vertices and the original labels.
 */
public DataSet featurize(DataSet input) {
    if (isGraph) {
        //trying to featurize for a computation graph
        if (origGraph.getNumInputArrays() > 1 || origGraph.getNumOutputArrays() > 1) {
            throw new IllegalArgumentException(
                            "Input or output size to a computation graph is greater than one. Requires use of a MultiDataSet.");
        } else {
            if (input.getFeaturesMaskArray() != null) {
                throw new IllegalArgumentException(
                                "Currently cannot support featurizing datasets with feature masks");
            }
            MultiDataSet inbW = new MultiDataSet(new INDArray[] {input.getFeatures()},
                            new INDArray[] {input.getLabels()}, null, new INDArray[] {input.getLabelsMaskArray()});
            MultiDataSet ret = featurize(inbW);
            return new DataSet(ret.getFeatures()[0], input.getLabels(), ret.getLabelsMaskArrays()[0],
                            input.getLabelsMaskArray());
        }
    } else {
        if (input.getFeaturesMaskArray() != null)
            throw new UnsupportedOperationException("Feature masks not supported with featurizing currently");
        return new DataSet(origMLN.feedForwardToLayer(frozenInputLayer + 1, input.getFeatures(), false)
                        .get(frozenInputLayer + 1), input.getLabels(), null, input.getLabelsMaskArray());
    }
}
 
Example 6
Source Project: Canova   Source File: DrawMnist.java    License: Apache License 2.0 6 votes vote down vote up
public static void drawMnist(DataSet mnist,INDArray reconstruct) throws InterruptedException {
	for(int j = 0; j < mnist.numExamples(); j++) {
		INDArray draw1 = mnist.get(j).getFeatureMatrix().mul(255);
		INDArray reconstructed2 = reconstruct.getRow(j);
		INDArray draw2 = Sampling.binomial(reconstructed2, 1, new MersenneTwister(123)).mul(255);

		DrawReconstruction d = new DrawReconstruction(draw1);
		d.title = "REAL";
		d.draw();
		DrawReconstruction d2 = new DrawReconstruction(draw2,1000,1000);
		d2.title = "TEST";
		
		d2.draw();
		Thread.sleep(1000);
		d.frame.dispose();
		d2.frame.dispose();

	}
}
 
Example 7
Source Project: deeplearning4j   Source File: ModelSerializerTest.java    License: Apache License 2.0 6 votes vote down vote up
@Test
public void testJavaSerde_1() throws Exception {
    int nIn = 5;
    int nOut = 6;

    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01)
            .graphBuilder()
            .addInputs("in")
            .layer("0", new OutputLayer.Builder().nIn(nIn).nOut(nOut).build(), "in")
            .setOutputs("0")
            .validateOutputLayerConfig(false)
            .build();

    ComputationGraph net = new ComputationGraph(conf);
    net.init();

    DataSet dataSet = trivialDataSet();
    NormalizerStandardize norm = new NormalizerStandardize();
    norm.fit(dataSet);

    val b = SerializationUtils.serialize(net);

    ComputationGraph restored = SerializationUtils.deserialize(b);

    assertEquals(net, restored);
}
 
Example 8
@Override
public DataSet next() {
    MultiDataSet mds = iterator.next();
    if (mds.getFeatures().length > 1 || mds.getLabels().length > 1)
        throw new UnsupportedOperationException(
                        "This iterator is able to convert MultiDataSet with number of inputs/outputs of 1");

    INDArray features = mds.getFeatures()[0];
    INDArray labels = mds.getLabels() != null ? mds.getLabels()[0] : features;
    INDArray fMask = mds.getFeaturesMaskArrays() != null ? mds.getFeaturesMaskArrays()[0] : null;
    INDArray lMask = mds.getLabelsMaskArrays() != null ? mds.getLabelsMaskArrays()[0] : null;

    DataSet ds = new DataSet(features, labels, fMask, lMask);

    if (preProcessor != null)
        preProcessor.preProcess(ds);

    return ds;
}
 
Example 9
/**
 * Test getDataSetIterator
 */
@Test
public void testGetIterator() throws Exception {
  final Instances metaData = DatasetLoader.loadMiniMnistMeta();
  this.idi.setImagesLocation(new File("datasets/nominal/mnist-minimal"));
  final int batchSize = 1;
  final DataSetIterator it = this.idi.getDataSetIterator(metaData, SEED, batchSize);

  Set<Integer> labels = new HashSet<>();
  for (Instance inst : metaData) {
    int label = Integer.parseInt(inst.stringValue(1));
    final DataSet next = Utils.getNext(it);
    int itLabel = next.getLabels().argMax().getInt(0);
    Assert.assertEquals(label, itLabel);
    labels.add(label);
  }
  final List<Integer> collect =
      it.getLabels().stream().map(Integer::valueOf).collect(Collectors.toList());
  Assert.assertEquals(10, labels.size());
  Assert.assertTrue(labels.containsAll(collect));
  Assert.assertTrue(collect.containsAll(labels));
}
 
Example 10
@Test
@Ignore
public void specialRRTest4() throws Exception {
    RecordReader rr = new SpecialImageRecordReader(25000, 10, 3, 224, 224);
    RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 128);

    int cnt = 0;
    int examples = 0;
    while (rrdsi.hasNext()) {
        DataSet ds = rrdsi.next();
        assertEquals(128, ds.numExamples());
        for (int i = 0; i < ds.numExamples(); i++) {
            INDArray example = ds.getFeatures().tensorAlongDimension(i, 1, 2, 3).dup();
            //                assertEquals("Failed on DataSet [" + cnt + "], example [" + i + "]", (double) examples, example.meanNumber().doubleValue(), 0.01);

            //                assertEquals("Failed on DataSet [" + cnt + "], example [" + i + "]", (double) examples, ds.getLabels().getRow(i).meanNumber().doubleValue(), 0.01);
            examples++;
        }
        cnt++;
    }

}
 
Example 11
Source Project: deeplearning4j   Source File: ScoreUtil.java    License: Apache License 2.0 6 votes vote down vote up
/**
 * Score based on the loss function
 * @param model the model to score with
 * @param testData the test data to score
 * @param average whether to average the score
 *                for the whole batch or not
 * @return the score for the given test set
 */
public static double score(ComputationGraph model, DataSetIterator testData, boolean average) {
    //TODO: do this properly taking into account division by N, L1/L2 etc
    double sumScore = 0.0;
    int totalExamples = 0;
    while (testData.hasNext()) {
        DataSet ds = testData.next();
        int numExamples = ds.numExamples();

        sumScore += numExamples * model.score(ds);
        totalExamples += numExamples;
    }

    if (!average)
        return sumScore;
    return sumScore / totalExamples;
}
 
Example 12
Source Project: deeplearning4j   Source File: DataSetIteratorTest.java    License: Apache License 2.0 5 votes vote down vote up
@Test
public void testLfwIterator() throws Exception {
    int numExamples = 1;
    int row = 28;
    int col = 28;
    int channels = 1;
    LFWDataSetIterator iter = new LFWDataSetIterator(numExamples, new int[] {row, col, channels}, true);
    assertTrue(iter.hasNext());
    DataSet data = iter.next();
    assertEquals(numExamples, data.getLabels().size(0));
    assertEquals(row, data.getFeatures().size(2));
}
 
Example 13
Source Project: deeplearning4j   Source File: SplitDataSetsFunction.java    License: Apache License 2.0 5 votes vote down vote up
@Override
public Iterator<DataSet> call(Iterator<DataSet> dataSetIterator) throws Exception {
    List<DataSet> out = new ArrayList<>();
    while (dataSetIterator.hasNext()) {
        out.addAll(dataSetIterator.next().asList());
    }
    return out.iterator();
}
 
Example 14
Source Project: Canova   Source File: LFWLoader.java    License: Apache License 2.0 5 votes vote down vote up
public DataSet convertListPairs(List<DataSet> images) {
    INDArray inputs = Nd4j.create(images.size(), numPixelColumns);
    INDArray outputs = Nd4j.create(images.size(),numNames);

    for(int i = 0; i < images.size(); i++) {
        inputs.putRow(i,images.get(i).getFeatureMatrix());
        outputs.putRow(i,images.get(i).getLabels());
    }
    return new DataSet(inputs,outputs);
}
 
Example 15
Source Project: inception   Source File: DL4JSequenceRecommender.java    License: Apache License 2.0 5 votes vote down vote up
private MultiLayerNetwork train(List<Sample> aTrainingData, Object2IntMap<String> aTagset)
    throws IOException
{
    // Configure the neural network
    MultiLayerNetwork model = createConfiguredNetwork(traits, wordVectors.dimensions());

    final int limit = traits.getTrainingSetSizeLimit();
    final int batchSize = traits.getBatchSize();

    // First vectorizing all sentences and then passing them to the model would consume
    // huge amounts of memory. Thus, every sentence is vectorized and then immediately
    // passed on to the model.
    nextEpoch: for (int epoch = 0; epoch < traits.getnEpochs(); epoch++) {
        int sentNum = 0;
        Iterator<Sample> sampleIterator = aTrainingData.iterator();
        while (sampleIterator.hasNext()) {
            List<DataSet> batch = new ArrayList<>();
            while (sampleIterator.hasNext() && batch.size() < batchSize && sentNum < limit) {
                Sample sample = sampleIterator.next();
                DataSet trainingData = vectorize(asList(sample), aTagset, true);
                batch.add(trainingData);
                sentNum++;
            }
            
            model.fit(new ListDataSetIterator<DataSet>(batch, batch.size()));
            log.trace("Epoch {}: processed {} of {} sentences", epoch, sentNum,
                    aTrainingData.size());
            
            if (sentNum >= limit) {
                continue nextEpoch;
            }
        }
    }

    return model;
}
 
Example 16
@Test
public void when_dataSetIsEmpty_expect_emptyDataSet() {
    // Assemble
    CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor();
    DataSet ds = new DataSet(null, null);

    // Act
    sut.preProcess(ds);

    // Assert
    assertTrue(ds.isEmpty());
}
 
Example 17
Source Project: deeplearning4j   Source File: IrisUtils.java    License: Apache License 2.0 5 votes vote down vote up
public static List<DataSet> loadIris(int from, int to) throws IOException {
    File rootDir = DL4JResources.getDirectory(ResourceType.DATASET, "iris");
    File irisData = new File(rootDir, "iris.dat");
    if(!irisData.exists()){
        URL url = DL4JResources.getURL(IRIS_RELATIVE_URL);
        Downloader.download("Iris", url, irisData, MD5, 3);
    }

    @SuppressWarnings("unchecked")
    List<String> lines;
    try(InputStream is = new FileInputStream(irisData)){
        lines = IOUtils.readLines(is);
    }
    List<DataSet> list = new ArrayList<>();
    INDArray ret = Nd4j.ones(Math.abs(to - from), 4);
    double[][] outcomes = new double[lines.size()][3];
    int putCount = 0;

    for (int i = from; i < to; i++) {
        String line = lines.get(i);
        String[] split = line.split(",");

        addRow(ret, putCount++, split);

        String outcome = split[split.length - 1];
        double[] rowOutcome = new double[3];
        rowOutcome[Integer.parseInt(outcome)] = 1;
        outcomes[i] = rowOutcome;
    }

    for (int i = 0; i < ret.rows(); i++) {
        DataSet add = new DataSet(ret.getRow(i, true), Nd4j.create(outcomes[from + i], new long[]{1,3}));
        list.add(add);
    }
    return list;
}
 
Example 18
/**
     * Basically all we want from this test - being able to finish without exceptions.
     */
    @Test
    public void testIterator1() throws Exception {

        File inputFile = Resources.asFile("big/raw_sentences.txt");
        SentenceIterator iter = ParagraphVectorsTest.getIterator(isIntegrationTests(), inputFile);
//        SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());

        TokenizerFactory t = new DefaultTokenizerFactory();
        t.setTokenPreProcessor(new CommonPreprocessor());

        Word2Vec vec = new Word2Vec.Builder().minWordFrequency(10) // we make sure we'll have some missing words
                        .iterations(1).learningRate(0.025).layerSize(150).seed(42).sampling(0).negativeSample(0)
                        .useHierarchicSoftmax(true).windowSize(5).modelUtils(new BasicModelUtils<VocabWord>())
                        .useAdaGrad(false).iterate(iter).workers(8).tokenizerFactory(t)
                        .elementsLearningAlgorithm(new CBOW<VocabWord>()).build();

        vec.fit();

        List<String> labels = new ArrayList<>();
        labels.add("positive");
        labels.add("negative");

        Word2VecDataSetIterator iterator = new Word2VecDataSetIterator(vec, getLASI(iter, labels), labels, 1);
        INDArray array = iterator.next().getFeatures();
        int count = 0;
        while (iterator.hasNext()) {
            DataSet ds = iterator.next();

            assertArrayEquals(array.shape(), ds.getFeatures().shape());

            if(!isIntegrationTests() && count++ > 20)
                break;  //raw_sentences.txt is 2.81 MB, takes quite some time to process. We'll only first 20 minibatches when doing unit tests
        }
    }
 
Example 19
Source Project: DataVec   Source File: SingleCSVRecordTest.java    License: Apache License 2.0 5 votes vote down vote up
@Test
public void testVectorRegression() {
    DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{1, 1}, {1, 1}}));

    //assert
    SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet.get(0));
    assertEquals(4, singleCsvRecord.getValues().size());

}
 
Example 20
Source Project: deeplearning4j   Source File: SamplingDataSetIterator.java    License: Apache License 2.0 5 votes vote down vote up
/**
 * @param sampleFrom         the dataset to sample from
 * @param batchSize          the batch size to sample
 * @param totalNumberSamples the sample size
 */
public SamplingDataSetIterator(DataSet sampleFrom, int batchSize, int totalNumberSamples, boolean replace) {
    super();
    this.sampleFrom = sampleFrom;
    this.batchSize = batchSize;
    this.totalNumberSamples = totalNumberSamples;
    this.replace = replace;
}
 
Example 21
/**
 * Test getDataSetIterator
 */
@Test
public void testGetIterator() throws Exception {
  final int batchSize = 1;
  final DataSetIterator it = this.cii.getDataSetIterator(mnistMiniArff, SEED, batchSize);

  Set<Integer> labels = new HashSet<>();
  for (int i = 0; i < mnistMiniArff.size(); i++) {
    Instance inst = mnistMiniArff.get(i);
    int instLabel = Integer.parseInt(inst.stringValue(inst.numAttributes() - 1));
    final DataSet next = Utils.getNext(it);
    int dsLabel = next.getLabels().argMax().getInt(0);
    Assert.assertEquals(instLabel, dsLabel);
    labels.add(instLabel);

    INDArray reshaped = next.getFeatures().reshape(1, inst.numAttributes() - 1);

    // Compare each attribute value
    for (int j = 0; j < inst.numAttributes() - 1; j++) {
      double instVal = inst.value(j);
      double dsVal = reshaped.getDouble(j);
      Assert.assertEquals(instVal, dsVal, 10e-8);
    }
  }

  final List<Integer> collect =
      it.getLabels().stream().map(Integer::valueOf).collect(Collectors.toList());
  Assert.assertEquals(10, labels.size());
  Assert.assertTrue(labels.containsAll(collect));
  Assert.assertTrue(collect.containsAll(labels));
}
 
Example 22
Source Project: deeplearning4j   Source File: FileSplitDataSetIterator.java    License: Apache License 2.0 5 votes vote down vote up
@Override
public DataSet next() {
    //        long time1 = System.nanoTime();
    DataSet ds = callback.call(files.get(counter.getAndIncrement()));

    if (preProcessor != null && ds != null)
        preProcessor.preProcess(ds);

    //        long time2 = System.nanoTime();

    //        if (counter.get() % 5 == 0)
    //            log.info("Device: [{}]; Time: [{}] ns;", Nd4j.getAffinityManager().getDeviceForCurrentThread(), time2 - time1);

    return ds;
}
 
Example 23
Source Project: deeplearning4j   Source File: MLLibUtil.java    License: Apache License 2.0 5 votes vote down vote up
/**
 * Converts a continuous JavaRDD LabeledPoint to a JavaRDD DataSet.
 * @param data JavaRdd LabeledPoint
 * @param preCache boolean pre-cache rdd before operation
 * @return
 */
public static JavaRDD<DataSet> fromContinuousLabeledPoint(JavaRDD<LabeledPoint> data, boolean preCache) {
    if (preCache && !data.getStorageLevel().useMemory()) {
        data.cache();
    }
    return data.map(new Function<LabeledPoint, DataSet>() {
        @Override
        public DataSet call(LabeledPoint lp) {
            return convertToDataset(lp);
        }
    });
}
 
Example 24
Source Project: deeplearning4j   Source File: BaseDataFetcher.java    License: Apache License 2.0 5 votes vote down vote up
/**
 * Initializes this data transform fetcher from the passed in datasets
 *
 * @param examples the examples to use
 */
protected void initializeCurrFromList(List<DataSet> examples) {

    if (examples.isEmpty())
        log.warn("Warning: empty dataset from the fetcher");

    INDArray inputs = createInputMatrix(examples.size());
    INDArray labels = createOutputMatrix(examples.size());
    for (int i = 0; i < examples.size(); i++) {
        inputs.putRow(i, examples.get(i).getFeatures());
        labels.putRow(i, examples.get(i).getLabels());
    }
    curr = new DataSet(inputs, labels);

}
 
Example 25
Source Project: deeplearning4j   Source File: MultiLayerTest.java    License: Apache License 2.0 5 votes vote down vote up
@Test
public void testOutput() throws Exception {
    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .weightInit(WeightInit.XAVIER).seed(12345L).list()
                    .layer(0, new DenseLayer.Builder().nIn(784).nOut(50).activation(Activation.RELU).build())
                    .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                    .activation(Activation.SOFTMAX).nIn(50).nOut(10).build())
                    .setInputType(InputType.convolutional(28, 28, 1)).build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    DataSetIterator fullData = new MnistDataSetIterator(1, 2);
    net.fit(fullData);


    fullData.reset();
    DataSet expectedSet = fullData.next(2);
    INDArray expectedOut = net.output(expectedSet.getFeatures(), false);

    fullData.reset();

    INDArray actualOut = net.output(fullData);

    assertEquals(expectedOut, actualOut);
}
 
Example 26
Source Project: deeplearning4j   Source File: SingletonDataSetIterator.java    License: Apache License 2.0 5 votes vote down vote up
@Override
public DataSet next() {
    if (!hasNext) {
        throw new NoSuchElementException("No elements remaining");
    }
    hasNext = false;
    if (preProcessor != null && !preprocessed) {
        preProcessor.preProcess(dataSet);
        preprocessed = true;
    }
    return dataSet;
}
 
Example 27
Source Project: deeplearning4j   Source File: SparkADSI.java    License: Apache License 2.0 5 votes vote down vote up
public SparkADSI(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue, boolean useWorkspace,
                DataSetCallback callback, Integer deviceId) {
    this();

    if (queueSize < 2)
        queueSize = 2;

    this.deviceId = deviceId;
    this.callback = callback;
    this.useWorkspace = useWorkspace;
    this.buffer = queue;
    this.prefetchSize = queueSize;
    this.backedIterator = iterator;
    this.workspaceId = "SADSI_ITER-" + java.util.UUID.randomUUID().toString();

    if (iterator.resetSupported())
        this.backedIterator.reset();

    context = TaskContext.get();

    this.thread = new SparkPrefetchThread(buffer, iterator, terminator, null, Nd4j.getAffinityManager().getDeviceForCurrentThread());

    /**
     * We want to ensure, that background thread will have the same thread->device affinity, as master thread
     */

    thread.setDaemon(true);
    thread.start();
}
 
Example 28
Source Project: deeplearning4j   Source File: MLLibUtil.java    License: Apache License 2.0 5 votes vote down vote up
/**
 * Convert an rdd of data set in to labeled point.
 * @param data the dataset to convert
 * @param preCache boolean pre-cache rdd before operation
 * @return an rdd of labeled point
 */
public static JavaRDD<LabeledPoint> fromDataSet(JavaRDD<DataSet> data, boolean preCache) {
    if (preCache && !data.getStorageLevel().useMemory()) {
        data.cache();
    }
    return data.map(new Function<DataSet, LabeledPoint>() {
        @Override
        public LabeledPoint call(DataSet dataSet) {
            return toLabeledPoint(dataSet);
        }
    });
}
 
Example 29
Source Project: deeplearning4j   Source File: EvaluativeListener.java    License: Apache License 2.0 5 votes vote down vote up
public EvaluativeListener(@NonNull DataSet dataSet, int frequency, @NonNull InvocationType type,
                IEvaluation... evaluations) {
    this.ds = dataSet;
    this.frequency = frequency;
    this.evaluations = evaluations;

    this.invocationType = type;
}
 
Example 30
Source Project: deeplearning4j   Source File: TestMasking.java    License: Apache License 2.0 5 votes vote down vote up
@Test
public void testCompGraphEvalWithMask() {
    int minibatch = 3;
    int layerSize = 6;
    int nIn = 5;
    int nOut = 4;

    ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().updater(new NoOp())
                    .dist(new NormalDistribution(0, 1)).seed(12345)
                    .graphBuilder().addInputs("in")
                    .addLayer("0", new DenseLayer.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH)
                                    .build(), "in")
                    .addLayer("1", new OutputLayer.Builder().nIn(layerSize).nOut(nOut)
                                    .lossFunction(LossFunctions.LossFunction.XENT).activation(Activation.SIGMOID)
                                    .build(), "0")
                    .setOutputs("1").build();

    ComputationGraph graph = new ComputationGraph(conf2);
    graph.init();

    INDArray f = Nd4j.create(minibatch, nIn);
    INDArray l = Nd4j.create(minibatch, nOut);
    INDArray lMask = Nd4j.ones(minibatch, nOut);

    DataSet ds = new DataSet(f, l, null, lMask);
    DataSetIterator iter = new ExistingDataSetIterator(Collections.singletonList(ds).iterator());

    EvaluationBinary eb = new EvaluationBinary();
    graph.doEvaluation(iter, eb);
}