Java Code Examples for org.nd4j.autodiff.samediff.SDVariable#mul()

The following examples show how to use org.nd4j.autodiff.samediff.SDVariable#mul() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: CosineSimilarity.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static List<SDVariable> doDiff(SameDiff sameDiff, SDVariable x, SDVariable y,
                                      SDVariable gradOut, boolean keepDims, int... dimensions){
    SDVariable a = sameDiff.sum(x.mul(y),true, dimensions);
    SDVariable l2x = sameDiff.norm2(x, true, dimensions);
    SDVariable l2y = sameDiff.norm2(y, true, dimensions);
    SDVariable b = l2x.mul(l2y);

    SDVariable l2xSq = sameDiff.math().square(l2x);
    SDVariable l2ySq = sameDiff.math().square(l2y);
    SDVariable broadcastableGrad;
    if(keepDims || dimensions == null || dimensions.length == 0 || (dimensions.length == 1 && dimensions[0] == Integer.MAX_VALUE)){
        //keepDims or full array reduction
        broadcastableGrad = gradOut;
    } else {
        broadcastableGrad = SameDiffUtils.reductionBroadcastableWithOrigShape(x, sameDiff.constant(Nd4j.createFromArray(dimensions)), gradOut);
    }

    SDVariable dcdx = y.sub(x.mul(a).div(l2xSq)).div(b);
    SDVariable dcdy = x.sub(y.mul(a).div(l2ySq)).div(b);

    return Arrays.asList(dcdx.mul(broadcastableGrad), dcdy.mul(broadcastableGrad));
}
 
Example 2
Source File: Cross.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> gradients) {
    /**
     * dL / dx = dL / dCross * dCross / dx
     * dCross(a,b) / da = Cross(1, b)
     * dCross(a,b) / db = Cross(a, 1)
     *
     * return (grad * Cross(1, b), grad * Cross(a, 1)
     */
    SDVariable grad = gradients.get(0);
    SDVariable a = larg();
    SDVariable b = rarg();
    SDVariable ones = sameDiff.onesLike(a);

    SDVariable gradLeft = grad.mul(sameDiff.math().cross(b, ones));
    SDVariable gradRight = grad.mul(sameDiff.math().cross(ones, a));

    return Arrays.asList(gradLeft, gradRight);
}
 
Example 3
Source File: ScatterMin.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> gradOut) {
    //3 args: ref, indices, updates
    //For non-modified indices, input gradient (reference) is same as output gradient
    //For modified indices, dL/dref = dL/dOut if(ref[index[i],j] == min) or 0 otherwise
    //And for updates, dL/du = dL/dOut if(update[i,j]==min) or 0 otherwise

    List<SDVariable> ret = new ArrayList<>(3);
    SDVariable notModified = arg(0).eq(outputVariable()).castTo(arg(0).dataType());   //0 if modified, 1 otherwise
    SDVariable refGrad = gradOut.get(0).mul(notModified);

    SDVariable gatherOut = sameDiff.gather(outputVariable(), arg(1), 0);
    SDVariable gatherGrad = sameDiff.gather(gradOut.get(0), arg(1), 0);
    SDVariable outIsUpdate = gatherOut.eq(arg(2)).castTo(arg(2).dataType());
    SDVariable updateGrad = gatherGrad.mul(outIsUpdate);

    return Arrays.asList(refGrad, sameDiff.zerosLike(arg(1)), updateGrad);
}
 
