Java Code Examples for org.deeplearning4j.util.ModelSerializer#restoreComputationGraph()

The following examples show how to use org.deeplearning4j.util.ModelSerializer#restoreComputationGraph() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: Vasttext.java    From scava with Eclipse Public License 2.0 6 votes vote down vote up
@SuppressWarnings("unchecked")
public void loadModel(File file) throws FileNotFoundException, ClassNotFoundException, IOException
{
	HashMap<String, Object> configuration = (HashMap<String, Object>) ModelSerializer.getObjectFromFile(file, "vasttext.config");
	multiLabel = (Boolean) configuration.get("multiLabel");
	vectorizer = new VasttextTextVectorizer();
	vectorizer.loadDictionary(configuration.get("dictionary"));
	labels = vectorizer.getLabels();
	labelsSize = labels.size();
	typeVasttext= (String) configuration.get("typeVasttext");
	multiLabelActivation = (Double) configuration.get("multiLabelActivation");
	
	if(typeVasttext.equalsIgnoreCase("textAndNumeric"))
	{
		vasttextTextAndNumeric=ModelSerializer.restoreComputationGraph(file); 
	}
	else if(typeVasttext.equalsIgnoreCase("onlyText"))
	{
		vasttextText=ModelSerializer.restoreMultiLayerNetwork(file);
	}
	else
	{
		throw new UnsupportedOperationException("Unknown type of model.");
	}
}
 
Example 2
Source File: TestUtils.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static ComputationGraph testModelSerialization(ComputationGraph net){
    ComputationGraph restored;
    try {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        ModelSerializer.writeModel(net, baos, true);
        byte[] bytes = baos.toByteArray();

        ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
        restored = ModelSerializer.restoreComputationGraph(bais, true);

        assertEquals(net.getConfiguration(), restored.getConfiguration());
        assertEquals(net.params(), restored.params());
    } catch (IOException e){
        //Should never happen
        throw new RuntimeException(e);
    }

    //Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
    ComputationGraphConfiguration conf = net.getConfiguration();
    serializeDeserializeJava(conf);

    return restored;
}
 
Example 3
Source File: DLModel.java    From java-ml-projects with Apache License 2.0 6 votes vote down vote up
public static DLModel fromFile(File file) throws Exception {
	Model model = null;
	try {
		System.out.println("Trying to load file as computation graph: " + file);
		model = ModelSerializer.restoreComputationGraph(file);
		System.out.println("Loaded Computation Graph.");
	} catch (Exception e) {
		try {
			System.out.println("Failed to load computation graph. Trying to load model.");
			model = ModelSerializer.restoreMultiLayerNetwork(file);
			System.out.println("Loaded Multilayernetwork");
		} catch (Exception e1) {
			System.out.println("Give up trying to load file: " + file);
			throw e;
		}
	}
	return new DLModel(model);
}
 
Example 4
Source File: TestUtils.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static ComputationGraph testModelSerialization(ComputationGraph net){

        ComputationGraph restored;
        try {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ModelSerializer.writeModel(net, baos, true);
            byte[] bytes = baos.toByteArray();

            ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
            restored = ModelSerializer.restoreComputationGraph(bais, true);

            assertEquals(net.getConfiguration(), restored.getConfiguration());
            assertEquals(net.params(), restored.params());
        } catch (IOException e){
            //Should never happen
            throw new RuntimeException(e);
        }

        //Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
        ComputationGraphConfiguration conf = net.getConfiguration();
        serializeDeserializeJava(conf);

        return restored;
    }
 
Example 5
Source File: Vgg16DeepLearning4jClassifier.java    From vision4j-collection with MIT License 5 votes vote down vote up
private void init(File computationGraph) throws IOException {
    this.vgg16 = ModelSerializer.restoreComputationGraph(computationGraph);
    this.scaler = new VGG16ImagePreProcessor();
    this.imageSize = new ImageSize(224, 224, 3);
    this.imageLoader = new NativeImageLoader(imageSize.getHeight(), imageSize.getWidth(), imageSize.channels());

    ArrayList<String> labels = ImageNetLabels.getLabels();
    String[] categoriesArray = Constants.IMAGENET_CATEGORIES;
    this.categories = new Categories(IntStream.range(0, categoriesArray.length)
            .mapToObj(i -> new Category(categoriesArray[i], i))
            .collect(Collectors.toList()));
}
 
