Java Code Examples for org.nd4j.linalg.api.buffer.DataType#SHORT

The following examples show how to use org.nd4j.linalg.api.buffer.DataType#SHORT . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TestNativeImageLoader.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testDataTypes_2() throws Exception {
    val dtypes = new DataType[]{DataType.FLOAT, DataType.HALF, DataType.SHORT, DataType.INT};

    val dt = Nd4j.dataType();

    for (val dtype: dtypes) {
        Nd4j.setDataType(dtype);
        int w3 = 123, h3 = 77, ch3 = 3;
        val loader = new NativeImageLoader(h3, w3, 1);
        File f3 = new ClassPathResource("datavec-data-image/testimages/class0/2.jpg").getFile();
        val array = loader.asMatrix(f3);

        assertEquals(dtype, array.dataType());
    }

    Nd4j.setDataType(dt);
}
 
Example 2
Source File: TestNativeImageLoader.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testDataTypes_1() throws Exception {
    val dtypes = new DataType[]{DataType.FLOAT, DataType.HALF, DataType.SHORT, DataType.INT};

    val dt = Nd4j.dataType();

    for (val dtype: dtypes) {
        Nd4j.setDataType(dtype);
        int w3 = 123, h3 = 77, ch3 = 3;
        val loader = new NativeImageLoader(h3, w3, ch3);
        File f3 = new ClassPathResource("datavec-data-image/testimages/class0/2.jpg").getFile();
        ImageWritable iw3 = loader.asWritable(f3);

        val array = loader.asMatrix(iw3);

        assertEquals(dtype, array.dataType());
    }

    Nd4j.setDataType(dt);
}
 
Example 3
Source File: ArrayOptionsHelper.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static DataType dataType(long opt) {
    if (hasBitSet(opt, DTYPE_COMPRESSED_BIT))
        return DataType.COMPRESSED;
    else if (hasBitSet(opt, DTYPE_HALF_BIT))
        return DataType.HALF;
    else if (hasBitSet(opt, DTYPE_BFLOAT16_BIT))
        return DataType.BFLOAT16;
    else if (hasBitSet(opt, DTYPE_FLOAT_BIT))
        return DataType.FLOAT;
    else if (hasBitSet(opt, DTYPE_DOUBLE_BIT))
        return DataType.DOUBLE;
    else if (hasBitSet(opt, DTYPE_INT_BIT))
        return hasBitSet(opt, DTYPE_UNSIGNED_BIT) ? DataType.UINT32 : DataType.INT;
    else if (hasBitSet(opt, DTYPE_LONG_BIT))
        return hasBitSet(opt, DTYPE_UNSIGNED_BIT) ? DataType.UINT64 : DataType.LONG;
    else if (hasBitSet(opt, DTYPE_BOOL_BIT))
        return DataType.BOOL;
    else if (hasBitSet(opt, DTYPE_BYTE_BIT)) {
        return hasBitSet(opt, DTYPE_UNSIGNED_BIT) ? DataType.UBYTE : DataType.BYTE;     //Byte bit set for both UBYTE and BYTE
    } else if (hasBitSet(opt, DTYPE_SHORT_BIT))
        return hasBitSet(opt, DTYPE_UNSIGNED_BIT) ? DataType.UINT16 : DataType.SHORT;
    else if (hasBitSet(opt, DTYPE_UTF8_BIT))
        return DataType.UTF8;
    else
        throw new ND4JUnknownDataTypeException("Unknown extras set: [" + opt + "]");
}
 
Example 4
Source File: TensorflowConversion.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
private DataType typeFor(int tensorflowType) {
    switch(tensorflowType) {
        case DT_DOUBLE: return DataType.DOUBLE;
        case DT_FLOAT: return DataType.FLOAT;
        case DT_HALF: return DataType.HALF;
        case DT_INT16: return DataType.SHORT;
        case DT_INT32: return DataType.INT;
        case DT_INT64: return DataType.LONG;
        case DT_STRING: return DataType.UTF8;
        case DT_INT8: return DataType.BYTE;
        case DT_UINT8: return DataType.UBYTE;
        case DT_UINT16: return DataType.UINT16;
        case DT_UINT32: return DataType.UINT32;
        case DT_UINT64: return DataType.UINT64;
        case DT_BFLOAT16: return DataType.BFLOAT16;
        case DT_BOOL: return DataType.BOOL;
        default: throw new IllegalArgumentException("Illegal type " + tensorflowType);
    }
}
 
