Java Code Examples for org.deeplearning4j.util.ModelSerializer

The following examples show how to use org.deeplearning4j.util.ModelSerializer. These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source Project: deeplearning4j   Source File: TestUtils.java    License: Apache License 2.0 6 votes vote down vote up
public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){

        MultiLayerNetwork restored;
        try {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ModelSerializer.writeModel(net, baos, true);
            byte[] bytes = baos.toByteArray();

            ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
            restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);

            assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
            assertEquals(net.params(), restored.params());
        } catch (IOException e){
            //Should never happen
            throw new RuntimeException(e);
        }

        //Also check the MultiLayerConfiguration is serializable (required by Spark etc)
        MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
        serializeDeserializeJava(conf);

        return restored;
    }
 
Example 2
@Test
public void testMultiLayerNetwork() throws Exception {
    Pair<MultiLayerNetwork, DataNormalization> trainedNetwork = TrainUtils.getTrainedNetwork();
    MultiLayerNetwork save = trainedNetwork.getLeft();
    File dir = testDir.newFolder();
    File tmpZip = new File(dir, "dl4j_mln_model.zip");
    tmpZip.deleteOnExit();
    ModelSerializer.writeModel(save, tmpZip, true);

    ModelStep modelPipelineStep = Dl4jStep.builder()
            .inputName("default")
            .outputName("output")
            .path(tmpZip.getAbsolutePath())
            .build();

    Dl4jInferenceExecutionerFactory factory = new Dl4jInferenceExecutionerFactory();
    InitializedInferenceExecutionerConfig initializedInferenceExecutionerConfig = factory.create(modelPipelineStep);
    MultiLayerNetworkInferenceExecutioner multiLayerNetworkInferenceExecutioner = (MultiLayerNetworkInferenceExecutioner) initializedInferenceExecutionerConfig.getInferenceExecutioner();
    assertNotNull(multiLayerNetworkInferenceExecutioner);
    assertNotNull(multiLayerNetworkInferenceExecutioner.model());
    assertNotNull(multiLayerNetworkInferenceExecutioner.modelLoader());
}
 
Example 3
@Test
public void testComputationGraph() throws Exception {
    Pair<MultiLayerNetwork, DataNormalization> trainedNetwork = TrainUtils.getTrainedNetwork();
    ComputationGraph save = trainedNetwork.getLeft().toComputationGraph();
    File dir = testDir.newFolder();
    File tmpZip = new File(dir, "dl4j_cg_model.zip");
    tmpZip.deleteOnExit();
    ModelSerializer.writeModel(save, tmpZip, true);

    ModelStep modelPipelineStep = Dl4jStep.builder()
            .inputName("default")
            .outputName("output")
            .path(tmpZip.getAbsolutePath())
            .build();

    Dl4jInferenceExecutionerFactory factory = new Dl4jInferenceExecutionerFactory();
    InitializedInferenceExecutionerConfig initializedInferenceExecutionerConfig = factory.create(modelPipelineStep);
    MultiComputationGraphInferenceExecutioner multiComputationGraphInferenceExecutioner = (MultiComputationGraphInferenceExecutioner) initializedInferenceExecutionerConfig.getInferenceExecutioner();
    assertNotNull(multiComputationGraphInferenceExecutioner);
    assertNotNull(multiComputationGraphInferenceExecutioner.model());
    assertNotNull(multiComputationGraphInferenceExecutioner.modelLoader());
}
 
Example 4
Source Project: deeplearning4j   Source File: TestUtils.java    License: Apache License 2.0 6 votes vote down vote up
public static ComputationGraph testModelSerialization(ComputationGraph net){

        ComputationGraph restored;
        try {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ModelSerializer.writeModel(net, baos, true);
            byte[] bytes = baos.toByteArray();

            ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
            restored = ModelSerializer.restoreComputationGraph(bais, true);

            assertEquals(net.getConfiguration(), restored.getConfiguration());
            assertEquals(net.params(), restored.params());
        } catch (IOException e){
            //Should never happen
            throw new RuntimeException(e);
        }

        //Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
        ComputationGraphConfiguration conf = net.getConfiguration();
        serializeDeserializeJava(conf);

        return restored;
    }
 
