org.nd4j.linalg.dataset.api.DataSet Java Examples

The following examples show how to use org.nd4j.linalg.dataset.api.DataSet. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: ImageFlatteningDataSetPreProcessor.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void preProcess(DataSet toPreProcess) {
    INDArray input = toPreProcess.getFeatures();
    if (input.rank() == 2)
        return; //No op: should usually never happen in a properly configured data pipeline

    //Assume input is standard rank 4 activations - i.e., CNN image data
    //First: we require input to be in c order. But c order (as declared in array order) isn't enough; also need strides to be correct
    if (input.ordering() != 'c' || !Shape.strideDescendingCAscendingF(input))
        input = input.dup('c');

    val inShape = input.shape(); //[miniBatch,depthOut,outH,outW]
    val outShape = new long[] {inShape[0], inShape[1] * inShape[2] * inShape[3]};

    INDArray reshaped = input.reshape('c', outShape);
    toPreProcess.setFeatures(reshaped);
}
 
Example #2
Source File: ImageFlatteningDataSetPreProcessor.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public void preProcess(DataSet toPreProcess) {
    INDArray input = toPreProcess.getFeatures();
    if (input.rank() == 2)
        return; //No op: should usually never happen in a properly configured data pipeline

    //Assume input is standard rank 4 activations - i.e., CNN image data
    //First: we require input to be in c order. But c order (as declared in array order) isn't enough; also need strides to be correct
    if (input.ordering() != 'c' || !Shape.strideDescendingCAscendingF(input))
        input = input.dup('c');

    val inShape = input.shape(); //[miniBatch,depthOut,outH,outW]
    val outShape = new long[] {inShape[0], inShape[1] * inShape[2] * inShape[3]};

    INDArray reshaped = input.reshape('c', outShape);
    toPreProcess.setFeatures(reshaped);
}
 
Example #3
Source File: TestRecordReaders.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testClassIndexOutsideOfRangeRRDSI() {
    Collection<Collection<Writable>> c = new ArrayList<>();
    c.add(Arrays.<Writable>asList(new DoubleWritable(0.5), new IntWritable(0)));
    c.add(Arrays.<Writable>asList(new DoubleWritable(1.0), new IntWritable(2)));

    CollectionRecordReader crr = new CollectionRecordReader(c);

    RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(crr, 2, 1, 2);

    try {
        DataSet ds = iter.next();
        fail("Expected exception");
    } catch (Exception e) {
        assertTrue(e.getMessage(), e.getMessage().contains("to one-hot"));
    }
}
 
Example #4
Source File: ParameterServerTrainer.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void feedDataSet(@NonNull DataSet dataSet, long time) {
    // FIXME: this is wrong, and should be fixed. Training should happen within run() loop

    if (getModel() instanceof ComputationGraph) {
        ComputationGraph computationGraph = (ComputationGraph) getModel();
        computationGraph.fit(dataSet);
    } else {
        MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) getModel();
        log.info("Calling fit on multi layer network");
        multiLayerNetwork.fit(dataSet);

    }

    log.info("About to send params in");
    //send the updated params
    parameterServerClient.pushNDArray(getModel().params());
    log.info("Sent params");
}
 
Example #5
Source File: DoubleDQNTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void when_isTerminal_expect_rewardValueAtIdx0() {

    // Assemble
    when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);

    List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
        {
            add(builtTransition(buildObservation(new double[]{1.1, 2.2}),
                    0, 1.0, true, buildObservation(new double[]{11.0, 22.0})));
        }
    };

    DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);

    // Act
    DataSet result = sut.compute(transitions);

    // Assert
    INDArray evaluatedQValues = result.getLabels();
    assertEquals(1.0, evaluatedQValues.getDouble(0, 0), 0.0001);
    assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001);
}
 
