org.deeplearning4j.util.ModelSerializer Java Examples

The following examples show how to use org.deeplearning4j.util.ModelSerializer. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: TestUtils.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){

        MultiLayerNetwork restored;
        try {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ModelSerializer.writeModel(net, baos, true);
            byte[] bytes = baos.toByteArray();

            ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
            restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);

            assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
            assertEquals(net.params(), restored.params());
        } catch (IOException e){
            //Should never happen
            throw new RuntimeException(e);
        }

        //Also check the MultiLayerConfiguration is serializable (required by Spark etc)
        MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
        serializeDeserializeJava(conf);

        return restored;
    }
 
Example #2
Source File: TestUtils.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static ComputationGraph testModelSerialization(ComputationGraph net){

        ComputationGraph restored;
        try {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ModelSerializer.writeModel(net, baos, true);
            byte[] bytes = baos.toByteArray();

            ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
            restored = ModelSerializer.restoreComputationGraph(bais, true);

            assertEquals(net.getConfiguration(), restored.getConfiguration());
            assertEquals(net.params(), restored.params());
        } catch (IOException e){
            //Should never happen
            throw new RuntimeException(e);
        }

        //Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
        ComputationGraphConfiguration conf = net.getConfiguration();
        serializeDeserializeJava(conf);

        return restored;
    }
 
Example #3
Source File: YOLOModel.java    From java-ml-projects with Apache License 2.0 6 votes vote down vote up
public void init() {
	try {
		if (Objects.isNull(modelPath)) {

			yoloModel = (ComputationGraph) YOLO2.builder().build().initPretrained();
			setModelClasses(COCO_CLASSES);
		} else {
			yoloModel = ModelSerializer.restoreComputationGraph(modelPath);
			if (!(yoloModel.getOutputLayer(0) instanceof Yolo2OutputLayer)) {
				throw new Error("The model is not an YOLO model (output layer is not Yolo2OutputLayer)");
			}
			setModelClasses(classes.split("\\,"));
		}
		imageLoader = new NativeImageLoader(getInputWidth(), getInputHeight(), getInputChannels(),
				new ColorConversionTransform(COLOR_BGR2RGB));
		loadInputParameters();
	} catch (IOException e) {
		throw new Error("Not able to init the model", e);
	}
}
 
Example #4
Source File: DLModel.java    From java-ml-projects with Apache License 2.0 6 votes vote down vote up
public static DLModel fromFile(File file) throws Exception {
	Model model = null;
	try {
		System.out.println("Trying to load file as computation graph: " + file);
		model = ModelSerializer.restoreComputationGraph(file);
		System.out.println("Loaded Computation Graph.");
	} catch (Exception e) {
		try {
			System.out.println("Failed to load computation graph. Trying to load model.");
			model = ModelSerializer.restoreMultiLayerNetwork(file);
			System.out.println("Loaded Multilayernetwork");
		} catch (Exception e1) {
			System.out.println("Give up trying to load file: " + file);
			throw e;
		}
	}
	return new DLModel(model);
}
 
Example #5
Source File: TestUtils.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){

        MultiLayerNetwork restored;
        try {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ModelSerializer.writeModel(net, baos, true);
            byte[] bytes = baos.toByteArray();

            ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
            restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);

            assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
            assertEquals(net.params(), restored.params());
        } catch (IOException e){
            //Should never happen
            throw new RuntimeException(e);
        }

        //Also check the MultiLayerConfiguration is serializable (required by Spark etc)
        MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
        serializeDeserializeJava(conf);

        return restored;
    }
 
