Java Code Examples for org.nd4j.linalg.api.buffer.DataBuffer#Type

The following examples show how to use org.nd4j.linalg.api.buffer.DataBuffer#Type . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ArrowSerde.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * Create thee databuffer type frm the given type,
 * relative to the bytes in arrow in class:
 * {@link Type}
 * @param type the type to create the nd4j {@link DataBuffer.Type} from
 * @param elementSize the element size
 * @return the data buffer type
 */
public static DataBuffer.Type typeFromTensorType(byte type,int elementSize) {
    if(type == Type.Decimal || type == Type.FloatingPoint) {
        if(elementSize == 4) {
            return DataBuffer.Type.FLOAT;
        }
        else if(elementSize == 8) {
            return DataBuffer.Type.DOUBLE;
        }
    }
    else if(type == Type.Int) {
        if(elementSize == 4) {
            return DataBuffer.Type.INT;
        }
        else if(elementSize == 8) {
            return DataBuffer.Type.LONG;
        }
    }
    else {
        throw new IllegalArgumentException("Only valid types are Type.Decimal and Type.Int");
    }

    throw new IllegalArgumentException("Unable to determine data type");
}
 
Example 2
Source File: SporadicTests.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void randomStrangeTest() {
    DataBuffer.Type type = Nd4j.dataType();
    DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE);

    int a=9;
    int b=2;
    int[] shapes = new int[a];
    for (int i = 0; i < a; i++) {
        shapes[i] = b;
    }
    INDArray c = Nd4j.linspace(1, (int) (100 * 1 + 1 + 2), (int) Math.pow(b, a)).reshape(shapes);
    c=c.sum(0);
    double[] d = c.data().asDouble();
    System.out.println("d: " + Arrays.toString(d));

    DataTypeUtil.setDTypeForContext(type);
}
 
Example 3
Source File: OnnxGraphMapper.java    From nd4j with Apache License 2.0 6 votes vote down vote up
public INDArray mapTensorProto(OnnxProto3.TensorProto tensor) {
    if(tensor == null)
        return null;


    DataBuffer.Type type = nd4jTypeFromOnnxType(tensor.getDataType());

    ByteString bytes = tensor.getRawData();
    ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
    ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder());
    directAlloc.put(byteBuffer);
    directAlloc.rewind();
    long[] shape = getShapeFromTensor(tensor);
    DataBuffer buffer = Nd4j.createBuffer(directAlloc,type, ArrayUtil.prod(shape));
    INDArray arr = Nd4j.create(buffer).reshape(shape);
    return arr;
}
 
Example 4
Source File: Cast.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public void setValueFor(Field target, Object value) {
    if(value == null) {
        throw new ND4JIllegalStateException("Unable to set field " + target + " using null value!");
    }

    // FIXME!
    if (!(value instanceof DataBuffer.Type))
        return;

    try {
        target.set(this, (DataBuffer.Type) value);
    } catch (IllegalAccessException e) {
        e.printStackTrace();
    }
}
 
Example 5
Source File: TFGraphMapper.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public DataBuffer.Type dataTypeForTensor(NodeDef tensorProto) {
    if(!tensorProto.containsAttr("dtype") && !tensorProto.containsAttr("Tidx") && !tensorProto.containsAttr("T"))
        return DataBuffer.Type.UNKNOWN;

    val type = tensorProto.containsAttr("dtype") ? tensorProto.getAttrOrThrow("dtype").getType()
            : tensorProto.containsAttr("T") ? tensorProto.getAttrOrThrow("T").getType() : tensorProto
            .getAttrOrThrow("Tidx").getType();
    switch(type) {
        case DT_DOUBLE: return DataBuffer.Type.DOUBLE;
        case DT_INT32:
        case DT_INT64: return DataBuffer.Type.INT;
        case DT_FLOAT: return DataBuffer.Type.FLOAT;
        case DT_BFLOAT16: return DataBuffer.Type.HALF;
        default: return DataBuffer.Type.UNKNOWN;
    }
}
 
Example 6
Source File: DefaultOpExecutioner.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * Validate the data types
 * for the given operation
 * @param expectedType
 * @param op
 */
