Java Code Examples for org.deeplearning4j.util.ModelSerializer#writeModel()

The following examples show how to use org.deeplearning4j.util.ModelSerializer#writeModel() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: KerasCustomLayerTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Ignore
@Test
public void testCustomLayerImport() throws Exception {
    // file paths
    String kerasWeightsAndConfigUrl = DL4JResources.getURLString("googlenet_keras_weightsandconfig.h5");
    File cachedKerasFile = testDir.newFile("googlenet_keras_weightsandconfig.h5");
    String outputPath = testDir.newFile("googlenet_dl4j_inference.zip").getAbsolutePath();

    KerasLayer.registerCustomLayer("PoolHelper", KerasPoolHelper.class);
    KerasLayer.registerCustomLayer("LRN", KerasLRN.class);

    // download file
    if (!cachedKerasFile.exists()) {
        log.info("Downloading model to " + cachedKerasFile.toString());
        FileUtils.copyURLToFile(new URL(kerasWeightsAndConfigUrl), cachedKerasFile);
        cachedKerasFile.deleteOnExit();
    }

    org.deeplearning4j.nn.api.Model importedModel =
            KerasModelImport.importKerasModelAndWeights(cachedKerasFile.getAbsolutePath());
    ModelSerializer.writeModel(importedModel, outputPath, false);

    ComputationGraph serializedModel = ModelSerializer.restoreComputationGraph(outputPath);
    log.info(serializedModel.summary());
}
 
Example 2
Source File: Vasttext.java    From scava with Eclipse Public License 2.0 6 votes vote down vote up
public void storeModel(File file) throws IOException
{
	HashMap<String, Object> configuration = new HashMap<String, Object>();
	
	//We do not store the updaters
	if(vasttextText!=null)
	{
		ModelSerializer.writeModel(vasttextText, file, false);
		configuration.put("typeVasttext", "onlyText");
	}
	else if(vasttextTextAndNumeric!=null)
	{
		ModelSerializer.writeModel(vasttextTextAndNumeric, file, false);
		configuration.put("typeVasttext", "textAndNumeric");
	}
	else
		throw new UnsupportedOperationException("Train before store model");
	
	configuration.put("multiLabelActivation", multiLabelActivation);
	configuration.put("multiLabel", multiLabel);
	configuration.put("dictionary", vectorizer.getDictionary());
	
	ModelSerializer.addObjectToFile(file, "vasttext.config", configuration);
}
 
Example 3
Source File: InferenceExecutionerFactoryTests.java    From konduit-serving with Apache License 2.0 6 votes vote down vote up
@Test
public void testMultiLayerNetwork() throws Exception {
    Pair<MultiLayerNetwork, DataNormalization> trainedNetwork = TrainUtils.getTrainedNetwork();
    MultiLayerNetwork save = trainedNetwork.getLeft();
    File dir = testDir.newFolder();
    File tmpZip = new File(dir, "dl4j_mln_model.zip");
    tmpZip.deleteOnExit();
    ModelSerializer.writeModel(save, tmpZip, true);

    ModelStep modelPipelineStep = Dl4jStep.builder()
            .inputName("default")
            .outputName("output")
            .path(tmpZip.getAbsolutePath())
            .build();

    Dl4jInferenceExecutionerFactory factory = new Dl4jInferenceExecutionerFactory();
    InitializedInferenceExecutionerConfig initializedInferenceExecutionerConfig = factory.create(modelPipelineStep);
    MultiLayerNetworkInferenceExecutioner multiLayerNetworkInferenceExecutioner = (MultiLayerNetworkInferenceExecutioner) initializedInferenceExecutionerConfig.getInferenceExecutioner();
    assertNotNull(multiLayerNetworkInferenceExecutioner);
    assertNotNull(multiLayerNetworkInferenceExecutioner.model());
    assertNotNull(multiLayerNetworkInferenceExecutioner.modelLoader());
}
 
Example 4
Source File: TestUtils.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){

        MultiLayerNetwork restored;
        try {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ModelSerializer.writeModel(net, baos, true);
            byte[] bytes = baos.toByteArray();

            ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
            restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);

            assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
            assertEquals(net.params(), restored.params());
        } catch (IOException e){
            //Should never happen
            throw new RuntimeException(e);
        }

        //Also check the MultiLayerConfiguration is serializable (required by Spark etc)
        MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
        serializeDeserializeJava(conf);

        return restored;
    }
 
