Java Code Examples for org.nd4j.linalg.factory.Nd4j#defaultFloatingPointType()

The following examples show how to use org.nd4j.linalg.factory.Nd4j#defaultFloatingPointType() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: OrthogonalDistribution.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray sample(long[] shape){
    long numRows = 1;
    for (int i = 0; i < shape.length - 1; i++)
        numRows *= shape[i];
    long numCols = shape[shape.length - 1];

    val dtype = Nd4j.defaultFloatingPointType();

    val flatShape = new long[]{numRows, numCols};
    val flatRng =  Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(dtype, flatShape, Nd4j.order()), 0.0, 1.0), random);

    val m = flatRng.rows();
    val n = flatRng.columns();

    val s = Nd4j.create(dtype, m < n ? m : n);
    val u = Nd4j.create(dtype, m, m);
    val v = Nd4j.create(dtype, new long[] {n, n}, 'f');

    Nd4j.exec(new Svd(flatRng, true, s, u, v));

    if (gains == null) {
        if (u.rows() >= numRows && u.columns() >= numCols) {
            return u.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape);
        } else {
            return v.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape);
        }
    } else {
        throw new UnsupportedOperationException();
    }
}
 
Example 2
Source File: BaseTransformFloatOp.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public DataType resultType() {
    if (this.x() != null && this.x().isR())
        return this.x().dataType();

    return Nd4j.defaultFloatingPointType();
}
 
Example 3
Source File: BaseTransformFloatOp.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public DataType resultType(OpContext oc) {
    if (oc.getInputArray(0) != null && oc.getInputArray(0).isR())
        return oc.getInputArray(0).dataType();

    return Nd4j.defaultFloatingPointType();
}
 
Example 4
Source File: Variance.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public DataType resultType(OpContext oc){
    INDArray x = oc != null ? oc.getInputArray(0) : x();
    if (x != null && x.isR())
        return x.dataType();

    if(this.arg() != null){
        return this.arg().dataType();
    }

    return Nd4j.defaultFloatingPointType();
}
 
Example 5
Source File: BaseReduceFloatOp.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public DataType resultType(OpContext oc) {
    INDArray x = oc != null ? oc.getInputArray(0) : x();
    if (x != null && x.isR())
        return x.dataType();

    return Nd4j.defaultFloatingPointType();
}
 
Example 6
Source File: BaseReduceFloatOp.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
    INDArray x = oc != null ? oc.getInputArray(0) : x();

    if(x == null)
        return Collections.emptyList();

    //Calculate reduction shape. Note that reduction on scalar - returns a scalar
    long[] reducedShape = x.rank() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims());
    DataType retType = arg().dataType();
    if(!retType.isFPType())
        retType = Nd4j.defaultFloatingPointType();
    return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, retType));
}
 
