Java Code Examples for org.nd4j.linalg.factory.Nd4j#createFromNpyFile()

The following examples show how to use org.nd4j.linalg.factory.Nd4j#createFromNpyFile() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: NumpyFormatTests.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test(expected = RuntimeException.class)
public void readNumpyCorruptHeader1() throws Exception {
    File f = testDir.newFolder();

    File fValid = new ClassPathResource("numpy_arrays/arange_3,4_float32.npy").getFile();
    byte[] numpyBytes = FileUtils.readFileToByteArray(fValid);
    for( int i=0; i<10; i++ ){
        numpyBytes[i] = 0;
    }
    File fCorrupt = new File(f, "corrupt.npy");
    FileUtils.writeByteArrayToFile(fCorrupt, numpyBytes);

    INDArray exp = Nd4j.arange(12).castTo(DataType.FLOAT).reshape(3,4);

    INDArray act1 = Nd4j.createFromNpyFile(fValid);
    assertEquals(exp, act1);

    INDArray probablyShouldntLoad = Nd4j.createFromNpyFile(fCorrupt); //Loads fine
    boolean eq = exp.equals(probablyShouldntLoad); //And is actually equal content
}
 
Example 2
Source File: NumpyFormatTests.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test(expected = RuntimeException.class)
public void readNumpyCorruptHeader2() throws Exception {
    File f = testDir.newFolder();

    File fValid = new ClassPathResource("numpy_arrays/arange_3,4_float32.npy").getFile();
    byte[] numpyBytes = FileUtils.readFileToByteArray(fValid);
    for( int i=1; i<10; i++ ){
        numpyBytes[i] = 0;
    }
    File fCorrupt = new File(f, "corrupt.npy");
    FileUtils.writeByteArrayToFile(fCorrupt, numpyBytes);

    INDArray exp = Nd4j.arange(12).castTo(DataType.FLOAT).reshape(3,4);

    INDArray act1 = Nd4j.createFromNpyFile(fValid);
    assertEquals(exp, act1);

    INDArray probablyShouldntLoad = Nd4j.createFromNpyFile(fCorrupt); //Loads fine
    boolean eq = exp.equals(probablyShouldntLoad); //And is actually equal content
}
 
Example 3
Source File: ImportModelDebugger.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static Map<String, INDArray> loadPlaceholders(File rootDir){
    File dir = new File(rootDir, "__placeholders");
    if(!dir.exists()){
        throw new IllegalStateException("Cannot find placeholders: directory does not exist: " + dir.getAbsolutePath());
    }

    Map<String, INDArray> ret = new HashMap<>();
    Iterator<File> iter = FileUtils.iterateFiles(dir, null, true);
    while(iter.hasNext()){
        File f = iter.next();
        if(!f.isFile())
            continue;
        String s = dir.toURI().relativize(f.toURI()).getPath();
        int idx = s.lastIndexOf("__");
        String name = s.substring(0, idx);
        INDArray arr = Nd4j.createFromNpyFile(f);
        ret.put(name, arr);
    }

    return ret;
}
 
Example 4
Source File: TestNDArrayCreation.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
@Ignore
public void testCreateNpy() throws Exception {
    INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("test.npy").getFile());
    assertEquals(2, arrCreate.size(0));
    assertEquals(2, arrCreate.size(1));
    assertEquals(1.0, arrCreate.getDouble(0, 0), 1e-1);
    assertEquals(2.0, arrCreate.getDouble(0, 1), 1e-1);
    assertEquals(3.0, arrCreate.getDouble(1, 0), 1e-1);
    assertEquals(4.0, arrCreate.getDouble(1, 1), 1e-1);

}
 
Example 5
Source File: TestNDArrayCreation.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testCreateNpy3() throws Exception {
    INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("rank3.npy").getFile());
    assertEquals(8, arrCreate.length());
    assertEquals(3, arrCreate.rank());

    Pointer pointer = NativeOpsHolder.getInstance().getDeviceNativeOps()
                    .pointerForAddress(arrCreate.data().address());
    assertEquals(arrCreate.data().address(), pointer.address());
}
 
