Java Code Examples for org.nd4j.linalg.api.ndarray.INDArray#isSparse()

The following examples show how to use org.nd4j.linalg.api.ndarray.INDArray#isSparse() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: BaseLevel1.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * computes the Euclidean norm of a vector.
 *
 * @param arr
 * @return
 */
@Override
public double nrm2(INDArray arr) {

    if (arr.isSparse()) {
        return Nd4j.getSparseBlasWrapper().level1().nrm2(arr);
    }
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, arr);
    if (arr.isSparse()) {
        return Nd4j.getSparseBlasWrapper().level1().nrm2(arr);
    }
    if (arr.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, arr);
        return dnrm2(arr.length(), arr, BlasBufferUtil.getBlasStride(arr));
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, arr);
        return snrm2(arr.length(), arr, BlasBufferUtil.getBlasStride(arr));
    }
    // TODO: add nrm2 for half, as call to appropriate NativeOp<HALF>
}
 
Example 2
Source File: BaseLevel1.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * computes the sum of magnitudes of all vector elements or, for a complex vector x, the sum
 *
 * @param arr
 * @return
 */
@Override
public double asum(INDArray arr) {

    if (arr.isSparse()) {
        return Nd4j.getSparseBlasWrapper().level1().asum(arr);
    }
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, arr);

    if (arr.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, arr);
        return dasum(arr.length(), arr, BlasBufferUtil.getBlasStride(arr));
    } else if (arr.data().dataType() == DataBuffer.Type.FLOAT) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, arr);
        return sasum(arr.length(), arr, BlasBufferUtil.getBlasStride(arr));
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.HALF, arr);
        return hasum(arr.length(), arr, BlasBufferUtil.getBlasStride(arr));
    }
}
 
Example 3
Source File: BaseLevel1.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * finds the element of a
 * vector that has the largest absolute value.
 *
 * @param arr
 * @return
 */
@Override
public int iamax(INDArray arr) {
    if (arr.isSparse()) {
        return Nd4j.getSparseBlasWrapper().level1().iamax(arr);
    }
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, arr);

    if (arr.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, arr);
        return idamax(arr.length(), arr, BlasBufferUtil.getBlasStride(arr));
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, arr);
        return isamax(arr.length(), arr, BlasBufferUtil.getBlasStride(arr));
    }
}
 
Example 4
Source File: BaseLevel1.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * swaps a vector with another vector.
 *
 * @param x
 * @param y
 */
@Override
public void swap(INDArray x, INDArray y) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, x, y);

    if (x.isSparse() || y.isSparse()) {
        Nd4j.getSparseBlasWrapper().level1().swap(x, y);
        return;
    }

    if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, x, y);
        dswap(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, x, y);
        sswap(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
    }
}
 
Example 5
Source File: BaseLevel1.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * swaps a vector with another vector.
 *
 * @param x
 * @param y
 */
@Override
public void copy(INDArray x, INDArray y) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, x, y);

    if (x.isSparse() || y.isSparse()) {
        Nd4j.getSparseBlasWrapper().level1().copy(x, y);
        return;
    }
    if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, x, y);
        dcopy(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, x, y);
        scopy(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
    }
}
 
Example 6
Source File: BaseLevel1.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * computes a vector-scalar product and adds the result to a vector.
 *
 * @param n
 * @param alpha
 * @param x
 * @param y
 */
@Override
public void axpy(long n, double alpha, INDArray x, INDArray y) {

    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, x, y);

    if (x.isSparse() && !y.isSparse()) {
        Nd4j.getSparseBlasWrapper().level1().axpy(n, alpha, x, y);
    } else if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, x, y);
        daxpy(n, alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
    } else if (x.data().dataType() == DataBuffer.Type.FLOAT) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, x, y);
        saxpy(n, (float) alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.HALF, x, y);
        haxpy(n, (float) alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
    }
}
 
Example 7
Source File: BaseLevel1.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * performs rotation of points in the plane.
 *
 * @param N
 * @param X
 * @param Y
 * @param c
 * @param s
 */
@Override
public void rot(long N, INDArray X, INDArray Y, double c, double s) {

    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, X, Y);

    if (X.isSparse() && !Y.isSparse()) {
        Nd4j.getSparseBlasWrapper().level1().rot(N, X, Y, c, s);
    } else if (X.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, X, Y);
        drot(N, X, BlasBufferUtil.getBlasStride(X), Y, BlasBufferUtil.getBlasStride(X), c, s);
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, X, Y);
        srot(N, X, BlasBufferUtil.getBlasStride(X), Y, BlasBufferUtil.getBlasStride(X), (float) c, (float) s);
    }
}
 
