Java Code Examples for org.nd4j.linalg.factory.Nd4j#readBinary()

The following examples show how to use org.nd4j.linalg.factory.Nd4j#readBinary() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ArraySavingListener.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public static void compare(File dir1, File dir2, double eps) throws Exception {
    File[] files1 = dir1.listFiles();
    File[] files2 = dir2.listFiles();
    Preconditions.checkNotNull(files1, "No files in directory 1: %s", dir1);
    Preconditions.checkNotNull(files2, "No files in directory 2: %s", dir2);
    Preconditions.checkState(files1.length == files2.length, "Different number of files: %s vs %s", files1.length, files2.length);

    Map<String,File> m1 = toMap(files1);
    Map<String,File> m2 = toMap(files2);

    for(File f : files1){
        String name = f.getName();
        String varName = name.substring(name.indexOf('_') + 1, name.length()-4); //Strip "x_" and ".bin"
        File f2 = m2.get(varName);

        INDArray arr1 = Nd4j.readBinary(f);
        INDArray arr2 = Nd4j.readBinary(f2);

        //TODO String arrays won't work here!
        boolean eq = arr1.equalsWithEps(arr2, eps);
        if(eq){
            System.out.println("Equals: " + varName.replaceAll("__", "/"));
        } else {
            if(arr1.dataType() == DataType.BOOL){
                INDArray xor = Nd4j.exec(new Xor(arr1, arr2));
                int count = xor.castTo(DataType.INT).sumNumber().intValue();
                System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - boolean, # differences = " + count);
                System.out.println("\t" + f.getAbsolutePath());
                System.out.println("\t" + f2.getAbsolutePath());
                xor.close();
            } else {
                INDArray sub = arr1.sub(arr2);
                INDArray diff = Nd4j.math.abs(sub);
                double maxDiff = diff.maxNumber().doubleValue();
                System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - max difference = " + maxDiff);
                System.out.println("\t" + f.getAbsolutePath());
                System.out.println("\t" + f2.getAbsolutePath());
                sub.close();
                diff.close();
            }
        }
        arr1.close();
        arr2.close();
    }
}
 
Example 2
Source File: Nd4jValidator.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
/**
 * Validate whether the file represents a valid INDArray (of one of the allowed/specified data types) saved previously
 * with {@link Nd4j#saveBinary(INDArray, File)} to be read with {@link Nd4j#readBinary(File)}
 *
 * @param f                  File that should represent an INDArray saved with Nd4j.saveBinary
 * @param allowableDataTypes May be null. If non-null, the file must represent one of the specified data types
 * @return Result of validation
 */
public static ValidationResult validateINDArrayFile(@NonNull File f, DataType... allowableDataTypes) {

    ValidationResult vr = Nd4jCommonValidator.isValidFile(f, "INDArray File", false);
    if (vr != null && !vr.isValid()) {
        vr.setFormatClass(INDArray.class);
        return vr;
    }

    //TODO let's do this without reading the whole thing into memory - check header + length...
    try (INDArray arr = Nd4j.readBinary(f)) {   //Using the fact that INDArray.close() exists -> deallocate memory as soon as reading is done
        if (allowableDataTypes != null) {
            ArrayUtils.contains(allowableDataTypes, arr.dataType());
        }
    } catch (IOException e) {
        return ValidationResult.builder()
                .valid(false)
                .formatType("INDArray File")
                .formatClass(INDArray.class)
                .path(Nd4jCommonValidator.getPath(f))
                .issues(Collections.singletonList("Unable to read file (IOException)"))
                .exception(e)
                .build();
    } catch (Throwable t) {
        if (t instanceof OutOfMemoryError || t.getMessage().toLowerCase().contains("failed to allocate")) {
            //This is a memory exception during reading... result is indeterminant (might be valid, might not be, can't tell here)
            return ValidationResult.builder()
                    .valid(true)
                    .formatType("INDArray File")
                    .formatClass(INDArray.class)
                    .path(Nd4jCommonValidator.getPath(f))
                    .build();
        }

        return ValidationResult.builder()
                .valid(false)
                .formatType("INDArray File")
                .formatClass(INDArray.class)
                .path(Nd4jCommonValidator.getPath(f))
                .issues(Collections.singletonList("File may be corrupt or is not a binary INDArray file"))
                .exception(t)
                .build();
    }

    return ValidationResult.builder()
            .valid(true)
            .formatType("INDArray File")
            .formatClass(INDArray.class)
            .path(Nd4jCommonValidator.getPath(f))
            .build();
}
 
Example 3
Source File: DistributionStats.java    From nd4j with Apache License 2.0 2 votes vote down vote up
/**
 * Load distribution statistics from the file system
 *
 * @param meanFile file containing the means
 * @param stdFile  file containing the standard deviations
 */
public static DistributionStats load(@NonNull File meanFile, @NonNull File stdFile) throws IOException {
    return new DistributionStats(Nd4j.readBinary(meanFile), Nd4j.readBinary(stdFile));
}
 
Example 4
Source File: StandardScaler.java    From nd4j with Apache License 2.0 2 votes vote down vote up
/**
 * Load the given mean and std
 * @param mean the mean file
 * @param std the std file
 * @throws IOException
 */
public void load(File mean, File std) throws IOException {
    this.mean = Nd4j.readBinary(mean);
    this.std = Nd4j.readBinary(std);
}
 
Example 5
Source File: DistributionStats.java    From deeplearning4j with Apache License 2.0 2 votes vote down vote up
/**
 * Load distribution statistics from the file system
 *
 * @param meanFile file containing the means
 * @param stdFile  file containing the standard deviations
 */
public static DistributionStats load(@NonNull File meanFile, @NonNull File stdFile) throws IOException {
    return new DistributionStats(Nd4j.readBinary(meanFile), Nd4j.readBinary(stdFile));
}
 
Example 6
Source File: StandardScaler.java    From deeplearning4j with Apache License 2.0 2 votes vote down vote up
/**
 * Load the given mean and std
 * @param mean the mean file
 * @param std the std file
 * @throws IOException
 */
public void load(File mean, File std) throws IOException {
    this.mean = Nd4j.readBinary(mean);
    this.std = Nd4j.readBinary(std);
}