Java Code Examples for org.nd4j.linalg.api.shape.Shape#getOrder()

The following examples show how to use org.nd4j.linalg.api.shape.Shape#getOrder() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: BaseNDArray.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray subArray(ShapeOffsetResolution resolution) {
    Nd4j.getCompressor().autoDecompress(this);
    long[] offsets = resolution.getOffsets();
    int[] shape = LongUtils.toInts(resolution.getShapes());
    int[] stride = LongUtils.toInts(resolution.getStrides());

    //        if (offset() + resolution.getOffset() >= Integer.MAX_VALUE)
    //            throw new IllegalArgumentException("Offset of array can not be >= Integer.MAX_VALUE");

    long offset = (offset() + resolution.getOffset());


    int n = shape.length;

    // FIXME: shapeInfo should be used here
    if (shape.length < 1)
        return create(Nd4j.createBufferDetached(shape));
    if (offsets.length != n)
        throw new IllegalArgumentException("Invalid offset " + Arrays.toString(offsets));
    if (stride.length != n)
        throw new IllegalArgumentException("Invalid stride " + Arrays.toString(stride));

    if (shape.length == rank() && Shape.contentEquals(shape, shapeOf())) {
        if (ArrayUtil.isZero(offsets)) {
            return this;
        } else {
            throw new IllegalArgumentException("Invalid subArray offsets");
        }
    }

    char newOrder = Shape.getOrder(shape, stride, 1);

    return create(data, Arrays.copyOf(shape, shape.length), stride, offset, newOrder);
}
 
Example 2
Source File: ShapeTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testShapeOrder(){
    long[] shape = {2,2};
    long[] stride = {1,8};  //Ascending strides -> F order

    char order = Shape.getOrder(shape, stride, 1);
    assertEquals('f', order);
}
 
Example 3
Source File: TADTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testTADEWSStride(){
    INDArray orig = Nd4j.linspace(1, 600, 600).reshape('f', 10, 1, 60);

    for( int i=0; i<60; i++ ){
        INDArray tad = orig.tensorAlongDimension(i, 0, 1);
        //TAD: should be equivalent to get(all, all, point(i))
        INDArray get = orig.get(all(), all(), point(i));

        String str = String.valueOf(i);
        assertEquals(str, get, tad);
        assertEquals(str, get.data().offset(), tad.data().offset());
        assertEquals(str, get.elementWiseStride(), tad.elementWiseStride());

        char orderTad = Shape.getOrder(tad.shape(), tad.stride(), 1);
        char orderGet = Shape.getOrder(get.shape(), get.stride(), 1);

        assertEquals('f', orderTad);
        assertEquals('f', orderGet);

        long ewsTad = Shape.elementWiseStride(tad.shape(), tad.stride(), tad.ordering() == 'f');
        long ewsGet = Shape.elementWiseStride(get.shape(), get.stride(), get.ordering() == 'f');

        assertEquals(1, ewsTad);
        assertEquals(1, ewsGet);
    }
}
 
Example 4
Source File: BaseNDArray.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public INDArray tensorAlongDimension(int index, int... dimension) {
    if (dimension == null || dimension.length == 0)
        throw new IllegalArgumentException("Invalid input: dimensions not specified (null or length 0)");

    if (dimension.length >= rank()  || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE)
        return this;
    for (int i = 0; i < dimension.length; i++)
        if (dimension[i] < 0)
            dimension[i] += rank();

    //dedup
    if (dimension.length > 1)
        dimension = Ints.toArray(new ArrayList<>(new TreeSet<>(Ints.asList(dimension))));

    if (dimension.length > 1) {
        Arrays.sort(dimension);
    }

    long tads = tensorssAlongDimension(dimension);
    if (index >= tads)
        throw new IllegalArgumentException("Illegal index " + index + " out of tads " + tads);


    if (dimension.length == 1) {
        if (dimension[0] == 0 && isColumnVector()) {
            return this.transpose();
        } else if (dimension[0] == 1 && isRowVector()) {
            return this;
        }
    }

    Pair<DataBuffer, DataBuffer> tadInfo =
            Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(this, dimension);
    DataBuffer shapeInfo = tadInfo.getFirst();
    val shape = Shape.shape(shapeInfo);
    val stride = Shape.stride(shapeInfo).asLong();
    long offset = offset() + tadInfo.getSecond().getLong(index);
    INDArray toTad = Nd4j.create(data(), shape, stride, offset);
    BaseNDArray baseNDArray = (BaseNDArray) toTad;

    //preserve immutability
    char newOrder = Shape.getOrder(shape, stride, 1);

    int ews = baseNDArray.shapeInfoDataBuffer().getInt(baseNDArray.shapeInfoDataBuffer().length() - 2);

    //TAD always calls permute. Permute EWS is always -1. This is not true
    // for row vector shapes though.
    if (!Shape.isRowVectorShape(baseNDArray.shapeInfoDataBuffer()))
        ews = -1;

    // we create new shapeInfo with possibly new ews & order
    /**
     * NOTE HERE THAT ZERO IS PRESET FOR THE OFFSET AND SHOULD STAY LIKE THAT.
     * Zero is preset for caching purposes.
     * We don't actually use the offset defined in the
     * shape info data buffer.
     * We calculate and cache the offsets separately.
     *
     */
    baseNDArray.setShapeInformation(
            Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, 0, ews, newOrder));

    return toTad;
}
 
Example 5
Source File: BaseNDArray.java    From nd4j with Apache License 2.0 4 votes vote down vote up
/**
 * An <b>in-place</b> version of permute. The array  shape information (shape, strides)
 * is modified by this operation (but not the data itself)
 * See: http://www.mathworks.com/help/matlab/ref/permute.html
 *
 * @param rearrange the dimensions to swap to
 * @return the current array
 */
@Override
public INDArray permutei(int... rearrange) {
    boolean alreadyInOrder = true;
    val shapeInfo = shapeInfo();
    int rank = Shape.rank(javaShapeInformation);
    for (int i = 0; i < rank; i++) {
        if (rearrange[i] != i) {
            alreadyInOrder = false;
            break;
        }
    }

    if (alreadyInOrder)
        return this;

    checkArrangeArray(rearrange);
    val newShape = doPermuteSwap(Shape.shapeOf(shapeInfo), rearrange);
    val newStride = doPermuteSwap(Shape.stride(shapeInfo), rearrange);
    char newOrder = Shape.getOrder(newShape, newStride, elementStride());

    //Set the shape information of this array: shape, stride, order.
    //Shape info buffer: [rank, [shape], [stride], offset, elementwiseStride, order]
    /*for( int i=0; i<rank; i++ ){
        shapeInfo.put(1+i,newShape[i]);
        shapeInfo.put(1+i+rank,newStride[i]);
    }
    shapeInfo.put(3+2*rank,newOrder);
    */
    val ews = shapeInfo.get(2 * rank + 2);
    /*
    if (ews < 1 && !attemptedToFindElementWiseStride)
        throw new RuntimeException("EWS is -1");
        */

    val si = Nd4j.getShapeInfoProvider().createShapeInformation(newShape, newStride, 0, ews, newOrder);
    setShapeInformation(si);


    if (shapeInfo.get(2 * rank + 2) > 0) {
        //for the backend to work - no ews for permutei
        //^^ not true anymore? Not sure here. Marking this for raver
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(newShape, newStride,
                this.offset(), -1, newOrder));
    }

    //this.shape = null;
    //this.stride = null;


    return this;
}