Example 5
Source File: SpecialTests.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
    public void reproduceWorkspaceCrash_4(){
        val conf = WorkspaceConfiguration.builder().build();

        val ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(conf, "WS");
        val dtypes = new DataType[]{DataType.LONG, DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT, DataType.SHORT, DataType.BYTE, DataType.UBYTE, DataType.BOOL};
        for (val dX : dtypes) {
            for (val dZ: dtypes) {
                try(val ws2 = Nd4j.getWorkspaceManager().getAndActivateWorkspace("WS")) {
                    val array = Nd4j.create(dX, 100, 100).assign(1);

//                    log.info("Trying to cast {} to {}", dX, dZ);
                    val casted = array.castTo(dZ);

                    val exp = Nd4j.create(dZ, 100, 100).assign(1);
                    assertEquals(exp, casted);
                }
            }
        }
    }
 
Example 6
Source File: SpecialTests.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
    public void reproduceWorkspaceCrash_3(){
        val conf = WorkspaceConfiguration.builder().build();

        val ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(conf, "WS");
        val dtypes = new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.LONG, DataType.INT, DataType.SHORT, DataType.BYTE, DataType.UBYTE, DataType.BOOL};
        for (val dX : dtypes) {
            for (val dZ: dtypes) {
                try(val ws2 = ws.notifyScopeEntered()) {
                    val array = Nd4j.create(dX, 2, 5).assign(1);
//                    log.info("Trying to cast {} to {}", dX, dZ);
                    val casted = array.castTo(dZ);
                    val exp = Nd4j.create(dZ, 2, 5).assign(1);
                    assertEquals(exp, casted);

                    Nd4j.getExecutioner().commit();
                }
            }
        }
    }
 
Example 7
Source File: ArrayOptionsHelper.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public static DataType convertToDataType(org.tensorflow.framework.DataType dataType) {
    switch (dataType) {
        case DT_UINT16:
            return DataType.UINT16;
        case DT_UINT32:
            return DataType.UINT32;
        case DT_UINT64:
            return DataType.UINT64;
        case DT_BOOL:
            return DataType.BOOL;
        case DT_BFLOAT16:
            return DataType.BFLOAT16;
        case DT_FLOAT:
            return DataType.FLOAT;
        case DT_INT32:
            return DataType.INT;
        case DT_INT64:
            return DataType.LONG;
        case DT_INT8:
            return DataType.BYTE;
        case DT_INT16:
            return DataType.SHORT;
        case DT_DOUBLE:
            return DataType.DOUBLE;
        case DT_UINT8:
            return DataType.UBYTE;
        case DT_HALF:
            return DataType.HALF;
        case DT_STRING:
            return DataType.UTF8;
        default:
            throw new UnsupportedOperationException("Unknown TF data type: [" + dataType.name() + "]");
    }
}
 
Example 8
Source File: JsonSerdeTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void testNDArrayTextSerializer() throws Exception {
        for(char order : new char[]{'c', 'f'}) {
            Nd4j.factory().setOrder(order);
            for (DataType globalDT : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
                Nd4j.setDefaultDataTypes(globalDT, globalDT);

                Nd4j.getRandom().setSeed(12345);
                INDArray in = Nd4j.rand(DataType.DOUBLE, 3, 4).muli(20).subi(10);

                val om = new ObjectMapper();

                for (DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.LONG, DataType.INT, DataType.SHORT,
                        DataType.BYTE, DataType.UBYTE, DataType.BOOL, DataType.UTF8}) {

                    INDArray arr;
                    if(dt == DataType.UTF8){
                        arr = Nd4j.create("aaaaa", "bbbb", "ccc", "dd", "e", "f", "g", "h", "i", "j", "k", "l").reshape('c', 3, 4);
                    } else {
                        arr = in.castTo(dt);
                    }

                    TestClass tc = new TestClass(arr);

                    String s = om.writeValueAsString(tc);
//                    System.out.println(dt);
//                    System.out.println(s);
//                    System.out.println("\n\n\n");

                    TestClass deserialized = om.readValue(s, TestClass.class);
                    assertEquals(dt.toString(), tc, deserialized);
                }
            }
        }
    }
 
