package ml; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.impl.collection.ListStringRecordReader; import org.datavec.api.split.ListStringSplit; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.eval.meta.Prediction; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; import org.nd4j.linalg.cpu.nativecpu.CpuNDArrayFactory; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.SplitTestAndTrain; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.nd4j.linalg.util.NDArrayUtil; import org.neo4j.graphdb.Label; import org.neo4j.graphdb.Node; import org.neo4j.helpers.collection.MapUtil; import result.VirtualNode; import java.util.*; import java.util.function.Function; import java.util.stream.Collectors; /** * @author mh * @since 23.07.17 */ public class DL4JMLModel extends MLModel<List<String>> { private MultiLayerNetwork model; public DL4JMLModel(String name, Map<String, String> types, String output, Map<String, Object> config) { super(name, types, output, config); } @Override protected List<String> asRow(Map<String, Object> inputs, Object output) { List<String> row = new ArrayList<>(inputs.size() + (output == null ? 0 : 1)); for (String k : inputs.keySet()) { row.add(offsets.get(k), inputs.get(k).toString()); } if (output != null) { row.add(offsets.get(this.output), output.toString()); } return row; } @Override protected Object doPredict(List<String> line) { try { ListStringSplit input = new ListStringSplit(Collections.singletonList(line)); ListStringRecordReader rr = new ListStringRecordReader(); rr.initialize(input); DataSetIterator iterator = new RecordReaderDataSetIterator(rr, 1); DataSet ds = iterator.next(); INDArray prediction = model.output(ds.getFeatures()); DataType outputType = types.get(this.output); switch (outputType) { case _float : return prediction.getDouble(0); case _class: { int numClasses = 2; double max = 0; int maxIndex = -1; for (int i=0;i<numClasses;i++) { if (prediction.getDouble(i) > max) {maxIndex = i; max = prediction.getDouble(i);} } return maxIndex; // return prediction.getInt(0,1); // numberOfClasses } default: throw new IllegalArgumentException("Output type not yet supported "+outputType); } } catch (Exception e) { throw new RuntimeException(e); } } @Override protected void doTrain() { try { long seed = config.seed.get(); double learningRate = config.learningRate.get(); int nEpochs = config.epochs.get(); int numOutputs = 1; int numInputs = types.size() - numOutputs; int outputOffset = offsets.get(output); // last column int numHiddenNodes = config.hidden.get(); double trainPercent = config.trainPercent.get(); int batchSize = rows.size(); // full dataset size Map<String,Set<String>> classes = new HashMap<>(); types.entrySet().stream() .filter(e -> e.getValue() == DataType._class) .map(e -> new HashMap.SimpleEntry<>(e.getKey(), offsets.get(e.getKey()))) .forEach(e -> classes.put(e.getKey(),rows.parallelStream().map(r -> r.get(e.getValue())).distinct().collect(Collectors.toSet()))); int numberOfClasses = (int)classes.get("output").size(); System.out.println("labels = " + classes); ListStringSplit input = new ListStringSplit(rows); ListStringRecordReader rr = new ListStringRecordReader(); rr.initialize(input); RecordReaderDataSetIterator iterator = new RecordReaderDataSetIterator(rr, batchSize, outputOffset, numberOfClasses); iterator.setCollectMetaData(true); // Instruct the iterator to collect metadata, and store it in the DataSet objects DataSet allData = iterator.next(); allData.shuffle(seed); SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(trainPercent); //Use 65% of data for training DataSet trainingData = testAndTrain.getTrain(); DataSet testData = testAndTrain.getTest(); //Normalize data as per basic CSV example // NormalizerStandardize normalizer = new NormalizerStandardize(); NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(); normalizer.fitLabel(true); normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data normalizer.transform(trainingData); //Apply normalization to the training data normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set //Let's view the example metadata in the training and test sets: List<RecordMetaData> trainMetaData = trainingData.getExampleMetaData(RecordMetaData.class); List<RecordMetaData> testMetaData = testData.getExampleMetaData(RecordMetaData.class); //Let's show specifically which examples are in the training and test sets, using the collected metadata // System.out.println(" +++++ Training Set Examples MetaData +++++"); // String format = "%-20s\t%s"; // for(RecordMetaData recordMetaData : trainMetaData){ // System.out.println(String.format(format, recordMetaData.getLocation(), recordMetaData.getURI())); // //Also available: recordMetaData.getReaderClass() // } // System.out.println("\n\n +++++ Test Set Examples MetaData +++++"); // for(RecordMetaData recordMetaData : testMetaData){ // System.out.println(recordMetaData.getLocation()); // } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .iterations(1) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .learningRate(learningRate) .updater(Updater.NESTEROVS).momentum(0.9) .list() .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes) .weightInit(WeightInit.XAVIER) .activation(Activation.RELU) .build()) .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) .weightInit(WeightInit.XAVIER) .activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER) .nIn(numHiddenNodes).nOut(numberOfClasses).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); model.setListeners(new ScoreIterationListener(10)); //Print score every 10 parameter updates for (int n = 0; n < nEpochs; n++) { model.fit(trainingData); } System.out.println("Evaluate model...."); INDArray output = model.output(testData.getFeatureMatrix(),false); Evaluation eval = new Evaluation(numberOfClasses); eval.eval(testData.getLabels(), output, testMetaData); //Note we are passing in the test set metadata here List<Prediction> predictionErrors = eval.getPredictionErrors(); System.out.println("\n\n+++++ Prediction Errors +++++"); for(Prediction p : predictionErrors){ System.out.printf("Predicted class: %d, Actual class: %d\t%s%n", p.getPredictedClass(), p.getActualClass(), p.getRecordMetaData(RecordMetaData.class)); } //Print the evaluation statistics System.out.println(eval.stats()); this.model = model; this.state = State.ready; } catch (Exception e) { throw new RuntimeException(e); } } @Override List<Node> show() { if ( state != State.ready ) throw new IllegalStateException("Model not trained yet"); List<Node> result = new ArrayList<>(); int layerCount = model.getnLayers(); for (Layer layer : model.getLayers()) { Node node = node("Layer", "type", layer.type().name(), "index", layer.getIndex(), "pretrainLayer", layer.isPretrainLayer(), "miniBatchSize", layer.getInputMiniBatchSize(), "numParams", layer.numParams()); if (layer instanceof DenseLayer) { DenseLayer dl = (DenseLayer) layer; node.addLabel(Label.label("DenseLayer")); node.setProperty("activation",dl.getActivationFn().toString()); // todo parameters node.setProperty("biasInit",dl.getBiasInit()); node.setProperty("biasLearningRate",dl.getBiasLearningRate()); node.setProperty("l1",dl.getL1()); node.setProperty("l1Bias",dl.getL1Bias()); node.setProperty("l2",dl.getL2()); node.setProperty("l2Bias",dl.getL2Bias()); node.setProperty("distribution",dl.getDist().toString()); node.setProperty("in",dl.getNIn()); node.setProperty("out",dl.getNOut()); } result.add(node); // layer.preOutput(allOne, Layer.TrainingMode.TEST); // layer.p(allOne, Layer.TrainingMode.TEST); // layer.activate(allOne, Layer.TrainingMode.TEST); } return result; } private Node node(String label, Object...keyValues) { return new VirtualNode(new Label[] {Label.label(label)}, MapUtil.map(keyValues),null); } }