Java Code Examples for org.nd4j.linalg.api.shape.Shape#rankFromShape()

The following examples show how to use org.nd4j.linalg.api.shape.Shape#rankFromShape() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ClipByNorm.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
    public List<SDVariable> doDiff(List<SDVariable> grad) {
        //dOut/dIn is ??? if clipped, 1 otherwise
        int origRank = Shape.rankFromShape(arg().getShape());
        SDVariable l2norm = f().norm2(arg(), dimensions);
        SDVariable broadcastableNorm = f().reductionBroadcastableWithOrigShape(origRank, dimensions, l2norm);
        SDVariable isClippedBC = f().gte(broadcastableNorm, clipValue);
        SDVariable notClippedBC = isClippedBC.rsub(1.0);

//        SDVariable dnormdx = arg().div(broadcastableNorm);
//        SDVariable sqNorm = f().square(broadcastableNorm);
//        SDVariable dOutdInClipped = sqNorm.rdiv(-1).mul(dnormdx).mul(arg()) //-1/(norm2(x))^2 * x/norm2(x)
//                .add(broadcastableNorm.rdiv(1.0))
//                .mul(clipValue);

        SDVariable dOutdInClipped = f().neg(f().square(arg()).div(f().cube(broadcastableNorm))) //-x^2/(norm2(x))^3
                .add(broadcastableNorm.rdiv(1.0))   //+ 1/norm(x)
                .mul(clipValue).mul(isClippedBC);


        SDVariable ret = notClippedBC.add(dOutdInClipped).mul(grad.get(0));
        return Arrays.asList(ret);
    }
 
Example 2
Source File: EuclideanDistance.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    //ddist(x,y)/dxi = (xi-yi)/dist(x,y)
    SDVariable euc = outputVariables()[0];
    SDVariable difference = larg().sub(rarg());
    SDVariable divBroadcastable;
    int origRank = Shape.rankFromShape(arg().getShape());   //TODO shape may not always be defined?
    if(!(dimensions.length == 1 && dimensions[0] == Integer.MAX_VALUE) ){
        //1x1 output case
        divBroadcastable = i_v1.get(0).div(euc);
    } else {
        divBroadcastable = f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0).div(euc));
    }

    SDVariable gradX = difference.mul(divBroadcastable);
    SDVariable gradY = f().neg(gradX);
    return Arrays.asList(gradX, gradY);
}
 
Example 3
Source File: JaccardDistance.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
    //Jaccard distance: https://en.wikipedia.org/wiki/Jaccard_index#Generalized_Jaccard_similarity_and_distance
    //J(x,y) = 1 - sum_i min(x_i, y_i) / sum_i max(x_i, y_i)

    int rank = Shape.rankFromShape(larg().getShape());

    SDVariable jSim = outputVariables()[0].rsub(1.0);   //jaccard similarity = 1 - jaccard distance
    SDVariable min = f().min(larg(), rarg());
    SDVariable max = f().max(larg(), rarg());
    SDVariable sumMax = f().sum(max, dimensions);
    SDVariable broadcastableSumMax = f().reductionBroadcastableWithOrigShape(rank, dimensions, sumMax);
    SDVariable broadcastableJSim = f().reductionBroadcastableWithOrigShape(rank, dimensions, jSim);

    SDVariable xIsMin = f().eq(min, larg());
    SDVariable xIsMax = f().eq(max, larg());
    SDVariable yIsMin = f().eq(min, rarg());
    SDVariable yIsMax = f().eq(max, rarg());

    SDVariable dldx = xIsMax.mul(broadcastableJSim).sub(xIsMin).div(broadcastableSumMax);
    SDVariable dldy = yIsMax.mul(broadcastableJSim).sub(yIsMin).div(broadcastableSumMax);

    return Arrays.asList(dldx.mul(f1.get(0)), dldy.mul(f1.get(0)));
}
 
Example 4
Source File: ManhattanDistance.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    //ddist(x,y)/dxi = sign(xi-yi)
    SDVariable difference = larg().sub(rarg());
    SDVariable gradBroadcastable;
    int origRank = Shape.rankFromShape(arg().getShape());   //TODO shape may not always be defined?
    if(!(dimensions.length == 1 && dimensions[0] == Integer.MAX_VALUE) ){
        //1x1 output case
        gradBroadcastable = i_v1.get(0);
    } else {
        gradBroadcastable = f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));
    }

    SDVariable gradX = sameDiff.sign(difference).mul(gradBroadcastable);
    SDVariable gradY = f().neg(gradX);
    return Arrays.asList(gradX, gradY);
}
 
