Java Code Examples for org.nd4j.linalg.api.ndarray.INDArray#isScalar()

The following examples show how to use org.nd4j.linalg.api.ndarray.INDArray#isScalar() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: RegressionMetrics.java    From konduit-serving with Apache License 2.0 6 votes vote down vote up
private void handleNdArray(INDArray output) {
    if(output.isVector()) {
        for(int i = 0; i < output.length(); i++) {
            statCounters.get(i).add(output.getDouble(i));
        }
    }
    else if(output.isMatrix() && output.length() > 1) {
        for(int i = 0; i < output.rows(); i++) {
            for(int j = 0; j < output.columns(); j++) {
                statCounters.get(i).add(output.getDouble(i,j));
            }
        }
    }
    else if(output.isScalar()) {
        statCounters.get(0).add(output.sumNumber().doubleValue());
    }
    else {
        throw new IllegalArgumentException("Only vectors and matrices supported right now");
    }
}
 
Example 2
Source File: BaseComplexNDArray.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * in place (element wise) multiplication of two matrices
 *
 * @param other  the second ndarray to multiply
 * @param result the result ndarray
 * @return the result of the multiplication
 */
@Override
public IComplexNDArray muli(INDArray other, INDArray result) {
    IComplexNDArray cOther = (IComplexNDArray) other;
    IComplexNDArray cResult = (IComplexNDArray) result;

    IComplexNDArray linear = linearView();
    IComplexNDArray cOtherLinear = cOther.linearView();
    IComplexNDArray cResultLinear = cResult.linearView();

    if (other.isScalar())
        return muli(cOther.getComplex(0), result);


    IComplexNumber c = Nd4j.createComplexNumber(0, 0);
    IComplexNumber d = Nd4j.createComplexNumber(0, 0);

    for (int i = 0; i < length(); i++)
        cResultLinear.putScalar(i, linear.getComplex(i, c).muli(cOtherLinear.getComplex(i, d)));
    return cResult;
}
 
Example 3
Source File: CpuNDArrayFactory.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public INDArray sort(INDArray x, boolean descending, int... dimension) {
    if (x.isScalar())
        return x;

    Arrays.sort(dimension);
    Pair<DataBuffer, DataBuffer> tadBuffers = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, dimension);


    NativeOpsHolder.getInstance().getDeviceNativeOps().sortTad(null,
                x.data().addressPointer(), (LongPointer) x.shapeInfoDataBuffer().addressPointer(),
            null, null,
                (IntPointer) Nd4j.getConstantHandler().getConstantBuffer(dimension, DataType.INT).addressPointer(),
                dimension.length,
                (LongPointer) tadBuffers.getFirst().addressPointer(),
                new LongPointerWrapper(tadBuffers.getSecond().addressPointer()),
                descending);


    return x;
}
 
Example 4
Source File: BaseNDArrayFactory.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * Generate a linearly spaced vector
 *
 * @param lower upper bound
 * @param upper lower bound
 * @param num   the step size
 * @return the linearly spaced vector
 */
@Override
public INDArray linspace(int lower, int upper, int num) {
    double[] data = new double[num];
    for (int i = 0; i < num; i++) {
        double t = (double) i / (num - 1);
        data[i] = lower * (1 - t) + t * upper;

    }

    //edge case for scalars
    INDArray ret = Nd4j.create(data.length);
    if (ret.isScalar())
        return ret;

    for (int i = 0; i < ret.length(); i++)
        ret.putScalar(i, data[i]);
    return ret;
}
 
Example 5
Source File: NDArrayPreconditionsFormat.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public String format(String tag, Object arg) {
    if(arg == null)
        return "null";
    INDArray arr = (INDArray)arg;
    switch (tag){
        case "%ndRank":
            return String.valueOf(arr.rank());
        case "%ndShape":
            return Arrays.toString(arr.shape());
        case "%ndStride":
            return Arrays.toString(arr.stride());
        case "%ndLength":
            return String.valueOf(arr.length());
        case "%ndSInfo":
            return arr.shapeInfoToString().replaceAll("\n","");
        case "%nd10":
            if(arr.isScalar() || arr.isEmpty()){
                return arr.toString();
            }
            INDArray sub = arr.reshape(arr.length()).get(NDArrayIndex.interval(0, Math.min(arr.length(), 10)));
            return sub.toString();
        default:
            //Should never happen
            throw new IllegalStateException("Unknown format tag: " + tag);
    }
}
 