Example 7
Source File: EvaluationCalibrationTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testReliabilityDiagram() {

        DataType dtypeBefore = Nd4j.defaultFloatingPointType();
        EvaluationCalibration first = null;
        String sFirst = null;
        try {
            for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT}) {
                Nd4j.setDefaultDataTypes(globalDtype, globalDtype.isFPType() ? globalDtype : DataType.DOUBLE);
                for (DataType lpDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {


                    //Test using 5 bins - format: binary softmax-style output
                    //Note: no values fall in fourth bin

                    //[0, 0.2)
                    INDArray bin0Probs = Nd4j.create(new double[][]{{1.0, 0.0}, {0.9, 0.1}, {0.85, 0.15}}).castTo(lpDtype);
                    INDArray bin0Labels = Nd4j.create(new double[][]{{1.0, 0.0}, {1.0, 0.0}, {0.0, 1.0}}).castTo(lpDtype);

                    //[0.2, 0.4)
                    INDArray bin1Probs = Nd4j.create(new double[][]{{0.80, 0.20}, {0.7, 0.3}, {0.65, 0.35}}).castTo(lpDtype);
                    INDArray bin1Labels = Nd4j.create(new double[][]{{1.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}}).castTo(lpDtype);

                    //[0.4, 0.6)
                    INDArray bin2Probs = Nd4j.create(new double[][]{{0.59, 0.41}, {0.5, 0.5}, {0.45, 0.55}}).castTo(lpDtype);
                    INDArray bin2Labels = Nd4j.create(new double[][]{{1.0, 0.0}, {0.0, 1.0}, {0.0, 1.0}}).castTo(lpDtype);

                    //[0.6, 0.8)
                    //Empty

                    //[0.8, 1.0]
                    INDArray bin4Probs = Nd4j.create(new double[][]{{0.0, 1.0}, {0.1, 0.9}}).castTo(lpDtype);
                    INDArray bin4Labels = Nd4j.create(new double[][]{{0.0, 1.0}, {0.0, 1.0}}).castTo(lpDtype);


                    INDArray probs = Nd4j.vstack(bin0Probs, bin1Probs, bin2Probs, bin4Probs);
                    INDArray labels = Nd4j.vstack(bin0Labels, bin1Labels, bin2Labels, bin4Labels);

                    EvaluationCalibration ec = new EvaluationCalibration(5, 5);
                    ec.eval(labels, probs);

                    for (int i = 0; i < 1; i++) {
                        double[] avgBinProbsClass;
                        double[] fracPos;
                        if (i == 0) {
                            //Class 0: needs to be handled a little differently, due to threshold/edge cases (0.8, etc)
                            avgBinProbsClass = new double[]{0.05, (0.59 + 0.5 + 0.45) / 3, (0.65 + 0.7) / 2.0,
                                    (0.8 + 0.85 + 0.9 + 1.0) / 4};
                            fracPos = new double[]{0.0 / 2.0, 1.0 / 3, 1.0 / 2, 3.0 / 4};
                        } else {
                            avgBinProbsClass = new double[]{bin0Probs.getColumn(i).meanNumber().doubleValue(),
                                    bin1Probs.getColumn(i).meanNumber().doubleValue(),
                                    bin2Probs.getColumn(i).meanNumber().doubleValue(),
                                    bin4Probs.getColumn(i).meanNumber().doubleValue()};

                            fracPos = new double[]{bin0Labels.getColumn(i).sumNumber().doubleValue() / bin0Labels.size(0),
                                    bin1Labels.getColumn(i).sumNumber().doubleValue() / bin1Labels.size(0),
                                    bin2Labels.getColumn(i).sumNumber().doubleValue() / bin2Labels.size(0),
                                    bin4Labels.getColumn(i).sumNumber().doubleValue() / bin4Labels.size(0)};
                        }

                        org.nd4j.evaluation.curves.ReliabilityDiagram rd = ec.getReliabilityDiagram(i);

                        double[] x = rd.getMeanPredictedValueX();
                        double[] y = rd.getFractionPositivesY();

                        assertArrayEquals(avgBinProbsClass, x, 1e-3);
                        assertArrayEquals(fracPos, y, 1e-3);

                        String s = ec.stats();
                        if(first == null) {
                            first = ec;
                            sFirst = s;
                        } else {
//                            assertEquals(first, ec);
                            assertEquals(sFirst, s);
                            assertTrue(first.getRDiagBinPosCount().equalsWithEps(ec.getRDiagBinPosCount(), lpDtype == DataType.HALF ? 1e-3 : 1e-5));  //Lower precision due to fload
                            assertTrue(first.getRDiagBinTotalCount().equalsWithEps(ec.getRDiagBinTotalCount(), lpDtype == DataType.HALF ? 1e-3 : 1e-5));
                            assertTrue(first.getRDiagBinSumPredictions().equalsWithEps(ec.getRDiagBinSumPredictions(), lpDtype == DataType.HALF ? 1e-3 : 1e-5));
                            assertArrayEquals(first.getLabelCountsEachClass(), ec.getLabelCountsEachClass());
                            assertArrayEquals(first.getPredictionCountsEachClass(), ec.getPredictionCountsEachClass());
                            assertTrue(first.getProbHistogramOverall().equalsWithEps(ec.getProbHistogramOverall(), lpDtype == DataType.HALF ? 1e-3 : 1e-5));
                            assertTrue(first.getProbHistogramByLabelClass().equalsWithEps(ec.getProbHistogramByLabelClass(), lpDtype == DataType.HALF ? 1e-3 : 1e-5));
                        }
                    }
                }
            }
        } finally {
            Nd4j.setDefaultDataTypes(dtypeBefore, dtypeBefore);
        }
    }
 
