Java Code Examples for org.nd4j.evaluation.classification.Evaluation#stats()

The following examples show how to use org.nd4j.evaluation.classification.Evaluation#stats() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: EvalTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testConfusionMatrixStats() {

    Evaluation e = new Evaluation();

    INDArray c0 = Nd4j.create(new double[] {1, 0, 0}, new long[]{1, 3});
    INDArray c1 = Nd4j.create(new double[] {0, 1, 0}, new long[]{1, 3});
    INDArray c2 = Nd4j.create(new double[] {0, 0, 1}, new long[]{1, 3});

    apply(e, 3, c2, c0); //Predicted class 2 when actually class 0, 3 times
    apply(e, 2, c0, c1); //Predicted class 0 when actually class 1, 2 times

    String s1 = " 0 0 3 | 0 = 0";   //First row: predicted 2, actual 0 - 3 times
    String s2 = " 2 0 0 | 1 = 1";   //Second row: predicted 0, actual 1 - 2 times

    String stats = e.stats();
    assertTrue(stats, stats.contains(s1));
    assertTrue(stats, stats.contains(s2));
}
 
Example 2
Source File: EvalTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void testSingleClassBinaryClassification() {

        Evaluation eval = new Evaluation(1);

        for (int xe = 0; xe < 3; xe++) {
            INDArray zero = Nd4j.create(1,1);
            INDArray one = Nd4j.ones(1,1);

            //One incorrect, three correct
            eval.eval(one, zero);
            eval.eval(one, one);
            eval.eval(one, one);
            eval.eval(zero, zero);

//            System.out.println(eval.stats());
            eval.stats();

            assertEquals(0.75, eval.accuracy(), 1e-6);
            assertEquals(4, eval.getNumRowCounter());

            assertEquals(1, (int) eval.truePositives().get(0));
            assertEquals(2, (int) eval.truePositives().get(1));
            assertEquals(1, (int) eval.falseNegatives().get(1));

            eval.reset();
        }
    }
 
Example 3
Source File: EvalTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void testEvalInvalid() {
        Evaluation e = new Evaluation(5);
        e.eval(0, 1);
        e.eval(1, 0);
        e.eval(1, 1);

//        System.out.println(e.stats());
        e.stats();

        assertFalse(e.stats().contains("\uFFFD"));
    }
 
Example 4
Source File: EvalTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void testLabelReset(){

        Map<Integer,String> m = new HashMap<>();
        m.put(0, "False");
        m.put(1, "True");

        Evaluation e1 = new Evaluation(m);
        INDArray zero = Nd4j.create(new double[]{1,0}).reshape(1,2);
        INDArray one = Nd4j.create(new double[]{0,1}).reshape(1,2);

        e1.eval(zero, zero);
        e1.eval(zero, zero);
        e1.eval(one, zero);
        e1.eval(one, one);
        e1.eval(one, one);
        e1.eval(one, one);

        String s1 = e1.stats();
//        System.out.println(s1);

        e1.reset();
        e1.eval(zero, zero);
        e1.eval(zero, zero);
        e1.eval(one, zero);
        e1.eval(one, one);
        e1.eval(one, one);
        e1.eval(one, one);

        String s2 = e1.stats();
        assertEquals(s1, s2);
    }
 
