Java Code Examples for org.tensorflow.framework.NodeDef#getAttrOrThrow()

The following examples show how to use org.tensorflow.framework.NodeDef#getAttrOrThrow() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: StridedSlice.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val inputBegin = nodeDef.getInput(1);
    val inputEnd = nodeDef.getInput(2);
    val inputStrides = nodeDef.getInput(3);

    // bit masks for this slice
    val bm = nodeDef.getAttrOrThrow("begin_mask");
    val xm = nodeDef.getAttrOrThrow("ellipsis_mask");
    val em = nodeDef.getAttrOrThrow("end_mask");
    val nm = nodeDef.getAttrOrThrow("new_axis_mask");
    val sm = nodeDef.getAttrOrThrow("shrink_axis_mask");

    beginMask = (int)bm.getI();
    ellipsisMask = (int) xm.getI();
    endMask = (int) em.getI();
    newAxisMask = (int) nm.getI();
    shrinkAxisMask = (int) sm.getI();

    addIArgument(beginMask);
    addIArgument(ellipsisMask);
    addIArgument(endMask);
    addIArgument(newAxisMask);
    addIArgument(shrinkAxisMask);
}
 
Example 2
Source File: LocalResponseNormalization.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {

    val aAlpha = nodeDef.getAttrOrThrow("alpha");
    val aBeta = nodeDef.getAttrOrThrow("beta");
    val aBias = nodeDef.getAttrOrThrow("bias");
    val aDepth = nodeDef.getAttrOrThrow("depth_radius");

    val alpha = aAlpha.getF();
    val beta = aBeta.getF();
    val bias = aBias.getF();
    val depth = aDepth.getF();

    LocalResponseNormalizationConfig localResponseNormalizationConfig = LocalResponseNormalizationConfig.builder()
            .alpha(alpha)
            .beta(beta)
            .bias(bias)
            .depth((int) depth)
            .build();
    this.config = localResponseNormalizationConfig;
    addArgs();
}
 
Example 3
Source File: LocalResponseNormalization.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {

    val aAlpha = nodeDef.getAttrOrThrow("alpha");
    val aBeta = nodeDef.getAttrOrThrow("beta");
    val aBias = nodeDef.getAttrOrThrow("bias");
    val aDepth = nodeDef.getAttrOrThrow("depth_radius");

    double alpha = aAlpha.getF();
    double beta = aBeta.getF();
    double bias = aBias.getF();
    int depth = (int)aDepth.getI();

    LocalResponseNormalizationConfig localResponseNormalizationConfig = LocalResponseNormalizationConfig.builder()
            .alpha(alpha)
            .beta(beta)
            .bias(bias)
            .depth((int) depth)
            .build();
    this.config = localResponseNormalizationConfig;
    addArgs();
}
 
Example 4
Source File: Unstack.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val attrAxis = nodeDef.getAttrOrThrow("axis");
    int axis = (int) attrAxis.getI();
    this.jaxis = axis;
    val attrNum = nodeDef.getAttrOrDefault("num", null);
    if(attrNum != null){
        this.num = (int) attrNum.getI();
    }
    addArgs();
}
 
Example 5
Source File: Unstack.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val attrAxis = nodeDef.getAttrOrThrow("axis");
    int axis = (int) attrAxis.getI();
    this.axis = axis;
    addArgs();
}
 
Example 6
Source File: Pooling2D.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val aStrides = nodeDef.getAttrOrThrow("strides");
    val tfStrides = aStrides.getList().getIList();
    val sH = tfStrides.get(1);
    val sW = tfStrides.get(2);

    val aKernels = nodeDef.getAttrOrThrow("ksize");
    val tfKernels = aKernels.getList().getIList();

    val kH = tfKernels.get(1);
    val kW = tfKernels.get(2);

    val aPadding = nodeDef.getAttrOrThrow("padding");
    val padding = aPadding.getList().getIList();

    val paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"","");

    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");

    if (!isSameMode)
        log.debug("Mode: {}", paddingMode);

    Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder()
            .sH(sH.intValue())
            .sW(sW.intValue())
            .type(null)
            .isSameMode(isSameMode)
            .kH(kH.intValue())
            .kW(kW.intValue())
            .pH(padding.get(0).intValue())
            .pW(padding.get(1).intValue())
            .build();
    this.config = pooling2DConfig;
    addArgs();
    log.debug("Pooling: k: [{},{}]; s: [{}, {}], padding: {}", kH, kW, sH, sW, aPadding);


}
 
