Java Code Examples for org.deeplearning4j.nn.multilayer.MultiLayerNetwork#doEvaluation()

The following examples show how to use org.deeplearning4j.nn.multilayer.MultiLayerNetwork#doEvaluation() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: EvaluationRunner.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
private static void doEval(Model m, IEvaluation[] e, Iterator<DataSet> ds, Iterator<MultiDataSet> mds, int evalBatchSize){
    if(m instanceof MultiLayerNetwork){
        MultiLayerNetwork mln = (MultiLayerNetwork)m;
        if(ds != null){
            mln.doEvaluation(new IteratorDataSetIterator(ds, evalBatchSize), e);
        } else {
            mln.doEvaluation(new IteratorMultiDataSetIterator(mds, evalBatchSize), e);
        }
    } else {
        ComputationGraph cg = (ComputationGraph)m;
        if(ds != null){
            cg.doEvaluation(new IteratorDataSetIterator(ds, evalBatchSize), e);
        } else {
            cg.doEvaluation(new IteratorMultiDataSetIterator(mds, evalBatchSize), e);
        }
    }
}
 
Example 2
Source File: ROCScoreFunction.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public double score(MultiLayerNetwork net, DataSetIterator iterator) {
    switch (type){
        case ROC:
            ROC r = net.evaluateROC(iterator);
            return metric == Metric.AUC ? r.calculateAUC() : r.calculateAUCPR();
        case BINARY:
            ROCBinary r2 = net.doEvaluation(iterator, new ROCBinary())[0];
            return metric == Metric.AUC ? r2.calculateAverageAuc() : r2.calculateAverageAUCPR();
        case MULTICLASS:
            ROCMultiClass r3 = net.evaluateROCMultiClass(iterator);
            return metric == Metric.AUC ? r3.calculateAverageAUC() : r3.calculateAverageAUCPR();
        default:
            throw new RuntimeException("Unknown type: " + type);
    }
}
 
Example 3
Source File: EvalTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testEvalSplitting(){
        //Test for "tbptt-like" functionality

        for(WorkspaceMode ws : WorkspaceMode.values()) {
            System.out.println("Starting test for workspace mode: " + ws);

            int nIn = 4;
            int layerSize = 5;
            int nOut = 6;
            int tbpttLength = 10;
            int tsLength = 5 * tbpttLength + tbpttLength / 2;

            MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
                    .seed(12345)
                    .trainingWorkspaceMode(ws)
                    .inferenceWorkspaceMode(ws)
                    .list()
                    .layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).build())
                    .layer(new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut)
                            .activation(Activation.SOFTMAX)
                            .build())
                    .build();

            MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
                    .seed(12345)
                    .trainingWorkspaceMode(ws)
                    .inferenceWorkspaceMode(ws)
                    .list()
                    .layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).build())
                    .layer(new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut)
                            .activation(Activation.SOFTMAX).build())
                    .tBPTTLength(10)
                    .backpropType(BackpropType.TruncatedBPTT)
                    .build();

            MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
            net1.init();

            MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
            net2.init();

            net2.setParams(net1.params());

            for(boolean useMask : new boolean[]{false, true}) {

                INDArray in1 = Nd4j.rand(new int[]{3, nIn, tsLength});
                INDArray out1 = TestUtils.randomOneHotTimeSeries(3, nOut, tsLength);

                INDArray in2 = Nd4j.rand(new int[]{5, nIn, tsLength});
                INDArray out2 = TestUtils.randomOneHotTimeSeries(5, nOut, tsLength);

                INDArray lMask1 = null;
                INDArray lMask2 = null;
                if(useMask){
                    lMask1 = Nd4j.create(3, tsLength);
                    lMask2 = Nd4j.create(5, tsLength);
                    Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask1, 0.5));
                    Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask2, 0.5));
                }

                List<DataSet> l = Arrays.asList(new DataSet(in1, out1, null, lMask1), new DataSet(in2, out2, null, lMask2));
                DataSetIterator iter = new ExistingDataSetIterator(l);

//                System.out.println("Net 1 eval");
                org.nd4j.evaluation.IEvaluation[] e1 = net1.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation());
//                System.out.println("Net 2 eval");
                org.nd4j.evaluation.IEvaluation[] e2 = net2.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation());

                assertEquals(e1[0], e2[0]);
                assertEquals(e1[1], e2[1]);
                assertEquals(e1[2], e2[2]);
            }
        }
    }