Java Code Examples for org.deeplearning4j.nn.graph.ComputationGraph#feedForward()

The following examples show how to use org.deeplearning4j.nn.graph.ComputationGraph#feedForward() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TestConvolution.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void validateXceptionImport() throws Exception {
    File dir = testDir.newFolder();
    File fSource = Resources.asFile("modelimport/keras/examples/xception/xception_tf_keras_2.h5");
    File fExtracted = new File(dir, "xception_tf_keras_2.h5" );
    FileUtils.copyFile(fSource, fExtracted);

    int inSize = 256;
    ComputationGraph model = KerasModelImport.importKerasModelAndWeights( fExtracted.getAbsolutePath(), new int[]{inSize, inSize, 3}, false);
    model = model.convertDataType(DataType.DOUBLE);

    INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, inSize, inSize, 3});      //Keras import model -> NHWC

    CuDNNTestUtils.assertHelpersPresent(model.getLayers());
    Map<String,INDArray> withCudnn = model.feedForward(in, false);

    CuDNNTestUtils.removeHelpers(model.getLayers());
    CuDNNTestUtils.assertHelpersAbsent(model.getLayers());
    Map<String,INDArray> noCudnn = model.feedForward(in, false);

    assertEquals(withCudnn.keySet(), noCudnn.keySet());

    for(String s : withCudnn.keySet()){
        assertEquals(s, withCudnn.get(s), noCudnn.get(s));
    }
}
 
Example 2
Source File: TestModels.java    From Java-Machine-Learning-for-Computer-Vision with MIT License 5 votes vote down vote up
private static INDArray getEmbeddings(ComputationGraph vgg16, File image) throws IOException {
    INDArray indArray = LOADER.asMatrix(image);
    IMAGE_PRE_PROCESSOR.preProcess(indArray);
    Map<String, INDArray> stringINDArrayMap = vgg16.feedForward(indArray, false);
    INDArray embeddings = stringINDArrayMap.get("embeddings");
    return embeddings;
}
 
Example 3
Source File: NeuralStyleTransfer.java    From dl4j-tutorials with MIT License 5 votes vote down vote up
private void transferStyle() throws IOException {

        ComputationGraph vgg16FineTune = loadModel();
        INDArray content = loadImage(CONTENT_FILE);
        INDArray style = loadImage(STYLE_FILE);
        INDArray combination = createCombinationImage();
        Map<String, INDArray> activationsContentMap = vgg16FineTune.feedForward(content, true);
        Map<String, INDArray> activationsStyleMap = vgg16FineTune.feedForward(style, true);
        HashMap<String, INDArray> activationsStyleGramMap = buildStyleGramValues(activationsStyleMap);
        AdamUpdater adamUpdater = createADAMUpdater();
        for (int iteration = 0; iteration < ITERATIONS; iteration++) {
            log.info("iteration  " + iteration);

            INDArray[] input = new INDArray[] { combination };
            Map<String, INDArray> activationsCombMap = vgg16FineTune.feedForward(input, true, false);

            INDArray styleBackProb = backPropagateStyles(vgg16FineTune, activationsStyleGramMap, activationsCombMap);
            INDArray backPropContent = backPropagateContent(vgg16FineTune, activationsContentMap, activationsCombMap);
            INDArray backPropAllValues = backPropContent.muli(ALPHA).addi(styleBackProb.muli(BETA));
            adamUpdater.applyUpdater(backPropAllValues, iteration, 0);
            combination.subi(backPropAllValues);

            log.info("Total Loss: " + totalLoss(activationsStyleMap, activationsCombMap, activationsContentMap));
            if (iteration % SAVE_IMAGE_CHECKPOINT == 0) {
                //save image can be found at target/classes/styletransfer/out
                saveImage(combination.dup(), iteration);
            }
        }

    }
 