Example 7
Source File: MaxPooling2D.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val aStrides = nodeDef.getAttrOrThrow("strides");
    val tfStrides = aStrides.getList().getIList();

    val aKernels = nodeDef.getAttrOrThrow("ksize");
    val tfKernels = aKernels.getList().getIList();

    int sH = 0;
    int sW = 0;

    int pH = 0;
    int pW = 0;

    int kH = 0;
    int kW = 0;

    val aPadding = nodeDef.getAttrOrThrow("padding");
    val padding = aPadding.getList().getIList();

    val paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"", "");

    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");

    String data_format = "nhwc";
    if (nodeDef.containsAttr("data_format")) {
        val attr = nodeDef.getAttrOrThrow("data_format");

        data_format = attr.getS().toStringUtf8().toLowerCase();
    }

    if (data_format.equalsIgnoreCase("nhwc")) {
        sH = tfStrides.get(1).intValue();
        sW = tfStrides.get(2).intValue();

        kH = tfKernels.get(1).intValue();
        kW = tfKernels.get(2).intValue();

        pH = padding.size() > 0 ? padding.get(1).intValue() : 0;
        pW = padding.size() > 0 ? padding.get(2).intValue() : 0;
    } else {
        sH = tfStrides.get(2).intValue();
        sW = tfStrides.get(3).intValue();

        kH = tfKernels.get(2).intValue();
        kW = tfKernels.get(3).intValue();

        pH = padding.size() > 0 ? padding.get(2).intValue() : 0;
        pW = padding.size() > 0 ? padding.get(3).intValue() : 0;
    }

    Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder()
            .sH(sH)
            .sW(sW)
            .type(Pooling2D.Pooling2DType.MAX)
            .isSameMode(isSameMode)
            .kH(kH)
            .kW(kW)
            .pH(pH)
            .pW(pW)
            .isNHWC(data_format.equalsIgnoreCase("nhwc"))
            .extra(1.0) // averaging only for non-padded values
            .build();
    this.config = pooling2DConfig;
    addArgs();
}
 
Example 8
Source File: AvgPooling2D.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val aStrides = nodeDef.getAttrOrThrow("strides");
    val tfStrides = aStrides.getList().getIList();

    val aKernels = nodeDef.getAttrOrThrow("ksize");
    val tfKernels = aKernels.getList().getIList();

    int sH = 0;
    int sW = 0;

    int pH = 0;
    int pW = 0;

    int kH = 0;
    int kW = 0;

    val aPadding = nodeDef.getAttrOrThrow("padding");
    val padding = aPadding.getList().getIList();

    val paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"", "");

    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");

    String data_format = "nhwc";
    if (nodeDef.containsAttr("data_format")) {
        val attr = nodeDef.getAttrOrThrow("data_format");

        data_format = attr.getS().toStringUtf8().toLowerCase();
    }

    if (data_format.equalsIgnoreCase("nhwc")) {
        sH = tfStrides.get(1).intValue();
        sW = tfStrides.get(2).intValue();

        kH = tfKernels.get(1).intValue();
        kW = tfKernels.get(2).intValue();

        pH = padding.size() > 0 ? padding.get(1).intValue() : 0;
        pW = padding.size() > 0 ? padding.get(2).intValue() : 0;
    } else {
        sH = tfStrides.get(2).intValue();
        sW = tfStrides.get(3).intValue();

        kH = tfKernels.get(2).intValue();
        kW = tfKernels.get(3).intValue();

        pH = padding.size() > 0 ? padding.get(2).intValue() : 0;
        pW = padding.size() > 0 ? padding.get(3).intValue() : 0;
    }

    Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder()
            .sH(sH)
            .sW(sW)
            .type(Pooling2D.Pooling2DType.AVG)
            .isSameMode(isSameMode)
            .kH(kH)
            .kW(kW)
            .pH(pH)
            .pW(pW)
            .isNHWC(data_format.equalsIgnoreCase("nhwc"))
            .extra(0.0) // averaging only for non-padded values
            .build();
    this.config = pooling2DConfig;
    addArgs();
    //log.debug("Pooling: k: [{},{}]; s: [{}, {}], padding: {}", kH, kW, sH, sW, aPadding);
}
 