Example 5
Source File: CosineSimilarity.java    From nd4j with Apache License 2.0 6 votes vote down vote up
public static List<SDVariable> doDiff(SameDiff sameDiff, DifferentialFunctionFactory f, SDVariable x, SDVariable y,
                                      SDVariable gradOut, int... dimensions){
    SDVariable a = sameDiff.sum(x.mul(y),dimensions);
    SDVariable l2x = f.norm2(x, dimensions);
    SDVariable l2y = f.norm2(y, dimensions);
    SDVariable b = l2x.mul(l2y);

    int origRank = Shape.rankFromShape(x.getShape());
    SDVariable broadcastableA = f.reductionBroadcastableWithOrigShape(origRank, dimensions, a);
    SDVariable broadcastableB = f.reductionBroadcastableWithOrigShape(origRank, dimensions, b);
    SDVariable broadcastableL2xSq = f.reductionBroadcastableWithOrigShape(origRank, dimensions, sameDiff.square(l2x));
    SDVariable broadcastableL2ySq = f.reductionBroadcastableWithOrigShape(origRank, dimensions, sameDiff.square(l2y));
    SDVariable broadcastableGrad = f.reductionBroadcastableWithOrigShape(origRank, dimensions, gradOut);

    SDVariable dcdx = y.sub(x.mul(broadcastableA).div(broadcastableL2xSq)).div(broadcastableB);
    SDVariable dcdy = x.sub(y.mul(broadcastableA).div(broadcastableL2ySq)).div(broadcastableB);

    return Arrays.asList(dcdx.mul(broadcastableGrad), dcdy.mul(broadcastableGrad));
}
 
Example 6
Source File: Min.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    //TODO do we need to handle the "multiple equal minimums" case?
    //TODO code duplication (min/max)

    SDVariable out = outputVariables()[0];

    int origRank = Shape.rankFromShape(arg().getShape());
    SDVariable expandedOut = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, out);
    expandedOut = sameDiff.onesLike("temp0", arg()).mul("tempmul", expandedOut);
    SDVariable expandedGrad = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));

    SDVariable eq = sameDiff.eq(arg(), expandedOut);
    SDVariable ret = eq.mul(expandedGrad);
    return Arrays.asList(ret);
}
 
Example 7
Source File: StandardDeviation.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    //Here: calculating dL/dIn given dL/dOut (i.e., i_v1) and input/output
    //If out = stdev(in) then:
    //dL/dIn = dL/dOut * dOut/dIn
    //dOut/dIn_i = (in_i-mean)/(stdev * (n-1))
    int origRank = Shape.rankFromShape(arg().getShape());
    long n = f().getReductionLength(this);
    SDVariable broadcastableStdevOut = f().reductionBroadcastableWithOrigShape(origRank, dimensions, outputVariables()[0]);
    SDVariable broadcastableMean = f().reductionBroadcastableWithOrigShape(origRank, dimensions, f().mean(arg(), dimensions));
    SDVariable diff = arg().sub(broadcastableMean);

    SDVariable dOutdIn = diff.div(broadcastableStdevOut);
    if (this.biasCorrected) {
        dOutdIn = dOutdIn.div(n - 1);
    } else {
        dOutdIn = dOutdIn.div(n);
    }


    SDVariable broadcastableGrad = f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));

    SDVariable dLdIn = dOutdIn.mul(broadcastableGrad);
    return Arrays.asList(dLdIn);
}
 
Example 8
Source File: NormMax.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    //maxnorm(in) = max_i |x_i|
    //d maxnorm(in)/dx = 0 if x_i is not the max, or d|x|/dx otherwise

    SDVariable absIn = sameDiff.abs(arg());
    SDVariable maxnorm = outputVariables()[0];
    int origRank = Shape.rankFromShape(arg().getShape());   //TODO shape may not always be defined?
    SDVariable maxnormBc = f().reductionBroadcastableWithOrigShape(origRank, dimensions, maxnorm);
    maxnormBc = sameDiff.onesLike(arg()).mul(maxnormBc);
    SDVariable eq = sameDiff.eq(absIn, maxnormBc);
    SDVariable dAbsXdX = sameDiff.sign(arg());
    SDVariable dNormmaxDx = eq.mul(dAbsXdX);
    SDVariable broadcastableGrad = f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));
    SDVariable ret = dNormmaxDx.mul(broadcastableGrad);
    return Arrays.asList(ret);
}
 