Example 5
Source Project: java-ml-projects   Source File: YOLOModel.java    License: Apache License 2.0 6 votes vote down vote up
public void init() {
	try {
		if (Objects.isNull(modelPath)) {

			yoloModel = (ComputationGraph) YOLO2.builder().build().initPretrained();
			setModelClasses(COCO_CLASSES);
		} else {
			yoloModel = ModelSerializer.restoreComputationGraph(modelPath);
			if (!(yoloModel.getOutputLayer(0) instanceof Yolo2OutputLayer)) {
				throw new Error("The model is not an YOLO model (output layer is not Yolo2OutputLayer)");
			}
			setModelClasses(classes.split("\\,"));
		}
		imageLoader = new NativeImageLoader(getInputWidth(), getInputHeight(), getInputChannels(),
				new ColorConversionTransform(COLOR_BGR2RGB));
		loadInputParameters();
	} catch (IOException e) {
		throw new Error("Not able to init the model", e);
	}
}
 
Example 6
Source Project: java-ml-projects   Source File: DLModel.java    License: Apache License 2.0 6 votes vote down vote up
public static DLModel fromFile(File file) throws Exception {
	Model model = null;
	try {
		System.out.println("Trying to load file as computation graph: " + file);
		model = ModelSerializer.restoreComputationGraph(file);
		System.out.println("Loaded Computation Graph.");
	} catch (Exception e) {
		try {
			System.out.println("Failed to load computation graph. Trying to load model.");
			model = ModelSerializer.restoreMultiLayerNetwork(file);
			System.out.println("Loaded Multilayernetwork");
		} catch (Exception e1) {
			System.out.println("Give up trying to load file: " + file);
			throw e;
		}
	}
	return new DLModel(model);
}
 
Example 7
Source Project: deeplearning4j   Source File: TestUtils.java    License: Apache License 2.0 6 votes vote down vote up
public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){

        MultiLayerNetwork restored;
        try {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ModelSerializer.writeModel(net, baos, true);
            byte[] bytes = baos.toByteArray();

            ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
            restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);

            assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
            assertEquals(net.params(), restored.params());
        } catch (IOException e){
            //Should never happen
            throw new RuntimeException(e);
        }

        //Also check the MultiLayerConfiguration is serializable (required by Spark etc)
        MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
        serializeDeserializeJava(conf);

        return restored;
    }
 
Example 8
Source Project: deeplearning4j   Source File: TestUtils.java    License: Apache License 2.0 6 votes vote down vote up
public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){

        MultiLayerNetwork restored;
        try {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ModelSerializer.writeModel(net, baos, true);
            byte[] bytes = baos.toByteArray();

            ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
            restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);

            assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
            assertEquals(net.params(), restored.params());
        } catch (IOException e){
            //Should never happen
            throw new RuntimeException(e);
        }

        //Also check the MultiLayerConfiguration is serializable (required by Spark etc)
        MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
        serializeDeserializeJava(conf);

        return restored;
    }
 
Example 9
Source Project: scava   Source File: Vasttext.java    License: Eclipse Public License 2.0 6 votes vote down vote up
public void storeModel(File file) throws IOException
{
	HashMap<String, Object> configuration = new HashMap<String, Object>();
	
	//We do not store the updaters
	if(vasttextText!=null)
	{
		ModelSerializer.writeModel(vasttextText, file, false);
		configuration.put("typeVasttext", "onlyText");
	}
	else if(vasttextTextAndNumeric!=null)
	{
		ModelSerializer.writeModel(vasttextTextAndNumeric, file, false);
		configuration.put("typeVasttext", "textAndNumeric");
	}
	else
		throw new UnsupportedOperationException("Train before store model");
	
	configuration.put("multiLabelActivation", multiLabelActivation);
	configuration.put("multiLabel", multiLabel);
	configuration.put("dictionary", vectorizer.getDictionary());
	
	ModelSerializer.addObjectToFile(file, "vasttext.config", configuration);
}
 
