Java Code Examples for org.deeplearning4j.util.ModelSerializer#restoreMultiLayerNetwork()

The following examples show how to use org.deeplearning4j.util.ModelSerializer#restoreMultiLayerNetwork() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: Vasttext.java    From scava with Eclipse Public License 2.0 6 votes vote down vote up
@SuppressWarnings("unchecked")
public void loadModel(File file) throws FileNotFoundException, ClassNotFoundException, IOException
{
	HashMap<String, Object> configuration = (HashMap<String, Object>) ModelSerializer.getObjectFromFile(file, "vasttext.config");
	multiLabel = (Boolean) configuration.get("multiLabel");
	vectorizer = new VasttextTextVectorizer();
	vectorizer.loadDictionary(configuration.get("dictionary"));
	labels = vectorizer.getLabels();
	labelsSize = labels.size();
	typeVasttext= (String) configuration.get("typeVasttext");
	multiLabelActivation = (Double) configuration.get("multiLabelActivation");
	
	if(typeVasttext.equalsIgnoreCase("textAndNumeric"))
	{
		vasttextTextAndNumeric=ModelSerializer.restoreComputationGraph(file); 
	}
	else if(typeVasttext.equalsIgnoreCase("onlyText"))
	{
		vasttextText=ModelSerializer.restoreMultiLayerNetwork(file);
	}
	else
	{
		throw new UnsupportedOperationException("Unknown type of model.");
	}
}
 
Example 2
Source File: TestUtils.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){

        try {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ModelSerializer.writeModel(net, baos, true);
            byte[] bytes = baos.toByteArray();

            ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
            MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);

            assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
            assertEquals(net.params(), restored.params());

            return restored;
        } catch (IOException e){
            //Should never happen
            throw new RuntimeException(e);
        }
    }
 
Example 3
Source File: CustomerRetentionPredictionApi.java    From Java-Deep-Learning-Cookbook with MIT License 5 votes vote down vote up
public static INDArray generateOutput(File inputFile, String modelFilePath) throws IOException, InterruptedException {
    final File modelFile = new File(modelFilePath);
    final MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(modelFile);
    final RecordReader recordReader = generateReader(inputFile);
    //final INDArray array = RecordConverter.toArray(recordReader.next());
    final NormalizerStandardize normalizerStandardize = ModelSerializer.restoreNormalizerFromFile(modelFile);
    //normalizerStandardize.transform(array);
    final DataSetIterator dataSetIterator = new RecordReaderDataSetIterator.Builder(recordReader,1).build();
    normalizerStandardize.fit(dataSetIterator);
    dataSetIterator.setPreProcessor(normalizerStandardize);
    return network.output(dataSetIterator);

}
 
Example 4
Source File: ModelGuesser.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Load the model from the given file path
 * @param path the path of the file to "guess"
 *
 * @return the loaded model
 * @throws Exception
 */
public static Model loadModelGuess(String path) throws Exception {
    try {
        return ModelSerializer.restoreMultiLayerNetwork(new File(path), true);
    } catch (Exception e) {
        log.warn("Tried multi layer network");
        try {
            return ModelSerializer.restoreComputationGraph(new File(path), true);
        } catch (Exception e1) {
            log.warn("Tried computation graph");
            try {
                return ModelSerializer.restoreMultiLayerNetwork(new File(path), false);
            } catch (Exception e4) {
                try {
                    return ModelSerializer.restoreComputationGraph(new File(path), false);
                } catch (Exception e5) {
                    try {
                        return KerasModelImport.importKerasModelAndWeights(path);
                    } catch (Exception e2) {
                        log.warn("Tried multi layer network keras");
                        try {
                            return KerasModelImport.importKerasSequentialModelAndWeights(path);

                        } catch (Exception e3) {
                            throw new ModelGuesserException("Unable to load model from path " + path
                                    + " (invalid model file or not a known model type)");
                        }
                    }
                }
            }
        }
    }
}
 
