package org.deeplearning4j;

import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.saver.LocalFileModelSaver;
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator;
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
import org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition;
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.util.concurrent.TimeUnit;

public class Trainer {
    static int batchSize = 16; // how many examples to simultaneously train in the network
    static EmnistDataSetIterator.Set emnistSet = EmnistDataSetIterator.Set.BALANCED;
    static int rngSeed = 123;
    static int numRows = 28;
    static int numColumns = 28;
    static int reportingInterval = 5;

    public static void main(String... args) throws java.io.IOException {
        // create the data iterators for emnist
        DataSetIterator emnistTrain = new EmnistDataSetIterator(emnistSet, batchSize, true);
        DataSetIterator emnistTest = new EmnistDataSetIterator(emnistSet, batchSize, false);

        int outputNum = EmnistDataSetIterator.numLabels(emnistSet);

        // network configuration (not yet initialized)
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(rngSeed)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .updater(new Adam())
                .l2(1e-4)
                .list()
                .layer(new DenseLayer.Builder()
                        .nIn(numRows * numColumns) // Number of input datapoints.
                        .nOut(1000) // Number of output datapoints.
                        .activation(Activation.RELU) // Activation function.
                        .weightInit(WeightInit.XAVIER) // Weight initialization.
                        .build())
                .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nIn(1000)
                        .nOut(outputNum)
                        .activation(Activation.SOFTMAX)
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .pretrain(false).backprop(true)
                .build();

        // create the MLN
        MultiLayerNetwork network = new MultiLayerNetwork(conf);
        network.init();

        // pass a training listener that reports score every N iterations
        network.addListeners(new ScoreIterationListener(reportingInterval));

        // here we set up an early stopping trainer
        // early stopping is useful when your trainer runs for
        // a long time or you need to programmatically stop training
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                .iterationTerminationConditions(new MaxTimeIterationTerminationCondition(20, TimeUnit.MINUTES))
                .scoreCalculator(new DataSetLossCalculator(emnistTest, true))
                .evaluateEveryNEpochs(1)
                .modelSaver(new LocalFileModelSaver(System.getProperty("user.dir")))
                .build();

        // training
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, network, emnistTrain);
        EarlyStoppingResult result = trainer.fit();

        // print out early stopping results
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        // evaluate basic performance
        Evaluation eval = network.evaluate(emnistTest);
        System.out.println(eval.accuracy());
        System.out.println(eval.precision());
        System.out.println(eval.recall());

        // evaluate ROC and calculate the Area Under Curve
        ROCMultiClass roc = network.evaluateROCMultiClass(emnistTest);
        System.out.println(roc.calculateAverageAUC());

        // calculate AUC for a single class
        int classIndex = 0;
        System.out.println(roc.calculateAUC(classIndex));

        // optionally, you can print all stats from the evaluations
        System.out.println(eval.stats());
        System.out.println(roc.stats());
    }
}