Java Code Examples for org.tensorflow.framework.NodeDef#getInput()

The following examples show how to use org.tensorflow.framework.NodeDef#getInput() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TensorArrayV3.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {

    val idd = nodeDef.getInput(nodeDef.getInputCount() - 1);
    NodeDef iddNode = null;
    for(int i = 0; i < graph.getNodeCount(); i++) {
        if(graph.getNode(i).getName().equals(idd)) {
            iddNode = graph.getNode(i);
        }
    }


    val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",iddNode,graph);

    if (arr != null) {
        int idx = arr.getInt(0);
        addIArgument(idx);
    }

}
 
Example 2
Source File: StridedSlice.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val inputBegin = nodeDef.getInput(1);
    val inputEnd = nodeDef.getInput(2);
    val inputStrides = nodeDef.getInput(3);

    // bit masks for this slice
    val bm = nodeDef.getAttrOrThrow("begin_mask");
    val xm = nodeDef.getAttrOrThrow("ellipsis_mask");
    val em = nodeDef.getAttrOrThrow("end_mask");
    val nm = nodeDef.getAttrOrThrow("new_axis_mask");
    val sm = nodeDef.getAttrOrThrow("shrink_axis_mask");

    beginMask = (int)bm.getI();
    ellipsisMask = (int) xm.getI();
    endMask = (int) em.getI();
    newAxisMask = (int) nm.getI();
    shrinkAxisMask = (int) sm.getI();

    addIArgument(beginMask);
    addIArgument(ellipsisMask);
    addIArgument(endMask);
    addIArgument(newAxisMask);
    addIArgument(shrinkAxisMask);
}
 
Example 3
Source File: TensorArray.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val idd = nodeDef.getInput(nodeDef.getInputCount() - 1);
    NodeDef iddNode = null;
    for(int i = 0; i < graph.getNodeCount(); i++) {
        if(graph.getNode(i).getName().equals(idd)) {
            iddNode = graph.getNode(i);
        }
    }

    val arr = TFGraphMapper.getNDArrayFromTensor(iddNode);

    if (arr != null) {
        int idx = arr.getInt(0);
        addIArgument(idx);
    }

    this.tensorArrayDataType = TFGraphMapper.convertType(attributesForNode.get("dtype").getType());
}
 
Example 4
Source File: Pow.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val weightsName = nodeDef.getInput(1);
    val variable = initWith.getVariable(weightsName);
    val tmp = initWith.getArrForVarName(weightsName);

    // if second argument is scalar - we should provide array of same shape
    if (tmp != null) {
        if (tmp.isScalar()) {
            this.pow = tmp.getDouble(0);
        }
    }
}
 
Example 5
Source File: BaseTensorOp.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val inputOne = nodeDef.getInput(1);
    val varFor = initWith.getVariable(inputOne);
    val nodeWithIndex = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,inputOne);
    val var = TFGraphMapper.getInstance().getArrayFrom(nodeWithIndex,graph);
    if(var != null) {
        val idx = var.getInt(0);
        addIArgument(idx);
    }
}
 
Example 6
Source File: BaseTensorOp.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val inputOne = nodeDef.getInput(1);
    val varFor = initWith.getVariable(inputOne);
    val nodeWithIndex = TFGraphMapper.getNodeWithNameFromGraph(graph,inputOne);
    val var = TFGraphMapper.getArrayFrom(nodeWithIndex,graph);
    if(var != null) {
        val idx = var.getInt(0);
        addIArgument(idx);
    }
}
 
Example 7
Source File: SavedModel.java    From jpmml-tensorflow with GNU Affero General Public License v3.0 4 votes vote down vote up
private void initializeTables(){
	Collection<String> tableInitializerNames = Collections.emptyList();

	try {
		CollectionDef collectionDef = getCollectionDef("table_initializer");

		CollectionDef.NodeList nodeList = collectionDef.getNodeList();

		tableInitializerNames = nodeList.getValueList();
	} catch(IllegalArgumentException iae){
		// Ignored
	}

	for(String tableInitializerName : tableInitializerNames){
		NodeDef tableInitializer = getNodeDef(tableInitializerName);

		String name = tableInitializer.getInput(0);

		List<?> keys;
		List<?> values;

		try(Tensor tensor = run(tableInitializer.getInput(1))){
			keys = TensorUtil.getValues(tensor);
		} // End try

		try(Tensor tensor = run(tableInitializer.getInput(2))){
			values = TensorUtil.getValues(tensor);
		}

		Map<Object, Object> table = new LinkedHashMap<>();

		if(keys.size() != values.size()){
			throw new IllegalArgumentException();
		}

		for(int i = 0; i < keys.size(); i++){
			table.put(keys.get(i), values.get(i));
		}

		putTable(name, table);
	}
}
 
