import tqdm import os import shutil import matplotlib.pyplot as plt import multiprocessing import numpy as np import torch from deep_privacy import torch_utils, config_parser from deep_privacy.data_tools.dataloaders import load_dataset from deep_privacy.data_tools.data_utils import denormalize_img from deep_privacy.utils import load_checkpoint from deep_privacy.inference.infer import init_generator def read_args(): config = config_parser.initialize_and_validate_config([ {"name": "target_path", "default": ""} ]) save_path = config.target_path if save_path == "": default_path = os.path.join( os.path.dirname(config.config_path), "fid_images" ) print("Setting target path to default:", default_path) save_path = default_path model_name = config.config_path.split("/")[-2] ckpt = load_checkpoint(os.path.join("validation_checkpoints", model_name)) #ckpt = load_checkpoint(os.path.join( # os.path.dirname(config.config_path), # "checkpoints")) generator = init_generator(config, ckpt) imsize = ckpt["current_imsize"] pose_size = config.models.pose_size return generator, imsize, save_path, pose_size generator, imsize, save_path, pose_size = read_args() batch_size = 128 dataloader_train, dataloader_val = load_dataset("fdf", batch_size, 128, True, pose_size, True ) dataloader_val.update_next_transition_variable(1.0) fake_images = np.zeros((len(dataloader_val)*batch_size, imsize, imsize, 3), dtype=np.uint8) real_images = np.zeros((len(dataloader_val)*batch_size, imsize, imsize, 3), dtype=np.uint8) z = generator.generate_latent_variable(batch_size, "cuda", torch.float32).zero_() with torch.no_grad(): for idx, (real_data, condition, landmarks) in enumerate(tqdm.tqdm(dataloader_val)): fake_data = generator(condition, landmarks, z.clone()) fake_data = torch_utils.image_to_numpy(fake_data, to_uint8=True, denormalize=True) real_data = torch_utils.image_to_numpy(real_data, to_uint8=True, denormalize=True) start_idx = idx * batch_size end_idx = (idx+1) * batch_size real_images[start_idx:end_idx] = real_data fake_images[start_idx:end_idx] = fake_data generator.cpu() del generator if os.path.isdir(save_path): shutil.rmtree(save_path) os.makedirs(os.path.join(save_path, "real")) os.makedirs(os.path.join(save_path, "fake")) def save_im(fp, im): plt.imsave(fp, im) def save_images(images, path): with multiprocessing.Pool(multiprocessing.cpu_count()) as pool: jobs = [] for idx, im in enumerate(tqdm.tqdm(images, desc="Starting jobs")): fp = os.path.join(path, "{}.jpg".format(idx)) jobs.append(pool.apply_async(save_im, (fp, im))) for j in tqdm.tqdm(jobs, desc="Saving images"): j.get() save_images(real_images, os.path.join(save_path, "real")) save_images(fake_images, os.path.join(save_path, "fake"))