/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ package org.deeplearning4j.nn.multilayer; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer; import org.deeplearning4j.nn.layers.recurrent.GravesLSTM; import org.deeplearning4j.nn.layers.recurrent.LSTM; import org.deeplearning4j.nn.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.params.GravesLSTMParamInitializer; import org.deeplearning4j.nn.weights.WeightInit; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.nd4j.common.primitives.Pair; import java.util.ArrayList; import java.util.List; import java.util.Map; import static org.junit.Assert.*; @Slf4j public class MultiLayerTestRNN extends BaseDL4JTest { @Test public void testGravesLSTMInit() { int nIn = 8; int nOut = 25; int nHiddenUnits = 17; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder() .nIn(nIn).nOut(nHiddenUnits) .activation(Activation.TANH).build()) .layer(1, new RnnOutputLayer.Builder(LossFunction.MSE).nIn(nHiddenUnits) .nOut(nOut) .activation(Activation.TANH).build()) .build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); //Ensure that we have the correct number weights and biases, that these have correct shape etc. Layer layer = network.getLayer(0); assertTrue(layer instanceof GravesLSTM); Map<String, INDArray> paramTable = layer.paramTable(); assertTrue(paramTable.size() == 3); //2 sets of weights, 1 set of biases INDArray recurrentWeights = paramTable.get(GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY); assertArrayEquals(recurrentWeights.shape(), new long[] {nHiddenUnits, 4 * nHiddenUnits + 3}); //Should be shape: [layerSize,4*layerSize+3] INDArray inputWeights = paramTable.get(GravesLSTMParamInitializer.INPUT_WEIGHT_KEY); assertArrayEquals(inputWeights.shape(), new long[] {nIn, 4 * nHiddenUnits}); //Should be shape: [nIn,4*layerSize] INDArray biases = paramTable.get(GravesLSTMParamInitializer.BIAS_KEY); assertArrayEquals(biases.shape(), new long[] {1, 4 * nHiddenUnits}); //Should be shape: [1,4*layerSize] //Want forget gate biases to be initialized to > 0. See parameter initializer for details INDArray forgetGateBiases = biases.get(NDArrayIndex.point(0), NDArrayIndex.interval(nHiddenUnits, 2 * nHiddenUnits)); INDArray gt = forgetGateBiases.gt(0); INDArray gtSum = gt.castTo(DataType.INT).sum(Integer.MAX_VALUE); int count = gtSum.getInt(0); assertEquals(nHiddenUnits, count); val nParams = recurrentWeights.length() + inputWeights.length() + biases.length(); assertTrue(nParams == layer.numParams()); } @Test public void testGravesTLSTMInitStacked() { int nIn = 8; int nOut = 25; int[] nHiddenUnits = {17, 19, 23}; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() .layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(17) .activation(Activation.TANH).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(17).nOut(19) .activation(Activation.TANH).build()) .layer(2, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(19).nOut(23) .activation(Activation.TANH).build()) .layer(3, new RnnOutputLayer.Builder(LossFunction.MSE).nIn(23).nOut(nOut) .activation(Activation.TANH).build()) .build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); //Ensure that we have the correct number weights and biases, that these have correct shape etc. for each layer for (int i = 0; i < nHiddenUnits.length; i++) { Layer layer = network.getLayer(i); assertTrue(layer instanceof GravesLSTM); Map<String, INDArray> paramTable = layer.paramTable(); assertTrue(paramTable.size() == 3); //2 sets of weights, 1 set of biases int layerNIn = (i == 0 ? nIn : nHiddenUnits[i - 1]); INDArray recurrentWeights = paramTable.get(GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY); assertArrayEquals(recurrentWeights.shape(), new long[] {nHiddenUnits[i], 4 * nHiddenUnits[i] + 3}); //Should be shape: [layerSize,4*layerSize+3] INDArray inputWeights = paramTable.get(GravesLSTMParamInitializer.INPUT_WEIGHT_KEY); assertArrayEquals(inputWeights.shape(), new long[] {layerNIn, 4 * nHiddenUnits[i]}); //Should be shape: [nIn,4*layerSize] INDArray biases = paramTable.get(GravesLSTMParamInitializer.BIAS_KEY); assertArrayEquals(biases.shape(), new long[] {1, 4 * nHiddenUnits[i]}); //Should be shape: [1,4*layerSize] //Want forget gate biases to be initialized to > 0. See parameter initializer for details INDArray forgetGateBiases = biases.get(NDArrayIndex.point(0), NDArrayIndex.interval(nHiddenUnits[i], 2 * nHiddenUnits[i])); INDArray gt = forgetGateBiases.gt(0).castTo(DataType.INT); INDArray gtSum = gt.sum(Integer.MAX_VALUE); double count = gtSum.getDouble(0); assertEquals(nHiddenUnits[i], (int)count); val nParams = recurrentWeights.length() + inputWeights.length() + biases.length(); assertTrue(nParams == layer.numParams()); } } @Test public void testRnnStateMethods() { Nd4j.getRandom().setSeed(12345); int timeSeriesLength = 6; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder() .nIn(5).nOut(7).activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7) .nOut(8).activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)) .build()) .layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT) .nIn(8).nOut(4) .activation(Activation.SOFTMAX) .dist(new NormalDistribution(0, 0.5)).build()) .build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf); INDArray input = Nd4j.rand(new int[] {3, 5, timeSeriesLength}); List<INDArray> allOutputActivations = mln.feedForward(input, true); INDArray outAct = allOutputActivations.get(3); INDArray outRnnTimeStep = mln.rnnTimeStep(input); assertTrue(outAct.equals(outRnnTimeStep)); //Should be identical here Map<String, INDArray> currStateL0 = mln.rnnGetPreviousState(0); Map<String, INDArray> currStateL1 = mln.rnnGetPreviousState(1); assertTrue(currStateL0.size() == 2); assertTrue(currStateL1.size() == 2); INDArray lastActL0 = currStateL0.get(GravesLSTM.STATE_KEY_PREV_ACTIVATION); INDArray lastMemL0 = currStateL0.get(GravesLSTM.STATE_KEY_PREV_MEMCELL); assertTrue(lastActL0 != null && lastMemL0 != null); INDArray lastActL1 = currStateL1.get(GravesLSTM.STATE_KEY_PREV_ACTIVATION); INDArray lastMemL1 = currStateL1.get(GravesLSTM.STATE_KEY_PREV_MEMCELL); assertTrue(lastActL1 != null && lastMemL1 != null); INDArray expectedLastActL0 = allOutputActivations.get(1).tensorAlongDimension(timeSeriesLength - 1, 1, 0); assertTrue(expectedLastActL0.equals(lastActL0)); INDArray expectedLastActL1 = allOutputActivations.get(2).tensorAlongDimension(timeSeriesLength - 1, 1, 0); assertTrue(expectedLastActL1.equals(lastActL1)); //Check clearing and setting of state: mln.rnnClearPreviousState(); assertTrue(mln.rnnGetPreviousState(0).isEmpty()); assertTrue(mln.rnnGetPreviousState(1).isEmpty()); mln.rnnSetPreviousState(0, currStateL0); assertTrue(mln.rnnGetPreviousState(0).size() == 2); mln.rnnSetPreviousState(1, currStateL1); assertTrue(mln.rnnGetPreviousState(1).size() == 2); } @Test public void testRnnTimeStepLayers() { for( int layerType=0; layerType<3; layerType++ ) { org.deeplearning4j.nn.conf.layers.Layer l0; org.deeplearning4j.nn.conf.layers.Layer l1; String lastActKey; if(layerType == 0){ l0 = new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(5).nOut(7) .activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)).build(); l1 = new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7).nOut(8) .activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)).build(); lastActKey = GravesLSTM.STATE_KEY_PREV_ACTIVATION; } else if(layerType == 1){ l0 = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().nIn(5).nOut(7) .activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)).build(); l1 = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().nIn(7).nOut(8) .activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)).build(); lastActKey = LSTM.STATE_KEY_PREV_ACTIVATION; } else { l0 = new org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn.Builder().nIn(5).nOut(7) .activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)).build(); l1 = new org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn.Builder().nIn(7).nOut(8) .activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)).build(); lastActKey = SimpleRnn.STATE_KEY_PREV_ACTIVATION; } log.info("Starting test for layer type: {}", l0.getClass().getSimpleName()); Nd4j.getRandom().setSeed(12345); int timeSeriesLength = 12; //4 layer network: 2 GravesLSTM + DenseLayer + RnnOutputLayer. Hence also tests preprocessors. MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list() .layer(0, l0) .layer(1, l1) .layer(2, new DenseLayer.Builder().nIn(8).nOut(9).activation(Activation.TANH) .dist( new NormalDistribution(0, 0.5)) .build()) .layer(3, new RnnOutputLayer.Builder(LossFunction.MCXENT) .nIn(9).nOut(4).activation(Activation.SOFTMAX) .dist(new NormalDistribution(0, 0.5)) .build()) .inputPreProcessor(2, new RnnToFeedForwardPreProcessor()) .inputPreProcessor(3, new FeedForwardToRnnPreProcessor()).build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf); INDArray input = Nd4j.rand(new int[]{3, 5, timeSeriesLength}); List<INDArray> allOutputActivations = mln.feedForward(input, true); INDArray fullOutL0 = allOutputActivations.get(1); INDArray fullOutL1 = allOutputActivations.get(2); INDArray fullOutL3 = allOutputActivations.get(4); int[] inputLengths = {1, 2, 3, 4, 6, 12}; //Do steps of length 1, then of length 2, ..., 12 //Should get the same result regardless of step size; should be identical to standard forward pass for (int i = 0; i < inputLengths.length; i++) { int inLength = inputLengths[i]; int nSteps = timeSeriesLength / inLength; //each of length inLength mln.rnnClearPreviousState(); mln.setInputMiniBatchSize(1); //Reset; should be set by rnnTimeStep method for (int j = 0; j < nSteps; j++) { int startTimeRange = j * inLength; int endTimeRange = startTimeRange + inLength; INDArray inputSubset; if (inLength == 1) { //Workaround to nd4j bug val sizes = new long[]{input.size(0), input.size(1), 1}; inputSubset = Nd4j.create(sizes); inputSubset.tensorAlongDimension(0, 1, 0).assign(input.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(startTimeRange))); } else { inputSubset = input.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(startTimeRange, endTimeRange)); } if (inLength > 1) assertTrue(inputSubset.size(2) == inLength); INDArray out = mln.rnnTimeStep(inputSubset); INDArray expOutSubset; if (inLength == 1) { val sizes = new long[]{fullOutL3.size(0), fullOutL3.size(1), 1}; expOutSubset = Nd4j.create(DataType.FLOAT, sizes); expOutSubset.tensorAlongDimension(0, 1, 0).assign(fullOutL3.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(startTimeRange))); } else { expOutSubset = fullOutL3.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(startTimeRange, endTimeRange)); } assertEquals(expOutSubset, out); Map<String, INDArray> currL0State = mln.rnnGetPreviousState(0); Map<String, INDArray> currL1State = mln.rnnGetPreviousState(1); INDArray lastActL0 = currL0State.get(lastActKey); INDArray lastActL1 = currL1State.get(lastActKey); INDArray expLastActL0 = fullOutL0.tensorAlongDimension(endTimeRange - 1, 1, 0); INDArray expLastActL1 = fullOutL1.tensorAlongDimension(endTimeRange - 1, 1, 0); assertEquals(expLastActL0, lastActL0); assertEquals(expLastActL1, lastActL1); } } } } @Test public void testRnnTimeStep2dInput() { Nd4j.getRandom().setSeed(12345); int timeSeriesLength = 6; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder() .nIn(5).nOut(7).activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7) .nOut(8).activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)) .build()) .layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT) .nIn(8).nOut(4) .activation(Activation.SOFTMAX) .dist(new NormalDistribution(0, 0.5)).build()) .build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); INDArray input3d = Nd4j.rand(new long[] {3, 5, timeSeriesLength}); INDArray out3d = mln.rnnTimeStep(input3d); assertArrayEquals(out3d.shape(), new long[] {3, 4, timeSeriesLength}); mln.rnnClearPreviousState(); for (int i = 0; i < timeSeriesLength; i++) { INDArray input2d = input3d.tensorAlongDimension(i, 1, 0); INDArray out2d = mln.rnnTimeStep(input2d); assertArrayEquals(out2d.shape(), new long[] {3, 4}); INDArray expOut2d = out3d.tensorAlongDimension(i, 1, 0); assertEquals(out2d, expOut2d); } //Check same but for input of size [3,5,1]. Expect [3,4,1] out mln.rnnClearPreviousState(); for (int i = 0; i < timeSeriesLength; i++) { INDArray temp = Nd4j.create(new int[] {3, 5, 1}); temp.tensorAlongDimension(0, 1, 0).assign(input3d.tensorAlongDimension(i, 1, 0)); INDArray out3dSlice = mln.rnnTimeStep(temp); assertArrayEquals(out3dSlice.shape(), new long[] {3, 4, 1}); assertTrue(out3dSlice.tensorAlongDimension(0, 1, 0).equals(out3d.tensorAlongDimension(i, 1, 0))); } } @Test public void testTruncatedBPTTVsBPTT() { //Under some (limited) circumstances, we expect BPTT and truncated BPTT to be identical //Specifically TBPTT over entire data vector int timeSeriesLength = 12; int miniBatchSize = 7; int nIn = 5; int nOut = 4; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) .trainingWorkspaceMode(WorkspaceMode.NONE).inferenceWorkspaceMode(WorkspaceMode.NONE) .list() .layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7) .activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7).nOut(8) .activation(Activation.TANH) .dist( new NormalDistribution(0, 0.5)) .build()) .layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT) .nIn(8).nOut(nOut).activation(Activation.SOFTMAX) .dist(new NormalDistribution(0, 0.5)) .build()) .build(); assertEquals(BackpropType.Standard, conf.getBackpropType()); MultiLayerConfiguration confTBPTT = new NeuralNetConfiguration.Builder().seed(12345) .trainingWorkspaceMode(WorkspaceMode.NONE).inferenceWorkspaceMode(WorkspaceMode.NONE) .list() .layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7) .activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7).nOut(8) .activation(Activation.TANH) .dist( new NormalDistribution(0, 0.5)) .build()) .layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT) .nIn(8).nOut(nOut).activation(Activation.SOFTMAX) .dist(new NormalDistribution(0, 0.5)) .build()) .backpropType(BackpropType.TruncatedBPTT).tBPTTBackwardLength(timeSeriesLength) .tBPTTForwardLength(timeSeriesLength).build(); Nd4j.getRandom().setSeed(12345); MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); Nd4j.getRandom().setSeed(12345); MultiLayerNetwork mlnTBPTT = new MultiLayerNetwork(confTBPTT); mlnTBPTT.init(); mlnTBPTT.clearTbpttState = false; assertEquals(BackpropType.TruncatedBPTT, mlnTBPTT.getLayerWiseConfigurations().getBackpropType()); assertEquals(timeSeriesLength, mlnTBPTT.getLayerWiseConfigurations().getTbpttFwdLength()); assertEquals(timeSeriesLength, mlnTBPTT.getLayerWiseConfigurations().getTbpttBackLength()); INDArray inputData = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); INDArray labels = Nd4j.rand(new int[] {miniBatchSize, nOut, timeSeriesLength}); mln.setInput(inputData); mln.setLabels(labels); mlnTBPTT.setInput(inputData); mlnTBPTT.setLabels(labels); mln.computeGradientAndScore(); mlnTBPTT.computeGradientAndScore(); Pair<Gradient, Double> mlnPair = mln.gradientAndScore(); Pair<Gradient, Double> tbpttPair = mlnTBPTT.gradientAndScore(); assertEquals(mlnPair.getFirst().gradientForVariable(), tbpttPair.getFirst().gradientForVariable()); assertEquals(mlnPair.getSecond(), tbpttPair.getSecond(), 1e-8); //Check states: expect stateMap to be empty but tBpttStateMap to not be Map<String, INDArray> l0StateMLN = mln.rnnGetPreviousState(0); Map<String, INDArray> l0StateTBPTT = mlnTBPTT.rnnGetPreviousState(0); Map<String, INDArray> l1StateMLN = mln.rnnGetPreviousState(0); Map<String, INDArray> l1StateTBPTT = mlnTBPTT.rnnGetPreviousState(0); Map<String, INDArray> l0TBPTTStateMLN = ((BaseRecurrentLayer<?>) mln.getLayer(0)).rnnGetTBPTTState(); Map<String, INDArray> l0TBPTTStateTBPTT = ((BaseRecurrentLayer<?>) mlnTBPTT.getLayer(0)).rnnGetTBPTTState(); Map<String, INDArray> l1TBPTTStateMLN = ((BaseRecurrentLayer<?>) mln.getLayer(1)).rnnGetTBPTTState(); Map<String, INDArray> l1TBPTTStateTBPTT = ((BaseRecurrentLayer<?>) mlnTBPTT.getLayer(1)).rnnGetTBPTTState(); assertTrue(l0StateMLN.isEmpty()); assertTrue(l0StateTBPTT.isEmpty()); assertTrue(l1StateMLN.isEmpty()); assertTrue(l1StateTBPTT.isEmpty()); assertTrue(l0TBPTTStateMLN.isEmpty()); assertEquals(2, l0TBPTTStateTBPTT.size()); assertTrue(l1TBPTTStateMLN.isEmpty()); assertEquals(2, l1TBPTTStateTBPTT.size()); INDArray tbpttActL0 = l0TBPTTStateTBPTT.get(GravesLSTM.STATE_KEY_PREV_ACTIVATION); INDArray tbpttActL1 = l1TBPTTStateTBPTT.get(GravesLSTM.STATE_KEY_PREV_ACTIVATION); List<INDArray> activations = mln.feedForward(inputData, true); INDArray l0Act = activations.get(1); INDArray l1Act = activations.get(2); INDArray expL0Act = l0Act.tensorAlongDimension(timeSeriesLength - 1, 1, 0); INDArray expL1Act = l1Act.tensorAlongDimension(timeSeriesLength - 1, 1, 0); assertEquals(tbpttActL0, expL0Act); assertEquals(tbpttActL1, expL1Act); } @Test public void testRnnActivateUsingStoredState() { int timeSeriesLength = 12; int miniBatchSize = 7; int nIn = 5; int nOut = 4; int nTimeSlices = 5; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7) .activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7) .nOut(8).activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)) .build()) .layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT) .nIn(8).nOut(nOut) .activation(Activation.SOFTMAX) .dist(new NormalDistribution(0, 0.5)).build()) .build(); Nd4j.getRandom().setSeed(12345); MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); INDArray inputLong = Nd4j.rand(new int[] {miniBatchSize, nIn, nTimeSlices * timeSeriesLength}); INDArray input = inputLong.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, timeSeriesLength)); List<INDArray> outStandard = mln.feedForward(input, true); List<INDArray> outRnnAct = mln.rnnActivateUsingStoredState(input, true, true); //As initially state is zeros: expect these to be the same assertEquals(outStandard, outRnnAct); //Furthermore, expect multiple calls to this function to be the same: for (int i = 0; i < 3; i++) { assertEquals(outStandard, mln.rnnActivateUsingStoredState(input, true, true)); } List<INDArray> outStandardLong = mln.feedForward(inputLong, true); BaseRecurrentLayer<?> l0 = ((BaseRecurrentLayer<?>) mln.getLayer(0)); BaseRecurrentLayer<?> l1 = ((BaseRecurrentLayer<?>) mln.getLayer(1)); for (int i = 0; i < nTimeSlices; i++) { INDArray inSlice = inputLong.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i * timeSeriesLength, (i + 1) * timeSeriesLength)); List<INDArray> outSlice = mln.rnnActivateUsingStoredState(inSlice, true, true); List<INDArray> expOut = new ArrayList<>(); for (INDArray temp : outStandardLong) { expOut.add(temp.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i * timeSeriesLength, (i + 1) * timeSeriesLength))); } for (int j = 0; j < expOut.size(); j++) { INDArray exp = expOut.get(j); INDArray act = outSlice.get(j); // System.out.println(j); // System.out.println(exp.sub(act)); assertEquals(exp, act); } assertEquals(expOut, outSlice); //Again, expect multiple calls to give the same output for (int j = 0; j < 3; j++) { outSlice = mln.rnnActivateUsingStoredState(inSlice, true, true); assertEquals(expOut, outSlice); } l0.rnnSetPreviousState(l0.rnnGetTBPTTState()); l1.rnnSetPreviousState(l1.rnnGetTBPTTState()); } } @Test public void testTruncatedBPTTSimple() { //Extremely simple test of the 'does it throw an exception' variety int timeSeriesLength = 12; int miniBatchSize = 7; int nIn = 5; int nOut = 4; int nTimeSlices = 20; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7) .activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7).nOut(8) .activation(Activation.TANH) .dist( new NormalDistribution(0, 0.5)) .build()) .layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT) .nIn(8).nOut(nOut).activation(Activation.SOFTMAX) .dist(new NormalDistribution(0, 0.5)) .build()) .backpropType(BackpropType.TruncatedBPTT) .tBPTTBackwardLength(timeSeriesLength).tBPTTForwardLength(timeSeriesLength).build(); Nd4j.getRandom().setSeed(12345); MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); INDArray inputLong = Nd4j.rand(new int[] {miniBatchSize, nIn, nTimeSlices * timeSeriesLength}); INDArray labelsLong = Nd4j.rand(new int[] {miniBatchSize, nOut, nTimeSlices * timeSeriesLength}); mln.fit(inputLong, labelsLong); } @Test public void testTruncatedBPTTWithMasking() { //Extremely simple test of the 'does it throw an exception' variety int timeSeriesLength = 100; int tbpttLength = 10; int miniBatchSize = 7; int nIn = 5; int nOut = 4; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7) .activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7).nOut(8) .activation(Activation.TANH) .dist( new NormalDistribution(0, 0.5)) .build()) .layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT) .nIn(8).nOut(nOut).activation(Activation.SOFTMAX) .dist(new NormalDistribution(0, 0.5)) .build()) .backpropType(BackpropType.TruncatedBPTT) .tBPTTBackwardLength(tbpttLength).tBPTTForwardLength(tbpttLength).build(); Nd4j.getRandom().setSeed(12345); MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); INDArray features = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); INDArray labels = Nd4j.rand(new int[] {miniBatchSize, nOut, timeSeriesLength}); INDArray maskArrayInput = Nd4j.ones(miniBatchSize, timeSeriesLength); INDArray maskArrayOutput = Nd4j.ones(miniBatchSize, timeSeriesLength); DataSet ds = new DataSet(features, labels, maskArrayInput, maskArrayOutput); mln.fit(ds); } @Test public void testRnnTimeStepWithPreprocessor() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .list() .layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(10) .nOut(10).activation(Activation.TANH).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(10) .nOut(10).activation(Activation.TANH).build()) .layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(10).nOut(10).build()) .inputPreProcessor(0, new FeedForwardToRnnPreProcessor()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); INDArray in = Nd4j.rand(1, 10); net.rnnTimeStep(in); } @Test public void testRnnTimeStepWithPreprocessorGraph() { ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .graphBuilder().addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(10).nOut(10) .activation(Activation.TANH).build(), "in") .addLayer("1", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(10).nOut(10) .activation(Activation.TANH).build(), "0") .addLayer("2", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "1") .setOutputs("2").inputPreProcessor("0", new FeedForwardToRnnPreProcessor()) .build(); ComputationGraph net = new ComputationGraph(conf); net.init(); INDArray in = Nd4j.rand(1, 10); net.rnnTimeStep(in); } @Test public void testTBPTTLongerThanTS() { //Extremely simple test of the 'does it throw an exception' variety int timeSeriesLength = 20; int tbpttLength = 1000; int miniBatchSize = 7; int nIn = 5; int nOut = 4; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.XAVIER).list() .layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7) .activation(Activation.TANH).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7).nOut(8) .activation(Activation.TANH).build()) .layer(2, new RnnOutputLayer.Builder(LossFunction.MSE).nIn(8).nOut(nOut) .activation(Activation.IDENTITY).build()) .backpropType(BackpropType.TruncatedBPTT) .tBPTTBackwardLength(tbpttLength).tBPTTForwardLength(tbpttLength).build(); Nd4j.getRandom().setSeed(12345); MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); INDArray features = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); INDArray labels = Nd4j.rand(new int[] {miniBatchSize, nOut, timeSeriesLength}); INDArray maskArrayInput = Nd4j.ones(miniBatchSize, timeSeriesLength); INDArray maskArrayOutput = Nd4j.ones(miniBatchSize, timeSeriesLength); DataSet ds = new DataSet(features, labels, maskArrayInput, maskArrayOutput); INDArray initialParams = mln.params().dup(); mln.fit(ds); INDArray afterParams = mln.params(); assertNotEquals(initialParams, afterParams); } @Test public void testInvalidTPBTT() { int nIn = 8; int nOut = 25; int nHiddenUnits = 17; try { new NeuralNetConfiguration.Builder() .list() .layer(new org.deeplearning4j.nn.conf.layers.LSTM.Builder().nIn(nIn).nOut(nHiddenUnits).build()) .layer(new GlobalPoolingLayer()) .layer(new OutputLayer.Builder(LossFunction.MSE).nIn(nHiddenUnits) .nOut(nOut) .activation(Activation.TANH).build()) .backpropType(BackpropType.TruncatedBPTT) .build(); fail("Exception expected"); } catch (IllegalStateException e){ log.info(e.toString()); assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig")); } } @Test public void testWrapperLayerGetPreviousState(){ MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .list() .layer(new FrozenLayer(new org.deeplearning4j.nn.conf.layers.LSTM.Builder() .nIn(5).nOut(5).build())) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); INDArray in = Nd4j.create(1, 5, 2); net.rnnTimeStep(in); Map<String,INDArray> m = net.rnnGetPreviousState(0); assertNotNull(m); assertEquals(2, m.size()); //activation and cell state net.rnnSetPreviousState(0, m); ComputationGraph cg = net.toComputationGraph(); cg.rnnTimeStep(in); m = cg.rnnGetPreviousState(0); assertNotNull(m); assertEquals(2, m.size()); //activation and cell state cg.rnnSetPreviousState(0, m); } }