Java Code Examples for org.nd4j.linalg.api.shape.Shape#shapeEquals()

The following examples show how to use org.nd4j.linalg.api.shape.Shape#shapeEquals() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: BaseNDArray.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * in place (element wise) division of two matrices
 *
 * @param other  the second ndarray to divide
 * @param result the result ndarray
 * @return the result of the divide
 */
@Override
public INDArray divi(INDArray other, INDArray result) {
    if (other.isScalar()) {
        return divi(other.getDouble(0), result);
    }

    if (isScalar()) {
        return other.rdivi(getDouble(0), result);
    }


    if(!Shape.shapeEquals(this.shape(),other.shape())) {
        int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape());
        Nd4j.getExecutioner().exec(new BroadcastDivOp(this,other,result,broadcastDimensions),broadcastDimensions);
        return result;
    }


    LinAlgExceptions.assertSameShape(other, result);
    Nd4j.getExecutioner().exec(new OldDivOp(this, other, result, length()));

    if (Nd4j.ENFORCE_NUMERICAL_STABILITY)
        Nd4j.clearNans(result);
    return result;
}
 
Example 2
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected void assertSlice(INDArray put, long slice) {
    Preconditions.checkArgument(slice < slices(), "Invalid slice specified: slice %s must be in range 0 (inclusive) to numSlices=%s (exclusive)", slice, slices());
    long[] sliceShape = put.shape();
    if (Shape.isRowVectorShape(sliceShape)) {
        return;
    } else {
        long[] requiredShape = ArrayUtil.removeIndex(shape(), 0);

        //no need to compare for scalar; primarily due to shapes either being [1] or length 0
        if (put.isScalar())
            return;

        if (isVector() && put.isVector() && put.length() < length())
            return;
        //edge case for column vectors
        if (Shape.isColumnVectorShape(sliceShape))
            return;
        if (!Shape.shapeEquals(sliceShape, requiredShape) && !Shape.isRowVectorShape(requiredShape)
                && !Shape.isRowVectorShape(sliceShape))
            throw new IllegalStateException(String.format("Invalid shape size of %s . Should have been %s ",
                    Arrays.toString(sliceShape), Arrays.toString(requiredShape)));
    }
}
 
Example 3
Source File: BaseNDArray.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * in place (element wise) multiplication of two matrices
 *
 * @param other  the second ndarray to multiply
 * @param result the result ndarray
 * @return the result of the multiplication
 */
@Override
public INDArray muli(INDArray other, INDArray result) {
    if (other.isScalar()) {
        return muli(other.getDouble(0), result);
    }
    if (isScalar()) {
        return other.muli(getDouble(0), result);
    }



    if(!Shape.shapeEquals(this.shape(),other.shape())) {
        int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape());
        Nd4j.getExecutioner().exec(new BroadcastMulOp(this,other,result,broadcastDimensions),broadcastDimensions);
        return result;
    }

    LinAlgExceptions.assertSameShape(other, result);

    Nd4j.getExecutioner().exec(new OldMulOp(this, other, result, length()));

    if (Nd4j.ENFORCE_NUMERICAL_STABILITY)
        Nd4j.clearNans(result);

    return result;
}
 
Example 4
Source File: BaseNDArray.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * in place subtraction of two matrices
 *
 * @param other  the second ndarray to subtract
 * @param result the result ndarray
 * @return the result of the subtraction
 */
@Override
public INDArray subi(INDArray other, INDArray result) {
    if (other.isScalar()) {
        return subi(other.getDouble(0), result);
    }
    if (isScalar()) {
        return other.rsubi(getDouble(0), result);
    }


    if(!Shape.shapeEquals(this.shape(),other.shape())) {
        int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape());
        Nd4j.getExecutioner().exec(new BroadcastSubOp(this,other,result,broadcastDimensions),broadcastDimensions);
        return result;
    }


    LinAlgExceptions.assertSameShape(other, result);


    Nd4j.getExecutioner().exec(new OldSubOp(this, other,result));

    if (Nd4j.ENFORCE_NUMERICAL_STABILITY)
        Nd4j.clearNans(result);

    return result;
}
 