Example 9
Source File: SpecialTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void reproduceWorkspaceCrash_2(){
        val dtypes = new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.LONG, DataType.INT, DataType.SHORT, DataType.BYTE, DataType.UBYTE, DataType.BOOL};
        for (val dX : dtypes) {
            for (val dZ: dtypes) {
                val array = Nd4j.create(dX, 2, 5).assign(1);

//                log.info("Trying to cast {} to {}", dX, dZ);
                val casted = array.castTo(dZ);

                val exp = Nd4j.create(dZ, 2, 5).assign(1);
                assertEquals(exp, casted);
            }
        }
    }
 
Example 10
Source File: CustomOpsTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testSizeTypes(){
    List<DataType> failed = new ArrayList<>();
    for(DataType dt : new DataType[]{DataType.LONG, DataType.INT, DataType.SHORT, DataType.BYTE,
            DataType.UINT64, DataType.UINT32, DataType.UINT16, DataType.UBYTE,
            DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.BFLOAT16}) {

        INDArray in = Nd4j.create(DataType.FLOAT, 100);
        INDArray out = Nd4j.scalar(dt, 0);
        INDArray e = Nd4j.scalar(dt, 100);

        DynamicCustomOp op = DynamicCustomOp.builder("size")
                .addInputs(in)
                .addOutputs(out)
                .build();

        try {
            Nd4j.exec(op);

            assertEquals(e, out);
        } catch (Throwable t){
            failed.add(dt);
        }
    }

    if(!failed.isEmpty()){
        fail("Failed datatypes: " + failed.toString());
    }
}
 
Example 11
Source File: PythonUtils.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
private static NumpyArray jsonToNumpyArray(JSONObject map) {
    String dtypeName = (String) map.get("dtype");
    DataType dtype;
    if (dtypeName.equals("float64")) {
        dtype = DataType.DOUBLE;
    } else if (dtypeName.equals("float32")) {
        dtype = DataType.FLOAT;
    } else if (dtypeName.equals("int16")) {
        dtype = DataType.SHORT;
    } else if (dtypeName.equals("int32")) {
        dtype = DataType.INT;
    } else if (dtypeName.equals("int64")) {
        dtype = DataType.LONG;
    } else {
        throw new RuntimeException("Unsupported array type " + dtypeName + ".");
    }
    List shapeList = map.getJSONArray("shape").toList();
    long[] shape = new long[shapeList.size()];
    for (int i = 0; i < shape.length; i++) {
        shape[i] = ((Number) shapeList.get(i)).longValue();
    }

    List strideList = map.getJSONArray("shape").toList();
    long[] stride = new long[strideList.size()];
    for (int i = 0; i < stride.length; i++) {
        stride[i] = ((Number) strideList.get(i)).longValue();
    }
    long address = ((Number) map.get("address")).longValue();
    NumpyArray numpyArray = new NumpyArray(address, shape, stride, dtype, true);
    return numpyArray;
}
 
Example 12
Source File: PythonUtils.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public static NumpyArray mapToNumpyArray(Map map) {
    String dtypeName = (String) map.get("dtype");
    DataType dtype;
    if (dtypeName.equals("float64")) {
        dtype = DataType.DOUBLE;
    } else if (dtypeName.equals("float32")) {
        dtype = DataType.FLOAT;
    } else if (dtypeName.equals("int16")) {
        dtype = DataType.SHORT;
    } else if (dtypeName.equals("int32")) {
        dtype = DataType.INT;
    } else if (dtypeName.equals("int64")) {
        dtype = DataType.LONG;
    } else {
        throw new RuntimeException("Unsupported array type " + dtypeName + ".");
    }
    List shapeList = (List) map.get("shape");
    long[] shape = new long[shapeList.size()];
    for (int i = 0; i < shape.length; i++) {
        shape[i] = (Long) shapeList.get(i);
    }

    List strideList = (List) map.get("shape");
    long[] stride = new long[strideList.size()];
    for (int i = 0; i < stride.length; i++) {
        stride[i] = (Long) strideList.get(i);
    }
    long address = (Long) map.get("address");
    NumpyArray numpyArray = new NumpyArray(address, shape, stride, dtype, true);
    return numpyArray;
}
 
Example 13
Source File: BaseCpuDataBuffer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
/**
 *
 * @param length
 * @param elementSize
 */