Example #6
Source File: ComputationGraphUtil.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/** Convert a DataSet to the equivalent MultiDataSet */
public static MultiDataSet toMultiDataSet(DataSet dataSet) {
    INDArray f = dataSet.getFeatures();
    INDArray l = dataSet.getLabels();
    INDArray fMask = dataSet.getFeaturesMaskArray();
    INDArray lMask = dataSet.getLabelsMaskArray();
    List<Serializable> meta = dataSet.getExampleMetaData();

    INDArray[] fNew = f == null ? null : new INDArray[] {f};
    INDArray[] lNew = l == null ? null : new INDArray[] {l};
    INDArray[] fMaskNew = (fMask != null ? new INDArray[] {fMask} : null);
    INDArray[] lMaskNew = (lMask != null ? new INDArray[] {lMask} : null);

    org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(fNew, lNew, fMaskNew, lMaskNew);
    mds.setExampleMetaData(meta);
    return mds;
}
 
Example #7
Source File: RecordConverter.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
private static List<List<Writable>> getRegressionWritableMatrix(DataSet dataSet) {
    List<List<Writable>> writableMatrix = new ArrayList<>();

    for (int i = 0; i < dataSet.numExamples(); i++) {
        List<Writable> writables = toRecord(dataSet.getFeatures().getRow(i));
        INDArray labelRow = dataSet.getLabels().getRow(i);

        for (int j = 0; j < labelRow.shape()[1]; j++) {
            writables.add(new DoubleWritable(labelRow.getDouble(j)));
        }

        writableMatrix.add(writables);
    }

    return writableMatrix;
}
 
Example #8
Source File: CropAndResizeDataSetPreProcessor.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * NOTE: The data format must be NHWC
 */
@Override
public void preProcess(DataSet dataSet) {
    Preconditions.checkNotNull(dataSet, "Encountered null dataSet");

    if(dataSet.isEmpty()) {
        return;
    }

    INDArray input = dataSet.getFeatures();
    INDArray output = Nd4j.create(LongShapeDescriptor.fromShape(resizedShape, input.dataType()), false);

    CustomOp op = DynamicCustomOp.builder("crop_and_resize")
            .addInputs(input, boxes, indices, resize)
            .addIntegerArguments(method)
            .addOutputs(output)
            .build();
    Nd4j.getExecutioner().exec(op);

    dataSet.setFeatures(output);
}
 
Example #9
Source File: CompositeDataSetPreProcessor.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void preProcess(DataSet dataSet) {
    Preconditions.checkNotNull(dataSet, "Encountered null dataSet");

    if(stopOnEmptyDataSet && dataSet.isEmpty()) {
        return;
    }

    for(DataSetPreProcessor p : preProcessors){
        p.preProcess(dataSet);

        if(stopOnEmptyDataSet && dataSet.isEmpty()) {
            return;
        }
    }
}
 
Example #10
Source File: DefaultCallback.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void call(DataSet dataSet) {
    if (dataSet != null) {
        if (dataSet.getFeatures() != null)
            Nd4j.getAffinityManager().ensureLocation(dataSet.getFeatures(), AffinityManager.Location.DEVICE);

        if (dataSet.getLabels() != null)
            Nd4j.getAffinityManager().ensureLocation(dataSet.getLabels(), AffinityManager.Location.DEVICE);

        if (dataSet.getFeaturesMaskArray() != null)
            Nd4j.getAffinityManager().ensureLocation(dataSet.getFeaturesMaskArray(),
                            AffinityManager.Location.DEVICE);

        if (dataSet.getLabelsMaskArray() != null)
            Nd4j.getAffinityManager().ensureLocation(dataSet.getLabelsMaskArray(), AffinityManager.Location.DEVICE);
    }
}
 
Example #11
Source File: RecordConverter.java    From DataVec with Apache License 2.0 5 votes vote down vote up
/**
 * Convert a DataSet to a matrix
 * @param dataSet the DataSet to convert
 * @return the matrix for the records
 */
public static List<List<Writable>> toRecords(DataSet dataSet) {
    if (isClassificationDataSet(dataSet)) {
        return getClassificationWritableMatrix(dataSet);
    } else {
        return getRegressionWritableMatrix(dataSet);
    }
}
 
Example #12
Source File: UnderSamplingByMaskingPreProcessor.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public void preProcess(DataSet toPreProcess) {
    INDArray label = toPreProcess.getLabels();
    INDArray labelMask = toPreProcess.getLabelsMaskArray();
    INDArray sampledMask = adjustMasks(label, labelMask, minorityLabel, targetMinorityDist);
    toPreProcess.setLabelsMaskArray(sampledMask);
}
 
Example #13
Source File: JointParallelDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Testing relocate
 *
 * @throws Exception
 */
