Java Code Examples for org.deeplearning4j.nn.graph.ComputationGraph#setLayerMaskArrays()

The following examples show how to use org.deeplearning4j.nn.graph.ComputationGraph#setLayerMaskArrays() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: GradientCheckTestsComputationGraph.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testLSTMWithReverseTimeSeriesVertex() {

        Nd4j.getRandom().setSeed(12345);
        ComputationGraphConfiguration conf =
                new NeuralNetConfiguration.Builder().seed(12345)
                        .dataType(DataType.DOUBLE)
                        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                        .dist(new NormalDistribution(0, 1))
                        .updater(new NoOp()).graphBuilder()
                        .addInputs("input").setOutputs("out")
                        .addLayer("lstm_a",
                                new LSTM.Builder().nIn(2).nOut(3)
                                        .activation(Activation.TANH).build(),
                                "input")
                        .addVertex("input_rev", new ReverseTimeSeriesVertex("input"), "input")
                        .addLayer("lstm_b",
                                new LSTM.Builder().nIn(2).nOut(3)
                                        .activation(Activation.TANH).build(),
                                "input_rev")
                        .addVertex("lstm_b_rev", new ReverseTimeSeriesVertex("input"), "lstm_b")
                        .addLayer("out", new RnnOutputLayer.Builder().nIn(3 + 3).nOut(2)
                                        .activation(Activation.SOFTMAX)
                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build(),
                                "lstm_a", "lstm_b_rev")
                        .build();

        ComputationGraph graph = new ComputationGraph(conf);
        graph.init();

        Random r = new Random(12345);
        INDArray input  = Nd4j.rand(new int[] {2, 2, 4});
        INDArray labels = TestUtils.randomOneHotTimeSeries(2, 2, 4);

        if (PRINT_RESULTS) {
            System.out.println("testLSTMWithReverseTimeSeriesVertex()");
//            for (int j = 0; j < graph.getNumLayers(); j++)
//                System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
        }

        boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
                .labels(new INDArray[]{labels}));

        String msg = "testLSTMWithDuplicateToTimeSeries()";
        assertTrue(msg, gradOK);

        //Second: test with input mask arrays.
        INDArray inMask = Nd4j.zeros(3, 5);
        inMask.putRow(0, Nd4j.create(new double[] {1, 1, 1, 0, 0}));
        inMask.putRow(1, Nd4j.create(new double[] {1, 1, 0, 1, 0}));
        inMask.putRow(2, Nd4j.create(new double[] {1, 1, 1, 1, 1}));
        graph.setLayerMaskArrays(new INDArray[] {inMask}, null);
        gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
                .labels(new INDArray[]{labels}));

        assertTrue(msg, gradOK);
        TestUtils.testModelSerialization(graph);
    }
 
Example 2
Source File: GradientCheckTestsComputationGraph.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testBasicStackUnstackVariableLengthTS() {

        int layerSizes = 2;

        Nd4j.getRandom().setSeed(12345);
        ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
                        .dataType(DataType.DOUBLE)
                        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                        .dist(new NormalDistribution(0, 1))
                        .activation(Activation.TANH).updater(new NoOp()).graphBuilder()
                        .addInputs("in1", "in2")
                        .addLayer("d0", new SimpleRnn.Builder().nIn(layerSizes).nOut(layerSizes).build(), "in1")
                        .addLayer("d1", new SimpleRnn.Builder().nIn(layerSizes).nOut(layerSizes).build(), "in2")
                        .addVertex("stack", new StackVertex(), "d0", "d1")
                        .addLayer("d2", new SimpleRnn.Builder().nIn(layerSizes).nOut(layerSizes).build(), "stack")
                        .addVertex("u1", new UnstackVertex(0, 2), "d2").addVertex("u2", new UnstackVertex(1, 2), "d2")
                        .addLayer("p1", new GlobalPoolingLayer.Builder(PoolingType.AVG).build(), "u1")
                        .addLayer("p2", new GlobalPoolingLayer.Builder(PoolingType.AVG).build(), "u2")
                        .addLayer("out1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.L2)
                                        .nIn(layerSizes).nOut(layerSizes).activation(Activation.IDENTITY).build(), "p1")
                        .addLayer("out2", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.L2)
                                        .nIn(layerSizes).nOut(2).activation(Activation.IDENTITY).build(), "p2")
                        .setOutputs("out1", "out2").build();

        ComputationGraph graph = new ComputationGraph(conf);
        graph.init();


        Nd4j.getRandom().setSeed(12345);
        long nParams = graph.numParams();
        INDArray newParams = Nd4j.rand(new long[]{1, nParams});
        graph.setParams(newParams);

        int[] mbSizes = new int[] {1, 2, 3};
        for (int minibatch : mbSizes) {

            INDArray in1 = Nd4j.rand(new int[] {minibatch, layerSizes, 4});
            INDArray in2 = Nd4j.rand(new int[] {minibatch, layerSizes, 5});
            INDArray inMask1 = Nd4j.zeros(minibatch, 4);
            inMask1.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 3)).assign(1);
            INDArray inMask2 = Nd4j.zeros(minibatch, 5);
            inMask2.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4)).assign(1);

            INDArray labels1 = Nd4j.rand(new int[] {minibatch, 2});
            INDArray labels2 = Nd4j.rand(new int[] {minibatch, 2});

            String testName = "testBasicStackUnstackVariableLengthTS() - minibatch = " + minibatch;

            if (PRINT_RESULTS) {
                System.out.println(testName);
//                for (int j = 0; j < graph.getNumLayers(); j++)
//                    System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
            }

            graph.setLayerMaskArrays(new INDArray[] {inMask1, inMask2}, null);

            boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2})
                    .labels(new INDArray[]{labels1, labels2}).inputMask(new INDArray[]{inMask1, inMask2}));

            assertTrue(testName, gradOK);
            TestUtils.testModelSerialization(graph);
        }
    }
 