Example 8
Source File: EvalTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testEval2() {

    DataType dtypeBefore = Nd4j.defaultFloatingPointType();
    Evaluation first = null;
    String sFirst = null;
    try {
        for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT}) {
            Nd4j.setDefaultDataTypes(globalDtype, globalDtype.isFPType() ? globalDtype : DataType.DOUBLE);
            for (DataType lpDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {

                //Confusion matrix:
                //actual 0      20      3
                //actual 1      10      5

                Evaluation evaluation = new Evaluation(Arrays.asList("class0", "class1"));
                INDArray predicted0 = Nd4j.create(new double[]{1, 0}, new long[]{1, 2}).castTo(lpDtype);
                INDArray predicted1 = Nd4j.create(new double[]{0, 1}, new long[]{1, 2}).castTo(lpDtype);
                INDArray actual0 = Nd4j.create(new double[]{1, 0}, new long[]{1, 2}).castTo(lpDtype);
                INDArray actual1 = Nd4j.create(new double[]{0, 1}, new long[]{1, 2}).castTo(lpDtype);
                for (int i = 0; i < 20; i++) {
                    evaluation.eval(actual0, predicted0);
                }

                for (int i = 0; i < 3; i++) {
                    evaluation.eval(actual0, predicted1);
                }

                for (int i = 0; i < 10; i++) {
                    evaluation.eval(actual1, predicted0);
                }

                for (int i = 0; i < 5; i++) {
                    evaluation.eval(actual1, predicted1);
                }

                assertEquals(20, evaluation.truePositives().get(0), 0);
                assertEquals(3, evaluation.falseNegatives().get(0), 0);
                assertEquals(10, evaluation.falsePositives().get(0), 0);
                assertEquals(5, evaluation.trueNegatives().get(0), 0);

                assertEquals((20.0 + 5) / (20 + 3 + 10 + 5), evaluation.accuracy(), 1e-6);

                String s = evaluation.stats();

                if(first == null) {
                    first = evaluation;
                    sFirst = s;
                } else {
                    assertEquals(first, evaluation);
                    assertEquals(sFirst, s);
                }
            }
        }
    } finally {
        Nd4j.setDefaultDataTypes(dtypeBefore, dtypeBefore);
    }
}
 
Example 9
Source File: RegressionEvalTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testKnownValues() {

    DataType dtypeBefore = Nd4j.defaultFloatingPointType();
    RegressionEvaluation first = null;
    String sFirst = null;
    try {
        for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT}) {
            Nd4j.setDefaultDataTypes(globalDtype, globalDtype.isFPType() ? globalDtype : DataType.DOUBLE);
            for (DataType lpDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {

                double[][] labelsD = new double[][]{{1, 2, 3}, {0.1, 0.2, 0.3}, {6, 5, 4}};
                double[][] predictedD = new double[][]{{2.5, 3.2, 3.8}, {2.15, 1.3, -1.2}, {7, 4.5, 3}};

                double[] expMSE = {2.484166667, 0.966666667, 1.296666667};
                double[] expMAE = {1.516666667, 0.933333333, 1.1};
                double[] expRSE = {0.368813923, 0.246598639, 0.530937216};
                double[] expCorrs = {0.997013483, 0.968619605, 0.915603032};
                double[] expR2 = {0.63118608, 0.75340136, 0.46906278};

                INDArray labels = Nd4j.create(labelsD).castTo(lpDtype);
                INDArray predicted = Nd4j.create(predictedD).castTo(lpDtype);

                RegressionEvaluation eval = new RegressionEvaluation(3);

                for (int xe = 0; xe < 2; xe++) {
                    eval.eval(labels, predicted);

                    for (int col = 0; col < 3; col++) {
                        assertEquals(expMSE[col], eval.meanSquaredError(col), lpDtype == DataType.HALF ? 1e-2 : 1e-4);
                        assertEquals(expMAE[col], eval.meanAbsoluteError(col), lpDtype == DataType.HALF ? 1e-2 : 1e-4);
                        assertEquals(Math.sqrt(expMSE[col]), eval.rootMeanSquaredError(col), lpDtype == DataType.HALF ? 1e-2 : 1e-4);
                        assertEquals(expRSE[col], eval.relativeSquaredError(col), lpDtype == DataType.HALF ? 1e-2 : 1e-4);
                        assertEquals(expCorrs[col], eval.pearsonCorrelation(col), lpDtype == DataType.HALF ? 1e-2 : 1e-4);
                        assertEquals(expR2[col], eval.rSquared(col), lpDtype == DataType.HALF ? 1e-2 : 1e-4);
                    }

                    String s = eval.stats();
                    if(first == null) {
                        first = eval;
                        sFirst = s;
                    } else if(lpDtype != DataType.HALF) {   //Precision issues with FP16
                        assertEquals(sFirst, s);
                        assertEquals(first, eval);
                    }

                    eval = new RegressionEvaluation(3);
                }
            }
        }
    } finally {
        Nd4j.setDefaultDataTypes(dtypeBefore, dtypeBefore);
    }
}
 
