import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np import h5py import scipy.interpolate as si from scipy import ndimage import sys varindices = {'u': 0, 'v': 1, 'w': 2, 'p': 3} def bicubic_interpolation(LR, HR_shape): LR_padded = np.zeros((LR.shape[0]+1,LR.shape[1]+1,LR.shape[2]+1)) LR_padded[:-1,:-1,:-1] = LR[:,:,:] LR_padded[-1,:-1,:-1] = LR[0,:,:] LR_padded[:-1,-1,:-1] = LR[:,0,:] LR_padded[:-1,:-1,-1] = LR[:,:,0] LR_padded[:-1,-1,-1] = LR[:,0,0] LR_padded[-1,:-1,-1] = LR[0,:,0] LR_padded[-1,-1,:-1] = LR[0,0,:] LR_padded[-1,-1,-1] = LR[0,0,0] x_HR = np.linspace(0, LR.shape[0], num=HR_shape[0]+1)[:-1] y_HR = np.linspace(0, LR.shape[1], num=HR_shape[1]+1)[:-1] z_HR = np.linspace(0, LR.shape[2], num=HR_shape[2]+1)[:-1] xx, yy, zz = np.meshgrid(x_HR, y_HR, z_HR, indexing='ij') xx = xx.reshape((-1)) yy = yy.reshape((-1)) zz = zz.reshape((-1)) out_BC = ndimage.map_coordinates(LR_padded, [xx, yy, zz], order=3, mode='wrap').reshape(HR_shape) return out_BC def plot(var, ax, extent=(0, 2.*np.pi,0, 2.*np.pi), vmin=None, vmax=None, cmap=None): # if vmin == None: # vmin = var.min() # if vmax == None: # vmax = var.max() if cmap == None: cmap = plt.get_cmap('viridis') im = ax.imshow( var, extent=extent, vmin=vmin, vmax=vmax, cmap=cmap, origin='lower', interpolation='none', aspect='equal' ) ax.set_xlim((extent[0],extent[1])) ax.set_ylim((extent[2],extent[3])) return im def make_comparison_plots(LR, HR, out, output_label='Generated output'): vmin = HR.min() vmax = HR.max() LR_padded = np.zeros((LR.shape[0]+1,LR.shape[1]+1)) LR_padded[:-1,:-1] = LR[:,:] LR_padded[-1,:-1] = LR[0,:] LR_padded[:-1,-1] = LR[:,0] LR_padded[-1,-1] = LR[0,0] x_HR = np.linspace(0, LR.shape[0], num=HR.shape[0]+1)[:-1] y_HR = np.linspace(0, LR.shape[1], num=HR.shape[1]+1)[:-1] print(x_HR) yy, xx = np.meshgrid(y_HR, x_HR) xx = xx.reshape((-1)) yy = yy.reshape((-1)) out_BC = ndimage.map_coordinates(LR_padded, [xx, yy], order=3, mode='wrap').reshape(HR.shape) fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, sharey=True, figsize=(17,5)) im1 = plot(LR, ax1, vmin=vmin, vmax=vmax) ax1.set_title('Low resolution') im2 = plot(out_BC, ax2, vmin=vmin, vmax=vmax) ax2.set_title('Bicubic') im3 = plot(out, ax3, vmin=vmin, vmax=vmax) ax3.set_title(output_label) im4 = plot(HR, ax4, vmin=vmin, vmax=vmax) ax4.set_title('High resolution') fig.tight_layout() return fig, (ax1, ax2, ax3, ax4) def convert_to_rgb(a, vmin, vmax, cmap=plt.cm.viridis): rgb = cmap( (a - vmin)/(vmax - vmin) ) return rgb[:,:,:-1] if __name__ == '__main__': if len(sys.argv) < 6: print("Usage: ") print(" python {} <HDF5 output file> <variable> <plane> <index> <output filename>".format(sys.argv[0])) sys.exit() filename = sys.argv[1] var = varindices[ sys.argv[2] ] plane = int(sys.argv[3]) index = int(sys.argv[4]) output_filename = sys.argv[5] h5f = h5py.File(filename, 'r+') HR = h5f['HR'].value LR = h5f['LR'].value out = h5f['output'].value batch_size, nx, ny, nz, _ = HR.shape x = np.linspace(0,2.*np.pi,num=nx+1)[:-1].reshape((nx,1,1)).repeat(ny, axis=1).repeat(nz, axis=2) y = np.linspace(0,2.*np.pi,num=ny+1)[:-1].reshape((1,ny,1)).repeat(nx, axis=0).repeat(nz, axis=2) z = np.linspace(0,2.*np.pi,num=nz+1)[:-1].reshape((1,1,nz)).repeat(nx, axis=0).repeat(ny, axis=1) batch = np.random.randint(batch_size, size=1)[0] if plane == 0: var_HR = HR[batch, index, :, :, var] var_LR = LR[batch, index, :, :, var] var_out = out[batch, index, :, :, var] elif plane == 1: var_HR = HR[batch, :, index, :, var] var_LR = LR[batch, :, index, :, var] var_out = out[batch, :, index, :, var] elif plane == 2: var_HR = HR[batch, :, :,index, var] var_LR = LR[batch, :, :,index, var] var_out = out[batch, :, :,index, var] else: raise ValueError('Plane has to be 0, 1 or 2. Given {}'.format(plane)) fig, (ax1, ax2, ax3, ax4) = make_comparison_plots(var_LR, var_HR, var_out) fig.savefig(output_filename) print("Saved plot to {}".format(output_filename))