Java Code Examples for org.tensorflow.framework.GraphDef#getNodeCount()

The following examples show how to use org.tensorflow.framework.GraphDef#getNodeCount() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TensorArrayV3.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {

    val idd = nodeDef.getInput(nodeDef.getInputCount() - 1);
    NodeDef iddNode = null;
    for(int i = 0; i < graph.getNodeCount(); i++) {
        if(graph.getNode(i).getName().equals(idd)) {
            iddNode = graph.getNode(i);
        }
    }


    val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",iddNode,graph);

    if (arr != null) {
        int idx = arr.getInt(0);
        addIArgument(idx);
    }

}
 
Example 2
Source File: InTopK.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {

    String thisName = nodeDef.getName();
    String inputName = thisName + "/k";
    NodeDef kNode = null;
    for(int i = 0; i < graph.getNodeCount(); i++) {
        if(graph.getNode(i).getName().equals(inputName)){
            kNode = graph.getNode(i);
            break;
        }
    }
    Preconditions.checkState(kNode != null, "Could not find 'k' parameter node for op: %s", thisName);

    INDArray arr = TFGraphMapper.getNDArrayFromTensor(kNode);
    this.k = arr.getInt(0);
    addIArgument(k);
}
 
Example 3
Source File: TensorArray.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val idd = nodeDef.getInput(nodeDef.getInputCount() - 1);
    NodeDef iddNode = null;
    for(int i = 0; i < graph.getNodeCount(); i++) {
        if(graph.getNode(i).getName().equals(idd)) {
            iddNode = graph.getNode(i);
        }
    }

    val arr = TFGraphMapper.getNDArrayFromTensor(iddNode);

    if (arr != null) {
        int idx = arr.getInt(0);
        addIArgument(idx);
    }

    this.tensorArrayDataType = TFGraphMapper.convertType(attributesForNode.get("dtype").getType());
}
 
Example 4
Source File: TopK.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {

    String thisName = nodeDef.getName();

    // FIXME: ????
    String inputName = thisName + "/k";
    NodeDef kNode = null;
    for(int i = 0; i < graph.getNodeCount(); i++) {
        if(graph.getNode(i).getName().equals(inputName)){
            kNode = graph.getNode(i);
            break;
        }
    }

    this.sorted = nodeDef.getAttrOrThrow("sorted").getB();

    if (kNode != null) {
        Preconditions.checkState(kNode != null, "Could not find 'k' parameter node for op: %s", thisName);

        INDArray arr = TFGraphMapper.getNDArrayFromTensor(kNode);
        this.k = arr.getInt(0);

        addIArgument(ArrayUtil.fromBoolean(sorted), k);
    } else
        addIArgument(ArrayUtil.fromBoolean(sorted));
}
 
Example 5
Source File: Transpose.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
    //permute dimensions are not specified as second input
    if (nodeDef.getInputCount() < 2)
        return;
    NodeDef permuteDimsNode = null;
    for (int i = 0; i < graph.getNodeCount(); i++) {
        if (graph.getNode(i).getName().equals(nodeDef.getInput(1))) {
            permuteDimsNode = graph.getNode(i);
        }

    }

    INDArray permuteArrayOp = TFGraphMapper.getNDArrayFromTensor(permuteDimsNode);
    if (permuteArrayOp != null) {
        this.permuteDims = permuteArrayOp.data().asInt();
    }

    //handle once properly mapped
    if (arg().getShape() == null || arg().getVariableType() == VariableType.PLACEHOLDER || arg().getArr() == null) {
        return;
    }

    INDArray arr = sameDiff.getArrForVarName(arg().name());

    if(permuteArrayOp != null){
        addInputArgument(arr, permuteArrayOp);
    } else {
        addInputArgument(arr);
    }

    if (arr != null && permuteDims == null) {
        this.permuteDims = ArrayUtil.reverseCopy(ArrayUtil.range(0, arr.rank()));
    }

    if (permuteDims != null && permuteDims.length < arg().getShape().length)
        throw new ND4JIllegalStateException("Illegal permute found. Not all dimensions specified");
}
 
Example 6
Source File: Transpose.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
    //permute dimensions re not specified as second input
    if (nodeDef.getInputCount() < 2)
        return;
    NodeDef permuteDimsNode = null;
    for (int i = 0; i < graph.getNodeCount(); i++) {
        if (graph.getNode(i).getName().equals(nodeDef.getInput(1))) {
            permuteDimsNode = graph.getNode(i);
        }

    }

    val permuteArrayOp = TFGraphMapper.getInstance().getNDArrayFromTensor("value", permuteDimsNode, graph);
    if (permuteArrayOp != null) {
        this.permuteDims = permuteArrayOp.data().asInt();
        for (int i = 0; i < permuteDims.length; i++) {
            addIArgument(permuteDims[i]);
        }
    }

    //handle once properly mapped
    if (arg().getShape() == null) {
        return;
    }


    INDArray arr = sameDiff.getArrForVarName(arg().getVarName());
    if (arr == null) {
        val arrVar = sameDiff.getVariable(arg().getVarName());
        arr = arrVar.getWeightInitScheme().create(arrVar.getShape());
        sameDiff.putArrayForVarName(arg().getVarName(), arr);
    }

    addInputArgument(arr);

    if (arr != null && permuteDims == null) {
        this.permuteDims = ArrayUtil.reverseCopy(ArrayUtil.range(0, arr.rank()));
    }

    if (permuteDims != null && permuteDims.length < arg().getShape().length)
        throw new ND4JIllegalStateException("Illegal permute found. Not all dimensions specified");


}
 
