Java Code Examples for org.deeplearning4j.nn.graph.ComputationGraph#setInput()

The following examples show how to use org.deeplearning4j.nn.graph.ComputationGraph#setInput() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: CenterLossOutputLayerTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testLambdaConf() {
    double[] lambdas = new double[] {0.1, 0.01};
    double[] results = new double[2];
    int numClasses = 2;

    INDArray input = Nd4j.rand(150, 4);
    INDArray labels = Nd4j.zeros(150, numClasses);
    Random r = new Random(12345);
    for (int i = 0; i < 150; i++) {
        labels.putScalar(i, r.nextInt(numClasses), 1.0);
    }
    ComputationGraph graph;

    for (int i = 0; i < lambdas.length; i++) {
        graph = getGraph(numClasses, lambdas[i]);
        graph.setInput(0, input);
        graph.setLabel(0, labels);
        graph.computeGradientAndScore();
        results[i] = graph.score();
    }

    assertNotEquals(results[0], results[1]);
}
 
Example 2
Source File: BNGradientCheckTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testGradientBNWithCNNandSubsamplingCompGraph() {
        //Parameterized test, testing combinations of:
        // (a) activation function
        // (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation')
        // (c) Loss function (with specified output activations)
        // (d) l1 and l2 values
        Activation[] activFns = {Activation.TANH, Activation.IDENTITY};
        boolean doLearningFirst = true;

        LossFunctions.LossFunction[] lossFunctions = {LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD};
        Activation[] outputActivations = {Activation.SOFTMAX}; //i.e., lossFunctions[i] used with outputActivations[i] here

        double[] l2vals = {0.0, 0.1};
        double[] l1vals = {0.0, 0.2}; //i.e., use l2vals[j] with l1vals[j]

        Nd4j.getRandom().setSeed(12345);
        int minibatch = 10;
        int depth = 2;
        int hw = 5;
        int nOut = 3;
        INDArray input = Nd4j.rand(new int[]{minibatch, depth, hw, hw});
        INDArray labels = Nd4j.zeros(minibatch, nOut);
        Random r = new Random(12345);
        for (int i = 0; i < minibatch; i++) {
            labels.putScalar(i, r.nextInt(nOut), 1.0);
        }

        DataSet ds = new DataSet(input, labels);

        for (boolean useLogStd : new boolean[]{true, false}) {
            for (Activation afn : activFns) {
                for (int i = 0; i < lossFunctions.length; i++) {
                    for (int j = 0; j < l2vals.length; j++) {
                        LossFunctions.LossFunction lf = lossFunctions[i];
                        Activation outputActivation = outputActivations[i];

                        ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
                                .dataType(DataType.DOUBLE)
                                .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
                                .updater(new NoOp())
                                .dist(new UniformDistribution(-2, 2)).seed(12345L).graphBuilder()
                                .addInputs("in")
                                .addLayer("0", new ConvolutionLayer.Builder(2, 2).stride(1, 1).nOut(3)
                                        .activation(afn).build(), "in")
                                .addLayer("1", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "0")
                                .addLayer("2", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                                        .kernelSize(2, 2).stride(1, 1).build(), "1")
                                .addLayer("3", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "2")
                                .addLayer("4", new ActivationLayer.Builder().activation(afn).build(), "3")
                                .addLayer("5", new OutputLayer.Builder(lf).activation(outputActivation)
                                        .nOut(nOut).build(), "4")
                                .setOutputs("5").setInputTypes(InputType.convolutional(hw, hw, depth))
                                .build();

                        ComputationGraph net = new ComputationGraph(conf);
                        net.init();
                        String name = new Object() {
                        }.getClass().getEnclosingMethod().getName();

                        if (doLearningFirst) {
                            //Run a number of iterations of learning
                            net.setInput(0, ds.getFeatures());
                            net.setLabels(ds.getLabels());
                            net.computeGradientAndScore();
                            double scoreBefore = net.score();
                            for (int k = 0; k < 20; k++)
                                net.fit(ds);
                            net.computeGradientAndScore();
                            double scoreAfter = net.score();
                            //Can't test in 'characteristic mode of operation' if not learning
                            String msg = name
                                    + " - score did not (sufficiently) decrease during learning - activationFn="
                                    + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
                                    + ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore
                                    + ", scoreAfter=" + scoreAfter + ")";
                            assertTrue(msg, scoreAfter < 0.9 * scoreBefore);
                        }

                        System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf
                                + ", outputActivation=" + outputActivation + ", doLearningFirst="
                                + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]);
//                        for (int k = 0; k < net.getNumLayers(); k++)
//                            System.out.println("Layer " + k + " # params: " + net.getLayer(k).numParams());

                        //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
                        //i.e., runningMean = decay * runningMean + (1-decay) * batchMean
                        //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
                        Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev"));
                        boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{input})
                                .labels(new INDArray[]{labels}).excludeParams(excludeParams));

                        assertTrue(gradOK);
                        TestUtils.testModelSerialization(net);
                    }
                }
            }
        }
    }
 