Example 5
Source File: Main.java    From gluon-samples with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
public static void main(String[] args) throws Exception {
    File f = new File(savedModelLocation);
    MultiLayerNetwork model = null;
    if (f.exists()) {
        LOGGER.info("Model exists, restore it");
        model = ModelSerializer.restoreMultiLayerNetwork(savedModelLocation);
        utils.evaluateModel(model);

    } else {
        LOGGER.info("Create model");
        model = utils.createModel();
        LOGGER.info("Train model");
        utils.trainModel(model, true, null, -1);
        LOGGER.info("Save model");
        utils.saveModel(model, savedModelLocation);
        LOGGER.info("Eval model");
        utils.evaluateModel(model);
    }
    LOGGER.info("Run tests");
    runTests(model);
    LOGGER.info("Evaluate model after tests");
    utils.evaluateModel(model);
    LOGGER.info("Correct Image");
    correctImage(model, Main.class.getResourceAsStream("/mytestdata/3b.png"),3);
    utils.evaluateModel(model);

}
 
Example 6
Source File: ParallelInferenceTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Before
public void setUp() throws Exception {
    if (model == null) {
        File file = Resources.asFile("models/LenetMnistMLN.zip");
        model = ModelSerializer.restoreMultiLayerNetwork(file, true);

        iterator = new MnistDataSetIterator(1, false, 12345);
    }
}
 
Example 7
Source File: RegressionTest071.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void regressionTestLSTM1() throws Exception {

    File f = Resources.asFile("regression_testing/071/071_ModelSerializer_Regression_LSTM_1.zip");

    MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);

    MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
    assertEquals(3, conf.getConfs().size());

    GravesLSTM l0 = (GravesLSTM) conf.getConf(0).getLayer();
    assertEquals("tanh", l0.getActivationFn().toString());
    assertEquals(3, l0.getNIn());
    assertEquals(4, l0.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization());
    assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5);

    GravesBidirectionalLSTM l1 = (GravesBidirectionalLSTM) conf.getConf(1).getLayer();
    assertEquals("softsign", l1.getActivationFn().toString());
    assertEquals(4, l1.getNIn());
    assertEquals(4, l1.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization());
    assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);

    RnnOutputLayer l2 = (RnnOutputLayer) conf.getConf(2).getLayer();
    assertEquals(4, l2.getNIn());
    assertEquals(5, l2.getNOut());
    assertEquals("softmax", l2.getActivationFn().toString());
    assertTrue(l2.getLossFn() instanceof LossMCXENT);
}
 
Example 8
Source File: TrainUtil.java    From FancyBing with GNU General Public License v3.0 5 votes vote down vote up
public static MultiLayerNetwork loadNetwork(String fn, double learningRate) throws Exception {
  	System.err.println("Loading model...");
  	File locationToSave = new File(System.getProperty("user.dir") + "/model/" + fn);
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(locationToSave);

int numLayers = model.getnLayers();
for (int i = 0; i < numLayers; i++) {
	model.getLayer(i).conf().setLearningRateByParam("W", learningRate);
	model.getLayer(i).conf().setLearningRateByParam("b", learningRate);
}
return model;
  }
 
Example 9
Source File: PolicyNetUtil.java    From FancyBing with GNU General Public License v3.0 5 votes vote down vote up
public static MultiLayerNetwork loadNetwork(String fn) throws Exception {
  	System.err.println("Loading model...");
  	File locationToSave = new File(System.getProperty("user.dir") + "/model/" + fn);
  	MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(locationToSave);
	
return model;
  }
 
Example 10
Source File: ModelGuesser.java    From konduit-serving with Apache License 2.0 5 votes vote down vote up
/**
 * Loads a dl4j zip file (either computation graph or multi layer network)
 *
 * @param path the path to the file to load
 * @return a loaded dl4j model
 * @throws Exception if loading a dl4j model fails
 */
