Java Code Examples for org.nd4j.autodiff.samediff.SameDiff#output()

The following examples show how to use org.nd4j.autodiff.samediff.SameDiff#output() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: MiscOpValidation.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testBatchMmulBasic() {
    OpValidationSuite.ignoreFailing();  //https://github.com/deeplearning4j/deeplearning4j/issues/6873
    int M = 5;
    int N = 3;
    int K = 4;

    INDArray A = Nd4j.create(new float[]{1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}).reshape(M, N).castTo(DataType.DOUBLE);
    INDArray B = Nd4j.create(new float[]{1,2,3,4,5,6,7,8,9,10,11,12}).reshape(N, K).castTo(DataType.DOUBLE);

    SameDiff sd = SameDiff.create();

    SDVariable A1 = sd.var("A1", A);
    SDVariable A2 = sd.var("A2", A);
    SDVariable B1 = sd.var("B1", B);
    SDVariable B2 = sd.var("B2", B);

    SDVariable[] batchMul = sd.batchMmul(new SDVariable[] {A1, A2}, new SDVariable[] {B1, B2});
    Map<String,INDArray> m = sd.output(Collections.emptyMap(), sd.outputs());

    INDArray resultingMatrix = m.get(batchMul[0].name());
    //System.out.print(resultingMatrix);
}
 
Example 2
Source File: MiscOpValidation.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testMergeRank1(){
    SameDiff sd = SameDiff.create();
    SDVariable var = sd.var("in", Nd4j.create(new long[]{1}).assign(5));

    SDVariable merged = sd.math().mergeAvg("merged", new SDVariable[]{var});
    SDVariable sum = sd.sum(merged);

    Map<String,INDArray> m = sd.output(Collections.emptyMap(), "merged");
    Map<String,INDArray> gm = sd.calculateGradients(null, "in");

    INDArray out = m.get("merged");
    assertEquals(1, out.rank());

    INDArray inGrad = gm.get("in");
    assertEquals(1, inGrad.rank());
}
 
Example 3
Source File: ShapeOpValidation.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testGather2(){
    SameDiff sd = SameDiff.create();
    SDVariable input = sd.var("in", Nd4j.arange(6).castTo(DataType.FLOAT).reshape(2,3));
    SDVariable indices = sd.constant("indices", Nd4j.createFromArray(0));

    SDVariable gathered = sd.gather(input, indices, 1);
    SDVariable loss = gathered.std(true);

    sd.output((Map<String,INDArray>)null, gathered.name());
    sd.setLossVariables(gathered.name());

    String err = OpValidation.validate(new TestCase(sd)
            .gradCheckEpsilon(1e-3)
            .gradCheckMaxRelativeError(1e-4));

    assertNull(err);
}
 
