import matplotlib.pyplot as plt import numpy as np import tensorflow as tf import tensorflow.contrib.slim as slim def periodic_padding(inpt, pad): L = inpt[:,:pad[0][0],:,:,:] if pad[0][1] > 0: R = inpt[:,-pad[0][1]:,:,:,:] else: R = inpt[:,:0,:,:,:] inpt_pad = tf.concat([R, inpt, L], axis=1) L = inpt_pad[:,:,:pad[1][0],:,:] if pad[1][1] > 0: R = inpt_pad[:,:,-pad[1][1]:,:,:] else: R = inpt_pad[:,:,:0,:,:] inpt_pad = tf.concat([R, inpt_pad, L], axis=2) L = inpt_pad[:,:,:,:pad[2][0],:] if pad[2][1] > 0: R = inpt_pad[:,:,:,-pad[2][1]:,:] else: R = inpt_pad[:,:,:,:0,:] inpt_pad = tf.concat([R, inpt_pad, L], axis=3) return inpt_pad def conv3d_withPeriodicPadding(inpt, filtr, strides, name=None): ### Does not work for large strides ### inpt_shape = inpt.get_shape().as_list() filtr_shape = filtr.get_shape().as_list() pad = [] for i_dim in range(3): # Compute pad assuming output_size = input_size / stride and odd filter sizes padL = int( 0.5*(filtr_shape[i_dim] - 1) ) padR = padL pad_idim = (padL,padR) pad.append(pad_idim) inpt_pad = periodic_padding(inpt, pad) output = tf.nn.conv3d(inpt_pad, filtr, strides, padding = 'VALID', data_format = 'NDHWC', name=name) return output def conv3d(inpt, f, output_channels, s, use_bias=False, scope='conv', name=None): inpt_shape = inpt.get_shape().as_list() with tf.variable_scope(scope): filtr = tf.get_variable(initializer=tf.contrib.layers.xavier_initializer(), shape=[f,f,f,inpt_shape[-1],output_channels],name='filtr') strides = [1,s,s,s,1] output = conv3d_withPeriodicPadding(inpt,filtr,strides,name) if use_bias: with tf.variable_scope(scope): bias = tf.get_variable(intializer=tf.zeros_initializer( [1,1,1,1,output_channels],dtype=tf.float32),name='bias') output = output + bias; return output def filter3d(inpt, scope='filter', name=None): inpt_shape = inpt.get_shape().as_list() with tf.variable_scope(scope): filter1D = tf.constant([0.04997364, 0.13638498, 0.20002636, 0.22723004, 0.20002636, 0.13638498, 0.04997364], dtype=tf.float32) filter1Dx = tf.reshape(filter1D, shape=(-1,1,1)) filter1Dy = tf.reshape(filter1D, shape=(1,-1,1)) filter1Dz = tf.reshape(filter1D, shape=(1,1,-1)) filter3D = filter1Dx * filter1Dy * filter1Dz # Tensor product 3D filter using broadcasting filter3D = tf.expand_dims(filter3D, axis=3) zero = tf.constant( np.zeros((7,7,7,1), dtype=np.float32) ) filter3Du = tf.concat( [filter3D, zero, zero, zero], axis=3 ) filter3Dv = tf.concat( [zero, filter3D, zero, zero], axis=3 ) filter3Dw = tf.concat( [zero, zero, filter3D, zero], axis=3 ) filter3Dp = tf.concat( [zero, zero, zero, filter3D], axis=3 ) filter3D = tf.stack( [filter3Du, filter3Dv, filter3Dw, filter3Dp], axis=4, name='filter' ) strides = [1,4,4,4,1] inpt_pad = periodic_padding( inpt, ((3,3),(3,3),(3,3)) ) output = tf.nn.conv3d(inpt_pad, filter3D, strides, padding = 'VALID', data_format = 'NDHWC', name=name) return output def ddx(inpt, channel, dx, scope='ddx', name=None): inpt_shape = inpt.get_shape().as_list() var = tf.expand_dims( inpt[:,:,:,:,channel], axis=4 ) with tf.variable_scope(scope): ddx1D = tf.constant([-1./60., 3./20., -3./4., 0., 3./4., -3./20., 1./60.], dtype=tf.float32) ddx3D = tf.reshape(ddx1D, shape=(-1,1,1,1,1)) strides = [1,1,1,1,1] var_pad = periodic_padding( var, ((3,3),(0,0),(0,0)) ) output = tf.nn.conv3d(var_pad, ddx3D, strides, padding = 'VALID', data_format = 'NDHWC', name=name) output = tf.scalar_mul(1./dx, output) return output def ddy(inpt, channel, dy, scope='ddy', name=None): inpt_shape = inpt.get_shape().as_list() var = tf.expand_dims( inpt[:,:,:,:,channel], axis=4 ) with tf.variable_scope(scope): ddy1D = tf.constant([-1./60., 3./20., -3./4., 0., 3./4., -3./20., 1./60.], dtype=tf.float32) ddy3D = tf.reshape(ddy1D, shape=(1,-1,1,1,1)) strides = [1,1,1,1,1] var_pad = periodic_padding( var, ((0,0),(3,3),(0,0)) ) output = tf.nn.conv3d(var_pad, ddy3D, strides, padding = 'VALID', data_format = 'NDHWC', name=name) output = tf.scalar_mul(1./dy, output) return output def ddz(inpt, channel, dz, scope='ddz', name=None): inpt_shape = inpt.get_shape().as_list() var = tf.expand_dims( inpt[:,:,:,:,channel], axis=4 ) with tf.variable_scope(scope): ddz1D = tf.constant([-1./60., 3./20., -3./4., 0., 3./4., -3./20., 1./60.], dtype=tf.float32) ddz3D = tf.reshape(ddz1D, shape=(1,1,-1,1,1)) strides = [1,1,1,1,1] var_pad = periodic_padding( var, ((0,0),(0,0),(3,3)) ) output = tf.nn.conv3d(var_pad, ddz3D, strides, padding = 'VALID', data_format = 'NDHWC', name=name) output = tf.scalar_mul(1./dz, output) return output def d2dx2(inpt, channel, dx, scope='d2dx2', name=None): inpt_shape = inpt.get_shape().as_list() var = tf.expand_dims( inpt[:,:,:,:,channel], axis=4 ) with tf.variable_scope(scope): ddx1D = tf.constant([1./90., -3./20., 3./2., -49./18., 3./2., -3./20., 1./90.], dtype=tf.float32) ddx3D = tf.reshape(ddx1D, shape=(-1,1,1,1,1)) strides = [1,1,1,1,1] var_pad = periodic_padding( var, ((3,3),(0,0),(0,0)) ) output = tf.nn.conv3d(var_pad, ddx3D, strides, padding = 'VALID', data_format = 'NDHWC', name=name) output = tf.scalar_mul(1./dx**2, output) return output def d2dy2(inpt, channel, dy, scope='d2dy2', name=None): inpt_shape = inpt.get_shape().as_list() var = tf.expand_dims( inpt[:,:,:,:,channel], axis=4 ) with tf.variable_scope(scope): ddy1D = tf.constant([1./90., -3./20., 3./2., -49./18., 3./2., -3./20., 1./90.], dtype=tf.float32) ddy3D = tf.reshape(ddy1D, shape=(1,-1,1,1,1)) strides = [1,1,1,1,1] var_pad = periodic_padding( var, ((0,0),(3,3),(0,0)) ) output = tf.nn.conv3d(var_pad, ddy3D, strides, padding = 'VALID', data_format = 'NDHWC', name=name) output = tf.scalar_mul(1./dy**2, output) return output def d2dz2(inpt, channel, dz, scope='d2dz2', name=None): inpt_shape = inpt.get_shape().as_list() var = tf.expand_dims( inpt[:,:,:,:,channel], axis=4 ) with tf.variable_scope(scope): ddz1D = tf.constant([1./90., -3./20., 3./2., -49./18., 3./2., -3./20., 1./90.], dtype=tf.float32) ddz3D = tf.reshape(ddz1D, shape=(1,1,-1,1,1)) strides = [1,1,1,1,1] var_pad = periodic_padding( var, ((0,0),(0,0),(3,3)) ) output = tf.nn.conv3d(var_pad, ddz3D, strides, padding = 'VALID', data_format = 'NDHWC', name=name) output = tf.scalar_mul(1./dz**2, output) return output def get_TKE(inpt, name='TKE'): with tf.name_scope(name): TKE = tf.square( inpt[:,:,:,0] ) TKE = TKE + tf.square( inpt[:,:,:,1] ) TKE = TKE + tf.square( inpt[:,:,:,2] ) TKE = 0.5*TKE TKE = tf.expand_dims(TKE, axis=4) return TKE def get_velocity_grad(inpt, dx, dy, dz, scope='vel_grad', name=None): with tf.variable_scope(scope): dudx = ddx(inpt, 0, dx, scope='dudx') dudy = ddy(inpt, 0, dy, scope='dudy') dudz = ddz(inpt, 0, dz, scope='dudz') dvdx = ddx(inpt, 1, dx, scope='dvdx') dvdy = ddy(inpt, 1, dy, scope='dvdy') dvdz = ddz(inpt, 1, dz, scope='dvdz') dwdx = ddx(inpt, 2, dx, scope='dwdx') dwdy = ddy(inpt, 2, dy, scope='dwdy') dwdz = ddz(inpt, 2, dz, scope='dwdz') return dudx, dvdx, dwdx, dudy, dvdy, dwdy, dudz, dvdz, dwdz def get_strain_rate_mag2(vel_grad, scope='strain_rate_mag', name=None): dudx, dvdx, dwdx, dudy, dvdy, dwdy, dudz, dvdz, dwdz = vel_grad strain_rate_mag2 = dudx**2 + dvdy**2 + dwdz**2 \ + 2*( (0.5*(dudy + dvdx))**2 + (0.5*(dudz + dwdx))**2 + (0.5*(dvdz + dwdy))**2 ) return strain_rate_mag2 def get_vorticity(vel_grad, scope='vorticity', name=None): dudx, dvdx, dwdx, dudy, dvdy, dwdy, dudz, dvdz, dwdz = vel_grad vort_x = dwdy - dvdz vort_y = dudz - dwdx vort_z = dvdx - dudy return vort_x, vort_y, vort_z def get_enstrophy(vorticity, name='enstrophy'): omega_x, omega_y, omega_z = vorticity with tf.name_scope(name): Omega = omega_x**2 + omega_y**2 + omega_z**2 return Omega def get_continuity_residual(vel_grad, name='continuity'): dudx, dvdx, dwdx, dudy, dvdy, dwdy, dudz, dvdz, dwdz = vel_grad with tf.name_scope(name): res = dudx + dvdy + dwdz return res def get_pressure_residual(inpt, vel_grad, dx, dy, dz, scope='pressure'): dudx, dvdx, dwdx, dudy, dvdy, dwdy, dudz, dvdz, dwdz = vel_grad with tf.variable_scope(scope): d2pdx2 =d2dx2(inpt, 3, dx) d2pdy2 =d2dy2(inpt, 3, dy) d2pdz2 =d2dz2(inpt, 3, dz) res = (d2pdx2 + d2pdy2 + d2pdz2) res = res + dudx*dudx + dvdy*dvdy + dwdz*dwdz \ + 2*(dudy*dvdx + dudz*dwdx + dvdz*dwdy) return res def prelu_tf(inputs, name='Prelu'): with tf.variable_scope(name): alphas = tf.get_variable('alpha',inputs.get_shape()[-1], initializer=tf.zeros_initializer(),dtype=tf.float32) pos = tf.nn.relu(inputs) neg = alphas * (inputs - abs(inputs)) * 0.5 return pos + neg def lrelu(inputs, alpha): return tf.keras.layers.LeakyReLU(alpha=alpha).call(inputs) def denselayer(inputs, output_size): output = tf.layers.dense(inputs, output_size, activation=None, kernel_initializer=tf.contrib.layers.xavier_initializer()) return output def batchnorm(inputs, is_training): return slim.batch_norm(inputs,decay=0.9,epsilon=0.001, updates_collections=tf.GraphKeys.UPDATE_OPS, scale=False,fused=True,is_training=is_training) def print_configuration_op(FLAGS): print('[Configurations]:') a = FLAGS.mode #pdb.set_trace() for name, value in FLAGS.__flags.items(): if type(value) == float: print('\t%s: %f'%(name, value)) elif type(value) == int: print('\t%s: %d'%(name, value)) elif type(value) == str: print('\t%s: %s'%(name, value)) elif type(value) == bool: print('\t%s: %s'%(name, value)) else: print('\t%s: %s' % (name, value)) print('End of configuration') def phaseShift(inputs, shape_1, shape_2): # Tackle the condition when the batch is None X = tf.reshape(inputs, shape_1) X = tf.transpose(X, [0, 1, 4, 2, 5, 3, 6]) return tf.reshape(X, shape_2) # The implementation of PixelShuffler def pixelShuffler(inputs, scale=2): # size = tf.shape(inputs) size = inputs.get_shape().as_list() batch_size = size[0] d = size[1] h = size[2] w = size[3] c = size[4] # Get the target channel size channel_target = c // (scale * scale * scale) channel_factor = c // channel_target shape_1 = [batch_size, d, h, w, scale, scale, scale] shape_2 = [batch_size, d * scale, h * scale, w * scale, 1] # Reshape and transpose for periodic shuffling for each channel input_split = tf.split(inputs, channel_target, axis=4) output = tf.concat([phaseShift(x, shape_1, shape_2) for x in input_split], axis=4) return output def convert_to_rgba(a, vmin, vmax, cmap=plt.cm.viridis): rgba = cmap( (a - vmin)/(vmax - vmin) ) return rgba def get_slice_images(HR, LR, out, n_images=1): batch_size, nx, ny, nz, nvars = HR.shape grid_size = [nx, ny, nz] images = [] for i in range(n_images): batch = np.random.randint(batch_size, size=1)[0] var = np.random.randint(nvars-1, size=1)[0] plane = np.random.randint(3, size=1)[0] index = np.random.randint(grid_size[plane]//4, size=1)[0] print("batch = {}, var = {}, plane = {}, index = {}".format(batch, var, plane, index)) if plane == 0: var_HR = HR[batch, index*4, :, :, var] var_LR = LR[batch, index, :, :, var] var_out = out[batch, index*4, :, :, var] elif plane == 1: var_HR = HR[batch, :, index*4, :, var] var_LR = LR[batch, :, index, :, var] var_out = out[batch, :, index*4, :, var] elif plane == 2: var_HR = HR[batch, :, :,index*4, var] var_LR = LR[batch, :, :,index, var] var_out = out[batch, :, :,index*4, var] else: raise ValueError('Plane has to be 0, 1 or 2. Given {}'.format(plane)) vmin = var_HR.min() - 1.e-10 vmax = var_HR.max() + 1.e-10 # Repeat values to make it 64^3 var_LR = np.repeat(var_LR, 4, axis=0) var_LR = np.repeat(var_LR, 4, axis=1) im_HR = convert_to_rgba(var_HR, vmin, vmax) im_LR = convert_to_rgba(var_LR, vmin, vmax) im_out = convert_to_rgba(var_out, vmin, vmax) im = np.concatenate((im_HR, im_LR, im_out), axis=1) images.append( im ) return np.stack(images, axis=0) if __name__ == "__main__": import matplotlib.pyplot as plt import numpy as np # Check if derivatives are correct with periodic padding X = tf.placeholder(tf.float32, shape=(1, 5, 5, 5, 1)) pad = [(1,1),(1,1),(1,1)] X_pad = periodic_padding(X, pad) x_sum = tf.reduce_sum(X_pad) dX = tf.gradients(x_sum, X) with tf.Session() as sess: grad, = sess.run(dX, feed_dict={X: np.ones((1, 5, 5, 5, 1), dtype=np.float32)}) print(grad[0,:,:,:,0]) # Checking the padding in conv3d filtr = tf.constant(np.ones((3,3,3,1,1), dtype =np.float32)) output = conv3d_withPeriodicPadding(X, filtr, [1,1,1,1,1]) with tf.Session() as sess: XConv = sess.run(output, feed_dict={X: np.ones((1, 5, 5, 5, 1), dtype=np.float32)}) print(XConv.shape) print(XConv[0,:,:,:,0]) Xphase = tf.placeholder(tf.float32, shape=(1, 4, 4, 4, 4)) shape_1 = [1, 4, 4, 4, 2, 2, 1] shape_2 = [1, 8, 8, 4, 1] X_ups = phaseShift(Xphase, shape_1, shape_2) with tf.Session() as sess: xphase = np.zeros((1,4,4,4,4)) for i in range(4): xphase[:,:,:,:,i] = i xups = sess.run( X_ups, feed_dict={ Xphase: xphase } ) print(xups[0,:,:,0,0]) print(xups.shape) HR = tf.placeholder(tf.float32, shape=(2,16,16,16,4)) LR = filter3d(HR) with tf.Session() as sess: hr = np.zeros( (2,16,16,16,4), dtype=np.float32 ) for i in range(4): hr[:,:,:,:,i] = i lr = sess.run(LR, feed_dict={HR: hr} ) print(lr[0,:,:,:,3]) print(lr.shape) var = tf.placeholder(tf.float32, shape=(2,16,16,16,4)) dudx = ddx(var, 0, 2.*np.pi/16.) dudy = ddy(var, 0, 2.*np.pi/16.) dudz = ddz(var, 0, 2.*np.pi/16.) d2udx2 = d2dx2(var, 0, 2.*np.pi/16.) d2udy2 = d2dy2(var, 0, 2.*np.pi/16.) d2udz2 = d2dz2(var, 0, 2.*np.pi/16.) with tf.Session() as sess: x = np.linspace(0,2.*np.pi,num=16+1)[:-1].reshape((16,1,1)).repeat(16, axis=1).repeat(16, axis=2) y = np.linspace(0,2.*np.pi,num=16+1)[:-1].reshape((1,16,1)).repeat(16, axis=0).repeat(16, axis=2) z = np.linspace(0,2.*np.pi,num=16+1)[:-1].reshape((1,1,16)).repeat(16, axis=0).repeat(16, axis=1) v = np.zeros( (2,16,16,16,4) ) v[0,:,:,:,0] = np.sin(x) v[1,:,:,:,0] = np.cos(x) dv = np.zeros( (2,16,16,16,1) ) dv[0,:,:,:,0] = np.cos(x) dv[1,:,:,:,0] = -np.sin(x) dv2 = np.zeros( (2,16,16,16,1) ) dv2[0,:,:,:,0] = -np.sin(x) dv2[1,:,:,:,0] = -np.cos(x) du, du2 = sess.run( [dudx, d2udx2], feed_dict={var: v}) print(du.shape) print("Max X derivative error = {}".format( np.absolute(du-dv).max())) print("Max X 2nd derivative error = {}".format( np.absolute(du2-dv2).max())) v[0,:,:,:,0] = np.sin(y) v[1,:,:,:,0] = np.cos(y) dv[0,:,:,:,0] = np.cos(y) dv[1,:,:,:,0] = -np.sin(y) dv2[0,:,:,:,0] = -np.sin(y) dv2[1,:,:,:,0] = -np.cos(y) du, du2 = sess.run( [dudy, d2udy2], feed_dict={var: v}) print(du.shape) print("Max Y derivative error = {}".format( np.absolute(du-dv).max())) print("Max Y 2nd derivative error = {}".format( np.absolute(du2-dv2).max())) v[0,:,:,:,0] = np.sin(z) v[1,:,:,:,0] = np.cos(z) dv[0,:,:,:,0] = np.cos(z) dv[1,:,:,:,0] = -np.sin(z) dv2[0,:,:,:,0] = -np.sin(z) dv2[1,:,:,:,0] = -np.cos(z) du, du2 = sess.run( [dudz, d2udz2], feed_dict={var: v}) print(du.shape) print("Max Z derivative error = {}".format( np.absolute(du-dv).max())) print("Max Z 2nd derivative error = {}".format( np.absolute(du2-dv2).max())) for i in range(4): v[0,:,:,:,i] = np.sin(3*z) v[1,:,:,:,i] = np.cos(3*z) images = get_slice_images(v, v[:,::4,::4,::4,:], v, n_images=1) print(images.shape)