Java Code Examples for org.nd4j.autodiff.samediff.SameDiff#addLossVariable()

The following examples show how to use org.nd4j.autodiff.samediff.SameDiff#addLossVariable() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: MiscOpValidation.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testTensorGradTensorMmul() {
    OpValidationSuite.ignoreFailing();

    Nd4j.getRandom().setSeed(12345);
    SameDiff sameDiff = SameDiff.create();
    INDArray arr = Nd4j.rand(new long[]{2, 2, 2});
    INDArray arr2 = Nd4j.rand(new long[]{2, 2, 2});
    SDVariable x = sameDiff.var("x", arr);
    SDVariable y = sameDiff.var("y", arr2);
    SDVariable result = sameDiff.tensorMmul(x, y, new int[]{0}, new int[]{1});
    assertArrayEquals(ArrayUtil.getTensorMmulShape(new long[]{2, 2, 2}, new long[]{2, 2, 2}, new int[][]{{0}, {1}}),
            result.eval().shape());
    assertEquals(16, sameDiff.numElements());

    SDVariable loss = sameDiff.standardDeviation(result, true);
    sameDiff.addLossVariable(loss);

    String err = OpValidation.validate(new TestCase(sameDiff));
    assertNull(err);
}
 
Example 2
Source File: MiscOpValidation.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testFlatten() {

    SameDiff sameDiff = SameDiff.create();

    INDArray x = Nd4j.linspace(DataType.DOUBLE, 1, 27, 1).reshape(3,3,3);
    SDVariable sdx = sameDiff.var(x);

    INDArray expected = Nd4j.linspace(DataType.DOUBLE,1,27,1);

    SDVariable output = new Flatten(sameDiff, 'c', sdx).outputVariable();
    SDVariable loss = sameDiff.standardDeviation(sdx, true);
    sameDiff.addLossVariable(loss);

    TestCase tc = new TestCase(sameDiff)
            .gradientCheck(true)
            .expectedOutput(output.name(), expected);

    String err = OpValidation.validate(tc);
    assertNull(err);
}
 
Example 3
Source File: MiscOpValidation.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testLgamma() {

    SameDiff sameDiff = SameDiff.create();

    INDArray in = Nd4j.linspace(DataType.DOUBLE, 1, 12, 1).reshape(3, 4);
    SDVariable sdInput = sameDiff.var(in);

    INDArray expected = Nd4j.createFromArray(new double[]{
            0.0,0.0,0.6931472,1.7917595,3.1780539,4.787492,6.5792513,8.525162,10.604603,12.801827,15.104413,17.502308
    }).reshape(3,4);

    SDVariable output = new Lgamma(sameDiff, sdInput).outputVariable();

    SDVariable loss = sameDiff.standardDeviation(sdInput, true);
    sameDiff.addLossVariable(loss);

    TestCase tc = new TestCase(sameDiff)
            .gradientCheck(true)
            .expectedOutput(output.name(), expected);

    String err = OpValidation.validate(tc);
    assertNull(err);
}
 
Example 4
Source File: MiscOpValidation.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testBiasAdd() {

    SameDiff sameDiff = SameDiff.create();

    INDArray in1 = Nd4j.linspace(1, 12, 12);
    INDArray in2 = Nd4j.linspace(1, 12, 12);

    SDVariable input1 = sameDiff.var(in1);
    SDVariable input2 = sameDiff.var(in2);

    INDArray expected = Nd4j.createFromArray(new double[]{
            2.0000,    4.0000,    6.0000,    8.0000,   10.0000,   12.0000,   14.0000,   16.0000,   18.0000,   20.0000,   22.0000,   24.0000
    });

    SDVariable output = new BiasAdd(sameDiff, input1, input2, false).outputVariable();
    SDVariable loss = sameDiff.standardDeviation(input1, true);
    sameDiff.addLossVariable(loss);
    SDVariable loss2 = sameDiff.standardDeviation(input2, true);
    sameDiff.addLossVariable(loss2);

    TestCase tc = new TestCase(sameDiff)
            .gradientCheck(true)
            .expectedOutput(output.name(), expected);

    String err = OpValidation.validate(tc);
    assertNull(err);
}
 