Example 10
Source Project: deeplearning4j   Source File: TestUtils.java    License: Apache License 2.0 6 votes vote down vote up
public static ComputationGraph testModelSerialization(ComputationGraph net){

        ComputationGraph restored;
        try {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ModelSerializer.writeModel(net, baos, true);
            byte[] bytes = baos.toByteArray();

            ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
            restored = ModelSerializer.restoreComputationGraph(bais, true);

            assertEquals(net.getConfiguration(), restored.getConfiguration());
            assertEquals(net.params(), restored.params());
        } catch (IOException e){
            //Should never happen
            throw new RuntimeException(e);
        }

        //Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
        ComputationGraphConfiguration conf = net.getConfiguration();
        serializeDeserializeJava(conf);

        return restored;
    }
 
Example 11
Source Project: deeplearning4j   Source File: TestUtils.java    License: Apache License 2.0 6 votes vote down vote up
public static ComputationGraph testModelSerialization(ComputationGraph net){
    ComputationGraph restored;
    try {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        ModelSerializer.writeModel(net, baos, true);
        byte[] bytes = baos.toByteArray();

        ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
        restored = ModelSerializer.restoreComputationGraph(bais, true);

        assertEquals(net.getConfiguration(), restored.getConfiguration());
        assertEquals(net.params(), restored.params());
    } catch (IOException e){
        //Should never happen
        throw new RuntimeException(e);
    }

    //Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
    ComputationGraphConfiguration conf = net.getConfiguration();
    serializeDeserializeJava(conf);

    return restored;
}
 
Example 12
public static INDArray generateOutput(File inputFile, String modelFilePath) throws IOException, InterruptedException {
    final File modelFile = new File(modelFilePath);
    final MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(modelFile);
    final RecordReader recordReader = generateReader(inputFile);
    //final INDArray array = RecordConverter.toArray(recordReader.next());
    final NormalizerStandardize normalizerStandardize = ModelSerializer.restoreNormalizerFromFile(modelFile);
    //normalizerStandardize.transform(array);
    final DataSetIterator dataSetIterator = new RecordReaderDataSetIterator.Builder(recordReader,1).build();
    normalizerStandardize.fit(dataSetIterator);
    dataSetIterator.setPreProcessor(normalizerStandardize);
    return network.output(dataSetIterator);

}
 
Example 13
Source Project: konduit-serving   Source File: KerasDl4jHandler.java    License: Apache License 2.0 5 votes vote down vote up
@Override
public void handle(RoutingContext event) {
    File kerasFile = getTmpFileWithContext(event);
    ModelType type = getTypeFromContext(event);
    try {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        switch (type) {
            case FUNCTIONAL:
                ComputationGraph computationGraph = KerasModelImport.importKerasModelAndWeights(kerasFile.getAbsolutePath());
                ModelSerializer.writeModel(computationGraph, byteArrayOutputStream, true);
                break;
            case SEQUENTIAL:
                MultiLayerNetwork multiLayerConfiguration = KerasModelImport.importKerasSequentialModelAndWeights(kerasFile.getAbsolutePath());
                ModelSerializer.writeModel(multiLayerConfiguration, byteArrayOutputStream, true);
                break;
        }

        Buffer buffer = Buffer.buffer(byteArrayOutputStream.toByteArray());
        File newFile = new File("tmpFile-" + UUID.randomUUID().toString() + ".xml");
        FileUtils.writeByteArrayToFile(newFile, buffer.getBytes());
        event.response().sendFile(newFile.getAbsolutePath(), resultHandler -> {
            if (resultHandler.failed()) {
                resultHandler.cause().printStackTrace();
                event.response().setStatusCode(HttpStatus.SC_INTERNAL_SERVER_ERROR);

            } else {
                event.response().setStatusCode(200);
            }
        });

        event.response().exceptionHandler(Throwable::printStackTrace);

    } catch (Exception e) {
        event.response().setStatusCode(HttpStatus.SC_INTERNAL_SERVER_ERROR);
        event.response().setStatusMessage("Error importing model " + e.getMessage());
    }
}
 
Example 14
@Override
public Buffer saveModel(ComputationGraph model) {
    ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
    try {
        ModelSerializer.writeModel(model, byteArrayOutputStream, true);
    } catch (IOException e) {
        throw new RuntimeException(e);
    }

    return Buffer.buffer(byteArrayOutputStream.toByteArray());
}
 
