#!/usr/bin/ipython from __future__ import print_function from misc.utils import PRINT, config_yaml import os from data_loader import get_loader import config as cfg import warnings import sys import torch from misc.utils import horovod hvd = horovod() warnings.filterwarnings('ignore') def _PRINT(config): string = '------------ Options -------------' PRINT(config.log, string) for k, v in sorted(vars(config).items()): string = '%s: %s' % (str(k), str(v)) PRINT(config.log, string) string = '-------------- End ---------------' PRINT(config.log, string) def main(config): from torch.backends import cudnn # For fast training cudnn.benchmark = True data_loader = get_loader( config.mode_data, config.image_size, config.batch_size, config.dataset_fake, config.mode, num_workers=config.num_workers, all_attr=config.ALL_ATTR, c_dim=config.c_dim) from misc.scores import set_score if set_score(config): return if config.mode == 'train': from train import Train Train(config, data_loader) from test import Test test = Test(config, data_loader) test(dataset=config.dataset_real) elif config.mode == 'test': from test import Test test = Test(config, data_loader) if config.DEMO_PATH: test.DEMO(config.DEMO_PATH) else: test(dataset=config.dataset_real) if __name__ == '__main__': from misc.options import base_parser config = base_parser() if config.GPU == '-1': # Horovod torch.cuda.set_device(hvd.local_rank()) config.GPU = [int(i) for i in range(hvd.size())] config.g_lr *= hvd.size() config.d_lr *= hvd.size() else: if config.GPU == 'NO_CUDA': config.GPU = '-1' os.environ["CUDA_VISIBLE_DEVICES"] = config.GPU config.GPU = [int(i) for i in config.GPU.split(',')] config.batch_size *= len(config.GPU) config.g_lr *= len(config.GPU) config.d_lr *= len(config.GPU) torch.manual_seed(config.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(config.seed) config_yaml(config, 'datasets/{}.yaml'.format(config.dataset_fake)) config = cfg.update_config(config) if config.mode == 'train': if hvd.rank() == 0: PRINT(config.log, ' '.join(sys.argv)) _PRINT(config) main(config) config.log.close() else: main(config)