Java Code Examples for org.nd4j.autodiff.samediff.SameDiff#fit()

The following examples show how to use org.nd4j.autodiff.samediff.SameDiff#fit() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: CheckpointListenerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testCheckpointEveryEpoch() throws Exception {
    File dir = testDir.newFolder();

    SameDiff sd = getModel();
    CheckpointListener l = CheckpointListener.builder(dir)
            .saveEveryNEpochs(1)
            .build();

    sd.setListeners(l);

    DataSetIterator iter = getIter();
    sd.fit(iter, 3);

    File[] files = dir.listFiles();
    String s1 = "checkpoint-0_epoch-0_iter-9";      //Note: epoch is 10 iterations, 0-9, 10-19, 20-29, etc
    String s2 = "checkpoint-1_epoch-1_iter-19";
    String s3 = "checkpoint-2_epoch-2_iter-29";
    boolean found1 = false;
    boolean found2 = false;
    boolean found3 = false;
    for(File f : files){
        String s = f.getAbsolutePath();
        if(s.contains(s1))
            found1 = true;
        if(s.contains(s2))
            found2 = true;
        if(s.contains(s3))
            found3 = true;
    }
    assertEquals(4, files.length);  //3 checkpoints and 1 text file (metadata)
    assertTrue(found1 && found2 && found3);
}
 
Example 2
Source File: CheckpointListenerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void testCheckpointEvery5Iter() throws Exception {
        File dir = testDir.newFolder();

        SameDiff sd = getModel();
        CheckpointListener l = CheckpointListener.builder(dir)
                .saveEveryNIterations(5)
                .build();

        sd.setListeners(l);

        DataSetIterator iter = getIter();
        sd.fit(iter, 2);                        //2 epochs = 20 iter

        File[] files = dir.listFiles();
        List<String> names = Arrays.asList(
                "checkpoint-0_epoch-0_iter-4",
                "checkpoint-1_epoch-0_iter-9",
                "checkpoint-2_epoch-1_iter-14",
                "checkpoint-3_epoch-1_iter-19");
        boolean[] found = new boolean[names.size()];
        for(File f : files){
            String s = f.getAbsolutePath();
//            System.out.println(s);
            for( int i=0; i<names.size(); i++ ){
                if(s.contains(names.get(i))){
                    found[i] = true;
                    break;
                }
            }
        }
        assertEquals(5, files.length);  //4 checkpoints and 1 text file (metadata)

        for( int i=0; i<found.length; i++ ){
            assertTrue(names.get(i), found[i]);
        }
    }
 
Example 3
Source File: ExecDebuggingListenerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void testExecDebugListener(){

        SameDiff sd = SameDiff.create();
        SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);
        SDVariable label = sd.placeHolder("label", DataType.FLOAT, 1, 2);
        SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 3, 2));
        SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 2));
        SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b));
        SDVariable loss = sd.loss.logLoss("loss", label, sm);

        INDArray i = Nd4j.rand(DataType.FLOAT, 1, 3);
        INDArray l = Nd4j.rand(DataType.FLOAT, 1, 2);

        sd.setTrainingConfig(TrainingConfig.builder()
                .dataSetFeatureMapping("in")
                .dataSetLabelMapping("label")
                .updater(new Adam(0.001))
                .build());

        for(ExecDebuggingListener.PrintMode pm : ExecDebuggingListener.PrintMode.values()){
            sd.setListeners(new ExecDebuggingListener(pm, -1, true));
//            sd.output(m, "softmax");
            sd.fit(new DataSet(i, l));

            System.out.println("\n\n\n");
        }

    }
 