Example 4
Source File: ScatterMul.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> gradOut){
    //3 args: ref, indices, updates
    //For non-modified indices, input gradient (reference) is same as output gradient
    //For modified indices, dL/dref = dL/dOut * dOut/dRef = dL/dOut * d(ref * update)/dRef = dL/dOut * update
    //And for updates, dL/du = dL/dOut * dOut/du = dL/dOut * d(ref * update)/du = dL/dOut * ref

    SDVariable ref = arg(0);
    SDVariable indices = arg(1);
    SDVariable updates = arg(2);

    List<SDVariable> ret = new ArrayList<>(3);
    SDVariable gradRef = sameDiff.scatterMul(gradOut.get(0), indices, updates);
    ret.add(gradRef);            //Reference array
    ret.add(sameDiff.zerosLike(arg(1)));  //Indices

    SDVariable gatherOutGrad = sameDiff.gather(gradOut.get(0), indices, 0);       //Updates
    SDVariable gatherRef = sameDiff.gather(ref, indices, 0);
    SDVariable updateGrad = gatherOutGrad.mul(gatherRef);
    ret.add(updateGrad);

    return ret;
}
 
Example 5
Source File: Cross.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> gradients) {
    /**
     * dL / dx = dL / dCross * dCross / dx
     * dCross(a,b) / da = Cross(1, b)
     * dCross(a,b) / db = Cross(a, 1)
     *
     * return (grad * Cross(1, b), grad * Cross(a, 1)
     */
    SDVariable grad = gradients.get(0);
    SDVariable a = larg();
    SDVariable b = rarg();
    SDVariable ones = sameDiff.onesLike(a);

    SDVariable gradLeft = grad.mul(sameDiff.cross(ones, b));
    SDVariable gradRight = grad.mul(sameDiff.cross(a, ones));

    return Arrays.asList(gradLeft, gradRight);
}
 
Example 6
Source File: AttentionVertex.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public SDVariable defineVertex(SameDiff sameDiff, Map<String, SDVariable> layerInput, Map<String, SDVariable> paramTable, Map<String, SDVariable> maskVars) {
    final SDVariable queries = layerInput.get("queries");
    final SDVariable keys = layerInput.get("keys");
    final SDVariable values = layerInput.get("values");
    final SDVariable mask = maskVars != null ? sameDiff.min(maskVars.get("keys"), maskVars.get("values")): null;

    SDVariable attention;
    if(projectInput){
        val Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION);
        val Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION);
        val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
        val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);

        attention = sameDiff.nn.multiHeadDotProductAttention(getLayerName(), queries, keys, values, Wq, Wk, Wv, Wo, mask, true);
    }else{
        attention = sameDiff.nn.dotProductAttention(getLayerName(), queries, keys, values, mask, true);
    }

    if(maskVars != null){
        return attention.mul(sameDiff.expandDims(maskVars.get("queries"), 1));
    }else{
        return attention;
    }
}
 
Example 7
Source File: NormMax.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    //maxnorm(in) = max_i |x_i|
    //d maxnorm(in)/dx = 0 if x_i is not the max, or d|x|/dx otherwise

    SDVariable absIn = sameDiff.abs(arg());
    SDVariable maxnorm = outputVariables()[0];
    int origRank = Shape.rankFromShape(arg().getShape());   //TODO shape may not always be defined?
    SDVariable maxnormBc = f().reductionBroadcastableWithOrigShape(origRank, dimensions, maxnorm);
    maxnormBc = sameDiff.onesLike(arg()).mul(maxnormBc);
    SDVariable eq = sameDiff.eq(absIn, maxnormBc);
    SDVariable dAbsXdX = sameDiff.sign(arg());
    SDVariable dNormmaxDx = eq.mul(dAbsXdX);
    SDVariable broadcastableGrad = f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));
    SDVariable ret = dNormmaxDx.mul(broadcastableGrad);
    return Arrays.asList(ret);
}
 
Example 8
Source File: Tan.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
    //d(tan(x))/dx = (sec(x))^2 = 1 / (cos(x))^2

    SDVariable oneDivCos2 = sameDiff.square(sameDiff.cos(arg())).rdiv(1.0);
    SDVariable ret = oneDivCos2.mul(i_v.get(0));
    return Arrays.asList(ret);
}
 