public BaseCpuDataBuffer(long length, int elementSize) {
    if (length < 1)
        throw new IllegalArgumentException("Length must be >= 1");
    initTypeAndSize();
    allocationMode = AllocUtil.getAllocationModeFromContext();
    this.length = length;
    this.underlyingLength = length;
    this.elementSize = (byte) elementSize;

    if (dataType() != DataType.UTF8)
        ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, dataType(), false);

    if (dataType() == DataType.DOUBLE) {
        pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asDoublePointer();

        indexer = DoubleIndexer.create((DoublePointer) pointer);
    } else if (dataType() == DataType.FLOAT) {
        pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asFloatPointer();

        setIndexer(FloatIndexer.create((FloatPointer) pointer));
    } else if (dataType() == DataType.INT32) {
        pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer();

        setIndexer(IntIndexer.create((IntPointer) pointer));
    } else if (dataType() == DataType.LONG) {
        pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer();

        setIndexer(LongIndexer.create((LongPointer) pointer));
    } else if (dataType() == DataType.SHORT) {
        pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();

        setIndexer(ShortIndexer.create((ShortPointer) pointer));
    } else if (dataType() == DataType.BYTE) {
        pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer();

        setIndexer(ByteIndexer.create((BytePointer) pointer));
    } else if (dataType() == DataType.UBYTE) {
        pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer();

        setIndexer(UByteIndexer.create((BytePointer) pointer));
    } else if (dataType() == DataType.UTF8) {
        ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, INT8, false);
        pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer();

        setIndexer(ByteIndexer.create((BytePointer) pointer));
    } else if(dataType() == DataType.FLOAT16){
        pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
        setIndexer(HalfIndexer.create((ShortPointer) pointer));
    } else if(dataType() == DataType.BFLOAT16){
        pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
        setIndexer(Bfloat16Indexer.create((ShortPointer) pointer));
    } else if(dataType() == DataType.BOOL){
        pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBoolPointer();
        setIndexer(BooleanIndexer.create((BooleanPointer) pointer));
    } else if(dataType() == DataType.UINT16){
        pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
        setIndexer(UShortIndexer.create((ShortPointer) pointer));
    } else if(dataType() == DataType.UINT32){
        pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer();
        setIndexer(UIntIndexer.create((IntPointer) pointer));
    } else if (dataType() == DataType.UINT64) {
        pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer();
        setIndexer(LongIndexer.create((LongPointer) pointer));
    }

    Nd4j.getDeallocatorService().pickObject(this);
}
 
Example 14
Source File: BaseCpuDataBuffer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public void actualizePointerAndIndexer() {
    val cptr = ptrDataBuffer.primaryBuffer();

    // skip update if pointers are equal
    if (cptr != null && pointer != null && cptr.address() == pointer.address())
        return;

    val t = dataType();
    if (t == DataType.BOOL) {
        pointer = new PagedPointer(cptr, length).asBoolPointer();
        setIndexer(BooleanIndexer.create((BooleanPointer) pointer));
    } else if (t == DataType.UBYTE) {
        pointer = new PagedPointer(cptr, length).asBytePointer();
        setIndexer(UByteIndexer.create((BytePointer) pointer));
    } else if (t == DataType.BYTE) {
        pointer = new PagedPointer(cptr, length).asBytePointer();
        setIndexer(ByteIndexer.create((BytePointer) pointer));
    } else if (t == DataType.UINT16) {
        pointer = new PagedPointer(cptr, length).asShortPointer();
        setIndexer(UShortIndexer.create((ShortPointer) pointer));
    } else if (t == DataType.SHORT) {
        pointer = new PagedPointer(cptr, length).asShortPointer();
        setIndexer(ShortIndexer.create((ShortPointer) pointer));
    } else if (t == DataType.UINT32) {
        pointer = new PagedPointer(cptr, length).asIntPointer();
        setIndexer(UIntIndexer.create((IntPointer) pointer));
    } else if (t == DataType.INT) {
        pointer = new PagedPointer(cptr, length).asIntPointer();
        setIndexer(IntIndexer.create((IntPointer) pointer));
    } else if (t == DataType.UINT64) {
        pointer = new PagedPointer(cptr, length).asLongPointer();
        setIndexer(LongIndexer.create((LongPointer) pointer));
    } else if (t == DataType.LONG) {
        pointer = new PagedPointer(cptr, length).asLongPointer();
        setIndexer(LongIndexer.create((LongPointer) pointer));
    } else if (t == DataType.BFLOAT16) {
        pointer = new PagedPointer(cptr, length).asShortPointer();
        setIndexer(Bfloat16Indexer.create((ShortPointer) pointer));
    } else if (t == DataType.HALF) {
        pointer = new PagedPointer(cptr, length).asShortPointer();
        setIndexer(HalfIndexer.create((ShortPointer) pointer));
    } else if (t == DataType.FLOAT) {
        pointer = new PagedPointer(cptr, length).asFloatPointer();
        setIndexer(FloatIndexer.create((FloatPointer) pointer));
    } else if (t == DataType.DOUBLE) {
        pointer = new PagedPointer(cptr, length).asDoublePointer();
        setIndexer(DoubleIndexer.create((DoublePointer) pointer));
    } else if (t == DataType.UTF8) {
        pointer = new PagedPointer(cptr, length()).asBytePointer();
        setIndexer(ByteIndexer.create((BytePointer) pointer));
    } else
        throw new IllegalArgumentException("Unknown datatype: " + dataType());
}
 