Example #6
Source File: InferenceExecutionerFactoryTests.java    From konduit-serving with Apache License 2.0 6 votes vote down vote up
@Test
public void testComputationGraph() throws Exception {
    Pair<MultiLayerNetwork, DataNormalization> trainedNetwork = TrainUtils.getTrainedNetwork();
    ComputationGraph save = trainedNetwork.getLeft().toComputationGraph();
    File dir = testDir.newFolder();
    File tmpZip = new File(dir, "dl4j_cg_model.zip");
    tmpZip.deleteOnExit();
    ModelSerializer.writeModel(save, tmpZip, true);

    ModelStep modelPipelineStep = Dl4jStep.builder()
            .inputName("default")
            .outputName("output")
            .path(tmpZip.getAbsolutePath())
            .build();

    Dl4jInferenceExecutionerFactory factory = new Dl4jInferenceExecutionerFactory();
    InitializedInferenceExecutionerConfig initializedInferenceExecutionerConfig = factory.create(modelPipelineStep);
    MultiComputationGraphInferenceExecutioner multiComputationGraphInferenceExecutioner = (MultiComputationGraphInferenceExecutioner) initializedInferenceExecutionerConfig.getInferenceExecutioner();
    assertNotNull(multiComputationGraphInferenceExecutioner);
    assertNotNull(multiComputationGraphInferenceExecutioner.model());
    assertNotNull(multiComputationGraphInferenceExecutioner.modelLoader());
}
 
Example #7
Source File: InferenceExecutionerFactoryTests.java    From konduit-serving with Apache License 2.0 6 votes vote down vote up
@Test
public void testMultiLayerNetwork() throws Exception {
    Pair<MultiLayerNetwork, DataNormalization> trainedNetwork = TrainUtils.getTrainedNetwork();
    MultiLayerNetwork save = trainedNetwork.getLeft();
    File dir = testDir.newFolder();
    File tmpZip = new File(dir, "dl4j_mln_model.zip");
    tmpZip.deleteOnExit();
    ModelSerializer.writeModel(save, tmpZip, true);

    ModelStep modelPipelineStep = Dl4jStep.builder()
            .inputName("default")
            .outputName("output")
            .path(tmpZip.getAbsolutePath())
            .build();

    Dl4jInferenceExecutionerFactory factory = new Dl4jInferenceExecutionerFactory();
    InitializedInferenceExecutionerConfig initializedInferenceExecutionerConfig = factory.create(modelPipelineStep);
    MultiLayerNetworkInferenceExecutioner multiLayerNetworkInferenceExecutioner = (MultiLayerNetworkInferenceExecutioner) initializedInferenceExecutionerConfig.getInferenceExecutioner();
    assertNotNull(multiLayerNetworkInferenceExecutioner);
    assertNotNull(multiLayerNetworkInferenceExecutioner.model());
    assertNotNull(multiLayerNetworkInferenceExecutioner.modelLoader());
}
 
Example #8
Source File: TestUtils.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){

        MultiLayerNetwork restored;
        try {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ModelSerializer.writeModel(net, baos, true);
            byte[] bytes = baos.toByteArray();

            ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
            restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);

            assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
            assertEquals(net.params(), restored.params());
        } catch (IOException e){
            //Should never happen
            throw new RuntimeException(e);
        }

        //Also check the MultiLayerConfiguration is serializable (required by Spark etc)
        MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
        serializeDeserializeJava(conf);

        return restored;
    }
 
Example #9
Source File: TestUtils.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static ComputationGraph testModelSerialization(ComputationGraph net){
    ComputationGraph restored;
    try {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        ModelSerializer.writeModel(net, baos, true);
        byte[] bytes = baos.toByteArray();

        ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
        restored = ModelSerializer.restoreComputationGraph(bais, true);

        assertEquals(net.getConfiguration(), restored.getConfiguration());
        assertEquals(net.params(), restored.params());
    } catch (IOException e){
        //Should never happen
        throw new RuntimeException(e);
    }

    //Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
    ComputationGraphConfiguration conf = net.getConfiguration();
    serializeDeserializeJava(conf);

    return restored;
}
 
Example #10
Source File: Vasttext.java    From scava with Eclipse Public License 2.0 6 votes vote down vote up
public void storeModel(File file) throws IOException
{
	HashMap<String, Object> configuration = new HashMap<String, Object>();
	
	//We do not store the updaters
	if(vasttextText!=null)
	{
		ModelSerializer.writeModel(vasttextText, file, false);
		configuration.put("typeVasttext", "onlyText");
	}
	else if(vasttextTextAndNumeric!=null)
	{
		ModelSerializer.writeModel(vasttextTextAndNumeric, file, false);
		configuration.put("typeVasttext", "textAndNumeric");
	}
	else
		throw new UnsupportedOperationException("Train before store model");
	
	configuration.put("multiLabelActivation", multiLabelActivation);
	configuration.put("multiLabel", multiLabel);
	configuration.put("dictionary", vectorizer.getDictionary());
	
	ModelSerializer.addObjectToFile(file, "vasttext.config", configuration);
}
 