Example 4
Source File: KerasModelConverter.java    From wekaDeeplearning4j with GNU General Public License v3.0 5 votes vote down vote up
private static void saveH5File(File modelFile, File outputFolder) {
    try {
        INDArray testShape = Nd4j.zeros(1, 3, 224, 224);
        String modelName = modelFile.getName();
        Method method = null;
        try {
            method = InputType.class.getMethod("setDefaultCNN2DFormat", CNN2DFormat.class);
            method.invoke(null, CNN2DFormat.NCHW);
        } catch (NoSuchMethodException ex) {
            System.err.println("setDefaultCNN2DFormat() not found on InputType class... " +
                    "Are you using the custom built deeplearning4j-nn.jar?");
            System.exit(1);
        }

        if (modelName.contains("EfficientNet")) {
            // Fixes for EfficientNet family of models
            testShape = Nd4j.zeros(1, 224, 224, 3);
            method.invoke(null, CNN2DFormat.NHWC);
            // We don't want the resulting .zip files to have 'Fixed' in the name, so we'll strip it off here
            modelName = modelName.replace("Fixed", "");
        }
        ComputationGraph kerasModel = KerasModelImport.importKerasModelAndWeights(modelFile.getAbsolutePath());
        kerasModel.feedForward(testShape, false);
        // e.g. ResNet50.h5 -> KerasResNet50.zip
        modelName = "Keras" + modelName.replace(".h5", ".zip");
        String newZip = Paths.get(outputFolder.getPath(), modelName).toString();
        kerasModel.save(new File(newZip));
        System.out.println("Saved file " + newZip);
    } catch (Exception e) {
        System.err.println("\n\nCouldn't save " + modelFile.getName());
        e.printStackTrace();
    }
}
 
Example 5
Source File: KerasModelConverter.java    From wekaDeeplearning4j with GNU General Public License v3.0 5 votes vote down vote up
private static void saveH5File(File modelFile, File outputFolder) {
    try {
        INDArray testShape = Nd4j.zeros(1, 3, 224, 224);
        String modelName = modelFile.getName();
        Method method = null;
        try {
            method = InputType.class.getMethod("setDefaultCNN2DFormat", CNN2DFormat.class);
            method.invoke(null, CNN2DFormat.NCHW);
        } catch (NoSuchMethodException ex) {
            System.err.println("setDefaultCNN2DFormat() not found on InputType class... " +
                    "Are you using the custom built deeplearning4j-nn.jar?");
            System.exit(1);
        }

        if (modelName.contains("EfficientNet")) {
            // Fixes for EfficientNet family of models
            testShape = Nd4j.zeros(1, 224, 224, 3);
            method.invoke(null, CNN2DFormat.NHWC);
            // We don't want the resulting .zip files to have 'Fixed' in the name, so we'll strip it off here
            modelName = modelName.replace("Fixed", "");
        }
        ComputationGraph kerasModel = KerasModelImport.importKerasModelAndWeights(modelFile.getAbsolutePath());
        kerasModel.feedForward(testShape, false);
        // e.g. ResNet50.h5 -> KerasResNet50.zip
        modelName = "Keras" + modelName.replace(".h5", ".zip");
        String newZip = Paths.get(outputFolder.getPath(), modelName).toString();
        kerasModel.save(new File(newZip));
        System.out.println("Saved file " + newZip);
    } catch (Exception e) {
        System.err.println("\n\nCouldn't save " + modelFile.getName());
        e.printStackTrace();
    }
}
 
Example 6
Source File: WorkspaceTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void testWithPreprocessorsCG() {
        //https://github.com/deeplearning4j/deeplearning4j/issues/4347
        //Cause for the above issue was layerVertex.setInput() applying the preprocessor, with the result
        // not being detached properly from the workspace...

        for (WorkspaceMode wm : WorkspaceMode.values()) {
            System.out.println(wm);
            ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
                    .trainingWorkspaceMode(wm)
                    .inferenceWorkspaceMode(wm)
                    .graphBuilder()
                    .addInputs("in")
                    .addLayer("e", new GravesLSTM.Builder().nIn(10).nOut(5).build(), new DupPreProcessor(), "in")
//                .addLayer("e", new GravesLSTM.Builder().nIn(10).nOut(5).build(), "in")    //Note that no preprocessor is OK
                    .addLayer("rnn", new GravesLSTM.Builder().nIn(5).nOut(8).build(), "e")
                    .addLayer("out", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE)
                            .activation(Activation.SIGMOID).nOut(3).build(), "rnn")
                    .setInputTypes(InputType.recurrent(10))
                    .setOutputs("out")
                    .build();

            ComputationGraph cg = new ComputationGraph(conf);
            cg.init();


            INDArray[] input = new INDArray[]{Nd4j.zeros(1, 10, 5)};

            for (boolean train : new boolean[]{false, true}) {
                cg.clear();
                cg.feedForward(input, train);
            }

            cg.setInputs(input);
            cg.setLabels(Nd4j.rand(new int[]{1, 3, 5}));
            cg.computeGradientAndScore();
        }
    }
 