Example 15
Source File: Int16Buffer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
/**
 * Initialize the opType of this buffer
 */
@Override
protected void initTypeAndSize() {
    type = DataType.SHORT;
    elementSize = 2;
}
 
Example 16
Source File: CudaShortDataBuffer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public DataType dataType() {
    return DataType.SHORT;
}
 
Example 17
Source File: CudaShortDataBuffer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
/**
 * Initialize the opType of this buffer
 */
@Override
protected void initTypeAndSize() {
    elementSize = 2;
    type = DataType.SHORT;
}
 
Example 18
Source File: BaseCudaDataBuffer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public void actualizePointerAndIndexer() {
    val cptr = ptrDataBuffer.primaryBuffer();

    // skip update if pointers are equal
    if (cptr != null && pointer != null && cptr.address() == pointer.address())
        return;

    val t = dataType();
    if (t == DataType.BOOL) {
        pointer = new PagedPointer(cptr, length).asBoolPointer();
        setIndexer(BooleanIndexer.create((BooleanPointer) pointer));
    } else if (t == DataType.UBYTE) {
        pointer = new PagedPointer(cptr, length).asBytePointer();
        setIndexer(UByteIndexer.create((BytePointer) pointer));
    } else if (t == DataType.BYTE) {
        pointer = new PagedPointer(cptr, length).asBytePointer();
        setIndexer(ByteIndexer.create((BytePointer) pointer));
    } else if (t == DataType.UINT16) {
        pointer = new PagedPointer(cptr, length).asShortPointer();
        setIndexer(UShortIndexer.create((ShortPointer) pointer));
    } else if (t == DataType.SHORT) {
        pointer = new PagedPointer(cptr, length).asShortPointer();
        setIndexer(ShortIndexer.create((ShortPointer) pointer));
    } else if (t == DataType.UINT32) {
        pointer = new PagedPointer(cptr, length).asIntPointer();
        setIndexer(UIntIndexer.create((IntPointer) pointer));
    } else if (t == DataType.INT) {
        pointer = new PagedPointer(cptr, length).asIntPointer();
        setIndexer(IntIndexer.create((IntPointer) pointer));
    } else if (t == DataType.UINT64) {
        pointer = new PagedPointer(cptr, length).asLongPointer();
        setIndexer(LongIndexer.create((LongPointer) pointer));
    } else if (t == DataType.LONG) {
        pointer = new PagedPointer(cptr, length).asLongPointer();
        setIndexer(LongIndexer.create((LongPointer) pointer));
    } else if (t == DataType.BFLOAT16) {
        pointer = new PagedPointer(cptr, length).asShortPointer();
        setIndexer(Bfloat16Indexer.create((ShortPointer) pointer));
    } else if (t == DataType.HALF) {
        pointer = new PagedPointer(cptr, length).asShortPointer();
        setIndexer(HalfIndexer.create((ShortPointer) pointer));
    } else if (t == DataType.FLOAT) {
        pointer = new PagedPointer(cptr, length).asFloatPointer();
        setIndexer(FloatIndexer.create((FloatPointer) pointer));
    } else if (t == DataType.DOUBLE) {
        pointer = new PagedPointer(cptr, length).asDoublePointer();
        setIndexer(DoubleIndexer.create((DoublePointer) pointer));
    } else if (t == DataType.UTF8) {
        pointer = new PagedPointer(cptr, length()).asBytePointer();
        setIndexer(ByteIndexer.create((BytePointer) pointer));
    } else
        throw new IllegalArgumentException("Unknown datatype: " + dataType());
}
 