Example #11
Source File: TestUtils.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static ComputationGraph testModelSerialization(ComputationGraph net){

        ComputationGraph restored;
        try {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ModelSerializer.writeModel(net, baos, true);
            byte[] bytes = baos.toByteArray();

            ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
            restored = ModelSerializer.restoreComputationGraph(bais, true);

            assertEquals(net.getConfiguration(), restored.getConfiguration());
            assertEquals(net.params(), restored.params());
        } catch (IOException e){
            //Should never happen
            throw new RuntimeException(e);
        }

        //Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
        ComputationGraphConfiguration conf = net.getConfiguration();
        serializeDeserializeJava(conf);

        return restored;
    }
 
Example #12
Source File: MultiLayerNetwork.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException {
    val mln = ModelSerializer.restoreMultiLayerNetwork(ois, true);

    this.defaultConfiguration = mln.defaultConfiguration.clone();
    this.layerWiseConfigurations = mln.layerWiseConfigurations.clone();
    this.init();
    this.flattenedParams.assign(mln.flattenedParams);

    int numWorkingMem = 2 * (layerWiseConfigurations.getConfs().size() + layerWiseConfigurations.getInputPreProcessors().size());
    WS_LAYER_WORKING_MEM_CONFIG = getLayerWorkingMemWSConfig(numWorkingMem);
    WS_LAYER_ACT_X_CONFIG = getLayerActivationWSConfig(layerWiseConfigurations.getConfs().size());

    if (mln.getUpdater() != null && mln.getUpdater(false).getStateViewArray() != null)
        this.getUpdater(true).getStateViewArray().assign(mln.getUpdater(false).getStateViewArray());
}
 
Example #13
Source File: TestUtils.java    From konduit-serving with Apache License 2.0 5 votes vote down vote up
public static InferenceConfiguration getConfig(TemporaryFolder trainDir) throws Exception {
    Pair<MultiLayerNetwork, DataNormalization> multiLayerNetwork = TrainUtils.getTrainedNetwork();
    File modelSave = trainDir.newFile("model.zip");
    ModelSerializer.writeModel(multiLayerNetwork.getFirst(), modelSave, false);

    Schema.Builder schemaBuilder = new Schema.Builder();
    schemaBuilder.addColumnDouble("petal_length")
            .addColumnDouble("petal_width")
            .addColumnDouble("sepal_width")
            .addColumnDouble("sepal_height");
    Schema inputSchema = schemaBuilder.build();

    Schema.Builder outputSchemaBuilder = new Schema.Builder();
    outputSchemaBuilder.addColumnDouble("setosa");
    outputSchemaBuilder.addColumnDouble("versicolor");
    outputSchemaBuilder.addColumnDouble("virginica");
    Schema outputSchema = outputSchemaBuilder.build();

    ServingConfig servingConfig = ServingConfig.builder()
            .createLoggingEndpoints(true)
            .build();

    Dl4jStep modelPipelineStep = Dl4jStep.builder()
            .inputName("default")
            .inputColumnName("default", SchemaTypeUtils.columnNames(inputSchema))
            .inputSchema("default", SchemaTypeUtils.typesForSchema(inputSchema))
            .outputSchema("default", SchemaTypeUtils.typesForSchema(outputSchema))
            .path(modelSave.getAbsolutePath())
            .outputColumnName("default", SchemaTypeUtils.columnNames(outputSchema))
            .build();

    return InferenceConfiguration.builder()
            .servingConfig(servingConfig)
            .step(modelPipelineStep)
            .build();
}
 
Example #14
Source File: ModelGuesser.java    From konduit-serving with Apache License 2.0 5 votes vote down vote up
/**
 * Load the model from the given file path
 *
 * @param path the path of the file to "guess"
 * @return the loaded model
 * @throws Exception if every model load attempt fails
 */
