from __future__ import print_function import theano import theano.tensor as T from time import time from lib import HOGNet from lib.rng import np_rng from lib.theano_utils import floatX, sharedX import numpy as np from lib import AlexNet import lasagne from scipy import optimize import argparse from PIL import Image from pydoc import locate from lib import activations def def_feature(layer='conv4', up_scale=4): print('COMPILING...') t = time() x = T.tensor4() x_t = AlexNet.transform_im(x) x_net = AlexNet.build_model(x_t, layer=layer, shape=(None, 3, 64, 64), up_scale=up_scale) AlexNet.load_model(x_net, layer=layer) x_f = lasagne.layers.get_output(x_net[layer], deterministic=True) _ftr = theano.function(inputs=[x], outputs=x_f) print('%.2f seconds to compile _feature function' % (time() - t)) return _ftr def def_bfgs(model_G, layer='conv4', npx=64, alpha=0.002): print('COMPILING...') t = time() x_f = T.tensor4() x = T.tensor4() z = T.matrix() tanh = activations.Tanh() gx = model_G(tanh(z)) if layer is 'hog': gx_f = HOGNet.get_hog(gx, use_bin=True, BS=4) else: gx_t = AlexNet.transform_im(gx) gx_net = AlexNet.build_model(gx_t, layer=layer, shape=(None, 3, npx, npx)) AlexNet.load_model(gx_net, layer=layer) gx_f = lasagne.layers.get_output(gx_net[layer], deterministic=True) f_rec = T.mean(T.sqr(x_f - gx_f), axis=(1, 2, 3)) * sharedX(alpha) x_rec = T.mean(T.sqr(x - gx), axis=(1, 2, 3)) cost = T.sum(f_rec) + T.sum(x_rec) grad = T.grad(cost, z) output = [cost, grad, gx] _invert = theano.function(inputs=[z, x, x_f], outputs=output) print('%.2f seconds to compile _bfgs function' % (time() - t)) return _invert, z def def_predict(model_P): print('COMPILING...') t = time() x = T.tensor4() z = model_P(x) _predict = theano.function([x], [z]) print('%.2f seconds to compile _predict function' % (time() - t)) return _predict def def_invert_models(gen_model, layer='conv4', alpha=0.002): bfgs_model = def_bfgs(gen_model.model_G, layer=layer, npx=gen_model.npx, alpha=alpha) ftr_model = def_feature(layer=layer) predict_model = def_predict(gen_model.model_P) return gen_model, bfgs_model, ftr_model, predict_model def predict_z(gen_model, _predict, ims, batch_size=32): n = ims.shape[0] n_gen = 0 zs = [] n_batch = int(np.ceil(n / float(batch_size))) for i in range(n_batch): imb = gen_model.transform(ims[batch_size * i:min(n, batch_size * (i + 1)), :, :, :]) zmb = _predict(imb) zs.append(zmb) n_gen += len(imb) zs = np.squeeze(np.concatenate(zs, axis=0)) if np.ndim(zs) == 1: zs = zs[np.newaxis, :] return zs def invert_bfgs_batch(gen_model, invert_model, ftr_model, ims, z_predict=None, npx=64): zs = [] recs = [] fs = [] n_imgs = ims.shape[0] print('reconstruct %d images using bfgs' % n_imgs) for n in range(n_imgs): im_n = ims[[n], :, :, :] if z_predict is not None: z0_n = z_predict[[n], ...] else: z0_n = None gx, z_value, f_value = invert_bfgs(gen_model, invert_model, ftr_model, im=im_n, z_predict=z0_n, npx=npx) rec_im = (gx * 255).astype(np.uint8) fs.append(f_value[np.newaxis, ...]) zs.append(z_value[np.newaxis, ...]) recs.append(rec_im) recs = np.concatenate(recs, axis=0) zs = np.concatenate(zs, axis=0) fs = np.concatenate(fs, axis=0) return recs, zs, fs def invert_bfgs(gen_model, invert_model, ftr_model, im, z_predict=None, npx=64): _f, z = invert_model nz = gen_model.nz if z_predict is None: z_predict = np_rng.uniform(-1., 1., size=(1, nz)) else: z_predict = floatX(z_predict) z_predict = np.arctanh(z_predict) im_t = gen_model.transform(im) ftr = ftr_model(im_t) prob = optimize.minimize(f_bfgs, z_predict, args=(_f, im_t, ftr), tol=1e-6, jac=True, method='L-BFGS-B', options={'maxiter': 200}) print('n_iters = %3d, f = %.3f' % (prob.nit, prob.fun)) z_opt = prob.x z_opt_n = floatX(z_opt[np.newaxis, :]) [f_opt, g, gx] = _f(z_opt_n, im_t, ftr) gx = gen_model.inverse_transform(gx, npx=npx) z_opt = np.tanh(z_opt) return gx, z_opt, f_opt def f_bfgs(z0, _f, x, x_f): z0_n = floatX(z0[np.newaxis, :]) [f, g, gx] = _f(z0_n, x, x_f) f = f.astype(np.float64) g = g[0].astype(np.float64) return f, g def invert_images_CNN_opt(invert_models, ims, solver='cnn'): gen_model, invert_model, ftr_model, predict_model = invert_models n_imgs = len(ims) print('process %d images' % n_imgs) # gen_samples(self, z0=None, n=32, batch_size=32, use_transform=True) if solver == 'cnn' or solver == 'cnn_opt': z_predict = predict_z(gen_model, predict_model, ims, batch_size=n_imgs) else: z_predict = None if solver == 'cnn': recs = gen_model.gen_samples(z0=z_predict, n=n_imgs, batch_size=n_imgs) zs = None if solver == 'cnn_opt' or solver == 'opt': recs, zs, loss = invert_bfgs_batch(gen_model, invert_model, ftr_model, ims, z_predict=z_predict, npx=npx) return recs, zs, z_predict def parse_args(): parser = argparse.ArgumentParser(description='iGAN: Interactive Visual Synthesis Powered by GAN') parser.add_argument('--model_name', dest='model_name', help='the model name', default='shoes_64', type=str) parser.add_argument('--model_type', dest='model_type', help='the generative models and its deep learning framework', default='dcgan_theano', type=str) parser.add_argument('--input_image', dest='input_image', help='input image', default='./pics/shoes_test.png', type=str) parser.add_argument('--output_image', dest='output_image', help='output reconstruction image', default=None, type=str) parser.add_argument('--model_file', dest='model_file', help='the file that stores the generative model', type=str, default=None) # cnn: feed-forward network; opt: optimization based; cnn_opt: hybrid of the two methods parser.add_argument('--solver', dest='solver', help='solver (cnn, opt, or cnn_opt)', type=str, default='cnn_opt') args = parser.parse_args() return args if __name__ == "__main__": args = parse_args() if not args.model_file: # if the model file is not specified args.model_file = './models/%s.%s' % (args.model_name, args.model_type) if not args.output_image: # if the output image path is not specified args.output_image = args.input_image.replace('.png', '_%s.png' % args.solver) for arg in vars(args): print('[%s] =' % arg, getattr(args, arg)) # read a single image im = Image.open(args.input_image) [h, w] = im.size print('read image: %s (%dx%d)' % (args.input_image, h, w)) # define the theano models model_class = locate('model_def.%s' % args.model_type) gen_model = model_class.Model(model_name=args.model_name, model_file=args.model_file, use_predict=True) invert_models = def_invert_models(gen_model, layer='conv4', alpha=0.002) # pre-processing steps npx = gen_model.npx im = im.resize((npx, npx)) im = np.array(im) im_pre = im[np.newaxis, :, :, :] # run the model rec, _, _ = invert_images_CNN_opt(invert_models, im_pre, solver=args.solver) rec = np.squeeze(rec) rec_im = Image.fromarray(rec) # resize the image (input aspect ratio) rec_im = rec_im.resize((h, w)) print('write result to %s' % args.output_image) rec_im.save(args.output_image)