package com.javadeeplearningcookbook.examples; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.transform.TransformProcessRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.transform.TransformProcess; import org.datavec.api.transform.schema.Schema; import org.datavec.api.util.ndarray.RecordConverter; import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.iterator.DataSetIteratorSplitter; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.ui.api.UIServer; import org.deeplearning4j.ui.stats.StatsListener; import org.deeplearning4j.ui.storage.InMemoryStatsStorage; import org.deeplearning4j.util.ModelSerializer; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.lossfunctions.impl.LossMCXENT; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; import java.io.IOException; import java.util.Arrays; public class CustomerRetentionPredictionExample { private static final Logger log = LoggerFactory.getLogger("com.javadeeplearningcookbook.examples.CustomerLossPrediction.class"); private static Schema generateSchema(){ final Schema schema = new Schema.Builder() .addColumnString("RowNumber") .addColumnInteger("CustomerId") .addColumnString("Surname") .addColumnInteger("CreditScore") .addColumnCategorical("Geography", Arrays.asList("France","Germany","Spain")) .addColumnCategorical("Gender", Arrays.asList("Male","Female")) .addColumnsInteger("Age", "Tenure") .addColumnDouble("Balance") .addColumnsInteger("NumOfProducts","HasCrCard","IsActiveMember") .addColumnDouble("EstimatedSalary") .addColumnInteger("Exited") .build(); return schema; } private static RecordReader applyTransform(RecordReader recordReader, Schema schema){ final TransformProcess transformProcess = new TransformProcess.Builder(schema) .removeColumns("RowNumber","CustomerId","Surname") .categoricalToInteger("Gender") .categoricalToOneHot("Geography") .removeColumns("Geography[France]") .build(); final TransformProcessRecordReader transformProcessRecordReader = new TransformProcessRecordReader(recordReader,transformProcess); return transformProcessRecordReader; } private static RecordReader generateReader(File file) throws IOException, InterruptedException { final RecordReader recordReader = new CSVRecordReader(1,','); recordReader.initialize(new FileSplit(file)); final RecordReader transformProcessRecordReader=applyTransform(recordReader,generateSchema()); return transformProcessRecordReader; } public static void main(String[] args) throws IOException, InterruptedException { final int labelIndex=11; final int batchSize=8; final int numClasses=2; final INDArray weightsArray = Nd4j.create(new double[]{0.57, 0.75}); final RecordReader recordReader = generateReader(new ClassPathResource("Churn_Modelling.csv").getFile()); final DataSetIterator dataSetIterator = new RecordReaderDataSetIterator.Builder(recordReader,batchSize) .classification(labelIndex,numClasses) .build(); final DataNormalization dataNormalization = new NormalizerStandardize(); dataNormalization.fit(dataSetIterator); dataSetIterator.setPreProcessor(dataNormalization); final DataSetIteratorSplitter dataSetIteratorSplitter = new DataSetIteratorSplitter(dataSetIterator,1250,0.8); log.info("Building Model------------------->>>>>>>>>"); final MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder() .weightInit(WeightInit.RELU_UNIFORM) .updater(new Adam(0.015D)) .list() .layer(new DenseLayer.Builder().nIn(11).nOut(6).activation(Activation.RELU).dropOut(0.9).build()) .layer(new DenseLayer.Builder().nIn(6).nOut(6).activation(Activation.RELU).dropOut(0.9).build()) .layer(new DenseLayer.Builder().nIn(6).nOut(4).activation(Activation.RELU).dropOut(0.9).build()) .layer(new OutputLayer.Builder(new LossMCXENT(weightsArray)).nIn(4).nOut(2).activation(Activation.SOFTMAX).build()) .build(); final UIServer uiServer = UIServer.getInstance(); final StatsStorage statsStorage = new InMemoryStatsStorage(); final MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(configuration); multiLayerNetwork.init(); multiLayerNetwork.setListeners(new ScoreIterationListener(100), new StatsListener(statsStorage)); uiServer.attach(statsStorage); multiLayerNetwork.fit(dataSetIteratorSplitter.getTrainIterator(),100); final Evaluation evaluation = multiLayerNetwork.evaluate(dataSetIteratorSplitter.getTestIterator(),Arrays.asList("0","1")); System.out.println(evaluation.stats()); final File file = new File("model.zip"); ModelSerializer.writeModel(multiLayerNetwork,file,true); ModelSerializer.addNormalizerToModel(file,dataNormalization); } }