from __future__ import print_function, division # Imports import numpy as np from numpy.fft import fftn, ifftn, fft2, ifft2 import cv2 import timeit import sys ############################################################################### # Implementations supported ############################################################################### Impl = {'numpy': 0, 'halide': 1, 'pycuda': 2} ############################################################################### # TODO: DIRTY HACK FOR BACKWARDS COMPATIBILITY! ############################################################################### try: np.stack(np.array([1])) except: from numpy.core import numeric def _stack(arrays, axis=0): arrays = [np.asanyarray(arr) for arr in arrays] if not arrays: raise ValueError('need at least one array to stack') shapes = set(arr.shape for arr in arrays) if len(shapes) != 1: raise ValueError('all input arrays must have the same shape') result_ndim = arrays[0].ndim + 1 if not -result_ndim <= axis < result_ndim: msg = 'axis {0} out of bounds [-{1}, {1})'.format(axis, result_ndim) raise np.IndexError(msg) if axis < 0: axis += result_ndim sl = (slice(None),) * axis + (numeric.newaxis,) expanded_arrays = [arr[sl] for arr in arrays] return numeric.concatenate(expanded_arrays, axis=axis) np.stack = _stack ############################################################################### # Image utils ############################################################################### def im2nparray(img, datatype=np.float32): """ Converts and normalizes image in certain datatype (e.g. np.float32) """ np_img = np.array(img) i = np.iinfo(np_img.dtype) max_class = 2 ** i.bits - 1 np_img = np_img.astype(datatype) / max_class return np_img ############################################################################### # Timing utils ############################################################################### # Store last time stamp globally global lastticstamp lastticstamp = [] def tic(): """ Default timer Example: t = tic() ... code elapsed = toc(t) print( '{0}: {1:.4f}ms'.format(message, elapsed) ) """ global lastticstamp t = timeit.default_timer() lastticstamp = t return t def toc(t=[]): """ See tic f """ global lastticstamp # Last tic if not t: if lastticstamp: t = lastticstamp else: print('Error: Call to toc did never call tic before.', file=sys.stderr) return 0.0 # Measure time in ms elapsed = (timeit.default_timer() - t) * 1000.0 # in ms return elapsed ############################################################################### # FFT utils ############################################################################### def fftd(I, dims=None): # Compute fft if dims is None: X = fftn(I) elif dims == 2: X = fft2(I, axes=(0, 1)) else: X = fftn(I, axes=tuple(range(dims))) return X def ifftd(I, dims=None): # Compute fft if dims is None: X = ifftn(I) elif dims == 2: X = ifft2(I, axes=(0, 1)) else: X = ifftn(I, axes=tuple(range(dims))) return X def circshift(x, shifts): for j in range(len(shifts)): x = np.roll(x, shifts[j], axis=j) return x def psf2otf(K, outsize, dims=None): # Size sK = K.shape assert len(sK) == len(outsize) # Pad to large size and circshift padfull = [] for j in range(len(sK)): padfull.append((0, outsize[j] - sK[j])) Kfull = np.pad(K, padfull, mode='constant', constant_values=0.0) # Circular shift shifts = -np.floor_divide(np.array(sK), 2) if dims is not None and dims < len(sK): shifts = shifts[0:dims] Kfull = circshift(Kfull, shifts) # Compute otf otf = fftd(Kfull, dims) # Estimate the rough number of operations involved in the computation of the FFT. if dims is not None and dims < len(sK): sK = sK[0:dims] nElem = np.prod(sK) nOps = 0 for k in range(len(sK)): nffts = nElem / sK[k] nOps = nOps + sK[k] * np.log2(sK[k]) * nffts # Discard the imaginary part of the psf if it's withi roundoff error. eps = np.finfo(np.float32).eps if np.amax(np.absolute(otf.imag)) / np.amax(np.absolute(otf)) <= nOps * eps: otf = otf.real return otf ############################################################################### # Image metrics ############################################################################### def psnr(x, ref, pad=None, maxval=1.0): # Sheck size if ref.shape != x.shape: raise Exception("Wrong size in PSNR evaluation.") # Remove padding if necessary if pad is not None: ss = x.shape il = () for j in range(len(ss)): if len(pad) >= j + 1 and pad[j] > 0: currpad = pad[j] il += np.index_exp[currpad:-currpad] else: il += np.index_exp[:] mse = np.mean((x[il] - ref[il])**2) else: mse = np.mean((x - ref)**2) # MSE if mse > np.finfo(float).eps: return 10.0 * np.log10(maxval**2 / mse) else: return np.inf ############################################################################### # Noise estimation ############################################################################### # Currently only implements one method NoiseEstMethod = {'daub_reflect': 0, 'daub_replicate': 1} def estimate_std(z, method='daub_reflect'): # Estimates noise standard deviation assuming additive gaussian noise # Check method if (method not in NoiseEstMethod.values()) and (method in NoiseEstMethod.keys()): method = NoiseEstMethod[method] else: raise Exception("Invalid noise estimation method.") # Check shape if len(z.shape) == 2: z = z[..., np.newaxis] elif len(z.shape) != 3: raise Exception("Supports only up to 3D images.") # Run on multichannel image channels = z.shape[2] dev = np.zeros(channels) # Iterate over channels for ch in range(channels): # Daubechies denoising method if method == NoiseEstMethod['daub_reflect'] or method == NoiseEstMethod['daub_replicate']: daub6kern = np.array([0.03522629188571, 0.08544127388203, -0.13501102001025, -0.45987750211849, 0.80689150931109, -0.33267055295008], dtype=np.float32, order='F') if method == NoiseEstMethod['daub_reflect']: wav_det = cv2.sepFilter2D(z, -1, daub6kern, daub6kern, borderType=cv2.BORDER_REFLECT_101) else: wav_det = cv2.sepFilter2D(z, -1, daub6kern, daub6kern, borderType=cv2.BORDER_REPLICATE) dev[ch] = np.median(np.absolute(wav_det)) / 0.6745 # Return standard deviation return dev def graph_visualize(prox_fns, filename = None): import graphviz from IPython.display import display dot = graphviz.Digraph() nodes = {} def node(obj): if not obj in nodes: nodes[obj] = 'N%d' % len(nodes) return nodes[obj] from proximal.prox_fns.prox_fn import ProxFn for pfn in prox_fns: dot.node(node(pfn), str(pfn)) activenodes = [pfn.lin_op] while len(activenodes) > 0: n = activenodes.pop(0) if not n in nodes: dot.node(node(n), str(type(n))) activenodes.extend(n.input_nodes) dot.edge(nodes[pfn.lin_op], nodes[pfn]) activenodes = [pfn.lin_op] visited = set() while len(activenodes) > 0: n = activenodes.pop(0) if not n in visited: visited.add(n) activenodes.extend(n.input_nodes) for inn in n.input_nodes: dot.edge(nodes[inn], nodes[n]) if filename is None: display(dot) else: dot.render(filename)