Example 6
Source File: FirstAxisIterator.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public Object next() {
    INDArray s = iterateOver.slice(i++);
    if (s.isScalar()) {
        return s.getDouble(0);
    } else {
        return s;
    }
}
 
Example 7
Source File: BaseComplexNDArray.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * Copy imaginary numbers to the given
 * ndarray
 * @param arr the array to copy imaginary numbers to
 */
protected void copyImagTo(INDArray arr) {
    INDArray linear = arr.linearView();
    IComplexNDArray thisLinear = linearView();
    if (arr.isScalar())
        arr.putScalar(0, getReal(0));
    else
        for (int i = 0; i < linear.length(); i++) {
            arr.putScalar(i, thisLinear.getImag(i));
        }

}
 
Example 8
Source File: BaseComplexNDArray.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * Copy real numbers to arr
 * @param arr the arr to copy to
 */
protected void copyRealTo(INDArray arr) {
    INDArray linear = arr.linearView();
    IComplexNDArray thisLinear = linearView();
    if (arr.isScalar())
        arr.putScalar(0, getReal(0));
    else
        for (int i = 0; i < linear.length(); i++) {
            arr.putScalar(i, thisLinear.getReal(i));
        }

}
 
Example 9
Source File: JCublasNDArrayFactory.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray sort(INDArray x, boolean descending, int... dimension) {
    if (x.isScalar())
        return x;

    Arrays.sort(dimension);

    Nd4j.getExecutioner().push();

    val tadBuffers = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, dimension);

    val context = AtomicAllocator.getInstance().getFlowController().prepareAction(x);

    val extraz = new PointerPointer(AtomicAllocator.getInstance().getHostPointer(x.shapeInfoDataBuffer()), // not used
            context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer());


    val dimensionPointer = AtomicAllocator.getInstance()
            .getHostPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension));


    nativeOps.sortTad(extraz,
                null,
                (LongPointer) x.shapeInfoDataBuffer().addressPointer(),
                AtomicAllocator.getInstance().getPointer(x, context),
                (LongPointer) AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context),
                (IntPointer) dimensionPointer,
                dimension.length,
                (LongPointer) AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context),
                new LongPointerWrapper(AtomicAllocator.getInstance().getPointer(tadBuffers.getSecond(), context)),
                descending
        );

    if (nativeOps.lastErrorCode() != 0)
        throw new RuntimeException(nativeOps.lastErrorMessage());

    AtomicAllocator.getInstance().getFlowController().registerAction(context, x);

    return x;
}
 
Example 10
Source File: CpuNDArrayFactory.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray sort(INDArray x, boolean descending) {
    if (x.isScalar())
        return x;

    if (x.data().dataType() == DataBuffer.Type.FLOAT) {
        NativeOpsHolder.getInstance().getDeviceNativeOps().sortFloat(null, (FloatPointer) x.data().addressPointer(), (LongPointer) x.shapeInfoDataBuffer().addressPointer(), descending);
    } else if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
        NativeOpsHolder.getInstance().getDeviceNativeOps().sortDouble(null, (DoublePointer) x.data().addressPointer(), (LongPointer) x.shapeInfoDataBuffer().addressPointer(), descending);
    } else {
        throw new UnsupportedOperationException("Unknown dataype " + x.data().dataType());
    }
    return x;
}
 
Example 11
Source File: CpuSparseNDArrayFactory.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray sort(INDArray x, boolean descending, int... dimension) {
    if (x.isScalar())
        return x;

    Arrays.sort(dimension);
    Pair<DataBuffer, DataBuffer> tadBuffers = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, dimension);

    if (x.data().dataType() == DataBuffer.Type.FLOAT) {
        NativeOpsHolder.getInstance().getDeviceNativeOps().sortTadFloat(null,
                (FloatPointer) x.data().addressPointer(),
                (LongPointer) x.shapeInfoDataBuffer().addressPointer(),
                new IntPointer(dimension),
                dimension.length,
                (LongPointer) tadBuffers.getFirst().addressPointer(),
                new LongPointerWrapper(tadBuffers.getSecond().addressPointer()),
                descending);
    } else if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
        NativeOpsHolder.getInstance().getDeviceNativeOps().sortTadDouble(null,
                (DoublePointer) x.data().addressPointer(),
                (LongPointer) x.shapeInfoDataBuffer().addressPointer(),
                new IntPointer(dimension),
                dimension.length,
                (LongPointer) tadBuffers.getFirst().addressPointer(),
                new LongPointerWrapper(tadBuffers.getSecond().addressPointer()),
                descending);
    } else {
        throw new UnsupportedOperationException("Unknown datatype " + x.data().dataType());
    }

    return x;
}
 