@Test
public void testJointIterator3() throws Exception {
    DataSetIterator iteratorA = new SimpleVariableGenerator(119, 200, 32, 100, 10);
    DataSetIterator iteratorB = new SimpleVariableGenerator(119, 100, 32, 100, 10);

    JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.RELOCATE)
                    .addSourceIterator(iteratorA).addSourceIterator(iteratorB).build();

    int cnt = 0;
    int example = 0;
    while (jpdsi.hasNext()) {
        DataSet ds = jpdsi.next();
        assertNotNull("Failed on iteration " + cnt, ds);

        assertEquals("Failed on iteration " + cnt, (double) example, ds.getFeatures().meanNumber().doubleValue(),
                        0.001);
        assertEquals("Failed on iteration " + cnt, (double) example + 0.5,
                        ds.getLabels().meanNumber().doubleValue(), 0.001);


        cnt++;
        if (cnt < 200) {
            if (cnt % 2 == 0)
                example++;
        } else
            example++;
    }


    assertEquals(300, cnt);
    assertEquals(200, example);
}
 
Example #14
Source File: EvaluationToolsTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void testRocMultiToHtml() throws Exception {
        DataSetIterator iter = new IrisDataSetIterator(150, 150);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list()
                        .layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1,
                                        new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX)
                                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                        .build();
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        NormalizerStandardize ns = new NormalizerStandardize();
        DataSet ds = iter.next();
        ns.fit(ds);
        ns.transform(ds);

        for (int i = 0; i < 30; i++) {
            net.fit(ds);
        }

        for (int numSteps : new int[] {20, 0}) {
            ROCMultiClass roc = new ROCMultiClass(numSteps);
            iter.reset();

            INDArray f = ds.getFeatures();
            INDArray l = ds.getLabels();
            INDArray out = net.output(f);
            roc.eval(l, out);


            String str = EvaluationTools.rocChartToHtml(roc, Arrays.asList("setosa", "versicolor", "virginica"));
//            System.out.println(str);
        }
    }
 