Example 5
Source File: Dl4jMlpClassifier.java    From wekaDeeplearning4j with GNU General Public License v3.0 6 votes vote down vote up
/**
 * Custom serialization method.
 *
 * @param oos the object output stream
 */
private void writeObject(ObjectOutputStream oos) throws IOException {
  // figure out size of the written network
  CountingOutputStream cos = new CountingOutputStream(new NullOutputStream());
  if (isInitializationFinished) {
    ModelSerializer.writeModel(model, cos, false);
  }
  modelSize = cos.getByteCount();

  // default serialization
  oos.defaultWriteObject();

  // Write layer configurations
  String[] layerConfigs = new String[layers.length];
  for (int i = 0; i < layers.length; i++) {
    layerConfigs[i] =
        layers[i].getClass().getName() + "::"
            + weka.core.Utils.joinOptions(layers[i].getOptions());
  }
  oos.writeObject(layerConfigs);

  // actually write the network
  if (isInitializationFinished) {
    ModelSerializer.writeModel(model, oos, false);
  }
}
 
Example 6
Source File: TestUtils.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){

        MultiLayerNetwork restored;
        try {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ModelSerializer.writeModel(net, baos, true);
            byte[] bytes = baos.toByteArray();

            ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
            restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);

            assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
            assertEquals(net.params(), restored.params());
        } catch (IOException e){
            //Should never happen
            throw new RuntimeException(e);
        }

        //Also check the MultiLayerConfiguration is serializable (required by Spark etc)
        MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
        serializeDeserializeJava(conf);

        return restored;
    }
 
Example 7
Source File: MultiLayerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testIterationCountAndPersistence() throws IOException {
    Nd4j.getRandom().setSeed(123);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
                    .list()
                    .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER)
                                    .activation(Activation.TANH).build())
                    .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
                                    LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
                                                    .build())
                    .build();


    MultiLayerNetwork network = new MultiLayerNetwork(conf);
    network.init();

    DataSetIterator iter = new IrisDataSetIterator(50, 150);

    assertEquals(0, network.getLayerWiseConfigurations().getIterationCount());
    network.fit(iter);
    assertEquals(3, network.getLayerWiseConfigurations().getIterationCount());
    iter.reset();
    network.fit(iter);
    assertEquals(6, network.getLayerWiseConfigurations().getIterationCount());
    iter.reset();
    network.fit(iter.next());
    assertEquals(7, network.getLayerWiseConfigurations().getIterationCount());

    ByteArrayOutputStream baos = new ByteArrayOutputStream();
    ModelSerializer.writeModel(network, baos, true);
    byte[] asBytes = baos.toByteArray();

    ByteArrayInputStream bais = new ByteArrayInputStream(asBytes);
    MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(bais, true);
    assertEquals(7, net.getLayerWiseConfigurations().getIterationCount());
}
 
Example 8
Source File: InMemoryMultiLayernetworkModelLoader.java    From konduit-serving with Apache License 2.0 5 votes vote down vote up
@Override
public Buffer saveModel(@NonNull MultiLayerNetwork model) {
    ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
    try {
        ModelSerializer.writeModel(model, byteArrayOutputStream, true);
    } catch (IOException e) {
        throw new RuntimeException(e);
    }

    return Buffer.buffer(byteArrayOutputStream.toByteArray());
}
 
Example 9
Source File: TestUtils.java    From konduit-serving with Apache License 2.0 5 votes vote down vote up
public static InferenceConfiguration getConfig(TemporaryFolder trainDir) throws Exception {
    Pair<MultiLayerNetwork, DataNormalization> multiLayerNetwork = TrainUtils.getTrainedNetwork();
    File modelSave = trainDir.newFile("model.zip");
    ModelSerializer.writeModel(multiLayerNetwork.getFirst(), modelSave, false);

    Schema.Builder schemaBuilder = new Schema.Builder();
    schemaBuilder.addColumnDouble("petal_length")
            .addColumnDouble("petal_width")
            .addColumnDouble("sepal_width")
            .addColumnDouble("sepal_height");
    Schema inputSchema = schemaBuilder.build();

    Schema.Builder outputSchemaBuilder = new Schema.Builder();
    outputSchemaBuilder.addColumnDouble("setosa");
    outputSchemaBuilder.addColumnDouble("versicolor");
    outputSchemaBuilder.addColumnDouble("virginica");
    Schema outputSchema = outputSchemaBuilder.build();

    ServingConfig servingConfig = ServingConfig.builder()
            .createLoggingEndpoints(true)
            .build();

    Dl4jStep modelPipelineStep = Dl4jStep.builder()
            .inputName("default")
            .inputColumnName("default", SchemaTypeUtils.columnNames(inputSchema))
            .inputSchema("default", SchemaTypeUtils.typesForSchema(inputSchema))
            .outputSchema("default", SchemaTypeUtils.typesForSchema(outputSchema))
            .path(modelSave.getAbsolutePath())
            .outputColumnName("default", SchemaTypeUtils.columnNames(outputSchema))
            .build();

    return InferenceConfiguration.builder()
            .servingConfig(servingConfig)
            .step(modelPipelineStep)
            .build();
}
 
