#!/usr/bin/env python # -*- coding: utf-8 -*- """ file : main.py author: Xiaohan Chen email : chernxh@tamu.edu last_modified: 2018-10-13 Main script. Start running model from main.py. """ import os , sys os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # BE QUIET!!!! # timing import time from datetime import timedelta from config import get_config import utils.prob as problem import utils.data as data import utils.train as train import numpy as np import tensorflow as tf try : from sklearn.feature_extraction.image \ import extract_patches_2d, reconstruct_from_patches_2d except Exception as e : pass def setup_model(config , **kwargs) : untiedf = 'u' if config.untied else 't' coordf = 'c' if config.coord else 's' if config.net == 'LISTA' : """LISTA""" config.model = ("LISTA_T{T}_lam{lam}_{untiedf}_{coordf}_{exp_id}" .format (T=config.T, lam=config.lam, untiedf=untiedf, coordf=coordf, exp_id=config.exp_id)) from models.LISTA import LISTA model = LISTA (kwargs['A'], T=config.T, lam=config.lam, untied=config.untied, coord=config.coord, scope=config.scope) if config.net == 'LAMP' : """LAMP""" config.model = ("LAMP_T{T}_lam{lam}_{untiedf}_{coordf}_{exp_id}" .format (T=config.T, lam=config.lam, untiedf=untiedf, coordf=coordf, exp_id=config.exp_id)) from models.LAMP import LAMP model = LAMP (kwargs['A'], T=config.T, lam=config.lam, untied=config.untied, coord=config.coord, scope=config.scope) if config.net == 'LIHT' : """LIHT""" from models.LIHT import LIHT model = LIHT (p, T=config.T, lam=config.lam, y_=p.y_ , x0_=None , untied=config.untied , cord=config.coord) if config.net == 'LISTA_cp' : """LISTA-CP""" config.model = ("LISTA_cp_T{T}_lam{lam}_{untiedf}_{coordf}_{exp_id}" .format (T=config.T, lam=config.lam, untiedf=untiedf, coordf=coordf, exp_id=config.exp_id)) from models.LISTA_cp import LISTA_cp model = LISTA_cp (kwargs['A'], T=config.T, lam=config.lam, untied=config.untied, coord=config.coord, scope=config.scope) if config.net == 'LISTA_ss' : """LISTA-SS""" config.model = ("LISTA_ss_T{T}_lam{lam}_p{p}_mp{mp}_" "{untiedf}_{coordf}_{exp_id}" .format (T=config.T, lam=config.lam, p=config.percent, mp=config.max_percent, untiedf=untiedf, coordf=coordf, exp_id=config.exp_id)) from models.LISTA_ss import LISTA_ss model = LISTA_ss (kwargs['A'], T=config.T, lam=config.lam, percent=config.percent, max_percent=config.max_percent, untied=config.untied , coord=config.coord, scope=config.scope) if config.net == 'LISTA_cpss' : """LISTA-CPSS""" config.model = ("LISTA_cpss_T{T}_lam{lam}_p{p}_mp{mp}_" "{untiedf}_{coordf}_{exp_id}" .format (T=config.T, lam=config.lam, p=config.percent, mp=config.max_percent, untiedf=untiedf, coordf=coordf, exp_id=config.exp_id)) from models.LISTA_cpss import LISTA_cpss model = LISTA_cpss (kwargs['A'], T=config.T, lam=config.lam, percent=config.percent, max_percent=config.max_percent, untied=config.untied , coord=config.coord, scope=config.scope) if config.net == 'TiLISTA': """TiLISTA""" config.model = ("TiLISTA_T{T}_lam{lam}_p{p}_mp{mp}_" "{coordf}_{exp_id}" .format (T=config.T, lam=config.lam, p=config.percent, mp=config.max_percent, coordf=coordf, exp_id=config.exp_id)) from models.TiLISTA import TiLISTA # Note that TiLISTA is just LISTA-CPSS with tied weight in all layers. model = TiLISTA(kwargs['A'], T=config.T, lam=config.lam, percent=config.percent, max_percent=config.max_percent, coord=config.coord, scope=config.scope) if config.net == "ALISTA": """ALISTA""" config.model = ("ALISTA_T{T}_lam{lam}_p{p}_mp{mp}_{W}_{coordf}_{exp_id}" .format(T=config.T, lam=config.lam, p=config.percent, mp=config.max_percent, W=os.path.basename(config.W), coordf=coordf, exp_id=config.exp_id)) W = np.load(config.W) print("Pre-calculated weight W loaded from {}".format(config.W)) from models.ALISTA import ALISTA model = ALISTA(kwargs['A'], T=config.T, lam=config.lam, W=W, percent=config.percent, max_percent=config.max_percent, coord=config.coord, scope=config.scope) if config.net == 'LISTA_cs': """LISTA-CS""" config.model = ("LISTA_cs_T{T}_lam{lam}_llam{llam}_" "{untiedf}_{coordf}_{exp_id}" .format (T=config.T, lam=config.lam, llam=config.lasso_lam, untiedf=untiedf, coordf=coordf, exp_id=config.exp_id)) from models.LISTA_cs import LISTA_cs model = LISTA_cs (kwargs['Phi'], kwargs['D'], T=config.T, lam=config.lam, untied=config.untied, coord=config.coord, scope=config.scope) if config.net == 'LISTA_ss_cs' : """LISTA-SS-CS""" config.model = ("LISTA_ss_cs_T{T}_lam{lam}_p{p}_mp{mp}_llam{llam}_" "{untiedf}_{coordf}_{exp_id}" .format (T=config.T, lam=config.lam, p=config.percent, mp=config.max_percent, llam=config.lasso_lam, untiedf=untiedf, coordf=coordf, exp_id=config.exp_id)) from models.LISTA_ss_cs import LISTA_ss_cs model = LISTA_ss_cs (kwargs['Phi'], kwargs['D'], T=config.T, lam=config.lam, percent=config.percent, max_percent=config.max_percent, untied=config.untied, coord=config.coord, scope=config.scope) if config.net == 'LISTA_cpss_cs' : """LISTA-CPSS-CS""" config.model = ("LISTA_cpss_cs_T{T}_lam{lam}_p{p}_mp{mp}_llam{llam}_" "{untiedf}_{coordf}_{exp_id}" .format (T=config.T, lam=config.lam, p=config.percent, mp=config.max_percent, llam=config.lasso_lam, untiedf=untiedf, coordf=coordf, exp_id=config.exp_id)) from models.LISTA_cpss_cs import LISTA_cpss_cs model = LISTA_cpss_cs (kwargs['Phi'], kwargs['D'], T=config.T, lam=config.lam, percent=config.percent, max_percent=config.max_percent, untied=config.untied, coord=config.coord, scope=config.scope) if config.net == 'LISTA_cp_conv': """LISTA-CP-CONV""" config.model = ("LISTA_cp_conv_T{T}_lam{lam}_alpha{alpha}_" "sigma{sigma}_{untiedf}_{exp_id}.npz" .format(T=config.T, lam=config.lam, alpha=config.conv_alpha, sigma=config.sigma, untiedf=untiedf, coordf=coordf, exp_id=config.exp_id)) from models.LISTA_cp_conv import LISTA_cp_conv model = LISTA_cp_conv(kwargs['filters'], T=config.T, lam=config.lam, alpha=config.conv_alpha, untied=config.untied, scope=config.scope) if config.net == 'ALISTA_conv': """ALISTA-CONV""" config.model = ("ALISTA_conv_T{T}_lam{lam}_alpha{alpha}_" "sigma{sigma}_{exp_id}.npz" .format(T=config.T, lam=config.lam, alpha=config.conv_alpha, sigma=config.sigma, exp_id=config.exp_id)) W = np.load(config.W) print("Pre-calculated weight W loaded from {}".format(config.W)) from models.ALISTA_conv import ALISTA_conv model = ALISTA_conv(kwargs['filters'], W=W, T=config.T, lam=config.lam, alpha=config.conv_alpha, scope=config.scope) if config.net == "AtoW_grad": """AtoW_grad""" config.model = ("AtoW_grad_eT{eT}_Binit-{Binit}_eta{eta}_loss-{loss}_ps{ps}_lr{lr}_{id}" .format(eT=config.eT, Binit=config.encoder_Binit, eta=config.eta, loss=config.encoder_loss, ps=config.encoder_psigma, lr=config.encoder_pre_lr, id=config.exp_id)) from models.AtoW_grad import AtoW_grad model = AtoW_grad(config.M, config.N, config.eT, Binit=kwargs["Binit"], eta=config.eta, loss=config.encoder_loss, Q=kwargs["Q"], scope=config.scope) if config.net == "robust_ALISTA": """Robust ALISTA""" config.encoder = ("AtoW_grad_eT{eT}_Binit-{Binit}_eta{eta}_loss-{loss}_ps{ps}_lr{lr}_{id}" .format(eT=config.eT, Binit=config.encoder_Binit, eta=config.eta, loss=config.encoder_loss, ps=config.encoder_psigma, lr=config.encoder_pre_lr, id=config.exp_id)) config.decoder = ("ALISTA_robust_T{T}_lam{lam}_p{p}_mp{mp}_{W}_{coordf}_{exp_id}" .format(T=config.T, lam=config.lam, p=config.percent, mp=config.max_percent, W=os.path.basename(config.W), coordf=coordf, exp_id=config.exp_id)) # set up encoder from models.AtoW_grad import AtoW_grad encoder = AtoW_grad(config.M, config.N, config.eT, Binit=kwargs["Binit"], eta=config.eta, loss=config.encoder_loss, Q=kwargs["Q"], scope=config.encoder_scope) # set up decoder from models.ALISTA_robust import ALISTA_robust decoder = ALISTA_robust(M=config.M, N=config.N, T=config.T, percent=config.percent, max_percent=config.max_percent, coord=config.coord, scope=config.decoder_scope) model_desc = ("robust_" + config.encoder + '_' + config.decoder + "_elr{}_dlr{}_psmax{}_psteps{}_{}" .format(config.encoder_lr, config.decoder_lr, config.psigma_max, config.psteps, config.exp_id)) model_dir = os.path.join(config.expbase, model_desc) config.resfn = os.path.join(config.resbase, model_desc) if not os.path.exists(model_dir): if config.test: raise ValueError("Testing folder {} not existed".format(model_dir)) else: os.makedirs(model_dir) config.enc_load = os.path.join(config.expbase, config.encoder) config.dec_load = os.path.join(config.expbase, config.decoder.replace("_robust", "")) config.encoderfn = os.path.join(model_dir, config.encoder) config.decoderfn = os.path.join(model_dir, config.decoder) return encoder, decoder config.modelfn = os.path.join(config.expbase, config.model) config.resfn = os.path.join(config.resbase, config.model) print ("model disc:", config.model) return model ############################################################ ###################### Training ####################### ############################################################ def run_train(config) : if config.task_type == "sc": run_sc_train(config) elif config.task_type == "cs": run_cs_train(config) elif config.task_type == "denoise": run_denoise_train(config) elif config.task_type == "encoder": run_encoder_train(config) elif config.task_type == "robust": run_robust_train(config) def run_sc_train(config) : """Load problem.""" if not os.path.exists(config.probfn): raise ValueError ("Problem file not found.") else: p = problem.load_problem(config.probfn) """Set up model.""" model = setup_model (config, A=p.A) """Set up input.""" config.SNR = np.inf if config.SNR == 'inf' else float (config.SNR) y_, x_, y_val_, x_val_ = ( train.setup_input_sc ( config.test, p, config.tbs, config.vbs, config.fixval, config.supp_prob, config.SNR, config.magdist, **config.distargs)) """Set up training.""" stages = train.setup_sc_training ( model, y_, x_, y_val_, x_val_, None, config.init_lr, config.decay_rate, config.lr_decay) tfconfig = tf.ConfigProto (allow_soft_placement=True) tfconfig.gpu_options.allow_growth = True with tf.Session (config=tfconfig) as sess: # graph initialization sess.run (tf.global_variables_initializer ()) # start timer start = time.time () # train model model.do_training(sess, stages, config.modelfn, config.scope, config.val_step, config.maxit, config.better_wait) # end timer end = time.time () elapsed = end - start print ("elapsed time of training = " + str (timedelta (seconds=elapsed))) # end of run_sc_train def run_cs_train (config) : """Load dictionary and sensing matrix.""" Phi = np.load (config.sensing)['A'] D = np.load (config.dict) """Set up model.""" model = setup_model (config, Phi=Phi, D=D) """Set up inputs.""" y_, f_, y_val_, f_val_ = train.setup_input_cs(config.train_file, config.val_file, config.tbs, config.vbs) """Set up training.""" stages = train.setup_cs_training ( model, y_, f_, y_val_, f_val_, None, config.init_lr, config.decay_rate, config.lr_decay, config.lasso_lam) """Start training.""" tfconfig = tf.ConfigProto (allow_soft_placement=True) tfconfig.gpu_options.allow_growth = True with tf.Session (config=tfconfig) as sess: # graph initialization sess.run (tf.global_variables_initializer ()) # start timer start = time.time () # train model model.do_training (sess, stages, config.modelfn, config.scope, config.val_step, config.maxit, config.better_wait) # end timer end = time.time () elapsed = end - start print ("elapsed time of training = " + str (timedelta (seconds=elapsed))) # end of run_cs_train def run_denoise_train (config) : """Load problem.""" import utils.prob_conv as problem if not os.path.exists (config.probfn): raise ValueError ("Problem file not found.") else: p = problem.load_problem (config.probfn) """Set up model.""" model = setup_model (config, filters=p._fs) """Set up input.""" # training clean_ = data.bsd500_denoise_inputs(config.data_folder, config.train_file, config.tbs, config.height_crop, config.width_crop, config.num_epochs) clean_.set_shape((config.tbs, *clean_.get_shape()[1:],)) # validation clean_val_ = data.bsd500_denoise_inputs(config.data_folder, config.val_file, config.vbs, config.height_crop, config.width_crop, 1) clean_val_.set_shape((config.vbs, *clean_val_.get_shape()[1:],)) # add noise noise_ = tf.random_normal(clean_.shape, stddev=config.denoise_std, dtype=tf.float32) noise_val_ = tf.random_normal(clean_val_.shape, stddev=config.denoise_std, dtype=tf.float32) noisy_ = clean_ + noise_ noisy_val_= clean_val_ + noise_val_ # fix validation set with tf.name_scope ('input'): clean_val_ = tf.get_variable(name='clean_val', dtype=tf.float32, initializer=clean_val_) noisy_val_ = tf.get_variable(name='noisy_val', dtype=tf.float32, initializer=noisy_val_) """Set up training.""" stages = train.setup_denoise_training( model, noisy_, clean_, noisy_val_, clean_val_, None, config.init_lr, config.decay_rate, config.lr_decay) tfconfig = tf.ConfigProto (allow_soft_placement=True) tfconfig.gpu_options.allow_growth = True with tf.Session (config=tfconfig) as sess: # graph initialization sess.run (tf.global_variables_initializer ()) # start timer start = time.time () # train model model.do_training(sess, stages, config.modelfn, config.scope, config.val_step, config.maxit, config.better_wait) # end timer end = time.time () elapsed = end - start print ("elapsed time of training = " + str (timedelta (seconds=elapsed))) # end of run_denoise_train def run_encoder_train(config): """Load problem.""" if not os.path.exists(config.probfn): raise ValueError("Problem file not found.") else: p = problem.load_problem(config.probfn) """Load the Q reweighting matrix.""" if config.Q is None: # use default Q reweighting matrix if "re" in config.encoder_loss: # if using reweighted loss Q = np.sqrt((np.ones(shape=(config.N, config.N), dtype=np.float32) + np.eye(config.N, dtype=np.float32) * (config.N - 2))) else: Q = None elif os.path.exists(config.Q) and config.Q.endswith(".npy"): Q = np.load(config.Q) assert Q.shape == (config.N, config.N) else: raise ValueError("Invalid parameter `--Q`\n" "A valid `--Q` parameter should be one of the following:\n" " 1) omitted for default value as in the paper;\n" " 2) path/to/your/npy/file that contains your Q matrix.\n") """Binit matrix.""" if config.encoder_Binit == "default": Binit = p.A elif config.Binit in ["uniform", "normal"]: pass else: raise ValueError("Invalid parameter `--Binit`\n" "A valid `--Binit` parameter should be one of the following:\n" " 1) omitted for default value `p.A`;\n" " 2) `normal` or `uniform`.\n") """Set up model.""" model = setup_model(config, Binit=Binit, Q=Q) print("The trained model will be saved in {}".format(config.model)) """Set up training.""" from utils.tf import get_loss_func, bmxbm, mxbm with tf.name_scope ('input'): A_ = tf.constant(p.A, dtype=tf.float32) perturb_ = tf.random.normal(shape=(config.Abs, config.M, config.N), mean=0.0, stddev=config.encoder_psigma, dtype=tf.float32) Ap_ = A_ + perturb_ Ap_ = Ap_ / tf.sqrt(tf.reduce_sum(tf.square( Ap_ ), axis=1, keepdims=True)) Apt_ = tf.transpose(Ap_, [0,2,1]) W_ = model.inference(Ap_) """Set up loss.""" eye_ = tf.eye(config.N, batch_shape=[config.Abs], dtype=tf.float32) residual_ = bmxbm(Apt_, W_, batch_first=True) - eye_ loss_func = get_loss_func(config.encoder_loss, model._Q_) loss_ = loss_func(residual_) # fix validation set Ap_val_ = tf.get_variable(name='Ap_val', dtype=tf.float32, initializer=Ap_, trainable=False) Apt_val_ = tf.transpose(Ap_val_, [0,2,1]) W_val_ = model.inference(Ap_val_) # validation loss residual_val_ = bmxbm(Apt_val_, W_val_, batch_first=True) - eye_ loss_val_ = loss_func(residual_val_) """Set up optimizer.""" global_step = tf.Variable(0, trainable=False) lr = tf.train.exponential_decay(config.encoder_lr, global_step, 5000, 0.75, staircase=True) learning_step = (tf.train.AdamOptimizer(lr) .minimize(loss_, global_step=global_step)) # create session and initialize the graph tfconfig = tf.ConfigProto (allow_soft_placement=True) tfconfig.gpu_options.allow_growth = True with tf.Session (config=tfconfig) as sess: sess.run (tf.global_variables_initializer ()) # start timer start = time.time () for i in range (config.maxit): # training step _, loss = sess.run([learning_step, loss_]) # validation step if i % config.val_step == 0: # validation step loss_val = sess.run(loss_val_) sys.stdout.write ( "\ri={i:<7d} | train_loss={train_loss:.6f} | " "loss_val={loss_val:.6f}" .format(i=i, train_loss=loss, loss_val=loss_val)) sys.stdout.flush() # end timer end = time.time() elapsed = end - start print("elapsed time of training = " + str(timedelta(seconds=elapsed))) train.save_trainable_variables (sess, config.modelfn, config.scope) print("model saved to {}".format(config.modelfn)) # end of run_encoder_train def run_robust_train(config): """Load problem.""" if not os.path.exists(config.probfn): raise ValueError("Problem file not found.") else: p = problem.load_problem(config.probfn) """Set up input.""" # `psigma` is a list of standard deviations for curriculum learning psigmas = np.linspace(0, config.psigma_max, config.psteps)[1:] psigma_ = tf.placeholder(dtype=tf.float32, shape=()) with tf.name_scope ('input'): Ap_, y_, x_ = train.setup_input_robust(p.A, psigma_, config.msigma, p.pnz, config.Abs, config.xbs) if config.net != "robust_ALISTA": # If not joint robust training # reshape y_ into shape (m, Abs * xbs) # reshape x_ into shape (n, Abs * xbs) y_ = tf.reshape(tf.transpose(y_, [1, 0, 2]), (config.M, -1)) x_ = tf.reshape(tf.transpose(x_, [1, 0, 2]), (config.N, -1)) # fix validation set Ap_val_ = tf.get_variable(name="Ap_val", dtype=tf.float32, initializer=Ap_) y_val_ = tf.get_variable(name="y_val", dtype=tf.float32, initializer=y_) x_val_ = tf.get_variable(name="x_val", dtype=tf.float32, initializer=x_) """Set up model.""" if config.net == "robust_ALISTA": """Load the Q reweighting matrix.""" if config.Q is None: # use default Q reweighting matrix if "re" in config.encoder_loss: # if using reweighted loss Q = np.sqrt((np.ones(shape=(config.N, config.N), dtype=np.float32) + np.eye(config.N, dtype=np.float32) * (config.N - 2))) else: Q = None elif os.path.exists(config.Q) and config.Q.endswith(".npy"): Q = np.load(config.Q) assert Q.shape == (config.N, config.N) else: raise ValueError("Invalid parameter `--Q`\n" "A valid `--Q` parameter should be one of the following:\n" " 1) omitted for default value as in the paper;\n" " 2) path/to/your/npy/file that contains your Q matrix.\n") """Binit matrix.""" if config.encoder_Binit == "default": Binit = p.A elif config.Binit in ["uniform", "normal"]: pass else: raise ValueError("Invalid parameter `--Binit`\n" "A valid `--Binit` parameter should be one of the following:\n" " 1) omitted for default value `p.A`;\n" " 2) `normal` or `uniform`.\n") encoder, decoder = setup_model(config, Q=Q, Binit=Binit) W_ = encoder.inference(Ap_) W_val_ = encoder.inference(Ap_val_) xh_ = decoder.inference(y_, Ap_, W_, x0_=None)[-1] xh_val_ = decoder.inference(y_val_, Ap_val_, W_val_, x0_=None)[-1] else: decoder = setup_model(config, A=p.A) xh_ = decoder.inference(y_, None)[-1] xh_val_ = decoder.inference(y_val_, None)[-1] config.dec_load = config.modelfn config.decoder = ( "robust_" + config.model + '_ps{ps}_nsteps{nsteps}_ms{ms}_lr{lr}' .format(ps=config.psigma_max, nsteps=config.psteps, ms=config.msigma, lr=config.decoder_lr)) config.decoderfn = os.path.join(config.expbase, config.decoder) print("\npretrained decoder loaded from {}".format(config.modelfn)) print("trained augmented model will be saved to {}".format(config.decoderfn)) """Set up loss.""" loss_ = tf.nn.l2_loss (xh_ - x_) nmse_denom_ = tf.nn.l2_loss (x_) nmse_ = loss_ / nmse_denom_ db_ = 10.0 * tf.log (nmse_) / tf.log (10.0) # validation loss_val_ = tf.nn.l2_loss (xh_val_ - x_val_) nmse_denom_val_ = tf.nn.l2_loss (x_val_) nmse_val_ = loss_val_ / nmse_denom_val_ db_val_ = 10.0 * tf.log (nmse_val_) / tf.log (10.0) """Set up optimizer.""" global_step_ = tf.Variable (0, trainable=False) if config.net == "robust_ALISTA": """Encoder and decoder apply different initial learning rate.""" # get trainable variable for de encoder and decoder encoder_variables_ = tf.get_collection( key=tf.GraphKeys.TRAINABLE_VARIABLES, scope=config.encoder_scope) decoder_variables_ = tf.get_collection( key=tf.GraphKeys.TRAINABLE_VARIABLES, scope=config.decoder_scope) trainable_variables_ = encoder_variables_ + decoder_variables_ # calculate gradients w.r.t. all trainable variables in the model grads_ = tf.gradients (loss_, trainable_variables_) encoder_grads_ = grads_[:len (encoder_variables_)] decoder_grads_ = grads_[len (encoder_variables_):] # define learning rates for optimizers over two parts global_step_ = tf.Variable (0, trainable=False) encoder_lr_ = tf.train.exponential_decay( config.encoder_lr, global_step_, 5000, 0.75, staircase=False) encoder_opt_ = tf.train.AdamOptimizer(encoder_lr_) decoder_lr_ = tf.train.exponential_decay( config.decoder_lr, global_step_, 5000, 0.75, staircase=False) decoder_opt_ = tf.train.AdamOptimizer(decoder_lr_) # define training operator encoder_op_ = encoder_opt_.apply_gradients( zip(encoder_grads_, encoder_variables_)) decoder_op_ = decoder_opt_.apply_gradients( zip(decoder_grads_, decoder_variables_)) learning_step_ = tf.group(encoder_op_, decoder_op_) else: lr_ = tf.train.exponential_decay(config.decoder_lr, global_step_, 5000, 0.75, staircase=False) learning_step_ = (tf.train.AdamOptimizer(lr_) .minimize(loss_, global_step=global_step_)) tfconfig = tf.ConfigProto (allow_soft_placement=True) tfconfig.gpu_options.allow_growth = True with tf.Session (config=tfconfig) as sess: # graph initialization sess.run (tf.global_variables_initializer (), feed_dict={psigma_: psigmas[0]}) # load pre-trained model(s) if config.net == "robust_ALISTA": encoder.load_trainable_variables(sess, config.enc_load) decoder.load_trainable_variables(sess, config.dec_load) # start timer start = time.time () for psigma in psigmas: print ('\ncurrent sigma: {}'.format (psigma)) global_step_.initializer.run () for i in range (config.maxit): db, loss, _ = sess.run([db_, loss_, learning_step_], feed_dict={psigma_: psigma}) if i % config.val_step == 0: db_val, loss_val = sess.run([db_val_, loss_val_], feed_dict={psigma_: psigma}) sys.stdout.write( "\ri={i:<7d} | loss_train={loss_train:.6f} | " "db_train={db_train:.6f} | loss_val={loss_val:.6f} | " "db_val={db_val:.6f}".format( i=i, loss_train=loss, db_train=db, loss_val=loss_val, db_val=db_val)) sys.stdout.flush() if config.net == "robust_ALISTA": encoder.save_trainable_variables(sess, config.encoderfn) decoder.save_trainable_variables(sess, config.decoderfn) # end timer end = time.time() elapsed = end - start print("elapsed time of training = " + str(timedelta(seconds=elapsed))) # end of run_robust_train ############################################################ ###################### Testing ######################## ############################################################ def run_test (config): if config.task_type == "sc": run_sc_test (config) elif config.task_type == "cs": run_cs_test (config) elif config.task_type == "denoise": run_denoise_test(config) elif config.task_type == "robust": run_robust_test(config) def run_sc_test (config) : """ Test model. """ """Load problem.""" if not os.path.exists (config.probfn): raise ValueError ("Problem file not found.") else: p = problem.load_problem (config.probfn) """Load testing data.""" xt = np.load (config.xtest) """Set up input for testing.""" config.SNR = np.inf if config.SNR == 'inf' else float (config.SNR) input_, label_ = ( train.setup_input_sc (config.test, p, xt.shape [1], None, False, config.supp_prob, config.SNR, config.magdist, **config.distargs)) """Set up model.""" model = setup_model (config , A=p.A) xhs_ = model.inference (input_, None) """Create session and initialize the graph.""" tfconfig = tf.ConfigProto (allow_soft_placement=True) tfconfig.gpu_options.allow_growth = True with tf.Session (config=tfconfig) as sess: # graph initialization sess.run (tf.global_variables_initializer ()) # load model model.load_trainable_variables (sess , config.modelfn) nmse_denom = np.sum (np.square (xt)) supp_gt = xt != 0 lnmse = [] lspar = [] lsperr = [] lflspo = [] lflsne = [] # test model for xh_ in xhs_ : xh = sess.run (xh_ , feed_dict={label_:xt}) # nmse: loss = np.sum (np.square (xh - xt)) nmse_dB = 10.0 * np.log10 (loss / nmse_denom) print (nmse_dB) lnmse.append (nmse_dB) supp = xh != 0.0 # intermediate sparsity spar = np.sum (supp , axis=0) lspar.append (spar) # support error sperr = np.logical_xor(supp, supp_gt) lsperr.append (np.sum (sperr , axis=0)) # false positive flspo = np.logical_and (supp , np.logical_not (supp_gt)) lflspo.append (np.sum (flspo , axis=0)) # false negative flsne = np.logical_and (supp_gt , np.logical_not (supp)) lflsne.append (np.sum (flsne , axis=0)) res = dict (nmse=np.asarray (lnmse), spar=np.asarray (lspar), sperr=np.asarray (lsperr), flspo=np.asarray (lflspo), flsne=np.asarray (lflsne)) np.savez (config.resfn , **res) # end of test def run_cs_test (config) : from utils.cs import imread_CS_py, img2col_py, col2im_CS_py from skimage.io import imsave """Load dictionary and sensing matrix.""" Phi = np.load (config.sensing) ['A'] D = np.load (config.dict) # loading compressive sensing settings M = Phi.shape [0] F = Phi.shape [1] N = D.shape [1] assert M == config.M and F == config.F and N == config.N patch_size = int (np.sqrt (F)) assert patch_size ** 2 == F """Set up model.""" model = setup_model (config, Phi=Phi, D=D) """Inference.""" y_ = tf.placeholder (shape=(M, None), dtype=tf.float32) _, fhs_ = model.inference (y_, None) """Start testing.""" tfconfig = tf.ConfigProto (allow_soft_placement=True) tfconfig.gpu_options.allow_growth = True with tf.Session (config=tfconfig) as sess: # graph initialization sess.run (tf.global_variables_initializer ()) # load model model.load_trainable_variables (sess , config.modelfn) # calculate average NMSE and PSRN on test images test_dir = './data/test_images/' test_files = os.listdir (test_dir) avg_nmse = 0.0 avg_psnr = 0.0 overlap = 0 stride = patch_size - overlap out_dir = "./data/recon_images" if 'joint' in config.net : D = sess.run (model.D_) for test_fn in test_files : # read in image out_fn = test_fn[:-4] + "_recon_{}.png".format(config.sample_rate) out_fn = os.path.join(out_dir, out_fn) test_fn = os.path.join (test_dir, test_fn) test_im, H, W, test_im_pad, H_pad, W_pad = \ imread_CS_py (test_fn, patch_size, stride) test_fs = img2col_py (test_im_pad, patch_size, stride) # remove dc from features test_dc = np.mean (test_fs, axis=0, keepdims=True) test_cfs = test_fs - test_dc test_cfs = np.asarray (test_cfs) / 255.0 # sensing signals test_ys = np.matmul (Phi, test_cfs) num_patch = test_ys.shape [1] rec_cfs = sess.run (fhs_ [-1], feed_dict={y_: test_ys}) rec_fs = rec_cfs * 255.0 + test_dc # patch-level NMSE patch_err = np.sum (np.square (rec_fs - test_fs)) patch_denom = np.sum (np.square (test_fs)) avg_nmse += 10.0 * np.log10 (patch_err / patch_denom) rec_im = col2im_CS_py (rec_fs, patch_size, stride, H, W, H_pad, W_pad) # image-level PSNR image_mse = np.mean (np.square (np.clip(rec_im, 0.0, 255.0) - test_im)) avg_psnr += 10.0 * np.log10 (255.**2 / image_mse) num_test_ims = len (test_files) print ('Average Patch-level NMSE is {}'.format (avg_nmse / num_test_ims)) print ('Average Image-level PSNR is {}'.format (avg_psnr / num_test_ims)) # end of cs_testing def run_denoise_test(config) : import glob from PIL import Image """Load problem.""" import utils.prob_conv as problem if not os.path.exists(config.probfn): raise ValueError("Problem file not found.") else: p = problem.load_problem(config.probfn) """Set up model.""" model = setup_model(config, filters=p._fs) """Set up input.""" orig_clean_ = tf.placeholder(dtype=tf.float32, shape=(None, 256, 256, 1)) clean_ = orig_clean_ * (1.0 / 255.0) mean_ = tf.reduce_mean(clean_, axis=(1,2,3,), keepdims=True) demean_ = clean_ - mean_ """Add noise.""" noise_ = tf.random_normal (tf.shape (demean_), stddev=config.denoise_std, dtype=tf.float32) noisy_ = demean_ + noise_ """Inference.""" _, recons_ = model.inference(noisy_, None) recon_ = recons_[-1] # denormalize recon_ = (recon_ + mean_) * 255.0 """PSNR.""" mse2_ = tf.reduce_mean(tf.square(orig_clean_ - recon_), axis=(1,2,3,)) psnr_ = 10.0 * tf.log(255.0 ** 2 / mse2_) / tf.log (10.0) avg_psnr_ = tf.reduce_mean(psnr_) """Load test images.""" test_images = [] filenames = [] types = ("*.tif", "*.png", "*.jpg", "*.gif",) for type in types: filenames.extend(glob.glob(os.path.join(config.test_dir, type))) for filename in filenames: im = Image.open(filename) if im.size != (256, 256): im = im.resize((256, 256)) test_images.append(np.asarray (im).astype(np.float32)) test_images = np.asarray(test_images).reshape((-1, 256, 256, 1)) tfconfig = tf.ConfigProto(allow_soft_placement=True) tfconfig.gpu_options.allow_growth = True with tf.Session(config=tfconfig) as sess: # graph initialization sess.run(tf.global_variables_initializer()) # load model model.load_trainable_variables(sess, config.modelfn) # testing psnr, avg_psnr = sess.run([psnr_, avg_psnr_], feed_dict={orig_clean_:test_images}) print('file names\t| PSNR/dB') for fname, p in zip(filenames, psnr): print(os.path.basename (fname), '\t', p) print("average PSNR = {} dB".format(avg_psnr)) print("full PSNR records on testing set are stored in {}".format(config.resfn)) np.save(config.resfn, psnr) sum_time = 0.0 ntimes = 200 for i in range(ntimes): # start timer start = time.time() # testing sess.run(recon_, feed_dict={orig_clean_:test_images}) # end timer end = time.time() sum_time = sum_time + end - start print("average elapsed time for one image inference = " + str(timedelta(seconds=sum_time/ntimes/test_images.shape[0]))) # start timer start = time.time() # end of run_denoise_test def run_robust_test(config): """Load problem.""" print(config.probfn) if not os.path.exists(config.probfn): raise ValueError("Problem file not found.") else: p = problem.load_problem(config.probfn) """Set tesing data.""" test_As = np.load('./data/robust_test_A.npz') x = np.load('./data/xtest_n500_p10.npy') """Set up input.""" psigmas = sorted([float(k) for k in test_As.keys()]) psigma_ = tf.placeholder(dtype=tf.float32, shape=()) with tf.name_scope ('input'): Ap_ = tf.placeholder (dtype=tf.float32, shape=(250, 500)) x_ = tf.placeholder (dtype=tf.float32, shape=(500, None)) ## measure y_ from x_ using Ap_ y_ = tf.matmul (Ap_, x_) """Set up model.""" if config.net == "robust_ALISTA": """Load the Q reweighting matrix.""" if config.Q is None: # use default Q reweighting matrix if "re" in config.encoder_loss: # if using reweighted loss Q = np.sqrt((np.ones(shape=(config.N, config.N), dtype=np.float32) + np.eye(config.N, dtype=np.float32) * (config.N - 2))) else: Q = None elif os.path.exists(config.Q) and config.Q.endswith(".npy"): Q = np.load(config.Q) assert Q.shape == (config.N, config.N) else: raise ValueError("Invalid parameter `--Q`\n" "A valid `--Q` parameter should be one of the following:\n" " 1) omitted for default value as in the paper;\n" " 2) path/to/your/npy/file that contains your Q matrix.\n") """Binit matrix.""" if config.encoder_Binit == "default": Binit = p.A elif config.Binit in ["uniform", "normal"]: pass else: raise ValueError("Invalid parameter `--Binit`\n" "A valid `--Binit` parameter should be one of the following:\n" " 1) omitted for default value `p.A`;\n" " 2) `normal` or `uniform`.\n") encoder, decoder = setup_model(config, Q=Q, Binit=Binit) W_ = tf.squeeze(encoder.inference(tf.expand_dims(Ap_, axis=0)), axis=0) xh_ = decoder.inference(y_, Ap_, W_, x0_=None)[-1] else: decoder = setup_model(config, A=p.A) xh_ = decoder.inference(y_, None)[-1] config.decoder = ( "robust_" + config.model + '_ps{ps}_nsteps{nsteps}_ms{ms}_lr{lr}' .format(ps=config.psigma_max, nsteps=config.psteps, ms=config.msigma, lr=config.decoder_lr)) config.decoderfn = os.path.join(config.expbase, config.decoder) print("\ntrained augmented model loaded from {}".format(config.decoderfn)) """Set up loss.""" loss_ = tf.nn.l2_loss (xh_ - x_) nmse_denom_ = tf.nn.l2_loss (x_) nmse_ = loss_ / nmse_denom_ db_ = 10.0 * tf.log (nmse_) / tf.log (10.0) tfconfig = tf.ConfigProto (allow_soft_placement=True) tfconfig.gpu_options.allow_growth = True with tf.Session (config=tfconfig) as sess: # graph initialization sess.run (tf.global_variables_initializer (), feed_dict={psigma_: psigmas[0]}) # load pre-trained model(s) if config.net == "robust_ALISTA": encoder.load_trainable_variables(sess, config.encoderfn) decoder.load_trainable_variables(sess, config.decoderfn) # start timer start = time.time () res = dict (sigmas=np.array (psigmas)) avg_dBs = [] print ('sigma\tnmse') sum_time = 0.0 tcounter = 0 for psigma in psigmas: Aps = test_As[str(psigma)] sum_dB = 0.0 counter = 0 for Ap in Aps: db = sess.run(db_, feed_dict={x_:x, Ap_:Ap}) # start timer start = time.time () # inference sess.run (xh_, feed_dict={x_:x, Ap_:Ap}) # end timer end = time.time () elapsed = end - start sum_time = sum_time + elapsed tcounter = tcounter + 1 sum_dB = sum_dB + db counter = counter + 1 avg_dB = sum_dB / counter print(psigma, '\t', avg_dB) avg_dBs.append (avg_dB) print("average elapsed time of inference =", str (timedelta (seconds=sum_time/tcounter))) res['avg_dBs'] = np.asarray(avg_dBs) print('saving results to', config.resfn) np.savez (config.resfn, **res) # end of run_robust_test ############################################################ ####################### Main ######################### ############################################################ def main (): # parse configuration config, _ = get_config() # set visible GPUs os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu if config.test: run_test (config) else: run_train (config) # end of main if __name__ == "__main__": main ()