Java Code Examples for org.nd4j.linalg.api.ndarray.INDArray#isRowVectorOrScalar()

The following examples show how to use org.nd4j.linalg.api.ndarray.INDArray#isRowVectorOrScalar() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: EvaluationBinary.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * Create an EvaulationBinary instance with an optional decision threshold array.
 *
 * @param decisionThreshold Decision threshold for each output; may be null. Should be a row vector with length
 *                          equal to the number of outputs, with values in range 0 to 1. An array of 0.5 values is
 *                          equivalent to the default (no manually specified decision threshold).
 */
public EvaluationBinary(INDArray decisionThreshold) {
    if (decisionThreshold != null) {
        if (!decisionThreshold.isRowVectorOrScalar()) {
            throw new IllegalArgumentException(
                            "Decision threshold array must be a row vector; got array with shape "
                                            + Arrays.toString(decisionThreshold.shape()));
        }
        if (decisionThreshold.minNumber().doubleValue() < 0.0) {
            throw new IllegalArgumentException("Invalid decision threshold array: minimum value is less than 0");
        }
        if (decisionThreshold.maxNumber().doubleValue() > 1.0) {
            throw new IllegalArgumentException(
                            "invalid decision threshold array: maximum value is greater than 1.0");
        }

        this.decisionThreshold = decisionThreshold;
    }
}
 
Example 2
Source File: RecordReaderMultiDataSetIterator.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
private int countLength(List<Writable> list, int from, int to) {
    int length = 0;
    for (int i = from; i <= to; i++) {
        Writable w = list.get(i);
        if (w instanceof NDArrayWritable) {
            INDArray a = ((NDArrayWritable) w).get();
            if (!a.isRowVectorOrScalar()) {
                throw new UnsupportedOperationException("Multiple writables present but NDArrayWritable is "
                                + "not a row vector. Can only concat row vectors with other writables. Shape: "
                                + Arrays.toString(a.shape()));
            }
            length += a.length();
        } else {
            //Assume all others are single value
            length++;
        }
    }

    return length;
}
 
Example 3
Source File: Evaluation.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 *  Created evaluation instance with the specified cost array. A cost array can be used to bias the multi class
 *  predictions towards or away from certain classes. The predicted class is determined using argMax(cost * probability)
 *  instead of argMax(probability) when no cost array is present.
 *
 * @param labels Labels for the output classes. May be null
 * @param costArray Row vector cost array. May be null
 */
public Evaluation(List<String> labels, INDArray costArray) {
    if (costArray != null && !costArray.isRowVectorOrScalar()) {
        throw new IllegalArgumentException("Invalid cost array: must be a row vector (got shape: "
                        + Arrays.toString(costArray.shape()) + ")");
    }
    if (costArray != null && costArray.minNumber().doubleValue() < 0.0) {
        throw new IllegalArgumentException("Invalid cost array: Cost array values must be positive");
    }
    this.labelsList = labels;
    this.costArray = costArray == null ? null : costArray.castTo(DataType.FLOAT);
    this.topN = 1;
}
 
Example 4
Source File: AdaGradUpdater.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, boolean initialize) {
    if (!viewArray.isRowVectorOrScalar())
        throw new IllegalArgumentException("Invalid input: expect row vector input");
    if (initialize)
        viewArray.assign(epsilon);
    this.historicalGradient = viewArray;
    //Reshape to match the expected shape of the input gradient arrays
    this.historicalGradient = Shape.newShapeNoCopy(this.historicalGradient, gradientShape, gradientOrder == 'f');
    if (historicalGradient == null)
        throw new IllegalStateException("Could not correctly reshape gradient view array");

    this.gradientReshapeOrder = gradientOrder;
}
 
Example 5
Source File: RmsPropUpdater.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, boolean initialize) {
    if (!viewArray.isRowVectorOrScalar())
        throw new IllegalArgumentException("Invalid input: expect row vector input");
    if (initialize)
        viewArray.assign(config.getEpsilon());
    this.lastGradient = viewArray;

    //Reshape to match the expected shape of the input gradient arrays
    this.lastGradient = Shape.newShapeNoCopy(this.lastGradient, gradientShape, gradientOrder == 'f');
    if (lastGradient == null)
        throw new IllegalStateException("Could not correctly reshape gradient view array");

    gradientReshapeOrder = gradientOrder;
}
 
Example 6
Source File: NesterovsUpdater.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, boolean initialize) {
    if (!viewArray.isRowVectorOrScalar())
        throw new IllegalArgumentException("Invalid input: expect row vector input");
    if (initialize)
        viewArray.assign(0);

    this.v = viewArray;

    //Reshape to match the expected shape of the input gradient arrays
    this.v = Shape.newShapeNoCopy(this.v, gradientShape, gradientOrder == 'f');
    if (v == null)
        throw new IllegalStateException("Could not correctly reshape gradient view array");
    this.gradientReshapeOrder = gradientOrder;
}