Example 15
@Override
public Buffer saveModel(ComputationGraph model) {
    ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
    try {
        ModelSerializer.writeModel(model, byteArrayOutputStream, true);
    } catch (IOException e) {
        throw new RuntimeException(e);
    }

    return Buffer.buffer(byteArrayOutputStream.toByteArray());
}
 
Example 16
Source Project: deeplearning4j   Source File: MultiLayerNetwork.java    License: Apache License 2.0 5 votes vote down vote up
protected void update(Task task) {
    if (!initDone) {
        initDone = true;
        Heartbeat heartbeat = Heartbeat.getInstance();
        task = ModelSerializer.taskByModel(this);
        Environment env = EnvironmentUtils.buildEnvironment();
        heartbeat.reportEvent(Event.STANDALONE, env, task);
    }
}
 
Example 17
@Override
public Buffer saveModel(@NonNull MultiLayerNetwork model) {
    ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
    try {
        ModelSerializer.writeModel(model, byteArrayOutputStream, true);
    } catch (IOException e) {
        throw new RuntimeException(e);
    }

    return Buffer.buffer(byteArrayOutputStream.toByteArray());
}
 
Example 18
Source Project: konduit-serving   Source File: ModelGuesser.java    License: Apache License 2.0 5 votes vote down vote up
/**
 * Loads a dl4j zip file (either computation graph or multi layer network)
 *
 * @param path the path to the file to load
 * @return a loaded dl4j model
 * @throws Exception if loading a dl4j model fails
 */
public static Model loadDl4jGuess(String path) throws Exception {
    if (isZipFile(new File(path))) {
        log.debug("Loading file " + path);
        boolean compGraph = false;
        try (ZipFile zipFile = new ZipFile(path)) {
            List<String> collect = zipFile.stream().map(ZipEntry::getName)
                    .collect(Collectors.toList());
            log.debug("Entries " + collect);
            if (collect.contains(ModelSerializer.COEFFICIENTS_BIN) && collect.contains(ModelSerializer.CONFIGURATION_JSON)) {
                ZipEntry entry = zipFile.getEntry(ModelSerializer.CONFIGURATION_JSON);
                log.debug("Loaded configuration");
                try (InputStream is = zipFile.getInputStream(entry)) {
                    String configJson = IOUtils.toString(is, StandardCharsets.UTF_8);
                    JSONObject jsonObject = new JSONObject(configJson);
                    if (jsonObject.has("vertexInputs")) {
                        log.debug("Loading computation graph.");
                        compGraph = true;
                    } else {
                        log.debug("Loading multi layer network.");
                    }

                }
            }
        }

        if (compGraph) {
            return ModelSerializer.restoreComputationGraph(new File(path));
        } else {
            return ModelSerializer.restoreMultiLayerNetwork(new File(path));
        }
    }

    return null;
}
 
Example 19
Source Project: konduit-serving   Source File: ModelGuesser.java    License: Apache License 2.0 5 votes vote down vote up
/**
 * Load the model from the given file path
 *
 * @param path the path of the file to "guess"
 * @return the loaded model
 * @throws Exception if every model load attempt fails
 */
public static Model loadModelGuess(String path) throws Exception {
    try {
        return ModelSerializer.restoreMultiLayerNetwork(new File(path), true);
    } catch (Exception e) {
        log.warn("Tried multi layer network");
        try {
            return ModelSerializer.restoreComputationGraph(new File(path), true);
        } catch (Exception e1) {
            log.warn("Tried computation graph");
            try {
                return ModelSerializer.restoreMultiLayerNetwork(new File(path), false);
            } catch (Exception e4) {
                try {
                    return ModelSerializer.restoreComputationGraph(new File(path), false);
                } catch (Exception e5) {
                    try {
                        return KerasModelImport.importKerasModelAndWeights(path);
                    } catch (Exception e2) {
                        log.warn("Tried multi layer network keras");
                        try {
                            return KerasModelImport.importKerasSequentialModelAndWeights(path);

                        } catch (Exception e3) {
                            throw new ModelGuesserException("Unable to load model from path " + path
                                    + " (invalid model file or not a known model type)");
                        }
                    }
                }
            }
        }
    }
}
 