public static Model loadDl4jGuess(String path) throws Exception {
    if (isZipFile(new File(path))) {
        log.debug("Loading file " + path);
        boolean compGraph = false;
        try (ZipFile zipFile = new ZipFile(path)) {
            List<String> collect = zipFile.stream().map(ZipEntry::getName)
                    .collect(Collectors.toList());
            log.debug("Entries " + collect);
            if (collect.contains(ModelSerializer.COEFFICIENTS_BIN) && collect.contains(ModelSerializer.CONFIGURATION_JSON)) {
                ZipEntry entry = zipFile.getEntry(ModelSerializer.CONFIGURATION_JSON);
                log.debug("Loaded configuration");
                try (InputStream is = zipFile.getInputStream(entry)) {
                    String configJson = IOUtils.toString(is, StandardCharsets.UTF_8);
                    JSONObject jsonObject = new JSONObject(configJson);
                    if (jsonObject.has("vertexInputs")) {
                        log.debug("Loading computation graph.");
                        compGraph = true;
                    } else {
                        log.debug("Loading multi layer network.");
                    }

                }
            }
        }

        if (compGraph) {
            return ModelSerializer.restoreComputationGraph(new File(path));
        } else {
            return ModelSerializer.restoreMultiLayerNetwork(new File(path));
        }
    }

    return null;
}
 
Example 11
Source File: CustomerRetentionPredictionApi.java    From Java-Deep-Learning-Cookbook with MIT License 5 votes vote down vote up
public static INDArray generateOutput(File inputFile, String modelFilePath) throws IOException, InterruptedException {
    final File modelFile = new File(modelFilePath);
    final MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(modelFile);
    final RecordReader recordReader = generateReader(inputFile);
    //final INDArray array = RecordConverter.toArray(recordReader.next());
    final NormalizerStandardize normalizerStandardize = ModelSerializer.restoreNormalizerFromFile(modelFile);
    //normalizerStandardize.transform(array);
    final DataSetIterator dataSetIterator = new RecordReaderDataSetIterator.Builder(recordReader,1).build();
    normalizerStandardize.fit(dataSetIterator);
    dataSetIterator.setPreProcessor(normalizerStandardize);
    return network.output(dataSetIterator);

}
 
Example 12
Source File: OCNNOutputLayerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void testLabelProbabilities() throws Exception {
        Nd4j.getRandom().setSeed(42);
        DataSetIterator dataSetIterator = getNormalizedIterator();
        MultiLayerNetwork network = getSingleLayer();
        DataSet next = dataSetIterator.next();
        DataSet filtered = next.filterBy(new int[]{0, 1});
        for (int i = 0; i < 10; i++) {
            network.setEpochCount(i);
            network.getLayerWiseConfigurations().setEpochCount(i);
            network.fit(filtered);
        }

        DataSet anomalies = next.filterBy(new int[] {2});
        INDArray output = network.output(anomalies.getFeatures());
        INDArray normalOutput = network.output(anomalies.getFeatures(),false);
        assertEquals(output.lt(0.0).castTo(Nd4j.defaultFloatingPointType()).sumNumber().doubleValue(),
                normalOutput.eq(0.0).castTo(Nd4j.defaultFloatingPointType()).sumNumber().doubleValue(),1e-1);

//        System.out.println("Labels " + anomalies.getLabels());
//        System.out.println("Anomaly output " + normalOutput);
//        System.out.println(output);

        INDArray normalProbs = network.output(filtered.getFeatures());
        INDArray outputForNormalSamples = network.output(filtered.getFeatures(),false);
        System.out.println("Normal probabilities " + normalProbs);
        System.out.println("Normal raw output " + outputForNormalSamples);

        File tmpFile = new File(testDir.getRoot(),"tmp-file-" + UUID.randomUUID().toString());
        ModelSerializer.writeModel(network,tmpFile,true);
        tmpFile.deleteOnExit();

        MultiLayerNetwork multiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork(tmpFile);
        assertEquals(network.params(),multiLayerNetwork.params());
        assertEquals(network.numParams(),multiLayerNetwork.numParams());

    }
 
Example 13
Source File: ImageClassifierAPI.java    From Java-Deep-Learning-Cookbook with MIT License 5 votes vote down vote up
public static INDArray generateOutput(File inputFile, String modelFileLocation) throws IOException, InterruptedException {
    //retrieve the saved model
    final File modelFile = new File(modelFileLocation);
    final MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelFile);
    final RecordReader imageRecordReader = generateReader(inputFile);
    final ImagePreProcessingScaler normalizerStandardize = ModelSerializer.restoreNormalizerFromFile(modelFile);
    final DataSetIterator dataSetIterator = new RecordReaderDataSetIterator.Builder(imageRecordReader,1).build();
    normalizerStandardize.fit(dataSetIterator);
    dataSetIterator.setPreProcessor(normalizerStandardize);
    return model.output(dataSetIterator);
}
 