Example 4
Source File: RnnOpValidation.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testRnnBlockCell(){
    Nd4j.getRandom().setSeed(12345);
    int mb = 2;
    int nIn = 3;
    int nOut = 4;

    SameDiff sd = SameDiff.create();
    SDVariable x = sd.constant(Nd4j.rand(DataType.FLOAT, mb, nIn));
    SDVariable cLast = sd.constant(Nd4j.rand(DataType.FLOAT, mb, nOut));
    SDVariable yLast = sd.constant(Nd4j.rand(DataType.FLOAT, mb, nOut));
    SDVariable W = sd.constant(Nd4j.rand(DataType.FLOAT, (nIn+nOut), 4*nOut));
    SDVariable Wci = sd.constant(Nd4j.rand(DataType.FLOAT, nOut));
    SDVariable Wcf = sd.constant(Nd4j.rand(DataType.FLOAT, nOut));
    SDVariable Wco = sd.constant(Nd4j.rand(DataType.FLOAT, nOut));
    SDVariable b = sd.constant(Nd4j.rand(DataType.FLOAT, 4*nOut));

    double fb = 1.0;
    LSTMConfiguration conf = LSTMConfiguration.builder()
            .peepHole(true)
            .forgetBias(fb)
            .clippingCellValue(0.0)
            .build();

    LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b)
            .inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build();

    LSTMCellOutputs v = new LSTMCellOutputs(sd.rnn().lstmCell(x, cLast, yLast, weights, conf));  //Output order: i, c, f, o, z, h, y
    List<String> toExec = new ArrayList<>();
    for(SDVariable sdv : v.getAllOutputs()){
        toExec.add(sdv.name());
    }

    //Test forward pass:
    Map<String,INDArray> m = sd.output(null, toExec);

    //Weights and bias order: [i, f, z, o]

    //Block input (z) - post tanh:
    INDArray wz_x = W.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.interval(nOut, 2*nOut));           //Input weights
    INDArray wz_r = W.getArr().get(NDArrayIndex.interval(nIn,nIn+nOut), NDArrayIndex.interval(nOut, 2*nOut));    //Recurrent weights
    INDArray bz = b.getArr().get(NDArrayIndex.interval(nOut, 2*nOut));

    INDArray zExp = x.getArr().mmul(wz_x).addiRowVector(bz);        //[mb,nIn]*[nIn, nOut] + [nOut]
    zExp.addi(yLast.getArr().mmul(wz_r));   //[mb,nOut]*[nOut,nOut]
    Transforms.tanh(zExp, false);

    INDArray zAct = m.get(toExec.get(4));
    assertEquals(zExp, zAct);

    //Input modulation gate (post sigmoid) - i: (note: peephole input - last time step)
    INDArray wi_x = W.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.interval(0, nOut));           //Input weights
    INDArray wi_r = W.getArr().get(NDArrayIndex.interval(nIn,nIn+nOut), NDArrayIndex.interval(0, nOut));    //Recurrent weights
    INDArray bi = b.getArr().get(NDArrayIndex.interval(0, nOut));

    INDArray iExp = x.getArr().mmul(wi_x).addiRowVector(bi);        //[mb,nIn]*[nIn, nOut] + [nOut]
    iExp.addi(yLast.getArr().mmul(wi_r));   //[mb,nOut]*[nOut,nOut]
    iExp.addi(cLast.getArr().mulRowVector(Wci.getArr()));    //Peephole
    Transforms.sigmoid(iExp, false);
    assertEquals(iExp, m.get(toExec.get(0)));

    //Forget gate (post sigmoid): (note: peephole input - last time step)
    INDArray wf_x = W.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.interval(2*nOut, 3*nOut));           //Input weights
    INDArray wf_r = W.getArr().get(NDArrayIndex.interval(nIn,nIn+nOut), NDArrayIndex.interval(2*nOut, 3*nOut));    //Recurrent weights
    INDArray bf = b.getArr().get(NDArrayIndex.interval(2*nOut, 3*nOut));

    INDArray fExp = x.getArr().mmul(wf_x).addiRowVector(bf);        //[mb,nIn]*[nIn, nOut] + [nOut]
    fExp.addi(yLast.getArr().mmul(wf_r));   //[mb,nOut]*[nOut,nOut]
    fExp.addi(cLast.getArr().mulRowVector(Wcf.getArr()));   //Peephole
    fExp.addi(fb);
    Transforms.sigmoid(fExp, false);
    assertEquals(fExp, m.get(toExec.get(2)));

    //Cell state (pre tanh): tanh(z) .* sigmoid(i) + sigmoid(f) .* cLast
    INDArray cExp = zExp.mul(iExp).add(fExp.mul(cLast.getArr()));
    INDArray cAct = m.get(toExec.get(1));
    assertEquals(cExp, cAct);

    //Output gate (post sigmoid): (note: peephole input: current time step)
    INDArray wo_x = W.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.interval(3*nOut, 4*nOut));           //Input weights
    INDArray wo_r = W.getArr().get(NDArrayIndex.interval(nIn,nIn+nOut), NDArrayIndex.interval(3*nOut, 4*nOut));    //Recurrent weights
    INDArray bo = b.getArr().get(NDArrayIndex.interval(3*nOut, 4*nOut));

    INDArray oExp = x.getArr().mmul(wo_x).addiRowVector(bo);        //[mb,nIn]*[nIn, nOut] + [nOut]
    oExp.addi(yLast.getArr().mmul(wo_r));   //[mb,nOut]*[nOut,nOut]
    oExp.addi(cExp.mulRowVector(Wco.getArr())); //Peephole
    Transforms.sigmoid(oExp, false);
    assertEquals(oExp, m.get(toExec.get(3)));

    //Cell state, post tanh
    INDArray hExp = Transforms.tanh(cExp, true);
    assertEquals(hExp, m.get(toExec.get(5)));

    //Final output
    INDArray yExp = hExp.mul(oExp);
    assertEquals(yExp, m.get(toExec.get(6)));
}
 