Example 7
Source File: TestGraphNodes.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testStackVertexEmbedding() {
    Nd4j.getRandom().setSeed(12345);
    GraphVertex unstack = new StackVertex(null, "", -1, Nd4j.dataType());

    INDArray in1 = Nd4j.zeros(5, 1);
    INDArray in2 = Nd4j.zeros(5, 1);
    for (int i = 0; i < 5; i++) {
        in1.putScalar(i, 0, i);
        in2.putScalar(i, 0, i);
    }

    INDArray l = Nd4j.rand(5, 5);
    MultiDataSet ds = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] {in1, in2}, new INDArray[] {l, l},
                    null, null);


    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in1", "in2")
                    .addVertex("stack", new org.deeplearning4j.nn.conf.graph.StackVertex(), "in1", "in2")
                    .addLayer("1", new EmbeddingLayer.Builder().nIn(5).nOut(5).build(), "stack")
                    .addVertex("unstack1", new org.deeplearning4j.nn.conf.graph.UnstackVertex(0, 2), "1")
                    .addVertex("unstack2", new org.deeplearning4j.nn.conf.graph.UnstackVertex(0, 2), "1")
                    .addLayer("out1", new OutputLayer.Builder().activation(Activation.TANH)
                                    .lossFunction(LossFunctions.LossFunction.L2).nIn(5).nOut(5).build(), "unstack1")
                    .addLayer("out2", new OutputLayer.Builder().activation(Activation.TANH)
                                    .lossFunction(LossFunctions.LossFunction.L2).nIn(5).nOut(5).build(), "unstack2")
                    .setOutputs("out1", "out2").build();

    ComputationGraph g = new ComputationGraph(conf);
    g.init();

    g.feedForward(new INDArray[] {in1, in2}, false);

    g.fit(ds);

}
 
Example 8
Source File: GradientCheckTestsComputationGraph.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testBasicIrisTripletStackingL2Loss() {
        Nd4j.getRandom().setSeed(12345);
        ComputationGraphConfiguration conf =
                        new NeuralNetConfiguration.Builder().seed(12345)
                                        .dataType(DataType.DOUBLE)
                                        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                                        .dist(new NormalDistribution(0, 1))
                                        .updater(new NoOp()).graphBuilder()
                                        .addInputs("input1", "input2", "input3")
                                        .addVertex("stack1", new StackVertex(), "input1", "input2", "input3")
                                        .addLayer("l1", new DenseLayer.Builder().nIn(4).nOut(5)
                                                        .activation(Activation.TANH).build(), "stack1")
                                        .addVertex("unstack0", new UnstackVertex(0, 3), "l1")
                                        .addVertex("unstack1", new UnstackVertex(1, 3), "l1")
                                        .addVertex("unstack2", new UnstackVertex(2, 3), "l1")
                                        .addVertex("l2-1", new L2Vertex(), "unstack1", "unstack0") // x - x-
                                        .addVertex("l2-2", new L2Vertex(), "unstack1", "unstack2") // x - x+
                                        .addLayer("lossLayer",
                                                        new LossLayer.Builder()
                                                                        .lossFunction(LossFunctions.LossFunction.MCXENT)
                                                                        .activation(Activation.SOFTMAX).build(),
                                                        "l2-1", "l2-2")
                                        .setOutputs("lossLayer").build();

        ComputationGraph graph = new ComputationGraph(conf);
        graph.init();

        int numParams = (4 * 5 + 5);
        assertEquals(numParams, graph.numParams());

        Nd4j.getRandom().setSeed(12345);
        long nParams = graph.numParams();
        INDArray newParams = Nd4j.rand(new long[]{1, nParams});
        graph.setParams(newParams);

        INDArray pos = Nd4j.rand(150, 4);
        INDArray anc = Nd4j.rand(150, 4);
        INDArray neg = Nd4j.rand(150, 4);

        INDArray labels = Nd4j.zeros(150, 2);
        Random r = new Random(12345);
        for (int i = 0; i < 150; i++) {
            labels.putScalar(i, r.nextInt(2), 1.0);
        }


        Map<String, INDArray> out = graph.feedForward(new INDArray[] {pos, anc, neg}, true);

//        for (String s : out.keySet()) {
//            System.out.println(s + "\t" + Arrays.toString(out.get(s).shape()));
//        }

        if (PRINT_RESULTS) {
            System.out.println("testBasicIrisTripletStackingL2Loss()");
//            for (int j = 0; j < graph.getNumLayers(); j++)
//                System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
        }

        boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{pos, anc, neg})
                .labels(new INDArray[]{labels}));

        String msg = "testBasicIrisTripletStackingL2Loss()";
        assertTrue(msg, gradOK);
        TestUtils.testModelSerialization(graph);
    }
 