Example 9
Source File: EuclideanDistance.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    //ddist(x,y)/dxi = (xi-yi)/dist(x,y)
    SDVariable euc = outputVariables()[0];
    SDVariable difference = larg().sub(rarg());
    SDVariable divBroadcastable = i_v1.get(0).div(euc);
    if(!keepDims && !(dimensions == null || dimensions.length == 0 || (dimensions.length == 1 && dimensions[0] == Integer.MAX_VALUE))){
        //Not keep dims, and not full array reduction -> need to make broadcastable
        divBroadcastable = SameDiffUtils.reductionBroadcastableWithOrigShape(arg(), sameDiff.constant(Nd4j.createFromArray(dimensions)), divBroadcastable);
    }

    SDVariable gradX = difference.mul(divBroadcastable);
    SDVariable gradY = sameDiff.math.neg(gradX);
    return Arrays.asList(gradX, gradY);
}
 
Example 10
Source File: Variance.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    //If out = var(in) then:
    //dL/dIn = dL/dOut * dOut/dIn
    // with dOut/dIn = (in-mean) * 2/(n-1)
    val n = f().getReductionLength(this);
    int origRank = Shape.rankFromShape(arg().getShape());
    SDVariable broadcastableMean = f().reductionBroadcastableWithOrigShape(origRank, dimensions, f().mean(arg(), dimensions));
    SDVariable broadcastableGrad = f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));
    SDVariable dOutdIn = arg().sub(broadcastableMean).mul(2.0 / (biasCorrected ? (n - 1) : n));

    SDVariable dLdIn = dOutdIn.mul(broadcastableGrad);
    return Arrays.asList(dLdIn);
}
 
Example 11
Source File: SameDiffSimpleLambdaVertex.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public SDVariable defineVertex(SameDiff sameDiff, VertexInputs inputs) {
    SDVariable in1 = inputs.getInput(0);
    SDVariable in2 = inputs.getInput(1);
    SDVariable ret = in1.mul(in2);
    return ret;
}
 
Example 12
Source File: Max.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    //TODO do we need to handle the "multiple equal maximums" case?
    //TODO code duplication (min/max)

    SDVariable out = outputVariables()[0];
    int origRank = Shape.rankFromShape(arg().getShape());
    SDVariable expandedOut = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, out);
    expandedOut = sameDiff.onesLike(arg()).mul(expandedOut);
    SDVariable expandedGrad = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));

    SDVariable eq = sameDiff.eq(arg(), expandedOut);
    SDVariable ret = eq.mul(expandedGrad);
    return Arrays.asList(ret);
}
 