Example 6
Source File: RegressionTest060.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void regressionTestCGLSTM1() throws Exception {

    File f = Resources.asFile("regression_testing/060/060_ModelSerializer_Regression_CG_LSTM_1.zip");

    ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true);

    ComputationGraphConfiguration conf = net.getConfiguration();
    assertEquals(3, conf.getVertices().size());

    GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer();
    assertEquals("tanh", l0.getActivationFn().toString());
    assertEquals(3, l0.getNIn());
    assertEquals(4, l0.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization());
    assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5);

    GravesBidirectionalLSTM l1 =
                    (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer();
    assertEquals("softsign", l1.getActivationFn().toString());
    assertEquals(4, l1.getNIn());
    assertEquals(4, l1.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization());
    assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);

    RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer();
    assertEquals(4, l2.getNIn());
    assertEquals(5, l2.getNOut());
    assertEquals("softmax", l2.getActivationFn().toString());
    assertTrue(l2.getLossFn() instanceof LossMCXENT);
}
 
Example 7
Source File: Classifier.java    From java-ml-projects with Apache License 2.0 5 votes vote down vote up
public static void init() throws IOException {
	String modelPath = Properties.classifierModelPath(); 
	labels = Properties.classifierLabels();
	int[] format = Properties.classifierInputFormat();
	loader = new NativeImageLoader(format[0], format[1], format[2]);
	model = ModelSerializer.restoreComputationGraph(modelPath);
	model.init();
}
 
Example 8
Source File: RegressionTest080.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void regressionTestCGLSTM1() throws Exception {

    File f = Resources.asFile("regression_testing/080/080_ModelSerializer_Regression_CG_LSTM_1.zip");

    ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true);

    ComputationGraphConfiguration conf = net.getConfiguration();
    assertEquals(3, conf.getVertices().size());

    GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer();
    assertTrue(l0.getActivationFn() instanceof ActivationTanH);
    assertEquals(3, l0.getNIn());
    assertEquals(4, l0.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization());
    assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5);

    GravesBidirectionalLSTM l1 =
                    (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer();
    assertTrue(l1.getActivationFn() instanceof ActivationSoftSign);
    assertEquals(4, l1.getNIn());
    assertEquals(4, l1.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization());
    assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);

    RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer();
    assertEquals(4, l2.getNIn());
    assertEquals(5, l2.getNOut());
    assertTrue(l2.getActivationFn() instanceof ActivationSoftmax);
    assertTrue(l2.getLossFn() instanceof LossMCXENT);
}
 
Example 9
Source File: LocalFileNetResultReference.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public Object getResultModel() throws IOException {
    Model m;
    if (isGraph) {
        m = ModelSerializer.restoreComputationGraph(modelFile, false);
    } else {
        m = ModelSerializer.restoreMultiLayerNetwork(modelFile, false);
    }
    return m;
}
 
Example 10
Source File: TrainUtil.java    From FancyBing with GNU General Public License v3.0 5 votes vote down vote up
public static ComputationGraph loadComputationGraph(String fn, double learningRate) throws Exception {
  	System.err.println("Loading model...");
  	File locationToSave = new File(System.getProperty("user.dir") + "/model/" + fn);
  	ComputationGraph model = ModelSerializer.restoreComputationGraph(locationToSave);

int numLayers = model.getNumLayers();
for (int i = 0; i < numLayers; i++) {
	model.getLayer(i).conf().setLearningRateByParam("W", learningRate);
	model.getLayer(i).conf().setLearningRateByParam("b", learningRate);
}
  	
return model;
  }
 
Example 11
Source File: PolicyNetService.java    From FancyBing with GNU General Public License v3.0 5 votes vote down vote up
private static ComputationGraph loadComputationGraph(String fn) throws Exception {
  	File f = new File(System.getProperty("user.dir") + "/model/" + fn);
  	System.out.println("Loading model " + f);
  	ComputationGraph model = ModelSerializer.restoreComputationGraph(f);
	
return model;
  }
 