Example 12
Source File: CpuSparseNDArrayFactory.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray sort(INDArray x, boolean descending) {
    if (x.isScalar())
        return x;

    if (x.data().dataType() == DataBuffer.Type.FLOAT) {
        NativeOpsHolder.getInstance().getDeviceNativeOps().sortFloat(null, (FloatPointer) x.data().addressPointer(), (LongPointer) x.shapeInfoDataBuffer().addressPointer(), descending);
    } else if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
        NativeOpsHolder.getInstance().getDeviceNativeOps().sortDouble(null, (DoublePointer) x.data().addressPointer(), (LongPointer) x.shapeInfoDataBuffer().addressPointer(), descending);
    } else {
        throw new UnsupportedOperationException("Unknown dataype " + x.data().dataType());
    }
    return x;
}
 
Example 13
Source File: CoverageModelEMWorkspaceMathUtils.java    From gatk-protected with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
/**
 * Solves a linear system using Apache commons methods [mat].[x] = [vec]
 *
 * @param mat the coefficients matrix (must be square and full-rank)
 * @param vec the right hand side vector
 * @param singularityThreshold a threshold for detecting singularity
 * @return solution of the linear system
 */
public static INDArray linsolve(@Nonnull final INDArray mat, @Nonnull final INDArray vec,
                                final double singularityThreshold) {
    if (mat.isScalar()) {
        return vec.div(mat.getDouble(0));
    }
    if (!mat.isSquare()) {
        throw new IllegalArgumentException("invalid array: must be a square matrix");
    }
    final RealVector sol = new LUDecomposition(Nd4jApacheAdapterUtils.convertINDArrayToApacheMatrix(mat),
            singularityThreshold).getSolver().solve(Nd4jApacheAdapterUtils.convertINDArrayToApacheVector(vec));
    return Nd4j.create(sol.toArray(), vec.shape());
}
 
Example 14
Source File: CoverageModelEMWorkspaceMathUtils.java    From gatk-protected with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
/**
 * Calculates log abs determinant of a matrix via LU decomposition.
 *
 * @param mat a square matrix
 * @return log abs determinant of {@code mat}
 */
public static double logdet(@Nonnull final INDArray mat) {
    if (mat.isScalar()) {
        return FastMath.log(FastMath.abs(mat.getDouble(0)));
    }
    if (!mat.isSquare()) {
        throw new IllegalArgumentException("Invalid array: must be square matrix");
    }
    final LUDecomposition decomp = new LUDecomposition(Nd4jApacheAdapterUtils.convertINDArrayToApacheMatrix(mat),
            DEFAULT_LU_DECOMPOSITION_SINGULARITY_THRESHOLD);
    final double[] diagL = diag(decomp.getL());
    final double[] diagU = diag(decomp.getU());
    return Arrays.stream(diagL).map(FastMath::abs).map(FastMath::log).sum() +
            Arrays.stream(diagU).map(FastMath::abs).map(FastMath::log).sum();
}
 
Example 15
Source File: CpuNDArrayFactory.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray sort(INDArray x, boolean descending) {
    if (x.isScalar())
        return x;


    NativeOpsHolder.getInstance().getDeviceNativeOps().sort(null,
            x.data().addressPointer(), (LongPointer) x.shapeInfoDataBuffer().addressPointer(),
            null, null,
            descending);

    return x;
}
 
