Java Code Examples for org.nd4j.linalg.api.ndarray.INDArray#stride()

The following examples show how to use org.nd4j.linalg.api.ndarray.INDArray#stride() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: Shape.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * This method is used in DL4J LSTM implementation
 * @param input
 * @return
 */
public static INDArray toMmulCompatible(INDArray input) {
    if (input.rank() != 2)
        throw new IllegalArgumentException("Input must be rank 2 (matrix)");
    //Same conditions as GemmParams.copyIfNecessary()
    boolean doCopy = false;
    if (input.ordering() == 'c' && (input.stride(0) != input.size(1) || input.stride(1) != 1))
        doCopy = true;
    else if (input.ordering() == 'f' && (input.stride(0) != 1 || input.stride(1) != input.size(0)))
        doCopy = true;

    if (doCopy)
        return Shape.toOffsetZeroCopyAnyOrder(input);
    else
        return input;
}
 
Example 2
Source File: Shape.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/** Check if strides are in order suitable for non-strided mmul etc.
 * Returns true if c order and strides are descending [100,10,1] etc
 * Returns true if f order and strides are ascending [1,10,100] etc
 * False otherwise.
 * @return true if c+descending, f+ascending, false otherwise
 */
public static boolean strideDescendingCAscendingF(INDArray array) {
    if(array.rank() <= 1)
        return true;
    long[] strides = array.stride();
    if (array.isVector() && strides[0] == 1 && strides[1] == 1)
        return true;
    char order = array.ordering();

    if (order == 'c') { //Expect descending. [100,10,1] etc
        for (int i = 1; i < strides.length; i++)
            if (strides[i - 1] <= strides[i])
                return false;
        return true;
    } else if (order == 'f') {//Expect ascending. [1,10,100] etc
        for (int i = 1; i < strides.length; i++)
            if (strides[i - 1] >= strides[i])
                return false;
        return true;
    } else if (order == 'a') {
        return true;
    } else {
        throw new RuntimeException("Invalid order: not c or f (is: " + order + ")");
    }
}
 
Example 3
Source File: Shape.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/** Check if strides are in order suitable for non-strided mmul etc.
 * Returns true if c order and strides are descending [100,10,1] etc
 * Returns true if f order and strides are ascending [1,10,100] etc
 * False otherwise.
 * @return true if c+descending, f+ascending, false otherwise
 */
public static boolean strideDescendingCAscendingF(INDArray array) {
    long[] strides = array.stride();
    if (array.isVector() && strides[0] == 1 && strides[1] == 1)
        return true;
    char order = array.ordering();

    if (order == 'c') { //Expect descending. [100,10,1] etc
        for (int i = 1; i < strides.length; i++)
            if (strides[i - 1] <= strides[i])
                return false;
        return true;
    } else if (order == 'f') {//Expect ascending. [1,10,100] etc
        for (int i = 1; i < strides.length; i++)
            if (strides[i - 1] >= strides[i])
                return false;
        return true;
    } else if (order == 'a') {
        return true;
    } else {
        throw new RuntimeException("Invalid order: not c or f (is: " + order + ")");
    }
}
 
Example 4
Source File: Shape.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * This method is used in DL4J LSTM implementation
 * @param input
 * @return
 */
public static INDArray toMmulCompatible(INDArray input) {
    if (input.rank() != 2)
        throw new IllegalArgumentException("Input must be rank 2 (matrix)");
    //Same conditions as GemmParams.copyIfNecessary()
    boolean doCopy = false;
    if (input.ordering() == 'c' && (input.stride(0) != input.size(1) || input.stride(1) != 1))
        doCopy = true;
    else if (input.ordering() == 'f' && (input.stride(0) != 1 || input.stride(1) != input.size(0)))
        doCopy = true;

    if (doCopy)
        return Shape.toOffsetZeroCopyAnyOrder(input);
    else
        return input;
}
 
Example 5
Source File: NumpyArray.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public NumpyArray(INDArray nd4jArray) {
    Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
    DataBuffer buff = nd4jArray.data();
    address = buff.pointer().address();
    shape = nd4jArray.shape();
    long[] nd4jStrides = nd4jArray.stride();
    strides = new long[nd4jStrides.length];
    int elemSize = buff.getElementSize();
    for (int i = 0; i < strides.length; i++) {
        strides[i] = nd4jStrides[i] * elemSize;
    }
    dtype = nd4jArray.dataType();
    this.nd4jArray = nd4jArray;
    String cacheKey = address + "_" + nd4jArray.length() + "_" + dtype + "_" + ArrayUtils.toString(strides);
    arrayCache.put(cacheKey, nd4jArray);
}
 