Example 12
Source File: PolicyNetUtil.java    From FancyBing with GNU General Public License v3.0 5 votes vote down vote up
public static ComputationGraph loadComputationGraph(String fn) throws Exception {
    	System.err.println("Loading model...");
//    	File locationToSave = new File(System.getProperty("user.dir") + "/model/" + fn);
    	File locationToSave = new File("D:\\workspace\\fancybing-train\\model\\" + fn);
    	ComputationGraph model = ModelSerializer.restoreComputationGraph(locationToSave);
	
		return model;
    }
 
Example 13
Source File: ModelGuesser.java    From konduit-serving with Apache License 2.0 5 votes vote down vote up
/**
 * Load the model from the given file path
 *
 * @param path the path of the file to "guess"
 * @return the loaded model
 * @throws Exception if every model load attempt fails
 */
public static Model loadModelGuess(String path) throws Exception {
    try {
        return ModelSerializer.restoreMultiLayerNetwork(new File(path), true);
    } catch (Exception e) {
        log.warn("Tried multi layer network");
        try {
            return ModelSerializer.restoreComputationGraph(new File(path), true);
        } catch (Exception e1) {
            log.warn("Tried computation graph");
            try {
                return ModelSerializer.restoreMultiLayerNetwork(new File(path), false);
            } catch (Exception e4) {
                try {
                    return ModelSerializer.restoreComputationGraph(new File(path), false);
                } catch (Exception e5) {
                    try {
                        return KerasModelImport.importKerasModelAndWeights(path);
                    } catch (Exception e2) {
                        log.warn("Tried multi layer network keras");
                        try {
                            return KerasModelImport.importKerasSequentialModelAndWeights(path);

                        } catch (Exception e3) {
                            throw new ModelGuesserException("Unable to load model from path " + path
                                    + " (invalid model file or not a known model type)");
                        }
                    }
                }
            }
        }
    }
}
 
Example 14
Source File: CatVsDogRecognition.java    From Java-Machine-Learning-for-Computer-Vision with MIT License 4 votes vote down vote up
public ComputationGraph loadModel() throws IOException {
    computationGraph = ModelSerializer.restoreComputationGraph(new File(TRAINED_PATH_MODEL));
    return computationGraph;
}
 
Example 15
Source File: BidirectionalTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSerializationCompGraph() throws Exception {

    for(WorkspaceMode wsm : WorkspaceMode.values()) {
        log.info("*** Starting workspace mode: " + wsm);

        Nd4j.getRandom().setSeed(12345);

        ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder()
                .activation(Activation.TANH)
                .weightInit(WeightInit.XAVIER)
                .trainingWorkspaceMode(wsm)
                .inferenceWorkspaceMode(wsm)
                .updater(new Adam())
                .graphBuilder()
                .addInputs("in")
                .layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in")
                .layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "0")
                .layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat)
                        .nIn(10).nOut(10).build(), "1")
                .setOutputs("2")
                .build();

        ComputationGraph net1 = new ComputationGraph(conf1);
        net1.init();
        long[] inshape = (rnnDataFormat == NCW)? new long[]{3, 10, 5}: new long[]{3, 5, 10};
        INDArray in = Nd4j.rand(inshape);
        INDArray labels = Nd4j.rand(inshape);

        net1.fit(new DataSet(in, labels));

        byte[] bytes;
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
            ModelSerializer.writeModel(net1, baos, true);
            bytes = baos.toByteArray();
        }


        ComputationGraph net2 = ModelSerializer.restoreComputationGraph(new ByteArrayInputStream(bytes), true);


        in = Nd4j.rand(inshape);
        labels = Nd4j.rand(inshape);

        INDArray out1 = net1.outputSingle(in);
        INDArray out2 = net2.outputSingle(in);

        assertEquals(out1, out2);

        net1.setInput(0, in);
        net2.setInput(0, in);
        net1.setLabels(labels);
        net2.setLabels(labels);

        net1.computeGradientAndScore();
        net2.computeGradientAndScore();

        assertEquals(net1.score(), net2.score(), 1e-6);
        assertEquals(net1.gradient().gradient(), net2.gradient().gradient());
    }
}
 