Example 16
Source File: NativeOpExecutioner.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public INDArray exec(IndexAccumulation op, OpContext oc) {
    checkForCompression(op);

    INDArray x = getX(op, oc);
    INDArray z = getZ(op, oc);

    if (extraz.get() == null)
        extraz.set(new PointerPointer(32));

    val dimension = Shape.normalizeAxis(x.rank(), op.dimensions().toIntVector());

    if (x.isEmpty()) {
        for (val d:dimension) {
            Preconditions.checkArgument(x.shape()[d] != 0, "IndexReduce can't be issued along axis with 0 in shape");
        }
    }

    boolean keepDims = op.isKeepDims();
    long[] retShape = Shape.reductionShape(x, dimension, true, keepDims);

    if(z == null || x == z) {
        val ret = Nd4j.createUninitialized(DataType.LONG, retShape);

        setZ(ret, op, oc);
        z = ret;
    } else if(!Arrays.equals(retShape, z.shape())){
        throw new IllegalStateException("Z array shape does not match expected return type for op " + op
                + ": expected shape " + Arrays.toString(retShape) + ", z.shape()=" + Arrays.toString(z.shape()));
    }

    op.validateDataTypes();

    Pointer dimensionAddress = constantHandler.getConstantBuffer(dimension, DataType.INT).addressPointer();

    Pair<DataBuffer, DataBuffer> tadBuffers = tadManager.getTADOnlyShapeInfo(x, dimension);

    Pointer hostTadShapeInfo = tadBuffers.getFirst().addressPointer();

    DataBuffer offsets = tadBuffers.getSecond();
    Pointer hostTadOffsets = offsets == null ? null : offsets.addressPointer();

    PointerPointer dummy = extraz.get().put(hostTadShapeInfo, hostTadOffsets);

    long st = profilingConfigurableHookIn(op, tadBuffers.getFirst());

    val xb = ((BaseCpuDataBuffer) x.data()).getOpaqueDataBuffer();
    val zb = ((BaseCpuDataBuffer) z.data()).getOpaqueDataBuffer();

    if (z.isScalar()) {
        loop.execIndexReduceScalar(dummy, op.opNum(),
                    xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null,
                    getPointerForExtraArgs(op, x.dataType()),
                    zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null);
        } else {
            loop.execIndexReduce(dummy, op.opNum(),
                    xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null,
                    getPointerForExtraArgs(op, x.dataType()),
                    zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null,
                    ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
        }

    if (loop.lastErrorCode() != 0)
        throw new RuntimeException(loop.lastErrorMessage());

    profilingConfigurableHookOut(op, oc, st);
    return getZ(op, oc);
}
 
Example 17
Source File: JCublasNDArrayFactory.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public INDArray sort(INDArray x, boolean descending, int... dimension) {
    if (x.isScalar())
        return x;

    Arrays.sort(dimension);

    Nd4j.getExecutioner().push();

    Pair<DataBuffer, DataBuffer> tadBuffers = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, dimension);

    CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(x);

    PointerPointer extraz = new PointerPointer(AtomicAllocator.getInstance().getHostPointer(x.shapeInfoDataBuffer()), // not used
            context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer());


    Pointer dimensionPointer = AtomicAllocator.getInstance()
            .getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);

    if (x.data().dataType() == DataBuffer.Type.FLOAT) {
        nativeOps.sortTadFloat(extraz,
                (FloatPointer) AtomicAllocator.getInstance().getPointer(x, context),
                (LongPointer) AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context),
                (IntPointer) dimensionPointer,
                dimension.length,
                (LongPointer) AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context),
                new LongPointerWrapper(AtomicAllocator.getInstance().getPointer(tadBuffers.getSecond(), context)),
                descending
        );
    } else if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
        nativeOps.sortTadDouble(extraz,
                (DoublePointer) AtomicAllocator.getInstance().getPointer(x, context),
                (LongPointer) AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context),
                (IntPointer) dimensionPointer,
                dimension.length,
                (LongPointer) AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context),
                new LongPointerWrapper(AtomicAllocator.getInstance().getPointer(tadBuffers.getSecond(), context)),
                descending
        );
    } else if (x.data().dataType() == DataBuffer.Type.HALF) {
        nativeOps.sortTadHalf(extraz,
                (ShortPointer) AtomicAllocator.getInstance().getPointer(x, context),
                (LongPointer) AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context),
                (IntPointer) dimensionPointer,
                dimension.length,
                (LongPointer) AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context),
                new LongPointerWrapper(AtomicAllocator.getInstance().getPointer(tadBuffers.getSecond(), context)),
                descending
        );
    } else {
        throw new UnsupportedOperationException("Unknown dataType " + x.data().dataType());
    }

    AtomicAllocator.getInstance().getFlowController().registerAction(context, x);

    return x;
}
 