Example 10
Source File: EvaluationBinaryTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testEvaluationBinary() {
    //Compare EvaluationBinary to Evaluation class
    DataType dtypeBefore = Nd4j.defaultFloatingPointType();
    EvaluationBinary first = null;
    String sFirst = null;
    try {
        for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT}) {
            Nd4j.setDefaultDataTypes(globalDtype, globalDtype.isFPType() ? globalDtype : DataType.DOUBLE);
            for (DataType lpDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {

                Nd4j.getRandom().setSeed(12345);

                int nExamples = 50;
                int nOut = 4;
                long[] shape = {nExamples, nOut};

                INDArray labels = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(lpDtype, shape), 0.5));

                INDArray predicted = Nd4j.rand(lpDtype, shape);
                INDArray binaryPredicted = predicted.gt(0.5);

                EvaluationBinary eb = new EvaluationBinary();
                eb.eval(labels, predicted);

                //System.out.println(eb.stats());

                double eps = 1e-6;
                for (int i = 0; i < nOut; i++) {
                    INDArray lCol = labels.getColumn(i,true);
                    INDArray pCol = predicted.getColumn(i,true);
                    INDArray bpCol = binaryPredicted.getColumn(i,true);

                    int countCorrect = 0;
                    int tpCount = 0;
                    int tnCount = 0;
                    for (int j = 0; j < lCol.length(); j++) {
                        if (lCol.getDouble(j) == bpCol.getDouble(j)) {
                            countCorrect++;
                            if (lCol.getDouble(j) == 1) {
                                tpCount++;
                            } else {
                                tnCount++;
                            }
                        }
                    }
                    double acc = countCorrect / (double) lCol.length();

                    Evaluation e = new Evaluation();
                    e.eval(lCol, pCol);

                    assertEquals(acc, eb.accuracy(i), eps);
                    assertEquals(e.accuracy(), eb.scoreForMetric(ACCURACY, i), eps);
                    assertEquals(e.precision(1), eb.scoreForMetric(PRECISION, i), eps);
                    assertEquals(e.recall(1), eb.scoreForMetric(RECALL, i), eps);
                    assertEquals(e.f1(1), eb.scoreForMetric(F1, i), eps);
                    assertEquals(e.falseAlarmRate(), eb.scoreForMetric(FAR, i), eps);
                    assertEquals(e.falsePositiveRate(1), eb.falsePositiveRate(i), eps);


                    assertEquals(tpCount, eb.truePositives(i));
                    assertEquals(tnCount, eb.trueNegatives(i));

                    assertEquals((int) e.truePositives().get(1), eb.truePositives(i));
                    assertEquals((int) e.trueNegatives().get(1), eb.trueNegatives(i));
                    assertEquals((int) e.falsePositives().get(1), eb.falsePositives(i));
                    assertEquals((int) e.falseNegatives().get(1), eb.falseNegatives(i));

                    assertEquals(nExamples, eb.totalCount(i));

                    String s = eb.stats();
                    if(first == null) {
                        first = eb;
                        sFirst = s;
                    } else {
                        assertEquals(first, eb);
                        assertEquals(sFirst, s);
                    }
                }
            }
        }
    } finally {
        Nd4j.setDefaultDataTypes(dtypeBefore, dtypeBefore);
    }
}
 
