import tensorflow as tf from tensorflow.python.ops import control_flow_ops from os.path import join,exists import glob import random import numpy as np from PIL import Image import time import os from utils import Huber,LoadImage, DownSample, DownSample_4D, BLUR, AVG_PSNR, depth_to_space_3D, DynFilter3D, LoadParams, automkdir, get_num_params, cv2_imread,cv2_imsave from model.nets import FR_16L, FR_28L, FR_52L from model.base_model import VSR from tqdm import trange,tqdm '''This work tries to rebuild DUFVSR (Deep Video Super-Resolution Network Using Dynamic Upsampling Filters Without Explicit Motion Compensation). The code is mainly based on https://github.com/psychopa4/MMCNN, https://github.com/jiangsutx/SPMC_VideoSR and https://github.com/yhjo09/VSR-DUF. ''' class DUFVSR(VSR): def __init__(self): self.num_frames=7 self.scale=4 self.in_size=32 self.gt_size=self.in_size*self.scale self.eval_in_size=[128,240] self.batch_size=11 #can be increased with larger GPU memory self.eval_basz=4 self.learning_rate=1e-3 self.end_lr=1e-4 self.reload=True self.max_step=int(1.5e5+1) self.decay_step=1.2e5 self.train_dir='./data/filelist_train.txt' self.eval_dir='./data/filelist_val.txt' self.save_dir='./checkpoint/duf_52' self.log_dir='./duf_52.txt' def forward(self, x, is_train): # shape of x: [B,T_in,H,W,C] # Generate filters and residual # Fx: [B,1,H,W,1*5*5,R*R] # Rx: [B,1,H,W,3*R*R] with tf.variable_scope('G',reuse=tf.AUTO_REUSE) as scope: Fx, Rx = FR_52L(x, is_train) x_c = [] for c in range(3): t = DynFilter3D(x[:,self.num_frames//2:self.num_frames//2+1,:,:,c], Fx[:,0,:,:,:,:], [1,5,5]) # [B,H,W,R*R] t = tf.depth_to_space(t, self.scale) # [B,H*R,W*R,1] x_c += [t] x = tf.concat(x_c, axis=3) # [B,H*R,W*R,3] x = tf.expand_dims(x, axis=1) Rx = depth_to_space_3D(Rx, self.scale) # [B,1,H*R,W*R,3] x += Rx return x def build(self): H = tf.placeholder(tf.float32, shape=[None, 1, None, None, 3], name='H_truth') L = tf.placeholder(tf.float32, shape=[None, self.num_frames, None, None, 3], name='L_input') is_train = tf.placeholder(tf.bool, shape=[]) # Phase ,scalar SR = self.forward(L,is_train) loss=Huber(SR,H,0.01)#tf.reduce_mean(tf.sqrt((SR-H)**2+1e-6)) eval_mse=tf.reduce_mean((SR-H) ** 2, axis=[2,3,4])#[:,self.num_frames//2:self.num_frames//2+1] self.loss, self.eval_mse= loss, eval_mse self.L, self.H, self.SR, self.is_train = L, H, SR, is_train def eval(self): print('Evaluating ...') if not hasattr(self, 'sess'): global_step=tf.Variable(initial_value=0, trainable=False) self.global_step=global_step self.build() sess = tf.Session() self.load(sess, self.save_dir) else: sess = self.sess border=8 in_h,in_w=self.eval_in_size out_h = in_h*self.scale #512 out_w = in_w*self.scale #960 bd=border//self.scale filenames=open(self.eval_dir, 'rt').read().splitlines() hr_list=[sorted(glob.glob(join(f,'truth','*.png'))) for f in filenames] lr_list=[sorted(glob.glob(join(f,'blur{}'.format(self.scale),'*.png'))) for f in filenames] center=15 batch_hr = [] batch_lr = [] batch_cnt=0 mse_acc=None for lrlist,hrlist in zip(lr_list,hr_list): max_frame=len(lrlist) for idx0 in range(center, max_frame, 32): index=np.array([i for i in range(idx0-self.num_frames//2,idx0+self.num_frames//2+1)]) index=np.clip(index,0,max_frame-1).tolist() gt=[cv2_imread(hrlist[idx0])] inp=[cv2_imread(lrlist[i]) for i in index] inp=[i[bd:in_h+bd, bd:in_w+bd].astype(np.float32) / 255.0 for i in inp] gt = [i[border:out_h+border, border:out_w+border, :].astype(np.float32) / 255.0 for i in gt] batch_hr.append(np.stack(gt, axis=0)) batch_lr.append(np.stack(inp, axis=0)) if len(batch_hr) == self.eval_basz: batch_hr = np.stack(batch_hr, 0) batch_lr = np.stack(batch_lr, 0) mse_val=sess.run(self.eval_mse,feed_dict={self.L:batch_lr, self.H:batch_hr, self.is_train:False}) if mse_acc is None: mse_acc = mse_val else: mse_acc = np.concatenate([mse_acc, mse_val], axis=0) batch_hr = [] batch_lr=[] print('\tEval batch {} - {} ...'.format(batch_cnt, batch_cnt + self.eval_basz)) batch_cnt+=self.eval_basz psnr_acc = 10 * np.log10(1.0 / mse_acc) mse_avg = np.mean(mse_acc, axis=0) psnr_avg = np.mean(psnr_acc, axis=0) for i in range(mse_avg.shape[0]): tf.summary.scalar('val_mse{}'.format(i), tf.convert_to_tensor(mse_avg[i], dtype=tf.float32)) print('Eval PSNR: {}, MSE: {}'.format(psnr_avg, mse_avg)) # write to log file with open(self.log_dir, 'a+') as f: mse_avg=(mse_avg*1e6).astype(np.int64)/(1e6) psnr_avg=(psnr_avg*1e6).astype(np.int64)/(1e6) f.write('{'+'"Iter": {} , "PSNR": {}, "MSE": {}'.format(sess.run(self.global_step), psnr_avg.tolist(), mse_avg.tolist())+'}\n') def train(self): LR, HR= self.double_input_producer() global_step=tf.Variable(initial_value=0, trainable=False) self.global_step=global_step self.build() lr= tf.train.polynomial_decay(self.learning_rate, global_step, self.decay_step, end_learning_rate=self.end_lr, power=1.) vars_all=tf.trainable_variables() print('Params num of all:',get_num_params(vars_all)) training_op = tf.train.AdamOptimizer(lr).minimize(self.loss, var_list=vars_all, global_step=global_step) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) #sess=tf.Session() self.sess=sess sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(max_to_keep=100, keep_checkpoint_every_n_hours=1) if self.reload: self.load(sess, self.save_dir) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) cost_time=0 start_time=time.time() gs=sess.run(global_step) for step in range(sess.run(global_step), self.max_step): if step>gs and step%20==0: print(time.strftime("%Y-%m-%d %H:%M:%S",time.localtime()),'Step:{}, loss:{}'.format(step,loss_v)) if step % 500 == 0: if step>gs: self.save(sess, self.save_dir, step) cost_time=time.time()-start_time print('cost {}s.'.format(cost_time)) self.eval() cost_time=time.time()-start_time start_time=time.time() print('cost {}s.'.format(cost_time)) lr1,hr=sess.run([LR,HR]) _,loss_v=sess.run([training_op,self.loss],feed_dict={self.L:lr1, self.H:hr, self.is_train:True}) if step>500 and loss_v>10: print('Model collapsed with loss={}'.format(loss_v)) break def test_video_truth(self, path, name='result', reuse=False, part=8): save_path=join(path,name) automkdir(save_path) imgs=sorted(glob.glob(join(path,'truth','*.png'))) imgs=[cv2_imread(i)/255. for i in imgs] test_gt = tf.placeholder(tf.float32, [None, self.num_frames, None, None, 3]) test_inp=DownSample(test_gt, BLUR, scale=self.scale) if not reuse: self.build() sess=tf.Session() self.sess=sess sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(max_to_keep=100, keep_checkpoint_every_n_hours=1) self.load(sess, self.save_dir) gt_list=[] max_frame=len(imgs) for i in range(max_frame): index=np.array([i for i in range(i-self.num_frames//2,i+self.num_frames//2+1)]) index=np.clip(index,0,max_frame-1).tolist() gt=np.array([imgs[j] for j in index]) gt_list.append(gt) gt_list=np.array(gt_list) lr_list=self.sess.run(test_inp,feed_dict={test_gt:gt_list}) print('Save at {}'.format(save_path)) print('{} Inputs With Shape {}'.format(lr_list.shape[0],lr_list.shape[1:])) part=min(part,max_frame) if max_frame%part ==0 : num_once=max_frame//part else: num_once=max_frame//part+1 all_time=0 for i in trange(part): st_time=time.time() sr=self.sess.run(self.SR,feed_dict={self.L : lr_list[i*num_once:(i+1)*num_once], self.is_train : False}) onece_time=time.time()-st_time if i>0: all_time+=onece_time for j in range(sr.shape[0]): img=sr[j][0]*255. img=np.clip(img,0,255).astype(np.uint8) imgname='{:0>4}.png'.format(i*num_once+j) cv2_imsave(join(save_path, imgname),img) print('spent {} s in total and {} s in average'.format(all_time,all_time/(max_frame-1))) def test_video_lr(self, path, name='result', reuse=False, part=8): save_path=join(path,name) automkdir(save_path) inp_path=join(path,'blur{}'.format(self.scale)) imgs=sorted(glob.glob(join(inp_path,'*.png'))) imgs=np.array([cv2_imread(i)/255. for i in imgs]) lr_list=[] max_frame=imgs.shape[0] for i in range(max_frame): index=np.array([i for i in range(i-self.num_frames//2,i+self.num_frames//2+1)]) index=np.clip(index,0,max_frame-1).tolist() lr_list.append(np.array([imgs[j] for j in index])) lr_list=np.array(lr_list) if not reuse: self.build() sess=tf.Session() self.sess=sess sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(max_to_keep=100, keep_checkpoint_every_n_hours=1) self.load(sess, self.save_dir) print('Save at {}'.format(save_path)) print('{} Inputs With Shape {}'.format(lr_list.shape[0],lr_list.shape[1:])) part=min(part,max_frame) if max_frame%part ==0 : num_once=max_frame//part else: num_once=max_frame//part+1 all_time=0 for i in trange(part): st_time=time.time() sr=self.sess.run(self.SR,feed_dict={self.L : lr_list[i*num_once:(i+1)*num_once], self.is_train : False}) onece_time=time.time()-st_time if i>0: all_time+=onece_time for j in range(sr.shape[0]): img=sr[j][0]*255. img=np.clip(img,0,255).astype(np.uint8) imgname='{:0>4}.png'.format(i*num_once+j) cv2_imsave(join(save_path, imgname),img) print('spent {} s in total and {} s in average'.format(all_time,all_time/(max_frame-1))) def testvideos(self, path='/dev/f/data/video/test2/vid4', start=0, name='duf_52'): kind=sorted(glob.glob(join(path,'*'))) kind=[k for k in kind if os.path.isdir(k)] reuse=False for k in kind: idx=kind.index(k) if idx>=start: if idx>start: reuse=True datapath=join(path,k) self.test_video_lr(datapath, name=name, reuse=reuse, part=1000) if __name__=='__main__': model=DUFVSR() #model.train() model.testvideos()