Example #15
Source File: JointParallelDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
     * Simple test, checking datasets alignment. They all should have the same data for the same cycle
     *
     *
     * @throws Exception
     */
    @Test
    public void testJointIterator1() throws Exception {
        DataSetIterator iteratorA = new SimpleVariableGenerator(119, 100, 32, 100, 10);
        DataSetIterator iteratorB = new SimpleVariableGenerator(119, 100, 32, 100, 10);

        JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.STOP_EVERYONE)
                        .addSourceIterator(iteratorA).addSourceIterator(iteratorB).build();

        int cnt = 0;
        int example = 0;
        while (jpdsi.hasNext()) {
            DataSet ds = jpdsi.next();
            assertNotNull("Failed on iteration " + cnt, ds);

//            ds.detach();
            //ds.migrate();

            assertEquals("Failed on iteration " + cnt, (double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001);
            assertEquals("Failed on iteration " + cnt, (double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001);

            cnt++;
            if (cnt % 2 == 0)
                example++;
        }

        assertEquals(100, example);
        assertEquals(200, cnt);
    }
 
Example #16
Source File: RPTreeTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testRpTreeMaxNodes() throws Exception {
    DataSetIterator mnist = new MnistDataSetIterator(150,150);
    RPForest rpTree = new RPForest(4,4,"euclidean");
    DataSet d = mnist.next();
    NormalizerStandardize normalizerStandardize = new NormalizerStandardize();
    normalizerStandardize.fit(d);
    rpTree.fit(d.getFeatures());
    for(RPTree tree : rpTree.getTrees()) {
        for(RPNode node : tree.getLeaves()) {
            assertTrue(node.getIndices().size() <= rpTree.getMaxSize());
        }
    }

}
 
Example #17
Source File: DummyBlockDataSetIteratorTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testBlock_1() throws Exception {
    val simpleIterator = new SimpleVariableGenerator(123, 8, 3, 3, 3);

    val iterator = new DummyBlockDataSetIterator(simpleIterator);

    assertTrue(iterator.hasAnything());
    val list = new ArrayList<DataSet>(8);

    var datasets = iterator.next(3);
    assertNotNull(datasets);
    assertEquals(3, datasets.length);
    list.addAll(Arrays.asList(datasets));



    datasets = iterator.next(3);
    assertNotNull(datasets);
    assertEquals(3, datasets.length);
    list.addAll(Arrays.asList(datasets));

    datasets = iterator.next(3);
    assertNotNull(datasets);
    assertEquals(2, datasets.length);
    list.addAll(Arrays.asList(datasets));

    for (int e = 0; e < list.size(); e++) {
        val dataset = list.get(e);

        assertEquals(e, (int) dataset.getFeatures().getDouble(0));
        assertEquals(e + 0.5, dataset.getLabels().getDouble(0), 1e-3);
    }
}
 
Example #18
Source File: AbstractDataSetNormalizer.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * Fit a dataset (only compute based on the statistics from this dataset)
 * @param dataSet the dataset to compute on
 */
@Override
public void fit(DataSet dataSet) {
    featureStats = (S) newBuilder().addFeatures(dataSet).build();
    if (isFitLabel()) {
        labelStats = (S) newBuilder().addLabels(dataSet).build();
    }
}
 
Example #19
Source File: DummyBlockDataSetIterator.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public DataSet[] next(int maxDatasets) {
    val list = new ArrayList<DataSet>(maxDatasets);
    int cnt = 0;
    while (iterator.hasNext() && cnt < maxDatasets) {
        list.add(iterator.next());
        cnt++;
    }

    return list.toArray(new DataSet[list.size()]);
}
 
Example #20
Source File: LossLayer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Return predicted label names
 *
 * @param dataSet to predict
 * @return the predicted labels for the dataSet
 */
@Override
public List<String> predict(DataSet dataSet) {
    int[] intRet = predict(dataSet.getFeatures());
    List<String> ret = new ArrayList<>();
    for (int i : intRet) {
        ret.add(i, dataSet.getLabelName(i));
    }
    return ret;
}
 
Example #21
Source File: StandardDQNTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {

    // Assemble
    List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
        {
            add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
                    0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
            add(buildTransition(buildObservation(new double[]{3.3, 4.4}),
                    1, 2.0, false, buildObservation(new double[]{33.0, 44.0})));
            add(buildTransition(buildObservation(new double[]{5.5, 6.6}),
                    0, 3.0, true, buildObservation(new double[]{55.0, 66.0})));
        }
    };

    StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);

    // Act
    DataSet result = sut.compute(transitions);

    // Assert
    INDArray evaluatedQValues = result.getLabels();
    assertEquals((1.0 + 0.5 * 22.0), evaluatedQValues.getDouble(0, 0), 0.0001);
    assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001);

    assertEquals(3.3, evaluatedQValues.getDouble(1, 0), 0.0001);
    assertEquals((2.0 + 0.5 * 44.0), evaluatedQValues.getDouble(1, 1), 0.0001);

    assertEquals(3.0, evaluatedQValues.getDouble(2, 0), 0.0001); // terminal: reward only
    assertEquals(6.6, evaluatedQValues.getDouble(2, 1), 0.0001);

}
 
Example #22
Source File: TestRecordReaders.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testClassIndexOutsideOfRangeRRMDSI_MultipleReaders() {

    Collection<Collection<Collection<Writable>>> c1 = new ArrayList<>();
    Collection<Collection<Writable>> seq1 = new ArrayList<>();
    seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
    seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
    c1.add(seq1);

    Collection<Collection<Writable>> seq2 = new ArrayList<>();
    seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
    seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
    c1.add(seq2);

    Collection<Collection<Collection<Writable>>> c2 = new ArrayList<>();
    Collection<Collection<Writable>> seq1a = new ArrayList<>();
    seq1a.add(Arrays.<Writable>asList(new IntWritable(0)));
    seq1a.add(Arrays.<Writable>asList(new IntWritable(1)));
    c2.add(seq1a);

    Collection<Collection<Writable>> seq2a = new ArrayList<>();
    seq2a.add(Arrays.<Writable>asList(new IntWritable(0)));
    seq2a.add(Arrays.<Writable>asList(new IntWritable(2)));
    c2.add(seq2a);

    CollectionSequenceRecordReader csrr = new CollectionSequenceRecordReader(c1);
    CollectionSequenceRecordReader csrrLabels = new CollectionSequenceRecordReader(c2);
    DataSetIterator dsi = new SequenceRecordReaderDataSetIterator(csrr, csrrLabels, 2, 2);

    try {
        DataSet ds = dsi.next();
        fail("Expected exception");
    } catch (Exception e) {
        assertTrue(e.getMessage(), e.getMessage().contains("to one-hot"));
    }
}
 
