Java Code Examples for org.nd4j.autodiff.samediff.SDVariable#setArray()

The following examples show how to use org.nd4j.autodiff.samediff.SDVariable#setArray() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: LayerOpValidation.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testLrn2d() {
    Nd4j.getRandom().setSeed(12345);

    int[][] inputSizes = new int[][]{{1, 3, 8, 8}, {3, 6, 12, 12}};

    List<String> failed = new ArrayList<>();

    for (int[] inSizeNCHW : inputSizes) {

        SameDiff sd = SameDiff.create();
        SDVariable in = null;

        int[] inSize;

        //LRN
        String msg = "LRN with NCHW - input" + Arrays.toString(inSizeNCHW);
        inSize = inSizeNCHW;
        in = sd.var("in", inSize);
        SDVariable out = sd.cnn().localResponseNormalization(in, LocalResponseNormalizationConfig.builder()
                .depth(3)
                .bias(1)
                .alpha(1)
                .beta(0.5)
                .build());

        INDArray inArr = Nd4j.rand(inSize).muli(10);
        in.setArray(inArr);
        SDVariable loss = sd.mean("loss", out);

        log.info("Starting test: " + msg);
        TestCase tc = new TestCase(sd).gradientCheck(true);
        String error = OpValidation.validate(tc);
        if (error != null) {
            failed.add(msg);
        }

    }
    assertEquals(failed.toString(), 0, failed.size());
}
 
Example 2
Source File: LayerOpValidation.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testConv3d() {
    //Pooling3d, Conv3D, batch norm
    Nd4j.getRandom().setSeed(12345);

    //NCDHW format
    int[][] inputSizes = new int[][]{{2, 3, 4, 5, 5}};

    List<String> failed = new ArrayList<>();

    for (int[] inSizeNCDHW : inputSizes) {
        for (boolean ncdhw : new boolean[]{true, false}) {
            int nIn = inSizeNCDHW[1];
            int[] shape = (ncdhw ? inSizeNCDHW : ncdhwToNdhwc(inSizeNCDHW));

            for (int i = 0; i < 5; i++) {
                SameDiff sd = SameDiff.create();
                SDVariable in = sd.var("in", shape);

                SDVariable out;
                String msg;
                switch (i) {
                    case 0:
                        //Conv3d, with bias, same
                        msg = "0 - conv3d+bias+same, ncdhw=" + ncdhw + " - input " + Arrays.toString(shape);
                        SDVariable w0 = sd.var("w0", Nd4j.rand(new int[]{2, 2, 2, nIn, 3}).muli(10));  //[kD, kH, kW, iC, oC]
                        SDVariable b0 = sd.var("b0", Nd4j.rand(new long[]{3}).muli(10));
                        out = sd.cnn().conv3d(in, w0, b0, Conv3DConfig.builder()
                                .dataFormat(ncdhw ? Conv3DConfig.NCDHW : Conv3DConfig.NDHWC)
                                .isSameMode(true)
                                .kH(2).kW(2).kD(2)
                                .sD(1).sH(1).sW(1)
                                .build());
                        break;
                    case 1:
                        //Conv3d, no bias, no same
                        msg = "1 - conv3d+no bias+no same, ncdhw=" + ncdhw + " - input " + Arrays.toString(shape);
                        SDVariable w1 = sd.var("w1", Nd4j.rand(new int[]{2, 2, 2, nIn, 3}).muli(10));  //[kD, kH, kW, iC, oC]
                        out = sd.cnn().conv3d(in, w1, Conv3DConfig.builder()
                                .dataFormat(ncdhw ? Conv3DConfig.NCDHW : Conv3DConfig.NDHWC)
                                .isSameMode(false)
                                .kH(2).kW(2).kD(2)
                                .sD(1).sH(1).sW(1)
                                .build());
                        break;
                    case 2:
                        //pooling3d - average, no same
                        msg = "2 - pooling 3d, average, same";
                        out = sd.cnn().avgPooling3d(in, Pooling3DConfig.builder()
                                .kH(2).kW(2).kD(2)
                                .sH(1).sW(1).sD(1)
                                .isSameMode(false)
                                .isNCDHW(ncdhw)
                                .build());
                        break;
                    case 3:
                        //pooling 3d - max, no same
                        msg = "3 - pooling 3d, max, same";
                        out = sd.cnn().maxPooling3d(in, Pooling3DConfig.builder()
                                .kH(2).kW(2).kD(2)
                                .sH(1).sW(1).sD(1)
                                .isSameMode(true)
                                .isNCDHW(ncdhw)
                                .build());
                        break;
                    case 4:
                        //Deconv3d
                        msg = "4 - deconv3d, ncdhw=" + ncdhw;
                        SDVariable wDeconv = sd.var(Nd4j.rand(new int[]{2, 2, 2, 3, nIn}));  //[kD, kH, kW, oC, iC]
                        SDVariable bDeconv = sd.var(Nd4j.rand(new int[]{3}));
                        out = sd.cnn().deconv3d("Deconv3d", in, wDeconv, bDeconv, DeConv3DConfig.builder()
                                .kD(2).kH(2).kW(2)
                                .isSameMode(true)
                                .dataFormat(ncdhw ? DeConv3DConfig.NCDHW : DeConv3DConfig.NDHWC)
                                .build());
                        break;
                    case 5:
                        //Batch norm - 3d input
                        throw new RuntimeException("Batch norm test not yet implemented");
                    default:
                        throw new RuntimeException();
                }

                INDArray inArr = Nd4j.rand(shape).muli(10);
                in.setArray(inArr);
                SDVariable loss = sd.standardDeviation("loss", out, true);

                log.info("Starting test: " + msg);
                TestCase tc = new TestCase(sd).gradientCheck(true);
                tc.testName(msg);
                String error = OpValidation.validate(tc);
                if (error != null) {
                    failed.add(name);
                }
            }
        }
    }

    assertEquals(failed.toString(), 0, failed.size());
}
 
