org.nd4j.imports.graphmapper.tf.TFGraphMapper Java Examples

The following examples show how to use org.nd4j.imports.graphmapper.tf.TFGraphMapper. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: TensorFlowImportTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
@Ignore
public void testCrash_119_matrix_diag() throws Exception {
    Nd4j.create(1);

    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/partition_stitch_misc/frozen_model.pb").getInputStream());
    assertNotNull(tg);

    val input0 = Nd4j.create(2, 5, 4).assign(1.0);
    val input1 = Nd4j.create(2, 3, 5, 4).assign(2.0);
    val input2 = Nd4j.create(3, 1, 5, 4).assign(3.0);
    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
    tg.associateArrayWithVariable(input2, tg.getVariable("input_2"));


    tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/partition_stitch_misc.fb"));
}
 
Example #2
Source File: TensorArray.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val idd = nodeDef.getInput(nodeDef.getInputCount() - 1);
    NodeDef iddNode = null;
    for(int i = 0; i < graph.getNodeCount(); i++) {
        if(graph.getNode(i).getName().equals(idd)) {
            iddNode = graph.getNode(i);
        }
    }

    val arr = TFGraphMapper.getNDArrayFromTensor(iddNode);

    if (arr != null) {
        int idx = arr.getInt(0);
        addIArgument(idx);
    }

    this.tensorArrayDataType = TFGraphMapper.convertType(attributesForNode.get("dtype").getType());
}
 
Example #3
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
    public void testWhileMapping2() throws Exception {
        Nd4j.create(1);
        val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
        assertNotNull(tg);
        val input = Nd4j.scalar(4.0);
        tg.associateArrayWithVariable(input, tg.getVariable("input_1"));

        tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0_4.fb"));

        //log.info("{}", tg.asFlatPrint());
/*
        val array = tg.outputAll(null).get(tg.outputs().get(0));
        val exp = Nd4j.create(2, 2).assign(2);
        assertNotNull(array);
        assertEquals(exp, array);*/
    }
 
Example #4
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testIntermediateLoop3() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/nested_while.pb.txt").getInputStream());

    assertNotNull(tg);

    // now converting to FlatBuffer
    val fb = tg.asFlatBuffers(true);
    assertNotNull(fb);

    val graph = FlatGraph.getRootAsFlatGraph(fb);
    assertEquals(15, graph.variablesLength());

    //assertEquals("phi/Assign", graph.nodes(0).name());
    //assertEquals("alpha/Assign", graph.nodes(1).name());

    assertEquals(2, graph.nodes(0).inputPairedLength());
    assertEquals(2, graph.nodes(1).inputPairedLength());

    //   tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/nested_while.fb"));
}
 
Example #5
Source File: TensorFlowImportTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
@Ignore
public void importGraph1() throws Exception {
    SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/max_add_2.pb.txt").getInputStream());

    assertNotNull(graph);

    assertEquals(2, graph.variableMap().size());

    SDVariable var0 = graph.variableMap().get("zeros");
    SDVariable var1 = graph.variableMap().get("ones");

    assertNotNull(var0);
    assertNotNull(var1);

    assertNotNull(var0.getArr());
    assertNotNull(var1.getArr());

    assertEquals(0.0, var0.getArr().sumNumber().doubleValue(), 1e-5);
    assertEquals(12.0, var1.getArr().sumNumber().doubleValue(), 1e-5);
}
 
Example #6
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
@Ignore
public void testCrash_119_transpose() throws Exception {
    Nd4j.create(1);

    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/transpose/frozen_model.pb").getInputStream());
    assertNotNull(tg);

    val input0 = Nd4j.create(new double[]{0.98114507, 0.96400015, 0.58669623, 0.60073098, 0.75425418, 0.44258752, 0.76373084, 0.96593234, 0.34067846}, new int[] {3, 3});
    val input1 = Nd4j.create(new double[]{0.98114507, 0.60073098, 0.76373084, 0.96400015, 0.75425418, 0.96593234, 0.58669623, 0.44258752, 0.34067846}, new int[] {3, 3});

    tg.associateArrayWithVariable(input0, tg.getVariable("input"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));

    tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/transpose.fb"));
}
 