Example 13
Source File: TransformOpValidation.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testScalarOps() {
    int d0 = 2;
    int d1 = 3;
    int d2 = 4;

    int n = d0 * d1 * d2;

    List<String> failed = new ArrayList<>();

    for (int i = 0; i < 11; i++) {
        for (char inOrder : new char[]{'c', 'f'}) {
            SameDiff sd = SameDiff.create();

            INDArray inArr = Nd4j.linspace(1, n, n, DataType.DOUBLE).reshape('c', d0, d1, d2).dup(inOrder);
            SDVariable in = sd.var("in", inArr);
            TestCase tc = new TestCase(sd).gradientCheck(true);

            SDVariable out;
            String msg;
            switch (i) {
                case 0:
                    out = in.mul(2);
                    tc.expectedOutput(out.name(), inArr.mul(2));
                    msg = "mul - " + inOrder;
                    break;
                case 1:
                    out = in.div(2);
                    tc.expectedOutput(out.name(), inArr.div(2));
                    msg = "div - " + inOrder;
                    break;
                case 2:
                    out = in.add(2);
                    tc.expectedOutput(out.name(), inArr.add(2));
                    msg = "add - " + inOrder;
                    break;
                case 3:
                    out = in.sub(2);
                    tc.expectedOutput(out.name(), inArr.sub(2));
                    msg = "sub - " + inOrder;
                    break;
                case 4:
                    out = in.rdiv(2);
                    tc.expectedOutput(out.name(), inArr.rdiv(2));
                    msg = "rdiv - " + inOrder;
                    break;
                case 5:
                    out = in.rsub(2);
                    tc.expectedOutput(out.name(), inArr.rsub(2));
                    msg = "rsub - " + inOrder;
                    break;
                case 6:
                    out = sd.math().pow(in, 2);
                    tc.expectedOutput(out.name(), Transforms.pow(inArr, 2));
                    msg = "pow - " + inOrder;
                    break;
                case 7:
                    inArr.assign(Nd4j.rand(inArr.dataType(), inArr.shape()).muli(5).subi(2.5));
                    out = sd.math().floorMod(in, 2.0);
                    tc.expected(out, Nd4j.getExecutioner().exec(new ScalarFMod(inArr.dup(), 2.0)));
                    msg = "scalarFloorMod - " + inOrder;
                    break;
                case 8:
                    inArr.assign(Nd4j.rand(inArr.shape()));
                    out = sd.scalarMax(in, 0.5);
                    tc.expected(out, Transforms.max(inArr.dup(), 0.5));
                    msg = "scalarMax - " + inOrder;
                    break;
                case 9:
                    inArr.assign(Nd4j.rand(inArr.shape()));
                    out = sd.scalarMin(in, 0.5);
                    tc.expected(out, Transforms.min(inArr.dup(), 0.5));
                    msg = "scalarMin - " + inOrder;
                    break;
                case 10:
                    out = in.assign(0.5);
                    tc.expected(out, Nd4j.valueArrayOf(inArr.shape(), 0.5));
                    msg = "scalarSet - " + inOrder;
                    break;
                default:
                    throw new RuntimeException();
            }

            tc.testName(msg);

            SDVariable loss = sd.standardDeviation(out, true);

            log.info("Starting test: " + msg);
            String err = OpValidation.validate(tc, true);
            if (err != null) {
                failed.add(err);
            }
        }
    }
    assertEquals(failed.toString(), 0, failed.size());
}
 
Example 14
Source File: MiscOpValidation.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testGradientAutoBroadcast2() {
    Nd4j.getRandom().setSeed(12345);

    List<String> failed = new ArrayList<>();

    for (int[] dim_sz1s : new int[][]{{0, 1}, {0, 2}, {1, 2}, {0, 1, 2}}) {

        long[] otherShape = {3, 4, 5};
        otherShape[dim_sz1s[0]] = 1;
        otherShape[dim_sz1s[1]] = 1;
        if (dim_sz1s.length == 3) {
            otherShape[dim_sz1s[2]] = 1;
        }

        for (int i = 0; i < 8; i++) {

            SameDiff sd = SameDiff.create();

            SDVariable in3 = sd.var("in3", DataType.DOUBLE, 3, 4, 5);
            SDVariable in2 = sd.var("inToBc", DataType.DOUBLE, otherShape);

            String name;
            SDVariable bcOp;
            switch (i) {
                case 0:
                    bcOp = in3.add(in2);
                    name = "add";
                    break;
                case 1:
                    bcOp = in3.sub(in2);
                    name = "sub";
                    break;
                case 2:
                    bcOp = in3.mul(in2);
                    name = "mul";
                    break;
                case 3:
                    bcOp = in3.div(in2);
                    name = "div";
                    break;
                case 4:
                    bcOp = in3.rsub(in2);
                    name = "rsub";
                    break;
                case 5:
                    bcOp = in3.rdiv(in2);
                    name = "rdiv";
                    break;
                case 6:
                    //bcOp = sd.scalarFloorDiv(in3, in2);
                    bcOp = new FloorDivOp(sd, in3, in2).outputVariable();
                    name = "floordiv";
                    break;
                case 7:
                    //bcOp = sd.scalarFloorMod(in3, in2);
                    bcOp = new FloorModOp(sd, in3, in2).outputVariable();
                    name = "floormod";
                    if(OpValidationSuite.IGNORE_FAILING){
                        //https://github.com/deeplearning4j/deeplearning4j/issues/5976
                        continue;
                    }
                    break;
                default:
                    throw new RuntimeException();
            }

            SDVariable outVar = sd.sum(bcOp);

            String msg = "(test " + i + ": " + name + ", dimensions=" + Arrays.toString(dim_sz1s) + ")";
            log.info("*** Starting test: " + msg);

            INDArray in3Arr = Nd4j.randn(DataType.DOUBLE, 3, 4, 5).muli(100);
            INDArray in2Arr = Nd4j.randn(DataType.DOUBLE, otherShape).muli(100);

            sd.associateArrayWithVariable(in3Arr, in3);
            sd.associateArrayWithVariable(in2Arr, in2);

            TestCase tc = new TestCase(sd);
            String error = OpValidation.validate(tc);
            if(error != null){
                failed.add(name);
            }
        }
    }

    assertEquals("Failed: " + failed, 0, failed.size());
}
 
