Java Code Examples for org.nd4j.linalg.api.shape.Shape#isRowVectorShape()

The following examples show how to use org.nd4j.linalg.api.shape.Shape#isRowVectorShape() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected void assertSlice(INDArray put, long slice) {
    Preconditions.checkArgument(slice < slices(), "Invalid slice specified: slice %s must be in range 0 (inclusive) to numSlices=%s (exclusive)", slice, slices());
    long[] sliceShape = put.shape();
    if (Shape.isRowVectorShape(sliceShape)) {
        return;
    } else {
        long[] requiredShape = ArrayUtil.removeIndex(shape(), 0);

        //no need to compare for scalar; primarily due to shapes either being [1] or length 0
        if (put.isScalar())
            return;

        if (isVector() && put.isVector() && put.length() < length())
            return;
        //edge case for column vectors
        if (Shape.isColumnVectorShape(sliceShape))
            return;
        if (!Shape.shapeEquals(sliceShape, requiredShape) && !Shape.isRowVectorShape(requiredShape)
                && !Shape.isRowVectorShape(sliceShape))
            throw new IllegalStateException(String.format("Invalid shape size of %s . Should have been %s ",
                    Arrays.toString(sliceShape), Arrays.toString(requiredShape)));
    }
}
 
Example 2
Source File: BaseNDArray.java    From nd4j with Apache License 2.0 5 votes vote down vote up
private void applyBroadcastOp(INDArray vector, final char operation) {
    Nd4j.getCompressor().autoDecompress(this);
    int alongDimension = Shape.isRowVectorShape(vector.shape()) ? 1 : 0;

    // FIXME: probably this is wrong, because strict equality is always false in current DataBuffer mechanics
    if (this.data() == vector.data())
        vector = vector.dup();
    switch (operation) {
        case 'a':
            Nd4j.getExecutioner().exec(new BroadcastAddOp(this, vector, this, alongDimension), alongDimension);
            return;
        case 's':
            Nd4j.getExecutioner().exec(new BroadcastSubOp(this, vector, this, alongDimension), alongDimension);
            return;
        case 'm':
            Nd4j.getExecutioner().exec(new BroadcastMulOp(this, vector, this, alongDimension), alongDimension);
            return;
        case 'd':
            Nd4j.getExecutioner().exec(new BroadcastDivOp(this, vector, this, alongDimension), alongDimension);
            return;
        case 'h':
            Nd4j.getExecutioner().exec(new BroadcastRSubOp(this, vector, this, alongDimension), alongDimension);
            return;
        case 't':
            Nd4j.getExecutioner().exec(new BroadcastRDivOp(this, vector, this, alongDimension), alongDimension);
            return;
        case 'p':
            Nd4j.getExecutioner().exec(new BroadcastCopyOp(this, vector, this, alongDimension), alongDimension);
            return;
        default:
            throw new UnsupportedOperationException("Unknown operation: " + operation);
    }
}
 
Example 3
Source File: BaseNDArray.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * Number of columns (shape[1]), throws an exception when
 * called when not 2d
 *
 * @return the number of columns in the array (only 2d)
 */
@Override
public int columns() {
    // FIXME: int cast
    if (isMatrix())
        return (int) size(1);
    else if (Shape.isColumnVectorShape(shape())) {
        return 1;
    } else if (Shape.isRowVectorShape(shape())) {
        return (int) length();
    }
    throw new IllegalStateException("Rank is [" + rank() + "]; columns() call is not valid");


}
 
Example 4
Source File: BaseNDArray.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * Returns the number of rows
 * in the array (only 2d) throws an exception when
 * called when not 2d
 *
 * @return the number of rows in the matrix
 */
@Override
public int rows() {
    // FIXME:
    if (isMatrix())
        return (int) size(0);
    else if (Shape.isRowVectorShape(shape())) {
        return 1;
    } else if (Shape.isColumnVectorShape(shape())) {
        return (int) length();
    }

    throw new IllegalStateException("Rank is " + rank() + " rows() call is not valid");
}
 