Example 7
Source File: StridedSlice.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val inputBegin = nodeDef.getInput(1);
    val inputEnd = nodeDef.getInput(2);
    val inputStrides = nodeDef.getInput(3);

    NodeDef beginNode = null;
    NodeDef endNode = null;
    NodeDef strides = null;

    for(int i = 0; i < graph.getNodeCount(); i++) {
        if(graph.getNode(i).getName().equals(inputBegin)) {
            beginNode = graph.getNode(i);
        }
        if(graph.getNode(i).getName().equals(inputEnd)) {
            endNode = graph.getNode(i);
        }
        if(graph.getNode(i).getName().equals(inputStrides)) {
            strides = graph.getNode(i);
        }
    }


    // bit masks for this slice
    val bm = nodeDef.getAttrOrThrow("begin_mask");
    val xm = nodeDef.getAttrOrThrow("ellipsis_mask");
    val em = nodeDef.getAttrOrThrow("end_mask");
    val nm = nodeDef.getAttrOrThrow("new_axis_mask");
    val sm = nodeDef.getAttrOrThrow("shrink_axis_mask");

    addIArgument((int) bm.getI());
    addIArgument((int) xm.getI());
    addIArgument((int) em.getI());

    addIArgument((int) nm.getI());
    addIArgument((int) sm.getI());

    val beginArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",beginNode,graph);
    val endArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",endNode,graph);
    val stridesArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",strides,graph);

    if (beginArr != null && endArr != null && stridesArr != null) {

        for (int e = 0; e < beginArr.length(); e++)
            addIArgument(beginArr.getInt(e));

        for (int e = 0; e <  endArr.length(); e++)
            addIArgument(endArr.getInt(e));

        for (int e = 0; e < stridesArr.length(); e++)
            addIArgument(stridesArr.getInt(e));
    } else {
        // do nothing
    }



}
 
Example 8
Source File: Slice.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
     /*
        strided slice typically takes 4 tensor arguments:
        0) input, it's shape determines number of elements in other arguments
        1) begin indices
        2) end indices
        3) strides
     */

    val inputBegin = nodeDef.getInput(1);
    val inputEnd = nodeDef.getInput(2);

    NodeDef beginNode = null;
    NodeDef endNode = null;

    for(int i = 0; i < graph.getNodeCount(); i++) {
        if(graph.getNode(i).getName().equals(inputBegin)) {
            beginNode = graph.getNode(i);
        }
        if(graph.getNode(i).getName().equals(inputEnd)) {
            endNode = graph.getNode(i);
        }

    }



    val beginArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",beginNode,graph);
    val endArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",endNode,graph);

    if (beginArr != null && endArr != null) {

        for (int e = 0; e < beginArr.length(); e++)
            addIArgument(beginArr.getInt(e));

        for (int e = 0; e <  endArr.length(); e++)
            addIArgument(endArr.getInt(e));


    } else {
        // do nothing
    }



}
 
Example 9
Source File: TensorFlowImportValidator.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public static TFImportStatus checkModelForImport(String path, InputStream is, boolean exceptionOnRead) throws IOException {

        try {
            int opCount = 0;
            Set<String> opNames = new HashSet<>();
            Map<String,Integer> opCounts = new HashMap<>();

            try(InputStream bis = new BufferedInputStream(is)) {
                GraphDef graphDef = GraphDef.parseFrom(bis);
                List<NodeDef> nodes = new ArrayList<>(graphDef.getNodeCount());
                for( int i=0; i<graphDef.getNodeCount(); i++ ){
                    nodes.add(graphDef.getNode(i));
                }

                if(nodes.isEmpty()){
                    throw new IllegalStateException("Error loading model for import - loaded graph def has no nodes (empty/corrupt file?): " + path);
                }

                for (NodeDef nd : nodes) {
                    if (TFGraphMapper.isVariableNode(nd) || TFGraphMapper.isPlaceHolder(nd))
                        continue;

                    String op = nd.getOp();
                    opNames.add(op);
                    int soFar = opCounts.containsKey(op) ? opCounts.get(op) : 0;
                    opCounts.put(op, soFar + 1);
                    opCount++;
                }
            }

            Set<String> importSupportedOpNames = new HashSet<>();
            Set<String> unsupportedOpNames = new HashSet<>();
            Map<String,Set<String>> unsupportedOpModel = new HashMap<>();

            for (String s : opNames) {
                if (DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(s) != null) {
                    importSupportedOpNames.add(s);
                } else {
                    unsupportedOpNames.add(s);
                    if(unsupportedOpModel.containsKey(s)) {
                        continue;
                    } else {
                        Set<String> l = new HashSet<>();
                        l.add(path);
                        unsupportedOpModel.put(s, l);
                    }

                }
            }




            return new TFImportStatus(
                    Collections.singletonList(path),
                    unsupportedOpNames.size() > 0 ? Collections.singletonList(path) : Collections.<String>emptyList(),
                    Collections.<String>emptyList(),
                    opCount,
                    opNames.size(),
                    opNames,
                    opCounts,
                    importSupportedOpNames,
                    unsupportedOpNames,
                    unsupportedOpModel);
        } catch (Throwable t){
            if(exceptionOnRead) {
                throw new IOException("Error reading model from path " + path + " - not a TensorFlow frozen model in ProtoBuf format?", t);
            }
            log.warn("Failed to import model from: " + path + " - not a TensorFlow frozen model in ProtoBuf format?", t);
            return new TFImportStatus(
                    Collections.<String>emptyList(),
                    Collections.<String>emptyList(),
                    Collections.singletonList(path),
                    0,
                    0,
                    Collections.<String>emptySet(),
                    Collections.<String, Integer>emptyMap(),
                    Collections.<String>emptySet(),
                    Collections.<String>emptySet(),
                    Collections.<String, Set<String>>emptyMap());
        }
    }