Example 3
Source File: BidirectionalTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void compareImplementationsCompGraph(){
//        for(WorkspaceMode wsm : WorkspaceMode.values()) {
        for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) {
            log.info("*** Starting workspace mode: " + wsm);

            //Bidirectional(GravesLSTM) and GravesBidirectionalLSTM should be equivalent, given equivalent params
            //Note that GravesBidirectionalLSTM implements ADD mode only

            ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder()
                    .activation(Activation.TANH)
                    .weightInit(WeightInit.XAVIER)
                    .updater(new Adam())
                    .trainingWorkspaceMode(wsm)
                    .inferenceWorkspaceMode(wsm)
                    .graphBuilder()
                    .addInputs("in")
                    .layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "in")
                    .layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "0")
                    .layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
                            .nIn(10).nOut(10).build(), "1")
                    .setOutputs("2")
                    .build();

            ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder()
                    .activation(Activation.TANH)
                    .weightInit(WeightInit.XAVIER)
                    .updater(new Adam())
                    .trainingWorkspaceMode(wsm)
                    .inferenceWorkspaceMode(wsm)
                    .graphBuilder()
                    .addInputs("in")
                    .layer("0", new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build(), "in")
                    .layer("1", new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build(), "0")
                    .layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
                            .nIn(10).nOut(10).build(), "1")
                    .setOutputs("2")
                    .build();

            ComputationGraph net1 = new ComputationGraph(conf1);
            net1.init();

            ComputationGraph net2 = new ComputationGraph(conf2);
            net2.init();

            assertEquals(net1.numParams(), net2.numParams());
            for (int i = 0; i < 3; i++) {
                int n1 = (int)net1.getLayer(i).numParams();
                int n2 = (int)net2.getLayer(i).numParams();
                assertEquals(n1, n2);
            }

            net2.setParams(net1.params());  //Assuming exact same layout here...

            INDArray in = Nd4j.rand(new int[]{3, 10, 5});

            INDArray out1 = net1.outputSingle(in);
            INDArray out2 = net2.outputSingle(in);

            assertEquals(out1, out2);

            INDArray labels = Nd4j.rand(new int[]{3, 10, 5});

            net1.setInput(0,in);
            net1.setLabels(labels);

            net2.setInput(0,in);
            net2.setLabels(labels);

            net1.computeGradientAndScore();
            net2.computeGradientAndScore();

            //Ensure scores are equal:
            assertEquals(net1.score(), net2.score(), 1e-6);

            //Ensure gradients are equal:
            Gradient g1 = net1.gradient();
            Gradient g2 = net2.gradient();
            assertEquals(g1.gradient(), g2.gradient());

            //Ensure updates are equal:
            ComputationGraphUpdater u1 = (ComputationGraphUpdater) net1.getUpdater();
            ComputationGraphUpdater u2 = (ComputationGraphUpdater) net2.getUpdater();
            assertEquals(u1.getUpdaterStateViewArray(), u2.getUpdaterStateViewArray());
            u1.update(g1, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces());
            u2.update(g2, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces());
            assertEquals(g1.gradient(), g2.gradient());
            assertEquals(u1.getUpdaterStateViewArray(), u2.getUpdaterStateViewArray());

            //Ensure params are equal, after fitting
            net1.fit(new DataSet(in, labels));
            net2.fit(new DataSet(in, labels));

            INDArray p1 = net1.params();
            INDArray p2 = net2.params();
            assertEquals(p1, p2);
        }
    }
 