Example 9
Source File: LocallyConnectedLayerTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testLocallyConnected(){
        for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
            Nd4j.setDefaultDataTypes(globalDtype, globalDtype);
            for (DataType networkDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
                assertEquals(globalDtype, Nd4j.dataType());
                assertEquals(globalDtype, Nd4j.defaultFloatingPointType());

                for (int test = 0; test < 2; test++) {
                    String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype + ", test=" + test;

                    ComputationGraphConfiguration.GraphBuilder b = new NeuralNetConfiguration.Builder()
                            .dataType(networkDtype)
                            .seed(123)
                            .updater(new NoOp())
                            .weightInit(WeightInit.XAVIER)
                            .convolutionMode(ConvolutionMode.Same)
                            .graphBuilder();

                    INDArray[] in;
                    INDArray label;
                    switch (test){
                        case 0:
                            b.addInputs("in")
                                    .addLayer("1", new LSTM.Builder().nOut(5).build(), "in")
                                    .addLayer("2", new LocallyConnected1D.Builder().kernelSize(2).nOut(4).build(), "1")
                                    .addLayer("out", new RnnOutputLayer.Builder().nOut(10).build(), "2")
                                    .setOutputs("out")
                                    .setInputTypes(InputType.recurrent(5, 4));
                            in = new INDArray[]{Nd4j.rand(networkDtype, 2, 5, 4)};
                            label = TestUtils.randomOneHotTimeSeries(2, 10, 4).castTo(networkDtype);
                            break;
                        case 1:
                            b.addInputs("in")
                                    .addLayer("1", new ConvolutionLayer.Builder().kernelSize(2,2).nOut(5).convolutionMode(ConvolutionMode.Same).build(), "in")
                                    .addLayer("2", new LocallyConnected2D.Builder().kernelSize(2,2).nOut(5).build(), "1")
                                    .addLayer("out", new OutputLayer.Builder().nOut(10).build(), "2")
                                    .setOutputs("out")
//                                    .setInputTypes(InputType.convolutional(28, 28, 1));
//                            in = new INDArray[]{Nd4j.rand(networkDtype, 2, 1, 28, 28)};
                                    .setInputTypes(InputType.convolutional(8, 8, 1));
                            in = new INDArray[]{Nd4j.rand(networkDtype, 2, 1, 8, 8)};
                            label = TestUtils.randomOneHot(2, 10).castTo(networkDtype);
                            break;
                        default:
                            throw new RuntimeException();
                    }

                    ComputationGraph net = new ComputationGraph(b.build());
                    net.init();

                    INDArray out = net.outputSingle(in);
                    assertEquals(msg, networkDtype, out.dataType());
                    Map<String, INDArray> ff = net.feedForward(in, false);
                    for (Map.Entry<String, INDArray> e : ff.entrySet()) {
                        if (e.getKey().equals("in"))
                            continue;
                        String s = msg + " - layer: " + e.getKey();
                        assertEquals(s, networkDtype, e.getValue().dataType());
                    }

                    net.setInputs(in);
                    net.setLabels(label);
                    net.computeGradientAndScore();

                    net.fit(new MultiDataSet(in, new INDArray[]{label}));
                }
            }
        }
    }
 