Example #7
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testWhileDualMapping1() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    val input0 = Nd4j.create(2, 2).assign(-4.0);
    val input1 = Nd4j.scalar(1.0);
    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));

    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_1.fb"));

    //log.info("{}", tg.asFlatPrint());

    INDArray array = tg.outputAll(null).get(tg.outputs().get(0));
    val exp = Nd4j.create(2, 2).assign(-1);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
Example #8
Source File: TensorFlowImportTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testLenet() throws Exception {
    /**
     * Produced with:
     * python  ~/anaconda2/lib/python2.7/site-packages/tensorflow/python/tools/freeze_graph.py  --input_graph=graph2.pb.txt  --input_checkpoint=test3.ckpt  --output_graph=graph_frozen2.pb  --output_node_name=output/BiasAdd --input_binary=False

     */

    Nd4j.create(1);
    val rawGraph = GraphDef.parseFrom(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream());
    val nodeNames = rawGraph.getNodeList().stream().map(node -> node.getName()).collect(Collectors.toList());
    System.out.println(nodeNames);
    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream());


    val convNode = tg.getVariable("conv2d/kernel");
    assertNotNull(convNode.getArr());
    val shape = convNode.getShape();
    System.out.println(Arrays.toString(shape));

    // this is NHWC weights. will be changed soon.
    assertArrayEquals(new int[]{5,5,1,32}, shape);
    System.out.println(convNode);
}
 
Example #9
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testWhileDualMapping2() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    val input0 = Nd4j.create(2, 2).assign(-9.0);
    val input1 = Nd4j.scalar(1.0);
    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));

    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_1.fb"));

    //log.info("{}", tg.asFlatPrint());

    val array = tg.outputAll(null).get(tg.outputs().get(0));
    val exp = Nd4j.create(2, 2).assign(-3);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
Example #10
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
@Ignore
public void testCrash_119_matrix_diag() throws Exception {
    Nd4j.create(1);

    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/partition_stitch_misc/frozen_model.pb").getInputStream());
    assertNotNull(tg);

    val input0 = Nd4j.create(2, 5, 4).assign(1.0);
    val input1 = Nd4j.create(2, 3, 5, 4).assign(2.0);
    val input2 = Nd4j.create(3, 1, 5, 4).assign(3.0);
    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
    tg.associateArrayWithVariable(input2, tg.getVariable("input_2"));


    tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/partition_stitch_misc.fb"));
}
 
Example #11
Source File: TensorFlowImportTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
    @Ignore
    public void testIntermediateTensorArraySimple1() throws Exception {
        Nd4j.create(1);
        val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array.pb.txt").getInputStream());
        tg.updateVariable("input_matrix",Nd4j.ones(3,2));

        assertNotNull(tg);

        val firstSlice = tg.getVariable("strided_slice");


        val fb = tg.asFlatBuffers();
        assertNotNull(fb);

        val graph = FlatGraph.getRootAsFlatGraph(fb);
        assertEquals(36, graph.variablesLength());

        assertTrue(graph.nodesLength() > 1);
     /*   assertEquals("strided_slice", graph.nodes(0).name());
        assertEquals("TensorArray", graph.nodes(1).name());
*/
        //   assertEquals(4, graph.nodes(0).inputPairedLength());

        //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/tensor_array.fb"));
    }
 
Example #12
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testCondMapping2() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream());
    assertNotNull(tg);

    val input = Nd4j.create(2, 2).assign(-1);
    tg.associateArrayWithVariable(input, tg.getVariable("input_0"));
    tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0.fb"));

    //log.info("{}", tg.asFlatPrint());
    val array = tg.outputAll(null).get(tg.outputs().get(0));
    val exp = Nd4j.create(2, 2).assign(1);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
