Python tensorflow.python.ops.gen_nn_ops.conv3d() Examples
The following are 1
code examples of tensorflow.python.ops.gen_nn_ops.conv3d().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.python.ops.gen_nn_ops
, or try the search function
.
Example #1
Source File: nn_ops.py From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License | 4 votes |
def __init__(self, input_shape, filter_shape, # pylint: disable=redefined-builtin padding, data_format=None, strides=None, name=None): filter_shape = filter_shape.with_rank(input_shape.ndims) self.padding = padding self.name = name input_shape = input_shape.with_rank(filter_shape.ndims) if input_shape.ndims is None: raise ValueError("Rank of convolution must be known") if input_shape.ndims < 3 or input_shape.ndims > 5: raise ValueError( "`input` and `filter` must have rank at least 3 and at most 5") conv_dims = input_shape.ndims - 2 if strides is None: strides = [1] * conv_dims elif len(strides) != conv_dims: raise ValueError("len(strides)=%d, but should be %d" % (len(strides), conv_dims)) if conv_dims == 1: # conv1d uses the 2-d data format names if data_format is None or data_format == "NWC": data_format_2d = "NHWC" elif data_format == "NCW": data_format_2d = "NCHW" else: raise ValueError("data_format must be \"NWC\" or \"NCW\".") self.strides = strides[0] self.data_format = data_format_2d self.conv_op = self._conv1d elif conv_dims == 2: if data_format is None or data_format == "NHWC": data_format = "NHWC" strides = [1] + list(strides) + [1] elif data_format == "NCHW": strides = [1, 1] + list(strides) else: raise ValueError("data_format must be \"NHWC\" or \"NCHW\".") self.strides = strides self.data_format = data_format self.conv_op = gen_nn_ops.conv2d elif conv_dims == 3: if data_format is None or data_format == "NDHWC": strides = [1] + list(strides) + [1] elif data_format == "NCDHW": strides = [1, 1] + list(strides) else: raise ValueError("data_format must be \"NDHWC\" or \"NCDHW\". Have: %s" % data_format) self.strides = strides self.data_format = data_format self.conv_op = gen_nn_ops.conv3d # Note that we need this adapter since argument names for conv1d don't match # those for gen_nn_ops.conv2d and gen_nn_ops.conv3d. # pylint: disable=redefined-builtin