Example 9
Source File: Pooling3D.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val aStrides = nodeDef.getAttrOrThrow("strides");
    List<Long> tfStrides = aStrides.getList().getIList();
    val aKernels = nodeDef.getAttrOrThrow("ksize");
    List<Long> tfKernels = aKernels.getList().getIList();
    val aPadding = nodeDef.getAttrOrThrow("padding");
    List<Long> tfPadding = aPadding.getList().getIList();

    String paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"", "");

    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");

    String data_format = "ndhwc";
    if (nodeDef.containsAttr("data_format")) {
        val attr = nodeDef.getAttrOrThrow("data_format");

        data_format = attr.getS().toStringUtf8().toLowerCase();
    }

    //Order: depth, height, width
    //TF doesn't have dilation, it seems?
    int[] strides = new int[3];
    int[] padding = new int[3];
    int[] kernel = new int[3];
    for( int i=0; i<3; i++ ) {
        //TF values here have 5 values: minibatch and Channels at positions 0 and 4, which are almost always 1
        strides[i] = tfStrides.get(i+1).intValue();
        if(tfPadding != null && tfPadding.size() > 0) {
            //Empty for SAME mode
            padding[i] = tfPadding.get(i + 1).intValue();
        }
        kernel[i] = tfKernels.get(i+1).intValue();
    }

    Pooling3DType type;
    String name = nodeDef.getOp().toLowerCase();
    if(name.startsWith("max")){
        type = Pooling3DType.MAX;
    } else if(name.startsWith("av")){
        type = Pooling3DType.AVG;
    } else {
        throw new IllegalStateException("Unknown or not supported pooling type: " + name);
    }

    Pooling3DConfig conf = Pooling3DConfig.builder()
            .sD(strides[0]).sH(strides[1]).sW(strides[2])
            .pD(padding[0]).pH(padding[1]).pW(padding[2])
            .kD(kernel[0]).kH(kernel[1]).kW(kernel[2])
            .type(type)
            .isSameMode(isSameMode)
            .isNCDHW(data_format.equalsIgnoreCase("ncdhw"))
            .build();
    this.config = conf;
    addArgs();
}
 
Example 10
Source File: MaxPoolWithArgmax.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val aStrides = nodeDef.getAttrOrThrow("strides");
    val tfStrides = aStrides.getList().getIList();

    val aKernels = nodeDef.getAttrOrThrow("ksize");
    val tfKernels = aKernels.getList().getIList();

    int sH = 0;
    int sW = 0;

    int pH = 0;
    int pW = 0;

    int kH = 0;
    int kW = 0;

    val aPadding = nodeDef.getAttrOrThrow("padding");
    val padding = aPadding.getList().getIList();

    val paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"", "");

    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");

    String data_format = "nhwc";
    if (nodeDef.containsAttr("data_format")) {
        val attr = nodeDef.getAttrOrThrow("data_format");

        data_format = attr.getS().toStringUtf8().toLowerCase();
    }

    if (data_format.equalsIgnoreCase("nhwc")) {
        sH = tfStrides.get(1).intValue();
        sW = tfStrides.get(2).intValue();

        kH = tfKernels.get(1).intValue();
        kW = tfKernels.get(2).intValue();

        pH = padding.size() > 0 ? padding.get(1).intValue() : 0;
        pW = padding.size() > 0 ? padding.get(2).intValue() : 0;
    } else {
        sH = tfStrides.get(2).intValue();
        sW = tfStrides.get(3).intValue();

        kH = tfKernels.get(2).intValue();
        kW = tfKernels.get(3).intValue();

        pH = padding.size() > 0 ? padding.get(2).intValue() : 0;
        pW = padding.size() > 0 ? padding.get(3).intValue() : 0;
    }

    Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder()
            .sH(sH)
            .sW(sW)
            .type(Pooling2D.Pooling2DType.MAX)
            .isSameMode(isSameMode)
            .kH(kH)
            .kW(kW)
            .pH(pH)
            .pW(pW)
            .isNHWC(data_format.equalsIgnoreCase("nhwc"))
            .extra(1.0) // averaging only for non-padded values
            .build();
    this.config = pooling2DConfig;
    addArgs();
    if(attributesForNode.containsKey("argmax")) {
        outputType = TFGraphMapper.convertType(attributesForNode.get("argmax").getType());
    } else {
        outputType = DataType.LONG;
    }
}
 