Example 5
Source File: EvalTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testEval2() {

    DataType dtypeBefore = Nd4j.defaultFloatingPointType();
    Evaluation first = null;
    String sFirst = null;
    try {
        for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT}) {
            Nd4j.setDefaultDataTypes(globalDtype, globalDtype.isFPType() ? globalDtype : DataType.DOUBLE);
            for (DataType lpDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {

                //Confusion matrix:
                //actual 0      20      3
                //actual 1      10      5

                Evaluation evaluation = new Evaluation(Arrays.asList("class0", "class1"));
                INDArray predicted0 = Nd4j.create(new double[]{1, 0}, new long[]{1, 2}).castTo(lpDtype);
                INDArray predicted1 = Nd4j.create(new double[]{0, 1}, new long[]{1, 2}).castTo(lpDtype);
                INDArray actual0 = Nd4j.create(new double[]{1, 0}, new long[]{1, 2}).castTo(lpDtype);
                INDArray actual1 = Nd4j.create(new double[]{0, 1}, new long[]{1, 2}).castTo(lpDtype);
                for (int i = 0; i < 20; i++) {
                    evaluation.eval(actual0, predicted0);
                }

                for (int i = 0; i < 3; i++) {
                    evaluation.eval(actual0, predicted1);
                }

                for (int i = 0; i < 10; i++) {
                    evaluation.eval(actual1, predicted0);
                }

                for (int i = 0; i < 5; i++) {
                    evaluation.eval(actual1, predicted1);
                }

                assertEquals(20, evaluation.truePositives().get(0), 0);
                assertEquals(3, evaluation.falseNegatives().get(0), 0);
                assertEquals(10, evaluation.falsePositives().get(0), 0);
                assertEquals(5, evaluation.trueNegatives().get(0), 0);

                assertEquals((20.0 + 5) / (20 + 3 + 10 + 5), evaluation.accuracy(), 1e-6);

                String s = evaluation.stats();

                if(first == null) {
                    first = evaluation;
                    sFirst = s;
                } else {
                    assertEquals(first, evaluation);
                    assertEquals(sFirst, s);
                }
            }
        }
    } finally {
        Nd4j.setDefaultDataTypes(dtypeBefore, dtypeBefore);
    }
}
 
Example 6
Source File: EvalTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testEvalMasking() {
        int miniBatch = 5;
        int nOut = 3;
        int tsLength = 6;

        INDArray labels = Nd4j.zeros(miniBatch, nOut, tsLength);
        INDArray predicted = Nd4j.zeros(miniBatch, nOut, tsLength);

        Nd4j.getRandom().setSeed(12345);
        Random r = new Random(12345);
        for (int i = 0; i < miniBatch; i++) {
            for (int j = 0; j < tsLength; j++) {
                INDArray rand = Nd4j.rand(1, nOut);
                rand.divi(rand.sumNumber());
                predicted.put(new INDArrayIndex[] {NDArrayIndex.point(i), all(), NDArrayIndex.point(j)},
                                rand);
                int idx = r.nextInt(nOut);
                labels.putScalar(new int[] {i, idx, j}, 1.0);
            }
        }

        //Create a longer labels/predicted with mask for first and last time step
        //Expect masked evaluation to be identical to original evaluation
        INDArray labels2 = Nd4j.zeros(miniBatch, nOut, tsLength + 2);
        labels2.put(new INDArrayIndex[] {all(), all(),
                        interval(1, tsLength + 1)}, labels);
        INDArray predicted2 = Nd4j.zeros(miniBatch, nOut, tsLength + 2);
        predicted2.put(new INDArrayIndex[] {all(), all(),
                        interval(1, tsLength + 1)}, predicted);

        INDArray labelsMask = Nd4j.ones(miniBatch, tsLength + 2);
        for (int i = 0; i < miniBatch; i++) {
            labelsMask.putScalar(new int[] {i, 0}, 0.0);
            labelsMask.putScalar(new int[] {i, tsLength + 1}, 0.0);
        }

        Evaluation evaluation = new Evaluation();
        evaluation.evalTimeSeries(labels, predicted);

        Evaluation evaluation2 = new Evaluation();
        evaluation2.evalTimeSeries(labels2, predicted2, labelsMask);

//        System.out.println(evaluation.stats());
//        System.out.println(evaluation2.stats());
        evaluation.stats();
        evaluation2.stats();

        assertEquals(evaluation.accuracy(), evaluation2.accuracy(), 1e-12);
        assertEquals(evaluation.f1(), evaluation2.f1(), 1e-12);

        assertMapEquals(evaluation.falsePositives(), evaluation2.falsePositives());
        assertMapEquals(evaluation.falseNegatives(), evaluation2.falseNegatives());
        assertMapEquals(evaluation.truePositives(), evaluation2.truePositives());
        assertMapEquals(evaluation.trueNegatives(), evaluation2.trueNegatives());

        for (int i = 0; i < nOut; i++)
            assertEquals(evaluation.classCount(i), evaluation2.classCount(i));
    }
 