Example #13
Source File: GraphInferenceGrpcClientTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testSimpleGraph_1() throws Exception {
    val exp = Nd4j.create(new double[] {-0.95938617, -1.20301781, 1.22260064, 0.50172403, 0.59972949, 0.78568028, 0.31609724, 1.51674747, 0.68013491, -0.05227458, 0.25903158,1.13243439}, new long[]{3, 1, 4});

    // configuring client
    val client = new GraphInferenceGrpcClient("127.0.0.1", 40123);

    val graphId = RandomUtils.nextLong(0, Long.MAX_VALUE);

    // preparing and registering graph (it's optional, and graph might be embedded into Docker image
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/expand_dim/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    client.registerGraph(graphId, tg, ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build());

    //defining input
    val input0 = Nd4j.create(new double[] {0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743}, new int[] {3, 4});
    val operands = new Operands().addArgument("input_0", input0);

    // sending request and getting result
    val result = client.output(graphId, operands);
    assertEquals(exp, result.getById("output"));
}
 
Example #14
Source File: TensorFlowImportTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testMixedWhileCond1() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_nested/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    val input0 = Nd4j.create(2, 2).assign(1.0);
    val input1 = Nd4j.create(3, 3).assign(2.0);
    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));

    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_nested.fb"));


    //log.info("{}", tg.asFlatPrint());

    val array = tg.execAndEndResult();
    //val array = tg.getVariable("output").getArr();
    val exp = Nd4j.create(2, 2).assign(15.0);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
Example #15
Source File: TensorFlowImportTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testWhileDualMapping2() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    val input0 = Nd4j.create(2, 2).assign(-9.0);
    val input1 = Nd4j.trueScalar(1.0);
    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));

    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_1.fb"));

    //log.info("{}", tg.asFlatPrint());

    val array = tg.execAndEndResult();
    val exp = Nd4j.create(2, 2).assign(-3);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
Example #16
Source File: Create.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    // convert output data type
    if(attributesForNode.containsKey("dtype")) {
        outputType = TFGraphMapper.convertType(attributesForNode.get("dtype").getType());
    }

    // get init field
    if(attributesForNode.containsKey("init")) {
        initialize = attributesForNode.get("init").getB();
    }

    // there's no order in TF, just plain C
    this.order = 'c';
    addArgs();
}
 
Example #17
Source File: TensorFlowImportTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
    public void testWhileMapping2() throws Exception {
        Nd4j.create(1);
        val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
        assertNotNull(tg);
        val input = Nd4j.trueScalar(4.0);
        tg.associateArrayWithVariable(input, tg.getVariable("input_1"));

        tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0_4.fb"));

        //log.info("{}", tg.asFlatPrint());
/*
        val array = tg.execAndEndResult();
        val exp = Nd4j.create(2, 2).assign(2);
        assertNotNull(array);
        assertEquals(exp, array);*/
    }
 
Example #18
Source File: TensorFlowImportTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testWhileMapping1() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    val input = Nd4j.create(2, 2).assign(1);
    tg.associateArrayWithVariable(input, tg.getVariable("input_0"));

    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0_3.fb"));

    //log.info("{}", tg.asFlatPrint());


    val array = tg.execAndEndResult();
    val exp = Nd4j.create(2, 2).assign(1);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
Example #19
Source File: TensorFlowImportTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testCondMapping2() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream());
    assertNotNull(tg);

    val input = Nd4j.create(2, 2).assign(-1);
    tg.associateArrayWithVariable(input, tg.getVariable("input_0"));
    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0.fb"));

    //log.info("{}", tg.asFlatPrint());
    val array = tg.execAndEndResult();
    val exp = Nd4j.create(2, 2).assign(1);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
