import numpy as np from ..util.typing import zX, zX_like class ConvolutionOp: @staticmethod def valid(A, F): im, ic, iy, ix = A.shape nf, fc, fy, fx = F.shape # fx, fy, fc, nf = F.shape recfield_size = fx * fy * fc oy, ox = iy - fy + 1, ix - fx + 1 rfields = np.zeros((im, oy*ox, recfield_size)) Frsh = F.reshape(nf, recfield_size) if fc != ic: err = "Supplied filter (F) is incompatible with supplied input! (X)\n" err += "input depth: {} != {} :filter depth".format(ic, fc) raise ValueError(err) for i, sy, sx in ((idx, shy, shx) for shx in range(ox) for shy in range(oy) for idx in range(im)): rfields[i][sy*ox + sx] = A[i, :, sy:sy+fy, sx:sx+fx].ravel() # output = np.zeros((im, oy*ox, nf)) # for m in range(im): # output[m] = np.dot(rfields[m], Frsh.T) output = np.matmul(rfields, Frsh.T) output = output.transpose((0, 2, 1)).reshape((im, nf, oy, ox)) return output @staticmethod def full(A, F): nf, fc, fy, fx = F.shape py, px = fy - 1, fx - 1 pA = np.pad(A, pad_width=((0, 0), (0, 0), (py, py), (px, px)), mode="constant", constant_values=0.) return ConvolutionOp.valid(pA, F) @staticmethod def forward(A, F, mode="valid"): if mode == "valid": return ConvolutionOp.valid(A, F) return ConvolutionOp.full(A, F) @staticmethod def backward(X, E, F): dF = ConvolutionOp.forward( A=X.transpose(1, 0, 2, 3), F=E.transpose(1, 0, 2, 3), mode="valid" ).transpose(1, 0, 2, 3) db = E.sum(axis=(0, 2, 3), keepdims=True) dX = ConvolutionOp.forward( A=E, F=F[:, :, ::-1, ::-1].transpose(1, 0, 2, 3), mode="full" ) return dF, db, dX @staticmethod def outshape(inshape, fshape, mode="valid"): ic, iy, ix = inshape[-3:] fx, fy, fc, nf = fshape if mode == "valid": return nf, iy - fy + 1, ix - fx + 1 elif mode == "full": return nf, iy + fy - 1, ix + fx - 1 else: raise RuntimeError("Unsupported mode:", mode) def __str__(self): return "Convolution" class MaxPoolOp: def __str__(self): return "MaxPool" @staticmethod def predict(A): return np.max([ A[:, :, 0::2, 0::2], A[:, :, 0::2, 1::2], A[:, :, 1::2, 0::2], A[:, :, 1::2, 1::2], ], axis=0) @staticmethod def forward(A, fdim): im, ic, iy, ix = A.shape oy, ox = iy // fdim, ix // fdim output = zX(im, ic, oy, ox) filt = zX_like(A) for m in range(im): for c in range(ic): for y, sy in enumerate(range(0, iy, fdim)): for x, sx in enumerate(range(0, ix, fdim)): recfield = A[m, c, sy:sy+fdim, sx:sx+fdim] value = recfield.max() output[m, c, y, x] = value ffield = np.equal(recfield, value) filt[m, c, sy:sy+fdim, sx:sx+fdim] += ffield return output, filt @staticmethod def backward(E, filt): em, ec, ey, ex = E.shape fm, fc, fy, fx = filt.shape fdim = fy // ey for m in range(em): for c in range(ec): for i, y in enumerate(range(0, fy, fdim)): for j, x in enumerate(range(0, fx, fdim)): filt[m, c, y:y+fdim, x:x+fdim] *= E[m, c, i, j] return filt @staticmethod def outshape(inshape, fdim): if len(inshape) == 3: m, iy, ix = inshape return m, iy // fdim, ix // fdim elif len(inshape) == 2: iy, ix = inshape return iy // fdim, ix // fdim