Example 14
Source File: DigitRecognizerNeuralNetwork.java    From Java-Machine-Learning-for-Computer-Vision with MIT License 4 votes vote down vote up
public void init() throws IOException {
    preTrainedModel = ModelSerializer.restoreMultiLayerNetwork(new File(OUTPUT_DIRECTORY));
}
 
Example 15
Source File: DigitRecognizerConvolutionalNeuralNetwork.java    From Java-Machine-Learning-for-Computer-Vision with MIT License 4 votes vote down vote up
public void init() throws IOException {
    preTrainedModel = ModelSerializer.restoreMultiLayerNetwork(new File(TRAINED_MODEL_FILE));
}
 
Example 16
Source File: ZooModel.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
/**
 * Returns a pretrained model for the given dataset, if available.
 *
 * @param pretrainedType
 * @return
 * @throws IOException
 */
public <M extends Model> M initPretrained(PretrainedType pretrainedType) throws IOException {
    String remoteUrl = pretrainedUrl(pretrainedType);
    if (remoteUrl == null)
        throw new UnsupportedOperationException(
                        "Pretrained " + pretrainedType + " weights are not available for this model.");

    String localFilename = new File(remoteUrl).getName();

    File rootCacheDir = DL4JResources.getDirectory(ResourceType.ZOO_MODEL, modelName());
    File cachedFile = new File(rootCacheDir, localFilename);

    if (!cachedFile.exists()) {
        log.info("Downloading model to " + cachedFile.toString());
        FileUtils.copyURLToFile(new URL(remoteUrl), cachedFile);
    } else {
        log.info("Using cached model at " + cachedFile.toString());
    }

    long expectedChecksum = pretrainedChecksum(pretrainedType);
    if (expectedChecksum != 0L) {
        log.info("Verifying download...");
        Checksum adler = new Adler32();
        FileUtils.checksum(cachedFile, adler);
        long localChecksum = adler.getValue();
        log.info("Checksum local is " + localChecksum + ", expecting " + expectedChecksum);

        if (expectedChecksum != localChecksum) {
            log.error("Checksums do not match. Cleaning up files and failing...");
            cachedFile.delete();
            throw new IllegalStateException(
                            "Pretrained model file failed checksum. If this error persists, please open an issue at https://github.com/deeplearning4j/deeplearning4j.");
        }
    }

    if (modelType() == MultiLayerNetwork.class) {
        return (M) ModelSerializer.restoreMultiLayerNetwork(cachedFile);
    } else if (modelType() == ComputationGraph.class) {
        return (M) ModelSerializer.restoreComputationGraph(cachedFile);
    } else {
        throw new UnsupportedOperationException(
                        "Pretrained models are only supported for MultiLayerNetwork and ComputationGraph.");
    }
}
 
Example 17
Source File: RegressionTest060.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void regressionTestMLP2() throws Exception {

    File f = Resources.asFile("regression_testing/060/060_ModelSerializer_Regression_MLP_2.zip");

    MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);

    MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
    assertEquals(2, conf.getConfs().size());

    DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer();
    assertTrue(l0.getActivationFn() instanceof ActivationLReLU);
    assertEquals(3, l0.getNIn());
    assertEquals(4, l0.getNOut());
    assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn());
    assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
    assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
    assertEquals(new Dropout(0.6), l0.getIDropout());
    assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
    assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l0));
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization());
    assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5);

    OutputLayer l1 = (OutputLayer) conf.getConf(1).getLayer();
    assertEquals("identity", l1.getActivationFn().toString());
    assertTrue(l1.getLossFn() instanceof LossMSE);
    assertEquals(4, l1.getNIn());
    assertEquals(5, l1.getNOut());
    assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn());
    assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater());
    assertEquals(0.15, ((RmsProp)l1.getIUpdater()).getLearningRate(), 1e-6);
    assertEquals(new Dropout(0.6), l1.getIDropout());
    assertEquals(0.1, TestUtils.getL1(l1), 1e-6);
    assertEquals(new WeightDecay(0.2,false), TestUtils.getWeightDecayReg(l1));
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization());
    assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);

    int numParams = (int)net.numParams();
    assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
    int updaterSize = (int) new RmsProp().stateSize(numParams);
    assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
}
 