Example 6
Source File: CudaAffinityManager.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * This method replicates given INDArray, and places it to target device.
 *
 * @param deviceId target deviceId
 * @param array    INDArray to replicate
 * @return
 */
@Override
public synchronized INDArray replicateToDevice(Integer deviceId, INDArray array) {
    if (array == null)
        return null;

    if (array.isView())
        throw new UnsupportedOperationException("It's impossible to replicate View");

    val shape = array.shape();
    val stride = array.stride();
    val elementWiseStride = array.elementWiseStride();
    val ordering = array.ordering();
    val length = array.length();

    // we use this call to get device memory updated
    AtomicAllocator.getInstance().getPointer(array,
                    (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext());

    int currentDeviceId = getDeviceForCurrentThread();

    NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(new CudaPointer(deviceId));
    attachThreadToDevice(Thread.currentThread().getId(), deviceId);


    DataBuffer newDataBuffer = replicateToDevice(deviceId, array.data());
    DataBuffer newShapeBuffer = Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, 0,
                    elementWiseStride, ordering).getFirst();
    INDArray result = Nd4j.createArrayFromShapeBuffer(newDataBuffer, newShapeBuffer);

    attachThreadToDevice(Thread.currentThread().getId(), currentDeviceId);
    NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(new CudaPointer(currentDeviceId));


    return result;
}
 
Example 7
Source File: OpExecutionerUtil.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/** Can we do the transform op (X = Op(X,Y)) directly on the arrays without breaking them up into 1d tensors first? */
public static boolean canDoOpDirectly(INDArray x, INDArray y) {
    if (x.isVector())
        return true;
    if (x.ordering() != y.ordering())
        return false; //other than vectors, elements in f vs. c NDArrays will never line up
    if (x.elementWiseStride() < 1 || y.elementWiseStride() < 1)
        return false;
    //Full buffer + matching strides -> implies all elements are contiguous (and match)
    //Need strides to match, otherwise elements in buffer won't line up (i.e., c vs. f order arrays)
    long l1 = x.lengthLong();
    long dl1 = x.data().length();
    long l2 = y.lengthLong();
    long dl2 = y.data().length();
    long[] strides1 = x.stride();
    long[] strides2 = y.stride();
    boolean equalStrides = Arrays.equals(strides1, strides2);
    if (l1 == dl1 && l2 == dl2 && equalStrides)
        return true;

    //Strides match + are same as a zero offset NDArray -> all elements are contiguous (and match)
    if (equalStrides) {
        long[] shape1 = x.shape();
        long[] stridesAsInit = (x.ordering() == 'c' ? ArrayUtil.calcStrides(shape1)
                        : ArrayUtil.calcStridesFortran(shape1));
        boolean stridesSameAsInit = Arrays.equals(strides1, stridesAsInit);
        return stridesSameAsInit;
    }

    return false;
}
 
Example 8
Source File: OpExecutionerUtil.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/** Can we do the transform op (Z = Op(X,Y)) directly on the arrays without breaking them up into 1d tensors first? */
public static boolean canDoOpDirectly(INDArray x, INDArray y, INDArray z) {
    if (x.isVector())
        return true;
    if (x.ordering() != y.ordering() || x.ordering() != z.ordering())
        return false; //other than vectors, elements in f vs. c NDArrays will never line up
    if (x.elementWiseStride() < 1 || y.elementWiseStride() < 1)
        return false;
    //Full buffer + matching strides -> implies all elements are contiguous (and match)
    long l1 = x.lengthLong();
    long dl1 = x.data().length();
    long l2 = y.lengthLong();
    long dl2 = y.data().length();
    long l3 = z.lengthLong();
    long dl3 = z.data().length();
    long[] strides1 = x.stride();
    long[] strides2 = y.stride();
    long[] strides3 = z.stride();
    boolean equalStrides = Arrays.equals(strides1, strides2) && Arrays.equals(strides1, strides3);
    if (l1 == dl1 && l2 == dl2 && l3 == dl3 && equalStrides)
        return true;

    //Strides match + are same as a zero offset NDArray -> all elements are contiguous (and match)
    if (equalStrides) {
        long[] shape1 = x.shape();
        long[] stridesAsInit = (x.ordering() == 'c' ? ArrayUtil.calcStrides(shape1)
                        : ArrayUtil.calcStridesFortran(shape1));
        boolean stridesSameAsInit = Arrays.equals(strides1, stridesAsInit);
        return stridesSameAsInit;
    }

    return false;
}
 