Example 15
Source File: ReductionOpValidation.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testReductionsBackwards() {
//        for (int i = 0; i < 7; i++) {
        int i=5;
        {

            SameDiff sd = SameDiff.create();

            int nOut = 4;
            int minibatch = 3;
            SDVariable input = sd.var("in", DataType.DOUBLE, new long[]{minibatch, nOut});
            SDVariable label = sd.var("label", DataType.DOUBLE, new long[]{minibatch, nOut});

            SDVariable diff = input.sub(label);
            SDVariable sqDiff = diff.mul(diff);
            SDVariable msePerEx = sd.mean("msePerEx", sqDiff, 1);

            SDVariable loss;    //Scalar value
            String name;
            switch (i) {
                case 0:
                    loss = sd.mean("loss", msePerEx, 0);
                    name = "mean";
                    break;
                case 1:
                    loss = sd.sum("loss", msePerEx, 0);
                    name = "sum";
                    break;
                case 2:
                    loss = sd.standardDeviation("loss", msePerEx, true, 0);
                    name = "stdev";
                    break;
                case 3:
                    loss = sd.min("loss", msePerEx, 0);
                    name = "min";
                    break;
                case 4:
                    loss = sd.max("loss", msePerEx, 0);
                    name = "max";
                    break;
                case 5:
                    loss = sd.variance("loss", msePerEx, true, 0);
                    name = "variance";
                    break;
                case 6:
                    loss = sd.prod("loss", msePerEx, 0);
                    name = "prod";
                    break;
                default:
                    throw new RuntimeException();
            }


            String msg = "test: " + i + " - " + name;
            log.info("*** Starting test: " + msg);

            INDArray inputArr = Nd4j.rand(DataType.DOUBLE, minibatch, nOut);
            INDArray labelArr = Nd4j.rand(DataType.DOUBLE, minibatch, nOut);

            sd.associateArrayWithVariable(inputArr, input);
            sd.associateArrayWithVariable(labelArr, label);

            INDArray result = loss.eval();
            assertEquals(1, result.length());

            sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet());
        }
    }
 
Example 16
Source File: KerasLambdaTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public SDVariable defineLayer(SameDiff sd, SDVariable x) { return x.mul(3); }
 