Example 5
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
private void applyBroadcastOp(INDArray vector, final char operation) {
    Nd4j.getCompressor().autoDecompress(this);
    int alongDimension = Shape.isRowVectorShape(vector.shape()) ? 1 : 0;

    // FIXME: probably this is wrong, because strict equality is always false in current DataBuffer mechanics
    if (this.data() == vector.data())
        vector = vector.dup();
    switch (operation) {
        case 'a':
            Nd4j.getExecutioner().exec(new BroadcastAddOp(this, vector, this, alongDimension));
            return;
        case 's':
            Nd4j.getExecutioner().exec(new BroadcastSubOp(this, vector, this, alongDimension));
            return;
        case 'm':
            Nd4j.getExecutioner().exec(new BroadcastMulOp(this, vector, this, alongDimension));
            return;
        case 'd':
            Nd4j.getExecutioner().exec(new BroadcastDivOp(this, vector, this, alongDimension));
            return;
        case 'h':
            Nd4j.getExecutioner().exec(new BroadcastRSubOp(this, vector, this, alongDimension));
            return;
        case 't':
            Nd4j.getExecutioner().exec(new BroadcastRDivOp(this, vector, this, alongDimension));
            return;
        case 'p':
            Nd4j.getExecutioner().exec(new BroadcastCopyOp(this, vector, this, alongDimension));
            return;
        default:
            throw new UnsupportedOperationException("Unknown operation: " + operation);
    }
}
 
Example 6
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public int columns() {
    if (isMatrix())
        return (int) size(1);
    else if (Shape.isColumnVectorShape(shape())) {
        return 1;
    } else if (Shape.isRowVectorShape(shape())) {
        return (int) length();
    }
    throw new IllegalStateException("Rank is [" + rank() + "]; columns() call is not valid");


}
 
Example 7
Source File: BaseNDArray.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public int rows() {
    if (isMatrix())
        return (int) size(0);
    else if (Shape.isRowVectorShape(shape())) {
        return 1;
    } else if (Shape.isColumnVectorShape(shape())) {
        return (int) length();
    }

    throw new IllegalStateException("Rank is " + rank() + " rows() call is not valid");
}
 
Example 8
Source File: BaseNDArray.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public INDArray tensorAlongDimension(int index, int... dimension) {
    if (dimension == null || dimension.length == 0)
        throw new IllegalArgumentException("Invalid input: dimensions not specified (null or length 0)");

    if (dimension.length >= rank()  || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE)
        return this;
    for (int i = 0; i < dimension.length; i++)
        if (dimension[i] < 0)
            dimension[i] += rank();

    //dedup
    if (dimension.length > 1)
        dimension = Ints.toArray(new ArrayList<>(new TreeSet<>(Ints.asList(dimension))));

    if (dimension.length > 1) {
        Arrays.sort(dimension);
    }

    long tads = tensorssAlongDimension(dimension);
    if (index >= tads)
        throw new IllegalArgumentException("Illegal index " + index + " out of tads " + tads);


    if (dimension.length == 1) {
        if (dimension[0] == 0 && isColumnVector()) {
            return this.transpose();
        } else if (dimension[0] == 1 && isRowVector()) {
            return this;
        }
    }

    Pair<DataBuffer, DataBuffer> tadInfo =
            Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(this, dimension);
    DataBuffer shapeInfo = tadInfo.getFirst();
    val shape = Shape.shape(shapeInfo);
    val stride = Shape.stride(shapeInfo).asLong();
    long offset = offset() + tadInfo.getSecond().getLong(index);
    INDArray toTad = Nd4j.create(data(), shape, stride, offset);
    BaseNDArray baseNDArray = (BaseNDArray) toTad;

    //preserve immutability
    char newOrder = Shape.getOrder(shape, stride, 1);

    int ews = baseNDArray.shapeInfoDataBuffer().getInt(baseNDArray.shapeInfoDataBuffer().length() - 2);

    //TAD always calls permute. Permute EWS is always -1. This is not true
    // for row vector shapes though.
    if (!Shape.isRowVectorShape(baseNDArray.shapeInfoDataBuffer()))
        ews = -1;

    // we create new shapeInfo with possibly new ews & order
    /**
     * NOTE HERE THAT ZERO IS PRESET FOR THE OFFSET AND SHOULD STAY LIKE THAT.
     * Zero is preset for caching purposes.
     * We don't actually use the offset defined in the
     * shape info data buffer.
     * We calculate and cache the offsets separately.
     *
     */
    baseNDArray.setShapeInformation(
            Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, 0, ews, newOrder));

    return toTad;
}