Example 11
Source File: ROCBinaryTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testROCBinary() {
        //Compare ROCBinary to ROC class

        DataType dtypeBefore = Nd4j.defaultFloatingPointType();
        ROCBinary first30 = null;
        ROCBinary first0 = null;
        String sFirst30 = null;
        String sFirst0 = null;
        try {
            for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT}) {
//            for (DataType globalDtype : new DataType[]{DataType.HALF}) {
                Nd4j.setDefaultDataTypes(globalDtype, globalDtype.isFPType() ? globalDtype : DataType.DOUBLE);
                for (DataType lpDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
                    String msg = "globalDtype=" + globalDtype + ", labelPredictionsDtype=" + lpDtype;

                    int nExamples = 50;
                    int nOut = 4;
                    long[] shape = {nExamples, nOut};

                    for (int thresholdSteps : new int[]{30, 0}) { //0 == exact

                        Nd4j.getRandom().setSeed(12345);
                        INDArray labels =
                                Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(DataType.DOUBLE, shape), 0.5)).castTo(lpDtype);

                        Nd4j.getRandom().setSeed(12345);
                        INDArray predicted = Nd4j.rand(DataType.DOUBLE, shape).castTo(lpDtype);

                        ROCBinary rb = new ROCBinary(thresholdSteps);

                        for (int xe = 0; xe < 2; xe++) {
                            rb.eval(labels, predicted);

                            //System.out.println(rb.stats());

                            double eps = lpDtype == DataType.HALF ? 1e-2 : 1e-6;
                            for (int i = 0; i < nOut; i++) {
                                INDArray lCol = labels.getColumn(i, true);
                                INDArray pCol = predicted.getColumn(i, true);


                                ROC r = new ROC(thresholdSteps);
                                r.eval(lCol, pCol);

                                double aucExp = r.calculateAUC();
                                double auc = rb.calculateAUC(i);

                                assertEquals(msg, aucExp, auc, eps);

                                long apExp = r.getCountActualPositive();
                                long ap = rb.getCountActualPositive(i);
                                assertEquals(msg, ap, apExp);

                                long anExp = r.getCountActualNegative();
                                long an = rb.getCountActualNegative(i);
                                assertEquals(anExp, an);

                                PrecisionRecallCurve pExp = r.getPrecisionRecallCurve();
                                PrecisionRecallCurve p = rb.getPrecisionRecallCurve(i);

                                assertEquals(msg, pExp, p);
                            }

                            String s = rb.stats();

                            if(thresholdSteps == 0){
                                if(first0 == null) {
                                    first0 = rb;
                                    sFirst0 = s;
                                } else if(lpDtype != DataType.HALF) {   //Precision issues with FP16
                                    assertEquals(msg, sFirst0, s);
                                    assertEquals(first0, rb);
                                }
                            } else {
                                if(first30 == null) {
                                    first30 = rb;
                                    sFirst30 = s;
                                } else if(lpDtype != DataType.HALF) {   //Precision issues with FP16
                                    assertEquals(msg, sFirst30, s);
                                    assertEquals(first30, rb);
                                }
                            }

//                            rb.reset();
                            rb = new ROCBinary(thresholdSteps);
                        }
                    }
                }
            }
        } finally {
            Nd4j.setDefaultDataTypes(dtypeBefore, dtypeBefore);
        }
    }
 
Example 12
Source File: MaxOut.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public DataType resultType() {
    return Nd4j.defaultFloatingPointType();
}
 
Example 13
Source File: MaxOut.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public DataType resultType(OpContext oc) {
    return Nd4j.defaultFloatingPointType();
}
 
Example 14
Source File: JaccardDistance.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public DataType resultType() {
    return Nd4j.defaultFloatingPointType();
}
 
Example 15
Source File: BaseReduce3Op.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public DataType resultType() {
    if(x.dataType().isFPType())
        return x.dataType();
    return Nd4j.defaultFloatingPointType();
}
 
Example 16
Source File: TestEarlyStoppingCompGraph.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testAEScoreFunctionSimple() throws Exception {
    DataType dt = Nd4j.defaultFloatingPointType();

    for(Metric metric : new Metric[]{Metric.MSE,
            Metric.MAE}) {
        log.info("Metric: " + metric);

        ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
                .graphBuilder()
                .addInputs("in")
                .layer("0", new AutoEncoder.Builder().nIn(784).nOut(32).build(), "in")
                .setOutputs("0")

                .build();

        ComputationGraph net = new ComputationGraph(conf);
        net.init();

        DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);

        List<DataSet> l = new ArrayList<>();
        for( int i=0; i<10; i++ ){
            DataSet ds = iter.next();
            l.add(new DataSet(ds.getFeatures(), ds.getFeatures()));
        }

        iter = new ExistingDataSetIterator(l);

        EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
        EarlyStoppingConfiguration<ComputationGraph> esConf =
                new EarlyStoppingConfiguration.Builder<ComputationGraph>()
                        .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                        .iterationTerminationConditions(
                                new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                        .scoreCalculator(new AutoencoderScoreCalculator(metric, iter)).modelSaver(saver)
                        .build();

        EarlyStoppingGraphTrainer trainer = new EarlyStoppingGraphTrainer(esConf, net, iter);
        EarlyStoppingResult<ComputationGraph> result = trainer.pretrain();

        assertNotNull(result.getBestModel());
        assertTrue(result.getBestModelScore() > 0.0);
    }
}