Example 18
Source File: BidirectionalTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSerialization() throws Exception {

    for(WorkspaceMode wsm : WorkspaceMode.values()) {
        log.info("*** Starting workspace mode: " + wsm);

        Nd4j.getRandom().setSeed(12345);

        MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
                .activation(Activation.TANH)
                .weightInit(WeightInit.XAVIER)
                .trainingWorkspaceMode(wsm)
                .inferenceWorkspaceMode(wsm)
                .updater(new Adam())
                .list()
                .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
                .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
                .layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
                        .nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
                .build();

        MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
        net1.init();

        INDArray in;
        INDArray labels;

        long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 5} : new long[]{3, 5, 10};

        in = Nd4j.rand(inshape);
        labels = Nd4j.rand(inshape);

        net1.fit(in, labels);

        byte[] bytes;
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
            ModelSerializer.writeModel(net1, baos, true);
            bytes = baos.toByteArray();
        }


        MultiLayerNetwork net2 = ModelSerializer.restoreMultiLayerNetwork(new ByteArrayInputStream(bytes), true);


        in = Nd4j.rand(inshape);
        labels = Nd4j.rand(inshape);

        INDArray out1 = net1.output(in);
        INDArray out2 = net2.output(in);

        assertEquals(out1, out2);

        net1.setInput(in);
        net2.setInput(in);
        net1.setLabels(labels);
        net2.setLabels(labels);

        net1.computeGradientAndScore();
        net2.computeGradientAndScore();

        assertEquals(net1.score(), net2.score(), 1e-6);
        assertEquals(net1.gradient().gradient(), net2.gradient().gradient());
    }
}
 
Example 19
Source File: RegressionTest071.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void regressionTestMLP2() throws Exception {

    File f = Resources.asFile("regression_testing/071/071_ModelSerializer_Regression_MLP_2.zip");

    MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);

    MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
    assertEquals(2, conf.getConfs().size());

    DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer();
    assertTrue(l0.getActivationFn() instanceof ActivationLReLU);
    assertEquals(3, l0.getNIn());
    assertEquals(4, l0.getNOut());
    assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn());
    assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
    assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
    assertEquals(new Dropout(0.6), l0.getIDropout());
    assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
    assertEquals(new WeightDecay(0.2,false), TestUtils.getWeightDecayReg(l0));
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization());
    assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5);

    OutputLayer l1 = (OutputLayer) conf.getConf(1).getLayer();
    assertTrue(l1.getActivationFn() instanceof ActivationIdentity);
    assertTrue(l1.getLossFn() instanceof LossMSE);
    assertEquals(4, l1.getNIn());
    assertEquals(5, l1.getNOut());
    assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn());
    assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater());
    assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
    assertEquals(new Dropout(0.6), l1.getIDropout());
    assertEquals(0.1, TestUtils.getL1(l1), 1e-6);
    assertEquals(new WeightDecay(0.2,false), TestUtils.getWeightDecayReg(l1));
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization());
    assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);

    long numParams = net.numParams();
    assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.params());
    int updaterSize = (int) new RmsProp().stateSize(numParams);
    assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray());
}
 
Example 20
Source File: PLNetDyadRanker.java    From AILibs with GNU Affero General Public License v3.0 2 votes vote down vote up
/**
 * Restore a trained model from a given file path. Warning: does not check
 * whether the loaded model is a valid PLNet or conforms to the configuration of
 * the object.
 *
 * @param filePath
 *            The file to load from.
 * @throws IOException
 */
public void loadModelFromFile(final String filePath) throws IOException {
	MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(filePath);
	this.plNet = restored;
}