# -*- coding: utf-8 -*- # Copyright 2018 Google LLC. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Parameterizations for layer classes.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import numpy as np from scipy import fftpack from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops _matrix_cache = {} def irdft_matrix(shape, dtype=dtypes.float32): """Matrix for implementing kernel reparameterization with `tf.matmul`. This can be used to represent a kernel with the provided shape in the RDFT domain. Example code for kernel creation, assuming 2D kernels: ``` def create_kernel(init): shape = init.shape.as_list() matrix = irdft_matrix(shape[:2]) init = tf.reshape(init, (shape[0] * shape[1], shape[2] * shape[3])) init = tf.matmul(tf.transpose(matrix), init) kernel = tf.Variable(init) kernel = tf.matmul(matrix, kernel) kernel = tf.reshape(kernel, shape) return kernel ``` Args: shape: Iterable of integers. Shape of kernel to apply this matrix to. dtype: `dtype` of returned matrix. Returns: `Tensor` of shape `(prod(shape), prod(shape))` and dtype `dtype`. """ shape = tuple(int(s) for s in shape) dtype = dtypes.as_dtype(dtype) key = (ops.get_default_graph(), "irdft", shape, dtype.as_datatype_enum) matrix = _matrix_cache.get(key) if matrix is None: size = np.prod(shape) rank = len(shape) matrix = np.identity(size, dtype=np.float64).reshape((size,) + shape) for axis in range(rank): matrix = fftpack.rfft(matrix, axis=axis + 1) slices = (rank + 1) * [slice(None)] if shape[axis] % 2 == 1: slices[axis + 1] = slice(1, None) else: slices[axis + 1] = slice(1, -1) matrix[tuple(slices)] *= np.sqrt(2) matrix /= np.sqrt(size) matrix = np.reshape(matrix, (size, size)) matrix = array_ops.constant( matrix, dtype=dtype, name="irdft_" + "x".join([str(s) for s in shape])) _matrix_cache[key] = matrix return matrix