Java Code Examples for org.nd4j.linalg.dataset.DataSet#save()

The following examples show how to use org.nd4j.linalg.dataset.DataSet#save() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: SporadicTests.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testDataSetSaveLost() throws Exception {
    INDArray features = Nd4j.linspace(1, 16 * 784, 16 * 784).reshape(16, 784);
    INDArray labels = Nd4j.linspace(1, 160, 160).reshape(16, 10);

    for (int i = 0; i < 100; i++) {
        DataSet ds = new DataSet(features, labels);

        File tempFile = File.createTempFile("dataset", "temp");
        tempFile.deleteOnExit();

        ds.save(tempFile);

        DataSet restore = new DataSet();
        restore.load(tempFile);

        assertEquals(features, restore.getFeatureMatrix());
        assertEquals(labels, restore.getLabels());

    }
}
 
Example 2
Source File: InFileDataSetCache.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public void put(String key, DataSet dataSet) {
    File file = resolveKey(key);

    File parentDir = file.getParentFile();
    if (!parentDir.exists()) {
        if (!parentDir.mkdirs()) {
            throw new IllegalStateException("ERROR: cannot create parent directory: " + parentDir);
        }
    }

    if (file.exists()) {
        file.delete();
    }

    dataSet.save(file);
}
 
Example 3
Source File: InFileDataSetCache.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void put(String key, DataSet dataSet) {
    File file = resolveKey(key);

    File parentDir = file.getParentFile();
    if (!parentDir.exists()) {
        if (!parentDir.mkdirs()) {
            throw new IllegalStateException("ERROR: cannot create parent directory: " + parentDir);
        }
    }

    if (file.exists()) {
        file.delete();
    }

    dataSet.save(file);
}
 
Example 4
Source File: BatchAndExportDataSetsFunction.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
private String export(DataSet dataSet, int partitionIdx, int outputCount) throws Exception {
    String filename = "dataset_" + partitionIdx + jvmuid + "_" + outputCount + ".bin";

    URI uri = new URI(exportBaseDirectory
                    + (exportBaseDirectory.endsWith("/") || exportBaseDirectory.endsWith("\\") ? "" : "/")
                    + filename);

    Configuration c = conf == null ? DefaultHadoopConfig.get() : conf.getValue().getConfiguration();

    FileSystem file = FileSystem.get(uri, c);
    try (FSDataOutputStream out = file.create(new Path(uri))) {
        dataSet.save(out);
    }

    return uri.toString();
}
 
Example 5
Source File: DataSetExportFunction.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void call(Iterator<DataSet> iter) throws Exception {
    String jvmuid = UIDProvider.getJVMUID();
    uid = Thread.currentThread().getId() + jvmuid.substring(0, Math.min(8, jvmuid.length()));

    Configuration c = conf == null ? DefaultHadoopConfig.get() : conf.getValue().getConfiguration();

    while (iter.hasNext()) {
        DataSet next = iter.next();

        String filename = "dataset_" + uid + "_" + (outputCount++) + ".bin";

        String path = outputDir.getPath();
        URI uri = new URI(path + (path.endsWith("/") || path.endsWith("\\") ? "" : "/") + filename);
        FileSystem file = FileSystem.get(uri, c);
        try (FSDataOutputStream out = file.create(new Path(uri))) {
            next.save(out);
        }
    }
}
 
Example 6
Source File: StringToDataSetExportFunction.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
private void processBatchIfRequired(List<List<Writable>> list, boolean finalRecord) throws Exception {
    if (list.isEmpty())
        return;
    if (list.size() < batchSize && !finalRecord)
        return;

    RecordReader rr = new CollectionRecordReader(list);
    RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(rr, null, batchSize, labelIndex, labelIndex, numPossibleLabels, -1, regression);

    DataSet ds = iter.next();

    String filename = "dataset_" + uid + "_" + (outputCount++) + ".bin";

    URI uri = new URI(outputDir.getPath() + "/" + filename);
    Configuration c = conf == null ? DefaultHadoopConfig.get() : conf.getValue().getConfiguration();
    FileSystem file = FileSystem.get(uri, c);
    try (FSDataOutputStream out = file.create(new Path(uri))) {
        ds.save(out);
    }

    list.clear();
}
 