Example 4
Source File: UIListenerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testUIListenerBasic() throws Exception {
    Nd4j.getRandom().setSeed(12345);

    IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);

    SameDiff sd = getSimpleNet();

    File dir = testDir.newFolder();
    File f = new File(dir, "logFile.bin");
    UIListener l = UIListener.builder(f)
            .plotLosses(1)
            .trainEvaluationMetrics("softmax", 0, Evaluation.Metric.ACCURACY, Evaluation.Metric.F1)
            .updateRatios(1)
            .build();

    sd.setListeners(l);

    sd.setTrainingConfig(TrainingConfig.builder()
            .dataSetFeatureMapping("in")
            .dataSetLabelMapping("label")
            .updater(new Adam(1e-1))
            .weightDecay(1e-3, true)
            .build());

    sd.fit(iter, 20);

    //Test inference after training with UI Listener still around
    Map<String, INDArray> m = new HashMap<>();
    iter.reset();
    m.put("in", iter.next().getFeatures());
    INDArray out = sd.outputSingle(m, "softmax");
    assertNotNull(out);
    assertArrayEquals(new long[]{150, 3}, out.shape());
}
 
Example 5
Source File: CheckpointListenerTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testCheckpointListenerEveryTimeUnit() throws Exception {
        File dir = testDir.newFolder();
        SameDiff sd = getModel();

        CheckpointListener l = new CheckpointListener.Builder(dir)
                .keepLast(2)
                .saveEvery(4, TimeUnit.SECONDS)
                .build();
        sd.setListeners(l);

        DataSetIterator iter = getIter(15, 150);

        for(int i=0; i<5; i++ ){   //10 iterations total
            sd.fit(iter, 1);
            Thread.sleep(5000);
        }

        //Expect models saved at iterations: 10, 20, 30, 40
        //But: keep only 30, 40
        File[] files = dir.listFiles();

        assertEquals(3, files.length);  //2 files, 1 metadata file

        List<String> names = Arrays.asList(
                "checkpoint-2_epoch-3_iter-30",
                "checkpoint-3_epoch-4_iter-40");
        boolean[] found = new boolean[names.size()];
        for(File f : files){
            String s = f.getAbsolutePath();
//            System.out.println(s);
            for( int i=0; i<names.size(); i++ ){
                if(s.contains(names.get(i))){
                    found[i] = true;
                    break;
                }
            }
        }

        for( int i=0; i<found.length; i++ ){
            assertTrue(names.get(i), found[i]);
        }
    }
 
Example 6
Source File: CheckpointListenerTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testCheckpointListenerKeepLast3AndEvery3() throws Exception {
    File dir = testDir.newFolder();
    SameDiff sd = getModel();

    CheckpointListener l = new CheckpointListener.Builder(dir)
            .keepLastAndEvery(3, 3)
            .saveEveryNEpochs(2)
            .fileNamePrefix("myFilePrefix")
            .build();
    sd.setListeners(l);

    DataSetIterator iter = getIter();

    sd.fit(iter, 20);

    //Expect models saved at end of epochs: 1, 3, 5, 7, 9, 11, 13, 15, 17, 19
    //But: keep only 5, 11, 15, 17, 19
    File[] files = dir.listFiles();
    int count = 0;
    Set<Integer> cpNums = new HashSet<>();
    Set<Integer> epochNums = new HashSet<>();
    for(File f2 : files){
        if(!f2.getPath().endsWith(".bin")){
            continue;
        }
        count++;
        int idx = f2.getName().indexOf("epoch-");
        int end = f2.getName().indexOf("_", idx);
        int num = Integer.parseInt(f2.getName().substring(idx + "epoch-".length(), end));
        epochNums.add(num);

        int start = f2.getName().indexOf("checkpoint-");
        end = f2.getName().indexOf("_", start + "checkpoint-".length());
        int epochNum = Integer.parseInt(f2.getName().substring(start + "checkpoint-".length(), end));
        cpNums.add(epochNum);
    }

    assertEquals(cpNums.toString(), 5, cpNums.size());
    Assert.assertTrue(cpNums.toString(), cpNums.containsAll(Arrays.asList(2, 5, 7, 8, 9)));
    Assert.assertTrue(epochNums.toString(), epochNums.containsAll(Arrays.asList(5, 11, 15, 17, 19)));

    assertEquals(5, l.availableCheckpoints().size());
}