Java Code Examples for org.nd4j.linalg.api.shape.Shape#isColumnVectorShape()

The following examples show how to use org.nd4j.linalg.api.shape.Shape#isColumnVectorShape() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected void assertSlice(INDArray put, long slice) {
    Preconditions.checkArgument(slice < slices(), "Invalid slice specified: slice %s must be in range 0 (inclusive) to numSlices=%s (exclusive)", slice, slices());
    long[] sliceShape = put.shape();
    if (Shape.isRowVectorShape(sliceShape)) {
        return;
    } else {
        long[] requiredShape = ArrayUtil.removeIndex(shape(), 0);

        //no need to compare for scalar; primarily due to shapes either being [1] or length 0
        if (put.isScalar())
            return;

        if (isVector() && put.isVector() && put.length() < length())
            return;
        //edge case for column vectors
        if (Shape.isColumnVectorShape(sliceShape))
            return;
        if (!Shape.shapeEquals(sliceShape, requiredShape) && !Shape.isRowVectorShape(requiredShape)
                && !Shape.isRowVectorShape(sliceShape))
            throw new IllegalStateException(String.format("Invalid shape size of %s . Should have been %s ",
                    Arrays.toString(sliceShape), Arrays.toString(requiredShape)));
    }
}
 
Example 2
Source File: BaseNDArray.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * Number of columns (shape[1]), throws an exception when
 * called when not 2d
 *
 * @return the number of columns in the array (only 2d)
 */
@Override
public int columns() {
    // FIXME: int cast
    if (isMatrix())
        return (int) size(1);
    else if (Shape.isColumnVectorShape(shape())) {
        return 1;
    } else if (Shape.isRowVectorShape(shape())) {
        return (int) length();
    }
    throw new IllegalStateException("Rank is [" + rank() + "]; columns() call is not valid");


}
 
Example 3
Source File: BaseNDArray.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * Returns the number of rows
 * in the array (only 2d) throws an exception when
 * called when not 2d
 *
 * @return the number of rows in the matrix
 */
@Override
public int rows() {
    // FIXME:
    if (isMatrix())
        return (int) size(0);
    else if (Shape.isRowVectorShape(shape())) {
        return 1;
    } else if (Shape.isColumnVectorShape(shape())) {
        return (int) length();
    }

    throw new IllegalStateException("Rank is " + rank() + " rows() call is not valid");
}
 
Example 4
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public int columns() {
    if (isMatrix())
        return (int) size(1);
    else if (Shape.isColumnVectorShape(shape())) {
        return 1;
    } else if (Shape.isRowVectorShape(shape())) {
        return (int) length();
    }
    throw new IllegalStateException("Rank is [" + rank() + "]; columns() call is not valid");


}
 
Example 5
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public int rows() {
    if (isMatrix())
        return (int) size(0);
    else if (Shape.isRowVectorShape(shape())) {
        return 1;
    } else if (Shape.isColumnVectorShape(shape())) {
        return (int) length();
    }

    throw new IllegalStateException("Rank is " + rank() + " rows() call is not valid");
}