Example 7
Source File: SaveFeaturizedDataExample.java    From Java-Deep-Learning-Cookbook with MIT License 5 votes vote down vote up
private static void saveToDisk(DataSet currentFeaturized, int iterNum, boolean isTrain){
    File fileFolder = isTrain ? new File("{PATH-TO-SAVE-TRAIN-SAMPLES}"): new File("{PATH-TO-SAVE-TEST-SAMPLES}");
    if (iterNum == 0) {
        fileFolder.mkdirs();
    }
    String fileName = "churn-" + featurizeExtractionLayer + "-";
    fileName += isTrain ? "train-" : "test-";
    fileName += iterNum + ".bin";
    currentFeaturized.save(new File(fileFolder,fileName));
    log.info("Saved " + (isTrain?"train ":"test ") + "dataset #"+ iterNum);
}
 
Example 8
Source File: SaveFeaturizedDataExample.java    From Java-Deep-Learning-Cookbook with MIT License 5 votes vote down vote up
private static void saveToDisk(DataSet currentFeaturized, int iterNum, boolean isTrain){
    File fileFolder = isTrain ? new File("{PATH-TO-SAVE-TRAIN-SAMPLES}"): new File("{PATH-TO-SAVE-TEST-SAMPLES}");
    if (iterNum == 0) {
        fileFolder.mkdirs();
    }
    String fileName = "churn-" + featurizeExtractionLayer + "-";
    fileName += isTrain ? "train-" : "test-";
    fileName += iterNum + ".bin";
    currentFeaturized.save(new File(fileFolder,fileName));
    log.info("Saved " + (isTrain?"train ":"test ") + "dataset #"+ iterNum);
}
 
Example 9
Source File: EndlessWorkspaceTests.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void endlessTestSerDe1() throws Exception {
    INDArray features = Nd4j.create(32, 3, 224, 224);
    INDArray labels = Nd4j.create(32, 200);
    File tmp = File.createTempFile("12dadsad", "dsdasds");
    float[] array = new float[33 * 3 * 224 * 224];
    DataSet ds = new DataSet(features, labels);
    ds.save(tmp);

    WorkspaceConfiguration wsConf = WorkspaceConfiguration.builder().initialSize(0)
                    .policyLearning(LearningPolicy.FIRST_LOOP).build();

    while (true) {

        try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsConf, "serde")) {
            /*
                        try (FileOutputStream fos = new FileOutputStream(tmp); BufferedOutputStream bos = new BufferedOutputStream(fos)) {
            SerializationUtils.serialize(array, fos);
                        }
            
                        try (FileInputStream fis = new FileInputStream(tmp); BufferedInputStream bis = new BufferedInputStream(fis)) {
            long time1 = System.currentTimeMillis();
            float[] arrayR = (float[]) SerializationUtils.deserialize(bis);
            long time2 = System.currentTimeMillis();
            
            log.info("Load time: {}", time2 - time1);
                        }
            */



            long time1 = System.currentTimeMillis();
            ds.load(tmp);
            long time2 = System.currentTimeMillis();

            log.info("Load time: {}", time2 - time1);
        }
    }
}
 
Example 10
Source File: InMemoryDataSetCache.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public void put(String key, DataSet dataSet) {
    if (cache.containsKey(key)) {
        log.debug("evicting key %s from data set cache", key);
        cache.remove(key);
    }

    ByteArrayOutputStream os = new ByteArrayOutputStream();

    dataSet.save(os);

    cache.put(key, os.toByteArray());
}
 
