Java Code Examples for org.nd4j.linalg.api.shape.Shape#areShapesBroadcastable()

The following examples show how to use org.nd4j.linalg.api.shape.Shape#areShapesBroadcastable() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public INDArray fmod(INDArray denominator, INDArray result) {
    validateNumericalArray("fmod", false);
    if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) {
        val outShape = Shape.broadcastOutputShape(this.shape(), denominator.shape());
        Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), "Result shape doesn't match expectations: " + Arrays.toString(result.shape()));

        Nd4j.exec(new FloorModOp(new INDArray[]{this, denominator}, new INDArray[]{result}));

        return result;
    } else {
        FModOp op = new FModOp(this, denominator, result);
        Nd4j.getExecutioner().exec(op);
        return result;
    }
}
 
Example 2
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray lt(INDArray other) {
    validateNumericalArray("less than (lt)", false);
    if (Shape.shapeEquals(this.shape(), other.shape())) {
        return Nd4j.getExecutioner().exec(new LessThan(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0];
    } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
        return Nd4j.exec(new LessThan(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0];
    } else
        throw new IllegalArgumentException("Shapes must be broadcastable");
}
 
Example 3
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray eq(INDArray other) {
    if (Shape.shapeEquals(this.shape(), other.shape())) {
        return Nd4j.getExecutioner().exec(new EqualTo(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0];
    } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
        return Nd4j.exec(new EqualTo(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0];
    } else
        throw new IllegalArgumentException("Shapes must be broadcastable");
}
 
Example 4
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray gt(INDArray other) {
    validateNumericalArray("greater than (gt)", false);
    if (Shape.shapeEquals(this.shape(), other.shape())) {
        return Nd4j.getExecutioner().exec(new GreaterThan(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0];
    } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
        return Nd4j.exec(new GreaterThan(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0];
    } else
        throw new IllegalArgumentException("Shapes must be broadcastable");
}
 
Example 5
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray div(INDArray other) {
    if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
        return divi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering()));
    } else {
        return divi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), this.ordering()));
    }
}
 
Example 6
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray mul(INDArray other) {
    validateNumericalArray("mul", false);
    if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
        return muli(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering()));
    } else {
        val z = Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), this.ordering());
        return muli(other, z);
    }
}
 
Example 7
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray sub(INDArray other) {
    validateNumericalArray("sub", false);
    if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
        return subi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering()));
    } else {
        return subi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), this.ordering()));
    }
}
 
Example 8
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray add(INDArray other) {
    validateNumericalArray("add", false);
    if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
        return addi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering()));
    } else {
        return addi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), this.ordering()));
    }
}
 
Example 9
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray rdiv(INDArray other) {
    validateNumericalArray("rdiv", false);
    if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
        return rdivi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering()));
    } else {
        return rdivi(other, this.ulike());
    }
}
 
Example 10
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray rsub(INDArray other) {
    validateNumericalArray("rsub", false);
    if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
        return rsubi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering()));
    } else {
        return rsubi(other, this.ulike());
    }
}
 
Example 11
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray remainder(INDArray denominator) {
    if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) {
        return remainder(denominator, Nd4j.createUninitialized(this.dataType(), Shape.broadcastOutputShape(this.shape(), denominator.shape())));
    } else
        return remainder(denominator, this.ulike());
}
 
Example 12
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray fmod(INDArray denominator) {
    validateNumericalArray("fmod", false);
    if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) {
        return fmod(denominator, Nd4j.createUninitialized(Nd4j.defaultFloatingPointType(), Shape.broadcastOutputShape(this.shape(), denominator.shape())));
    } else
        return fmod(denominator, this.ulike());
}
 
Example 13
Source File: Transforms.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
protected static long[] broadcastResultShape(INDArray first, INDArray second){
    if(first.equalShapes(second)){
        return first.shape();
    } else if(Shape.areShapesBroadcastable(first.shape(), second.shape())){
        return Shape.broadcastOutputShape(first.shape(), second.shape());
    } else {
        throw new IllegalStateException("Array shapes are not broadcastable: " + Arrays.toString(first.shape()) +
                " vs. " + Arrays.toString(second.shape()));
    }
}