Example 5
Source File: RnnOpValidation.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testRnnBlockCellManualTFCompare() {
        //Test case: "rnn/lstmblockcell/static_batch1_n3-2_tsLength1_noPH_noClip_fBias1_noIS"

        SameDiff sd = SameDiff.create();
        INDArray zero2d = Nd4j.createFromArray(new float[][]{{0,0}});
        INDArray zero1d = Nd4j.createFromArray(new float[]{0,0});
        SDVariable x = sd.constant(Nd4j.createFromArray(new float[][]{{0.7787856f,0.80119777f,0.72437465f}}));
        SDVariable cLast = sd.constant(zero2d);
        SDVariable yLast = sd.constant(zero2d);
        //Weights shape: [(nIn+nOut), 4*nOut]
        SDVariable W = sd.constant(Nd4j.createFromArray(-0.61977,-0.5708851,-0.38089648,-0.07994056,-0.31706482,0.21500933,-0.35454142,-0.3239095,-0.3177906,
                0.39918554,-0.3115911,0.540841,0.38552666,0.34270835,-0.63456273,-0.13917702,-0.2985368,0.343238,
                -0.3178353,0.017154932,-0.060259163,0.28841054,-0.6257687,0.65097713,0.24375653,-0.22315514,0.2033832,
                0.24894875,-0.2062299,-0.2242794,-0.3809483,-0.023048997,-0.036284804,-0.46398938,-0.33979666,0.67012596,
                -0.42168984,0.34208286,-0.0456419,0.39803517).castTo(DataType.FLOAT).reshape(5,8));
        SDVariable Wci = sd.constant(zero1d);
        SDVariable Wcf = sd.constant(zero1d);
        SDVariable Wco = sd.constant(zero1d);
        SDVariable b = sd.constant(Nd4j.zeros(DataType.FLOAT, 8));

        double fb = 1.0;
        LSTMConfiguration conf = LSTMConfiguration.builder()
                .peepHole(false)
                .forgetBias(fb)
                .clippingCellValue(0.0)
                .build();

        LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b)
                .inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build();

        LSTMCellOutputs v = new LSTMCellOutputs(sd.rnn().lstmCell(x, cLast, yLast, weights, conf));  //Output order: i, c, f, o, z, h, y
        List<String> toExec = new ArrayList<>();
        for(SDVariable sdv : v.getAllOutputs()){
            toExec.add(sdv.name());
        }

        //Test forward pass:
        Map<String,INDArray> m = sd.output(null, toExec);

        INDArray out0 = Nd4j.create(new float[]{0.27817473f, 0.53092605f}, new int[]{1,2});     //Input mod gate
        INDArray out1 = Nd4j.create(new float[]{-0.18100877f, 0.19417824f}, new int[]{1,2});    //CS (pre tanh)
        INDArray out2 = Nd4j.create(new float[]{0.73464274f, 0.83901811f}, new int[]{1,2});     //Forget gate
        INDArray out3 = Nd4j.create(new float[]{0.22481689f, 0.52692068f}, new int[]{1,2});     //Output gate

        INDArray out4 = Nd4j.create(new float[]{-0.65070170f, 0.36573499f}, new int[]{1,2});    //block input
        INDArray out5 = Nd4j.create(new float[]{-0.17905743f, 0.19177397f}, new int[]{1,2});    //Cell state
        INDArray out6 = Nd4j.create(new float[]{-0.04025514f, 0.10104967f}, new int[]{1,2});    //Output

//        for(int i=0; i<toExec.size(); i++ ){
//            System.out.println(i + "\t" + m.get(toExec.get(i)));
//        }

        assertEquals(out0, m.get(toExec.get(0)));       //Input modulation gate
        assertEquals(out1, m.get(toExec.get(1)));       //Cell state (pre tanh)
        assertEquals(out2, m.get(toExec.get(2)));       //Forget gate
        assertEquals(out3, m.get(toExec.get(3)));       //Output gate
        assertEquals(out4, m.get(toExec.get(4)));       //block input
        assertEquals(out5, m.get(toExec.get(5)));       //Cell state
        assertEquals(out6, m.get(toExec.get(6)));       //Output
    }
 