Example 5
Source File: BaseNDArray.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * in place addition of two matrices
 *
 * @param other  the second ndarray to add
 * @param result the result ndarray
 * @return the result of the addition
 */
@Override
public INDArray addi(INDArray other, INDArray result) {
    if (other.isScalar()) {
        return result.addi(other.getDouble(0), result);
    }

    if (isScalar()) {
        return other.addi(getDouble(0), result);
    }

    if(!Shape.shapeEquals(this.shape(),other.shape())) {
        int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape());
        result = Nd4j.createUninitialized(Shape.broadcastOutputShape(this.shape(),other.shape()));
        Nd4j.getExecutioner().exec(new BroadcastAddOp(this,other,result,broadcastDimensions),broadcastDimensions);
        return result;
    }

    LinAlgExceptions.assertSameShape(other, result);

    Nd4j.getExecutioner().exec(new OldAddOp(this, other, result, length()));


    if (Nd4j.ENFORCE_NUMERICAL_STABILITY)
        Nd4j.clearNans(result);

    return result;
}
 
Example 6
Source File: BaseComplexNDArray.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 *
 * @param real
 */
protected void copyFromReal(INDArray real) {
    if (!Shape.shapeEquals(shape(), real.shape()))
        throw new IllegalStateException("Unable to copy array. Not the same shape");
    INDArray linear = real.linearView();
    IComplexNDArray thisLinear = linearView();
    for (int i = 0; i < linear.length(); i++) {
        thisLinear.putScalar(i, Nd4j.createComplexNumber(linear.getDouble(i), 0.0));
    }
}
 
Example 7
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray lt(INDArray other) {
    validateNumericalArray("less than (lt)", false);
    if (Shape.shapeEquals(this.shape(), other.shape())) {
        return Nd4j.getExecutioner().exec(new LessThan(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0];
    } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
        return Nd4j.exec(new LessThan(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0];
    } else
        throw new IllegalArgumentException("Shapes must be broadcastable");
}
 
Example 8
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray eq(INDArray other) {
    if (Shape.shapeEquals(this.shape(), other.shape())) {
        return Nd4j.getExecutioner().exec(new EqualTo(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0];
    } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
        return Nd4j.exec(new EqualTo(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0];
    } else
        throw new IllegalArgumentException("Shapes must be broadcastable");
}
 
Example 9
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray gt(INDArray other) {
    validateNumericalArray("greater than (gt)", false);
    if (Shape.shapeEquals(this.shape(), other.shape())) {
        return Nd4j.getExecutioner().exec(new GreaterThan(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0];
    } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
        return Nd4j.exec(new GreaterThan(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0];
    } else
        throw new IllegalArgumentException("Shapes must be broadcastable");
}
 
Example 10
Source File: LinAlgExceptions.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public static void assertSameShape(INDArray x, INDArray y, INDArray z) {
    //if (!Shape.isVector(x.shape()) && ! Shape.isVector(y.shape()) && !Shape.isVector(z.shape())) {
        if (!Shape.shapeEquals(x.shape(), y.shape()))
            throw new IllegalStateException("Mis matched shapes: " + Arrays.toString(x.shape()) + ", " + Arrays.toString(y.shape()));
        if (!Shape.shapeEquals(x.shape(), z.shape()))
            throw new IllegalStateException("Mis matched shapes: " + Arrays.toString(x.shape()) + ", " + Arrays.toString(z.shape()));
    //}
}
 
Example 11
Source File: LinAlgExceptions.java    From nd4j with Apache License 2.0 4 votes vote down vote up
public static void assertSameShape(INDArray n, INDArray n2) {
    if (!Shape.shapeEquals(n.shape(), n2.shape()))
        throw new IllegalStateException("Mis matched shapes: " + Arrays.toString(n.shape()) + ", "
                + Arrays.toString(n2.shape()));
}
 
Example 12
Source File: LinAlgExceptions.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public static void assertSameShape(INDArray n, INDArray n2) {
    if (!Shape.isVector(n.shape()) && ! Shape.isVector(n2.shape()))
        if (!Shape.shapeEquals(n.shape(), n2.shape()))
            throw new IllegalStateException("Mis matched shapes: " + Arrays.toString(n.shape()) + ", " + Arrays.toString(n2.shape()));
}