Example 11
Source File: DeConv2D.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val aStrides = nodeDef.getAttrOrThrow("strides");
    val tfStrides = aStrides.getList().getIList();
    long sH = 1;
    long sW = 1;
    long kH = 1;
    long kW = 1;

    val aPadding = nodeDef.getAttrOrDefault("padding", null);

    val paddingMode = aPadding.getS().toStringUtf8();

    val args = args();
    INDArray arr = sameDiff.getVariable(args[1].name()).getArr();
    if (arr == null) {
        arr = TFGraphMapper.getNDArrayFromTensor(nodeDef);
        // TODO: arguable. it might be easier to permute weights once
        //arr = (arr.permute(3, 2, 0, 1).dup('c'));
        val varForOp = initWith.getVariable(args[1].name());
        if (arr != null)
            initWith.associateArrayWithVariable(arr, varForOp);


    }

    String dataFormat = "nhwc";
    if (nodeDef.containsAttr("data_format")) {
        val attr = nodeDef.getAttrOrThrow("data_format");
        dataFormat = attr.getS().toStringUtf8().toLowerCase();
    }

    if (dataFormat.equalsIgnoreCase(DeConv2DConfig.NCHW)) {
        sH = tfStrides.get(2).longValue();
        sW = tfStrides.get(3).longValue();

        kH = arr.size(2);
        kW = arr.size(3);
    } else {
        sH = tfStrides.get(1).longValue();
        sW = tfStrides.get(2).longValue();

        kH = arr.size(0);
        kW = arr.size(1);
    }


    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");
    DeConv2DConfig conv2DConfig = DeConv2DConfig.builder()
            .kH(kH)
            .kW(kW)
            .sH(sW)
            .sW(sH)
            .isSameMode(isSameMode)
            .dataFormat(dataFormat.equalsIgnoreCase(DeConv2DConfig.NHWC) ? DeConv2DConfig.NHWC : DeConv2DConfig.NCHW)
            .build();
    this.config = conv2DConfig;

    addArgs();


}
 
Example 12
Source File: MaxPooling2D.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val aStrides = nodeDef.getAttrOrThrow("strides");
    val tfStrides = aStrides.getList().getIList();

    val aKernels = nodeDef.getAttrOrThrow("ksize");
    val tfKernels = aKernels.getList().getIList();

    int sY = 0;
    int sX = 0;

    int ph = 0;
    int pw = 0;

    int kY = 0;
    int kX = 0;

    val aPadding = nodeDef.getAttrOrThrow("padding");
    val padding = aPadding.getList().getIList();

    val paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"", "");

    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");

    String data_format = "nhwc";
    if (nodeDef.containsAttr("data_format")) {
        val attr = nodeDef.getAttrOrThrow("data_format");

        data_format = attr.getS().toStringUtf8().toLowerCase();
    }

    if (data_format.equalsIgnoreCase("nhwc")) {
        sY = tfStrides.get(1).intValue();
        sX = tfStrides.get(2).intValue();

        kY = tfKernels.get(1).intValue();
        kX = tfKernels.get(2).intValue();

        ph = padding.size() > 0 ? padding.get(1).intValue() : 0;
        pw = padding.size() > 0 ? padding.get(2).intValue() : 0;
    } else {
        sY = tfStrides.get(2).intValue();
        sX = tfStrides.get(3).intValue();

        kY = tfKernels.get(2).intValue();
        kX = tfKernels.get(3).intValue();

        ph = padding.size() > 0 ? padding.get(2).intValue() : 0;
        pw = padding.size() > 0 ? padding.get(3).intValue() : 0;
    }

    Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder()
            .sy(sY)
            .sx(sX)
            .type(Pooling2D.Pooling2DType.MAX)
            .isSameMode(isSameMode)
            .kh(kY)
            .kw(kX)
            .ph(ph)
            .pw(pw)
            .virtualWidth(1)
            .virtualHeight(1)
            .isNHWC(data_format.equalsIgnoreCase("nhwc"))
            .extra(1.0) // averaging only for non-padded values
            .build();
    this.config = pooling2DConfig;
    addArgs();
    log.debug("Pooling: k: [{},{}]; s: [{}, {}], padding: {}", kY, kX, sY, sX, aPadding);


}
 