Example 5
Source File: MiscOpValidation.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testFusedBatchNorm() {
    OpValidationSuite.ignoreFailing();
    SameDiff sameDiff = SameDiff.create();

    INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*2*3*4).reshape(2,2,3,4);
    INDArray scale = Nd4j.create(DataType.DOUBLE, 4);
    scale.assign(0.5);
    INDArray offset = Nd4j.create(DataType.DOUBLE, 4);
    offset.assign(2.0);

    SDVariable input1 = sameDiff.var(x);
    SDVariable input2 = sameDiff.var(scale);
    SDVariable input3 = sameDiff.var(offset);

    INDArray expectedY = Nd4j.createFromArray(new double[]{
            985.5258,  985.5258,  985.5258,  985.5258,
            659.7321,  659.7321,  659.7321,  659.7321,
            399.0972,  399.0972,  399.0972,  399.0972,
            203.6210,  203.6210,  203.6210,  203.6210,
            73.3036,   73.3036,   73.3036,   73.3036,
            8.1448,    8.1448,    8.1448,    8.1448,
            8.1448,    8.1448,    8.1448,    8.1448,
            73.3036,   73.3036,   73.3036,   73.3036,
            203.6210,  203.6210,  203.6210,  203.6210,
            399.0972,  399.0972,  399.0972,  399.0972,
            659.7321,  659.7321,  659.7321,  659.7321,
            985.5258,  985.5258,  985.5258,  985.5258}).reshape(x.shape());
    INDArray expectedBatchMean = Nd4j.createFromArray(new double[]{23.,  24.,  25.,  26.});
    INDArray expectedBatchVar = Nd4j.createFromArray(new double[]{208.00001526,  208.00001526,  208.00001526,  208.00001526});

    SDVariable[] outputs = new FusedBatchNorm(sameDiff, input1, input2, input3, 0, 1).outputVariables();
    SDVariable loss = sameDiff.standardDeviation(input1, true);
    sameDiff.addLossVariable(loss);

    TestCase tc = new TestCase(sameDiff)
            .gradientCheck(true)
            .expectedOutput(outputs[0].name(), expectedY)
            .expectedOutput(outputs[1].name(), expectedBatchMean)
            .expectedOutput(outputs[2].name(), expectedBatchVar);

    String err = OpValidation.validate(tc);
    assertNull(err);
}
 
Example 6
Source File: ShapeOpValidation.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSegmentOps(){
    OpValidationSuite.ignoreFailing();
    //https://github.com/deeplearning4j/deeplearning4j/issues/6952
    INDArray s = Nd4j.create(new double[]{0,0,0,1,2,2,3,3}, new long[]{8}).castTo(DataType.INT);
    INDArray d = Nd4j.create(new double[]{5,1,7,2,3,4,1,3}, new long[]{8});
    int numSegments = 4;

    List<String> failed = new ArrayList<>();

    for(String op : new String[]{"max", "min", "mean", "prod", "sum",
            "umax", "umin", "umean", "uprod", "usum", "usqrtn"}) {
        log.info("Starting test: {}", op);

        if(op.startsWith("u")){
            //Unsorted segment cases
            s = Nd4j.create(new double[]{3,1,0,0,2,0,3,2}, new long[]{8}).castTo(DataType.INT);
            d = Nd4j.create(new double[]{1,2,5,7,3,1,3,4}, new long[]{8});
        }

        SameDiff sd = SameDiff.create();
        SDVariable data = sd.var("data", d);
        SDVariable segments = sd.constant("segments", s);

        SDVariable sm;
        INDArray exp;
        switch (op){
            case "max":
                sm = sd.segmentMax(data, segments);
                exp = Nd4j.create(new double[]{7, 2, 4, 3});
                break;
            case "min":
                sm = sd.segmentMin(data, segments);
                exp = Nd4j.create(new double[]{1, 2, 3, 1});
                break;
            case "mean":
                sm = sd.segmentMean(data, segments);
                exp = Nd4j.create(new double[]{4.3333333333, 2, 3.5, 2});
                break;
            case "prod":
                sm = sd.segmentProd(data, segments);
                exp = Nd4j.create(new double[]{35, 2, 12, 3});
                break;
            case "sum":
                sm = sd.segmentSum(data, segments);
                exp = Nd4j.create(new double[]{13, 2, 7, 4});
                break;
            case "umax":
                sm = sd.unsortedSegmentMax(data, segments, numSegments);
                exp = Nd4j.create(new double[]{7, 2, 4, 3});
                break;
            case "umin":
                sm = sd.unsortedSegmentMin(data, segments, numSegments);
                exp = Nd4j.create(new double[]{1, 2, 3, 1});
                break;
            case "umean":
                sm = sd.unsortedSegmentMean(data, segments, numSegments);
                exp = Nd4j.create(new double[]{4.3333333333, 2, 3.5, 2});
                break;
            case "uprod":
                sm = sd.unsortedSegmentProd(data, segments, numSegments);
                exp = Nd4j.create(new double[]{35, 2, 12, 3});
                break;
            case "usum":
                sm = sd.unsortedSegmentSum(data, segments, numSegments);
                exp = Nd4j.create(new double[]{13, 2, 7, 4});
                break;
            case "usqrtn":
                sm = sd.unsortedSegmentSqrtN(data, segments, numSegments);
                exp = Nd4j.create(new double[]{(5+7+1)/Math.sqrt(3), 2, (3+4)/Math.sqrt(2), (1+3)/Math.sqrt(2)});
                break;
            default:
                throw new RuntimeException();
        }

        SDVariable loss = sm.std(true);
        sd.addLossVariable(loss);

        TestCase tc = new TestCase(sd)
                .testName(op)
                .expected(sm, exp)
                .gradientCheck(true)
                .gradCheckSkipVariables(segments.name());

        String err = OpValidation.validate(tc);
        if(err != null)
            failed.add(err);
    }

    assertEquals(failed.toString(), 0, failed.size());
}