Example 9
Source File: GemvParameters.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
private INDArray copyIfNecessary(INDArray arr) {
    //See also: Shape.toMmulCompatible - want same conditions here and there
    //Check if matrix values are contiguous in memory. If not: dup
    //Contiguous for c if: stride[0] == shape[1] and stride[1] = 1
    //Contiguous for f if: stride[0] == 1 and stride[1] == shape[0]
    if (arr.ordering() == 'c' && (arr.stride(0) != arr.size(1) || arr.stride(1) != 1))
        return arr.dup();
    else if (arr.ordering() == 'f' && (arr.stride(0) != 1 || arr.stride(1) != arr.size(0)))
        return arr.dup();
    else if (arr.elementWiseStride() < 1)
        return arr.dup();
    return arr;
}
 
Example 10
Source File: LogEntry.java    From nd4j with Apache License 2.0 5 votes vote down vote up
public LogEntry(INDArray toLog, StackTraceElement[] stackTraceElements, String status) {
    //this.id = toLog.id();
    this.shape = toLog.shape();
    this.stride = toLog.stride();
    this.ndArrayType = toLog.getClass().getName();
    this.length = toLog.length();
    this.references = toLog.data().references();
    this.dataType = toLog.data().dataType() == DataBuffer.Type.DOUBLE ? "double" : "float";
    this.timestamp = System.currentTimeMillis();
    this.stackTraceElements = stackTraceElements;
    this.status = status;
}
 
Example 11
Source File: GemvParameters.java    From nd4j with Apache License 2.0 5 votes vote down vote up
private INDArray copyIfNecessary(INDArray arr) {
    //See also: Shape.toMmulCompatible - want same conditions here and there
    //Check if matrix values are contiguous in memory. If not: dup
    //Contiguous for c if: stride[0] == shape[1] and stride[1] = 1
    //Contiguous for f if: stride[0] == 1 and stride[1] == shape[0]
    if (arr.ordering() == 'c' && (arr.stride(0) != arr.size(1) || arr.stride(1) != 1))
        return arr.dup();
    else if (arr.ordering() == 'f' && (arr.stride(0) != 1 || arr.stride(1) != arr.size(0)))
        return arr.dup();
    else if (arr.elementWiseStride() < 1)
        return arr.dup();
    return arr;
}
 
Example 12
Source File: GemmParams.java    From nd4j with Apache License 2.0 5 votes vote down vote up
private INDArray copyIfNeccessary(INDArray arr) {
    //See also: Shape.toMmulCompatible - want same conditions here and there
    //Check if matrix values are contiguous in memory. If not: dup
    //Contiguous for c if: stride[0] == shape[1] and stride[1] = 1
    //Contiguous for f if: stride[0] == 1 and stride[1] == shape[0]
    if (!Nd4j.allowsSpecifyOrdering() && arr.ordering() == 'c'
            && (arr.stride(0) != arr.size(1) || arr.stride(1) != 1))
        return arr.dup();
    else if (arr.ordering() == 'f' && (arr.stride(0) != 1 || arr.stride(1) != arr.size(0)))
        return arr.dup();
    else if (arr.elementWiseStride() < 0)
        return arr.dup();
    return arr;
}
 
Example 13
Source File: ArrowSerde.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * Get the strides of this {@link INDArray}
 * multiplieed by  the element size.
 * This is the {@link Tensor} and numpy format
 * @param arr the array to convert
 * @return
 */
public static long[] getArrowStrides(INDArray arr) {
    long[] ret = new long[arr.rank()];
    for(int i = 0; i < arr.rank(); i++) {
        ret[i] = arr.stride(i) * arr.data().getElementSize();
    }

    return ret;
}
 
