Java Code Examples for org.deeplearning4j.nn.graph.ComputationGraph#getConfiguration()

The following examples show how to use org.deeplearning4j.nn.graph.ComputationGraph#getConfiguration() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TestUtils.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static ComputationGraph testModelSerialization(ComputationGraph net){
    ComputationGraph restored;
    try {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        ModelSerializer.writeModel(net, baos, true);
        byte[] bytes = baos.toByteArray();

        ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
        restored = ModelSerializer.restoreComputationGraph(bais, true);

        assertEquals(net.getConfiguration(), restored.getConfiguration());
        assertEquals(net.params(), restored.params());
    } catch (IOException e){
        //Should never happen
        throw new RuntimeException(e);
    }

    //Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
    ComputationGraphConfiguration conf = net.getConfiguration();
    serializeDeserializeJava(conf);

    return restored;
}
 
Example 2
Source File: DTypeTests.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static void logUsedClasses(ComputationGraph net) {
    ComputationGraphConfiguration conf = net.getConfiguration();
    for (GraphVertex gv : conf.getVertices().values()) {
        seenVertices.add(gv.getClass());
        if (gv instanceof LayerVertex) {
            seenLayers.add(((LayerVertex) gv).getLayerConf().getLayer().getClass());
            InputPreProcessor ipp = ((LayerVertex) gv).getPreProcessor();
            if (ipp != null) {
                seenPreprocs.add(ipp.getClass());
            }
        } else if (gv instanceof PreprocessorVertex) {
            seenPreprocs.add(((PreprocessorVertex) gv).getPreProcessor().getClass());
        }
    }

}
 
Example 3
Source File: TestUtils.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static ComputationGraph testModelSerialization(ComputationGraph net){
    ComputationGraph restored;
    try {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        ModelSerializer.writeModel(net, baos, true);
        byte[] bytes = baos.toByteArray();

        ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
        restored = ModelSerializer.restoreComputationGraph(bais, true);

        assertEquals(net.getConfiguration(), restored.getConfiguration());
        assertEquals(net.params(), restored.params());
    } catch (IOException e){
        //Should never happen
        throw new RuntimeException(e);
    }

    //Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
    ComputationGraphConfiguration conf = net.getConfiguration();
    serializeDeserializeJava(conf);

    return restored;
}
 
Example 4
Source File: TestUtils.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static ComputationGraph testModelSerialization(ComputationGraph net){

        ComputationGraph restored;
        try {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ModelSerializer.writeModel(net, baos, true);
            byte[] bytes = baos.toByteArray();

            ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
            restored = ModelSerializer.restoreComputationGraph(bais, true);

            assertEquals(net.getConfiguration(), restored.getConfiguration());
            assertEquals(net.params(), restored.params());
        } catch (IOException e){
            //Should never happen
            throw new RuntimeException(e);
        }

        //Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
        ComputationGraphConfiguration conf = net.getConfiguration();
        serializeDeserializeJava(conf);

        return restored;
    }
 
Example 5
Source File: TestUtils.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static ComputationGraph testModelSerialization(ComputationGraph net){

        ComputationGraph restored;
        try {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ModelSerializer.writeModel(net, baos, true);
            byte[] bytes = baos.toByteArray();

            ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
            restored = ModelSerializer.restoreComputationGraph(bais, true);

            assertEquals(net.getConfiguration(), restored.getConfiguration());
            assertEquals(net.params(), restored.params());
        } catch (IOException e){
            //Should never happen
            throw new RuntimeException(e);
        }

        //Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
        ComputationGraphConfiguration conf = net.getConfiguration();
        serializeDeserializeJava(conf);

        return restored;
    }
 
Example 6
Source File: RegressionTest080.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void regressionTestCGLSTM1() throws Exception {

    File f = Resources.asFile("regression_testing/080/080_ModelSerializer_Regression_CG_LSTM_1.zip");

    ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true);

    ComputationGraphConfiguration conf = net.getConfiguration();
    assertEquals(3, conf.getVertices().size());

    GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer();
    assertTrue(l0.getActivationFn() instanceof ActivationTanH);
    assertEquals(3, l0.getNIn());
    assertEquals(4, l0.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization());
    assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5);

    GravesBidirectionalLSTM l1 =
                    (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer();
    assertTrue(l1.getActivationFn() instanceof ActivationSoftSign);
    assertEquals(4, l1.getNIn());
    assertEquals(4, l1.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization());
    assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);

    RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer();
    assertEquals(4, l2.getNIn());
    assertEquals(5, l2.getNOut());
    assertTrue(l2.getActivationFn() instanceof ActivationSoftmax);
    assertTrue(l2.getLossFn() instanceof LossMCXENT);
}
 
Example 7
Source File: RegressionTest071.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void regressionTestCGLSTM1() throws Exception {
    File f = Resources.asFile("regression_testing/071/071_ModelSerializer_Regression_CG_LSTM_1.zip");

    ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true);

    ComputationGraphConfiguration conf = net.getConfiguration();
    assertEquals(3, conf.getVertices().size());

    GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer();
    assertEquals("tanh", l0.getActivationFn().toString());
    assertEquals(3, l0.getNIn());
    assertEquals(4, l0.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization());
    assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5);

    GravesBidirectionalLSTM l1 =
                    (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer();
    assertEquals("softsign", l1.getActivationFn().toString());
    assertEquals(4, l1.getNIn());
    assertEquals(4, l1.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization());
    assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);

    RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer();
    assertEquals(4, l2.getNIn());
    assertEquals(5, l2.getNOut());
    assertEquals("softmax", l2.getActivationFn().toString());
    assertTrue(l2.getLossFn() instanceof LossMCXENT);
}
 
Example 8
Source File: RegressionTest060.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void regressionTestCGLSTM1() throws Exception {

    File f = Resources.asFile("regression_testing/060/060_ModelSerializer_Regression_CG_LSTM_1.zip");

    ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true);

    ComputationGraphConfiguration conf = net.getConfiguration();
    assertEquals(3, conf.getVertices().size());

    GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer();
    assertEquals("tanh", l0.getActivationFn().toString());
    assertEquals(3, l0.getNIn());
    assertEquals(4, l0.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization());
    assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5);

    GravesBidirectionalLSTM l1 =
                    (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer();
    assertEquals("softsign", l1.getActivationFn().toString());
    assertEquals(4, l1.getNIn());
    assertEquals(4, l1.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization());
    assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);

    RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer();
    assertEquals(4, l2.getNIn());
    assertEquals(5, l2.getNOut());
    assertEquals("softmax", l2.getActivationFn().toString());
    assertTrue(l2.getLossFn() instanceof LossMCXENT);
}