Example #20
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testLenet() throws Exception {
    /**
     * Produced with:
     * python  ~/anaconda2/lib/python2.7/site-packages/tensorflow/python/tools/freeze_graph.py  --input_graph=graph2.pb.txt  --input_checkpoint=test3.ckpt  --output_graph=graph_frozen2.pb  --output_node_name=output/BiasAdd --input_binary=False

     */

    Nd4j.create(1);
    val rawGraph = GraphDef.parseFrom(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream());
    val nodeNames = rawGraph.getNodeList().stream().map(node -> node.getName()).collect(Collectors.toList());
    System.out.println(nodeNames);
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream());


    val convNode = tg.getVariable("conv2d/kernel");
    assertNotNull(convNode.getArr());
    val shape = convNode.getShape();
    System.out.println(Arrays.toString(shape));

    // this is NHWC weights. will be changed soon.
    assertArrayEquals(new long[]{5,5,1,32}, shape);
    System.out.println(convNode);
}
 
Example #21
Source File: ScatterMin.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);

    if (nodeDef.containsAttr("use_locking")) {
        if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
            bArguments.add(true);
        } else {
            bArguments.add(false);
        }
    } else
        bArguments.add(false);
}
 
Example #22
Source File: Range.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
    if(attributesForNode.containsKey("Tidx")){
        dataType = TFGraphMapper.convertType(attributesForNode.get("Tidx").getType());
    }
    addDArgument(dataType);
}
 
Example #23
Source File: SameDiffModelLoader.java    From konduit-serving with Apache License 2.0 5 votes vote down vote up
@Override
public SameDiff loadModel() throws Exception {
    if (ModelGuesser.isTensorflowFile(pathToModel)) {
        log.debug("Loading tensorflow model from " + pathToModel.getAbsolutePath());
        return TFGraphMapper.importGraph(pathToModel);
    } else if (ModelGuesser.isSameDiffZipFile(pathToModel)) {
        return SameDiff.load(pathToModel, true);
    }

    log.debug("Loading samediff model from " + pathToModel.getAbsolutePath());
    return SameDiff.fromFlatFile(pathToModel);
}
 
Example #24
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
//@Ignore
public void testCrash_119_reduce_dim_false() throws Exception {
    Nd4j.create(1);

    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream());
    assertNotNull(tg);


    tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/reduce_dim_false.fb"), ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build(), true);
}
 
Example #25
Source File: BitCast.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
    val t = nodeDef.getAttrOrDefault("type", null);
    val type = ArrayOptionsHelper.convertToDataType(t.getType());
    addIArgument(type.toInt());

    dtype = type;
}
 
Example #26
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
@Ignore
public void testRandomGraph3() throws Exception {
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/assert_equal/3,4_3,4_float32/frozen_model.pb").getInputStream());
    assertNotNull(tg);

    log.info("{}", tg.asFlatPrint());
    tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/assertsomething.fb"));
}
 
Example #27
Source File: ScatterDiv.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);

    if (nodeDef.containsAttr("use_locking")) {
        if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
            bArguments.add(true);
        } else {
            bArguments.add(false);
        }
    } else
        bArguments.add(false);
}
 
Example #28
Source File: Relu6.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    //TF cutoff is always 0.0. Need to make sure scalar type is same as input type (due to scalar op 'same type' exec restrictions)
    if(attributesForNode.containsKey("T")){
        attributesForNode.get("T").getType();
        DataType dt = TFGraphMapper.convertType(attributesForNode.get("T").getType());
        scalarValue = Nd4j.scalar(dt, 0.0);
    }
}
 
Example #29
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testBoolImport_1() throws Exception {
    Nd4j.create(1);
    for (int e = 0; e < 1000; e++){
        val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/reduce_any/rank0/frozen_model.pb").getInputStream());

        Map<String,INDArray> result = tg.output(Collections.emptyMap(), tg.outputs());

        assertNotNull(result);
        assertTrue(result.size() > 0);
    }
}
 
Example #30
Source File: TensorFlowImportTest.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
@Ignore
public void testCrash_119_reduce_dim_true() throws Exception {
    Nd4j.create(1);

    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/reduce_dim_true.pb.txt").getInputStream());
    assertNotNull(tg);

    tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/reduce_dim_true.fb"), ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build());
}