Example 18
Source File: JCublasNDArrayFactory.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public INDArray sort(INDArray x, boolean descending) {
    if (x.isScalar())
        return x;

    Nd4j.getExecutioner().push();

    CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(x);

    Pointer ptr = AtomicAllocator.getInstance().getHostPointer(x.shapeInfoDataBuffer());

    PointerPointer extraz = new PointerPointer(ptr, // 0
            context.getOldStream(), // 1
            AtomicAllocator.getInstance().getDeviceIdPointer(), // 2
            context.getBufferAllocation(), // 3
            context.getBufferReduction(), // 4
            context.getBufferScalar(), // 5
            context.getBufferSpecial(), // 6
            ptr, // 7
            AtomicAllocator.getInstance().getHostPointer(x.shapeInfoDataBuffer()), // 8
            ptr, // 9
            ptr, // 10
            ptr, // 11
            ptr, // 12
            ptr, // 13
            ptr, // 14
            ptr, // special pointer for IsMax  // 15
            ptr, // special pointer for IsMax  // 16
            ptr, // special pointer for IsMax // 17
            new CudaPointer(0));

    // we're sending > 10m elements to radixSort
    boolean isRadix = !x.isView() && (x.lengthLong() > 1024 * 1024 * 10);
    INDArray tmpX = x;

    // we need to guarantee all threads are finished here
    if (isRadix)
        Nd4j.getExecutioner().commit();

    if (x.data().dataType() == DataBuffer.Type.FLOAT) {
        nativeOps.sortFloat(extraz,
                (FloatPointer) AtomicAllocator.getInstance().getPointer(tmpX, context),
                (LongPointer) AtomicAllocator.getInstance().getPointer(tmpX.shapeInfoDataBuffer(), context),
                descending
                );
    } else if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
        nativeOps.sortDouble(extraz,
                (DoublePointer) AtomicAllocator.getInstance().getPointer(tmpX, context),
                (LongPointer) AtomicAllocator.getInstance().getPointer(tmpX.shapeInfoDataBuffer(), context),
                descending
        );
    } else if (x.data().dataType() == DataBuffer.Type.HALF) {
        nativeOps.sortHalf(extraz,
                (ShortPointer) AtomicAllocator.getInstance().getPointer(tmpX, context),
                (LongPointer) AtomicAllocator.getInstance().getPointer(tmpX.shapeInfoDataBuffer(), context),
                descending
        );
    } else {
        throw new UnsupportedOperationException("Unknown dataType " + x.data().dataType());
    }

    AtomicAllocator.getInstance().getFlowController().registerAction(context, x);

    return x;
}
 
Example 19
Source File: BaseComplexNDArray.java    From nd4j with Apache License 2.0 4 votes vote down vote up
/**
 * Perform an copy matrix multiplication
 *
 * @param other  the other matrix to perform matrix multiply with
 * @param result the result ndarray
 * @return the result of the matrix multiplication
 */
@Override
public IComplexNDArray mmuli(INDArray other, INDArray result) {


    IComplexNDArray otherArray = (IComplexNDArray) other;
    IComplexNDArray resultArray = (IComplexNDArray) result;

    if (other.shape().length > 2) {
        for (int i = 0; i < other.slices(); i++) {
            resultArray.putSlice(i, slice(i).mmul(otherArray.slice(i)));
        }

        return resultArray;

    }


    LinAlgExceptions.assertMultiplies(this, other);

    if (other.isScalar()) {
        return muli(otherArray.getComplex(0), resultArray);
    }
    if (isScalar()) {
        return otherArray.muli(getComplex(0), resultArray);
    }

    /* check sizes and resize if necessary */
    //assertMultipliesWith(other);


    if (result == this || result == other) {
        /* actually, blas cannot do multiplications in-place. Therefore, we will fake by
         * allocating a temporary object on the side and copy the result later.
         */
        IComplexNDArray temp = Nd4j.createComplex(resultArray.shape());

        if (otherArray.columns() == 1) {
            Nd4j.getBlasWrapper().level2().gemv(BlasBufferUtil.getCharForTranspose(temp),
                            BlasBufferUtil.getCharForTranspose(this), Nd4j.UNIT, this, otherArray, Nd4j.ZERO, temp);
        } else {
            Nd4j.getBlasWrapper().level3().gemm(BlasBufferUtil.getCharForTranspose(temp),
                            BlasBufferUtil.getCharForTranspose(this), BlasBufferUtil.getCharForTranspose(other),
                            Nd4j.UNIT, this, otherArray, Nd4j.ZERO, temp);

        }

        Nd4j.getBlasWrapper().copy(temp, resultArray);


    } else {
        if (otherArray.columns() == 1) {
            Nd4j.getBlasWrapper().level2().gemv(BlasBufferUtil.getCharForTranspose(resultArray),
                            BlasBufferUtil.getCharForTranspose(this), Nd4j.UNIT, this, otherArray, Nd4j.ZERO,
                            resultArray);
        }


        else {
            Nd4j.getBlasWrapper().level3().gemm(BlasBufferUtil.getCharForTranspose(resultArray),
                            BlasBufferUtil.getCharForTranspose(this), BlasBufferUtil.getCharForTranspose(other),
                            Nd4j.UNIT, this, otherArray, Nd4j.ZERO, resultArray);
        }
    }
    return resultArray;


}
 
