Java Code Examples for org.nd4j.autodiff.samediff.SameDiff#setListeners()

The following examples show how to use org.nd4j.autodiff.samediff.SameDiff#setListeners() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: CheckpointListenerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testCheckpointEveryEpoch() throws Exception {
    File dir = testDir.newFolder();

    SameDiff sd = getModel();
    CheckpointListener l = CheckpointListener.builder(dir)
            .saveEveryNEpochs(1)
            .build();

    sd.setListeners(l);

    DataSetIterator iter = getIter();
    sd.fit(iter, 3);

    File[] files = dir.listFiles();
    String s1 = "checkpoint-0_epoch-0_iter-9";      //Note: epoch is 10 iterations, 0-9, 10-19, 20-29, etc
    String s2 = "checkpoint-1_epoch-1_iter-19";
    String s3 = "checkpoint-2_epoch-2_iter-29";
    boolean found1 = false;
    boolean found2 = false;
    boolean found3 = false;
    for(File f : files){
        String s = f.getAbsolutePath();
        if(s.contains(s1))
            found1 = true;
        if(s.contains(s2))
            found2 = true;
        if(s.contains(s3))
            found3 = true;
    }
    assertEquals(4, files.length);  //3 checkpoints and 1 text file (metadata)
    assertTrue(found1 && found2 && found3);
}
 
Example 2
Source File: CheckpointListenerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void testCheckpointEvery5Iter() throws Exception {
        File dir = testDir.newFolder();

        SameDiff sd = getModel();
        CheckpointListener l = CheckpointListener.builder(dir)
                .saveEveryNIterations(5)
                .build();

        sd.setListeners(l);

        DataSetIterator iter = getIter();
        sd.fit(iter, 2);                        //2 epochs = 20 iter

        File[] files = dir.listFiles();
        List<String> names = Arrays.asList(
                "checkpoint-0_epoch-0_iter-4",
                "checkpoint-1_epoch-0_iter-9",
                "checkpoint-2_epoch-1_iter-14",
                "checkpoint-3_epoch-1_iter-19");
        boolean[] found = new boolean[names.size()];
        for(File f : files){
            String s = f.getAbsolutePath();
//            System.out.println(s);
            for( int i=0; i<names.size(); i++ ){
                if(s.contains(names.get(i))){
                    found[i] = true;
                    break;
                }
            }
        }
        assertEquals(5, files.length);  //4 checkpoints and 1 text file (metadata)

        for( int i=0; i<found.length; i++ ){
            assertTrue(names.get(i), found[i]);
        }
    }
 
Example 3
Source File: ExecDebuggingListenerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void testExecDebugListener(){

        SameDiff sd = SameDiff.create();
        SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);
        SDVariable label = sd.placeHolder("label", DataType.FLOAT, 1, 2);
        SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 3, 2));
        SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 2));
        SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b));
        SDVariable loss = sd.loss.logLoss("loss", label, sm);

        INDArray i = Nd4j.rand(DataType.FLOAT, 1, 3);
        INDArray l = Nd4j.rand(DataType.FLOAT, 1, 2);

        sd.setTrainingConfig(TrainingConfig.builder()
                .dataSetFeatureMapping("in")
                .dataSetLabelMapping("label")
                .updater(new Adam(0.001))
                .build());

        for(ExecDebuggingListener.PrintMode pm : ExecDebuggingListener.PrintMode.values()){
            sd.setListeners(new ExecDebuggingListener(pm, -1, true));
//            sd.output(m, "softmax");
            sd.fit(new DataSet(i, l));

            System.out.println("\n\n\n");
        }

    }
 
Example 4
Source File: UIListenerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testUIListenerBasic() throws Exception {
    Nd4j.getRandom().setSeed(12345);

    IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);

    SameDiff sd = getSimpleNet();

    File dir = testDir.newFolder();
    File f = new File(dir, "logFile.bin");
    UIListener l = UIListener.builder(f)
            .plotLosses(1)
            .trainEvaluationMetrics("softmax", 0, Evaluation.Metric.ACCURACY, Evaluation.Metric.F1)
            .updateRatios(1)
            .build();

    sd.setListeners(l);

    sd.setTrainingConfig(TrainingConfig.builder()
            .dataSetFeatureMapping("in")
            .dataSetLabelMapping("label")
            .updater(new Adam(1e-1))
            .weightDecay(1e-3, true)
            .build());

    sd.fit(iter, 20);

    //Test inference after training with UI Listener still around
    Map<String, INDArray> m = new HashMap<>();
    iter.reset();
    m.put("in", iter.next().getFeatures());
    INDArray out = sd.outputSingle(m, "softmax");
    assertNotNull(out);
    assertArrayEquals(new long[]{150, 3}, out.shape());
}
 
