Java Code Examples for org.nd4j.linalg.factory.Nd4j#createFromNpzFile()

The following examples show how to use org.nd4j.linalg.factory.Nd4j#createFromNpzFile() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TestNDArrayCreation.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
@Ignore
public void testCreateNpz() throws Exception {
    Map<String, INDArray> map = Nd4j.createFromNpzFile(new ClassPathResource("nd4j-tests/test.npz").getFile());
    assertEquals(true, map.containsKey("x"));
    assertEquals(true, map.containsKey("y"));
    INDArray arrX = map.get("x");
    INDArray arrY = map.get("y");
    assertEquals(1.0, arrX.getDouble(0), 1e-1);
    assertEquals(2.0, arrX.getDouble(1), 1e-1);
    assertEquals(3.0, arrX.getDouble(2), 1e-1);
    assertEquals(4.0, arrX.getDouble(3), 1e-1);
    assertEquals(5.0, arrY.getDouble(0), 1e-1);
    assertEquals(6.0, arrY.getDouble(1), 1e-1);
    assertEquals(7.0, arrY.getDouble(2), 1e-1);
    assertEquals(8.0, arrY.getDouble(3), 1e-1);

}
 
Example 2
Source File: NumpyFormatTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void testNpzReading() throws Exception {

        val dir = testDir.newFolder();
        new ClassPathResource("numpy_arrays/npz/").copyDirectory(dir);

        File[] files = dir.listFiles();
        int cnt = 0;

        for(File f : files){
            if(!f.getPath().endsWith(".npz")){
                log.warn("Skipping: {}", f);
                continue;
            }

            String path = f.getAbsolutePath();
            int lastDot = path.lastIndexOf('.');
            int lastSlash = Math.max(path.lastIndexOf('/'), path.lastIndexOf('\\'));
            String dtype = path.substring(lastSlash+1, lastDot);
//            System.out.println(path + " : " + dtype);

            DataType dt = DataType.fromNumpy(dtype);
            //System.out.println(dt);

            INDArray arr = Nd4j.arange(12).castTo(dt).reshape(3,4);
            INDArray arr2 = Nd4j.linspace(DataType.FLOAT, 0, 3, 10);

            Map<String,INDArray> m = Nd4j.createFromNpzFile(f);
            assertEquals(2, m.size());
            assertTrue(m.containsKey("firstArr"));
            assertTrue(m.containsKey("secondArr"));

            assertEquals(arr, m.get("firstArr"));
            assertEquals(arr2, m.get("secondArr"));
            cnt++;
        }

        assertTrue(cnt > 0);
    }
 
Example 3
Source File: Nd4jValidator.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Validate whether the file represents a valid Numpy .npz file to be read with {@link Nd4j#createFromNpyFile(File)} }
 *
 * @param f File that should represent a Numpy .npz file written with Numpy savez method
 * @return Result of validation
 */
public static ValidationResult validateNpzFile(@NonNull File f) {
    ValidationResult vr = Nd4jCommonValidator.isValidFile(f, "Numpy .npz File", false);
    if (vr != null && !vr.isValid())
        return vr;

    Map<String, INDArray> m = null;
    try {
        m = Nd4j.createFromNpzFile(f);
    } catch (Throwable t) {
        return ValidationResult.builder()
                .valid(false)
                .formatType("Numpy .npz File")
                .path(Nd4jCommonValidator.getPath(f))
                .issues(Collections.singletonList("File may be corrupt or is not a Numpy .npz file"))
                .exception(t)
                .build();
    } finally {
        //Deallocate immediately
        if (m != null) {
            for (INDArray arr : m.values()) {
                if (arr != null) {
                    arr.close();
                }
            }
        }
    }

    return ValidationResult.builder()
            .valid(true)
            .formatType("Numpy .npz File")
            .path(Nd4jCommonValidator.getPath(f))
            .build();
}