Example 20
Source File: MtcnnService.java    From mtcnn-java with Apache License 2.0 4 votes vote down vote up
/**
 *  STAGE 2
 *
 * @param image
 * @param totalBoxes
 * @param padResult
 * @return
 * @throws IOException
 */
private INDArray refinementStage(INDArray image, INDArray totalBoxes, MtcnnUtil.PadResult padResult) throws IOException {

	// num_boxes = total_boxes.shape[0]
	int numBoxes = totalBoxes.isEmpty() ? 0 : (int) totalBoxes.shape()[0];
	// if num_boxes == 0:
	//   return total_boxes, stage_status
	if (numBoxes == 0) {
		return totalBoxes;
	}

	INDArray tempImg1 = computeTempImage(image, numBoxes, padResult, 24);

	//this.refineNetGraph.associateArrayWithVariable(tempImg1, this.refineNetGraph.variableMap().get("rnet/input"));
	//List<DifferentialFunction> refineNetResults = this.refineNetGraph.exec().getRight();
	//INDArray out0 = refineNetResults.stream().filter(df -> df.getOwnName().equalsIgnoreCase("rnet/fc2-2/fc2-2"))
	//		.findFirst().get().outputVariable().getArr();
	//INDArray out1 = refineNetResults.stream().filter(df -> df.getOwnName().equalsIgnoreCase("rnet/prob1"))
	//		.findFirst().get().outputVariable().getArr();

	Map<String, INDArray> resultMap = this.refineNetGraphRunner.run(Collections.singletonMap("rnet/input", tempImg1));
	//INDArray out0 = resultMap.get("rnet/fc2-2/fc2-2");  // for ipazc/mtcnn model
	INDArray out0 = resultMap.get("rnet/conv5-2/conv5-2");
	INDArray out1 = resultMap.get("rnet/prob1");

	//  score = out1[1, :]
	INDArray score = out1.get(all(), point(1)).transposei();

	// ipass = np.where(score > self.__steps_threshold[1])
	INDArray ipass = MtcnnUtil.getIndexWhereVector(score.transpose(), s -> s > stepsThreshold[1]);
	//INDArray ipass = MtcnnUtil.getIndexWhereVector2(score.transpose(), Conditions.greaterThan(stepsThreshold[1]));

	if (ipass.isEmpty()) {
		totalBoxes = Nd4j.empty();
		return totalBoxes;
	}
	// total_boxes = np.hstack([total_boxes[ipass[0], 0:4].copy(), np.expand_dims(score[ipass].copy(), 1)])
	INDArray b1 = totalBoxes.get(new SpecifiedIndex(ipass.toLongVector()), interval(0, 4));
	INDArray b2 = ipass.isScalar() ? score.get(ipass).reshape(1, 1)
			: Nd4j.expandDims(score.get(ipass), 1);
	totalBoxes = Nd4j.hstack(b1, b2);

	// mv = out0[:, ipass[0]]
	INDArray mv = out0.get(new SpecifiedIndex(ipass.toLongVector()), all()).transposei();

	// if total_boxes.shape[0] > 0:
	if (!totalBoxes.isEmpty() && totalBoxes.shape()[0] > 0) {
		// pick = self.__nms(total_boxes, 0.7, 'Union')
		INDArray pick = MtcnnUtil.nonMaxSuppression(totalBoxes.dup(), 0.7, MtcnnUtil.NonMaxSuppressionType.Union).transpose();

		// total_boxes = total_boxes[pick, :]
		totalBoxes = totalBoxes.get(new SpecifiedIndex(pick.toLongVector()), all());

		// total_boxes = self.__bbreg(total_boxes.copy(), np.transpose(mv[:, pick]))
		totalBoxes = MtcnnUtil.bbreg(totalBoxes, mv.get(all(), new SpecifiedIndex(pick.toLongVector())).transpose());

		// total_boxes = self.__rerec(total_boxes.copy())
		totalBoxes = MtcnnUtil.rerec(totalBoxes, false);
	}

	return totalBoxes;
}