Example 20
Source Project: konduit-serving   Source File: TestUtils.java    License: Apache License 2.0 5 votes vote down vote up
public static InferenceConfiguration getConfig(TemporaryFolder trainDir) throws Exception {
    Pair<MultiLayerNetwork, DataNormalization> multiLayerNetwork = TrainUtils.getTrainedNetwork();
    File modelSave = trainDir.newFile("model.zip");
    ModelSerializer.writeModel(multiLayerNetwork.getFirst(), modelSave, false);

    Schema.Builder schemaBuilder = new Schema.Builder();
    schemaBuilder.addColumnDouble("petal_length")
            .addColumnDouble("petal_width")
            .addColumnDouble("sepal_width")
            .addColumnDouble("sepal_height");
    Schema inputSchema = schemaBuilder.build();

    Schema.Builder outputSchemaBuilder = new Schema.Builder();
    outputSchemaBuilder.addColumnDouble("setosa");
    outputSchemaBuilder.addColumnDouble("versicolor");
    outputSchemaBuilder.addColumnDouble("virginica");
    Schema outputSchema = outputSchemaBuilder.build();

    ServingConfig servingConfig = ServingConfig.builder()
            .createLoggingEndpoints(true)
            .build();

    Dl4jStep modelPipelineStep = Dl4jStep.builder()
            .inputName("default")
            .inputColumnName("default", SchemaTypeUtils.columnNames(inputSchema))
            .inputSchema("default", SchemaTypeUtils.typesForSchema(inputSchema))
            .outputSchema("default", SchemaTypeUtils.typesForSchema(outputSchema))
            .path(modelSave.getAbsolutePath())
            .outputColumnName("default", SchemaTypeUtils.columnNames(outputSchema))
            .build();

    return InferenceConfiguration.builder()
            .servingConfig(servingConfig)
            .step(modelPipelineStep)
            .build();
}
 
Example 21
Source Project: deeplearning4j   Source File: RegressionTest050.java    License: Apache License 2.0 5 votes vote down vote up
@Test
public void regressionTestMLP2() throws Exception {

    File f = Resources.asFile("regression_testing/050/050_ModelSerializer_Regression_MLP_2.zip");

    MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);

    MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
    assertEquals(2, conf.getConfs().size());

    DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer();
    assertTrue(l0.getActivationFn() instanceof ActivationLReLU);
    assertEquals(3, l0.getNIn());
    assertEquals(4, l0.getNOut());
    assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn());
    assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
    assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
    assertEquals(new Dropout(0.6), l0.getIDropout());
    assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
    assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l0));

    OutputLayer l1 = (OutputLayer) conf.getConf(1).getLayer();
    assertEquals("identity", l1.getActivationFn().toString());
    assertTrue(l1.getLossFn() instanceof LossMSE);
    assertEquals(4, l1.getNIn());
    assertEquals(5, l1.getNOut());
    assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn());
    assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater());
    assertEquals(0.15, ((RmsProp)l1.getIUpdater()).getLearningRate(), 1e-6);
    assertEquals(new Dropout(0.6), l1.getIDropout());
    assertEquals(0.1, TestUtils.getL1(l1), 1e-6);
    assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l1));

    int numParams = (int)net.numParams();
    assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
    int updaterSize = (int) new RmsProp().stateSize(numParams);
    assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
}
 