public static Model loadModelGuess(String path) throws Exception {
    try {
        return ModelSerializer.restoreMultiLayerNetwork(new File(path), true);
    } catch (Exception e) {
        log.warn("Tried multi layer network");
        try {
            return ModelSerializer.restoreComputationGraph(new File(path), true);
        } catch (Exception e1) {
            log.warn("Tried computation graph");
            try {
                return ModelSerializer.restoreMultiLayerNetwork(new File(path), false);
            } catch (Exception e4) {
                try {
                    return ModelSerializer.restoreComputationGraph(new File(path), false);
                } catch (Exception e5) {
                    try {
                        return KerasModelImport.importKerasModelAndWeights(path);
                    } catch (Exception e2) {
                        log.warn("Tried multi layer network keras");
                        try {
                            return KerasModelImport.importKerasSequentialModelAndWeights(path);

                        } catch (Exception e3) {
                            throw new ModelGuesserException("Unable to load model from path " + path
                                    + " (invalid model file or not a known model type)");
                        }
                    }
                }
            }
        }
    }
}
 
Example #15
Source File: OCNNOutputLayerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void testLabelProbabilities() throws Exception {
        Nd4j.getRandom().setSeed(42);
        DataSetIterator dataSetIterator = getNormalizedIterator();
        MultiLayerNetwork network = getSingleLayer();
        DataSet next = dataSetIterator.next();
        DataSet filtered = next.filterBy(new int[]{0, 1});
        for (int i = 0; i < 10; i++) {
            network.setEpochCount(i);
            network.getLayerWiseConfigurations().setEpochCount(i);
            network.fit(filtered);
        }

        DataSet anomalies = next.filterBy(new int[] {2});
        INDArray output = network.output(anomalies.getFeatures());
        INDArray normalOutput = network.output(anomalies.getFeatures(),false);
        assertEquals(output.lt(0.0).castTo(Nd4j.defaultFloatingPointType()).sumNumber().doubleValue(),
                normalOutput.eq(0.0).castTo(Nd4j.defaultFloatingPointType()).sumNumber().doubleValue(),1e-1);

//        System.out.println("Labels " + anomalies.getLabels());
//        System.out.println("Anomaly output " + normalOutput);
//        System.out.println(output);

        INDArray normalProbs = network.output(filtered.getFeatures());
        INDArray outputForNormalSamples = network.output(filtered.getFeatures(),false);
        System.out.println("Normal probabilities " + normalProbs);
        System.out.println("Normal raw output " + outputForNormalSamples);

        File tmpFile = new File(testDir.getRoot(),"tmp-file-" + UUID.randomUUID().toString());
        ModelSerializer.writeModel(network,tmpFile,true);
        tmpFile.deleteOnExit();

        MultiLayerNetwork multiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork(tmpFile);
        assertEquals(network.params(),multiLayerNetwork.params());
        assertEquals(network.numParams(),multiLayerNetwork.numParams());

    }
 
Example #16
Source File: TrainCifar10Model.java    From Java-Machine-Learning-for-Computer-Vision with MIT License 5 votes vote down vote up
public void loadTrainedModel(String preTrainedCifarModel) throws IOException {
    File file = new File(MODEL_SAVE_PATH +
            preTrainedCifarModel);
    log.info("loading model " + file);
    cifar10Transfer = ModelSerializer.
            restoreComputationGraph(file);
    log.info(cifar10Transfer.summary());
}
 
Example #17
Source File: TransferLearningVGG16.java    From Java-Machine-Learning-for-Computer-Vision with MIT License 5 votes vote down vote up
private void saveProgressEveryConfiguredInterval(ComputationGraph vgg16Transfer, int iEpoch, int
        iIteration) throws IOException {
    if (iIteration % SAVING_INTERVAL == 0 && iIteration != 0) {

        ModelSerializer.writeModel(vgg16Transfer, new File(SAVING_PATH + iIteration + "_epoch_" + iEpoch + ".zip"),
                false);
        evalOn(vgg16Transfer, neuralNetworkTrainingData.getDevIterator(), iIteration);
    }
}
 
Example #18
Source File: Model.java    From gluon-samples with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
private void loadModelLocal() {
    System.out.println("******LOAD TRAINED MODEL (local)******");
    try {
        InputStream is = Model.class.getResourceAsStream("/mymodel.zip");
        MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(is);
        is.close();
        nnModel.set(network);

    } catch (Throwable t) {
        t.printStackTrace();
    }
}
 
