""" generation of images interactively with ui control """ import os import glob import sys import numpy as np import time import pickle import tensorflow as tf import PIL import matplotlib matplotlib.use('TkAgg') import matplotlib.pyplot as plt import matplotlib.widgets as widgets plt.ion() import src.tl_gan.feature_axis as feature_axis def gen_time_str(): """ tool function """ return time.strftime("%Y%m%d_%H%M%S", time.gmtime()) """ location to save images """ path_gan_explore_interactive = './asset_results/pggan_celeba_feature_axis_explore_interactive/' if not os.path.exists(path_gan_explore_interactive): os.mkdir(path_gan_explore_interactive) ## """ load feature directions """ path_feature_direction = './asset_results/pg_gan_celeba_feature_direction_40' pathfile_feature_direction = glob.glob(os.path.join(path_feature_direction, 'feature_direction_*.pkl'))[-1] with open(pathfile_feature_direction, 'rb') as f: feature_direction_name = pickle.load(f) feature_direction = feature_direction_name['direction'] feature_name = feature_direction_name['name'] num_feature = feature_direction.shape[1] ## """ load gan model """ # path to model code and weight path_pg_gan_code = './src/model/pggan' path_model = './asset_model/karras2018iclr-celebahq-1024x1024.pkl' sys.path.append(path_pg_gan_code) """ create tf session """ yn_CPU_only = False if yn_CPU_only: config = tf.ConfigProto(device_count = {'GPU': 0}, allow_soft_placement=True) else: config = tf.ConfigProto(allow_soft_placement=True) sess = tf.InteractiveSession(config=config) try: with open(path_model, 'rb') as file: G, D, Gs = pickle.load(file) except FileNotFoundError: print('before running the code, download pre-trained model to project_root/asset_model/') raise num_latent = Gs.input_shapes[0][1] ## # Generate random latent variables latents = np.random.randn(1, *Gs.input_shapes[0][1:]) # Generate dummy labels dummies = np.zeros([latents.shape[0]] + Gs.input_shapes[1][1:]) def gen_image(latents): """ tool funciton to generate image from latent variables :param latents: latent variables :return: """ images = Gs.run(latents, dummies) images = np.clip(np.rint((images + 1.0) / 2.0 * 255.0), 0.0, 255.0).astype(np.uint8) # [-1,1] => [0,255] images = images.transpose(0, 2, 3, 1) # NCHW => NHWC return images[0] img_cur = gen_image(latents) ## """ plot figure with GUI """ h_fig = plt.figure(figsize=[12, 6]) h_ax = plt.axes([0.0, 0.0, 0.5, 1.0]) h_ax.axis('off') h_img = plt.imshow(img_cur) yn_save_fig = True class GuiCallback(object): counter = 0 latents = latents def __init__(self): self.latents = np.random.randn(1, *Gs.input_shapes[0][1:]) self.feature_direction = feature_direction self.feature_lock_status = np.zeros(num_feature).astype('bool') self.feature_directoion_disentangled = feature_axis.disentangle_feature_axis_by_idx( self.feature_direction, idx_base=np.flatnonzero(self.feature_lock_status)) img_cur = gen_image(self.latents) h_img.set_data(img_cur) plt.draw() def random_gen(self, event): self.latents = np.random.randn(1, *Gs.input_shapes[0][1:]) img_cur = gen_image(self.latents) h_img.set_data(img_cur) plt.draw() def modify_along_feature(self, event, idx_feature, step_size=0.05): self.latents += self.feature_directoion_disentangled[:, idx_feature] * step_size img_cur = gen_image(self.latents) h_img.set_data(img_cur) plt.draw() plt.savefig(os.path.join(path_gan_explore_interactive, '{}_{}_{}.png'.format(gen_time_str(), feature_name[idx_feature], ('pos' if step_size>0 else 'neg')))) def set_feature_lock(self, event, idx_feature): self.feature_lock_status[idx_feature] = np.logical_not(self.feature_lock_status[idx_feature]) self.feature_directoion_disentangled = feature_axis.disentangle_feature_axis_by_idx( self.feature_direction, idx_base=np.flatnonzero(self.feature_lock_status)) callback = GuiCallback() ax_randgen = plt.axes([0.55, 0.90, 0.15, 0.05]) b_randgen = widgets.Button(ax_randgen, 'Random Generate') b_randgen.on_clicked(callback.random_gen) def get_loc_control(idx_feature, nrows=8, ncols=5, xywh_range=(0.51, 0.05, 0.48, 0.8)): r = idx_feature // ncols c = idx_feature % ncols x, y, w, h = xywh_range xywh = x+c*w/ncols, y+(nrows-r-1)*h/nrows, w/ncols, h/nrows return xywh step_size = 0.4 def create_button(idx_feature): """ function to built button groups for one feature """ x, y, w, h = get_loc_control(idx_feature) plt.text(x+w/2, y+h/2+0.01, feature_name[idx_feature], horizontalalignment='center', transform=plt.gcf().transFigure) ax_neg = plt.axes((x + w / 8, y, w / 4, h / 2)) b_neg = widgets.Button(ax_neg, '-', hovercolor='0.1') b_neg.on_clicked(lambda event: callback.modify_along_feature(event, idx_feature, step_size=-1 * step_size)) ax_pos = plt.axes((x + w *5/8, y, w / 4, h / 2)) b_pos = widgets.Button(ax_pos, '+', hovercolor='0.1') b_pos.on_clicked(lambda event: callback.modify_along_feature(event, idx_feature, step_size=+1 * step_size)) ax_lock = plt.axes((x + w * 3/8, y, w / 4, h / 2)) b_lock = widgets.CheckButtons(ax_lock, ['L'], [False]) b_lock.on_clicked(lambda event: callback.set_feature_lock(event, idx_feature)) return b_neg, b_pos, b_lock list_buttons = [] for idx_feature in range(num_feature): list_buttons.append(create_button(idx_feature)) plt.show() ## #sess.close()