Example 10
Source File: PortsTest.java    From konduit-serving with Apache License 2.0 5 votes vote down vote up
private static String trainAndSaveModel() throws Exception {
    Pair<MultiLayerNetwork, DataNormalization> multiLayerNetwork = TrainUtils.getTrainedNetwork();
    File modelSave = folder.newFile("model.zip");
    ModelSerializer.writeModel(multiLayerNetwork.getFirst(), modelSave, false);

    return modelSave.getAbsolutePath();
}
 
Example 11
Source File: TrainUtil.java    From FancyBing with GNU General Public License v3.0 5 votes vote down vote up
public static String saveModel(String name, Model model, int index, int accuracy) throws Exception {
  	System.err.println("Saving model, don't shutdown...");
      try {
      	String fn = name + "_idx_" + index + "_" + accuracy + ".zip";
	File locationToSave = new File(System.getProperty("user.dir") + "/model/" + fn);
	boolean saveUpdater = true;                                             //Updater: i.e., the state for Momentum, RMSProp, Adagrad etc. Save this if you want to train your network more in the future
	ModelSerializer.writeModel(model, locationToSave, saveUpdater);
	System.err.println("Model saved");
	return fn;
} catch (IOException e) {
	System.err.println("Save model failed");
	e.printStackTrace();
	throw e;
}
  }
 
Example 12
Source File: ActorCriticCompGraph.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public void save(OutputStream stream) throws IOException {
    ModelSerializer.writeModel(cg, stream, true);
}
 
Example 13
Source File: LocalFileModelSaver.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
private void save(MultiLayerNetwork net, String modelName) throws IOException {
    ModelSerializer.writeModel(net, modelName, true);
}
 