public static void validateDataType(DataBuffer.Type expectedType, Op op) {
    if (op.x() != null && op.x().data().dataType() == DataBuffer.Type.COMPRESSED) {
        Nd4j.getCompressor().decompressi(op.x());
    }

    if (op.y() != null && op.y().data().dataType() == DataBuffer.Type.COMPRESSED) {
        Nd4j.getCompressor().decompressi(op.y());
    }

    if (op.z() != null && op.z().data().dataType() == DataBuffer.Type.COMPRESSED) {
        Nd4j.getCompressor().decompressi(op.z());
    }


    if (op.x() != null && op.x().data().dataType() != expectedType
                    && op.x().data().dataType() != DataBuffer.Type.COMPRESSED)
        throw new ND4JIllegalStateException("op.X dataType is [" + op.x().data().dataType()
                        + "] instead of expected [" + expectedType + "]");

    if (op.z() != null && op.z().data().dataType() != expectedType
                    && op.z().data().dataType() != DataBuffer.Type.COMPRESSED)
        throw new ND4JIllegalStateException("op.Z dataType is [" + op.z().data().dataType()
                        + "] instead of expected [" + expectedType + "]");

    if (op.y() != null && op.y().data().dataType() != expectedType)
        throw new ND4JIllegalStateException("op.Y dataType is [" + op.y().data().dataType()
                        + "] instead of expected [" + expectedType + "]");


}
 
Example 7
Source File: BaseNDArrayFactory.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * @param dtype the data opType
 * @param order the ordering
 */
protected BaseNDArrayFactory(DataBuffer.Type dtype, char order) {
    // this.dtype = dtype;
    if (Character.toLowerCase(order) != 'c' && Character.toLowerCase(order) != 'f')
        throw new IllegalArgumentException("Order must either be c or f");

    this.order = order;
}
 
Example 8
Source File: BaseNDArrayFactory.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 *
 * Initialize with the given data opType and ordering
 * The ndarray factory will use this for
 * @param dtype the data opType
 * @param order the ordering in mem
 */
protected BaseNDArrayFactory(DataBuffer.Type dtype, Character order) {
    // this.dtype = dtype;
    if (Character.toLowerCase(order) != 'c' && Character.toLowerCase(order) != 'f')
        throw new IllegalArgumentException("Order must either be c or f");

    this.order = order;
}
 
Example 9
Source File: AbstractCompressor.java    From nd4j with Apache License 2.0 5 votes vote down vote up
protected static DataBuffer.TypeEx convertType(DataBuffer.Type type) {
    if (type == DataBuffer.Type.HALF) {
        return DataBuffer.TypeEx.FLOAT16;
    } else if (type == DataBuffer.Type.FLOAT) {
        return DataBuffer.TypeEx.FLOAT;
    } else if (type == DataBuffer.Type.DOUBLE) {
        return DataBuffer.TypeEx.DOUBLE;
    } else
        throw new IllegalStateException("Unknown dataType: [" + type + "]");
}
 
Example 10
Source File: BaseNDArrayFactory.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public INDArray create(int[] shape, DataBuffer.Type dataType) {
    return create(shape, Nd4j.createBuffer(shape, dataType));
}
 
Example 11
Source File: Nd4jWorkspace.java    From nd4j with Apache License 2.0 4 votes vote down vote up
public PagedPointer alloc(long requiredMemory, DataBuffer.Type type, boolean initialize) {
    return alloc(requiredMemory, MemoryKind.HOST, type, initialize);
}
 
