Java Code Examples for org.nd4j.autodiff.samediff.SDVariable#dataType()

The following examples show how to use org.nd4j.autodiff.samediff.SDVariable#dataType() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ExternalErrorsFunction.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
    List<SDVariable> out = new ArrayList<>();
    if (gradVariables == null) {
        gradVariables = new HashMap<>();
        for(SDVariable arg : args()){
            INDArray gradArr = gradients.get(arg.name());
            SDVariable grad;
            DataType dt = arg.dataType();
            String n = getGradPlaceholderName();
            if(gradArr != null){
                long[] shape = gradArr.shape().clone();
                shape[0] = -1;
                grad = sameDiff.var(n, VariableType.PLACEHOLDER, null, dt, shape);
            } else {
                grad = sameDiff.var(n, VariableType.PLACEHOLDER, null, dt);
            }
            gradVariables.put(arg.name(), grad);
            out.add(grad);
        }
    }
    return out;
}
 
Example 2
Source File: SDValidation.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Validate that the operation is being applied on a numerical SDVariable (not boolean or utf8).
 * Some operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 arrays
 *
 * @param opName Operation name to print in the exception
 * @param v      Variable to validate datatype for (input to operation)
 */
protected static void validateNumerical(String opName, String inputName, SDVariable v) {
    if (v == null)
        return;
    if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8)
        throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an numerical type type; got variable \"" +
                v.name() + "\" with non-integer data type " + v.dataType());
}
 
Example 3
Source File: SDValidation.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Validate that the operation is being applied on a boolean type SDVariable
 *
 * @param opName    Operation name to print in the exception
 * @param inputName Name of the input to the op to validate
 * @param v         Variable to validate datatype for (input to operation)
 */
protected static void validateBool(String opName, String inputName, SDVariable v) {
    if (v == null)
        return;
    if (v.dataType() != DataType.BOOL)
        throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an boolean variable; got variable \"" +
                v.name() + "\" with non-boolean data type " + v.dataType());
}
 
Example 4
Source File: SDValidation.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Validate that the operation is being applied on a boolean type SDVariable
 *
 * @param opName Operation name to print in the exception
 * @param v      Variable to validate datatype for (input to operation)
 */
protected static void validateBool(String opName, SDVariable v) {
    if (v == null)
        return;
    if (v.dataType() != DataType.BOOL)
        throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.name() + "\" with non-boolean point data type " + v.dataType());
}
 
Example 5
Source File: SDValidation.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Validate that the operation is being applied on a floating point type SDVariable
 *
 * @param opName    Operation name to print in the exception
 * @param inputName Name of the input to the op to validate
 * @param v         Variable to validate datatype for (input to operation)
 */
protected static void validateFloatingPoint(String opName, String inputName, SDVariable v) {
    if (v == null)
        return;
    if (!v.dataType().isFPType())
        throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an floating point type; got variable \"" +
                v.name() + "\" with non-floating point data type " + v.dataType());
}
 
Example 6
Source File: SDValidation.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Validate that the operation is being applied on an floating point type SDVariable
 *
 * @param opName Operation name to print in the exception
 * @param v      Variable to validate datatype for (input to operation)
 */
protected static void validateFloatingPoint(String opName, SDVariable v) {
    if (v == null)
        return;
    if (!v.dataType().isFPType())
        throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.name() + "\" with non-floating point data type " + v.dataType());
}
 
Example 7
Source File: SDValidation.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
protected static void validateInteger(String opName, String inputName, SDVariable[] vars) {
    for (SDVariable v : vars) {
        if (v == null)
            return;
        if (!v.dataType().isIntType())
            throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer type; got variable \"" +
                    v.name() + "\" with non-integer data type " + v.dataType());
    }
}
 
Example 8
Source File: SDValidation.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Validate that the operation is being applied on an integer type SDVariable
 *
 * @param opName    Operation name to print in the exception
 * @param inputName Name of the input to the op to validate
 * @param v         Variable to validate datatype for (input to operation)
 */
