# Copyright 2015 Google Inc. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A quick hack to try deconv out.""" import collections import tensorflow as tf from tensorflow.python.framework import tensor_shape from prettytensor import layers from prettytensor import pretty_tensor_class as prettytensor from prettytensor.pretty_tensor_class import PAD_SAME from prettytensor.pretty_tensor_class import Phase from prettytensor.pretty_tensor_class import PROVIDED class _deconv2d(prettytensor.VarStoreMethod): def __call__(self, input_layer, kernel, depth, name, stride, activation_fn, l2loss, init, stddev, bias, edges, batch_normalize): """Adds a convolution to the stack of operations. The current head must be a rank 4 Tensor. Args: input_layer: The chainable object, supplied. kernel: The size of the patch for the pool, either an int or a length 1 or 2 sequence (if length 1 or int, it is expanded). depth: The depth of the new Tensor. name: The name for this operation is also used to create/find the parameter variables. stride: The strides as a length 1, 2 or 4 sequence or an integer. If an int, length 1 or 2, the stride in the first and last dimensions are 1. activation_fn: A tuple of (activation_function, extra_parameters). Any function that takes a tensor as its first argument can be used. More common functions will have summaries added (e.g. relu). l2loss: Set to a value greater than 0 to use L2 regularization to decay the weights. init: An optional initialization. If not specified, uses Xavier initialization. stddev: A standard deviation to use in parameter initialization. bias: Set to False to not have a bias. edges: Either SAME to use 0s for the out of bounds area or VALID to shrink the output size and only uses valid input pixels. batch_normalize: Set to True to batch_normalize this layer. Returns: Handle to the generated layer. Raises: ValueError: If head is not a rank 4 tensor or the depth of the input (4th dim) is not known. """ if len(input_layer.shape) != 4: raise ValueError( 'Cannot perform conv2d on tensor with shape %s' % input_layer.shape) if input_layer.shape[3] is None: raise ValueError('Input depth must be known') kernel = _kernel(kernel) stride = _stride(stride) size = [kernel[0], kernel[1], depth, input_layer.shape[3]] books = input_layer.bookkeeper if init is None: if stddev is None: patch_size = size[0] * size[1] init = layers.xavier_init(size[2] * patch_size, size[3] * patch_size) elif stddev: init = tf.truncated_normal_initializer(stddev=stddev) else: init = tf.zeros_initializer() elif stddev is not None: raise ValueError('Do not set both init and stddev.') dtype = input_layer.tensor.dtype params = self.variable('weights', size, init, dt=dtype) input_height = input_layer.shape[1] input_width = input_layer.shape[2] filter_height = kernel[0] filter_width = kernel[1] row_stride = stride[1] col_stride = stride[2] out_rows, out_cols = get2d_deconv_output_size(input_height, input_width, filter_height, filter_width, row_stride, col_stride, edges) output_shape = [input_layer.shape[0], out_rows, out_cols, depth] y = tf.nn.conv2d_transpose(input_layer, params, output_shape, stride, edges) layers.add_l2loss(books, params, l2loss) if bias: y += self.variable( 'bias', [size[-2]], tf.zeros_initializer(), dt=dtype) books.add_scalar_summary( tf.reduce_mean( layers.spatial_slice_zeros(y)), '%s/zeros_spatial' % y.op.name) if batch_normalize: y = input_layer.with_tensor(y).batch_normalize() if activation_fn is not None: if not isinstance(activation_fn, collections.Sequence): activation_fn = (activation_fn,) y = layers.apply_activation( books, y, activation_fn[0], activation_args=activation_fn[1:]) return input_layer.with_tensor(y) # pylint: enable=redefined-outer-name,invalid-name # Helper methods def get2d_deconv_output_size(input_height, input_width, filter_height, filter_width, row_stride, col_stride, padding_type): """Returns the number of rows and columns in a convolution/pooling output.""" input_height = tensor_shape.as_dimension(input_height) input_width = tensor_shape.as_dimension(input_width) filter_height = tensor_shape.as_dimension(filter_height) filter_width = tensor_shape.as_dimension(filter_width) row_stride = int(row_stride) col_stride = int(col_stride) # Compute number of rows in the output, based on the padding. if input_height.value is None or filter_height.value is None: out_rows = None elif padding_type == "VALID": out_rows = (input_height.value - 1) * row_stride + filter_height.value elif padding_type == "SAME": out_rows = input_height.value * row_stride else: raise ValueError("Invalid value for padding: %r" % padding_type) # Compute number of columns in the output, based on the padding. if input_width.value is None or filter_width.value is None: out_cols = None elif padding_type == "VALID": out_cols = (input_width.value - 1) * col_stride + filter_width.value elif padding_type == "SAME": out_cols = input_width.value * col_stride return out_rows, out_cols def _kernel(kernel_spec): """Expands the kernel spec into a length 2 list. Args: kernel_spec: An integer or a length 1 or 2 sequence that is expanded to a list. Returns: A length 2 list. """ if isinstance(kernel_spec, int): return [kernel_spec, kernel_spec] elif len(kernel_spec) == 1: return [kernel_spec[0], kernel_spec[0]] else: assert len(kernel_spec) == 2 return kernel_spec def _stride(stride_spec): """Expands the stride spec into a length 4 list. Args: stride_spec: None, an integer or a length 1, 2, or 4 sequence. Returns: A length 4 list. """ if stride_spec is None: return [1, 1, 1, 1] elif isinstance(stride_spec, int): return [1, stride_spec, stride_spec, 1] elif len(stride_spec) == 1: return [1, stride_spec[0], stride_spec[0], 1] elif len(stride_spec) == 2: return [1, stride_spec[0], stride_spec[1], 1] else: assert len(stride_spec) == 4 return stride_spec # pylint: disable=redefined-outer-name,invalid-name @prettytensor.Register( assign_defaults=('activation_fn', 'l2loss', 'stddev', 'batch_normalize')) class deconv2d(_deconv2d): def __call__(self, input_layer, kernel, depth, name=PROVIDED, stride=None, activation_fn=None, l2loss=None, init=None, stddev=None, bias=True, edges=PAD_SAME, batch_normalize=False): return super().__call__( input_layer, kernel, depth, name, stride, activation_fn, l2loss, init, stddev, bias, edges, batch_normalize )