Example 16
Source File: KerasZooModel.java    From wekaDeeplearning4j with GNU General Public License v3.0 4 votes vote down vote up
@Override
public ComputationGraph initPretrained(PretrainedType pretrainedType) throws IOException {
    String remoteUrl = pretrainedUrl(pretrainedType);
    if (remoteUrl == null)
        throw new UnsupportedOperationException(
                "Pretrained " + pretrainedType + " weights are not available for this model.");

    // Set up file locations
    String localFilename = modelPrettyName() + ".zip";

    File rootCacheDir = DL4JResources.getDirectory(ResourceType.ZOO_MODEL, modelFamily());
    File cachedFile = new File(rootCacheDir, localFilename);

    // Download the file if necessary
    if (!cachedFile.exists()) {
        log.info("Downloading model to " + cachedFile.toString());
        FileUtils.copyURLToFile(new URL(remoteUrl), cachedFile);
    } else {
        log.info("Using cached model at " + cachedFile.toString());
    }

    // Validate the checksum - ensure this is the correct file
    long expectedChecksum = pretrainedChecksum(pretrainedType);
    if (expectedChecksum != 0L) {
        log.info("Verifying download...");
        Checksum adler = new Adler32();
        FileUtils.checksum(cachedFile, adler);
        long localChecksum = adler.getValue();
        log.info("Checksum local is " + localChecksum + ", expecting " + expectedChecksum);

        if (expectedChecksum != localChecksum) {
            log.error("Checksums do not match. Cleaning up files and failing...");
            cachedFile.delete();
            throw new IllegalStateException(
                    String.format("Pretrained model file for model %s failed checksum.", this.modelPrettyName()));
        }
    }

    // Load the .zip file to a ComputationGraph
    try {
        return ModelSerializer.restoreComputationGraph(cachedFile);
    } catch (Exception ex) {
        System.err.println("Failed to load model");
        ex.printStackTrace();
        return null;
    }
}
 
Example 17
Source File: Dl4jMlpClassifier.java    From wekaDeeplearning4j with GNU General Public License v3.0 4 votes vote down vote up
/**
 * Custom deserialization method
 *
 * @param ois the object input stream
 */
private void readObject(ObjectInputStream ois) throws ClassNotFoundException,
    IOException {
  ClassLoader origLoader = Thread.currentThread().getContextClassLoader();
  try {
    Thread.currentThread().setContextClassLoader(
        this.getClass().getClassLoader());
    // default deserialization
    ois.defaultReadObject();

    // Restore the layers
    String[] layerConfigs = (String[]) ois.readObject();
    layers = new Layer[layerConfigs.length];
    for (int i = 0; i < layerConfigs.length; i++) {
      String layerConfigString = layerConfigs[i];
      String[] split = layerConfigString.split("::");
      String clsName = split[0];
      String layerConfig = split[1];
      String[] options = weka.core.Utils.splitOptions(layerConfig);
      layers[i] =
          (Layer) weka.core.Utils.forName(Layer.class, clsName, options);
    }

    // restore the network model
    if (isInitializationFinished) {
      File tmpFile = File.createTempFile("restore", "multiLayer");
      tmpFile.deleteOnExit();
      BufferedOutputStream bos =
          new BufferedOutputStream(new FileOutputStream(tmpFile));
      long remaining = modelSize;
      while (remaining > 0) {
        int bsize = 10024;
        if (remaining < 10024) {
          bsize = (int) remaining;
        }
        byte[] buffer = new byte[bsize];
        int len = ois.read(buffer);
        if (len == -1) {
          throw new IOException(
              "Reached end of network model prematurely during deserialization.");
        }
        bos.write(buffer, 0, len);
        remaining -= len;
      }
      bos.flush();
      model = ModelSerializer.restoreComputationGraph(tmpFile, false);
    }
  } catch (Exception e) {
    log.error("Failed to restore serialized model. Error: " + e.getMessage());
    e.printStackTrace();
  } finally {
    Thread.currentThread().setContextClassLoader(origLoader);
  }
}
 