Example 13
Source File: DeConv3DTF.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {

    val aStrides = nodeDef.getAttrOrThrow("strides");
    val aDilations = nodeDef.getAttrOrDefault("dilations", null);
    val tfStrides = aStrides.getList().getIList();
    val tfDilation = aDilations == null ? null : aDilations.getList().getIList();
    int sD, sH, sW, dD, dH, dW;

    val aPadding = nodeDef.getAttrOrDefault("padding", null);
    String paddingMode = aPadding.getS().toStringUtf8();

    String dataFormat = DeConv3DConfig.NDHWC;
    if (nodeDef.containsAttr("data_format")) {
        val attr = nodeDef.getAttrOrThrow("data_format");
        dataFormat = attr.getS().toStringUtf8().toLowerCase();
    }

    if (dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW)) {
        sD = tfStrides.get(2).intValue();
        sH = tfStrides.get(3).intValue();
        sW = tfStrides.get(4).intValue();


        dD = tfDilation == null ? 1 : tfDilation.get(2).intValue();
        dH = tfDilation == null ? 1 : tfDilation.get(3).intValue();
        dW = tfDilation == null ? 1 : tfDilation.get(4).intValue();
    } else {
        sD = tfStrides.get(1).intValue();
        sH = tfStrides.get(2).intValue();
        sW = tfStrides.get(3).intValue();

        dD = tfDilation == null ? 1 : tfDilation.get(1).intValue();
        dH = tfDilation == null ? 1 : tfDilation.get(2).intValue();
        dW = tfDilation == null ? 1 : tfDilation.get(3).intValue();
    }


    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");
    DeConv3DConfig conv3DConfig = DeConv3DConfig.builder()
            .kD(-1)
            .kH(-1)
            .kW(-1)
            .sD(sD)
            .sH(sW)
            .sW(sH)
            .dD(dD)
            .dH(dH)
            .dW(dW)
            .isSameMode(isSameMode)
            .dataFormat(dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW) ? DeConv3DConfig.NCDHW : DeConv3DConfig.NDHWC)
            .build();
    this.config = conv3DConfig;

    addArgs();
}
 
Example 14
Source File: StridedSlice.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val inputBegin = nodeDef.getInput(1);
    val inputEnd = nodeDef.getInput(2);
    val inputStrides = nodeDef.getInput(3);

    NodeDef beginNode = null;
    NodeDef endNode = null;
    NodeDef strides = null;

    for(int i = 0; i < graph.getNodeCount(); i++) {
        if(graph.getNode(i).getName().equals(inputBegin)) {
            beginNode = graph.getNode(i);
        }
        if(graph.getNode(i).getName().equals(inputEnd)) {
            endNode = graph.getNode(i);
        }
        if(graph.getNode(i).getName().equals(inputStrides)) {
            strides = graph.getNode(i);
        }
    }


    // bit masks for this slice
    val bm = nodeDef.getAttrOrThrow("begin_mask");
    val xm = nodeDef.getAttrOrThrow("ellipsis_mask");
    val em = nodeDef.getAttrOrThrow("end_mask");
    val nm = nodeDef.getAttrOrThrow("new_axis_mask");
    val sm = nodeDef.getAttrOrThrow("shrink_axis_mask");

    addIArgument((int) bm.getI());
    addIArgument((int) xm.getI());
    addIArgument((int) em.getI());

    addIArgument((int) nm.getI());
    addIArgument((int) sm.getI());

    val beginArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",beginNode,graph);
    val endArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",endNode,graph);
    val stridesArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",strides,graph);

    if (beginArr != null && endArr != null && stridesArr != null) {

        for (int e = 0; e < beginArr.length(); e++)
            addIArgument(beginArr.getInt(e));

        for (int e = 0; e <  endArr.length(); e++)
            addIArgument(endArr.getInt(e));

        for (int e = 0; e < stridesArr.length(); e++)
            addIArgument(stridesArr.getInt(e));
    } else {
        // do nothing
    }



}
 
