# -*- coding: utf-8 -*- import os import re import sys import time import json import pdb import numpy as np import tensorflow as tf import matplotlib as mpl mpl.use('Agg') import matplotlib.pyplot as plt from scipy.io import loadmat from datetime import datetime from model.vaegan import VAEGAN from PIL import Image from iohandler.datareader import find_files FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string( 'datadir', './data/TWKai98_32x32', 'data dir') tf.app.flags.DEFINE_string( 'architecture', None, 'network architecture') tf.app.flags.DEFINE_string('logdir', 'logdir', 'log dir') tf.app.flags.DEFINE_string('checkpoint', None, 'model checkpoint') def SingleFileReader(filename, shape, rtype='tanh', ext='jpg'): n, h, w, c = shape if ext == 'jpg' or ext == 'jpeg': decoder = tf.image.decode_jpeg elif ext == 'png': decoder = tf.image.decode_png else: raise ValueError('Unsupported file type: {:s}.'.format(ext) + ' (only *.png and *.jpg are supported') filename_queue = tf.train.string_input_producer(filename, shuffle=False) reader = tf.WholeFileReader() key, value = reader.read(filename_queue) img = decoder(value, channels=c) img = tf.image.crop_to_bounding_box(img, 0, 0, h, w) img = tf.to_float(img) if rtype == 'tanh': img = tf.div(img, 127.5) - 1. imgs = tf.train.batch( [img], batch_size=n, capacity=1) return imgs, key def fit_the_shape(x_, shape): n, h, w, c = shape x_ = np.reshape( np.transpose(x_, [1, 0, 2, 3]), [h, w * n, c]) if x_.shape[-1] == 1: x_ = x_[:, :, 0] return x_ def main(): if FLAGS.checkpoint is None: raise ValueError('You must specify a checkpoint file.') # FLAGS started_datestring = "{0:%Y-%m%d-%H%M-%S}".format(datetime.now()) logdir = os.path.join(FLAGS.logdir, 'generate', started_datestring) if FLAGS.architecture is None: ckpt_dir = os.path.split(FLAGS.checkpoint)[0] architecture = os.path.join(ckpt_dir, 'architecture.json') else: architecture = FLAGS.architecture with open(architecture) as f: arch = json.load(f) h, w, c = arch['hwc'] coord = tf.train.Coordinator() print(FLAGS.datadir) net = VAEGAN(arch, is_training=False) n = 3 filenames = list() with open('test.txt', encoding='utf8') as f: for line in f: chars = list(line.strip()) for char in chars: filename = os.path.join( FLAGS.datadir, 'U{:d}.jpg'.format(ord(char))) filenames.append(filename) n_iter = len(filenames) // n xs, _ = SingleFileReader(filenames, shape=[n, h, w, c]) z = net.encode(xs)['mu'] xh = net.decode(z, tanh=True) z_any = tf.placeholder(dtype=tf.float32, shape=[None, arch['z_dim']]) xh_any = net.decode(z_any, tanh=True) # Restore model sess = tf.Session() saver = tf.train.Saver() print('Restoring model from {}'.format(FLAGS.checkpoint)) saver.restore(sess, FLAGS.checkpoint) threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: for it in range(n_iter): x_, z_, xh_ = sess.run([xs, z, xh]) z_final = z_[0] - z_[1] + z_[2] z_final = np.reshape(z_final, [1, -1]) a = np.asarray([.25, .5, 1., 2., 3]).reshape([-1, 1]) x_f_all = sess.run(xh_any, feed_dict={z_any: a * z_final}) x_refer = x_[0] - x_[1] + x_[2] xr = x_refer[:, :, 0] plt.figure() plt.subplot(4, 1, 1) plt.imshow( fit_the_shape(x_, [n, h, w, c]), cmap='gray') plt.subplot(4, 1, 2) plt.imshow( fit_the_shape(xh_, [n, h, w, c]), cmap='gray') plt.subplot(4, 1, 3) plt.imshow( fit_the_shape(x_f_all, [len(a), h, w, c]), cmap='gray') plt.subplot(4, 1, 4) plt.imshow(xr, cmap='gray') plt.savefig('test-arith-{}.png'.format( # ''.join(chars[3 * it: 3 * (it + 1)]))) it)) plt.close() finally: coord.request_stop() coord.join(threads) if __name__ == '__main__': main()