Example 22
Source Project: konduit-serving   Source File: BaseDl4JVerticalTest.java    License: Apache License 2.0 5 votes vote down vote up
@Override
public JsonObject getConfigObject() throws Exception {
    Pair<MultiLayerNetwork, DataNormalization> multiLayerNetwork = getTrainedNetwork();
    File modelSave = new File(temporary.getRoot(), "model.zip");
    ModelSerializer.writeModel(multiLayerNetwork.getFirst(), modelSave, true);

    Schema.Builder schemaBuilder = new Schema.Builder();
    schemaBuilder.addColumnDouble("petal_length")
            .addColumnDouble("petal_width")
            .addColumnDouble("sepal_width")
            .addColumnDouble("sepal_height");
    Schema inputSchema = schemaBuilder.build();

    Schema.Builder outputSchemaBuilder = new Schema.Builder();
    outputSchemaBuilder.addColumnDouble("setosa");
    outputSchemaBuilder.addColumnDouble("versicolor");
    outputSchemaBuilder.addColumnDouble("virginica");
    Schema outputSchema = outputSchemaBuilder.build();

    Nd4j.getRandom().setSeed(42);

    ServingConfig servingConfig = ServingConfig.builder()
            .httpPort(port)
            .build();

    ModelStep modelPipelineStep = Dl4jStep.builder()
            .path(modelSave.getAbsolutePath())
            .build()
            .setInput(inputSchema)
            .setOutput(outputSchema);

    InferenceConfiguration inferenceConfiguration = InferenceConfiguration.builder()
            .servingConfig(servingConfig)
            .step(modelPipelineStep)
            .build();

    return new JsonObject(inferenceConfiguration.toJson());
}
 
Example 23
@Override
public JsonObject getConfigObject() throws Exception {
    Pair<MultiLayerNetwork, DataNormalization> multiLayerNetwork = getTrainedNetwork();
    File modelSave = new File(temporary.getRoot(), "model.zip");
    ModelSerializer.writeModel(multiLayerNetwork.getFirst(), modelSave, true);

    inputSchema = TrainUtils.getIrisInputSchema();
    Schema outputSchema = getIrisOutputSchema();
    Nd4j.getRandom().setSeed(42);

    TransformProcess.Builder transformProcessBuilder = new TransformProcess.Builder(inputSchema);
    for (int i = 0; i < inputSchema.numColumns(); i++) {
        transformProcessBuilder.convertToDouble(inputSchema.getName(i));
    }

    TransformProcess transformProcess = transformProcessBuilder.build();

    TransformProcessStep transformStep = new TransformProcessStep(transformProcess, outputSchema);

    ServingConfig servingConfig = ServingConfig.builder()
            .httpPort(port)
            .build();

    ModelStep modelStepConfig = Dl4jStep.builder().path(modelSave.getAbsolutePath()).build()
            .setInput(inputSchema)
            .setOutput(outputSchema);

    InferenceConfiguration inferenceConfiguration = InferenceConfiguration.builder()
            .servingConfig(servingConfig)
            .step(transformStep)
            .step(modelStepConfig)
            .build();

    System.out.println(inferenceConfiguration.toJson());
    return new JsonObject(inferenceConfiguration.toJson());
}
 
Example 24
Source Project: FancyBing   Source File: PolicyNetUtil.java    License: GNU General Public License v3.0 5 votes vote down vote up
public static ComputationGraph loadComputationGraph(String fn) throws Exception {
    	System.err.println("Loading model...");
//    	File locationToSave = new File(System.getProperty("user.dir") + "/model/" + fn);
    	File locationToSave = new File("D:\\workspace\\fancybing-train\\model\\" + fn);
    	ComputationGraph model = ModelSerializer.restoreComputationGraph(locationToSave);
	
		return model;
    }
 
Example 25
@Test
public void testIterationCountAndPersistence() throws IOException {
    Nd4j.getRandom().setSeed(123);
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
            .graphBuilder().addInputs("in")
            .addLayer("0", new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER)
                    .activation(Activation.TANH).build(), "in")
            .addLayer("1", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
                            LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
                            .build(),
                    "0")
            .setOutputs("1").build();


    ComputationGraph network = new ComputationGraph(conf);
    network.init();

    DataSetIterator iter = new IrisDataSetIterator(50, 150);

    assertEquals(0, network.getConfiguration().getIterationCount());
    network.fit(iter);
    assertEquals(3, network.getConfiguration().getIterationCount());
    iter.reset();
    network.fit(iter);
    assertEquals(6, network.getConfiguration().getIterationCount());
    iter.reset();
    network.fit(iter.next());
    assertEquals(7, network.getConfiguration().getIterationCount());

    ByteArrayOutputStream baos = new ByteArrayOutputStream();
    ModelSerializer.writeModel(network, baos, true);
    byte[] asBytes = baos.toByteArray();

    ByteArrayInputStream bais = new ByteArrayInputStream(asBytes);
    ComputationGraph net = ModelSerializer.restoreComputationGraph(bais, true);
    assertEquals(7, net.getConfiguration().getIterationCount());
}
 
