Java Code Examples for org.tensorflow.framework.NodeDef#containsAttr()

The following examples show how to use org.tensorflow.framework.NodeDef#containsAttr() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ScatterNdUpdate.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);

    if (nodeDef.containsAttr("use_locking")) {
        if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
            bArguments.add(true);
        } else {
            bArguments.add(false);
        }
    } else
        bArguments.add(false);
}
 
Example 2
Source File: ScatterNdAdd.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);

    if (nodeDef.containsAttr("use_locking")) {
        if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
            bArguments.add(true);
        } else {
            bArguments.add(false);
        }
    } else
        bArguments.add(false);
}
 
Example 3
Source File: ScatterAdd.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);

    if (nodeDef.containsAttr("use_locking")) {
        if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
            bArguments.add(true);
        } else {
            bArguments.add(false);
        }
    } else
        bArguments.add(false);
}
 
Example 4
Source File: ScatterNdSub.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);

    if (nodeDef.containsAttr("use_locking")) {
        if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
            bArguments.add(true);
        } else {
            bArguments.add(false);
        }
    } else
        bArguments.add(false);
}
 
Example 5
Source File: ScatterNd.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);

    if (nodeDef.containsAttr("use_locking")) {
        if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
            bArguments.add(true);
        } else {
            bArguments.add(false);
        }
    } else
        bArguments.add(false);
}
 
Example 6
Source File: ScatterDiv.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);

    if (nodeDef.containsAttr("use_locking")) {
        if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
            bArguments.add(true);
        } else {
            bArguments.add(false);
        }
    } else
        bArguments.add(false);
}
 
Example 7
Source File: ScatterUpdate.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    if (nodeDef.containsAttr("use_locking")) {
        if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
            bArguments.add(true);
        } else {
            bArguments.add(false);
        }
    } else
        bArguments.add(false);
}
 
Example 8
Source File: ScatterMul.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);

    if (nodeDef.containsAttr("use_locking")) {
        if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
            bArguments.add(true);
        } else {
            bArguments.add(false);
        }
    } else
        bArguments.add(false);
}
 
Example 9
Source File: ScatterMax.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);

    if (nodeDef.containsAttr("use_locking")) {
        if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
            bArguments.add(true);
        } else {
            bArguments.add(false);
        }
    } else
        bArguments.add(false);
}
 
Example 10
Source File: ScatterSub.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);

    if (nodeDef.containsAttr("use_locking")) {
        if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
            bArguments.add(true);
        } else {
            bArguments.add(false);
        }
    } else
        bArguments.add(false);
}
 
Example 11
Source File: ScatterMin.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);

    if (nodeDef.containsAttr("use_locking")) {
        if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
            bArguments.add(true);
        } else {
            bArguments.add(false);
        }
    } else
        bArguments.add(false);
}
 
Example 12
Source File: AvgPooling2D.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val aStrides = nodeDef.getAttrOrThrow("strides");
    val tfStrides = aStrides.getList().getIList();

    val aKernels = nodeDef.getAttrOrThrow("ksize");
    val tfKernels = aKernels.getList().getIList();

    int sH = 0;
    int sW = 0;

    int pH = 0;
    int pW = 0;

    int kH = 0;
    int kW = 0;

    val aPadding = nodeDef.getAttrOrThrow("padding");
    val padding = aPadding.getList().getIList();

    val paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"", "");

    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");

    String data_format = "nhwc";
    if (nodeDef.containsAttr("data_format")) {
        val attr = nodeDef.getAttrOrThrow("data_format");

        data_format = attr.getS().toStringUtf8().toLowerCase();
    }

    if (data_format.equalsIgnoreCase("nhwc")) {
        sH = tfStrides.get(1).intValue();
        sW = tfStrides.get(2).intValue();

        kH = tfKernels.get(1).intValue();
        kW = tfKernels.get(2).intValue();

        pH = padding.size() > 0 ? padding.get(1).intValue() : 0;
        pW = padding.size() > 0 ? padding.get(2).intValue() : 0;
    } else {
        sH = tfStrides.get(2).intValue();
        sW = tfStrides.get(3).intValue();

        kH = tfKernels.get(2).intValue();
        kW = tfKernels.get(3).intValue();

        pH = padding.size() > 0 ? padding.get(2).intValue() : 0;
        pW = padding.size() > 0 ? padding.get(3).intValue() : 0;
    }

    Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder()
            .sH(sH)
            .sW(sW)
            .type(Pooling2D.Pooling2DType.AVG)
            .isSameMode(isSameMode)
            .kH(kH)
            .kW(kW)
            .pH(pH)
            .pW(pW)
            .isNHWC(data_format.equalsIgnoreCase("nhwc"))
            .extra(0.0) // averaging only for non-padded values
            .build();
    this.config = pooling2DConfig;
    addArgs();
    //log.debug("Pooling: k: [{},{}]; s: [{}, {}], padding: {}", kH, kW, sH, sW, aPadding);
}
 