Example 7
Source File: EvalTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testEvalMethods() {
        //Check eval(int,int) vs. eval(INDArray,INDArray)

        Evaluation e1 = new Evaluation(4);
        Evaluation e2 = new Evaluation(4);

        INDArray i0 = Nd4j.create(new double[] {1, 0, 0, 0}, new long[]{1, 4});
        INDArray i1 = Nd4j.create(new double[] {0, 1, 0, 0}, new long[]{1, 4});
        INDArray i2 = Nd4j.create(new double[] {0, 0, 1, 0}, new long[]{1, 4});
        INDArray i3 = Nd4j.create(new double[] {0, 0, 0, 1}, new long[]{1, 4});

        e1.eval(i0, i0); //order: actual, predicted
        e2.eval(0, 0); //order: predicted, actual
        e1.eval(i0, i2);
        e2.eval(2, 0);
        e1.eval(i0, i2);
        e2.eval(2, 0);
        e1.eval(i1, i2);
        e2.eval(2, 1);
        e1.eval(i3, i3);
        e2.eval(3, 3);
        e1.eval(i3, i0);
        e2.eval(0, 3);
        e1.eval(i3, i0);
        e2.eval(0, 3);

        org.nd4j.evaluation.classification.ConfusionMatrix<Integer> cm = e1.getConfusionMatrix();
        assertEquals(1, cm.getCount(0, 0)); //Order: actual, predicted
        assertEquals(2, cm.getCount(0, 2));
        assertEquals(1, cm.getCount(1, 2));
        assertEquals(1, cm.getCount(3, 3));
        assertEquals(2, cm.getCount(3, 0));

//        System.out.println(e1.stats());
//        System.out.println(e2.stats());
        e1.stats();
        e2.stats();

        assertEquals(e1.stats(), e2.stats());
    }
 