Example 8
Source File: Concat.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    int concatDimension = -1;
    String input = null;
    for(int i = 0; i < nodeDef.getInputCount(); i++) {
        if(nodeDef.getInput(i).contains("/concat_dim")) {
            input = nodeDef.getInput(i);
            break;
        }
    }

    //older versions may specify a concat_dim, usually it's the last argument
    if(input == null) {
        input = nodeDef.getInput(nodeDef.getInputCount() - 1);
    }

    val variable = initWith.getVariable(input);
    // concat dimension is only possible
    if (variable != null && variable.getArr() == null) {
        sameDiff.addPropertyToResolve(this, input);

    } else if (variable != null) {
        val arr = variable.getArr();
        if (arr.length() == 1) {
            concatDimension = arr.getInt(0);
        }

        this.concatDimension = concatDimension;
        addIArgument(this.concatDimension);
        log.debug("Concat dimension: {}", concatDimension);

    }

    //don't pass both iArg and last axis down to libnd4j
    if(inputArguments().length == nodeDef.getInputCount()) {
        val inputArgs = inputArguments();
        removeInputArgument(inputArgs[inputArguments().length - 1]);
    }

    sameDiff.removeArgFromFunction(input,this);
}
 
Example 9
Source File: StridedSlice.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val inputBegin = nodeDef.getInput(1);
    val inputEnd = nodeDef.getInput(2);
    val inputStrides = nodeDef.getInput(3);

    NodeDef beginNode = null;
    NodeDef endNode = null;
    NodeDef strides = null;

    for(int i = 0; i < graph.getNodeCount(); i++) {
        if(graph.getNode(i).getName().equals(inputBegin)) {
            beginNode = graph.getNode(i);
        }
        if(graph.getNode(i).getName().equals(inputEnd)) {
            endNode = graph.getNode(i);
        }
        if(graph.getNode(i).getName().equals(inputStrides)) {
            strides = graph.getNode(i);
        }
    }


    // bit masks for this slice
    val bm = nodeDef.getAttrOrThrow("begin_mask");
    val xm = nodeDef.getAttrOrThrow("ellipsis_mask");
    val em = nodeDef.getAttrOrThrow("end_mask");
    val nm = nodeDef.getAttrOrThrow("new_axis_mask");
    val sm = nodeDef.getAttrOrThrow("shrink_axis_mask");

    addIArgument((int) bm.getI());
    addIArgument((int) xm.getI());
    addIArgument((int) em.getI());

    addIArgument((int) nm.getI());
    addIArgument((int) sm.getI());

    val beginArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",beginNode,graph);
    val endArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",endNode,graph);
    val stridesArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",strides,graph);

    if (beginArr != null && endArr != null && stridesArr != null) {

        for (int e = 0; e < beginArr.length(); e++)
            addIArgument(beginArr.getInt(e));

        for (int e = 0; e <  endArr.length(); e++)
            addIArgument(endArr.getInt(e));

        for (int e = 0; e < stridesArr.length(); e++)
            addIArgument(stridesArr.getInt(e));
    } else {
        // do nothing
    }



}
 
Example 10
Source File: Slice.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
     /*
        strided slice typically takes 4 tensor arguments:
        0) input, it's shape determines number of elements in other arguments
        1) begin indices
        2) end indices
        3) strides
     */

    val inputBegin = nodeDef.getInput(1);
    val inputEnd = nodeDef.getInput(2);

    NodeDef beginNode = null;
    NodeDef endNode = null;

    for(int i = 0; i < graph.getNodeCount(); i++) {
        if(graph.getNode(i).getName().equals(inputBegin)) {
            beginNode = graph.getNode(i);
        }
        if(graph.getNode(i).getName().equals(inputEnd)) {
            endNode = graph.getNode(i);
        }

    }



    val beginArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",beginNode,graph);
    val endArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",endNode,graph);

    if (beginArr != null && endArr != null) {

        for (int e = 0; e < beginArr.length(); e++)
            addIArgument(beginArr.getInt(e));

        for (int e = 0; e <  endArr.length(); e++)
            addIArgument(endArr.getInt(e));


    } else {
        // do nothing
    }



}