Java Code Examples for org.deeplearning4j.nn.multilayer.MultiLayerNetwork#rnnTimeStep()

The following examples show how to use org.deeplearning4j.nn.multilayer.MultiLayerNetwork#rnnTimeStep() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TestInvalidInput.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testInvalidRnnTimeStep() {
    //Idea: Using rnnTimeStep with a different number of examples between calls
    //(i.e., not calling reset between time steps)

    for(String layerType : new String[]{"simple", "lstm", "graves"}) {

        Layer l;
        switch (layerType){
            case "simple":
                l = new SimpleRnn.Builder().nIn(5).nOut(5).build();
                break;
            case "lstm":
                l = new LSTM.Builder().nIn(5).nOut(5).build();
                break;
            case "graves":
                l = new GravesLSTM.Builder().nIn(5).nOut(5).build();
                break;
            default:
                throw new RuntimeException();
        }

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
                .layer(l)
                .layer(new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        net.rnnTimeStep(Nd4j.create(3, 5, 10));

        Map<String, INDArray> m = net.rnnGetPreviousState(0);
        assertNotNull(m);
        assertFalse(m.isEmpty());

        try {
            net.rnnTimeStep(Nd4j.create(5, 5, 10));
            fail("Expected Exception - " + layerType);
        } catch (Exception e) {
            log.error("",e);
            String msg = e.getMessage();
            assertTrue(msg, msg != null && msg.contains("rnn") && msg.contains("batch"));
        }
    }
}
 
Example 2
Source File: WorkspaceTests.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testRnnTimeStep() {
    for (WorkspaceMode ws : WorkspaceMode.values()) {
        for (int i = 0; i < 3; i++) {

            System.out.println("Starting test: " + ws + " - " + i);

            NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder()
                    .weightInit(WeightInit.XAVIER)
                    .activation(Activation.TANH)
                    .inferenceWorkspaceMode(ws)
                    .trainingWorkspaceMode(ws)
                    .list();

            ComputationGraphConfiguration.GraphBuilder gb = new NeuralNetConfiguration.Builder()
                    .weightInit(WeightInit.XAVIER)
                    .activation(Activation.TANH)
                    .inferenceWorkspaceMode(ws)
                    .trainingWorkspaceMode(ws)
                    .graphBuilder()
                    .addInputs("in");

            switch (i) {
                case 0:
                    b.layer(new SimpleRnn.Builder().nIn(10).nOut(10).build());
                    b.layer(new SimpleRnn.Builder().nIn(10).nOut(10).build());

                    gb.addLayer("0", new SimpleRnn.Builder().nIn(10).nOut(10).build(), "in");
                    gb.addLayer("1", new SimpleRnn.Builder().nIn(10).nOut(10).build(), "0");
                    break;
                case 1:
                    b.layer(new LSTM.Builder().nIn(10).nOut(10).build());
                    b.layer(new LSTM.Builder().nIn(10).nOut(10).build());

                    gb.addLayer("0", new LSTM.Builder().nIn(10).nOut(10).build(), "in");
                    gb.addLayer("1", new LSTM.Builder().nIn(10).nOut(10).build(), "0");
                    break;
                case 2:
                    b.layer(new GravesLSTM.Builder().nIn(10).nOut(10).build());
                    b.layer(new GravesLSTM.Builder().nIn(10).nOut(10).build());

                    gb.addLayer("0", new GravesLSTM.Builder().nIn(10).nOut(10).build(), "in");
                    gb.addLayer("1", new GravesLSTM.Builder().nIn(10).nOut(10).build(), "0");
                    break;
                default:
                    throw new RuntimeException();
            }

            b.layer(new RnnOutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build());
            gb.addLayer("out", new RnnOutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build(), "1");
            gb.setOutputs("out");

            MultiLayerConfiguration conf = b.build();
            ComputationGraphConfiguration conf2 = gb.build();


            MultiLayerNetwork net = new MultiLayerNetwork(conf);
            net.init();

            ComputationGraph net2 = new ComputationGraph(conf2);
            net2.init();

            for (int j = 0; j < 3; j++) {
                net.rnnTimeStep(Nd4j.rand(new int[]{3, 10, 5}));
            }

            for (int j = 0; j < 3; j++) {
                net2.rnnTimeStep(Nd4j.rand(new int[]{3, 10, 5}));
            }
        }
    }
}
 
Example 3
Source File: ValidateCudnnLSTM.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void validateImplMultiLayerRnnTimeStep() throws Exception {

    for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) {
        Nd4j.getRandom().setSeed(12345);
        int minibatch = 10;
        int inputSize = 3;
        int lstmLayerSize = 4;
        int timeSeriesLength = 3;
        int tbpttLength = 5;
        int nOut = 2;

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp())
                .inferenceWorkspaceMode(WorkspaceMode.NONE).trainingWorkspaceMode(WorkspaceMode.NONE)
                .cacheMode(CacheMode.NONE).seed(12345L)
                .dist(new NormalDistribution(0, 2)).list()
                .layer(0, new LSTM.Builder().nIn(inputSize).nOut(lstmLayerSize)
                        .gateActivationFunction(Activation.SIGMOID).activation(Activation.TANH).build())
                .layer(1, new LSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize)
                        .gateActivationFunction(Activation.SIGMOID).activation(Activation.TANH).build())
                .layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                        .activation(Activation.SOFTMAX).nIn(lstmLayerSize).nOut(nOut).build())
                .backpropType(BackpropType.TruncatedBPTT)
                .tBPTTLength(tbpttLength).build();

        MultiLayerNetwork mln1 = new MultiLayerNetwork(conf.clone());
        mln1.init();

        MultiLayerNetwork mln2 = new MultiLayerNetwork(conf.clone());
        mln2.init();


        assertEquals(mln1.params(), mln2.params());

        Field f = org.deeplearning4j.nn.layers.recurrent.LSTM.class.getDeclaredField("helper");
        f.setAccessible(true);

        Layer l0 = mln1.getLayer(0);
        Layer l1 = mln1.getLayer(1);
        f.set(l0, null);
        f.set(l1, null);
        assertNull(f.get(l0));
        assertNull(f.get(l1));

        l0 = mln2.getLayer(0);
        l1 = mln2.getLayer(1);
        assertTrue(f.get(l0) instanceof CudnnLSTMHelper);
        assertTrue(f.get(l1) instanceof CudnnLSTMHelper);

        Random r = new Random(12345);
        for (int x = 0; x < 5; x++) {
            INDArray input = Nd4j.rand(new int[]{minibatch, inputSize, timeSeriesLength});

            INDArray step1 = mln1.rnnTimeStep(input);
            INDArray step2 = mln2.rnnTimeStep(input);

            assertEquals("Step: " + x, step1, step2);
        }

        assertEquals(mln1.params(), mln2.params());

        //Also check fit (mainly for workspaces sanity check):
        INDArray in = Nd4j.rand(new int[]{minibatch, inputSize, 3 * tbpttLength});
        INDArray label = TestUtils.randomOneHotTimeSeries(minibatch, nOut, 3 * tbpttLength);
        for( int i=0; i<3; i++ ){
            mln1.fit(in, label);
            mln2.fit(in, label);
        }
    }
}