Java Code Examples for org.nd4j.linalg.api.buffer.util.DataTypeUtil#getDtypeFromContext()

The following examples show how to use org.nd4j.linalg.api.buffer.util.DataTypeUtil#getDtypeFromContext() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: BaseNDArrayList.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public boolean add(X aX) {
    if(container == null) {
        container = Nd4j.create(10);
    }
    else if(size == container.length()) {
        growCapacity(size * 2);
    }
    if(DataTypeUtil.getDtypeFromContext() == DataBuffer.Type.DOUBLE)
        container.putScalar(size,aX.doubleValue());
    else {
        container.putScalar(size,aX.floatValue());

    }

    size++;
    return true;
}
 
Example 2
Source File: BaseDataBuffer.java    From nd4j with Apache License 2.0 6 votes vote down vote up
public void pointerIndexerByGlobalType(Type currentType) {
    if (currentType == Type.LONG) {
        pointer = new LongPointer(length());
        setIndexer(LongRawIndexer.create((LongPointer) pointer));
        type = Type.LONG;
    } else if (currentType == Type.INT) {
        pointer = new IntPointer(length());
        setIndexer(IntIndexer.create((IntPointer) pointer));
        type = Type.INT;
    } else {
        if (DataTypeUtil.getDtypeFromContext() == Type.DOUBLE) {
            pointer = new DoublePointer(length());
            indexer = DoubleIndexer.create((DoublePointer) pointer);
        } else if (DataTypeUtil.getDtypeFromContext() == Type.FLOAT) {
            pointer = new FloatPointer(length());
            setIndexer(FloatIndexer.create((FloatPointer) pointer));
        } else if (DataTypeUtil.getDtypeFromContext() == Type.LONG) {
            pointer = new LongPointer(length());
            setIndexer(LongIndexer.create((LongPointer) pointer));
        }
    }
}
 
Example 3
Source File: BaseNDArrayList.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public boolean add(X aX) {
    if(container == null) {
        container = Nd4j.create(10);
    }
    else if(size == container.length()) {
        growCapacity(size * 2);
    }
    if(DataTypeUtil.getDtypeFromContext() == DataType.DOUBLE)
        container.putScalar(size,aX.doubleValue());
    else {
        container.putScalar(size,aX.floatValue());

    }

    size++;
    return true;
}
 
Example 4
Source File: BaseNDArrayList.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public X set(int i, X aX) {
    if(DataTypeUtil.getDtypeFromContext() == DataBuffer.Type.DOUBLE)
        container.putScalar(i,aX.doubleValue());
    else {
        container.putScalar(i,aX.floatValue());

    }


    return aX;
}
 
Example 5
Source File: BaseNDArrayList.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public void add(int i, X aX) {
    rangeCheck(i);
    growCapacity(i);
    moveForward(i);
    if(DataTypeUtil.getDtypeFromContext() == DataBuffer.Type.DOUBLE)
        container.putScalar(i,aX.doubleValue());
    else {
        container.putScalar(i,aX.floatValue());

    }

    size++;
}
 
Example 6
Source File: BaseNDArrayList.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public X set(int i, X aX) {
    if(DataTypeUtil.getDtypeFromContext() == DataType.DOUBLE)
        container.putScalar(i,aX.doubleValue());
    else {
        container.putScalar(i,aX.floatValue());

    }


    return aX;
}
 
Example 7
Source File: BaseNDArrayList.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void add(int i, X aX) {
    rangeCheck(i);
    growCapacity(i);
    moveForward(i);
    if(DataTypeUtil.getDtypeFromContext() == DataType.DOUBLE)
        container.putScalar(i,aX.doubleValue());
    else {
        container.putScalar(i,aX.floatValue());

    }

    size++;
}
 
Example 8
Source File: GradCheckUtil.java    From nd4j with Apache License 2.0 4 votes vote down vote up
/**
 *
 * @param function
 * @param epsilon
 * @param maxRelError
 * @param print
 * @param inputParameters
 * @return
 */
public static boolean checkGradients(
        SDVariable function,
        SDVariable wrt,
        double epsilon,
        double maxRelError,
        boolean print,
        Map<String,INDArray> inputParameters) {
    //Basic sanity checks on input:
    if (epsilon <= 0.0 || epsilon > 0.1)
        throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
    if (maxRelError <= 0.0 || maxRelError > 0.25)
        throw new IllegalArgumentException("Invalid maxRelativeError: " + maxRelError);

    DataBuffer.Type dataType = DataTypeUtil.getDtypeFromContext();
    if (dataType != DataBuffer.Type.DOUBLE) {
        throw new IllegalStateException("Cannot perform gradient check: Datatype is not set to double precision ("
                + "is: " + dataType + "). Double precision must be used for gradient checks. Set "
                + "DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE); before using GradientCheckUtil");
    }

    /**
     * Need to pass in the exact gradient.
     * This is obtained from executing a subgraph
     * with just the gradient part to get the exact values.
     * You then run the comparison vs the approximation from that.
     *
     * To obtain the comparison/computing the values,  use the below routine
     */


    SameDiff sameDiff = function.getSameDiff();
    //get just the subgraph for the graph
    SameDiff opExec = SameDiff.create(sameDiff);

    INDArray[] eval = opExec.eval(inputParameters);
    int totalNFailures = 0;
    double maxError = 0.0;

    for(Map.Entry<String,INDArray> entry : inputParameters.entrySet()) {
        long nParams = entry.getValue().length();
        INDArray params = entry.getValue().dup();
        for (int i = 0; i < nParams; i++) {
            INDArray zeros = Nd4j.create(nParams);
            zeros.putScalar(i,epsilon / 2.0);

            //(w+epsilon): Do forward pass and score
            double origValue = params.getDouble(i);
            params.putScalar(i, origValue + epsilon);
            Map<String, INDArray> evalParams = new HashMap<>();
            for (Map.Entry<String, INDArray> entry2 : inputParameters.entrySet()) {
                if (!entry2.getKey().equals(entry.getKey())) {
                    evalParams.put(entry2.getKey(), entry2.getValue());
                } else {
                    evalParams.put(entry.getKey(), params);
                }
            }

            /**
             * Need to figure out how I want to extract
             * parameters for computing the delta..
             *
             */
            INDArray[] plusParams = sameDiff.eval(evalParams);


            INDArray[] minusParams = sameDiff.eval(evalParams);


            /**
             * Difference between new params and old
             */
            INDArray[] newDifferences = new INDArray[minusParams.length];
            for (int j = 0; j < newDifferences.length; j++) {
                newDifferences[j] = plusParams[j].subi(minusParams[j]).divi(epsilon);
            }

            double diff = plusParams[plusParams.length - 1].sumNumber().doubleValue() - minusParams[minusParams.length - 1].sumNumber().doubleValue();
            double eps = diff / epsilon;
            double correctVal = eval[eval.length - 1].sumNumber().doubleValue();
            double gradDiff = Math.abs(correctVal - eps);
            if(gradDiff > maxRelError)
                totalNFailures++;
            if (print) {
                long nPass = nParams - totalNFailures;
                log.info("GradientCheckUtil.checkGradients(): " + nParams + " params checked, " + nPass + " passed, "
                        + totalNFailures + " failed. Largest relative error = " + maxError);
            }
        }
    }

    return totalNFailures == 0;
}