Example 3
Source File: TestLastTimeStepLayer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testLastTimeStepVertex() {

    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
            .addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder()
                    .nIn(5).nOut(6).dataFormat(rnnDataFormat).build()), "in")
            .setOutputs("lastTS")
            .build();

    ComputationGraph graph = new ComputationGraph(conf);
    graph.init();

    //First: test without input mask array
    Nd4j.getRandom().setSeed(12345);
    Layer l = graph.getLayer("lastTS");
    INDArray in;
    if (rnnDataFormat == RNNFormat.NCW){
        in = Nd4j.rand(3, 5, 6);
    }
    else{
        in = Nd4j.rand(3, 6, 5);
    }
    INDArray outUnderlying = ((LastTimeStepLayer)l).getUnderlying().activate(in, false, LayerWorkspaceMgr.noWorkspaces());
    INDArray expOut;
    if (rnnDataFormat == RNNFormat.NCW){
        expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5));
    }
    else{
        expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.point(5), NDArrayIndex.all());
    }



    //Forward pass:
    INDArray outFwd = l.activate(in, false, LayerWorkspaceMgr.noWorkspaces());
    assertEquals(expOut, outFwd);

    //Second: test with input mask array
    INDArray inMask = Nd4j.zeros(3, 6);
    inMask.putRow(0, Nd4j.create(new double[]{1, 1, 1, 0, 0, 0}));
    inMask.putRow(1, Nd4j.create(new double[]{1, 1, 1, 1, 0, 0}));
    inMask.putRow(2, Nd4j.create(new double[]{1, 1, 1, 1, 1, 0}));
    graph.setLayerMaskArrays(new INDArray[]{inMask}, null);

    expOut = Nd4j.zeros(3, 6);
    if (rnnDataFormat == RNNFormat.NCW){
        expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(2)));
        expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.point(3)));
        expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(4)));
    }
    else{
        expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.point(2), NDArrayIndex.all()));
        expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.point(3), NDArrayIndex.all()));
        expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.point(4), NDArrayIndex.all()));
    }


    outFwd = l.activate(in, false, LayerWorkspaceMgr.noWorkspaces());
    assertEquals(expOut, outFwd);

    TestUtils.testModelSerialization(graph);
}
 
Example 4
Source File: TestGraphNodes.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testLastTimeStepVertex() {

    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
                    .addVertex("lastTS", new LastTimeStepVertex("in"), "in")
                    .addLayer("out", new OutputLayer.Builder().nIn(1).nOut(1).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "lastTS").setOutputs("out")
                    .build();

    ComputationGraph graph = new ComputationGraph(conf);
    graph.init();

    //First: test without input mask array
    Nd4j.getRandom().setSeed(12345);
    INDArray in = Nd4j.rand(new int[] {3, 5, 6});
    INDArray expOut = in.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5));

    GraphVertex gv = graph.getVertex("lastTS");
    gv.setInputs(in);
    //Forward pass:
    INDArray outFwd = gv.doForward(true, LayerWorkspaceMgr.noWorkspaces());
    assertEquals(expOut, outFwd);
    //Backward pass:
    gv.setEpsilon(expOut);
    Pair<Gradient, INDArray[]> pair = gv.doBackward(false, LayerWorkspaceMgr.noWorkspaces());
    INDArray eps = pair.getSecond()[0];
    assertArrayEquals(in.shape(), eps.shape());
    assertEquals(Nd4j.zeros(3, 5, 5),
                    eps.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4, true)));
    assertEquals(expOut, eps.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5)));

    //Second: test with input mask array
    INDArray inMask = Nd4j.zeros(3, 6);
    inMask.putRow(0, Nd4j.create(new double[] {1, 1, 1, 0, 0, 0}));
    inMask.putRow(1, Nd4j.create(new double[] {1, 1, 1, 1, 0, 0}));
    inMask.putRow(2, Nd4j.create(new double[] {1, 1, 1, 1, 1, 0}));
    graph.setLayerMaskArrays(new INDArray[] {inMask}, null);

    expOut = Nd4j.zeros(3, 5);
    expOut.putRow(0, in.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(2)));
    expOut.putRow(1, in.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.point(3)));
    expOut.putRow(2, in.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(4)));

    gv.setInputs(in);
    outFwd = gv.doForward(true, LayerWorkspaceMgr.noWorkspaces());
    assertEquals(expOut, outFwd);

    String json = conf.toJson();
    ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);
    assertEquals(conf, conf2);
}