Example 6
Source File: RnnOpValidation.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testGRUCell(){
    Nd4j.getRandom().setSeed(12345);
    int mb = 2;
    int nIn = 3;
    int nOut = 4;

    SameDiff sd = SameDiff.create();
    SDVariable x = sd.constant(Nd4j.rand(DataType.FLOAT, mb, nIn));
    SDVariable hLast = sd.constant(Nd4j.rand(DataType.FLOAT, mb, nOut));
    SDVariable Wru = sd.constant(Nd4j.rand(DataType.FLOAT, (nIn+nOut), 2*nOut));
    SDVariable Wc = sd.constant(Nd4j.rand(DataType.FLOAT, (nIn+nOut), nOut));
    SDVariable bru = sd.constant(Nd4j.rand(DataType.FLOAT, 2*nOut));
    SDVariable bc = sd.constant(Nd4j.rand(DataType.FLOAT, nOut));

    double fb = 1.0;
    GRUWeights weights = GRUWeights.builder()
            .ruWeight(Wru)
            .cWeight(Wc)
            .ruBias(bru)
            .cBias(bc)
            .build();

    SDVariable[] v = sd.rnn().gruCell(x, hLast, weights);
    List<String> toExec = new ArrayList<>();
    for(SDVariable sdv : v){
        toExec.add(sdv.name());
    }

    //Test forward pass:
    Map<String,INDArray> m = sd.output(null, toExec);

    //Weights and bias order: [r, u], [c]

    //Reset gate:
    INDArray wr_x = Wru.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.interval(0, nOut));           //Input weights
    INDArray wr_r = Wru.getArr().get(NDArrayIndex.interval(nIn,nIn+nOut), NDArrayIndex.interval(0, nOut));    //Recurrent weights
    INDArray br = bru.getArr().get(NDArrayIndex.interval(0, nOut));

    INDArray rExp = x.getArr().mmul(wr_x).addiRowVector(br);        //[mb,nIn]*[nIn, nOut] + [nOut]
    rExp.addi(hLast.getArr().mmul(wr_r));   //[mb,nOut]*[nOut,nOut]
    Transforms.sigmoid(rExp,false);

    INDArray rAct = m.get(toExec.get(0));
    assertEquals(rExp, rAct);

    //Update gate:
    INDArray wu_x = Wru.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.interval(nOut, 2*nOut));           //Input weights
    INDArray wu_r = Wru.getArr().get(NDArrayIndex.interval(nIn,nIn+nOut), NDArrayIndex.interval(nOut, 2*nOut));    //Recurrent weights
    INDArray bu = bru.getArr().get(NDArrayIndex.interval(nOut, 2*nOut));

    INDArray uExp = x.getArr().mmul(wu_x).addiRowVector(bu);        //[mb,nIn]*[nIn, nOut] + [nOut]
    uExp.addi(hLast.getArr().mmul(wu_r));   //[mb,nOut]*[nOut,nOut]
    Transforms.sigmoid(uExp,false);

    INDArray uAct = m.get(toExec.get(1));
    assertEquals(uExp, uAct);

    //c = tanh(x * Wcx + Wcr * (hLast .* r))
    INDArray Wcx = Wc.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.all());
    INDArray Wcr = Wc.getArr().get(NDArrayIndex.interval(nIn, nIn+nOut), NDArrayIndex.all());
    INDArray cExp = x.getArr().mmul(Wcx);
    cExp.addi(hLast.getArr().mul(rExp).mmul(Wcr));
    cExp.addiRowVector(bc.getArr());
    Transforms.tanh(cExp, false);

    assertEquals(cExp, m.get(toExec.get(2)));

    //h = u * hLast + (1-u) * c
    INDArray hExp = uExp.mul(hLast.getArr()).add(uExp.rsub(1.0).mul(cExp));
    assertEquals(hExp, m.get(toExec.get(3)));
}
 