Example #19
Source File: TestComputationGraphNetwork.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testEpochCounter() throws Exception {

    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
            .graphBuilder()
            .addInputs("in")
            .addLayer("out", new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build(), "in")
            .setOutputs("out")
            .build();

    ComputationGraph net = new ComputationGraph(conf);
    net.init();

    assertEquals(0, net.getConfiguration().getEpochCount());


    DataSetIterator iter = new IrisDataSetIterator(150, 150);

    for( int i=0; i<4; i++ ){
        assertEquals(i, net.getConfiguration().getEpochCount());
        net.fit(iter);
        assertEquals(i+1, net.getConfiguration().getEpochCount());
    }

    assertEquals(4, net.getConfiguration().getEpochCount());

    ByteArrayOutputStream baos = new ByteArrayOutputStream();

    ModelSerializer.writeModel(net, baos, true);
    byte[] bytes = baos.toByteArray();

    ByteArrayInputStream bais = new ByteArrayInputStream(bytes);

    ComputationGraph restored = ModelSerializer.restoreComputationGraph(bais, true);
    assertEquals(4, restored.getConfiguration().getEpochCount());
}
 
Example #20
Source File: DeepAutoEncoderExample.java    From Java-for-Data-Science with MIT License 5 votes vote down vote up
public void retrieveModel() {
    try {
        modelFile = new File("savedModel");
        MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelFile);
    } catch (IOException ex) {
        ex.printStackTrace();
    }
}
 
Example #21
Source File: ModelSavingCallback.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * This method saves model
 *
 * @param model
 * @param filename
 */
protected void save(Model model, String filename) {
    try {
        ModelSerializer.writeModel(model, filename, true);
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
}
 
Example #22
Source File: ParallelInferenceTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Before
public void setUp() throws Exception {
    if (model == null) {
        File file = Resources.asFile("models/LenetMnistMLN.zip");
        model = ModelSerializer.restoreMultiLayerNetwork(file, true);

        iterator = new MnistDataSetIterator(1, false, 12345);
    }
}
 
Example #23
Source File: ModelGuesser.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * A facade for {@link ModelSerializer#restoreNormalizerFromFile(File)}
 * @param path the path to the file
 * @return the loaded normalizer
 */
public static Normalizer<?> loadNormalizer(String path) {
    try {
        return ModelSerializer.restoreNormalizerFromFile(new File(path));
    } catch (IOException e){
        throw new RuntimeException(e);
    }
}
 
Example #24
Source File: Classifier.java    From java-ml-projects with Apache License 2.0 5 votes vote down vote up
public static void init() throws IOException {
	String modelPath = Properties.classifierModelPath(); 
	labels = Properties.classifierLabels();
	int[] format = Properties.classifierInputFormat();
	loader = new NativeImageLoader(format[0], format[1], format[2]);
	model = ModelSerializer.restoreComputationGraph(modelPath);
	model.init();
}
 
Example #25
Source File: RegressionTest071.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void regressionTestLSTM1() throws Exception {

    File f = Resources.asFile("regression_testing/071/071_ModelSerializer_Regression_LSTM_1.zip");

    MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);

    MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
    assertEquals(3, conf.getConfs().size());

    GravesLSTM l0 = (GravesLSTM) conf.getConf(0).getLayer();
    assertEquals("tanh", l0.getActivationFn().toString());
    assertEquals(3, l0.getNIn());
    assertEquals(4, l0.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization());
    assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5);

    GravesBidirectionalLSTM l1 = (GravesBidirectionalLSTM) conf.getConf(1).getLayer();
    assertEquals("softsign", l1.getActivationFn().toString());
    assertEquals(4, l1.getNIn());
    assertEquals(4, l1.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization());
    assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);

    RnnOutputLayer l2 = (RnnOutputLayer) conf.getConf(2).getLayer();
    assertEquals(4, l2.getNIn());
    assertEquals(5, l2.getNOut());
    assertEquals("softmax", l2.getActivationFn().toString());
    assertTrue(l2.getLossFn() instanceof LossMCXENT);
}
 