Example 13
Source File: MaxPooling2D.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val aStrides = nodeDef.getAttrOrThrow("strides");
    val tfStrides = aStrides.getList().getIList();

    val aKernels = nodeDef.getAttrOrThrow("ksize");
    val tfKernels = aKernels.getList().getIList();

    int sY = 0;
    int sX = 0;

    int ph = 0;
    int pw = 0;

    int kY = 0;
    int kX = 0;

    val aPadding = nodeDef.getAttrOrThrow("padding");
    val padding = aPadding.getList().getIList();

    val paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"", "");

    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");

    String data_format = "nhwc";
    if (nodeDef.containsAttr("data_format")) {
        val attr = nodeDef.getAttrOrThrow("data_format");

        data_format = attr.getS().toStringUtf8().toLowerCase();
    }

    if (data_format.equalsIgnoreCase("nhwc")) {
        sY = tfStrides.get(1).intValue();
        sX = tfStrides.get(2).intValue();

        kY = tfKernels.get(1).intValue();
        kX = tfKernels.get(2).intValue();

        ph = padding.size() > 0 ? padding.get(1).intValue() : 0;
        pw = padding.size() > 0 ? padding.get(2).intValue() : 0;
    } else {
        sY = tfStrides.get(2).intValue();
        sX = tfStrides.get(3).intValue();

        kY = tfKernels.get(2).intValue();
        kX = tfKernels.get(3).intValue();

        ph = padding.size() > 0 ? padding.get(2).intValue() : 0;
        pw = padding.size() > 0 ? padding.get(3).intValue() : 0;
    }

    Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder()
            .sy(sY)
            .sx(sX)
            .type(Pooling2D.Pooling2DType.MAX)
            .isSameMode(isSameMode)
            .kh(kY)
            .kw(kX)
            .ph(ph)
            .pw(pw)
            .virtualWidth(1)
            .virtualHeight(1)
            .isNHWC(data_format.equalsIgnoreCase("nhwc"))
            .extra(1.0) // averaging only for non-padded values
            .build();
    this.config = pooling2DConfig;
    addArgs();
    log.debug("Pooling: k: [{},{}]; s: [{}, {}], padding: {}", kY, kX, sY, sX, aPadding);


}
 
Example 14
Source File: Pooling3D.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val aStrides = nodeDef.getAttrOrThrow("strides");
    List<Long> tfStrides = aStrides.getList().getIList();
    val aKernels = nodeDef.getAttrOrThrow("ksize");
    List<Long> tfKernels = aKernels.getList().getIList();
    val aPadding = nodeDef.getAttrOrThrow("padding");
    List<Long> tfPadding = aPadding.getList().getIList();

    String paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"", "");

    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");

    String data_format = "ndhwc";
    if (nodeDef.containsAttr("data_format")) {
        val attr = nodeDef.getAttrOrThrow("data_format");

        data_format = attr.getS().toStringUtf8().toLowerCase();
    }

    //Order: depth, height, width
    //TF doesn't have dilation, it seems?
    int[] strides = new int[3];
    int[] padding = new int[3];
    int[] kernel = new int[3];
    for( int i=0; i<3; i++ ) {
        //TF values here have 5 values: minibatch and Channels at positions 0 and 4, which are almost always 1
        strides[i] = tfStrides.get(i+1).intValue();
        if(tfPadding != null && tfPadding.size() > 0) {
            //Empty for SAME mode
            padding[i] = tfPadding.get(i + 1).intValue();
        }
        kernel[i] = tfKernels.get(i+1).intValue();
    }

    Pooling3DType type;
    String name = nodeDef.getOp().toLowerCase();
    if(name.startsWith("max")){
        type = Pooling3DType.MAX;
    } else if(name.startsWith("av")){
        type = Pooling3DType.AVG;
    } else {
        throw new IllegalStateException("Unknown or not supported pooling type: " + name);
    }

    Pooling3DConfig conf = Pooling3DConfig.builder()
            .sD(strides[0]).sH(strides[1]).sW(strides[2])
            .pD(padding[0]).pH(padding[1]).pW(padding[2])
            .kD(kernel[0]).kH(kernel[1]).kW(kernel[2])
            .type(type)
            .isSameMode(isSameMode)
            .isNCDHW(data_format.equalsIgnoreCase("ncdhw"))
            .build();
    this.config = conf;
    addArgs();
}
 
