from numpy import ones, round, zeros, expand_dims, Inf, tile, arange, repeat, array from functools import wraps from matplotlib import pyplot as plt import matplotlib.cm as cm from numpy.ma import masked_where from numpy import maximum, minimum import cvxpy as cp def pplot(As, titles): # setup try: vmin = min([A.min() for A, t in zip(As[:-1], titles) if "missing" not in t]) # for pixel color reference except: vmin = As[0].min() try: vmax = max([A.max() for A, t in zip(As[:-1], titles) if "missing" not in t]) except: vmax = As[0].max() my_dpi = 96 plt.figure(figsize=(1.4*(250*len(As))/my_dpi, 250/my_dpi), dpi = my_dpi) for i, (A, title) in enumerate(zip(As, titles)): plt.subplot(1, len(As), i+1) if i == len(As)-1: vmin, vmax = A.min(), A.max() if "missing" in title: missing = A masked_data = ones(As[i-1].shape) for j,k in missing: masked_data[j,k] = 0 masked_data = masked_where(masked_data > 0.5, masked_data) plt.imshow(As[i-1], interpolation = 'nearest', vmin = vmin, vmax = vmax) plt.colorbar() plt.imshow(masked_data, cmap = cm.binary, interpolation = "nearest") else: plt.imshow(A, interpolation = 'nearest', vmin = vmin, vmax = vmax) plt.colorbar() plt.title(title) plt.axis("off") plt.show() # # def unroll_missing(missing, ns): # missing_unrolled = [] # for i, (MM, n) in enumerate(zip(missing, ns)): # for m in MM: # n2 = m[1] + sum([ns[j] for j in range(i)]) # missing_unrolled.append((m[0], n2)) # return missing_unrolled # def shrinkage(a, kappa): """ soft threshold with parameter kappa). """ try: return maximum(a - kappa(ones(a.shape), 0)) - maximum(-a - kappa*ones(a.shape), 0) except: return max(a - kappa, 0) - max(-1 - kappa, 0)