Example 10
Source File: DTypeTests.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testEmbeddingDtypes() {
    for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
        Nd4j.setDefaultDataTypes(globalDtype, globalDtype);
        for (DataType networkDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
            for (boolean frozen : new boolean[]{false, true}) {
                for (int test = 0; test < 3; test++) {
                    assertEquals(globalDtype, Nd4j.dataType());
                    assertEquals(globalDtype, Nd4j.defaultFloatingPointType());

                    String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype + ", test=" + test;

                    ComputationGraphConfiguration.GraphBuilder conf = new NeuralNetConfiguration.Builder()
                            .dataType(networkDtype)
                            .seed(123)
                            .updater(new NoOp())
                            .weightInit(new WeightInitDistribution(new UniformDistribution(-6, 6)))
                            .graphBuilder()
                            .addInputs("in")
                            .setOutputs("out");

                    INDArray input;
                    if (test == 0) {
                        if (frozen) {
                            conf.layer("0", new FrozenLayer(new EmbeddingLayer.Builder().nIn(5).nOut(5).build()), "in");
                        } else {
                            conf.layer("0", new EmbeddingLayer.Builder().nIn(5).nOut(5).build(), "in");
                        }
                        input = Nd4j.rand(networkDtype, 10, 1).muli(5).castTo(DataType.INT);
                        conf.setInputTypes(InputType.feedForward(1));
                    } else if (test == 1) {
                        if (frozen) {
                            conf.layer("0", new FrozenLayer(new EmbeddingSequenceLayer.Builder().nIn(5).nOut(5).build()), "in");
                        } else {
                            conf.layer("0", new EmbeddingSequenceLayer.Builder().nIn(5).nOut(5).build(), "in");
                        }
                        conf.layer("gp", new GlobalPoolingLayer.Builder(PoolingType.PNORM).pnorm(2).poolingDimensions(2).build(), "0");
                        input = Nd4j.rand(networkDtype, 10, 1, 5).muli(5).castTo(DataType.INT);
                        conf.setInputTypes(InputType.recurrent(1));
                    } else {
                        conf.layer("0", new RepeatVector.Builder().repetitionFactor(5).nOut(5).build(), "in");
                        conf.layer("gp", new GlobalPoolingLayer.Builder(PoolingType.SUM).build(), "0");
                        input = Nd4j.rand(networkDtype, 10, 5);
                        conf.setInputTypes(InputType.feedForward(5));
                    }

                    conf.appendLayer("el", new ElementWiseMultiplicationLayer.Builder().nOut(5).build())
                            .appendLayer("ae", new AutoEncoder.Builder().nOut(5).build())
                            .appendLayer("prelu", new PReLULayer.Builder().nOut(5).inputShape(5).build())
                            .appendLayer("out", new OutputLayer.Builder().nOut(10).build());

                    ComputationGraph net = new ComputationGraph(conf.build());
                    net.init();

                    INDArray label = Nd4j.zeros(networkDtype, 10, 10);

                    INDArray out = net.outputSingle(input);
                    assertEquals(msg, networkDtype, out.dataType());
                    Map<String, INDArray> ff = net.feedForward(input, false);
                    for (Map.Entry<String, INDArray> e : ff.entrySet()) {
                        if (e.getKey().equals("in"))
                            continue;
                        String s = msg + " - layer: " + e.getKey();
                        assertEquals(s, networkDtype, e.getValue().dataType());
                    }

                    net.setInput(0, input);
                    net.setLabels(label);
                    net.computeGradientAndScore();

                    net.fit(new DataSet(input, label));

                    logUsedClasses(net);

                    //Now, test mismatched dtypes for input/labels:
                    for (DataType inputLabelDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
                        INDArray in2 = input.castTo(inputLabelDtype);
                        INDArray label2 = label.castTo(inputLabelDtype);
                        net.output(in2);
                        net.setInput(0, in2);
                        net.setLabels(label2);
                        net.computeGradientAndScore();

                        net.fit(new DataSet(in2, label2));
                    }
                }
            }
        }
    }
}
 
