#!/usr/bin/env python # -*- coding: utf-8 -*- import os import csv import argparse import numpy as np import torch import torch.utils import torchvision import torch.backends.cudnn as cudnn from common import utils from common.utils import RandomPixelMasking, RandomHalfMasking, CenterMasking from common.eval_test import evaluate from inpainting_int.cae_model import ProbablisticCAE from inpainting_int.train import arch_search_valid, train def load_data(path='../data/', data_name='celebA', img_size=64): print('Loading ' + data_name + 'data...') train_transform, test_transform = utils.data_transforms(img_size=img_size) if data_name != 'svhn': # The image data should be contained in sub folders (e.g., ../data/celebA/train/image/aaa.png) train_data = torchvision.datasets.ImageFolder('{}{}/train'.format(path, data_name), transform=train_transform) test_data = torchvision.datasets.ImageFolder('{}{}/test'.format(path, data_name), transform=test_transform) else: train_data = torchvision.datasets.SVHN(path, split='train', transform=train_transform, download=True) test_data = torchvision.datasets.SVHN(path, split='test', transform=test_transform, download=True) # extra_data = torchvision.datasets.SVHN(path, split='extra', transform=train_transform, download=True) # train_data = torch.utils.data.ConcatDataset([train_data, extra_data]) print('train_data_size: %d, test_data_size: %d' % (len(train_data), len(test_data))) return train_data, test_data # Save result data class SaveResult(object): def __init__(self, res_file_name='result.csv'): self.res_file_name = res_file_name # header with open(self.res_file_name, 'w') as fp: writer = csv.writer(fp, lineterminator='\n') writer.writerow(['exp_index', 'train_time', 'MLE_MSE', 'MLE_PSNR', 'MLE_SSIM', 'det_param', 'max_param', 'node_num', 'cat_d', 'cat_valid_d', 'n_cat', 'int_d', 'n_int', 'active_num', 'net_str']) def save(self, exp_index, model, train_time, res): dist = model.asng params = np.sum(np.prod(param.size()) for param in model.parameters()) net_str = model.mle_network_string(sep=' ') with open(self.res_file_name, 'a') as fp: writer = csv.writer(fp, lineterminator='\n') writer.writerow([exp_index, train_time, res['MLE_MSE'], res['MLE_PSNR'], res['MLE_SSIM'], model.get_params_mle(), params, len(model.module_info), dist.d_cat, dist.valid_d_cat, dist.n_cat, dist.d_int, dist.n_int, int(model.is_active.sum()), net_str]) def experiment(exp_num=1, start_id=0, data_name='celebA', dataset_path='../data/', corrupt_type='RandomPixel', gpu_id=0, init_delta_factor=0.0, batchsize=16, train_ite=200000, retrain_ite=500000, out_dir='./result/'): if gpu_id >= 0: torch.cuda.set_device(gpu_id) cudnn.benchmark = True cudnn.enabled = True if not os.path.exists(out_dir): os.makedirs(out_dir) # Corrupt function if corrupt_type == 'RandomPixel': corrupt_func = RandomPixelMasking() elif corrupt_type == 'RandomHalf': corrupt_func = RandomHalfMasking() elif corrupt_type == 'Center': corrupt_func = CenterMasking() else: print('Invalid corrupt function type!') return train_res = SaveResult(res_file_name=out_dir + 'train_result.csv') retrain_res = SaveResult(res_file_name=out_dir + 'retrain_result.csv') with open(out_dir + 'description.txt', 'w') as o: o.write('data_name: ' + data_name + '\n') o.write('corrupt_func: ' + corrupt_type + '\n') o.write('batchsize: %d\n' % batchsize) o.write('train_ite: %d\n' % train_ite) o.write('retrain_ite: %d\n' % retrain_ite) train_data, test_data = load_data(path=dataset_path, data_name=data_name, img_size=64) ch_size = train_data[0][0].shape[0] for n in np.arange(start_id, start_id + exp_num): prefix = out_dir + '{:02d}_'.format(n) print('Architecture Search...') nn_model = ProbablisticCAE(in_ch_size=ch_size, out_ch_size=ch_size, row_size=1, col_size=20, level_back=5, downsample=True, k_sizes=(1, 3, 5), ch_range=(64, 256), c=None, delta_init_factor=init_delta_factor) optimizer = torch.optim.SGD(nn_model.parameters(), lr=0.025, momentum=0.9, weight_decay=0., nesterov=False) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, train_ite) # Training period = {'max_ite': train_ite, 'save': train_ite/50, 'verbose_ite': 100} n_model, train_time = \ arch_search_valid(nn_model, train_data, test_data, corrupt_func, optimizer, lr_scheduler, clip_value=5., batchsize=batchsize, lam=2, valid_rate=0.5, gpu_id=gpu_id, period=period, out_model=prefix + 'trained_model.pt', log_file=prefix + 'train_log.csv') # Testing res = evaluate(nn_model, test_data, corrupt_func, gpu_id=gpu_id, batchsize=batchsize, img_out_dir=prefix+'trained_model_out_img/') train_res.save(n, nn_model, train_time, res) # Save result # Load theta from log file #import pandas as pd #df = pd.read_csv(prefix + 'train_log.csv') #theta = np.array(df.iloc[-1, 14:]) #nn_model = ProbablisticCAE(in_ch_size=ch_size, out_ch_size=ch_size, row_size=1, col_size=20, level_back=5, # downsample=True, k_sizes=(1, 3, 5), ch_nums=(64, 128, 256), skip=(True, False), # M=None) #nn_model.asng.load_theta_from_log(theta) print('Retraining...') nn_model = ProbablisticCAE(in_ch_size=ch_size, out_ch_size=ch_size, row_size=1, col_size=20, level_back=5, downsample=True, k_sizes=(1, 3, 5), ch_range=(64, 256), c=nn_model.asng.mle()) optimizer = torch.optim.Adam(nn_model.parameters(), lr=0.001, betas=(0.9, 0.999)) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(retrain_ite*2/5), int(retrain_ite*4/5)], gamma=0.1) # Re-training period = {'max_ite': retrain_ite, 'save': retrain_ite/50, 'verbose_ite': 100} nn_model, train_time = train(nn_model, train_data, test_data, corrupt_func, optimizer, lr_scheduler, clip_value=5., batchsize=batchsize, gpu_id=gpu_id, period=period, out_model=prefix + 'retrained_model.pt', log_file=prefix + 'retrain_log.csv') # Testing res = evaluate(nn_model, test_data, corrupt_func, gpu_id=gpu_id, batchsize=batchsize, img_out_dir=prefix + 'retrained_model_out_img/') retrain_res.save(n, nn_model, train_time, res) # Save result if __name__ == '__main__': parser = argparse.ArgumentParser(description='ASNG-NAS (Int) for Inpainting') parser.add_argument('--exp_id_start', '-s', type=int, default=0, help='Starting index number of experiment') parser.add_argument('--exp_num', '-e', type=int, default=1, help='Number of experiments') parser.add_argument('--data_path', '-p', default='../data/', help='Data path') parser.add_argument('--data_name', '-d', default='celebA', help='Data name (celebA / cars / svhn)') parser.add_argument('--corrupt_type', '-c', default='RandomPixel', help='Corrupt function (RandomPixel / RandomHalf / Center)') parser.add_argument('--gpu_id', '-g', type=int, default=0, help='GPU ID') parser.add_argument('--init_delta_factor', '-f', type=float, default=0.0, help='Init delta factor') parser.add_argument('--batch_size', '-b', type=int, default=16, help='Mini-batch size') parser.add_argument('--train_ite', '-t', type=int, default=50000, help='Maximum number of training iterations (W updates)') parser.add_argument('--retrain_ite', '-r', type=int, default=500000, help='Maximum number of re-training iterations (W updates)') parser.add_argument('--out_dir', '-o', default='./result/', help='Output directory') args = parser.parse_args() start_id = args.exp_id_start exp_num = args.exp_num data_path = args.data_path data_name = args.data_name corrupt_type = args.corrupt_type gpu_id = args.gpu_id init_delta_factor = args.init_delta_factor batch_size = args.batch_size train_ite = args.train_ite retrain_ite = args.retrain_ite out_dir = args.out_dir + data_name + '_' + corrupt_type + '/' experiment(exp_num=exp_num, start_id=start_id, data_name=data_name, dataset_path=data_path, corrupt_type=corrupt_type, gpu_id=gpu_id, init_delta_factor=init_delta_factor, batchsize=batch_size, train_ite=train_ite, retrain_ite=retrain_ite, out_dir=out_dir)