Example 14
Source File: DeepAutoEncoderExample.java    From Java-for-Data-Science with MIT License 4 votes vote down vote up
public DeepAutoEncoderExample() {
    try {
        int seed = 123;
        int numberOfIterations = 1;
        iterator = new MnistDataSetIterator(1000, MnistDataFetcher.NUM_EXAMPLES, true);
        
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(numberOfIterations)
                .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
                .list()
                .layer(0, new RBM.Builder().nIn(numberOfRows * numberOfColumns)
                        .nOut(1000)
                        .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
                .layer(1, new RBM.Builder().nIn(1000).nOut(500)
                        .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
                .layer(2, new RBM.Builder().nIn(500).nOut(250)
                        .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
                .layer(3, new RBM.Builder().nIn(250).nOut(100)
                        .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
                .layer(4, new RBM.Builder().nIn(100).nOut(30)
                        .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) //encoding stops
                .layer(5, new RBM.Builder().nIn(30).nOut(100)
                        .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) //decoding starts
                .layer(6, new RBM.Builder().nIn(100).nOut(250)
                        .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
                .layer(7, new RBM.Builder().nIn(250).nOut(500)
                        .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
                .layer(8, new RBM.Builder().nIn(500).nOut(1000)
                        .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
                .layer(9, new OutputLayer.Builder(
                                LossFunctions.LossFunction.RMSE_XENT).nIn(1000)
                        .nOut(numberOfRows * numberOfColumns).build())
                .pretrain(true).backprop(true)
                .build();

        model = new MultiLayerNetwork(conf);
        model.init();

        model.setListeners(Collections.singletonList(
                (IterationListener) new ScoreIterationListener()));

        while (iterator.hasNext()) {
            DataSet dataSet = iterator.next();
            model.fit(new DataSet(dataSet.getFeatureMatrix(),
                    dataSet.getFeatureMatrix()));
        }

        modelFile = new File("savedModel");
        ModelSerializer.writeModel(model, modelFile, true);
    } catch (IOException ex) {
        ex.printStackTrace();
    }
}
 
Example 15
Source File: ParallelWrapperMainTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void runParallelWrapperMain() throws Exception {

    int nChannels = 1;
    int outputNum = 10;

    // for GPU you usually want to have higher batchSize
    int batchSize = 128;
    int seed = 123;
    int uiPort = 9500;
    System.setProperty("org.deeplearning4j.ui.port", String.valueOf(uiPort));
    log.info("Load data....");
    DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);
    DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345);

    log.info("Build model....");
    MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed)
                    .l2(0.0005)
                    .weightInit(WeightInit.XAVIER)
                    .updater(new Nesterovs(0.01, 0.9)).list()
                    .layer(0, new ConvolutionLayer.Builder(5, 5)
                                    //nIn and nOut specify channels. nIn here is the nChannels and nOut is the number of filters to be applied
                                    .nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build())
                    .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
                                    .stride(2, 2).build())
                    .layer(2, new ConvolutionLayer.Builder(5, 5)
                                    //Note that nIn needed be specified in later layers
                                    .stride(1, 1).nOut(50).activation(Activation.IDENTITY).build())
                    .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
                                    .stride(2, 2).build())
                    .layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build())
                    .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                                    .nOut(outputNum).activation(Activation.SOFTMAX).build())
                    .setInputType(InputType.convolutionalFlat(28, 28, nChannels));

    MultiLayerConfiguration conf = builder.build();
    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();
    File tempModel = testDir.newFile("tmpmodel.zip");
    tempModel.deleteOnExit();
    ModelSerializer.writeModel(model, tempModel, false);
    File tmp = testDir.newFile("tmpmodel.bin");
    tmp.deleteOnExit();
    ParallelWrapperMain parallelWrapperMain = new ParallelWrapperMain();
    try {
        parallelWrapperMain.runMain(new String[]{"--modelPath", tempModel.getAbsolutePath(),
                "--dataSetIteratorFactoryClazz", MnistDataSetIteratorProviderFactory.class.getName(),
                "--modelOutputPath", tmp.getAbsolutePath(), "--uiUrl", "localhost:" + uiPort});
    } finally {
        parallelWrapperMain.stop();
    }
}
 
Example 16
Source File: LocalFileGraphSaver.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
private void save(ComputationGraph net, String confOut) throws IOException {
    ModelSerializer.writeModel(net, confOut, true);
}
 
Example 17
Source File: ModelTupleStreamTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
private void doTest(Model originalModel, int numInputs, int numOutputs) throws Exception {

    final Path tempDirPath = Files.createTempDirectory(null);
    final File tempDirFile = tempDirPath.toFile();
    tempDirFile.deleteOnExit();

    final SolrResourceLoader solrResourceLoader = new SolrResourceLoader(tempDirPath);

    final File tempFile = File.createTempFile("prefix", "suffix", tempDirFile);
    tempFile.deleteOnExit();

    final String serializedModelFileName = tempFile.getPath();

    ModelSerializer.writeModel(originalModel, serializedModelFileName, false);

    final Model restoredModel = ModelGuesser.loadModelGuess(serializedModelFileName);

    final StreamContext streamContext = new StreamContext();
    final SolrClientCache solrClientCache = new SolrClientCache();
    streamContext.setSolrClientCache(solrClientCache);

    final String[] inputKeys = new String[numInputs];
    final String inputKeysList = fillArray(inputKeys, "input", ",");

    final String[] outputKeys = new String[numOutputs];
    final String outputKeysList = fillArray(outputKeys, "output", ",");

    for (final float[] floats : floatsList(numInputs)) {

      final String inputValuesList;
      {
        final StringBuilder sb = new StringBuilder();
        for (int ii=0; ii<inputKeys.length; ++ii) {
          if (0 < ii) sb.append(',');
          sb.append(inputKeys[ii]).append('=').append(floats[ii]);
        }
        inputValuesList = sb.toString();
      }

      final StreamFactory streamFactory = new SolrDefaultStreamFactory()
          .withSolrResourceLoader(solrResourceLoader)
          .withFunctionName("model", ModelTupleStream.class);

      final StreamExpression streamExpression = StreamExpressionParser.parse("model("
        + "tuple(" + inputValuesList + ")"
        + ",serializedModelFileName=\"" + serializedModelFileName + "\""
        + ",inputKeys=\"" + inputKeysList + "\""
        + ",outputKeys=\"" + outputKeysList + "\""
        + ")");

      final TupleStream tupleStream = streamFactory.constructStream(streamExpression);
      tupleStream.setStreamContext(streamContext);

      assertTrue(tupleStream instanceof ModelTupleStream);
      final ModelTupleStream modelTupleStream = (ModelTupleStream)tupleStream;

      modelTupleStream.open();
      {
        final Tuple tuple1 = modelTupleStream.read();
        assertNotNull(tuple1);
        assertFalse(tuple1.EOF);

        for (int ii=0; ii<outputKeys.length; ++ii)
        {
          final INDArray inputs = Nd4j.create(new float[][] { floats });
          final double originalScore = NetworkUtils.output((Model)originalModel, inputs).getDouble(ii);
          final double restoredScore = NetworkUtils.output((Model)restoredModel, inputs).getDouble(ii);
          assertEquals(
            originalModel.getClass().getSimpleName()+" (originalScore-restoredScore)="+(originalScore-restoredScore),
            originalScore, restoredScore, 1e-5);

          final Double outputValue = tuple1.getDouble(outputKeys[ii]);
          assertNotNull(outputValue);
          final double tupleScore = outputValue.doubleValue();
          assertEquals(
            originalModel.getClass().getSimpleName()+" (originalScore-tupleScore["+ii+"])="+(originalScore-tupleScore),
            originalScore, tupleScore, 1e-5);
        }

        final Tuple tuple2 = modelTupleStream.read();
        assertNotNull(tuple2);
        assertTrue(tuple2.EOF);
      }
      modelTupleStream.close();

      doToExpressionTest(streamExpression,
        modelTupleStream.toExpression(streamFactory),
        inputKeys.length);

      doToExplanationTest(modelTupleStream.toExplanation(streamFactory));
    }

  }
 