Example 6
Source File: TestNDArrayCreation.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
@Ignore
public void testCreateNpy() throws Exception {
    INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/test.npy").getFile());
    assertEquals(2, arrCreate.size(0));
    assertEquals(2, arrCreate.size(1));
    assertEquals(1.0, arrCreate.getDouble(0, 0), 1e-1);
    assertEquals(2.0, arrCreate.getDouble(0, 1), 1e-1);
    assertEquals(3.0, arrCreate.getDouble(1, 0), 1e-1);
    assertEquals(4.0, arrCreate.getDouble(1, 1), 1e-1);

}
 
Example 7
Source File: TestNDArrayCreation.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
@Ignore("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
public void testCreateNpy3() throws Exception {
    INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/rank3.npy").getFile());
    assertEquals(8, arrCreate.length());
    assertEquals(3, arrCreate.rank());

    Pointer pointer = NativeOpsHolder.getInstance().getDeviceNativeOps()
                    .pointerForAddress(arrCreate.data().address());
    assertEquals(arrCreate.data().address(), pointer.address());
}
 
Example 8
Source File: NumpyFormatTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Ignore
    @Test
    public void testNumpyBoolean() {
        INDArray out = Nd4j.createFromNpyFile(new File("c:/Users/raver/Downloads/error2.npy"));
//        System.out.println(ArrayUtil.toList(ArrayUtil.toInts(out.shape())));
//        System.out.println(out);
    }
 
Example 9
Source File: Nd4jValidator.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Validate whether the file represents a valid Numpy .npy file to be read with {@link Nd4j#createFromNpyFile(File)} }
 *
 * @param f File that should represent a Numpy .npy file written with Numpy save method
 * @return Result of validation
 */
public static ValidationResult validateNpyFile(@NonNull File f) {

    ValidationResult vr = Nd4jCommonValidator.isValidFile(f, "Numpy .npy File", false);
    if (vr != null && !vr.isValid())
        return vr;

    //TODO let's do this without reading whole thing into memory
    try (INDArray arr = Nd4j.createFromNpyFile(f)) {   //Using the fact that INDArray.close() exists -> deallocate memory as soon as reading is done
    } catch (Throwable t) {
        if (t instanceof OutOfMemoryError || t.getMessage().toLowerCase().contains("failed to allocate")) {
            //This is a memory exception during reading... result is indeterminant (might be valid, might not be, can't tell here)
            return ValidationResult.builder()
                    .valid(true)
                    .formatType("Numpy .npy File")
                    .path(Nd4jCommonValidator.getPath(f))
                    .build();
        }

        return ValidationResult.builder()
                .valid(false)
                .formatType("Numpy .npy File")
                .path(Nd4jCommonValidator.getPath(f))
                .issues(Collections.singletonList("File may be corrupt or is not a Numpy .npy file"))
                .exception(t)
                .build();
    }

    return ValidationResult.builder()
            .valid(true)
            .formatType("Numpy .npy File")
            .path(Nd4jCommonValidator.getPath(f))
            .build();
}
 
Example 10
Source File: FullModelComparisons.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void cnnBatchNormTest() throws IOException, UnsupportedKerasConfigurationException,
        InvalidKerasConfigurationException {

    String modelPath = "modelimport/keras/fullconfigs/cnn/cnn_batch_norm.h5";

    KerasSequentialModel kerasModel = new KerasModel().modelBuilder()
            .modelHdf5Filename(Resources.asFile(modelPath).getAbsolutePath())
            .enforceTrainingConfig(false)
            .buildSequential();

    MultiLayerNetwork model = kerasModel.getMultiLayerNetwork();
    model.init();

    System.out.println(model.summary());

    INDArray input = Nd4j.createFromNpyFile(Resources.asFile("modelimport/keras/fullconfigs/cnn/input.npy"));
    input = input.permute(0, 3, 1, 2);
    assertTrue(Arrays.equals(input.shape(), new long[] {5, 3, 10, 10}));

    INDArray output = model.output(input);

    INDArray kerasOutput = Nd4j.createFromNpyFile(Resources.asFile("modelimport/keras/fullconfigs/cnn/predictions.npy"));

    for (int i = 0; i < 5; i++) {
        TestCase.assertEquals(output.getDouble(i), kerasOutput.getDouble(i), 1e-4);
    }
}
 
Example 11
Source File: FullModelComparisons.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void cnnBatchNormLargerTest() throws IOException, UnsupportedKerasConfigurationException,
        InvalidKerasConfigurationException {

    String modelPath = "modelimport/keras/fullconfigs/cnn_batch_norm/cnn_batch_norm_medium.h5";

    KerasSequentialModel kerasModel = new KerasModel().modelBuilder()
            .modelHdf5Filename(Resources.asFile(modelPath).getAbsolutePath())
            .enforceTrainingConfig(false)
            .buildSequential();

    MultiLayerNetwork model = kerasModel.getMultiLayerNetwork();
    model.init();

    System.out.println(model.summary());

    INDArray input = Nd4j.createFromNpyFile(Resources.asFile("modelimport/keras/fullconfigs/cnn_batch_norm/input.npy"));
    input = input.permute(0, 3, 1, 2);
    assertTrue(Arrays.equals(input.shape(), new long[] {5, 1, 48, 48}));

    INDArray output = model.output(input);

    INDArray kerasOutput = Nd4j.createFromNpyFile(Resources.asFile("modelimport/keras/fullconfigs/cnn_batch_norm/predictions.npy"));

    for (int i = 0; i < 5; i++) {
        // TODO this should be a little closer
        TestCase.assertEquals(output.getDouble(i), kerasOutput.getDouble(i), 1e-2);
    }
}
 
Example 12
Source File: NumpyFormatTests.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testNpy() throws Exception {
        for(boolean empty : new boolean[]{false, true}) {
            val dir = testDir.newFolder();
            if(!empty) {
                new ClassPathResource("numpy_arrays/npy/3,4/").copyDirectory(dir);
            } else {
                new ClassPathResource("numpy_arrays/npy/0,3_empty/").copyDirectory(dir);
            }

            File[] files = dir.listFiles();
            int cnt = 0;

            for (File f : files) {
                if (!f.getPath().endsWith(".npy")) {
                    log.warn("Skipping: {}", f);
                    continue;
                }

                String path = f.getAbsolutePath();
                int lastDot = path.lastIndexOf('.');
                int lastUnderscore = path.lastIndexOf('_');
                String dtype = path.substring(lastUnderscore + 1, lastDot);
//                System.out.println(path + " : " + dtype);

                DataType dt = DataType.fromNumpy(dtype);
                //System.out.println(dt);

                INDArray exp;
                if(empty){
                    exp = Nd4j.create(dt, 0, 3);
                } else {
                    exp = Nd4j.arange(12).castTo(dt).reshape(3, 4);
                }
                INDArray act = Nd4j.createFromNpyFile(f);

                assertEquals("Failed with file [" + f.getName() + "]", exp, act);
                cnt++;
            }

            assertTrue(cnt > 0);
        }
    }
 
Example 13
Source File: NumpyFormatTests.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testFromNumpyScalar() throws Exception {
    val out = Nd4j.createFromNpyFile(new ClassPathResource("numpy_oneoff/scalar.npy").getFile());
    assertEquals(Nd4j.scalar(DataType.INT, 1), out);
}
 
Example 14
Source File: NumpyFormatTests.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test(expected = IllegalArgumentException.class)
public void testAbsentNumpyFile_1() throws Exception {
    val f = new File("pew-pew-zomg.some_extension_that_wont_exist");
    INDArray act1 = Nd4j.createFromNpyFile(f);
}
 
Example 15
Source File: NumpyFormatTests.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test(expected = IllegalArgumentException.class)
public void testAbsentNumpyFile_2() throws Exception {
    val f = new File("c:/develop/batch-x-1.npy");
    INDArray act1 = Nd4j.createFromNpyFile(f);
    log.info("Array shape: {}; sum: {};", act1.shape(), act1.sumNumber().doubleValue());
}