Example 14
Source File: ArrayDescriptor.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public ArrayDescriptor(INDArray array) throws Exception{
    this(array.data().address(), array.shape(), array.stride(), array.data().dataType(), array.ordering());
    if (array.isEmpty()){
        throw new UnsupportedOperationException("Empty arrays are not supported");
    }
}
 
Example 15
Source File: StaticShapeTests.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testBufferToIntShapeStrideMethods() {
    //Specifically: Shape.shape(IntBuffer), Shape.shape(DataBuffer)
    //.isRowVectorShape(DataBuffer), .isRowVectorShape(IntBuffer)
    //Shape.size(DataBuffer,int), Shape.size(IntBuffer,int)
    //Also: Shape.stride(IntBuffer), Shape.stride(DataBuffer)
    //Shape.stride(DataBuffer,int), Shape.stride(IntBuffer,int)

    List<List<Pair<INDArray, String>>> lists = new ArrayList<>();
    lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(3, 4, 12345, DataType.DOUBLE));
    lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(1, 4, 12345, DataType.DOUBLE));
    lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(3, 1, 12345, DataType.DOUBLE));
    lists.add(NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, new long[]{3, 4, 5}, DataType.DOUBLE));
    lists.add(NDArrayCreationUtil.getAll4dTestArraysWithShape(12345, new int[]{3, 4, 5, 6}, DataType.DOUBLE));
    lists.add(NDArrayCreationUtil.getAll4dTestArraysWithShape(12345, new int[]{3, 1, 5, 1}, DataType.DOUBLE));
    lists.add(NDArrayCreationUtil.getAll5dTestArraysWithShape(12345, new int[]{3, 4, 5, 6, 7}, DataType.DOUBLE));
    lists.add(NDArrayCreationUtil.getAll6dTestArraysWithShape(12345, new int[]{3, 4, 5, 6, 7, 8}, DataType.DOUBLE));

    val shapes = new long[][] {{3, 4}, {1, 4}, {3, 1}, {3, 4, 5}, {3, 4, 5, 6}, {3, 1, 5, 1}, {3, 4, 5, 6, 7}, {3, 4, 5, 6, 7, 8}};

    for (int i = 0; i < shapes.length; i++) {
        List<Pair<INDArray, String>> list = lists.get(i);
        val shape = shapes[i];

        for (Pair<INDArray, String> p : list) {
            INDArray arr = p.getFirst();

            assertArrayEquals(shape, arr.shape());

            val thisStride = arr.stride();

            val ib = arr.shapeInfo();
            DataBuffer db = arr.shapeInfoDataBuffer();

            //Check shape calculation
            assertEquals(shape.length, Shape.rank(ib));
            assertEquals(shape.length, Shape.rank(db));

            assertArrayEquals(shape, Shape.shape(ib));
            assertArrayEquals(shape, Shape.shape(db));

            for (int j = 0; j < shape.length; j++) {
                assertEquals(shape[j], Shape.size(ib, j));
                assertEquals(shape[j], Shape.size(db, j));

                assertEquals(thisStride[j], Shape.stride(ib, j));
                assertEquals(thisStride[j], Shape.stride(db, j));
            }

            //Check base offset
            assertEquals(Shape.offset(ib), Shape.offset(db));

            //Check offset calculation:
            NdIndexIterator iter = new NdIndexIterator(shape);
            while (iter.hasNext()) {
                val next = iter.next();
                long offset1 = Shape.getOffset(ib, next);

                assertEquals(offset1, Shape.getOffset(db, next));

                switch (shape.length) {
                    case 2:
                        assertEquals(offset1, Shape.getOffset(ib, next[0], next[1]));
                        assertEquals(offset1, Shape.getOffset(db, next[0], next[1]));
                        break;
                    case 3:
                        assertEquals(offset1, Shape.getOffset(ib, next[0], next[1], next[2]));
                        assertEquals(offset1, Shape.getOffset(db, next[0], next[1], next[2]));
                        break;
                    case 4:
                        assertEquals(offset1, Shape.getOffset(ib, next[0], next[1], next[2], next[3]));
                        assertEquals(offset1, Shape.getOffset(db, next[0], next[1], next[2], next[3]));
                        break;
                    case 5:
                    case 6:
                        //No 5 and 6d getOffset overloads
                        break;
                    default:
                        throw new RuntimeException();
                }
            }
        }
    }
}
 