Example 11
Source File: EndlessWorkspaceTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void endlessTestSerDe1() throws Exception {
    INDArray features = Nd4j.create(32, 3, 224, 224);
    INDArray labels = Nd4j.create(32, 200);
    File tmp = File.createTempFile("12dadsad", "dsdasds");
    float[] array = new float[33 * 3 * 224 * 224];
    DataSet ds = new DataSet(features, labels);
    ds.save(tmp);

    WorkspaceConfiguration wsConf = WorkspaceConfiguration.builder().initialSize(0)
                    .policyLearning(LearningPolicy.FIRST_LOOP).build();

    while (true) {

        try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsConf, "serde")) {
            /*
                        try (FileOutputStream fos = new FileOutputStream(tmp); BufferedOutputStream bos = new BufferedOutputStream(fos)) {
            SerializationUtils.serialize(array, fos);
                        }
            
                        try (FileInputStream fis = new FileInputStream(tmp); BufferedInputStream bis = new BufferedInputStream(fis)) {
            long time1 = System.currentTimeMillis();
            float[] arrayR = (float[]) SerializationUtils.deserialize(bis);
            long time2 = System.currentTimeMillis();
            
            log.info("Load time: {}", time2 - time1);
                        }
            */



            long time1 = System.currentTimeMillis();
            ds.load(tmp);
            long time2 = System.currentTimeMillis();

            log.info("Load time: {}", time2 - time1);
        }
    }
}
 
Example 12
Source File: InMemoryDataSetCache.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void put(String key, DataSet dataSet) {
    if (cache.containsKey(key)) {
        log.debug("evicting key %s from data set cache", key);
        cache.remove(key);
    }

    ByteArrayOutputStream os = new ByteArrayOutputStream();

    dataSet.save(os);

    cache.put(key, os.toByteArray());
}
 