Example 5
Source File: ProfilingListenerTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testProfilingListenerSimple() throws Exception {

        SameDiff sd = SameDiff.create();
        SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);
        SDVariable label = sd.placeHolder("label", DataType.FLOAT, 1, 2);
        SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 3, 2));
        SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 2));
        SDVariable sm = sd.nn.softmax("predictions", in.mmul("matmul", w).add("addbias", b));
        SDVariable loss = sd.loss.logLoss("loss", label, sm);

        INDArray i = Nd4j.rand(DataType.FLOAT, 1, 3);
        INDArray l = Nd4j.rand(DataType.FLOAT, 1, 2);


        File dir = testDir.newFolder();
        File f = new File(dir, "test.json");
        ProfilingListener listener = ProfilingListener.builder(f)
                .recordAll()
                .warmup(5)
                .build();

        sd.setListeners(listener);

        Map<String,INDArray> ph = new HashMap<>();
        ph.put("in", i);

        for( int x=0; x<10; x++ ) {
            sd.outputSingle(ph, "predictions");
        }

        String content = FileUtils.readFileToString(f, StandardCharsets.UTF_8);
//        System.out.println(content);
        assertFalse(content.isEmpty());

        //Should be 2 begins and 2 ends for each entry
        //5 warmup iterations, 5 profile iterations, x2 for both the op name and the op "instance" name
        String[] opNames = {"mmul", "add", "softmax"};
        for(String s : opNames){
            assertEquals(s, 10, StringUtils.countMatches(content, s));
        }


        System.out.println("///////////////////////////////////////////");
        ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.SAMEDIFF);

    }
 
Example 6
Source File: CheckpointListenerTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testCheckpointListenerEveryTimeUnit() throws Exception {
        File dir = testDir.newFolder();
        SameDiff sd = getModel();

        CheckpointListener l = new CheckpointListener.Builder(dir)
                .keepLast(2)
                .saveEvery(4, TimeUnit.SECONDS)
                .build();
        sd.setListeners(l);

        DataSetIterator iter = getIter(15, 150);

        for(int i=0; i<5; i++ ){   //10 iterations total
            sd.fit(iter, 1);
            Thread.sleep(5000);
        }

        //Expect models saved at iterations: 10, 20, 30, 40
        //But: keep only 30, 40
        File[] files = dir.listFiles();

        assertEquals(3, files.length);  //2 files, 1 metadata file

        List<String> names = Arrays.asList(
                "checkpoint-2_epoch-3_iter-30",
                "checkpoint-3_epoch-4_iter-40");
        boolean[] found = new boolean[names.size()];
        for(File f : files){
            String s = f.getAbsolutePath();
//            System.out.println(s);
            for( int i=0; i<names.size(); i++ ){
                if(s.contains(names.get(i))){
                    found[i] = true;
                    break;
                }
            }
        }

        for( int i=0; i<found.length; i++ ){
            assertTrue(names.get(i), found[i]);
        }
    }
 
Example 7
Source File: CheckpointListenerTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testCheckpointListenerKeepLast3AndEvery3() throws Exception {
    File dir = testDir.newFolder();
    SameDiff sd = getModel();

    CheckpointListener l = new CheckpointListener.Builder(dir)
            .keepLastAndEvery(3, 3)
            .saveEveryNEpochs(2)
            .fileNamePrefix("myFilePrefix")
            .build();
    sd.setListeners(l);

    DataSetIterator iter = getIter();

    sd.fit(iter, 20);

    //Expect models saved at end of epochs: 1, 3, 5, 7, 9, 11, 13, 15, 17, 19
    //But: keep only 5, 11, 15, 17, 19
    File[] files = dir.listFiles();
    int count = 0;
    Set<Integer> cpNums = new HashSet<>();
    Set<Integer> epochNums = new HashSet<>();
    for(File f2 : files){
        if(!f2.getPath().endsWith(".bin")){
            continue;
        }
        count++;
        int idx = f2.getName().indexOf("epoch-");
        int end = f2.getName().indexOf("_", idx);
        int num = Integer.parseInt(f2.getName().substring(idx + "epoch-".length(), end));
        epochNums.add(num);

        int start = f2.getName().indexOf("checkpoint-");
        end = f2.getName().indexOf("_", start + "checkpoint-".length());
        int epochNum = Integer.parseInt(f2.getName().substring(start + "checkpoint-".length(), end));
        cpNums.add(epochNum);
    }

    assertEquals(cpNums.toString(), 5, cpNums.size());
    Assert.assertTrue(cpNums.toString(), cpNums.containsAll(Arrays.asList(2, 5, 7, 8, 9)));
    Assert.assertTrue(epochNums.toString(), epochNums.containsAll(Arrays.asList(5, 11, 15, 17, 19)));

    assertEquals(5, l.availableCheckpoints().size());
}
 
Example 8
Source File: ImportModelDebugger.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public static void main(String[] args) {

        File modelFile = new File("C:\\Temp\\TF_Graphs\\cifar10_gan_85\\tf_model.pb");
        File rootDir = new File("C:\\Temp\\TF_Graphs\\cifar10_gan_85");

        SameDiff sd = TFGraphMapper.importGraph(modelFile);

        ImportDebugListener l = ImportDebugListener.builder(rootDir)
                .checkShapesOnly(true)
                .floatingPointEps(1e-5)
                .onFailure(ImportDebugListener.OnFailure.EXCEPTION)
                .logPass(true)
                .build();

        sd.setListeners(l);

        Map<String,INDArray> ph = loadPlaceholders(rootDir);

        List<String> outputs = sd.outputs();

        sd.output(ph, outputs);
    }