Java Code Examples for org.tensorflow.framework.NodeDef#getInputCount()

The following examples show how to use org.tensorflow.framework.NodeDef#getInputCount() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: Fill.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    if(nodeDef.getInputCount() == 2) {
        val targetNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,nodeDef.getInput(1));
        val mapper = TFGraphMapper.getInstance();
        val secondInputAsScalar = mapper.getNDArrayFromTensor("value",targetNode,graph);
        //must be scalar
        if(secondInputAsScalar.length() == 1) {
            addTArgument(secondInputAsScalar.getDouble(0));
        }
        else {
            throw new ND4JIllegalStateException("Second input to node " + nodeDef + " should be scalar!");
        }
    }

}
 
Example 2
Source File: BaseAccumulation.java    From nd4j with Apache License 2.0 5 votes vote down vote up
protected boolean hasReductionIndices(NodeDef nodeDef) {
    for(int i = 0; i < nodeDef.getInputCount(); i++) {
        if(nodeDef.getInput(i).contains("reduction_indices")) {
            return true;
        }
    }

    return false;
}
 
Example 3
Source File: Transpose.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
    //permute dimensions are not specified as second input
    if (nodeDef.getInputCount() < 2)
        return;
    NodeDef permuteDimsNode = null;
    for (int i = 0; i < graph.getNodeCount(); i++) {
        if (graph.getNode(i).getName().equals(nodeDef.getInput(1))) {
            permuteDimsNode = graph.getNode(i);
        }

    }

    INDArray permuteArrayOp = TFGraphMapper.getNDArrayFromTensor(permuteDimsNode);
    if (permuteArrayOp != null) {
        this.permuteDims = permuteArrayOp.data().asInt();
    }

    //handle once properly mapped
    if (arg().getShape() == null || arg().getVariableType() == VariableType.PLACEHOLDER || arg().getArr() == null) {
        return;
    }

    INDArray arr = sameDiff.getArrForVarName(arg().name());

    if(permuteArrayOp != null){
        addInputArgument(arr, permuteArrayOp);
    } else {
        addInputArgument(arr);
    }

    if (arr != null && permuteDims == null) {
        this.permuteDims = ArrayUtil.reverseCopy(ArrayUtil.range(0, arr.rank()));
    }

    if (permuteDims != null && permuteDims.length < arg().getShape().length)
        throw new ND4JIllegalStateException("Illegal permute found. Not all dimensions specified");
}
 
Example 4
Source File: BaseReduceOp.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
protected boolean hasReductionIndices(NodeDef nodeDef) {
    for(int i = 0; i < nodeDef.getInputCount(); i++) {
        if(nodeDef.getInput(i).contains("reduction_indices")) {
            return true;
        }
    }

    return false;
}
 
Example 5
Source File: Concat.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    int concatDimension = -1;
    String input = null;
    for(int i = 0; i < nodeDef.getInputCount(); i++) {
        if(nodeDef.getInput(i).contains("/concat_dim")) {
            input = nodeDef.getInput(i);
            break;
        }
    }

    //older versions may specify a concat_dim, usually it's the last argument
    if(input == null) {
        input = nodeDef.getInput(nodeDef.getInputCount() - 1);
    }

    val variable = initWith.getVariable(input);
    // concat dimension is only possible
    if (variable != null && variable.getArr() == null) {
        sameDiff.addPropertyToResolve(this, input);

    } else if (variable != null) {
        val arr = variable.getArr();
        if (arr.length() == 1) {
            concatDimension = arr.getInt(0);
        }

        this.concatDimension = concatDimension;
        addIArgument(this.concatDimension);
        log.debug("Concat dimension: {}", concatDimension);

    }

    //don't pass both iArg and last axis down to libnd4j
    if(inputArguments().length == nodeDef.getInputCount()) {
        val inputArgs = inputArguments();
        removeInputArgument(inputArgs[inputArguments().length - 1]);
    }

    sameDiff.removeArgFromFunction(input,this);
}
 
Example 6
Source File: Transpose.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
    //permute dimensions re not specified as second input
    if (nodeDef.getInputCount() < 2)
        return;
    NodeDef permuteDimsNode = null;
    for (int i = 0; i < graph.getNodeCount(); i++) {
        if (graph.getNode(i).getName().equals(nodeDef.getInput(1))) {
            permuteDimsNode = graph.getNode(i);
        }

    }

    val permuteArrayOp = TFGraphMapper.getInstance().getNDArrayFromTensor("value", permuteDimsNode, graph);
    if (permuteArrayOp != null) {
        this.permuteDims = permuteArrayOp.data().asInt();
        for (int i = 0; i < permuteDims.length; i++) {
            addIArgument(permuteDims[i]);
        }
    }

    //handle once properly mapped
    if (arg().getShape() == null) {
        return;
    }


    INDArray arr = sameDiff.getArrForVarName(arg().getVarName());
    if (arr == null) {
        val arrVar = sameDiff.getVariable(arg().getVarName());
        arr = arrVar.getWeightInitScheme().create(arrVar.getShape());
        sameDiff.putArrayForVarName(arg().getVarName(), arr);
    }

    addInputArgument(arr);

    if (arr != null && permuteDims == null) {
        this.permuteDims = ArrayUtil.reverseCopy(ArrayUtil.range(0, arr.rank()));
    }

    if (permuteDims != null && permuteDims.length < arg().getShape().length)
        throw new ND4JIllegalStateException("Illegal permute found. Not all dimensions specified");


}
 
Example 7
Source File: GraphRunner.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
private void initSessionAndStatusIfNeeded(org.tensorflow.framework.GraphDef graphDef1) {
    //infer the inputs and outputSchema for the graph
    Set<String> seenAsInput = new LinkedHashSet<>();
    for(int i = 0; i < graphDef1.getNodeCount(); i++) {
        NodeDef node = graphDef1.getNode(i);
        for(int input = 0; input < node.getInputCount(); input++) {
            seenAsInput.add(node.getInput(input));
        }
    }

    if(outputOrder == null) {
        outputOrder = new ArrayList<>();
        log.trace("Attempting to automatically resolve tensorflow output names..");
        //find the nodes that were not inputs to any  nodes: these are the outputSchema
        for(int i = 0; i < graphDef1.getNodeCount(); i++) {
            if(!seenAsInput.contains(graphDef1.getNode(i).getName()) && !graphDef1.getNode(i).getOp().equals("Placeholder")) {
                outputOrder.add(graphDef1.getNode(i).getName());
            }
        }

        //multiple names: purge any generated names from the output
        if(outputOrder.size() > 1) {
            Set<String> remove = new HashSet<>();
            for (String name : outputOrder) {
                if(name.contains("/")) {
                    remove.add(name);
                }
            }

            outputOrder.removeAll(remove);
        }
    }


    //setup and configure the session, factoring
    //in the ConfigObject as needed
    if(session == null) {
        initOptionsIfNeeded();
        session = TF_NewSession(graph, options, status);
        if (TF_GetCode(status) != TF_OK) {
            throw new IllegalStateException("ERROR: Unable to open session " + TF_Message(status).getString());
        }

    }

}