Example 4
Source File: BidirectionalTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSerializationCompGraph() throws Exception {

    for(WorkspaceMode wsm : WorkspaceMode.values()) {
        log.info("*** Starting workspace mode: " + wsm);

        Nd4j.getRandom().setSeed(12345);

        ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder()
                .activation(Activation.TANH)
                .weightInit(WeightInit.XAVIER)
                .trainingWorkspaceMode(wsm)
                .inferenceWorkspaceMode(wsm)
                .updater(new Adam())
                .graphBuilder()
                .addInputs("in")
                .layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in")
                .layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "0")
                .layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat)
                        .nIn(10).nOut(10).build(), "1")
                .setOutputs("2")
                .build();

        ComputationGraph net1 = new ComputationGraph(conf1);
        net1.init();
        long[] inshape = (rnnDataFormat == NCW)? new long[]{3, 10, 5}: new long[]{3, 5, 10};
        INDArray in = Nd4j.rand(inshape);
        INDArray labels = Nd4j.rand(inshape);

        net1.fit(new DataSet(in, labels));

        byte[] bytes;
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
            ModelSerializer.writeModel(net1, baos, true);
            bytes = baos.toByteArray();
        }


        ComputationGraph net2 = ModelSerializer.restoreComputationGraph(new ByteArrayInputStream(bytes), true);


        in = Nd4j.rand(inshape);
        labels = Nd4j.rand(inshape);

        INDArray out1 = net1.outputSingle(in);
        INDArray out2 = net2.outputSingle(in);

        assertEquals(out1, out2);

        net1.setInput(0, in);
        net2.setInput(0, in);
        net1.setLabels(labels);
        net2.setLabels(labels);

        net1.computeGradientAndScore();
        net2.computeGradientAndScore();

        assertEquals(net1.score(), net2.score(), 1e-6);
        assertEquals(net1.gradient().gradient(), net2.gradient().gradient());
    }
}
 
Example 5
Source File: OutputLayerTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testCnnLossLayerCompGraph(){

    for(WorkspaceMode ws : WorkspaceMode.values()) {
        log.info("*** Testing workspace: " + ws);

        for (Activation a : new Activation[]{Activation.TANH, Activation.SELU}) {
            //Check that (A+identity) is equal to (identity+A), for activation A
            //i.e., should get same output and weight gradients for both

            ComputationGraphConfiguration conf1 =
                    new NeuralNetConfiguration.Builder().seed(12345L)
                            .updater(new NoOp())
                            .convolutionMode(ConvolutionMode.Same)
                            .inferenceWorkspaceMode(ws)
                            .trainingWorkspaceMode(ws)
                            .graphBuilder()
                            .addInputs("in")
                            .addLayer("0", new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(Activation.IDENTITY)
                                    .kernelSize(2, 2).stride(1, 1)
                                    .dist(new NormalDistribution(0, 1.0))
                                    .updater(new NoOp()).build(), "in")
                            .addLayer("1", new CnnLossLayer.Builder(LossFunction.MSE)
                                    .activation(a)
                                    .build(), "0")
                            .setOutputs("1")
                            .build();

            ComputationGraphConfiguration conf2 =
                    new NeuralNetConfiguration.Builder().seed(12345L)
                            .updater(new NoOp())
                            .convolutionMode(ConvolutionMode.Same)
                            .inferenceWorkspaceMode(ws)
                            .trainingWorkspaceMode(ws)
                            .graphBuilder()
                            .addInputs("in")
                            .addLayer("0", new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(a)
                                    .kernelSize(2, 2).stride(1, 1)
                                    .dist(new NormalDistribution(0, 1.0))
                                    .updater(new NoOp()).build(), "in")
                            .addLayer("1", new CnnLossLayer.Builder(LossFunction.MSE)
                                    .activation(Activation.IDENTITY)
                                    .build(), "0")
                            .setOutputs("1")
                            .build();

            ComputationGraph graph = new ComputationGraph(conf1);
            graph.init();

            ComputationGraph graph2 = new ComputationGraph(conf2);
            graph2.init();


            graph2.setParams(graph.params());


            INDArray in = Nd4j.rand(new int[]{3, 3, 5, 5});

            INDArray out1 = graph.outputSingle(in);
            INDArray out2 = graph2.outputSingle(in);

            assertEquals(out1, out2);

            INDArray labels = Nd4j.rand(out1.shape());

            graph.setInput(0,in);
            graph.setLabels(labels);

            graph2.setInput(0,in);
            graph2.setLabels(labels);

            graph.computeGradientAndScore();
            graph2.computeGradientAndScore();

            assertEquals(graph.score(), graph2.score(), 1e-6);
            assertEquals(graph.gradient().gradient(), graph2.gradient().gradient());

            //Also check computeScoreForExamples
            INDArray in2a = Nd4j.rand(new int[]{1, 3, 5, 5});
            INDArray labels2a = Nd4j.rand(new int[]{1, 4, 5, 5});

            INDArray in2 = Nd4j.concat(0, in2a, in2a);
            INDArray labels2 = Nd4j.concat(0, labels2a, labels2a);

            INDArray s = graph.scoreExamples(new DataSet(in2, labels2), false);
            assertArrayEquals(new long[]{2, 1}, s.shape());
            assertEquals(s.getDouble(0), s.getDouble(1), 1e-6);

            TestUtils.testModelSerialization(graph);
        }
    }
}
 