Example 13
Source File: GradientSharingTrainingTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test @Ignore //AB https://github.com/eclipse/deeplearning4j/issues/8985
public void differentNetsTrainingTest() throws Exception {
    int batch = 3;

    File temp = testDir.newFolder();
    DataSet ds = new IrisDataSetIterator(150, 150).next();
    List<DataSet> list = ds.asList();
    Collections.shuffle(list, new Random(12345));
    int pos = 0;
    int dsCount = 0;
    while (pos < list.size()) {
        List<DataSet> l2 = new ArrayList<>();
        for (int i = 0; i < 3 && pos < list.size(); i++) {
            l2.add(list.get(pos++));
        }
        DataSet d = DataSet.merge(l2);
        File f = new File(temp, dsCount++ + ".bin");
        d.save(f);
    }

    INDArray last = null;
    INDArray lastDup = null;
    for (int i = 0; i < 2; i++) {
        System.out.println("||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||");
        log.info("Starting: {}", i);

        MultiLayerConfiguration conf;
        if (i == 0) {
            conf = new NeuralNetConfiguration.Builder()
                    .weightInit(WeightInit.XAVIER)
                    .seed(12345)
                    .list()
                    .layer(new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
        } else {
            conf = new NeuralNetConfiguration.Builder()
                    .weightInit(WeightInit.XAVIER)
                    .seed(12345)
                    .list()
                    .layer(new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build())
                    .layer(new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
        }
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();


        //TODO this probably won't work everywhere...
        String controller = Inet4Address.getLocalHost().getHostAddress();
        String networkMask = controller.substring(0, controller.lastIndexOf('.')) + ".0" + "/16";

        VoidConfiguration voidConfiguration = VoidConfiguration.builder()
                .unicastPort(40123) // Should be open for IN/OUT communications on all Spark nodes
                .networkMask(networkMask) // Local network mask
                .controllerAddress(controller)
                .build();
        TrainingMaster tm = new SharedTrainingMaster.Builder(voidConfiguration, 2, new FixedThresholdAlgorithm(1e-4), batch)
                .rngSeed(12345)
                .collectTrainingStats(false)
                .batchSizePerWorker(batch) // Minibatch size for each worker
                .workersPerNode(2) // Workers per node
                .build();


        SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, net, tm);

        //System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));

        String fitPath = "file:///" + temp.getAbsolutePath().replaceAll("\\\\", "/");
        INDArray paramsBefore = net.params().dup();
        for( int j=0; j<3; j++ ) {
            sparkNet.fit(fitPath);
        }

        INDArray paramsAfter = net.params();
        assertNotEquals(paramsBefore, paramsAfter);

        //Also check we don't have any issues
        if(i == 0) {
            last = sparkNet.getNetwork().params();
            lastDup = last.dup();
        } else {
            assertEquals(lastDup, last);
        }
    }
}
 
Example 14
Source File: GradientSharingTrainingTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test @Ignore
public void testEpochUpdating() throws Exception {
    //Ensure that epoch counter is incremented properly on the workers

    File temp = testDir.newFolder();

    //TODO this probably won't work everywhere...
    String controller = Inet4Address.getLocalHost().getHostAddress();
    String networkMask = controller.substring(0, controller.lastIndexOf('.')) + ".0" + "/16";

    VoidConfiguration voidConfiguration = VoidConfiguration.builder()
            .unicastPort(40123) // Should be open for IN/OUT communications on all Spark nodes
            .networkMask(networkMask) // Local network mask
            .controllerAddress(controller)
            .meshBuildMode(MeshBuildMode.PLAIN) // everyone is connected to the master
            .build();
    SharedTrainingMaster tm = new SharedTrainingMaster.Builder(voidConfiguration, 2, new AdaptiveThresholdAlgorithm(1e-3), 16)
            .rngSeed(12345)
            .collectTrainingStats(false)
            .batchSizePerWorker(16) // Minibatch size for each worker
            .workersPerNode(2) // Workers per node
            .exportDirectory("file:///" + temp.getAbsolutePath().replaceAll("\\\\", "/"))
            .build();


    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(12345)
            .updater(new AMSGrad(0.001))
            .graphBuilder()
            .addInputs("in")
            .layer("out", new OutputLayer.Builder().nIn(784).nOut(10).activation(Activation.SOFTMAX)
                    .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
            .setOutputs("out")
            .build();


    SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, tm);
    sparkNet.setListeners(new TestListener());

    DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
    int count = 0;
    List<String> paths = new ArrayList<>();
    List<DataSet> ds = new ArrayList<>();
    File f = testDir.newFolder();
    while (iter.hasNext() && count++ < 8) {
        DataSet d = iter.next();
        File out = new File(f, count + ".bin");
        d.save(out);
        String path = "file:///" + out.getAbsolutePath().replaceAll("\\\\", "/");
        paths.add(path);
        ds.add(d);
    }

    JavaRDD<String> pathRdd = sc.parallelize(paths);
    for( int i=0; i<3; i++ ) {
        ThresholdAlgorithm ta = tm.getThresholdAlgorithm();
        sparkNet.fitPaths(pathRdd);
        //Check also that threshold algorithm was updated/averaged
        ThresholdAlgorithm taAfter = tm.getThresholdAlgorithm();
        assertTrue("Threshold algorithm should have been updated with different instance after averaging", ta != taAfter);
        AdaptiveThresholdAlgorithm ataAfter = (AdaptiveThresholdAlgorithm) taAfter;
        assertFalse(Double.isNaN(ataAfter.getLastSparsity()));
        assertFalse(Double.isNaN(ataAfter.getLastThreshold()));
    }

    Set<Integer> expectedEpochs = new HashSet<>(Arrays.asList(0, 1, 2));
    assertEquals(expectedEpochs, TestListener.epochs);
}
 
Example 15
Source File: TestValidation.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testDataSetValidation() throws Exception {

    File f = folder.newFolder();

    for( int i=0; i<3; i++ ) {
        DataSet ds = new DataSet(Nd4j.create(1,10), Nd4j.create(1,10));
        ds.save(new File(f, i + ".bin"));
    }

    ValidationResult r = SparkDataValidation.validateDataSets(sc, f.toURI().toString());
    ValidationResult exp = ValidationResult.builder()
            .countTotal(3)
            .countTotalValid(3)
            .build();
    assertEquals(exp, r);

    //Add a DataSet that is corrupt (can't be loaded)
    File f3 = new File(f, "3.bin");
    FileUtils.writeStringToFile(f3, "This isn't a DataSet!");
    r = SparkDataValidation.validateDataSets(sc, f.toURI().toString());
    exp = ValidationResult.builder()
            .countTotal(4)
            .countTotalValid(3)
            .countTotalInvalid(1)
            .countLoadingFailure(1)
            .build();
    assertEquals(exp, r);
    f3.delete();


    //Add a DataSet with missing features:
    new DataSet(null, Nd4j.create(1,10)).save(f3);

    r = SparkDataValidation.validateDataSets(sc, f.toURI().toString());
    exp = ValidationResult.builder()
            .countTotal(4)
            .countTotalValid(3)
            .countTotalInvalid(1)
            .countMissingFeatures(1)
            .build();
    assertEquals(exp, r);

    r = SparkDataValidation.deleteInvalidDataSets(sc, f.toURI().toString());
    exp.setCountInvalidDeleted(1);
    assertEquals(exp, r);
    assertFalse(f3.exists());
    for( int i=0; i<3; i++ ){
        assertTrue(new File(f,i + ".bin").exists());
    }

    //Add DataSet with incorrect labels shape:
    new DataSet(Nd4j.create(1,10), Nd4j.create(1,20)).save(f3);
    r = SparkDataValidation.validateDataSets(sc, f.toURI().toString(), new int[]{-1,10}, new int[]{-1,10});
    exp = ValidationResult.builder()
            .countTotal(4)
            .countTotalValid(3)
            .countTotalInvalid(1)
            .countInvalidLabels(1)
            .build();

    assertEquals(exp, r);
}
 
Example 16
Source File: TestSparkMultiLayerParameterAveraging.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testFitViaStringPaths() throws Exception {

        Path tempDir = testDir.newFolder("DL4J-testFitViaStringPaths").toPath();
        File tempDirF = tempDir.toFile();
        tempDirF.deleteOnExit();

        int dataSetObjSize = 5;
        int batchSizePerExecutor = 25;
        DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, 1000, false);
        int i = 0;
        while (iter.hasNext()) {
            File nextFile = new File(tempDirF, i + ".bin");
            DataSet ds = iter.next();
            ds.save(nextFile);
            i++;
        }

        System.out.println("Saved to: " + tempDirF.getAbsolutePath());



        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp())
                        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
                        .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(50)
                                        .activation(Activation.TANH).build())
                        .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
                                        LossFunctions.LossFunction.MCXENT).nIn(50).nOut(10)
                                                        .activation(Activation.SOFTMAX).build())
                        .build();

        SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf,
                        new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
                                        .workerPrefetchNumBatches(5).batchSizePerWorker(batchSizePerExecutor)
                                        .averagingFrequency(1).repartionData(Repartition.Always).build());
        sparkNet.setCollectTrainingStats(true);


        //List files:
        Configuration config = new Configuration();
        FileSystem hdfs = FileSystem.get(tempDir.toUri(), config);
        RemoteIterator<LocatedFileStatus> fileIter =
                        hdfs.listFiles(new org.apache.hadoop.fs.Path(tempDir.toString()), false);

        List<String> paths = new ArrayList<>();
        while (fileIter.hasNext()) {
            String path = fileIter.next().getPath().toString();
            paths.add(path);
        }

        INDArray paramsBefore = sparkNet.getNetwork().params().dup();
        JavaRDD<String> pathRdd = sc.parallelize(paths);
        sparkNet.fitPaths(pathRdd);

        INDArray paramsAfter = sparkNet.getNetwork().params().dup();
        assertNotEquals(paramsBefore, paramsAfter);

        SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
//        System.out.println(stats.statsAsString());
        stats.statsAsString();

        sparkNet.getTrainingMaster().deleteTempFiles(sc);
    }