Example 7
Source File: ShapeOpValidation.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testDistancesExec(){
        //https://github.com/deeplearning4j/deeplearning4j/issues/7001
        for(String s : new String[]{"euclidean", "manhattan", "cosinesim", "cosinedist", "jaccard"}) {
            log.info("Starting: {}", s);
            INDArray defaultTestCase = Nd4j.create(4, 4);
            defaultTestCase.putRow(0, Nd4j.create(new float[]{0, 2, -2, 0}));
            defaultTestCase.putRow(1, Nd4j.create(new float[]{0, 1, -1, 0}));
            defaultTestCase.putRow(2, Nd4j.create(new float[]{0, -1, 1, 0}));
            defaultTestCase.putRow(3, Nd4j.create(new float[]{0, -2, 2, 0}));
            long singleEmbeddingSize = defaultTestCase.size(1) / 2L;

            // Split vectors
            INDArray x = defaultTestCase.get(NDArrayIndex.all(), NDArrayIndex.interval(0, singleEmbeddingSize));
            INDArray y = defaultTestCase.get(NDArrayIndex.all(), NDArrayIndex.interval(singleEmbeddingSize, defaultTestCase.size(1)));

            log.info(y.shapeInfoToString());

            SameDiff sd = SameDiff.create();
            sd.enableDebugMode();

            SDVariable xSd = sd.var("x", x);
            SDVariable ySd = sd.var("y", y);

            ySd = ySd.add(ySd);
            SDVariable dist;
            switch (s){
                case "euclidean":
                    dist = sd.math().euclideanDistance(s, ySd, xSd, 0);
                    break;
                case "manhattan":
                    dist = sd.math().manhattanDistance(s, ySd, xSd, 0);
                    break;
                case "cosinesim":
                    dist = sd.math().cosineSimilarity(s, ySd, xSd, 0);
                    break;
                case "cosinedist":
                    dist = sd.math().cosineDistance(s, ySd, xSd, 0);
                    break;
                case "jaccard":
                    dist = sd.math().jaccardDistance(s, ySd, xSd, 0);
                    break;
                default:
                    throw new RuntimeException();
            }

            SDVariable loss = dist.sum();


//            log.info(sd.summary());
            sd.output(Collections.emptyMap(), Lists.newArrayList(s));
            sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet());
        }
    }
 
Example 8
Source File: ValidateZooModelPredictions.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testMobilenetV1() throws Exception {
    if(TFGraphTestZooModels.isPPC()){
        /*
        Ugly hack to temporarily disable tests on PPC only on CI
        Issue logged here: https://github.com/deeplearning4j/deeplearning4j/issues/7657
        These will be re-enabled for PPC once fixed - in the mean time, remaining tests will be used to detect and prevent regressions
         */

        log.warn("TEMPORARILY SKIPPING TEST ON PPC ARCHITECTURE DUE TO KNOWN JVM CRASH ISSUES - SEE https://github.com/deeplearning4j/deeplearning4j/issues/7657");
        OpValidationSuite.ignoreFailing();
    }

    TFGraphTestZooModels.currentTestDir = testDir.newFolder();

    //Load model
    String path = "tf_graphs/zoo_models/mobilenet_v1_0.5_128/tf_model.txt";
    File resource = new ClassPathResource(path).getFile();
    SameDiff sd = TFGraphTestZooModels.LOADER.apply(resource, "mobilenet_v1_0.5_128");

    //Load data
    //Because we don't have DataVec NativeImageLoader in ND4J tests due to circular dependencies, we'll load the image previously saved...
    File imgFile = new ClassPathResource("deeplearning4j-zoo/goldenretriever_rgb128_unnormalized_nchw_INDArray.bin").getFile();
    INDArray img = Nd4j.readBinary(imgFile).castTo(DataType.FLOAT);
    img = img.permute(0,2,3,1).dup();   //to NHWC

    //Mobilenet V1 - not sure, but probably using inception preprocessing
    //i.e., scale to (-1,1) range
    //Image is originally 0 to 255
    img.divi(255).subi(0.5).muli(2);

    double min = img.minNumber().doubleValue();
    double max = img.maxNumber().doubleValue();

    assertTrue(min >= -1 && min <= -0.6);
    assertTrue(max <= 1 && max >= 0.6);

    //Perform inference
    List<String> inputs = sd.inputs();
    assertEquals(1, inputs.size());

    String out = "MobilenetV1/Predictions/Softmax";
    Map<String,INDArray> m = sd.output(Collections.singletonMap(inputs.get(0), img), out);

    INDArray outArr = m.get(out);


    System.out.println("SHAPE: " + Arrays.toString(outArr.shape()));
    System.out.println(outArr);

    INDArray argmax = outArr.argMax(1);

    //Load labels
    List<String> labels = labels();

    int classIdx = argmax.getInt(0);
    String className = labels.get(classIdx);
    String expClass = "golden retriever";
    double prob = outArr.getDouble(classIdx);

    System.out.println("Predicted class: \"" + className + "\" - probability = " + prob);
    assertEquals(expClass, className);
}
 