Example 15
Source File: MaxPoolWithArgmax.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val aStrides = nodeDef.getAttrOrThrow("strides");
    val tfStrides = aStrides.getList().getIList();

    val aKernels = nodeDef.getAttrOrThrow("ksize");
    val tfKernels = aKernels.getList().getIList();

    int sH = 0;
    int sW = 0;

    int pH = 0;
    int pW = 0;

    int kH = 0;
    int kW = 0;

    val aPadding = nodeDef.getAttrOrThrow("padding");
    val padding = aPadding.getList().getIList();

    val paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"", "");

    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");

    String data_format = "nhwc";
    if (nodeDef.containsAttr("data_format")) {
        val attr = nodeDef.getAttrOrThrow("data_format");

        data_format = attr.getS().toStringUtf8().toLowerCase();
    }

    if (data_format.equalsIgnoreCase("nhwc")) {
        sH = tfStrides.get(1).intValue();
        sW = tfStrides.get(2).intValue();

        kH = tfKernels.get(1).intValue();
        kW = tfKernels.get(2).intValue();

        pH = padding.size() > 0 ? padding.get(1).intValue() : 0;
        pW = padding.size() > 0 ? padding.get(2).intValue() : 0;
    } else {
        sH = tfStrides.get(2).intValue();
        sW = tfStrides.get(3).intValue();

        kH = tfKernels.get(2).intValue();
        kW = tfKernels.get(3).intValue();

        pH = padding.size() > 0 ? padding.get(2).intValue() : 0;
        pW = padding.size() > 0 ? padding.get(3).intValue() : 0;
    }

    Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder()
            .sH(sH)
            .sW(sW)
            .type(Pooling2D.Pooling2DType.MAX)
            .isSameMode(isSameMode)
            .kH(kH)
            .kW(kW)
            .pH(pH)
            .pW(pW)
            .isNHWC(data_format.equalsIgnoreCase("nhwc"))
            .extra(1.0) // averaging only for non-padded values
            .build();
    this.config = pooling2DConfig;
    addArgs();
    if(attributesForNode.containsKey("argmax")) {
        outputType = TFGraphMapper.convertType(attributesForNode.get("argmax").getType());
    } else {
        outputType = DataType.LONG;
    }
}
 
Example 16
Source File: DeConv2D.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val aStrides = nodeDef.getAttrOrThrow("strides");
    val tfStrides = aStrides.getList().getIList();
    long sH = 1;
    long sW = 1;
    long kH = 1;
    long kW = 1;

    val aPadding = nodeDef.getAttrOrDefault("padding", null);

    val paddingMode = aPadding.getS().toStringUtf8();

    val args = args();
    INDArray arr = sameDiff.getVariable(args[1].name()).getArr();
    if (arr == null) {
        arr = TFGraphMapper.getNDArrayFromTensor(nodeDef);
        // TODO: arguable. it might be easier to permute weights once
        //arr = (arr.permute(3, 2, 0, 1).dup('c'));
        val varForOp = initWith.getVariable(args[1].name());
        if (arr != null)
            initWith.associateArrayWithVariable(arr, varForOp);


    }

    String dataFormat = "nhwc";
    if (nodeDef.containsAttr("data_format")) {
        val attr = nodeDef.getAttrOrThrow("data_format");
        dataFormat = attr.getS().toStringUtf8().toLowerCase();
    }

    if (dataFormat.equalsIgnoreCase(DeConv2DConfig.NCHW)) {
        sH = tfStrides.get(2).longValue();
        sW = tfStrides.get(3).longValue();

        kH = arr.size(2);
        kW = arr.size(3);
    } else {
        sH = tfStrides.get(1).longValue();
        sW = tfStrides.get(2).longValue();

        kH = arr.size(0);
        kW = arr.size(1);
    }


    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");
    DeConv2DConfig conv2DConfig = DeConv2DConfig.builder()
            .kH(kH)
            .kW(kW)
            .sH(sW)
            .sW(sH)
            .isSameMode(isSameMode)
            .dataFormat(dataFormat.equalsIgnoreCase(DeConv2DConfig.NHWC) ? DeConv2DConfig.NHWC : DeConv2DConfig.NCHW)
            .build();
    this.config = conv2DConfig;

    addArgs();


}
 