Example #23
Source File: RGBtoGrayscaleDataSetPreProcessor.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void preProcess(DataSet dataSet) {
    Preconditions.checkNotNull(dataSet, "Encountered null dataSet");

    if(dataSet.isEmpty()) {
        return;
    }

    INDArray originalFeatures = dataSet.getFeatures();
    long[] originalShape = originalFeatures.shape();

    // result shape is NHW
    INDArray result = Nd4j.create(originalShape[0], originalShape[2], originalShape[3]);

    for(long n = 0, numExamples = originalShape[0]; n < numExamples; ++n) {
        // Extract channels
        INDArray itemFeatures = originalFeatures.slice(n, 0); // shape is CHW
        INDArray R = itemFeatures.slice(0, 0);  // shape is HW
        INDArray G = itemFeatures.slice(1, 0);
        INDArray B = itemFeatures.slice(2, 0);

        // Convert
        R.muli(RED_RATIO);
        G.muli(GREEN_RATIO);
        B.muli(BLUE_RATIO);
        R.addi(G).addi(B);

        result.putSlice((int)n, R);
    }

    dataSet.setFeatures(result);
}
 
Example #24
Source File: RPTreeTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testFindSelf() throws Exception {
    DataSetIterator mnist = new MnistDataSetIterator(100, 6000);
    NormalizerMinMaxScaler minMaxNormalizer = new NormalizerMinMaxScaler(0, 1);
    minMaxNormalizer.fit(mnist);
    DataSet d = mnist.next();
    minMaxNormalizer.transform(d.getFeatures());
    RPForest rpForest = new RPForest(100, 100, "euclidean");
    rpForest.fit(d.getFeatures());
    for (int i = 0; i < 10; i++) {
        INDArray indexes = rpForest.queryAll(d.getFeatures().slice(i), 10);
        assertEquals(i,indexes.getInt(0));
    }
}
 
Example #25
Source File: UnderSamplingByMaskingPreProcessor.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void preProcess(DataSet toPreProcess) {
    INDArray label = toPreProcess.getLabels();
    INDArray labelMask = toPreProcess.getLabelsMaskArray();
    INDArray sampledMask = adjustMasks(label, labelMask, minorityLabel, targetMinorityDist);
    toPreProcess.setLabelsMaskArray(sampledMask);
}
 
Example #26
Source File: BaseOutputLayer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Return predicted label names
 *
 * @param dataSet to predict
 * @return the predicted labels for the dataSet
 */
@Override
public List<String> predict(DataSet dataSet) {
    int[] intRet = predict(dataSet.getFeatures());
    List<String> ret = new ArrayList<>();
    for (int i : intRet) {
        ret.add(i, dataSet.getLabelName(i));
    }
    return ret;
}
 
Example #27
Source File: SharedTrainingWorker.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public SharedTrainingResult processMinibatch(DataSet dataSet, MultiLayerNetwork network, boolean isLast) {
    /*
        We're not really going to use this method for training.
        Partitions will be mapped to ParallelWorker threads dynamically, wrt thread/device affinity.
        So plan is simple: we're going to use individual partitions to feed main worker
     */
    throw new UnsupportedOperationException();
}
 
Example #28
Source File: TransformProcessTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public DataSet transform(Integer input) {
    return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { input }), null);
}
 
Example #29
Source File: AbstractDataSetNormalizer.java    From nd4j with Apache License 2.0 4 votes vote down vote up
/**
 * Revert the data to what it was before transform
 *
 * @param data the dataset to revert back
 */
@Override
public void revert(DataSet data) {
    revertFeatures(data.getFeatures(), data.getFeaturesMaskArray());
    revertLabels(data.getLabels(), data.getLabelsMaskArray());
}
 
Example #30
Source File: SameDiffOutputLayer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public List<String> predict(DataSet dataSet) {
    throw new UnsupportedOperationException("Not supported");
}