Example #26
Source File: RegressionTest071.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void regressionTestMLP1() throws Exception {

    File f = Resources.asFile("regression_testing/071/071_ModelSerializer_Regression_MLP_1.zip");

    MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);

    MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
    assertEquals(2, conf.getConfs().size());

    DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer();
    assertEquals("relu", l0.getActivationFn().toString());
    assertEquals(3, l0.getNIn());
    assertEquals(4, l0.getNOut());
    assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
    assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater());
    assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6);

    OutputLayer l1 = (OutputLayer) conf.getConf(1).getLayer();
    assertEquals("softmax", l1.getActivationFn().toString());
    assertTrue(l1.getLossFn() instanceof LossMCXENT);
    assertEquals(4, l1.getNIn());
    assertEquals(5, l1.getNOut());
    assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
    assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
    assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
    assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);

    long numParams = (int)net.numParams();
    assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.params());
    int updaterSize = (int) new Nesterovs().stateSize(numParams);
    assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray());
}
 
Example #27
Source File: SaveLoadMultiLayerNetwork.java    From dl4j-tutorials with MIT License 5 votes vote down vote up
public static void main(String[] args) throws Exception {
    //Define a simple MultiLayerNetwork:
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
        .weightInit(WeightInit.XAVIER)
            .updater(new Nesterovs(0.01, 0.9))
        .list()
        .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH).build())
        .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX).nIn(3).nOut(3).build())
        .backprop(true).pretrain(false).build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();


    //Save the model
    File locationToSave = new File("model/MyMultiLayerNetwork.zip");      //Where to save the network. Note: the file is in .zip format - can be opened externally
    /**
     * 主要是用于保存模型的更新器信息
     * 如果模型保存之后还打算继续训练,则进行保存 -> true 才能根据后面的数据进行增量更新
     * 如果不打算继续训练 -> 模型定型之后,false
     */
    boolean saveUpdater = true;                                             //Updater: i.e., the state for Momentum, RMSProp, Adagrad etc. Save this if you want to train your network more in the future
    ModelSerializer.writeModel(net, locationToSave, saveUpdater);

    //Load the model
    MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(locationToSave);


    System.out.println("Saved and loaded parameters are equal:      " + net.params().equals(restored.params()));
    System.out.println("Saved and loaded configurations are equal:  " + net.getLayerWiseConfigurations().equals(restored.getLayerWiseConfigurations()));
}
 
Example #28
Source File: CheckpointListener.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Load a MultiLayerNetwork for the given checkpoint number
 *
 * @param rootDir       The directory that the checkpoint resides in
 * @param checkpointNum Checkpoint model to load
 * @return The loaded model
 */
public static MultiLayerNetwork loadCheckpointMLN(File rootDir, int checkpointNum){
    File f = getFileForCheckpoint(rootDir, checkpointNum);
    try {
        return ModelSerializer.restoreMultiLayerNetwork(f, true);
    } catch (IOException e){
        throw new RuntimeException(e);
    }
}
 
Example #29
Source File: Gan11Exemple.java    From dl4j-tutorials with MIT License 5 votes vote down vote up
public void saveModel(String modelName) {
    try {
        ModelSerializer.writeModel(net, modelName, true);
    } catch (IOException e) {
        e.printStackTrace();
    }
}
 
Example #30
Source File: Main.java    From gluon-samples with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
public static void main(String[] args) throws Exception {
    File f = new File(savedModelLocation);
    MultiLayerNetwork model = null;
    if (f.exists()) {
        LOGGER.info("Model exists, restore it");
        model = ModelSerializer.restoreMultiLayerNetwork(savedModelLocation);
        utils.evaluateModel(model);

    } else {
        LOGGER.info("Create model");
        model = utils.createModel();
        LOGGER.info("Train model");
        utils.trainModel(model, true, null, -1);
        LOGGER.info("Save model");
        utils.saveModel(model, savedModelLocation);
        LOGGER.info("Eval model");
        utils.evaluateModel(model);
    }
    LOGGER.info("Run tests");
    runTests(model);
    LOGGER.info("Evaluate model after tests");
    utils.evaluateModel(model);
    LOGGER.info("Correct Image");
    correctImage(model, Main.class.getResourceAsStream("/mytestdata/3b.png"),3);
    utils.evaluateModel(model);

}