#!/usr/bin/python import numpy as np import numpy.linalg as la import re import math import tensorflow as tf __doc__ = """ This file contains various separable shrinkage functions for use in TensorFlow. All functions perform shrinkage toward zero on each elements of an input vector r = x + w, where x is sparse and w is iid Gaussian noise of a known variance rvar All shrink_* functions are called with signature xhat,dxdr = func(r,rvar,theta) Hyperparameters are supplied via theta (which has length ranging from 1 to 5) shrink_soft_threshold : 1 or 2 parameters shrink_bgest : 2 parameters shrink_expo : 3 parameters shrink_spline : 3 parameters shrink_piecwise_linear : 5 parameters A note about dxdr: dxdr is the per-column average derivative of xhat with respect to r. So if r is in Real^(NxL), then xhat is in Real^(NxL) and dxdr is in Real^L """ def simple_soft_threshold(r_, lam_): "implement a soft threshold function y=sign(r)*max(0,abs(r)-lam)" lam_ = tf.maximum(lam_, 0) return tf.sign(r_) * tf.maximum(tf.abs(r_) - lam_, 0) def auto_gradients(xhat , r ): """Return the per-column average gradient of xhat xhat with respect to r. """ dxdr = tf.gradients(xhat,r)[0] dxdr = tf.reduce_mean(dxdr,0) minVal=.5/int(r.get_shape()[0]) dxdr = tf.maximum( dxdr, minVal) return dxdr def shrink_soft_threshold(r,rvar,theta): """ soft threshold function y=sign(x)*max(0,abs(x)-theta[0]*sqrt(rvar) )*scaling where scaling is theta[1] (default=1) in other words, if theta is len(1), then the standard """ if len(theta.get_shape())>0 and theta.get_shape() != (1,): lam = theta[0] * tf.sqrt(rvar) scale=theta[1] else: lam = theta * tf.sqrt(rvar) scale = None lam = tf.maximum(lam,0) arml = tf.abs(r) - lam xhat = tf.sign(r) * tf.maximum(arml,0) dxdr = tf.reduce_mean(tf.to_float(arml>0),0) if scale is not None: xhat = xhat*scale dxdr = dxdr*scale return (xhat,dxdr) def shrink_bgest(r,rvar,theta): """Bernoulli-Gaussian MMSE estimator Perform MMSE estimation E[x|r] for x ~ BernoulliGaussian(lambda,xvar1) r|x ~ Normal(x,rvar) The parameters theta[0],theta[1] represent The variance of non-zero x[i] xvar1 = abs(theta[0]) The probability of nonzero x[i] lamba = 1/(exp(theta[1])+1) """ xvar1 = abs(theta[...,0]) loglam = theta[...,1] # log(1/lambda - 1) beta = 1/(1+rvar/xvar1) r2scale = r*r*beta/rvar rho = tf.exp(loglam - .5*r2scale ) * tf.sqrt(1 +xvar1/rvar) rho1 = rho+1 xhat = beta*r/rho1 dxdr = beta*((1+rho*(1+r2scale) ) / tf.square( rho1 )) dxdr = tf.reduce_mean(dxdr,0) return (xhat,dxdr) def shrink_piecwise_linear(r,rvar,theta): """Implement the piecewise linear shrinkage function. With minor modifications and variance normalization. theta[...,0] : abscissa of first vertex, scaled by sqrt(rvar) theta[...,1] : abscissa of second vertex, scaled by sqrt(rvar) theta[...,2] : slope from origin to first vertex theta[''',3] : slope from first vertex to second vertex theta[...,4] : slope after second vertex """ ab0 = theta[...,0] ab1 = theta[...,1] sl0 = theta[...,2] sl1 = theta[...,3] sl2 = theta[...,4] # scale each column by sqrt(rvar) scale_out = tf.sqrt(rvar) scale_in = 1/scale_out rs = tf.sign(r*scale_in) ra = tf.abs(r*scale_in) # split the piecewise linear function into regions rgn0 = tf.to_float( ra<ab0) rgn1 = tf.to_float( ra<ab1) - rgn0 rgn2 = tf.to_float( ra>=ab1) xhat = scale_out * rs*( rgn0*sl0*ra + rgn1*(sl1*(ra - ab0) + sl0*ab0 ) + rgn2*(sl2*(ra - ab1) + sl0*ab0 + sl1*(ab1-ab0) ) ) dxdr = sl0*rgn0 + sl1*rgn1 + sl2*rgn2 dxdr = tf.reduce_mean(dxdr,0) return (xhat,dxdr) def pwlin_grid(r_,rvar_,theta_,dtheta = .75): """piecewise linear with noise-adaptive grid spacing. returns xhat,dxdr where q = r/dtheta/sqrt(rvar) xhat = r * interp(q,theta) all but the last dimensions of theta must broadcast to r_ e.g. r.shape = (500,1000) is compatible with theta.shape=(500,1,7) """ ntheta = int(theta_.get_shape()[-1]) scale_ = dtheta / tf.sqrt(rvar_) ars_ = tf.clip_by_value( tf.expand_dims( tf.abs(r_)*scale_,-1),0.0, ntheta-1.0 ) centers_ = tf.constant( np.arange(ntheta),dtype=tf.float32 ) outer_distance_ = tf.maximum(0., 1.0-tf.abs(ars_ - centers_) ) # new dimension for distance to closest bin centers (or center) gain_ = tf.reduce_sum( theta_ * outer_distance_,axis=-1) # apply the gain (learnable) xhat_ = gain_ * r_ dxdr_ = tf.gradients(xhat_,r_)[0] return (xhat_,dxdr_) def shrink_expo(r,rvar,theta): """ Exponential shrinkage function xhat = r*(theta[1] + theta[2]*exp( - r^2/(2*theta[0]^2*rvar ) ) ) """ r2 = tf.square(r) den = -1/(2*tf.square(theta[0])*rvar) rho = tf.exp( r2 * den) xhat = r*( theta[1] + theta[2] * rho ) return (xhat,auto_gradients(xhat,r) ) def shrink_spline(r,rvar,theta): """ Spline-based shrinkage function """ scale = theta[0]*tf.sqrt(rvar) rs = tf.sign(r) ar = tf.abs(r/scale) ar2 = tf.square(ar) ar3 = ar*ar2 reg1 = tf.to_float(ar<1) reg2 = tf.to_float(ar<2)-reg1 ar_m2 = 2-ar ar_m2_p2 = tf.square(ar_m2) ar_m2_p3 = ar_m2*ar_m2_p2 beta3 = ( (2./3 - ar2 + .5*ar3)*reg1 + (1./6*(ar_m2_p3))*reg2 ) xhat = r*(theta[1] + theta[2]*beta3) return (xhat,auto_gradients(xhat,r)) def get_shrinkage_function(name): "retrieve a shrinkage function and some (probably awful) default parameter values" try: return { 'soft':(shrink_soft_threshold,(1.,1.) ), 'bg':(shrink_bgest, (1,math.log(1/.1-1)) ), 'pwlin':(shrink_piecwise_linear, (2,4,0.1,1.5,.95) ), 'pwgrid':(pwlin_grid, np.linspace(.1,1,15).astype(np.float32) ), 'expo':(shrink_expo, (2.5,.9,-1) ), 'spline':(shrink_spline, (3.7,.9,-1.5)) }[name] except KeyError,ke: raise ValueError('unrecognized shrink function %s' % name) sys.exit(1) def tfcf(v): " return a tensorflow constant float version of v" return tf.constant(v,dtype=tf.float32) def tfvar(v): " return a tensorflow variable float version of v" return tf.Variable(v,dtype=tf.float32) def nmse(x1,x2): "return the normalized mean squared error between 2 numpy arrays" xdif=x1-x2 return 2*(xdif*xdif).sum() / ( (x1*x1).sum() + (x2*x2).sum()) def test_func(shrink_func,theta,**kwargs): # repeat the same experiment tf.reset_default_graph() tf.set_random_seed(kwargs.get('seed',1) ) N = kwargs.get('N',200) L = kwargs.get('L',400) tol = kwargs.get('tol',1e-6) step = kwargs.get('step',1e-4) shape = (N,L) xvar_ = tfcf(kwargs.get('xvar1',1)) pnz_ = tfcf(kwargs.get('pnz',.1)) rvar = np.ones(L)*kwargs.get('rvar',.1) rvar_ = tfcf(rvar) gx = tf.to_float(tf.random_uniform(shape ) < pnz_) * tf.random_normal(shape, stddev=tf.sqrt(xvar_), dtype=tf.float32) gr = gx + tf.random_normal(shape,stddev=tf.sqrt(rvar_), dtype=tf.float32) x_ = tf.placeholder(gx.dtype,gx.get_shape()) r_ = tf.placeholder(gr.dtype,gr.get_shape()) theta_ = tfvar(theta) xhat_,dxdr_ = shrink_func(r_,rvar_ ,theta_) loss = tf.nn.l2_loss(xhat_-x_) optimize_theta = tf.train.AdamOptimizer(step).minimize(loss,var_list=[theta_]) # calculate an empirical gradient for comparison dr_ = tfcf(1e-4) dxdre_ = tf.reduce_mean( (shrink_func(r_+.5*dr_,rvar_ ,theta_)[0] - shrink_func(r_-.5*dr_,rvar_ ,theta_)[0]) / dr_ ,0) with tf.Session() as sess: sess.run( tf.global_variables_initializer() ) (x,r) = sess.run((gx,gr)) fd = {x_:x,r_:r} loss_prev = float('inf') for i in range(500): for j in range(50): sess.run(optimize_theta,fd) loss_cur,theta_cur = sess.run((loss,theta_),fd) #print 'loss=%s, theta=%s' % (str(loss_cur),str(theta_cur)) if (1-loss_cur/loss_prev) < tol: break loss_prev = loss_cur xhat,dxdr,theta,dxdre = sess.run( (xhat_,dxdr_,theta_,dxdre_),fd) assert xhat.shape==(N,L) assert dxdr.shape==(L,) # MMV-specific -- we assume one average gradient per column assert nmse(dxdr,dxdre) < tol tf.reset_default_graph() estname = re.sub('.*shrink_([^ ]*).*','\\1', repr(shrink_func) ) print '#### %s loss=%g \ttheta=%s' % (estname,loss_cur,repr(theta)) if False: import matplotlib.pyplot as plt plt.figure(1) plt.plot(r.reshape(-1),xhat.reshape(-1),'b.') plt.plot(r,xhat,'.') plt.show() return (x,r,xhat,rvar) def show_shrinkage(shrink_func,theta,**kwargs): tf.reset_default_graph() tf.set_random_seed(kwargs.get('seed',1) ) N = kwargs.get('N',500) L = kwargs.get('L',4) nsigmas = kwargs.get('sigmas',10) shape = (N,L) rvar = 1e-4 r = np.reshape( np.linspace(0,nsigmas,N*L)*math.sqrt(rvar),shape) r_ = tfcf(r) rvar_ = tfcf(np.ones(L)*rvar) xhat_,dxdr_ = shrink_func(r_,rvar_ ,tfcf(theta)) with tf.Session() as sess: sess.run( tf.global_variables_initializer() ) xhat = sess.run(xhat_) import matplotlib.pyplot as plt plt.figure(1) plt.plot(r.reshape(-1),r.reshape(-1),'y') plt.plot(r.reshape(-1),xhat.reshape(-1),'b') if kwargs.has_key('title'): plt.suptitle(kwargs['title']) plt.show() if __name__ == "__main__": import sys import getopt usage = """ -h : help -p file : load problem definition parameters from npz file -f function : use the named shrinkage function, one of {soft,bg,pwlin,expo,spline} """ try: opts,args = getopt.getopt(sys.argv[1:] , 'hp:s:f:') opts = dict(opts) except getopt.GetoptError,e: opts={'-h':True} if opts.has_key('-h'): sys.stderr.write(usage) sys.exit() shrinkage_name = opts.get('-f','soft') f,theta = get_shrinkage_function( shrinkage_name ) if opts.has_key('-s'): D=dict(np.load(opts['-s']).items()) t=0 while D.has_key('theta_%d'% t): theta_t = D['theta_%d' % t] show_shrinkage(f,theta_t,title='shrinkage=%s, theta_%d=%s' % (shrinkage_name,t, theta_t)) t += 1 else: show_shrinkage(f,theta) """ test_func(shrink_bgest, (1,math.log(1/.1-1)) ,**parms) test_func(shrink_soft_threshold,(1.7,1.2) ,**parms) test_func(shrink_piecwise_linear, (2,4,0.1,1.5,.95) ,**parms) test_func(shrink_expo, (2.5,.9,-1) ,**parms) test_func(shrink_spline, (3.7,.9,-1.5) ,**parms) """