Example 16
Source File: GemvParameters.java    From nd4j with Apache License 2.0 4 votes vote down vote up
public GemvParameters(INDArray a, INDArray x, INDArray y) {
    a = copyIfNecessary(a);
    x = copyIfNecessaryVector(x);
    this.a = a;
    this.x = x;
    this.y = y;

    if (a.columns() > Integer.MAX_VALUE || a.rows() > Integer.MAX_VALUE)
        throw new ND4JArraySizeException();

    if (x.columns() > Integer.MAX_VALUE || x.rows() > Integer.MAX_VALUE)
        throw new ND4JArraySizeException();


    if (a.ordering() == 'f' && a.isMatrix()) {
        this.m = (int) a.rows();
        this.n = (int) a.columns();
        this.lda = (int) a.rows();
    } else if (a.ordering() == 'c' && a.isMatrix()) {
        this.m = (int) a.columns();
        this.n = (int) a.rows();
        this.lda = (int) a.columns();
        aOrdering = 'T';
    }

    else {
        this.m = (int) a.rows();
        this.n = (int) a.columns();
        this.lda = (int) a.size(0);
    }


    if (x.rank() == 1) {
        incx = 1;
    } else if (x.isColumnVector()) {
        incx = x.stride(0);
    } else {
        incx = x.stride(1);
    }

    this.incy = y.elementWiseStride();

    if (x instanceof IComplexNDArray)
        this.incx /= 2;
    if (y instanceof IComplexNDArray)
        this.incy /= 2;

}
 
Example 17
Source File: StaticShapeTests.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testBufferToIntShapeStrideMethods() {
    //Specifically: Shape.shape(IntBuffer), Shape.shape(DataBuffer)
    //.isRowVectorShape(DataBuffer), .isRowVectorShape(IntBuffer)
    //Shape.size(DataBuffer,int), Shape.size(IntBuffer,int)
    //Also: Shape.stride(IntBuffer), Shape.stride(DataBuffer)
    //Shape.stride(DataBuffer,int), Shape.stride(IntBuffer,int)

    List<List<Pair<INDArray, String>>> lists = new ArrayList<>();
    lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(3, 4, 12345));
    lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(1, 4, 12345));
    lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(3, 1, 12345));
    lists.add(NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 3, 4, 5));
    lists.add(NDArrayCreationUtil.getAll4dTestArraysWithShape(12345, 3, 4, 5, 6));
    lists.add(NDArrayCreationUtil.getAll4dTestArraysWithShape(12345, 3, 1, 5, 1));
    lists.add(NDArrayCreationUtil.getAll5dTestArraysWithShape(12345, 3, 4, 5, 6, 7));
    lists.add(NDArrayCreationUtil.getAll6dTestArraysWithShape(12345, 3, 4, 5, 6, 7, 8));

    val shapes = new long[][] {{3, 4}, {1, 4}, {3, 1}, {3, 4, 5}, {3, 4, 5, 6}, {3, 1, 5, 1}, {3, 4, 5, 6, 7},
                    {3, 4, 5, 6, 7, 8}};

    for (int i = 0; i < shapes.length; i++) {
        List<Pair<INDArray, String>> list = lists.get(i);
        val shape = shapes[i];

        for (Pair<INDArray, String> p : list) {
            INDArray arr = p.getFirst();

            assertArrayEquals(shape, arr.shape());

            val thisStride = arr.stride();

            val ib = arr.shapeInfo();
            DataBuffer db = arr.shapeInfoDataBuffer();

            //Check shape calculation
            assertEquals(shape.length, Shape.rank(ib));
            assertEquals(shape.length, Shape.rank(db));

            assertArrayEquals(shape, Shape.shape(ib));
            assertArrayEquals(shape, Shape.shape(db));

            for (int j = 0; j < shape.length; j++) {
                assertEquals(shape[j], Shape.size(ib, j));
                assertEquals(shape[j], Shape.size(db, j));

                assertEquals(thisStride[j], Shape.stride(ib, j));
                assertEquals(thisStride[j], Shape.stride(db, j));
            }

            //Check base offset
            assertEquals(Shape.offset(ib), Shape.offset(db));

            //Check offset calculation:
            NdIndexIterator iter = new NdIndexIterator(shape);
            while (iter.hasNext()) {
                val next = iter.next();
                long offset1 = Shape.getOffset(ib, next);

                assertEquals(offset1, Shape.getOffset(db, next));

                switch (shape.length) {
                    case 2:
                        assertEquals(offset1, Shape.getOffset(ib, next[0], next[1]));
                        assertEquals(offset1, Shape.getOffset(db, next[0], next[1]));
                        break;
                    case 3:
                        assertEquals(offset1, Shape.getOffset(ib, next[0], next[1], next[2]));
                        assertEquals(offset1, Shape.getOffset(db, next[0], next[1], next[2]));
                        break;
                    case 4:
                        assertEquals(offset1, Shape.getOffset(ib, next[0], next[1], next[2], next[3]));
                        assertEquals(offset1, Shape.getOffset(db, next[0], next[1], next[2], next[3]));
                        break;
                    case 5:
                    case 6:
                        //No 5 and 6d getOffset overloads
                        break;
                    default:
                        throw new RuntimeException();
                }
            }
        }
    }
}
 