Example 12
Source File: OpExecutionerTestsC.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testVarianceSingleVsMultipleDimensions() {
    // this test should always run in double
    DataBuffer.Type type = Nd4j.dataType();
    DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE);
    Nd4j.getRandom().setSeed(12345);

    //Generate C order random numbers. Strides: [500,100,10,1]
    INDArray fourd = Nd4j.rand('c', new int[] {100, 5, 10, 10}).muli(10);
    INDArray twod = Shape.newShapeNoCopy(fourd, new int[] {100, 5 * 10 * 10}, false);

    //Population variance. These two should be identical
    INDArray var4 = fourd.var(false, 1, 2, 3);
    INDArray var2 = twod.var(false, 1);

    //Manual calculation of population variance, not bias corrected
    //https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Na.C3.AFve_algorithm
    double[] sums = new double[100];
    double[] sumSquares = new double[100];
    NdIndexIterator iter = new NdIndexIterator(fourd.shape());
    while (iter.hasNext()) {
        val next = iter.next();
        double d = fourd.getDouble(next);

        // FIXME: int cast
        sums[(int) next[0]] += d;
        sumSquares[(int) next[0]] += d * d;
    }

    double[] manualVariance = new double[100];
    val N = (fourd.length() / sums.length);
    for (int i = 0; i < sums.length; i++) {
        manualVariance[i] = (sumSquares[i] - (sums[i] * sums[i]) / N) / N;
    }

    INDArray var4bias = fourd.var(true, 1, 2, 3);
    INDArray var2bias = twod.var(true, 1);

    assertArrayEquals(var2.data().asDouble(), var4.data().asDouble(), 1e-5);
    assertArrayEquals(manualVariance, var2.data().asDouble(), 1e-5);
    assertArrayEquals(var2bias.data().asDouble(), var4bias.data().asDouble(), 1e-5);

    DataTypeUtil.setDTypeForContext(type);
}
 
Example 13
Source File: CpuNDArrayFactory.java    From nd4j with Apache License 2.0 4 votes vote down vote up
public CpuNDArrayFactory(DataBuffer.Type dtype, char order) {
    super(dtype, order);
}
 
Example 14
Source File: CpuNDArrayFactory.java    From nd4j with Apache License 2.0 4 votes vote down vote up
public CpuNDArrayFactory(DataBuffer.Type dtype, Character order) {
    super(dtype, order);
}
 
Example 15
Source File: CudaWorkspace.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public PagedPointer alloc(long requiredMemory, DataBuffer.Type type, boolean initialize) {
    return this.alloc(requiredMemory, MemoryKind.DEVICE, type, initialize);
}
 
Example 16
Source File: AllocationShape.java    From nd4j with Apache License 2.0 4 votes vote down vote up
public AllocationShape(long length, int elementSize, DataBuffer.Type dataType) {
    this.length = length;
    this.elementSize = elementSize;
    this.dataType = dataType;
}
 
Example 17
Source File: JCublasNDArrayFactory.java    From nd4j with Apache License 2.0 4 votes vote down vote up
public JCublasNDArrayFactory(DataBuffer.Type dtype, Character order) {
    super(dtype, order);
}
 
Example 18
Source File: BasicSerDeTests.java    From nd4j with Apache License 2.0 3 votes vote down vote up
@Test
public void testBasicDataTypeSwitch1() throws Exception {
    DataBuffer.Type initialType = Nd4j.dataType();
    Nd4j.setDataType(DataBuffer.Type.FLOAT);


    INDArray array = Nd4j.create(new float[] {1, 2, 3, 4, 5, 6});

    ByteArrayOutputStream bos = new ByteArrayOutputStream();

    Nd4j.write(bos, array);


    Nd4j.setDataType(DataBuffer.Type.DOUBLE);


    INDArray restored = Nd4j.read(new ByteArrayInputStream(bos.toByteArray()));

    assertEquals(Nd4j.create(new float[] {1, 2, 3, 4, 5, 6}), restored);

    assertEquals(8, restored.data().getElementSize());
    assertEquals(8, restored.shapeInfoDataBuffer().getElementSize());



    Nd4j.setDataType(initialType);
}
 
Example 19
Source File: MemoryWorkspace.java    From nd4j with Apache License 2.0 2 votes vote down vote up
/**
 * This method does allocation from a given Workspace
 *
 * @param requiredMemory allocation size, in bytes
 * @param dataType dataType that is going to be used
 * @return
 */
PagedPointer alloc(long requiredMemory, DataBuffer.Type dataType, boolean initialize);
 
Example 20
Source File: NDArrayFactory.java    From nd4j with Apache License 2.0 2 votes vote down vote up
/**
 *
 * @param shape
 * @param dataType
 * @return
 */
INDArray create(int[] shape, DataBuffer.Type dataType);