Example 8
Source File: SparseCOOLevel1Test.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void shouldComputeRotWithFullVector() {

    // try with dense vectors to get the expected result
    /*
        INDArray temp1 = Nd4j.create( new double[] {1 ,2, 3, 4});
        INDArray temp2 = Nd4j.create( new double[] {1 ,2, 3, 4});
        System.out.println("before: " + temp1.data() + " " + temp2.data());
        Nd4j.getBlasWrapper().level1().rot(temp1.length(), temp1, temp2, 1, 2);
        System.out.println("after: " + temp1.data() + " " + temp2.data());
    */
    //before: [1.0,2.0,3.0,4.0]  [1.0,2.0,3.0,4.0]
    // after: [3.0,6.0,0.0,12.0] [-1.0,-2.0,-3.0,-4.0]

    int[] cols = {0, 1, 2, 3};
    double[] values = {1, 2, 3, 4};
    INDArray sparseVec = Nd4j.createSparseCOO(data, indexes, shape);
    INDArray vec = Nd4j.create(new double[] {1, 2, 3, 4});
    Nd4j.getBlasWrapper().level1().rot(vec.length(), sparseVec, vec, 1, 2);

    INDArray expectedSparseVec = Nd4j.createSparseCSR(new double[] {3, 6, 9, 12}, new int[] {0, 1, 2, 3},
                    new int[] {0}, new int[] {4}, new int[] {1, 4});
    INDArray expectedVec = Nd4j.create(new double[] {-1, -2, -3, -4});
    assertEquals(getFailureMessage(), expectedSparseVec.data(), sparseVec.data());
    assertEquals(getFailureMessage(), expectedVec, vec);
    if (expectedSparseVec.isSparse() && sparseVec.isSparse()) {
        BaseSparseNDArray vec2 = ((BaseSparseNDArray) expectedSparseVec);
        BaseSparseNDArray vecSparse2 = ((BaseSparseNDArray) sparseVec);
        assertEquals(getFailureMessage(), vec2.getVectorCoordinates(), vecSparse2);
    }
}
 
Example 9
Source File: SparseCSRLevel1Test.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void shouldComputeRotWithFullVector() {

    // try with dense vectors to get the expected result
    /*
        INDArray temp1 = Nd4j.create( new double[] {1 ,2, 3, 4});
        INDArray temp2 = Nd4j.create( new double[] {1 ,2, 3, 4});
        System.out.println("before: " + temp1.data() + " " + temp2.data());
        Nd4j.getBlasWrapper().level1().rot(temp1.length(), temp1, temp2, 1, 2);
        System.out.println("after: " + temp1.data() + " " + temp2.data());
    */
    //before: [1.0,2.0,3.0,4.0]  [1.0,2.0,3.0,4.0]
    // after: [3.0,6.0,0.0,12.0] [-1.0,-2.0,-3.0,-4.0]

    int[] cols = {0, 1, 2, 3};
    double[] values = {1, 2, 3, 4};
    INDArray sparseVec = Nd4j.createSparseCSR(values, cols, pointerB, pointerE, shape);
    INDArray vec = Nd4j.create(new double[] {1, 2, 3, 4});
    Nd4j.getBlasWrapper().level1().rot(vec.length(), sparseVec, vec, 1, 2);

    INDArray expectedSparseVec = Nd4j.createSparseCSR(new double[] {3, 6, 9, 12}, new int[] {0, 1, 2, 3},
                    new int[] {0}, new int[] {4}, new int[] {1, 4});
    INDArray expectedVec = Nd4j.create(new double[] {-1, -2, -3, -4});
    assertEquals(getFailureMessage(), expectedSparseVec.data(), sparseVec.data());
    assertEquals(getFailureMessage(), expectedVec, vec);
    if (expectedSparseVec.isSparse() && sparseVec.isSparse()) {
        BaseSparseNDArray vec2 = ((BaseSparseNDArray) expectedSparseVec);
        BaseSparseNDArray vecSparse2 = ((BaseSparseNDArray) sparseVec);
        assertEquals(getFailureMessage(), vec2.getVectorCoordinates(), vecSparse2);
    }
}
 
