org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer Java Examples

The following examples show how to use org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: TestEarlyStopping.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testEarlyStoppingEveryNEpoch() {
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .updater(new Sgd(0.01)).weightInit(WeightInit.XAVIER).list()
                    .layer(0, new OutputLayer.Builder().nIn(4).nOut(3)
                            .activation(Activation.SOFTMAX)
                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));

    DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                    new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                                    .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                                    .scoreCalculator(new DataSetLossCalculator(irisIter, true))
                                    .evaluateEveryNEpochs(2).modelSaver(saver).build();

    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, irisIter);

    EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
    System.out.println(result);

    assertEquals(5, result.getTotalEpochs());
    assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
}
 
Example #2
Source File: TestEarlyStopping.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testBadTuning() {
    //Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition

    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .updater(new Sgd(5.0)) //Intentionally huge LR
                    .weightInit(WeightInit.XAVIER).list()
                    .layer(0, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX)
                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));

    DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                    new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                                    .epochTerminationConditions(new MaxEpochsTerminationCondition(5000))
                                    .iterationTerminationConditions(
                                                    new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES),
                                                    new MaxScoreIterationTerminationCondition(10)) //Initial score is ~2.5
                                    .scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver)
                                    .build();

    IEarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, irisIter);
    EarlyStoppingResult result = trainer.fit();

    assertTrue(result.getTotalEpochs() < 5);
    assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition,
                    result.getTerminationReason());
    String expDetails = new MaxScoreIterationTerminationCondition(10).toString();
    assertEquals(expDetails, result.getTerminationDetails());

    assertEquals(0, result.getBestModelEpoch());
    assertNotNull(result.getBestModel());
}
 
Example #3
Source File: TestEarlyStopping.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testNoImprovementNEpochsTermination() {
    //Idea: terminate training if score (test set loss) does not improve for 5 consecutive epochs
    //Simulate this by setting LR = 0.0

    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .updater(new Sgd(0.0)).weightInit(WeightInit.XAVIER).list()
                    .layer(0, new OutputLayer.Builder().nIn(4).nOut(3)
                            .activation(Activation.SOFTMAX)
                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));

    DataSetIterator irisIter = new IrisDataSetIterator(150, 150);

    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                    new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                                    .epochTerminationConditions(new MaxEpochsTerminationCondition(100),
                                                    new ScoreImprovementEpochTerminationCondition(5))
                                    .iterationTerminationConditions(
                                                    new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES),
                                                    new MaxScoreIterationTerminationCondition(50)) //Initial score is ~8
                                    .scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver)
                                    .build();

    IEarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, irisIter);
    EarlyStoppingResult result = trainer.fit();

    //Expect no score change due to 0 LR -> terminate after 6 total epochs
    assertEquals(6, result.getTotalEpochs());
    assertEquals(0, result.getBestModelEpoch());
    assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
    String expDetails = new ScoreImprovementEpochTerminationCondition(5).toString();
    assertEquals(expDetails, result.getTerminationDetails());
}
 
Example #4
Source File: TestEarlyStopping.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testEarlyStoppingGetBestModel() {
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER).list()
                    .layer(0, new OutputLayer.Builder().nIn(4).nOut(3)
                            .activation(Activation.SOFTMAX)
                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));

    DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
    MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter);

    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                    new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                                    .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                                    .iterationTerminationConditions(
                                                    new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                                    .scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver)
                                    .build();

    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, mIter);

    EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
    System.out.println(result);

    MultiLayerNetwork mln = result.getBestModel();

    assertEquals(net.getnLayers(), mln.getnLayers());
    assertEquals(net.conf().getOptimizationAlgo(), mln.conf().getOptimizationAlgo());
    BaseLayer bl = (BaseLayer) net.conf().getLayer();
    assertEquals(bl.getActivationFn().toString(), ((BaseLayer) mln.conf().getLayer()).getActivationFn().toString());
    assertEquals(bl.getIUpdater(), ((BaseLayer) mln.conf().getLayer()).getIUpdater());
}
 