Example 17
Source File: MaxPooling2D.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val aStrides = nodeDef.getAttrOrThrow("strides");
    val tfStrides = aStrides.getList().getIList();

    val aKernels = nodeDef.getAttrOrThrow("ksize");
    val tfKernels = aKernels.getList().getIList();

    int sH = 0;
    int sW = 0;

    int pH = 0;
    int pW = 0;

    int kH = 0;
    int kW = 0;

    val aPadding = nodeDef.getAttrOrThrow("padding");
    val padding = aPadding.getList().getIList();

    val paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"", "");

    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");

    String data_format = "nhwc";
    if (nodeDef.containsAttr("data_format")) {
        val attr = nodeDef.getAttrOrThrow("data_format");

        data_format = attr.getS().toStringUtf8().toLowerCase();
    }

    if (data_format.equalsIgnoreCase("nhwc")) {
        sH = tfStrides.get(1).intValue();
        sW = tfStrides.get(2).intValue();

        kH = tfKernels.get(1).intValue();
        kW = tfKernels.get(2).intValue();

        pH = padding.size() > 0 ? padding.get(1).intValue() : 0;
        pW = padding.size() > 0 ? padding.get(2).intValue() : 0;
    } else {
        sH = tfStrides.get(2).intValue();
        sW = tfStrides.get(3).intValue();

        kH = tfKernels.get(2).intValue();
        kW = tfKernels.get(3).intValue();

        pH = padding.size() > 0 ? padding.get(2).intValue() : 0;
        pW = padding.size() > 0 ? padding.get(3).intValue() : 0;
    }

    Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder()
            .sH(sH)
            .sW(sW)
            .type(Pooling2D.Pooling2DType.MAX)
            .isSameMode(isSameMode)
            .kH(kH)
            .kW(kW)
            .pH(pH)
            .pW(pW)
            .isNHWC(data_format.equalsIgnoreCase("nhwc"))
            .extra(1.0) // averaging only for non-padded values
            .build();
    this.config = pooling2DConfig;
    addArgs();
}
 
Example 18
Source File: DeConv3DTF.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {

    val aStrides = nodeDef.getAttrOrThrow("strides");
    val aDilations = nodeDef.getAttrOrDefault("dilations", null);
    val tfStrides = aStrides.getList().getIList();
    val tfDilation = aDilations == null ? null : aDilations.getList().getIList();
    int sD, sH, sW, dD, dH, dW;

    val aPadding = nodeDef.getAttrOrDefault("padding", null);
    String paddingMode = aPadding.getS().toStringUtf8();

    String dataFormat = DeConv3DConfig.NDHWC;
    if (nodeDef.containsAttr("data_format")) {
        val attr = nodeDef.getAttrOrThrow("data_format");
        dataFormat = attr.getS().toStringUtf8().toLowerCase();
    }

    if (dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW)) {
        sD = tfStrides.get(2).intValue();
        sH = tfStrides.get(3).intValue();
        sW = tfStrides.get(4).intValue();


        dD = tfDilation == null ? 1 : tfDilation.get(2).intValue();
        dH = tfDilation == null ? 1 : tfDilation.get(3).intValue();
        dW = tfDilation == null ? 1 : tfDilation.get(4).intValue();
    } else {
        sD = tfStrides.get(1).intValue();
        sH = tfStrides.get(2).intValue();
        sW = tfStrides.get(3).intValue();

        dD = tfDilation == null ? 1 : tfDilation.get(1).intValue();
        dH = tfDilation == null ? 1 : tfDilation.get(2).intValue();
        dW = tfDilation == null ? 1 : tfDilation.get(3).intValue();
    }


    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");
    DeConv3DConfig conv3DConfig = DeConv3DConfig.builder()
            .kD(-1)
            .kH(-1)
            .kW(-1)
            .sD(sD)
            .sH(sW)
            .sW(sH)
            .dD(dD)
            .dH(dH)
            .dW(dW)
            .isSameMode(isSameMode)
            .dataFormat(dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW) ? DeConv3DConfig.NCDHW : DeConv3DConfig.NDHWC)
            .build();
    this.config = conv3DConfig;

    addArgs();
}
 