Example 15
Source File: AvgPooling2D.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val aStrides = nodeDef.getAttrOrThrow("strides");
    val tfStrides = aStrides.getList().getIList();

    val aKernels = nodeDef.getAttrOrThrow("ksize");
    val tfKernels = aKernels.getList().getIList();

    int sY = 0;
    int sX = 0;

    int ph = 0;
    int pw = 0;

    int kY = 0;
    int kX = 0;

    val aPadding = nodeDef.getAttrOrThrow("padding");
    val padding = aPadding.getList().getIList();

    val paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"","");

    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");

    String data_format = "nhwc";
    if (nodeDef.containsAttr("data_format")) {
        val attr = nodeDef.getAttrOrThrow("data_format");

        data_format = attr.getS().toStringUtf8().toLowerCase();
    }

    if (data_format.equalsIgnoreCase("nhwc")) {
        sY = tfStrides.get(1).intValue();
        sX = tfStrides.get(2).intValue();

        kY = tfKernels.get(1).intValue();
        kX = tfKernels.get(2).intValue();

        ph = padding.size() > 0 ? padding.get(1).intValue() : 0;
        pw = padding.size() > 0 ? padding.get(2).intValue() : 0;
    } else {
        sY = tfStrides.get(2).intValue();
        sX = tfStrides.get(3).intValue();

        kY = tfKernels.get(2).intValue();
        kX = tfKernels.get(3).intValue();

        ph = padding.size() > 0 ? padding.get(2).intValue() : 0;
        pw = padding.size() > 0 ? padding.get(3).intValue() : 0;
    }

    Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder()
            .sy(sY)
            .sx(sX)
            .type(Pooling2D.Pooling2DType.AVG)
            .isSameMode(isSameMode)
            .kh(kY)
            .kw(kX)
            .ph(ph)
            .pw(pw)
            .virtualWidth(1)
            .virtualHeight(1)
            .isNHWC(data_format.equalsIgnoreCase("nhwc"))
            .extra(0.0) // averaging only for non-padded values
            .build();
    this.config = pooling2DConfig;
    addArgs();
    log.debug("Pooling: k: [{},{}]; s: [{}, {}], padding: {}", kY, kX, sY, sX, aPadding);


}
 
Example 16
Source File: Pooling2D.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val aStrides = nodeDef.getAttrOrThrow("strides");
    val tfStrides = aStrides.getList().getIList();
    val sY = tfStrides.get(1);
    val sX = tfStrides.get(2);

    val aKernels = nodeDef.getAttrOrThrow("ksize");
    val tfKernels = aKernels.getList().getIList();

    val kY = tfKernels.get(1);
    val kX = tfKernels.get(2);

    val aPadding = nodeDef.getAttrOrThrow("padding");
    val padding = aPadding.getList().getIList();

    val paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"","");

    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");

    if (!isSameMode)
        log.debug("Mode: {}", paddingMode);

    Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder()
            .sy(sY.intValue())
            .sx(sX.intValue())
            .type(null)
            .isSameMode(isSameMode)
            .kh(kY.intValue())
            .kw(kX.intValue())
            .ph(padding.get(0).intValue())
            .pw(padding.get(1).intValue())
            .virtualWidth(1)
            .virtualHeight(1)
            .build();
    this.config = pooling2DConfig;
    addArgs();
    log.debug("Pooling: k: [{},{}]; s: [{}, {}], padding: {}", kY, kX, sY, sX, aPadding);


}
 
Example 17
Source File: DeConv2D.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val aStrides = nodeDef.getAttrOrThrow("strides");
    val tfStrides = aStrides.getList().getIList();
    int sY = 1;
    int sX = 1;
    int kY = 1;
    int kX = 1;

    val aPadding = nodeDef.getAttrOrDefault("padding", null);

    val paddingMode = aPadding.getS().toStringUtf8();

    val args = args();
    INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr();
    if (arr == null) {
        arr = TFGraphMapper.getInstance().getNDArrayFromTensor(nodeDef.getInput(0), nodeDef, graph);
        // TODO: arguable. it might be easier to permute weights once
        //arr = (arr.permute(3, 2, 0, 1).dup('c'));
        val varForOp = initWith.getVariable(args[1].getVarName());
        if (arr != null)
            initWith.associateArrayWithVariable(arr, varForOp);


    }

    String dataFormat = "nhwc";
    if (nodeDef.containsAttr("data_format")) {
        val attr = nodeDef.getAttrOrThrow("data_format");
        dataFormat = attr.getS().toStringUtf8().toLowerCase();
    }

    // FIXME: int cast


    if (dataFormat.equalsIgnoreCase("nchw")) {
        sY = tfStrides.get(2).intValue();
        sX = tfStrides.get(3).intValue();

        kY = (int) arr.size(2);
        kX = (int) arr.size(3);
    } else {
        sY = tfStrides.get(1).intValue();
        sX = tfStrides.get(2).intValue();

        kY = (int) arr.size(0);
        kX = (int) arr.size(1);
    }


    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");
    DeConv2DConfig conv2DConfig = DeConv2DConfig.builder()
            .kY(kY)
            .kX(kX)
            .sX(sX)
            .sY(sY)
            .isSameMode(isSameMode)
            //c++ check checks for nchw
            .isNHWC(dataFormat.equalsIgnoreCase("nhwc"))
            .build();
    this.config = conv2DConfig;

    addArgs();


}