Example 26
Source Project: FancyBing   Source File: PolicyNetService.java    License: GNU General Public License v3.0 5 votes vote down vote up
private static ComputationGraph loadComputationGraph(String fn) throws Exception {
  	File f = new File(System.getProperty("user.dir") + "/model/" + fn);
  	System.out.println("Loading model " + f);
  	ComputationGraph model = ModelSerializer.restoreComputationGraph(f);
	
return model;
  }
 
Example 27
Source Project: FancyBing   Source File: TrainUtil.java    License: GNU General Public License v3.0 5 votes vote down vote up
public static String saveModel(String name, Model model, int index, int accuracy) throws Exception {
  	System.err.println("Saving model, don't shutdown...");
      try {
      	String fn = name + "_idx_" + index + "_" + accuracy + ".zip";
	File locationToSave = new File(System.getProperty("user.dir") + "/model/" + fn);
	boolean saveUpdater = true;                                             //Updater: i.e., the state for Momentum, RMSProp, Adagrad etc. Save this if you want to train your network more in the future
	ModelSerializer.writeModel(model, locationToSave, saveUpdater);
	System.err.println("Model saved");
	return fn;
} catch (IOException e) {
	System.err.println("Save model failed");
	e.printStackTrace();
	throw e;
}
  }
 
Example 28
Source Project: FancyBing   Source File: TrainUtil.java    License: GNU General Public License v3.0 5 votes vote down vote up
public static MultiLayerNetwork loadNetwork(String fn, double learningRate) throws Exception {
  	System.err.println("Loading model...");
  	File locationToSave = new File(System.getProperty("user.dir") + "/model/" + fn);
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(locationToSave);

int numLayers = model.getnLayers();
for (int i = 0; i < numLayers; i++) {
	model.getLayer(i).conf().setLearningRateByParam("W", learningRate);
	model.getLayer(i).conf().setLearningRateByParam("b", learningRate);
}
return model;
  }
 
Example 29
Source Project: FancyBing   Source File: TrainUtil.java    License: GNU General Public License v3.0 5 votes vote down vote up
public static ComputationGraph loadComputationGraph(String fn, double learningRate) throws Exception {
  	System.err.println("Loading model...");
  	File locationToSave = new File(System.getProperty("user.dir") + "/model/" + fn);
  	ComputationGraph model = ModelSerializer.restoreComputationGraph(locationToSave);

int numLayers = model.getNumLayers();
for (int i = 0; i < numLayers; i++) {
	model.getLayer(i).conf().setLearningRateByParam("W", learningRate);
	model.getLayer(i).conf().setLearningRateByParam("b", learningRate);
}
  	
return model;
  }
 
Example 30
Source Project: deeplearning4j   Source File: RegressionTest060.java    License: Apache License 2.0 5 votes vote down vote up
@Test
public void regressionTestMLP1() throws Exception {

    File f = Resources.asFile("regression_testing/060/060_ModelSerializer_Regression_MLP_1.zip");

    MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);

    MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
    assertEquals(2, conf.getConfs().size());

    DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer();
    assertEquals("relu", l0.getActivationFn().toString());
    assertEquals(3, l0.getNIn());
    assertEquals(4, l0.getNOut());
    assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
    assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater());
    assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6);

    OutputLayer l1 = (OutputLayer) conf.getConf(1).getLayer();
    assertEquals("softmax", l1.getActivationFn().toString());
    assertTrue(l1.getLossFn() instanceof LossMCXENT);
    assertEquals(4, l1.getNIn());
    assertEquals(5, l1.getNOut());
    assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
    assertEquals(new Nesterovs(0.15, 0.9), l1.getIUpdater());
    assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
    assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);

    int numParams = (int)net.numParams();
    assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
    int updaterSize = (int) new Nesterovs().stateSize(numParams);
    assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
}