package packt.dl4jexamples; import java.io.File; import java.io.IOException; import org.deeplearning4j.datasets.fetchers.MnistDataFetcher; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.RBM; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.api.IterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.Collections; import org.deeplearning4j.util.ModelSerializer; public class DeepAutoEncoderExample { private MultiLayerNetwork model; private File modelFile; private DataSetIterator iterator; private final int numberOfRows = 28; private final int numberOfColumns = 28; public DeepAutoEncoderExample() { try { int seed = 123; int numberOfIterations = 1; iterator = new MnistDataSetIterator(1000, MnistDataFetcher.NUM_EXAMPLES, true); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .iterations(numberOfIterations) .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) .list() .layer(0, new RBM.Builder().nIn(numberOfRows * numberOfColumns) .nOut(1000) .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) .layer(1, new RBM.Builder().nIn(1000).nOut(500) .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) .layer(2, new RBM.Builder().nIn(500).nOut(250) .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) .layer(3, new RBM.Builder().nIn(250).nOut(100) .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) .layer(4, new RBM.Builder().nIn(100).nOut(30) .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) //encoding stops .layer(5, new RBM.Builder().nIn(30).nOut(100) .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) //decoding starts .layer(6, new RBM.Builder().nIn(100).nOut(250) .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) .layer(7, new RBM.Builder().nIn(250).nOut(500) .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) .layer(8, new RBM.Builder().nIn(500).nOut(1000) .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) .layer(9, new OutputLayer.Builder( LossFunctions.LossFunction.RMSE_XENT).nIn(1000) .nOut(numberOfRows * numberOfColumns).build()) .pretrain(true).backprop(true) .build(); model = new MultiLayerNetwork(conf); model.init(); model.setListeners(Collections.singletonList( (IterationListener) new ScoreIterationListener())); while (iterator.hasNext()) { DataSet dataSet = iterator.next(); model.fit(new DataSet(dataSet.getFeatureMatrix(), dataSet.getFeatureMatrix())); } modelFile = new File("savedModel"); ModelSerializer.writeModel(model, modelFile, true); } catch (IOException ex) { ex.printStackTrace(); } } public void retrieveModel() { try { modelFile = new File("savedModel"); MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelFile); } catch (IOException ex) { ex.printStackTrace(); } } public static void main(String[] args) throws Exception { new DeepAutoEncoderExample(); } }