Example 17
Source File: GradCheckMisc.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testGradientAutoBroadcast1() {

        Nd4j.getRandom().setSeed(12345);

        List<String> allFailed = new ArrayList<>();

        for (int dim_sz1 : new int[]{0, 1, 2}) {

            int[] in2Shape = {3, 4, 5};
            in2Shape[dim_sz1] = 1;

            for (int i = 2; i < 3; i++) {

                SameDiff sd = SameDiff.create();

                SDVariable in3 = sd.var("in3", Nd4j.rand(new int[]{3, 4, 5}));
                SDVariable in2 = sd.var("in2", in2Shape);

                SDVariable bcOp;
                String name;
                switch (i) {
                    case 0:
                        bcOp = in3.add(in2);
                        name = "add";
                        break;
                    case 1:
                        bcOp = in3.sub(in2);
                        name = "sub";
                        break;
                    case 2:
                        bcOp = in3.mul(in2);
                        name = "mul";
                        break;
                    case 3:
                        bcOp = in3.div(in2);
                        name = "div";
                        break;
                    case 4:
                        bcOp = in3.rsub(in2);
                        name = "rsub";
                        break;
                    case 5:
                        bcOp = in3.rdiv(in2);
                        name = "rdiv";
                        break;
                    case 6:
                        bcOp = sd.f().floorDiv(in3, in2);
                        name = "floordiv";
                        break;
                    case 7:
                        bcOp = sd.f().floorMod(in3, in2);
                        name = "floormod";
                        break;
                    default:
                        throw new RuntimeException();
                }

                SDVariable outVar = sd.sum(bcOp);

                String msg = "(test " + i + ": " + name + ", dimension=" + dim_sz1 + ")";
                log.info("*** Starting test: " + msg);

                INDArray in3Arr = Nd4j.randn(new int[]{3, 4, 5}).muli(100);
                INDArray in2Arr = Nd4j.randn(in2Shape).muli(100);

                sd.associateArrayWithVariable(in3Arr, in3);
                sd.associateArrayWithVariable(in2Arr, in2);

                try {
                    INDArray out = sd.execAndEndResult();
                    assertNotNull(out);
                    assertArrayEquals(new long[]{1, 1}, out.shape());

//                    System.out.println(sd.asFlatPrint());

                    boolean ok = GradCheckUtil.checkGradients(sd);
                    if (!ok) {
                        allFailed.add(msg);
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                    allFailed.add(msg + " - EXCEPTION");
                }
            }
        }

        assertEquals("Failed: " + allFailed, 0, allFailed.size());
    }
 
Example 18
Source File: MiscOpValidation.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testGradientAutoBroadcast3() {
    //These tests: output size > input sizes

    Nd4j.getRandom().setSeed(12345);

    List<String> failed = new ArrayList<>();

    //Test cases: in1Shape, in2Shape, shapeOf(op(in1,in2))
    List<Triple<long[], long[], long[]>> testCases = new ArrayList<>();
    testCases.add(new Triple<>(new long[]{3, 1}, new long[]{1, 4}, new long[]{3, 4}));
    testCases.add(new Triple<>(new long[]{3, 1}, new long[]{3, 4}, new long[]{3, 4}));
    testCases.add(new Triple<>(new long[]{3, 4}, new long[]{1, 4}, new long[]{3, 4}));
    testCases.add(new Triple<>(new long[]{3, 4, 1}, new long[]{1, 1, 5}, new long[]{3, 4, 5}));
    testCases.add(new Triple<>(new long[]{3, 4, 1}, new long[]{3, 1, 5}, new long[]{3, 4, 5}));
    testCases.add(new Triple<>(new long[]{3, 1, 5}, new long[]{1, 4, 1}, new long[]{3, 4, 5}));
    testCases.add(new Triple<>(new long[]{3, 1, 5}, new long[]{1, 4, 5}, new long[]{3, 4, 5}));
    testCases.add(new Triple<>(new long[]{3, 1, 5}, new long[]{3, 4, 5}, new long[]{3, 4, 5}));
    testCases.add(new Triple<>(new long[]{3, 1, 1, 1}, new long[]{1, 4, 5, 6}, new long[]{3, 4, 5, 6}));
    testCases.add(new Triple<>(new long[]{1, 1, 1, 6}, new long[]{3, 4, 5, 6}, new long[]{3, 4, 5, 6}));
    testCases.add(new Triple<>(new long[]{1, 4, 5, 1}, new long[]{3, 1, 1, 6}, new long[]{3, 4, 5, 6}));
    if(!OpValidationSuite.IGNORE_FAILING) {
        testCases.add(new Triple<>(new long[]{1, 6}, new long[]{3, 4, 5, 1}, new long[]{3, 4, 5, 6}));
    }

    for (val p : testCases) {

        for (int i = 0; i < 8; i++) {

            SameDiff sd = SameDiff.create();

            SDVariable in3 = sd.var("in1", DataType.DOUBLE, p.getFirst());
            SDVariable in2 = sd.var("in2", DataType.DOUBLE, p.getSecond());

            String name;
            SDVariable bcOp;
            switch (i) {
                case 0:
                    bcOp = in3.add(in2);
                    name = "add";
                    break;
                case 1:
                    bcOp = in3.sub(in2);
                    name = "sub";
                    break;
                case 2:
                    bcOp = in3.mul(in2);
                    name = "mul";
                    break;
                case 3:
                    bcOp = in3.div(in2);
                    name = "div";
                    break;
                case 4:
                    bcOp = in3.rsub(in2);
                    name = "rsub";
                    break;
                case 5:
                    bcOp = in3.rdiv(in2);
                    name = "rdiv";
                    break;
                case 6:
                    //bcOp = sd.scalarFloorDiv(in3, in2);
                    bcOp = new FloorDivOp(sd, in3, in2).outputVariable();
                    name = "floordiv";
                    break;
                case 7:
                    //bcOp = sd.scalarFloorMod(in3, in2);
                    bcOp = new FloorModOp(sd, in3, in2).outputVariable();
                    name = "floormod";
                    if(OpValidationSuite.IGNORE_FAILING){
                        //https://github.com/deeplearning4j/deeplearning4j/issues/5976
                        continue;
                    }
                    break;
                default:
                    throw new RuntimeException();
            }

            SDVariable outVar = sd.sum(bcOp);

            String msg = "(test " + i + ": " + name + ", array 1 size =" + Arrays.toString(p.getFirst())
                    + ", array 2 size = " + Arrays.toString(p.getSecond()) + ")";
            log.info("*** Starting test: " + msg);

            INDArray in3Arr = Nd4j.rand(DataType.DOUBLE, p.getFirst()).muli(100);
            INDArray in2Arr = Nd4j.rand(DataType.DOUBLE, p.getSecond()).muli(100);

            sd.associateArrayWithVariable(in3Arr, in3);
            sd.associateArrayWithVariable(in2Arr, in2);

            TestCase tc = new TestCase(sd);
            String error = OpValidation.validate(tc);
            if(error != null){
                failed.add(name + " " + i +  " - " + error);
            }
        }
    }

    assertEquals("Failed: " + failed, 0, failed.size());
}
 
Example 19
Source File: GradCheckReductions.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testReductionGradients1() {
        //Test reductions: final, but *not* the only function
        Nd4j.getRandom().setSeed(12345);

        List<String> allFailed = new ArrayList<>();

        for (int dim : new int[]{0, Integer.MAX_VALUE}) {    //These two cases are equivalent here

            for (int i = 0; i < 10; i++) {

                SameDiff sd = SameDiff.create();

                int nOut = 4;
                int minibatch = 10;
                SDVariable input = sd.var("in", new int[]{-1, nOut});
                SDVariable label = sd.var("label", new int[]{-1, nOut});

                SDVariable diff = input.sub(label);
                SDVariable sqDiff = diff.mul(diff);
                SDVariable msePerEx = sd.mean("msePerEx", sqDiff, 1);

                SDVariable loss;
                String name;
                switch (i) {
                    case 0:
                        loss = sd.mean("loss", msePerEx, dim);
                        name = "mean";
                        break;
                    case 1:
                        loss = sd.sum("loss", msePerEx, dim);
                        name = "sum";
                        break;
                    case 2:
                        loss = sd.standardDeviation("loss", msePerEx, true, dim);
                        name = "stdev";
                        break;
                    case 3:
                        loss = sd.min("loss", msePerEx, dim);
                        name = "min";
                        break;
                    case 4:
                        loss = sd.max("loss", msePerEx, dim);
                        name = "max";
                        break;
                    case 5:
                        loss = sd.variance("loss", msePerEx, true, dim);
                        name = "variance";
                        break;
                    case 6:
                        loss = sd.prod("loss", msePerEx, dim);
                        name = "prod";
                        break;
                    case 7:
                        loss = sd.norm1("loss", msePerEx, dim);
                        name = "norm1";
                        break;
                    case 8:
                        loss = sd.norm2("loss", msePerEx, dim);
                        name = "norm2";
                        break;
                    case 9:
                        loss = sd.normmax("loss", msePerEx, dim);
                        name = "normmax";
                        break;
                    default:
                        throw new RuntimeException();
                }


                String msg = "(test " + i + " - " + name + ", dimension=" + dim + ")";
                log.info("*** Starting test: " + msg);

                INDArray inputArr = Nd4j.randn(minibatch, nOut).muli(100);
                INDArray labelArr = Nd4j.randn(minibatch, nOut).muli(100);

                sd.associateArrayWithVariable(inputArr, input);
                sd.associateArrayWithVariable(labelArr, label);

                try {
                    INDArray out = sd.execAndEndResult();
                    assertNotNull(out);
                    assertArrayEquals(new int[]{1, 1}, out.shape());

//                    System.out.println(sd.asFlatPrint());

                    boolean ok = GradCheckUtil.checkGradients(sd);
                    if (!ok) {
                        allFailed.add(msg);
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                    allFailed.add(msg + " - EXCEPTION");
                }
            }
        }

        assertEquals("Failed: " + allFailed, 0, allFailed.size());
    }
 
Example 20
Source File: GradCheckMisc.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testGradientAutoBroadcast2() {

        Nd4j.getRandom().setSeed(12345);

        List<String> allFailed = new ArrayList<>();

        for (int[] dim_sz1s : new int[][]{{0, 1}, {0, 2}, {1, 2}, {0,1,2}}) {

            int[] otherShape = {3, 4, 5};
            otherShape[dim_sz1s[0]] = 1;
            otherShape[dim_sz1s[1]] = 1;
            if(dim_sz1s.length == 3){
                otherShape[dim_sz1s[2]] = 1;
            }

            for (int i = 0; i < 6; i++) {

                SameDiff sd = SameDiff.create();

                SDVariable in3 = sd.var("in3", new int[]{3, 4, 5});
                SDVariable in2 = sd.var("inToBc", otherShape);

                String name;
                SDVariable bcOp;
                switch (i) {
                    case 0:
                        bcOp = in3.add(in2);
                        name = "add";
                        break;
                    case 1:
                        bcOp = in3.sub(in2);
                        name = "sub";
                        break;
                    case 2:
                        bcOp = in3.mul(in2);
                        name = "mul";
                        break;
                    case 3:
                        bcOp = in3.div(in2);
                        name = "div";
                        break;
                    case 4:
                        bcOp = in3.rsub(in2);
                        name = "rsub";
                        break;
                    case 5:
                        bcOp = in3.rdiv(in2);
                        name = "rdiv";
                        break;
                    case 6:
                        bcOp = sd.f().floorDiv(in3, in2);
                        name = "floordiv";
                        break;
                    case 7:
                        bcOp = sd.f().floorMod(in3, in2);
                        name = "floormod";
                        break;
                    default:
                        throw new RuntimeException();
                }

                SDVariable outVar = sd.sum(bcOp);

                String msg = "(test " + i + ": " + name + ", dimensions=" + Arrays.toString(dim_sz1s) + ")";
                log.info("*** Starting test: " + msg);

                INDArray in3Arr = Nd4j.randn(new int[]{3,4,5}).muli(100);
                INDArray in2Arr = Nd4j.randn(otherShape).muli(100);

                sd.associateArrayWithVariable(in3Arr, in3);
                sd.associateArrayWithVariable(in2Arr, in2);

                try {
                    INDArray out = sd.execAndEndResult();
                    assertNotNull(out);
                    assertArrayEquals(new long[]{1, 1}, out.shape());

//                    System.out.println(sd.asFlatPrint());

                    boolean ok = GradCheckUtil.checkGradients(sd);
                    if (!ok) {
                        allFailed.add(msg);
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                    allFailed.add(msg + " - EXCEPTION");
                }
            }
        }

        assertEquals("Failed: " + allFailed, 0, allFailed.size());
    }