Example 6
Source File: DTypeTests.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testComputationGraphTypeConversion() {

    for (DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
        Nd4j.setDefaultDataTypes(dt, dt);

        ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(12345)
                .weightInit(WeightInit.XAVIER)
                .updater(new Adam(0.01))
                .dataType(DataType.DOUBLE)
                .graphBuilder()
                .addInputs("in")
                .layer("l0", new DenseLayer.Builder().activation(Activation.TANH).nIn(10).nOut(10).build(), "in")
                .layer("l1", new DenseLayer.Builder().activation(Activation.TANH).nIn(10).nOut(10).build(), "l0")
                .layer("out", new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "l1")
                .setOutputs("out")
                .build();

        ComputationGraph net = new ComputationGraph(conf);
        net.init();

        INDArray inD = Nd4j.rand(DataType.DOUBLE, 1, 10);
        INDArray lD = Nd4j.create(DataType.DOUBLE, 1, 10);
        net.fit(new DataSet(inD, lD));

        INDArray outDouble = net.outputSingle(inD);
        net.setInput(0, inD);
        net.setLabels(lD);
        net.computeGradientAndScore();
        double scoreDouble = net.score();
        INDArray grads = net.getFlattenedGradients();
        INDArray u = net.getUpdater().getStateViewArray();
        assertEquals(DataType.DOUBLE, net.params().dataType());
        assertEquals(DataType.DOUBLE, grads.dataType());
        assertEquals(DataType.DOUBLE, u.dataType());


        ComputationGraph netFloat = net.convertDataType(DataType.FLOAT);
        netFloat.initGradientsView();
        assertEquals(DataType.FLOAT, netFloat.params().dataType());
        assertEquals(DataType.FLOAT, netFloat.getFlattenedGradients().dataType());
        assertEquals(DataType.FLOAT, netFloat.getUpdater(true).getStateViewArray().dataType());
        INDArray inF = inD.castTo(DataType.FLOAT);
        INDArray lF = lD.castTo(DataType.FLOAT);
        INDArray outFloat = netFloat.outputSingle(inF);
        netFloat.setInput(0, inF);
        netFloat.setLabels(lF);
        netFloat.computeGradientAndScore();
        double scoreFloat = netFloat.score();
        INDArray gradsFloat = netFloat.getFlattenedGradients();
        INDArray uFloat = netFloat.getUpdater().getStateViewArray();

        assertEquals(scoreDouble, scoreFloat, 1e-6);
        assertEquals(outDouble.castTo(DataType.FLOAT), outFloat);
        assertEquals(grads.castTo(DataType.FLOAT), gradsFloat);
        INDArray uCast = u.castTo(DataType.FLOAT);
        assertTrue(uCast.equalsWithEps(uFloat, 1e-4));

        ComputationGraph netFP16 = net.convertDataType(DataType.HALF);
        netFP16.initGradientsView();
        assertEquals(DataType.HALF, netFP16.params().dataType());
        assertEquals(DataType.HALF, netFP16.getFlattenedGradients().dataType());
        assertEquals(DataType.HALF, netFP16.getUpdater(true).getStateViewArray().dataType());

        INDArray inH = inD.castTo(DataType.HALF);
        INDArray lH = lD.castTo(DataType.HALF);
        INDArray outHalf = netFP16.outputSingle(inH);
        netFP16.setInput(0, inH);
        netFP16.setLabels(lH);
        netFP16.computeGradientAndScore();
        double scoreHalf = netFP16.score();
        INDArray gradsHalf = netFP16.getFlattenedGradients();
        INDArray uHalf = netFP16.getUpdater().getStateViewArray();

        assertEquals(scoreDouble, scoreHalf, 1e-4);
        boolean outHalfEq = outDouble.castTo(DataType.HALF).equalsWithEps(outHalf, 1e-3);
        assertTrue(outHalfEq);
        boolean gradsHalfEq = grads.castTo(DataType.HALF).equalsWithEps(gradsHalf, 1e-3);
        assertTrue(gradsHalfEq);
        INDArray uHalfCast = u.castTo(DataType.HALF);
        assertTrue(uHalfCast.equalsWithEps(uHalf, 1e-4));
    }
}
 