Example 19
Source File: AvgPooling2D.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val aStrides = nodeDef.getAttrOrThrow("strides");
    val tfStrides = aStrides.getList().getIList();

    val aKernels = nodeDef.getAttrOrThrow("ksize");
    val tfKernels = aKernels.getList().getIList();

    int sY = 0;
    int sX = 0;

    int ph = 0;
    int pw = 0;

    int kY = 0;
    int kX = 0;

    val aPadding = nodeDef.getAttrOrThrow("padding");
    val padding = aPadding.getList().getIList();

    val paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"","");

    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");

    String data_format = "nhwc";
    if (nodeDef.containsAttr("data_format")) {
        val attr = nodeDef.getAttrOrThrow("data_format");

        data_format = attr.getS().toStringUtf8().toLowerCase();
    }

    if (data_format.equalsIgnoreCase("nhwc")) {
        sY = tfStrides.get(1).intValue();
        sX = tfStrides.get(2).intValue();

        kY = tfKernels.get(1).intValue();
        kX = tfKernels.get(2).intValue();

        ph = padding.size() > 0 ? padding.get(1).intValue() : 0;
        pw = padding.size() > 0 ? padding.get(2).intValue() : 0;
    } else {
        sY = tfStrides.get(2).intValue();
        sX = tfStrides.get(3).intValue();

        kY = tfKernels.get(2).intValue();
        kX = tfKernels.get(3).intValue();

        ph = padding.size() > 0 ? padding.get(2).intValue() : 0;
        pw = padding.size() > 0 ? padding.get(3).intValue() : 0;
    }

    Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder()
            .sy(sY)
            .sx(sX)
            .type(Pooling2D.Pooling2DType.AVG)
            .isSameMode(isSameMode)
            .kh(kY)
            .kw(kX)
            .ph(ph)
            .pw(pw)
            .virtualWidth(1)
            .virtualHeight(1)
            .isNHWC(data_format.equalsIgnoreCase("nhwc"))
            .extra(0.0) // averaging only for non-padded values
            .build();
    this.config = pooling2DConfig;
    addArgs();
    log.debug("Pooling: k: [{},{}]; s: [{}, {}], padding: {}", kY, kX, sY, sX, aPadding);


}
 
Example 20
Source File: DeConv2D.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val aStrides = nodeDef.getAttrOrThrow("strides");
    val tfStrides = aStrides.getList().getIList();
    int sY = 1;
    int sX = 1;
    int kY = 1;
    int kX = 1;

    val aPadding = nodeDef.getAttrOrDefault("padding", null);

    val paddingMode = aPadding.getS().toStringUtf8();

    val args = args();
    INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr();
    if (arr == null) {
        arr = TFGraphMapper.getInstance().getNDArrayFromTensor(nodeDef.getInput(0), nodeDef, graph);
        // TODO: arguable. it might be easier to permute weights once
        //arr = (arr.permute(3, 2, 0, 1).dup('c'));
        val varForOp = initWith.getVariable(args[1].getVarName());
        if (arr != null)
            initWith.associateArrayWithVariable(arr, varForOp);


    }

    String dataFormat = "nhwc";
    if (nodeDef.containsAttr("data_format")) {
        val attr = nodeDef.getAttrOrThrow("data_format");
        dataFormat = attr.getS().toStringUtf8().toLowerCase();
    }

    // FIXME: int cast


    if (dataFormat.equalsIgnoreCase("nchw")) {
        sY = tfStrides.get(2).intValue();
        sX = tfStrides.get(3).intValue();

        kY = (int) arr.size(2);
        kX = (int) arr.size(3);
    } else {
        sY = tfStrides.get(1).intValue();
        sX = tfStrides.get(2).intValue();

        kY = (int) arr.size(0);
        kX = (int) arr.size(1);
    }


    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");
    DeConv2DConfig conv2DConfig = DeConv2DConfig.builder()
            .kY(kY)
            .kX(kX)
            .sX(sX)
            .sY(sY)
            .isSameMode(isSameMode)
            //c++ check checks for nchw
            .isNHWC(dataFormat.equalsIgnoreCase("nhwc"))
            .build();
    this.config = conv2DConfig;

    addArgs();


}