Example 9
Source File: ValidateZooModelPredictions.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testResnetV2() throws Exception {
    if(TFGraphTestZooModels.isPPC()){
        /*
        Ugly hack to temporarily disable tests on PPC only on CI
        Issue logged here: https://github.com/deeplearning4j/deeplearning4j/issues/7657
        These will be re-enabled for PPC once fixed - in the mean time, remaining tests will be used to detect and prevent regressions
         */

        log.warn("TEMPORARILY SKIPPING TEST ON PPC ARCHITECTURE DUE TO KNOWN JVM CRASH ISSUES - SEE https://github.com/deeplearning4j/deeplearning4j/issues/7657");
        OpValidationSuite.ignoreFailing();
    }

    TFGraphTestZooModels.currentTestDir = testDir.newFolder();

    //Load model
    String path = "tf_graphs/zoo_models/resnetv2_imagenet_frozen_graph/tf_model.txt";
    File resource = new ClassPathResource(path).getFile();
    SameDiff sd = TFGraphTestZooModels.LOADER.apply(resource, "resnetv2_imagenet_frozen_graph");

    //Load data
    //Because we don't have DataVec NativeImageLoader in ND4J tests due to circular dependencies, we'll load the image previously saved...
    File imgFile = new ClassPathResource("deeplearning4j-zoo/goldenretriever_rgb224_unnormalized_nchw_INDArray.bin").getFile();
    INDArray img = Nd4j.readBinary(imgFile).castTo(DataType.FLOAT);
    img = img.permute(0,2,3,1).dup();   //to NHWC

    //Resnet v2 - NO external normalization, just resize and center crop
    // https://github.com/tensorflow/models/blob/d32d957a02f5cffb745a4da0d78f8432e2c52fd4/research/tensorrt/tensorrt.py#L70
    // https://github.com/tensorflow/models/blob/1af55e018eebce03fb61bba9959a04672536107d/official/resnet/imagenet_preprocessing.py#L253-L256

    //Perform inference
    List<String> inputs = sd.inputs();
    assertEquals(1, inputs.size());

    String out = "softmax_tensor";
    Map<String,INDArray> m = sd.output(Collections.singletonMap(inputs.get(0), img), out);

    INDArray outArr = m.get(out);


    System.out.println("SHAPE: " + Arrays.toString(outArr.shape()));
    System.out.println(outArr);

    INDArray argmax = outArr.argMax(1);

    //Load labels
    List<String> labels = labels();

    int classIdx = argmax.getInt(0);
    String className = labels.get(classIdx);
    String expClass = "golden retriever";
    double prob = outArr.getDouble(classIdx);

    System.out.println("Predicted class: " + classIdx + " - \"" + className + "\" - probability = " + prob);
    assertEquals(expClass, className);
}
 
Example 10
Source File: ImportModelDebugger.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public static void main(String[] args) {

        File modelFile = new File("C:\\Temp\\TF_Graphs\\cifar10_gan_85\\tf_model.pb");
        File rootDir = new File("C:\\Temp\\TF_Graphs\\cifar10_gan_85");

        SameDiff sd = TFGraphMapper.importGraph(modelFile);

        ImportDebugListener l = ImportDebugListener.builder(rootDir)
                .checkShapesOnly(true)
                .floatingPointEps(1e-5)
                .onFailure(ImportDebugListener.OnFailure.EXCEPTION)
                .logPass(true)
                .build();

        sd.setListeners(l);

        Map<String,INDArray> ph = loadPlaceholders(rootDir);

        List<String> outputs = sd.outputs();

        sd.output(ph, outputs);
    }