import autograd.numpy as np from . import fft from .cache import Cache from scipy.stats import median_absolute_deviation as mad # Filter for the scarlet transform. Here bspline h = np.array([1. / 16, 1. / 4, 3. / 8, 1. / 4, 1. / 16]) class Starlet(object): """ A class used to create the Wavelet transform of a cube of images from the 'a trou' algorithm. The transform is performed by convolving the image by a seed starlet: the transform of an all-zero image with its central pixel set to one. This requires 2-fold padding of the image and an odd pad shape. The fft of the seed starlet is cached so that it can be reused in the transform of other images that have the same shape. """ def __init__(self, image=None, lvl=None, coefficients=None, direct=True): """ Initialise the Starlet object Paramters image: numpy ndarray image to transform lvl: int number of starlet levels to use in the decomposition starlet: array Starlet transform of an array direct: bool if set to True, uses direct wavelet transform with the a trou algorithm. if set to False, the transform is performed by convolving the image by the wavelet transform of a dirac. """ self.seed = None # Transform method self._direct = direct if coefficients is None: if image is None: raise InputError('At least an image or a set of coefficients should be provided') else: # Original shape of the image self._image_shape = image.shape # Padding shape for the starlet transform if lvl is None: self._lvl = get_starlet_shape(image.shape) else: self._lvl = lvl if len(image.shape) == 2: image = image[np.newaxis, :, :] else: if len(np.shape(coefficients)) == 3: coefficients = coefficients[np.newaxis, :, :, :] self._image_shape = [coefficients.shape[0], *coefficients.shape[-2:]] self._lvl = coefficients.shape[1] if image is not None: raise InputError("Ambiguous initialisation: \ Starlet objects should be instanciated either with an image of a set of coefficients, not both") self._image = image self._coeffs = coefficients self._starlet_shape = [self._lvl, *self._image_shape[-2:]] if self.seed is None: self.seed = mk_starlet(self._starlet_shape) self._norm = np.sqrt(np.sum(self.seed ** 2, axis=(-2, -1))) @property def image(self): """The real space image""" rec = [] for star in self._coeffs: rec.append(iuwt(star)) self._image = np.array(rec) return self._image @image.setter def image(self, image): """Updates the coefficients if the image is changed""" if len(image.shape) == 2: self._image = image[np.newaxis, :, :] else: self._image = image if self._direct == True: self._coeffs = self.direct_transform() else: self._coeffs = self.transform() @property def norm(self): """The norm of the seed wavelet in each wavelet level (not in coarse wavelet)""" return self._norm @property def coefficients(self): """Starlet coefficients""" if self._direct == True: self._coeffs = self.direct_transform() else: self._coeffs = self.transform() return self._coeffs @coefficients.setter def coefficients(self, coeffs): """Updates the image if the coefficients are changed""" if len(np.shape(coeffs)) == 3: coeffs = coeffs[np.newaxis, :, :, :] self._coeffs = coeffs rec = [] for star in self._coeffs: rec.append(iuwt(star)) self._image = np.array(rec) @property def shape(self): """The shape of the real space image""" return self._image.shape @property def scales(self): """Number of starlet scales""" return self._lvl def transform(self): """ Performs the wavelet transform of an image by convolution with the seed wavelet Seed wavelets are the transform of a dirac in starlets when computed for a given shape, the seed is cached to be reused for images with the same shape. The transform is applied to `self._image` Returns ------- starlet: numpy ndarray the starlet transform of the Starlet object's image """ try: #Check if the starlet seed exists seed_fft = Cache.check('Starlet', tuple(self._starlet_shape)) except KeyError: # make a starlet seed self.seed = mk_starlet(self._starlet_shape) # Take its fft seed_fft = fft.Fourier(self.seed) seed_fft.fft(self._starlet_shape[-2:], (-2,-1)) # Cache the fft Cache.set('Starlet', tuple(self._starlet_shape), seed_fft) coefficients = [] for im in self._image: coefficients.append(fft.convolve(seed_fft, fft.Fourier(im[np.newaxis, :, :]), axes = (-2,-1)).image) return np.array(coefficients) def direct_transform(self): """ Computes the direct starlet transform of the starlet's image Returns ------- starlet: numpy ndarray the starlet transform of the Starlet object's image """ return mk_starlet(self._starlet_shape, self._image) def __len__(self): return len(self._image) def filter(self, niter = 20, k = 5): """ Applies wavelet iterative filtering to denoise the image Parameters ---------- niter: int number of iterations k: float threshold in units of noise levels below which coefficients are thresholded lvl: int Number of wavelet scale to use in the decomposition Results ------- filtered: array the image of filtered images """ if self._coeffs is None: self.coefficients if self._image is None: self.image() sigma = k * mad_wavelet(self._image)[:, None] * self.norm[None, :] filtered = 0 image = self._image wavelet = self._coeffs support = np.where(np.abs(wavelet[:,:-1,:,:]) < sigma[:,:-1,None, None] * np.ones_like(wavelet[:,:-1,:,:])) for i in range(niter): R = image - filtered R_coeff = Starlet(R) R_coeff.coefficients[support] = 0 filtered += R_coeff.image filtered[filtered < 0] = 0 self.image = filtered return filtered def get_starlet_shape(shape, lvl = None): """ Get the pad shape for a starlet transform """ #Number of levels for the Starlet decomposition lvl_max = np.int(np.log2(np.min(shape[-2:]))) if (lvl is None) or lvl > lvl_max: lvl = lvl_max return lvl def mk_starlet(shape, image = None): """ Creates a starlet for a given 2d shape. Parameters ---------- shape: tuple 2D shape of the desired shapelet lvl: int number of shapelet levels to compute. If None, lvl is set to the log2 of the number of pixels on a side. if lvl is higher than this number lvl will be set to it. Returns ------- starlet: Fourier object the starlet transform of a Dirac fonction as the `image` of a Fourier object """ lvl, n1, n2 = shape[-3:] # Filter size n = np.size(h) if image is None: c = np.zeros((n1,n2)) c[int(n1/2), int(n2/2)] = 1 else: if len(image.shape) > 2: wave = [] for im in image: wave.append(mk_starlet(shape, im)) return np.array(wave) else: c = image c = fft.Fourier(c) ## wavelet set of coefficients. wave = np.zeros([lvl, n1, n2]) for i in np.arange(lvl - 1): newh = np.zeros((n + (n - 1) * (2 ** i - 1), 1)) newh[0::2 ** i, 0] = h newhT = fft.Fourier(newh.T) newh = fft.Fourier(newh) # Calculates c(j+1) # Line convolution cnew = fft.convolve(c, newh, axes=[0]) # Column convolution cnew = fft.convolve(cnew, newhT, axes=[1]) ###### hoh for g; Column convolution hc = fft.convolve(cnew, newh, axes=[0]) # hoh for g; Line convolution hc = fft.convolve(hc, newhT, axes=[1]) # wj+1 = cj-hcj+1 wave[i, :, :] = c.image - hc.image c = cnew wave[-1, :, :] = c.image return wave def iuwt(starlet): """ Inverse starlet transform Parameters ---------- starlet: Shapelet object Starlet to be inverted Returns ------- cJ: array a 2D image that corresponds to the inverse transform of stralet. """ lvl, n1, n2 = np.shape(starlet) n = np.size(h) # Coarse scale cJ = fft.Fourier(starlet[-1, :, :]) for i in np.arange(1, lvl): newh = np.zeros((n + (n - 1) * (2 ** (lvl - i - 1) - 1), 1)) newh[0::2 ** (lvl - i - 1), 0] = h newhT = fft.Fourier(newh.T) newh = fft.Fourier(newh) # Line convolution cnew = fft.convolve(cJ, newh, axes=[0]) # Column convolution cnew = fft.convolve(cnew, newhT, axes=[1]) cJ = fft.Fourier(cnew.image + starlet[lvl - 1 - i, :, :]) return np.reshape(cJ.image, (n1, n2)) class InputError(Exception): """Exception raised for errors in the input. Attributes: expression -- input expression in which the error occurred message -- explanation of the error """ def __init__(self, message): self.message = message def mad_wavelet(image): """ image: Median absolute deviation of the first wavelet scale. (WARNING: sorry to disapoint, this is not a wavelet for mad scientists) Parameters ---------- image: array An image or cube of images Returns ------- mad: array median absolute deviation for each image in the cube """ sigma = mad(Starlet(image, lvl = 2).coefficients[:,0,...], axis = (-2,-1)) return sigma