Example 18
Source File: KerasZooModel.java    From wekaDeeplearning4j with GNU General Public License v3.0 4 votes vote down vote up
@Override
public ComputationGraph initPretrained(PretrainedType pretrainedType) throws IOException {
    String remoteUrl = pretrainedUrl(pretrainedType);
    if (remoteUrl == null)
        throw new UnsupportedOperationException(
                "Pretrained " + pretrainedType + " weights are not available for this model.");

    // Set up file locations
    String localFilename = modelPrettyName() + ".zip";

    File rootCacheDir = DL4JResources.getDirectory(ResourceType.ZOO_MODEL, modelFamily());
    File cachedFile = new File(rootCacheDir, localFilename);

    // Download the file if necessary
    if (!cachedFile.exists()) {
        log.info("Downloading model to " + cachedFile.toString());
        FileUtils.copyURLToFile(new URL(remoteUrl), cachedFile);
    } else {
        log.info("Using cached model at " + cachedFile.toString());
    }

    // Validate the checksum - ensure this is the correct file
    long expectedChecksum = pretrainedChecksum(pretrainedType);
    if (expectedChecksum != 0L) {
        log.info("Verifying download...");
        Checksum adler = new Adler32();
        FileUtils.checksum(cachedFile, adler);
        long localChecksum = adler.getValue();
        log.info("Checksum local is " + localChecksum + ", expecting " + expectedChecksum);

        if (expectedChecksum != localChecksum) {
            log.error("Checksums do not match. Cleaning up files and failing...");
            cachedFile.delete();
            throw new IllegalStateException(
                    String.format("Pretrained model file for model %s failed checksum.", this.modelPrettyName()));
        }
    }

    // Load the .zip file to a ComputationGraph
    try {
        return ModelSerializer.restoreComputationGraph(cachedFile);
    } catch (Exception ex) {
        System.err.println("Failed to load model");
        ex.printStackTrace();
        return null;
    }
}
 
Example 19
Source File: KerasYolo9000PredictTest.java    From deeplearning4j with Apache License 2.0 3 votes vote down vote up
@Ignore
@Test
public void testYoloPredictionImport() throws Exception {


    int HEIGHT = 416;
    int WIDTH = 416;
    INDArray indArray = Nd4j.create(HEIGHT, WIDTH, 3);
    IMAGE_PREPROCESSING_SCALER.transform(indArray);

    KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class);

    String h5_FILENAME = "modelimport/keras/examples/yolo/yolo-voc.h5";
    ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(h5_FILENAME, false);

    double[][] priorBoxes = {{1.3221, 1.73145}, {3.19275, 4.00944}, {5.05587, 8.09892}, {9.47112, 4.84053}, {11.2364, 10.0071}};
    INDArray priors = Nd4j.create(priorBoxes);

    ComputationGraph model = new TransferLearning.GraphBuilder(graph)
            .addLayer("outputs",
                    new org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer.Builder()
                            .boundingBoxPriors(priors)
                            .build(),
                    "conv2d_23")
            .setOutputs("outputs")
            .build();

    ModelSerializer.writeModel(model, DL4J_MODEL_FILE_NAME, false);

    ComputationGraph computationGraph = ModelSerializer.restoreComputationGraph(new File(DL4J_MODEL_FILE_NAME));

    System.out.println(computationGraph.summary(InputType.convolutional(416, 416, 3)));

    INDArray results = computationGraph.outputSingle(indArray);


}
 
Example 20
Source File: SolverDL4j.java    From twse-captcha-solver-dl4j with MIT License 2 votes vote down vote up
/**
 * Creates a new <code>SolverDL4j</code> instance.
 *
 * @exception IOException if an error occurs
 */
public SolverDL4j() throws IOException {
  InputStream is = SolverDL4j.class.getClass().getResourceAsStream("/model.zip");
  model = ModelSerializer.restoreComputationGraph(is);
}