Example #5
Source File: TestEarlyStopping.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testListeners() {
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER).list()
                    .layer(0, new OutputLayer.Builder().nIn(4).nOut(3)
                            .activation(Activation.SOFTMAX)
                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));

    DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                    new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                                    .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                                    .iterationTerminationConditions(
                                                    new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                                    .scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver)
                                    .build();

    LoggingEarlyStoppingListener listener = new LoggingEarlyStoppingListener();

    IEarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, irisIter, listener);

    trainer.fit();

    assertEquals(1, listener.onStartCallCount);
    assertEquals(5, listener.onEpochCallCount);
    assertEquals(1, listener.onCompletionCallCount);
}
 
Example #6
Source File: TestEarlyStopping.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testEarlyStoppingListeners() {
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER).list()
            .layer(0, new OutputLayer.Builder().nIn(4).nOut(3)
                    .activation(Activation.SOFTMAX)
                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
            .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);

    TestListener tl = new TestListener();
    net.setListeners(tl);

    DataSetIterator irisIter = new IrisDataSetIterator(50, 150);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
            new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                    .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                    .iterationTerminationConditions(
                            new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                    .scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver)
                    .build();

    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, irisIter);

    EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();

    assertEquals(5, tl.countEpochStart);
    assertEquals(5, tl.countEpochEnd);
    assertEquals(5 * 150/50, tl.iterCount);

    assertEquals(4, tl.maxEpochStart);
    assertEquals(4, tl.maxEpochEnd);
}
 
Example #7
Source File: TestEarlyStopping.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testClassificationScoreFunctionSimple() throws Exception {

    for(Evaluation.Metric metric : Evaluation.Metric.values()) {
        log.info("Metric: " + metric);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .list()
                .layer(new DenseLayer.Builder().nIn(784).nOut(32).build())
                .layer(new OutputLayer.Builder().nIn(32).nOut(10).activation(Activation.SOFTMAX).build())
                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);

        List<DataSet> l = new ArrayList<>();
        for( int i=0; i<10; i++ ){
            DataSet ds = iter.next();
            l.add(ds);
        }

        iter = new ExistingDataSetIterator(l);

        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
        EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                        .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                        .iterationTerminationConditions(
                                new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                        .scoreCalculator(new ClassificationScoreCalculator(metric, iter)).modelSaver(saver)
                        .build();

        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
        EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();

        assertNotNull(result.getBestModel());
    }
}
 
Example #8
Source File: TestEarlyStopping.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testAEScoreFunctionSimple() throws Exception {

    for(Metric metric : new Metric[]{Metric.MSE,
            Metric.MAE}) {
        log.info("Metric: " + metric);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .list()
                .layer(new AutoEncoder.Builder().nIn(784).nOut(32).build())

                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);

        List<DataSet> l = new ArrayList<>();
        for( int i=0; i<10; i++ ){
            DataSet ds = iter.next();
            l.add(new DataSet(ds.getFeatures(), ds.getFeatures()));
        }

        iter = new ExistingDataSetIterator(l);

        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
        EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                        .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                        .iterationTerminationConditions(
                                new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                        .scoreCalculator(new AutoencoderScoreCalculator(metric, iter)).modelSaver(saver)
                        .build();

        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
        EarlyStoppingResult<MultiLayerNetwork> result = trainer.pretrain();

        assertNotNull(result.getBestModel());
        assertTrue(result.getBestModelScore() > 0.0);
    }
}
 