Example 7
Source File: DTypeTests.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testEmbeddingDtypes() {
    for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
        Nd4j.setDefaultDataTypes(globalDtype, globalDtype);
        for (DataType networkDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
            for (boolean frozen : new boolean[]{false, true}) {
                for (int test = 0; test < 3; test++) {
                    assertEquals(globalDtype, Nd4j.dataType());
                    assertEquals(globalDtype, Nd4j.defaultFloatingPointType());

                    String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype + ", test=" + test;

                    ComputationGraphConfiguration.GraphBuilder conf = new NeuralNetConfiguration.Builder()
                            .dataType(networkDtype)
                            .seed(123)
                            .updater(new NoOp())
                            .weightInit(new WeightInitDistribution(new UniformDistribution(-6, 6)))
                            .graphBuilder()
                            .addInputs("in")
                            .setOutputs("out");

                    INDArray input;
                    if (test == 0) {
                        if (frozen) {
                            conf.layer("0", new FrozenLayer(new EmbeddingLayer.Builder().nIn(5).nOut(5).build()), "in");
                        } else {
                            conf.layer("0", new EmbeddingLayer.Builder().nIn(5).nOut(5).build(), "in");
                        }
                        input = Nd4j.rand(networkDtype, 10, 1).muli(5).castTo(DataType.INT);
                        conf.setInputTypes(InputType.feedForward(1));
                    } else if (test == 1) {
                        if (frozen) {
                            conf.layer("0", new FrozenLayer(new EmbeddingSequenceLayer.Builder().nIn(5).nOut(5).build()), "in");
                        } else {
                            conf.layer("0", new EmbeddingSequenceLayer.Builder().nIn(5).nOut(5).build(), "in");
                        }
                        conf.layer("gp", new GlobalPoolingLayer.Builder(PoolingType.PNORM).pnorm(2).poolingDimensions(2).build(), "0");
                        input = Nd4j.rand(networkDtype, 10, 1, 5).muli(5).castTo(DataType.INT);
                        conf.setInputTypes(InputType.recurrent(1));
                    } else {
                        conf.layer("0", new RepeatVector.Builder().repetitionFactor(5).nOut(5).build(), "in");
                        conf.layer("gp", new GlobalPoolingLayer.Builder(PoolingType.SUM).build(), "0");
                        input = Nd4j.rand(networkDtype, 10, 5);
                        conf.setInputTypes(InputType.feedForward(5));
                    }

                    conf.appendLayer("el", new ElementWiseMultiplicationLayer.Builder().nOut(5).build())
                            .appendLayer("ae", new AutoEncoder.Builder().nOut(5).build())
                            .appendLayer("prelu", new PReLULayer.Builder().nOut(5).inputShape(5).build())
                            .appendLayer("out", new OutputLayer.Builder().nOut(10).build());

                    ComputationGraph net = new ComputationGraph(conf.build());
                    net.init();

                    INDArray label = Nd4j.zeros(networkDtype, 10, 10);

                    INDArray out = net.outputSingle(input);
                    assertEquals(msg, networkDtype, out.dataType());
                    Map<String, INDArray> ff = net.feedForward(input, false);
                    for (Map.Entry<String, INDArray> e : ff.entrySet()) {
                        if (e.getKey().equals("in"))
                            continue;
                        String s = msg + " - layer: " + e.getKey();
                        assertEquals(s, networkDtype, e.getValue().dataType());
                    }

                    net.setInput(0, input);
                    net.setLabels(label);
                    net.computeGradientAndScore();

                    net.fit(new DataSet(input, label));

                    logUsedClasses(net);

                    //Now, test mismatched dtypes for input/labels:
                    for (DataType inputLabelDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
                        INDArray in2 = input.castTo(inputLabelDtype);
                        INDArray label2 = label.castTo(inputLabelDtype);
                        net.output(in2);
                        net.setInput(0, in2);
                        net.setLabels(label2);
                        net.computeGradientAndScore();

                        net.fit(new DataSet(in2, label2));
                    }
                }
            }
        }
    }
}