Example 10
Source File: BaseLevel2.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * gemv computes a matrix-vector product using a general matrix and performs one of the following matrix-vector operations:
 * y := alpha*a*x + beta*y  for trans = 'N'or'n';
 * y := alpha*a'*x + beta*y  for trans = 'T'or't';
 * y := alpha*conjg(a')*x + beta*y  for trans = 'C'or'c'.
 * Here a is an m-by-n band matrix, x and y are vectors, alpha and beta are scalars.
 *
 * @param order
 * @param transA
 * @param alpha
 * @param A
 * @param X
 * @param beta
 * @param Y
 */
@Override
public void gemv(char order, char transA, double alpha, INDArray A, INDArray X, double beta, INDArray Y) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, A, X, Y);

    if (A.isSparse() && !X.isSparse()) {
        Nd4j.getSparseBlasWrapper().level2().gemv(order, transA, alpha, A, X, beta, Y);
        return;
    }

    GemvParameters parameters = new GemvParameters(A, X, Y);
    if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, parameters.getA(), parameters.getX(),
                        parameters.getY());
        dgemv(order, parameters.getAOrdering(), parameters.getM(), parameters.getN(), alpha, parameters.getA(),
                        parameters.getLda(), parameters.getX(), parameters.getIncx(), beta, parameters.getY(),
                        parameters.getIncy());
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, parameters.getA(), parameters.getX(),
                        parameters.getY());
        sgemv(order, parameters.getAOrdering(), parameters.getM(), parameters.getN(), (float) alpha,
                        parameters.getA(), parameters.getLda(), parameters.getX(), parameters.getIncx(),
                        (float) beta, parameters.getY(), parameters.getIncy());
    }

    OpExecutionerUtil.checkForAny(Y);
}
 
Example 11
Source File: BaseLevel1.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * computes a vector-vector dot product.
 *
 * @param n number of accessed element
 * @param alpha
 * @param X an INDArray
 * @param Y an INDArray
 * @return the vector-vector dot product of X and Y
 */
@Override
public double dot(long n, double alpha, INDArray X, INDArray Y) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, X, Y);

    if (X.isSparse() && !Y.isSparse()) {
        return Nd4j.getSparseBlasWrapper().level1().dot(n, alpha, X, Y);
    } else if (!X.isSparse() && Y.isSparse()) {
        return Nd4j.getSparseBlasWrapper().level1().dot(n, alpha, Y, X);
    } else if (X.isSparse() && Y.isSparse()) {
        // TODO - MKL doesn't contain such routines
        return 0;
    }

    if (X.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, X, Y);
        return ddot(n, X, BlasBufferUtil.getBlasStride(X), Y, BlasBufferUtil.getBlasStride(Y));
    } else if (X.data().dataType() == DataBuffer.Type.FLOAT) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, X, Y);
        return sdot(n, X, BlasBufferUtil.getBlasStride(X), Y, BlasBufferUtil.getBlasStride(Y));
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.HALF, X, Y);
        return hdot(n, X, BlasBufferUtil.getBlasStride(X), Y, BlasBufferUtil.getBlasStride(Y));
    }

}
 
Example 12
Source File: BaseLevel1.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * finds the element of a vector that has the minimum absolute value.
 *
 * @param arr
 * @return
 */
@Override
public int iamin(INDArray arr) {
    if (arr.isSparse()) {
        return Nd4j.getSparseBlasWrapper().level1().iamin(arr);
    } else {
        throw new UnsupportedOperationException();
    }
}
 
Example 13
Source File: BaseLevel1.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * computes a vector by a scalar product.
 *
 * @param N
 * @param alpha
 * @param X
 */
@Override
public void scal(long N, double alpha, INDArray X) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, X);

    if (X.isSparse()) {
        Nd4j.getSparseBlasWrapper().level1().scal(N, alpha, X);
    } else if (X.data().dataType() == DataBuffer.Type.DOUBLE)
        dscal(N, alpha, X, BlasBufferUtil.getBlasStride(X));
    else if (X.data().dataType() == DataBuffer.Type.FLOAT)
        sscal(N, (float) alpha, X, BlasBufferUtil.getBlasStride(X));
    else if (X.data().dataType() == DataBuffer.Type.HALF)
        Nd4j.getExecutioner().exec(new ScalarMultiplication(X, alpha));
}