import numpy as np import scipy as sp from scipy.sparse.linalg import LinearOperator import matplotlib import matplotlib.pyplot as plt from vec import vec import timeit import pywt import os def vec(x): return x.ravel(order='F') def sigmoid(x): return 1/(1+np.exp(-x)) def wavelet_transform(x): w_coeffs_rgb = [] # np.zeros(x.shape[3], np.prod(x.shape)) for i in range(x.shape[3]): w_coeffs_list = pywt.wavedec2(x[0,:,:,i], 'db4', level=None, mode='periodization') w_coeffs, coeff_slices = pywt.coeffs_to_array(w_coeffs_list) w_coeffs_rgb.append(w_coeffs) w_coeffs_rgb = np.array(w_coeffs_rgb) return w_coeffs_rgb, coeff_slices def inverse_wavelet_transform(w_coeffs_rgb, coeff_slices, x_shape): x_hat = np.zeros(x_shape) for i in range(w_coeffs_rgb.shape[0]): w_coeffs_list = pywt.array_to_coeffs(w_coeffs_rgb[i,:,:], coeff_slices) x_hat[0,:,:,i] = pywt.waverecn(w_coeffs_list, wavelet='db4', mode='periodization') return x_hat def soft_threshold(x, beta): y = np.maximum(0, x-beta) - np.maximum(0, -x-beta) return y # A_fun, AT_fun takes a vector (d,1) or (d,) as input def solve(y, A_fun, AT_fun, lambda_l1, reshape_img_fun, base_folder, show_img_progress=False, alpha=0.2, max_iter=100, solver_tol=1e-6): """ See Wang, Yu, Wotao Yin, and Jinshan Zeng. "Global convergence of ADMM in nonconvex nonsmooth optimization." arXiv preprint arXiv:1511.06324 (2015). It provides convergence condition: basically with large enough alpha, the program will converge. """ #result_folder = '%s/iter-imgs' % base_folder #if not os.path.exists(result_folder): #os.makedirs(result_folder) obj_lss = np.zeros(max_iter) x_zs = np.zeros(max_iter) u_norms = np.zeros(max_iter) times = np.zeros(max_iter) ATy = AT_fun(y) x_shape = ATy.shape d = np.prod(x_shape) def A_cgs_fun(x): x = np.reshape(x, x_shape, order='F') y = AT_fun(A_fun(x)) + alpha * x return vec(y) A_cgs = LinearOperator((d,d), matvec=A_cgs_fun, dtype='float') def compute_p_inv_A(b, z0): (z,info) = sp.sparse.linalg.cgs(A_cgs, vec(b), x0=vec(z0), tol=1e-3, maxiter=100) if info > 0: print 'cgs convergence to tolerance not achieved' elif info <0: print 'cgs gets illegal input or breakdown' z = np.reshape(z, x_shape, order='F') return z def A_cgs_fun_init(x): x = np.reshape(x, x_shape, order='F') y = AT_fun(A_fun(x)) return vec(y) A_cgs_init = LinearOperator((d,d), matvec=A_cgs_fun_init, dtype='float') def compute_init(b, z0): (z,info) = sp.sparse.linalg.cgs(A_cgs_init, vec(b), x0=vec(z0), tol=1e-2) if info > 0: print 'cgs convergence to tolerance not achieved' elif info <0: print 'cgs gets illegal input or breakdown' z = np.reshape(z, x_shape, order='F') return z # initialize z and u z = compute_init(ATy, ATy) u = np.zeros(x_shape) plot_normalozer = matplotlib.colors.Normalize(vmin=0.0, vmax=1.0, clip=True) start_time = timeit.default_timer() for iter in range(max_iter): # x-update net_input = z+u Wzu, wbook = wavelet_transform(net_input) q = soft_threshold(Wzu, lambda_l1/alpha) x = inverse_wavelet_transform(q, wbook, x_shape) x = np.reshape(x, x_shape) # z-update b = ATy + alpha * (x - u) z = compute_p_inv_A(b, z) # u-update u += z - x; if show_img_progress == True: fig = plt.figure('current_sol') plt.gcf().clear() fig.canvas.set_window_title('iter %d' % iter) plt.subplot(1,3,1) plt.imshow(reshape_img_fun(np.clip(x, 0.0, 1.0)), interpolation='nearest', norm=plot_normalozer) plt.title('x') plt.subplot(1,3,2) plt.imshow(reshape_img_fun(np.clip(z, 0.0, 1.0)), interpolation='nearest', norm=plot_normalozer) plt.title('z') plt.subplot(1,3,3) plt.imshow(reshape_img_fun(np.clip(net_input, 0.0, 1.0)), interpolation='nearest', norm=plot_normalozer) plt.title('netin') plt.pause(0.00001) obj_ls = 0.5 * np.sum(np.square(y - A_fun(x))) x_z = np.sqrt(np.mean(np.square(x-z))) u_norm = np.sqrt(np.mean(np.square(u))) print 'iter = %d: obj_ls = %.3e |x-z| = %.3e u_norm = %.3e' % (iter, obj_ls, x_z, u_norm) obj_lss[iter] = obj_ls x_zs[iter] = x_z u_norms[iter] = u_norm times[iter] = timeit.default_timer() - start_time ## save images #filename = '%s/%d-x.jpg' % (result_folder, iter) #sp.misc.imsave(filename, sp.misc.imresize((reshape_img_fun(np.clip(x, 0.0, 1.0))*255).astype(np.uint8), 4.0, interp='nearest')) #filename = '%s/%d-z.jpg' % (result_folder, iter) #sp.misc.imsave(filename, sp.misc.imresize((reshape_img_fun(np.clip(z, 0.0, 1.0))*255).astype(np.uint8), 4.0, interp='nearest')) #filename = '%s/%d-u.jpg' % (result_folder, iter) #sp.misc.imsave(filename, sp.misc.imresize((reshape_img_fun(np.clip(u, 0.0, 1.0))*255).astype(np.uint8), 4.0, interp='nearest')) #_ = raw_input('') if x_z < solver_tol: break infos = {'obj_lss': obj_lss, 'x_zs': x_zs, 'u_norms': u_norms, 'times': times, 'alpha':alpha, 'lambda_l1':lambda_l1, 'max_iter':max_iter, 'solver_tol':solver_tol} return (x, z, u, infos)