""" Test conv fwd-prop and back-prop for ND convs""" from typing import Tuple import hypothesis.extra.numpy as hnp import hypothesis.strategies as st import numpy as np import pytest from hypothesis import HealthCheck, assume, given, settings from numpy.testing import assert_allclose from pytest import raises import mygrad as mg from mygrad import Tensor from mygrad.nnet.layers import conv_nd from ...utils.numerical_gradient import numerical_gradient_full from ...wrappers.uber import backprop_test_factory, fwdprop_test_factory @pytest.mark.parametrize( "shapes", [ # x has too few dims st.tuples( hnp.array_shapes(min_dims=0, max_dims=2), hnp.array_shapes(min_dims=3) ), # x.ndim != k.ndim hnp.array_shapes(min_dims=3).flatmap( lambda x: st.tuples( st.just(x), hnp.array_shapes(min_dims=2).filter(lambda s: len(s) != len(x)), ) ), # channel sizes don't match hnp.array_shapes(min_dims=3).flatmap( lambda x: st.tuples( st.just(x), hnp.array_shapes(min_dims=len(x), max_dims=len(x)).filter( lambda s: s[1] != x[1] ), ) ), ], ) @given(data=st.data()) def test_input_validation( shapes: st.SearchStrategy[Tuple[Tuple[int, ...], Tuple[int, ...]]], data: st.DataObject, ): x_shape, k_shape = data.draw(shapes, label="x_shape, k_shape") x = mg.zeros(x_shape, dtype="float") k = mg.zeros(k_shape, dtype="float") with raises(ValueError): conv_nd(x, k, stride=1) def get_outshape(x_shape, w_shape, stride, dilation): """ Compute the shape of the output tensor given an input shape, convolutional filter shape, and stride. Parameters ---------- x_shape : Tuple[int, ...] The shape of the input tensor. w_shape : Tuple[int, ...] The shape of the convolutional filter. stride : Tuple[int, ...] The stride at which to apply the convolutional filter to the input. dilation : Tuple[int, ...] The dilation used to form each window over the data. Returns ------- numpy.ndarray[int], shape=(num_conv,) The shape of the output tensor resulting from convolving a tensor of shape `x_shape` with a tensor of shape `w_shape`. Returns `None` if an invalid combination of shapes are provided. """ x_shape = np.array(x_shape) w_shape = np.array(w_shape) stride = np.array(stride) dilation = np.array(dilation) out_shape = (x_shape - ((w_shape - 1) * dilation + 1)) / stride + 1 if not all(i.is_integer() and i > 0 for i in out_shape): msg = "Stride and kernel dimensions are incompatible: \n" msg += "Input dimensions: {}\n".format(tuple(x_shape)) msg += "Stride dimensions: {}\n".format(tuple(stride)) msg += "Kernel dimensions: {}\n".format(tuple(w_shape)) msg += "Dilation dimensions: {}\n".format(tuple(dilation)) return None return out_shape.astype(np.int32) def convolve_numpy(input_image, conv_filter, stride, dilation=None): """ Convolve `input_image` with `conv_filter` at a stride of `stride`. Parameters ---------- input_image : numpy.ndarray, shape=(C, H, ...) The input over which to perform convolution. conv_filter : numpy.ndarray, shape=(C, Hf, ...) The convolutional filter to slide across the image. stride : Sequence[int] The stride at which to apply `conv_filter` across `input_image`. Returns ------- numpy.ndarray, shape=(H', ...) The result of convolving `input_image` with `conv_filter` at a stride of `stride`, where (H', W') is the result of `get_outshape`. """ conv_shape = conv_filter.shape[1:] in_shape = input_image.shape[1:] if dilation is None: dilation = (1,) * len(stride) out_shape = tuple(get_outshape(in_shape, conv_shape, stride, dilation)) out = np.empty(out_shape, np.float32) for ind in np.ndindex(out_shape): slices = (slice(None),) + tuple( slice(i * s, i * s + w * d, d) for i, w, s, d in zip(ind, conv_shape, stride, dilation) ) out[ind] = np.sum(conv_filter * input_image[slices]) return out def conv_bank(input_images, conv_filters, stride, dilation=None, padding=tuple()): """ Convolve a bank of filters over a stack of images. Parameters ---------- input_images : numpy.ndarray, shape=(N, C, H, ...) The images over which to convolve our filters. conv_filters : numpy.ndarray, shape=(K, C, Hf, ...) The convolutional filters to apply to the images. stride : Sequence[int] The stride at which to apply each filter to the images. dilation : Sequence[int] Returns ------- numpy.ndarray, shape=(N, K, H', ...) The result of convolving `input_image` with `conv_filter` at a stride of `stride`, where (H', ...) is the result of `get_outshape`. """ if isinstance(padding, int): padding = (padding,) * (input_images.ndim - 2) if sum(padding): # symmetric 0-padding for X0, X1, ... dimensions axis_pad = tuple((i, i) for i in (0, 0, *padding)) input_images = np.pad(input_images, axis_pad, mode="constant") img_shape = input_images.shape[2:] conv_shape = conv_filters.shape[2:] if dilation is None: dilation = (1,) * len(stride) out_shape = get_outshape(img_shape, conv_shape, stride, dilation) out = np.empty((len(input_images), len(conv_filters), *out_shape)) for i, image in enumerate(input_images): for j, conv in enumerate(conv_filters): out[i, j] = convolve_numpy(image, conv, stride, dilation) return out def test_convnd_fwd_trivial(): # trivial by-hand test: 1-dimensional conv # x: # [ 1, 2, 3, 4] # k: # [-1, -2], # stride = (2,) x = Tensor(np.arange(1, 5).reshape(1, 1, 4).astype(float)) k = Tensor(-1 * np.arange(1, 3).reshape(1, 1, 2).astype(float)) o = conv_nd(x, k, stride=(2,), constant=True) out = np.array([[[-5.0, -11.0]]]) assert isinstance(o, Tensor) assert o.constant is True assert o.scalar_only is False assert_allclose(actual=o.data, desired=out, err_msg="1d trivial test failed") # trivial by-hand test: 2-dimensional conv # x: # [ 1, 2, 3, 4], # [ 5, 6, 7, 8], # [ 9, 10, 11, 12]] # k: # [-1, -2], # [-3, -4] # stride = (1, 2) x = Tensor(np.arange(1, 13).reshape(1, 1, 3, 4).astype(float)) k = Tensor(-1 * np.arange(1, 5).reshape(1, 1, 2, 2).astype(float)) o = conv_nd(Tensor(x), k, stride=(1, 2), constant=True) out = np.array([[[[-44.0, -64.0], [-84.0, -104.0]]]]) assert isinstance(o, Tensor) assert o.constant is True assert o.scalar_only is False assert_allclose(actual=o.data, desired=out, err_msg="2d trivial test failed") def test_bad_conv_shapes(): x = np.zeros((1, 2, 2, 2)) w = np.zeros((1, 3, 2, 2)) with raises(ValueError): conv_nd(x, w, stride=1, padding=0) # mismatched channels w = np.zeros((1, 2, 3, 2)) with raises(ValueError): conv_nd(x, w, stride=1, padding=0) # large filter w = np.zeros((1, 2, 2, 2)) with raises(AssertionError): conv_nd(x, w, stride=0, padding=0) # bad stride with raises(AssertionError): conv_nd(x, w, stride=[1, 2, 3]) # bad stride with raises(AssertionError): conv_nd(x, w, stride=1, padding=-1) # bad pad with raises(AssertionError): conv_nd(x, w, stride=1, padding=[1, 2, 3]) # bad pad with raises(ValueError): conv_nd(x, w, stride=3, padding=1) # shape mismatch @settings(deadline=None) @given(ndim=st.integers(1, 4), data=st.data()) def test_padding(ndim: int, data: st.DataObject): """Ensure that convolving a padding-only image with a commensurate kernel yields the single entry: 0""" padding = data.draw( st.integers(1, 3) | st.tuples(*[st.integers(1, 3)] * ndim), label="padding" ) x = Tensor( data.draw( hnp.arrays(shape=(1, 1) + (0,) * ndim, dtype=float, elements=st.floats()), label="x", ) ) pad_tuple = padding if isinstance(padding, tuple) else (padding,) * ndim kernel = data.draw( hnp.arrays( shape=(1, 1) + tuple(2 * p for p in pad_tuple), dtype=float, elements=st.floats(allow_nan=False, allow_infinity=False), ) ) out = conv_nd(x, kernel, padding=padding, stride=1) assert out.shape == (1,) * x.ndim assert out.item() == 0.0 out.sum().backward() assert x.grad.shape == x.shape @fwdprop_test_factory( mygrad_func=conv_nd, true_func=conv_bank, num_arrays=2, index_to_arr_shapes={0: (4, 5, 7), 1: (2, 5, 3)}, kwargs=dict(stride=(1,), dilation=(1,)), index_to_bnds={0: (-10, 10), 1: (-10, 10)}, ) def test_conv_1d_fwd(): """ (N=4, C=5, W=7) x (F=2, C=5, Wf=3); stride=1, dilation=1 Also tests meta properties of conv function - appropriate return type, behavior with `constant` arg, etc.""" def _conv_nd(x, w, stride, dilation=1, padding=0): """ use mygrad-conv_nd forward pass for numerical derivative Returns ------- numpy.ndarray""" return conv_nd( x, w, stride=stride, dilation=dilation, padding=padding, constant=True ).data @settings(deadline=None) @backprop_test_factory( mygrad_func=conv_nd, true_func=_conv_nd, num_arrays=2, index_to_arr_shapes={0: (2, 1, 7), 1: (2, 1, 3)}, kwargs={"stride": (1,)}, index_to_bnds={0: (-10, 10), 1: (-10, 10)}, vary_each_element=True, ) def test_conv_1d_bkwd(): """ (N=2, C=1, W=7) x (F=2, C=1, Wf=3); stride=1, dilation=1 Also tests meta properties of conv-backprop - appropriate return type, behavior with `constant` arg, good behavior of null_gradients, etc.""" @settings(deadline=None, suppress_health_check=(HealthCheck.filter_too_much,)) @given( data=st.data(), shape=hnp.array_shapes(min_dims=1, max_dims=3, max_side=10), num_filters=st.integers(1, 3), num_batch=st.integers(1, 3), num_channel=st.integers(1, 3), ) def test_conv_ND_fwd(data, shape, num_filters, num_batch, num_channel): img_shape = (num_batch, num_channel) + shape padding = data.draw( st.integers(0, 2) | st.tuples(*[st.integers(0, 2)] * len(shape)), label="padding", ) if isinstance(padding, tuple): shape = tuple(s + 2 * p for s, p in zip(shape, padding)) else: shape = tuple(s + 2 * padding for s in shape) win_dim = len(shape) shape = (num_batch, num_channel) + shape win_shape = data.draw( st.tuples(*(st.integers(1, s) for s in shape[-win_dim:])), label="win_shape" ) kernel_shape = (num_filters, shape[1], *win_shape) stride = data.draw( st.tuples(*(st.integers(1, s) for s in shape[-win_dim:])), label="stride" ) max_dilation = np.array(shape[-win_dim:]) // win_shape dilation = data.draw( st.tuples(*(st.integers(1, s) for s in max_dilation)), label="dilation" ) conf = dict(stride=stride, dilation=dilation, padding=padding) # skip invalid data/kernel/stride/dilation combinations assume(get_outshape(shape[2:], kernel_shape[2:], stride, dilation) is not None) kernels = data.draw( hnp.arrays(dtype=float, shape=kernel_shape, elements=st.floats(-10, 10)), label="kernels", ) x = data.draw( hnp.arrays(dtype=float, shape=img_shape, elements=st.floats(-10, 10)), label="x" ) mygrad_conv = conv_nd(x, kernels, **conf).data numpy_conv = conv_bank(x, kernels, **conf) assert_allclose(actual=mygrad_conv, desired=numpy_conv, atol=1e-6, rtol=1e-6) @settings(deadline=None, suppress_health_check=(HealthCheck.filter_too_much,)) @given( data=st.data(), shape=hnp.array_shapes(min_dims=1, max_dims=3, max_side=6), num_filters=st.integers(1, 3), num_batch=st.integers(1, 3), num_channel=st.integers(1, 3), ) def test_conv_ND_bkwd(data, shape, num_filters, num_batch, num_channel): """ Test conv-backprop 1D-3D with various strides and dilations.""" img_shape = (num_batch, num_channel) + shape padding = data.draw( st.integers(0, 2) | st.tuples(*[st.integers(0, 2)] * len(shape)), label="padding", ) if isinstance(padding, tuple): shape = tuple(s + 2 * p for s, p in zip(shape, padding)) else: shape = tuple(s + 2 * padding for s in shape) win_dim = len(shape) shape = (num_batch, num_channel) + shape win_shape = data.draw( st.tuples(*(st.integers(1, s) for s in shape[-win_dim:])), label="win_shape" ) kernel_shape = (num_filters, shape[1], *win_shape) stride = data.draw( st.tuples(*(st.integers(1, s) for s in shape[-win_dim:])), label="stride" ) max_dilation = np.array(shape[-win_dim:]) // win_shape dilation = data.draw( st.tuples(*(st.integers(1, s) for s in max_dilation)), label="dilation" ) conf = dict(stride=stride, dilation=dilation, padding=padding) # skip invalid data/kernel/stride/dilation combinations assume(get_outshape(shape[2:], kernel_shape[2:], stride, dilation) is not None) kernels = data.draw( hnp.arrays(dtype=float, shape=kernel_shape, elements=st.floats(-10, 10)), label="kernels", ) x = data.draw( hnp.arrays(dtype=float, shape=img_shape, elements=st.floats(-10, 10)), label="x" ) x = Tensor(x) kernels = Tensor(kernels) out = conv_nd(x, kernels, **conf) grad = data.draw( hnp.arrays( shape=out.shape, dtype=float, elements=st.floats(-10, 10), unique=True ), label="grad", ) out.backward(grad) grads_numerical = numerical_gradient_full( _conv_nd, *(i.data for i in (x, kernels)), back_grad=grad, kwargs=conf ) for n, (arr, d_num) in enumerate(zip((x, kernels), grads_numerical)): assert_allclose( arr.grad, d_num, atol=1e-4, rtol=1e-4, err_msg="arr-{}: numerical derivative and mygrad derivative do not match".format( n ), )