Example #9
Source File: LearnIrisBackprop.java    From aifh with Apache License 2.0 4 votes vote down vote up
/**
 * The main method.
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 0.1;
        int splitTrainNum = (int) (150 * .75);

        int numInputs = 4;
        int numOutputs = 3;
        int numHiddenNodes = 50;

        // Setup training data.
        final InputStream istream = LearnIrisBackprop.class.getResourceAsStream("/iris.csv");
        if( istream==null ) {
            System.out.println("Cannot access data set, make sure the resources are available.");
            System.exit(1);
        }
        final NormalizeDataSet ds = NormalizeDataSet.load(istream);
        final CategoryMap species = ds.encodeOneOfN(4); // species is column 4
        istream.close();

        DataSet next = ds.extractSupervised(0, 4, 4, 3);
        next.shuffle();

        // Training and validation data split
        SplitTestAndTrain testAndTrain = next.splitTestAndTrain(splitTrainNum, new Random(seed));
        DataSet trainSet = testAndTrain.getTrain();
        DataSet validationSet = testAndTrain.getTest();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainSet.asList(), trainSet.numExamples());

        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationSet.numExamples());

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .list(2)
                .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                        .weightInit(WeightInit.XAVIER)
                        .activation("relu")
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .weightInit(WeightInit.XAVIER)
                        .activation("softmax")
                        .nIn(numHiddenNodes).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();


        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                .epochTerminationConditions(new MaxEpochsTerminationCondition(500)) //Max of 50 epochs
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(25))
                .evaluateEveryNEpochs(1)
                .scoreCalculator(new DataSetLossCalculator(validationSetIterator, true))     //Calculate test set score
                .modelSaver(saver)
                .build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            System.out.println(features + ":Prediction("+findSpecies(labels,species)
                    +"):Actual("+findSpecies(predicted,species)+")" + predicted );
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch(Exception ex) {
        ex.printStackTrace();
    }
}
 
Example #10
Source File: LearnDigitsDropout.java    From aifh with Apache License 2.0 4 votes vote down vote up
/**
 * The main method.
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 1e-2;
        int nEpochs = 50;
        int batchSize = 500;

        // Setup training data.
        System.out.println("Please wait, reading MNIST training data.");
        String dir = System.getProperty("user.dir");
        MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
        MNISTReader validationReader = MNIST.loadMNIST(dir, false);

        DataSet trainingSet = trainingReader.getData();
        DataSet validationSet = validationReader.getData();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationReader.getNumRows());

        System.out.println("Training set size: " + trainingReader.getNumImages());
        System.out.println("Validation set size: " + validationReader.getNumImages());

        System.out.println(trainingSet.get(0).getFeatures().size(1));
        System.out.println(validationSet.get(0).getFeatures().size(1));

        int numInputs = trainingReader.getNumCols()*trainingReader.getNumRows();
        int numOutputs = 10;
        int numHiddenNodes = 100;

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .list(2)
                .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                        .weightInit(WeightInit.XAVIER)
                        .activation("relu")
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .weightInit(WeightInit.XAVIER)
                        .activation("softmax")
                        .nIn(numHiddenNodes).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();


        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                //.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
                .evaluateEveryNEpochs(1)
                .scoreCalculator(new DataSetLossCalculator(validationSetIterator, true))     //Calculate test set score
                .modelSaver(saver)
                .build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch(Exception ex) {
        ex.printStackTrace();
    }

}
 
Example #11
Source File: LearnDigitsBackprop.java    From aifh with Apache License 2.0 4 votes vote down vote up
/**
 * The main method.
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 1e-2;
        int nEpochs = 50;
        int batchSize = 500;

        // Setup training data.
        System.out.println("Please wait, reading MNIST training data.");
        String dir = System.getProperty("user.dir");
        MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
        MNISTReader validationReader = MNIST.loadMNIST(dir, false);

        DataSet trainingSet = trainingReader.getData();
        DataSet validationSet = validationReader.getData();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationReader.getNumRows());

        System.out.println("Training set size: " + trainingReader.getNumImages());
        System.out.println("Validation set size: " + validationReader.getNumImages());

        System.out.println(trainingSet.get(0).getFeatures().size(1));
        System.out.println(validationSet.get(0).getFeatures().size(1));

        int numInputs = trainingReader.getNumCols()*trainingReader.getNumRows();
        int numOutputs = 10;
        int numHiddenNodes = 200;

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .regularization(true).dropOut(0.50)
                .list(2)
                .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                        .weightInit(WeightInit.XAVIER)
                        .activation("relu")
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .weightInit(WeightInit.XAVIER)
                        .activation("softmax")
                        .nIn(numHiddenNodes).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();


        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                //.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
                .evaluateEveryNEpochs(1)
                .scoreCalculator(new DataSetLossCalculator(validationSetIterator, true))     //Calculate test set score
                .modelSaver(saver)
                .build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch(Exception ex) {
        ex.printStackTrace();
    }

}
 
Example #12
Source File: TestEarlyStopping.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testEarlyStoppingMaximizeScore() throws Exception {
    Nd4j.getRandom().setSeed(12345);

    int outputs = 2;

    DataSet ds = new DataSet(
            Nd4j.rand(new int[]{3, 10, 50}),
            TestUtils.randomOneHotTimeSeries(3, outputs, 50, 12345));
    DataSetIterator train = new ExistingDataSetIterator(
            Arrays.asList(ds, ds, ds, ds, ds, ds, ds, ds, ds, ds));
    DataSetIterator test = new SingletonDataSetIterator(ds);


    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(123)
            .weightInit(WeightInit.XAVIER)
            .updater(new Adam(0.1))
            .activation(Activation.ELU)
            .l2(1e-5)
            .gradientNormalization(GradientNormalization
                    .ClipElementWiseAbsoluteValue)
            .gradientNormalizationThreshold(1.0)
            .list()
            .layer(0, new LSTM.Builder()
                    .nIn(10)
                    .nOut(10)
                    .activation(Activation.TANH)
                    .gateActivationFunction(Activation.SIGMOID)
                    .dropOut(0.5)
                    .build())
            .layer(1, new RnnOutputLayer.Builder()
                    .nIn(10)
                    .nOut(outputs)
                    .activation(Activation.SOFTMAX)
                    .lossFunction(LossFunctions.LossFunction.MCXENT)
                    .build())
            .build();

    File f = testDir.newFolder();
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new LocalFileModelSaver(f.getAbsolutePath());
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
            new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                    .epochTerminationConditions(
                            new MaxEpochsTerminationCondition(10),
                            new ScoreImprovementEpochTerminationCondition(1))
                    .iterationTerminationConditions(
                            new MaxTimeIterationTerminationCondition(10, TimeUnit.MINUTES))
                    .scoreCalculator(new ClassificationScoreCalculator(Evaluation.Metric.F1, test))
                    .modelSaver(saver)
                    .saveLastModel(true)
                    .build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    EarlyStoppingTrainer t = new EarlyStoppingTrainer(esConf, net, train);
    EarlyStoppingResult<MultiLayerNetwork> result = t.fit();

    Map<Integer,Double> map = result.getScoreVsEpoch();
    for( int i=1; i<map.size(); i++ ){
        if(i == map.size() - 1){
            assertTrue(map.get(i) <+ map.get(i-1));
        } else {
            assertTrue(map.get(i) > map.get(i-1));
        }
    }
}
 
Example #13
Source File: TestEarlyStopping.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testVAEScoreFunctionReconstructionProbSimple() throws Exception {

    for(boolean logProb : new boolean[]{false, true}) {

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .list()
                .layer(new VariationalAutoencoder.Builder()
                        .nIn(784).nOut(32)
                        .encoderLayerSizes(64)
                        .decoderLayerSizes(64)
                        .reconstructionDistribution(new BernoulliReconstructionDistribution(Activation.SIGMOID))
                        .build())

                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);

        List<DataSet> l = new ArrayList<>();
        for (int i = 0; i < 10; i++) {
            DataSet ds = iter.next();
            l.add(new DataSet(ds.getFeatures(), ds.getFeatures()));
        }

        iter = new ExistingDataSetIterator(l);

        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
        EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                        .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                        .iterationTerminationConditions(
                                new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                        .scoreCalculator(new VAEReconProbScoreCalculator(iter, 20, logProb)).modelSaver(saver)
                        .build();

        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
        EarlyStoppingResult<MultiLayerNetwork> result = trainer.pretrain();

        assertNotNull(result.getBestModel());
        assertTrue(result.getBestModelScore() > 0.0);
    }
}
 
Example #14
Source File: TestEarlyStopping.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testVAEScoreFunctionSimple() throws Exception {

    for(Metric metric : new Metric[]{Metric.MSE,
            Metric.MAE}) {
        log.info("Metric: " + metric);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .list()
                .layer(new VariationalAutoencoder.Builder()
                        .nIn(784).nOut(32)
                        .encoderLayerSizes(64)
                        .decoderLayerSizes(64)
                        .build())

                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);

        List<DataSet> l = new ArrayList<>();
        for( int i=0; i<10; i++ ){
            DataSet ds = iter.next();
            l.add(new DataSet(ds.getFeatures(), ds.getFeatures()));
        }

        iter = new ExistingDataSetIterator(l);

        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
        EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                        .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                        .iterationTerminationConditions(
                                new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                        .scoreCalculator(new VAEReconErrorScoreCalculator(metric, iter)).modelSaver(saver)
                        .build();

        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
        EarlyStoppingResult<MultiLayerNetwork> result = trainer.pretrain();

        assertNotNull(result.getBestModel());
        assertTrue(result.getBestModelScore() > 0.0);
    }
}
 
Example #15
Source File: DigitRecognizerConvolutionalNeuralNetwork.java    From Java-Machine-Learning-for-Computer-Vision with MIT License 4 votes vote down vote up
public void train() throws IOException {

        MnistDataSetIterator mnistTrain = new MnistDataSetIterator(MINI_BATCH_SIZE, true, 12345);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(SEED)
                .learningRate(LEARNING_RATE)
                .weightInit(WeightInit.XAVIER)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .updater(Updater.NESTEROVS)
                .list()
                .layer(0, new ConvolutionLayer.Builder(5, 5)
                        .nIn(CHANNELS)
                        .stride(1, 1)
                        .nOut(20)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(2, 2)
                        .build())
                .layer(2, new ConvolutionLayer.Builder(5, 5)
                        .nIn(20)
                        .stride(1, 1)
                        .nOut(50)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(2, 2)
                        .build())
                .layer(4, new DenseLayer.Builder().activation(Activation.RELU)
                        .nIn(800)
                        .nOut(128).build())
                .layer(5, new DenseLayer.Builder().activation(Activation.RELU)
                        .nIn(128)
                        .nOut(64).build())
                .layer(6, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nOut(OUTPUT)
                        .activation(Activation.SOFTMAX)
                        .build())
                .setInputType(InputType.convolutionalFlat(28, 28, 1))
                .backprop(true).pretrain(false).build();

        EarlyStoppingConfiguration earlyStoppingConfiguration = new EarlyStoppingConfiguration.Builder()
                .epochTerminationConditions(new MaxEpochsTerminationCondition(MAX_EPOCHS))
                .scoreCalculator(new AccuracyCalculator(new MnistDataSetIterator(MINI_BATCH_SIZE, false, 12345)))
                .evaluateEveryNEpochs(1)
                .modelSaver(new LocalFileModelSaver(OUT_DIR))
                .build();

        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(earlyStoppingConfiguration, conf, mnistTrain);

        EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();

        log.info("Termination reason: " + result.getTerminationReason());
        log.info("Termination details: " + result.getTerminationDetails());
        log.info("Total epochs: " + result.getTotalEpochs());
        log.info("Best epoch number: " + result.getBestModelEpoch());
        log.info("Score at best epoch: " + result.getBestModelScore());
    }
 
Example #16
Source File: TestEarlyStopping.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testRegressionScoreFunctionSimple() throws Exception {

    for(Metric metric : new Metric[]{Metric.MSE,
            Metric.MAE}) {
        log.info("Metric: " + metric);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .list()
                .layer(new DenseLayer.Builder().nIn(784).nOut(32).build())
                .layer(new OutputLayer.Builder().nIn(32).nOut(784).activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.MSE).build())
                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);

        List<DataSet> l = new ArrayList<>();
        for( int i=0; i<10; i++ ){
            DataSet ds = iter.next();
            l.add(new DataSet(ds.getFeatures(), ds.getFeatures()));
        }

        iter = new ExistingDataSetIterator(l);

        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
        EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                        .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                        .iterationTerminationConditions(
                                new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                        .scoreCalculator(new RegressionScoreCalculator(metric, iter)).modelSaver(saver)
                        .build();

        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
        EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();

        assertNotNull(result.getBestModel());
        assertTrue(result.getBestModelScore() > 0.0);
    }
}
 
Example #17
Source File: TestEarlyStopping.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testTimeTermination() {
    //test termination after max time

    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .updater(new Sgd(1e-6)).weightInit(WeightInit.XAVIER).list()
                    .layer(0, new OutputLayer.Builder().nIn(4).nOut(3)
                                .activation(Activation.SOFTMAX)
                                .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));

    DataSetIterator irisIter = new IrisDataSetIterator(150, 150);

    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                    new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                                    .epochTerminationConditions(new MaxEpochsTerminationCondition(10000))
                                    .iterationTerminationConditions(
                                                    new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS),
                                                    new MaxScoreIterationTerminationCondition(50)) //Initial score is ~8
                                    .scoreCalculator(new DataSetLossCalculator(irisIter, true))
                                    .modelSaver(saver).build();

    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, irisIter);
    long startTime = System.currentTimeMillis();
    EarlyStoppingResult result = trainer.fit();
    long endTime = System.currentTimeMillis();
    int durationSeconds = (int) (endTime - startTime) / 1000;

    assertTrue(durationSeconds >= 3);
    assertTrue(durationSeconds <= 12);

    assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition,
                    result.getTerminationReason());
    String expDetails = new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS).toString();
    assertEquals(expDetails, result.getTerminationDetails());
}
 
Example #18
Source File: Trainer.java    From dl4j-quickstart with Apache License 2.0 4 votes vote down vote up
public static void main(String... args) throws java.io.IOException {
    // create the data iterators for emnist
    DataSetIterator emnistTrain = new EmnistDataSetIterator(emnistSet, batchSize, true);
    DataSetIterator emnistTest = new EmnistDataSetIterator(emnistSet, batchSize, false);

    int outputNum = EmnistDataSetIterator.numLabels(emnistSet);

    // network configuration (not yet initialized)
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(rngSeed)
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .updater(new Adam())
            .l2(1e-4)
            .list()
            .layer(new DenseLayer.Builder()
                    .nIn(numRows * numColumns) // Number of input datapoints.
                    .nOut(1000) // Number of output datapoints.
                    .activation(Activation.RELU) // Activation function.
                    .weightInit(WeightInit.XAVIER) // Weight initialization.
                    .build())
            .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                    .nIn(1000)
                    .nOut(outputNum)
                    .activation(Activation.SOFTMAX)
                    .weightInit(WeightInit.XAVIER)
                    .build())
            .pretrain(false).backprop(true)
            .build();

    // create the MLN
    MultiLayerNetwork network = new MultiLayerNetwork(conf);
    network.init();

    // pass a training listener that reports score every N iterations
    network.addListeners(new ScoreIterationListener(reportingInterval));

    // here we set up an early stopping trainer
    // early stopping is useful when your trainer runs for
    // a long time or you need to programmatically stop training
    EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
            .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
            .iterationTerminationConditions(new MaxTimeIterationTerminationCondition(20, TimeUnit.MINUTES))
            .scoreCalculator(new DataSetLossCalculator(emnistTest, true))
            .evaluateEveryNEpochs(1)
            .modelSaver(new LocalFileModelSaver(System.getProperty("user.dir")))
            .build();

    // training
    EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, network, emnistTrain);
    EarlyStoppingResult result = trainer.fit();

    // print out early stopping results
    System.out.println("Termination reason: " + result.getTerminationReason());
    System.out.println("Termination details: " + result.getTerminationDetails());
    System.out.println("Total epochs: " + result.getTotalEpochs());
    System.out.println("Best epoch number: " + result.getBestModelEpoch());
    System.out.println("Score at best epoch: " + result.getBestModelScore());

    // evaluate basic performance
    Evaluation eval = network.evaluate(emnistTest);
    System.out.println(eval.accuracy());
    System.out.println(eval.precision());
    System.out.println(eval.recall());

    // evaluate ROC and calculate the Area Under Curve
    ROCMultiClass roc = network.evaluateROCMultiClass(emnistTest);
    System.out.println(roc.calculateAverageAUC());

    // calculate AUC for a single class
    int classIndex = 0;
    System.out.println(roc.calculateAUC(classIndex));

    // optionally, you can print all stats from the evaluations
    System.out.println(eval.stats());
    System.out.println(roc.stats());
}