Example 3
Source File: MiscOpValidation.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testScatterOpGradients() {
    List<String> failed = new ArrayList<>();

    for (int i = 0; i < 7; i++) {
        Nd4j.getRandom().setSeed(12345);

        SameDiff sd = SameDiff.create();

        SDVariable in = sd.var("in", DataType.DOUBLE, 20, 10);
        SDVariable indices = sd.var("indices", DataType.INT, new long[]{5});
        SDVariable updates = sd.var("updates", DataType.DOUBLE, 5, 10);


        in.setArray(Nd4j.rand(DataType.DOUBLE, 20, 10));
        indices.setArray(Nd4j.create(new double[]{3, 4, 5, 10, 18}).castTo(DataType.INT));
        updates.setArray(Nd4j.rand(DataType.DOUBLE, 5, 10).muli(2).subi(1));

        SDVariable scatter;
        String name;
        switch (i) {
            case 0:
                scatter = sd.scatterAdd("s", in, indices, updates);
                name = "scatterAdd";
                break;
            case 1:
                scatter = sd.scatterSub("s", in, indices, updates);
                name = "scatterSub";
                break;
            case 2:
                scatter = sd.scatterMul("s", in, indices, updates);
                name = "scatterMul";
                break;
            case 3:
                scatter = sd.scatterDiv("s", in, indices, updates);
                name = "scatterDiv";
                break;
            case 4:
                scatter = sd.scatterUpdate("s", in, indices, updates);
                name = "scatterUpdate";
                break;
            case 5:
                scatter = sd.scatterMax("s", in, indices, updates);
                name = "scatterMax";
                break;
            case 6:
                scatter = sd.scatterMin("s", in, indices, updates);
                name = "scatterMin";
                break;
            default:
                throw new RuntimeException();
        }

        INDArray exp = in.getArr().dup();
        int[] indicesInt = indices.getArr().dup().data().asInt();
        for( int j=0; j<indicesInt.length; j++ ){
            INDArray updateRow = updates.getArr().getRow(j);
            INDArray destinationRow = exp.getRow(indicesInt[j]);
            switch (i){
                case 0:
                    destinationRow.addi(updateRow);
                    break;
                case 1:
                    destinationRow.subi(updateRow);
                    break;
                case 2:
                    destinationRow.muli(updateRow);
                    break;
                case 3:
                    destinationRow.divi(updateRow);
                    break;
                case 4:
                    destinationRow.assign(updateRow);
                    break;
                case 5:
                    destinationRow.assign(Transforms.max(destinationRow, updateRow, true));
                    break;
                case 6:
                    destinationRow.assign(Transforms.min(destinationRow, updateRow, true));
                    break;
                default:
                    throw new RuntimeException();
            }
        }

        SDVariable loss = sd.sum(scatter);  //.standardDeviation(scatter, true);  //.sum(scatter);  //TODO stdev might be better here as gradients are non-symmetrical...


        TestCase tc = new TestCase(sd)
                .expected(scatter, exp)
                .gradCheckSkipVariables(indices.name());

        String error = OpValidation.validate(tc);
        if(error != null){
            failed.add(name);
        }
    }

    assertEquals(failed.toString(), 0, failed.size());
}
 
Example 4
Source File: TransformOpValidation.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testIsX() {
    List<String> failed = new ArrayList<>();

    for (int i = 0; i < 4; i++) {

        SameDiff sd = SameDiff.create();
        SDVariable in = sd.var("in", 4);

        SDVariable out;
        INDArray exp;
        INDArray inArr;
        switch (i) {
            case 0:
                inArr = Nd4j.create(new double[]{10, Double.POSITIVE_INFINITY, 0, Double.NEGATIVE_INFINITY});
                exp = Nd4j.create(new boolean[]{true, false, true, false});
                out = sd.math().isFinite(in);
                break;
            case 1:
                inArr = Nd4j.create(new double[]{10, Double.POSITIVE_INFINITY, 0, Double.NEGATIVE_INFINITY});
                exp = Nd4j.create(new boolean[]{false, true, false, true});
                out = sd.math().isInfinite(in);
                break;
            case 2:
                //TODO: IsMax supports both bool and float out: https://github.com/deeplearning4j/deeplearning4j/issues/6872
                inArr = Nd4j.create(new double[]{-3, 5, 0, 2});
                exp = Nd4j.create(new boolean[]{false, true, false, false});
                out = sd.math().isMax(in);
                break;
            case 3:
                inArr = Nd4j.create(new double[]{0, Double.NaN, 10, Double.NaN});
                exp = Nd4j.create(new boolean[]{false, true, false, true});
                out = sd.math().isNaN(in);
                break;
            default:
                throw new RuntimeException();
        }

        SDVariable other = sd.var("other", Nd4j.rand(DataType.DOUBLE, 4));

        SDVariable loss = out.castTo(DataType.DOUBLE).add(other).mean();
        TestCase tc = new TestCase(sd)
                .gradientCheck(false)   //Can't gradient check - in -> boolean -> cast(double)
                .expected(out, exp);

        in.setArray(inArr);

        String err = OpValidation.validate(tc, true);
        if (err != null) {
            failed.add(err);
        }
    }
    assertEquals(failed.toString(), 0, failed.size());
}