Example 8
Source File: EvalTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testTopNAccuracy() {

        Evaluation e = new Evaluation(null, 3);

        INDArray i0 = Nd4j.create(new double[] {1, 0, 0, 0, 0}, new long[]{1, 5});
        INDArray i1 = Nd4j.create(new double[] {0, 1, 0, 0, 0}, new long[]{1, 5});

        INDArray p0_0 = Nd4j.create(new double[] {0.8, 0.05, 0.05, 0.05, 0.05}, new long[]{1, 5}); //class 0: highest prob
        INDArray p0_1 = Nd4j.create(new double[] {0.4, 0.45, 0.05, 0.05, 0.05}, new long[]{1, 5}); //class 0: 2nd highest prob
        INDArray p0_2 = Nd4j.create(new double[] {0.1, 0.45, 0.35, 0.05, 0.05}, new long[]{1, 5}); //class 0: 3rd highest prob
        INDArray p0_3 = Nd4j.create(new double[] {0.1, 0.40, 0.30, 0.15, 0.05}, new long[]{1, 5}); //class 0: 4th highest prob

        INDArray p1_0 = Nd4j.create(new double[] {0.05, 0.80, 0.05, 0.05, 0.05}, new long[]{1, 5}); //class 1: highest prob
        INDArray p1_1 = Nd4j.create(new double[] {0.45, 0.40, 0.05, 0.05, 0.05}, new long[]{1, 5}); //class 1: 2nd highest prob
        INDArray p1_2 = Nd4j.create(new double[] {0.35, 0.10, 0.45, 0.05, 0.05}, new long[]{1, 5}); //class 1: 3rd highest prob
        INDArray p1_3 = Nd4j.create(new double[] {0.40, 0.10, 0.30, 0.15, 0.05}, new long[]{1, 5}); //class 1: 4th highest prob


        //                                              Correct     TopNCorrect     Total
        e.eval(i0, p0_0); //  1           1               1
        assertEquals(1.0, e.accuracy(), 1e-6);
        assertEquals(1.0, e.topNAccuracy(), 1e-6);
        assertEquals(1, e.getTopNCorrectCount());
        assertEquals(1, e.getTopNTotalCount());
        e.eval(i0, p0_1); //  1           2               2
        assertEquals(0.5, e.accuracy(), 1e-6);
        assertEquals(1.0, e.topNAccuracy(), 1e-6);
        assertEquals(2, e.getTopNCorrectCount());
        assertEquals(2, e.getTopNTotalCount());
        e.eval(i0, p0_2); //  1           3               3
        assertEquals(1.0 / 3, e.accuracy(), 1e-6);
        assertEquals(1.0, e.topNAccuracy(), 1e-6);
        assertEquals(3, e.getTopNCorrectCount());
        assertEquals(3, e.getTopNTotalCount());
        e.eval(i0, p0_3); //  1           3               4
        assertEquals(0.25, e.accuracy(), 1e-6);
        assertEquals(0.75, e.topNAccuracy(), 1e-6);
        assertEquals(3, e.getTopNCorrectCount());
        assertEquals(4, e.getTopNTotalCount());

        e.eval(i1, p1_0); //  2           4               5
        assertEquals(2.0 / 5, e.accuracy(), 1e-6);
        assertEquals(4.0 / 5, e.topNAccuracy(), 1e-6);
        e.eval(i1, p1_1); //  2           5               6
        assertEquals(2.0 / 6, e.accuracy(), 1e-6);
        assertEquals(5.0 / 6, e.topNAccuracy(), 1e-6);
        e.eval(i1, p1_2); //  2           6               7
        assertEquals(2.0 / 7, e.accuracy(), 1e-6);
        assertEquals(6.0 / 7, e.topNAccuracy(), 1e-6);
        e.eval(i1, p1_3); //  2           6               8
        assertEquals(2.0 / 8, e.accuracy(), 1e-6);
        assertEquals(6.0 / 8, e.topNAccuracy(), 1e-6);
        assertEquals(6, e.getTopNCorrectCount());
        assertEquals(8, e.getTopNTotalCount());

//        System.out.println(e.stats());
        e.stats();
    }
 
Example 9
Source File: EvalTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testEvalStatsBinaryCase(){
    //Make sure we report class 1 precision/recall/f1 not macro averaged, for binary case

    Evaluation e = new Evaluation();

    INDArray l0 = Nd4j.createFromArray(new double[]{1,0}).reshape(1,2);
    INDArray l1 = Nd4j.createFromArray(new double[]{0,1}).reshape(1,2);

    e.eval(l1, l1);
    e.eval(l1, l1);
    e.eval(l1, l1);
    e.eval(l0, l0);
    e.eval(l1, l0);
    e.eval(l1, l0);
    e.eval(l0, l1);

    double tp = 3;
    double fp = 1;
    double fn = 2;

    double prec = tp / (tp + fp);
    double rec = tp / (tp + fn);
    double f1 = 2 * prec * rec / (prec + rec);

    assertEquals(prec, e.precision(), 1e-6);
    assertEquals(rec, e.recall(), 1e-6);

    DecimalFormat df = new DecimalFormat("0.0000");

    String stats = e.stats();
    //System.out.println(stats);

    String stats2 = stats.replaceAll("( )+", " ");

    String recS = " Recall: " + df.format(rec);
    String preS = " Precision: " + df.format(prec);
    String f1S = "F1 Score: " + df.format(f1);

    assertTrue(stats2, stats2.contains(recS));
    assertTrue(stats2, stats2.contains(preS));
    assertTrue(stats2, stats2.contains(f1S));
}