Example 9
Source File: Norm2.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    //d norm2(in)/dx = x / norm2(in)

    SDVariable norm2 = outputVariables()[0];
    int origRank = Shape.rankFromShape(arg().getShape());   //TODO shape may not always be defined?
    SDVariable broadcastableNorm2 = f().reductionBroadcastableWithOrigShape(origRank, dimensions, norm2);
    SDVariable broadcastableGradOut = f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));
    SDVariable ret = arg().div(broadcastableNorm2).mul(broadcastableGradOut);
    return Arrays.asList(ret);
}
 
Example 10
Source File: Variance.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    //If out = var(in) then:
    //dL/dIn = dL/dOut * dOut/dIn
    // with dOut/dIn = (in-mean) * 2/(n-1)
    val n = f().getReductionLength(this);
    int origRank = Shape.rankFromShape(arg().getShape());
    SDVariable broadcastableMean = f().reductionBroadcastableWithOrigShape(origRank, dimensions, f().mean(arg(), dimensions));
    SDVariable broadcastableGrad = f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));
    SDVariable dOutdIn = arg().sub(broadcastableMean).mul(2.0 / (biasCorrected ? (n - 1) : n));

    SDVariable dLdIn = dOutdIn.mul(broadcastableGrad);
    return Arrays.asList(dLdIn);
}
 
Example 11
Source File: Mean.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    //If out = mean(in), then dL/dIn = 1/N * dL/dOut  (broadcast to appropriate shape)
    //Note that N differs for "along dimension" vs. "whole array" reduce cases
    long n = f().getReductionLength(this);

    int rank = Shape.rankFromShape(arg().getShape());
    SDVariable broadcastableGrad = f().reductionBroadcastableWithOrigShape(rank, dimensions, i_v1.get(0));
    SDVariable ret = sameDiff.onesLike(arg()).div(n);      //1/N with shape equal to input

    ret = ret.mul(broadcastableGrad);
    return Arrays.asList(ret);
}
 
Example 12
Source File: Sum.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    //Out = sum(in)
    // dL/dIn = dL/dOut * dOut/dIn
    //        = dL/dOut * 1
    // But broadcast to shape of the input

    int origRank = Shape.rankFromShape(arg().getShape());   //TODO shape may not always be defined?
    SDVariable broadcastable = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));
    SDVariable ret = sameDiff.onesLike(arg()).mul(broadcastable);
    return Arrays.asList(ret);
}
 
Example 13
Source File: Norm1.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    //d l1Norm(in)/dx = signum(x)
    SDVariable signum = sameDiff.sign(arg());

    //Note that we need to expand the dimensions of the gradient - auto-broadcast won't work for all cases.
    int origRank = Shape.rankFromShape(arg().getShape());   //TODO shape may not always be defined?
    SDVariable bcGrad = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));
    return Arrays.asList(signum.mul(bcGrad));
}
 
Example 14
Source File: Prod.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    SDVariable prod = outputVariables()[0];
    int origRank = Shape.rankFromShape(arg().getShape());   //TODO shape may not always be defined?
    SDVariable broadcastableGrad = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));
    SDVariable broadcastableProd = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, prod);
    SDVariable mul = broadcastableGrad.div(arg());
    SDVariable ret = broadcastableProd.mul(mul);
    return Arrays.asList(ret);
}
 
Example 15
Source File: Max.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    //TODO do we need to handle the "multiple equal maximums" case?
    //TODO code duplication (min/max)

    SDVariable out = outputVariables()[0];
    int origRank = Shape.rankFromShape(arg().getShape());
    SDVariable expandedOut = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, out);
    expandedOut = sameDiff.onesLike(arg()).mul(expandedOut);
    SDVariable expandedGrad = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));

    SDVariable eq = sameDiff.eq(arg(), expandedOut);
    SDVariable ret = eq.mul(expandedGrad);
    return Arrays.asList(ret);
}