Example 19
Source File: ExecDebuggingListener.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
private static String createString(INDArray arr){
    StringBuilder sb = new StringBuilder();

    if(arr.isEmpty()){
        sb.append("Nd4j.empty(DataType.").append(arr.dataType()).append(");");
    } else {
        sb.append("Nd4j.createFromArray(");

        DataType dt = arr.dataType();
        switch (dt){
            case DOUBLE:
                double[] dArr = arr.dup().data().asDouble();
                sb.append(Arrays.toString(dArr).replaceAll("[\\[\\]]", ""));
                break;
            case FLOAT:
            case HALF:
            case BFLOAT16:
                float[] fArr = arr.dup().data().asFloat();
                sb.append(Arrays.toString(fArr)
                        .replaceAll(",", "f,")
                        .replaceAll("]", "f")
                        .replaceAll("[\\[\\]]", ""));
                break;
            case LONG:
            case UINT32:
            case UINT64:
                long[] lArr = arr.dup().data().asLong();
                sb.append(Arrays.toString(lArr)
                        .replaceAll(",", "L,")
                        .replaceAll("]", "L")
                        .replaceAll("[\\[\\]]", ""));
                break;
            case INT:
            case SHORT:
            case UBYTE:
            case BYTE:
            case UINT16:
            case BOOL:
                int[] iArr = arr.dup().data().asInt();
                sb.append(Arrays.toString(iArr).replaceAll("[\\[\\]]", ""));
                break;
            case UTF8:
                break;
            case COMPRESSED:
            case UNKNOWN:
                break;
        }

        sb.append(").reshape(").append(Arrays.toString(arr.shape()).replaceAll("[\\[\\]]", ""))
                .append(")");

        if(dt == DataType.HALF || dt == DataType.BFLOAT16 || dt == DataType.UINT32 || dt == DataType.UINT64 ||
                dt == DataType.SHORT || dt == DataType.UBYTE || dt == DataType.BYTE || dt == DataType.UINT16 || dt == DataType.BOOL){
            sb.append(".cast(DataType.").append(arr.dataType()).append(")");
        }
    }

    return sb.toString();
}
 
Example 20
Source File: PythonObject.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public NumpyArray toNumpy() throws PythonException{
    PyObject np = PyImport_ImportModule("numpy");
    PyObject ndarray = PyObject_GetAttrString(np, "ndarray");
    if (PyObject_IsInstance(nativePythonObject, ndarray) != 1){
        throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array.");
    }
    Py_DecRef(ndarray);
    Py_DecRef(np);

    Pointer objPtr = new Pointer(nativePythonObject);
    PyArrayObject npArr = new PyArrayObject(objPtr);
    Pointer ptr = PyArray_DATA(npArr);
    long[] shape = new long[PyArray_NDIM(npArr)];
    SizeTPointer shapePtr = PyArray_SHAPE(npArr);
    if (shapePtr != null)
        shapePtr.get(shape, 0, shape.length);
    long[] strides = new long[shape.length];
    SizeTPointer stridesPtr = PyArray_STRIDES(npArr);
    if (stridesPtr != null)
        stridesPtr.get(strides, 0, strides.length);
    int npdtype = PyArray_TYPE(npArr);

    DataType dtype;
    switch (npdtype){
        case NPY_DOUBLE:
            dtype = DataType.DOUBLE; break;
        case NPY_FLOAT:
            dtype = DataType.FLOAT; break;
        case NPY_SHORT:
            dtype = DataType.SHORT; break;
        case NPY_INT:
            dtype = DataType.INT32; break;
        case NPY_LONG:
            dtype = DataType.LONG; break;
        case NPY_UINT:
            dtype = DataType.UINT32; break;
        case NPY_BYTE:
            dtype = DataType.INT8; break;
        case NPY_UBYTE:
            dtype = DataType.UINT8; break;
        case NPY_BOOL:
            dtype = DataType.BOOL; break;
        case NPY_HALF:
            dtype = DataType.FLOAT16; break;
        case NPY_LONGLONG:
            dtype = DataType.INT64; break;
        case NPY_USHORT:
            dtype = DataType.UINT16; break;
        case NPY_ULONG:
        case NPY_ULONGLONG:
            dtype = DataType.UINT64; break;
        default:
                throw new PythonException("Unsupported array data type: " + npdtype);
    }

    return new NumpyArray(ptr.address(), shape, strides, dtype);

}