protected static void validateInteger(String opName, String inputName, SDVariable v) {
    if (v == null)
        return;
    if (!v.dataType().isIntType())
        throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer type; got variable \"" +
                v.name() + "\" with non-integer data type " + v.dataType());
}
 
Example 9
Source File: SDValidation.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Validate that the operation is being applied on an integer type SDVariable
 *
 * @param opName Operation name to print in the exception
 * @param v      Variable to validate datatype for (input to operation)
 */
protected static void validateInteger(String opName, SDVariable v) {
    if (v == null)
        return;
    if (!v.dataType().isIntType())
        throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.name() + "\" with non-integer data type " + v.dataType());
}
 
Example 10
Source File: SDValidation.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
protected static void validateNumerical(String opName, String inputName, SDVariable[] vars) {
    for (SDVariable v : vars) {
        if (v == null) continue;
        if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8)
            throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an numerical type type; got variable \"" +
                    v.name() + "\" with non-integer data type " + v.dataType());
    }
}
 
Example 11
Source File: ZerosLike.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public ZerosLike(String name, SameDiff sameDiff, SDVariable input, boolean inPlace) {
    this(name, sameDiff, input, inPlace, input.dataType());
}
 
Example 12
Source File: SDValidation.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public static boolean isSameType(SDVariable x, SDVariable y) {
    return x.dataType() == y.dataType();
}
 
Example 13
Source File: FusedBatchNorm.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset,
                      @NonNull SDVariable dataFormat, @NonNull SDVariable isTraining) {
    super("", sameDiff, new SDVariable[]{x, scale, offset, dataFormat, isTraining});
    this.outputDataType = x.dataType();
}
 
Example 14
Source File: FusedBatchNorm.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset,
                      int dataFormat, int isTraining) {
    super("", sameDiff, new SDVariable[]{x, scale, offset});
    addIArgument(dataFormat, isTraining);
    this.outputDataType = x.dataType();
}
 
Example 15
Source File: OnesLike.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public OnesLike(String name, SameDiff sameDiff, SDVariable input) {
    this(name, sameDiff, input, input.dataType());
}
 
Example 16
Source File: Create.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public Create(String name, SameDiff sameDiff, SDVariable input, boolean initialize) {
    this(name, sameDiff, input, 'c', initialize, input.dataType());
}
 
Example 17
Source File: ZerosLike.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public ZerosLike(SameDiff sameDiff, SDVariable input) {
    this(null, sameDiff, input, false, input.dataType());
}
 
Example 18
Source File: ZerosLike.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public ZerosLike(String name, SameDiff sameDiff, SDVariable input) {
    this(name, sameDiff, input, false, input.dataType());
}
 
Example 19
Source File: SDValidation.java    From deeplearning4j with Apache License 2.0 3 votes vote down vote up
/**
 * Validate that the operation is being applied on a numerical SDVariable (not boolean or utf8).
 * Some operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 arrays
 *
 * @param opName Operation name to print in the exception
 * @param v      Variable to perform operation on
 */
protected static void validateNumerical(String opName, SDVariable v) {
    if (v == null)
        return;
    if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8)
        throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.name() + "\" with non-numerical data type " + v.dataType());
}
 
Example 20
Source File: SDValidation.java    From deeplearning4j with Apache License 2.0 2 votes vote down vote up
/**
 * Validate that the operation is being applied on numerical SDVariables (not boolean or utf8).
 * Some operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 arrays
 *
 * @param opName Operation name to print in the exception
 * @param v1     Variable to validate datatype for (input to operation)
 * @param v2     Variable to validate datatype for (input to operation)
 */
protected static void validateNumerical(String opName, SDVariable v1, SDVariable v2) {
    if (v1.dataType() == DataType.BOOL || v1.dataType() == DataType.UTF8 || v2.dataType() == DataType.BOOL || v2.dataType() == DataType.UTF8)
        throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on variables  \"" + v1.name() + "\" and \"" +
                v2.name() + "\" if one or both variables are non-numerical: " + v1.dataType() + " and " + v2.dataType());
}