Example 11
Source File: DTypeTests.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testLocallyConnected() {
    for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
        Nd4j.setDefaultDataTypes(globalDtype, globalDtype);
        for (DataType networkDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
            assertEquals(globalDtype, Nd4j.dataType());
            assertEquals(globalDtype, Nd4j.defaultFloatingPointType());

            INDArray[] in = null;
            for (int test = 0; test < 2; test++) {
                String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype + ", test=" + test;

                ComputationGraphConfiguration.GraphBuilder b = new NeuralNetConfiguration.Builder()
                        .dataType(networkDtype)
                        .seed(123)
                        .updater(new NoOp())
                        .weightInit(WeightInit.XAVIER)
                        .convolutionMode(ConvolutionMode.Same)
                        .graphBuilder();

                INDArray label;
                switch (test) {
                    case 0:
                        b.addInputs("in")
                                .addLayer("1", new LSTM.Builder().nOut(5).build(), "in")
                                .addLayer("2", new LocallyConnected1D.Builder().kernelSize(2).nOut(4).build(), "1")
                                .addLayer("out", new RnnOutputLayer.Builder().nOut(10).build(), "2")
                                .setOutputs("out")
                                .setInputTypes(InputType.recurrent(5, 2));
                        in = new INDArray[]{Nd4j.rand(networkDtype, 2, 5, 2)};
                        label = TestUtils.randomOneHotTimeSeries(2, 10, 2);
                        break;
                    case 1:
                        b.addInputs("in")
                                .addLayer("1", new ConvolutionLayer.Builder().kernelSize(2, 2).nOut(5).convolutionMode(ConvolutionMode.Same).build(), "in")
                                .addLayer("2", new LocallyConnected2D.Builder().kernelSize(2, 2).nOut(5).build(), "1")
                                .addLayer("out", new OutputLayer.Builder().nOut(10).build(), "2")
                                .setOutputs("out")
                                .setInputTypes(InputType.convolutional(8, 8, 1));
                        in = new INDArray[]{Nd4j.rand(networkDtype, 2, 1, 8, 8)};
                        label = TestUtils.randomOneHot(2, 10).castTo(networkDtype);
                        break;
                    default:
                        throw new RuntimeException();
                }

                ComputationGraph net = new ComputationGraph(b.build());
                net.init();

                INDArray out = net.outputSingle(in);
                assertEquals(msg, networkDtype, out.dataType());
                Map<String, INDArray> ff = net.feedForward(in, false);
                for (Map.Entry<String, INDArray> e : ff.entrySet()) {
                    if (e.getKey().equals("in"))
                        continue;
                    String s = msg + " - layer: " + e.getKey();
                    assertEquals(s, networkDtype, e.getValue().dataType());
                }

                net.setInputs(in);
                net.setLabels(label);
                net.computeGradientAndScore();

                net.fit(new MultiDataSet(in, new INDArray[]{label}));

                logUsedClasses(net);

                //Now, test mismatched dtypes for input/labels:
                for (DataType inputLabelDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
                    INDArray[] in2 = new INDArray[in.length];
                    for (int i = 0; i < in.length; i++) {
                        in2[i] = in[i].castTo(inputLabelDtype);
                    }
                    INDArray label2 = label.castTo(inputLabelDtype);
                    net.output(in2);
                    net.setInputs(in2);
                    net.setLabels(label2);
                    net.computeGradientAndScore();

                    net.fit(new MultiDataSet(in2, new INDArray[]{label2}));
                }
            }
        }
    }
}
 
Example 12
Source File: TestGraphNodes.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testLastTimeStepWithTransfer(){
    int lstmLayerSize = 16;
    int numLabelClasses = 10;
    int numInputs = 5;

    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
            .trainingWorkspaceMode(WorkspaceMode.NONE)
            .inferenceWorkspaceMode(WorkspaceMode.NONE)
            .seed(123)    //Random number generator seed for improved repeatability. Optional.
            .updater(new AdaDelta())
            .weightInit(WeightInit.XAVIER)
            .graphBuilder()
            .addInputs("rr")
            .setInputTypes(InputType.recurrent(30))
            .addLayer("1", new GravesLSTM.Builder().activation(Activation.TANH).nIn(numInputs).nOut(lstmLayerSize).dropOut(0.9).build(), "rr")
            .addLayer("2", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                    .activation(Activation.SOFTMAX).nOut(numLabelClasses).build(), "1")

            .setOutputs("2")
            .build();


    ComputationGraph net = new ComputationGraph(conf);
    net.init();

    ComputationGraph updatedModel = new TransferLearning.GraphBuilder(net)
            .addVertex("laststepoutput", new LastTimeStepVertex("rr"), "2")
            .setOutputs("laststepoutput")
            .build();


    INDArray input = Nd4j.rand(new int[]{10, numInputs, 16});

    INDArray[] out = updatedModel.output(input);

    assertNotNull(out);
    assertEquals(1, out.length);
    assertNotNull(out[0]);

    assertArrayEquals(new long[]{10, numLabelClasses}, out[0].shape());

    Map<String,INDArray> acts = updatedModel.feedForward(input, false);

    assertEquals(4, acts.size());   //2 layers + input + vertex output
    assertNotNull(acts.get("laststepoutput"));
    assertArrayEquals(new long[]{10, numLabelClasses}, acts.get("laststepoutput").shape());

    String toString = out[0].toString();
}