Example 18
Source File: BlasBufferUtil.java    From nd4j with Apache License 2.0 3 votes vote down vote up
/**
 * Return the proper stride
 * through a vector
 * relative to the ordering of the array
 * This is for incX/incY parameters in BLAS.
 *
 * @param arr the array to get the stride for
 * @return the stride wrt the ordering
 * for the given array
 */
public static int getStrideForOrdering(INDArray arr) {
    if (arr.ordering() == NDArrayFactory.FORTRAN) {
        return getBlasStride(arr);
    } else {
        if (arr instanceof IComplexNDArray)
            return arr.stride(1) / 2;
        return arr.stride(1);
    }
}
 
Example 19
Source File: BlasBufferUtil.java    From deeplearning4j with Apache License 2.0 3 votes vote down vote up
/**
 * Return the proper stride
 * through a vector
 * relative to the ordering of the array
 * This is for incX/incY parameters in BLAS.
 *
 * @param arr the array to get the stride for
 * @return the stride wrt the ordering
 * for the given array
 */
public static int getStrideForOrdering(INDArray arr) {
    if (arr.ordering() == NDArrayFactory.FORTRAN) {
        return getBlasStride(arr);
    } else {
        return arr.stride(1);
    }
}
 
Example 20
Source File: ProtectedCudaShapeInfoProviderTest.java    From nd4j with Apache License 2.0 2 votes vote down vote up
@Test
    public void testPurge2() throws Exception {
        INDArray arrayA = Nd4j.create(10, 10);

        DataBuffer shapeInfoA = arrayA.shapeInfoDataBuffer();

        INDArray arrayE = Nd4j.create(10, 10);

        DataBuffer shapeInfoE = arrayE.shapeInfoDataBuffer();

        int[] arrayShapeA = shapeInfoA.asInt();

        assertTrue(shapeInfoA == shapeInfoE);

        ShapeDescriptor descriptor = new ShapeDescriptor(arrayA.shape(), arrayA.stride(), 0, arrayA.elementWiseStride(), arrayA.ordering());
        ConstantProtector protector = ConstantProtector.getInstance();
        AllocationPoint pointA = AtomicAllocator.getInstance().getAllocationPoint(arrayA.shapeInfoDataBuffer());

        assertEquals(true, protector.containsDataBuffer(0, descriptor));

////////////////////////////////////

        Nd4j.getMemoryManager().purgeCaches();

////////////////////////////////////


        assertEquals(false, protector.containsDataBuffer(0, descriptor));

        INDArray arrayB = Nd4j.create(10, 10);

        DataBuffer shapeInfoB = arrayB.shapeInfoDataBuffer();

        assertFalse(shapeInfoA == shapeInfoB);

        AllocationPoint pointB = AtomicAllocator.getInstance().getAllocationPoint(arrayB.shapeInfoDataBuffer());


        assertArrayEquals(arrayShapeA, shapeInfoB.asInt());

        // pointers should be equal, due to offsets reset
        assertEquals(pointA.getPointers().getDevicePointer().address(), pointB.getPointers().getDevicePointer().address());
    }