Example 18
Source File: ActorCriticSeparate.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException {
    ModelSerializer.writeModel(valueNet, streamValue, true);
    ModelSerializer.writeModel(policyNet, streamPolicy, true);
}
 
Example 19
Source File: Word2VecTestsSmall.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test(timeout = 300000)
    public void testW2VEmbeddingLayerInit() throws Exception {
        Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);

        val inputFile = Resources.asFile("big/raw_sentences.txt");
        val iter = ParagraphVectorsTest.getIterator(isIntegrationTests(), inputFile);
//        val iter = new BasicLineIterator(inputFile);
        val t = new DefaultTokenizerFactory();
        t.setTokenPreProcessor(new CommonPreprocessor());

        Word2Vec vec = new Word2Vec.Builder()
                .minWordFrequency(1)
                .epochs(1)
                .layerSize(300)
                .limitVocabularySize(1) // Limit the vocab size to 2 words
                .windowSize(5)
                .allowParallelTokenization(true)
                .batchSize(512)
                .learningRate(0.025)
                .minLearningRate(0.0001)
                .negativeSample(0.0)
                .sampling(0.0)
                .useAdaGrad(false)
                .useHierarchicSoftmax(true)
                .iterations(1)
                .useUnknown(true) // Using UNK with limited vocab size causes the issue
                .seed(42)
                .iterate(iter)
                .workers(4)
                .tokenizerFactory(t).build();

        vec.fit();

        INDArray w = vec.lookupTable().getWeights();
        System.out.println(w);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(12345).list()
                .layer(new EmbeddingLayer.Builder().weightInit(vec).build())
                .layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(w.size(1)).nOut(3).build())
                .layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
                        .nOut(4).build())
                .build();

        final MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        INDArray w0 = net.getParam("0_W");
        assertEquals(w, w0);

        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        ModelSerializer.writeModel(net, baos, true);
        byte[] bytes = baos.toByteArray();

        ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
        final MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);

        assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
        assertTrue(net.params().equalsWithEps(restored.params(), 2e-3));
    }
 
Example 20
Source File: MultiLayerNetwork.java    From deeplearning4j with Apache License 2.0 2 votes vote down vote up
/**
 * Save the MultiLayerNetwork to a file. Restore using {@link #load(File, boolean)}.
 *
 * @param f File to save the network to
 * @param saveUpdater If true: save the updater (i.e., the state array for momentum/Adam/rmsprop etc), which should
 *                    usually be saved if further training is required
 * @see ModelSerializer ModelSerializer for more details (and saving/loading via streams)
 * @see #save(File, boolean)
 */
public void save(File f, boolean saveUpdater) throws IOException{
    ModelSerializer.writeModel(this, f, saveUpdater);
}