Python data.data_loader.CreateDataLoader() Examples

The following are 2 code examples of data.data_loader.CreateDataLoader(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module data.data_loader , or try the search function .
Example #1
Source File: test_function.py    From non-stationary_texture_syn with MIT License 5 votes vote down vote up
def test_func(opt_train, webpage, epoch='latest'):
	opt = copy.deepcopy(opt_train)
	print(opt)
	# specify the directory to save the results during training
	opt.results_dir = './results/'
	opt.isTrain = False
	opt.nThreads = 1   # test code only supports nThreads = 1
	opt.batchSize = 1  # test code only supports batchSize = 1
	opt.serial_batches = True  # no shuffle
	opt.no_flip = True  # no flip
	opt.dataroot = opt.dataroot + '/test'
	opt.model = 'test'
	opt.dataset_mode = 'single'
	opt.which_epoch = epoch
	opt.how_many = 50
	opt.phase = 'test'
	# opt.name = name

	data_loader = CreateDataLoader(opt)
	dataset = data_loader.load_data()
	model = create_model(opt)
	visualizer = Visualizer(opt)
	# create website
	# web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
	# web_dir = os.path.join(opt.results_dir, opt.name)
	# webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
	# test
	for i, data in enumerate(dataset):
	    if i >= opt.how_many:
	        break
	    model.set_input(data)
	    model.test()
	    visuals = model.get_current_visuals()
	    img_path = model.get_image_paths()
	    print('process image... %s' % img_path)
	    visualizer.save_images_epoch(webpage, visuals, img_path, epoch)

	webpage.save() 
Example #2
Source File: ocrd_anybaseocr_dewarp.py    From ocrd_anybaseocr with Apache License 2.0 5 votes vote down vote up
def prepare_data(self, opt, page_img, path):

        sys.path.append(path)
        from data.data_loader import CreateDataLoader
        
        data_loader = CreateDataLoader(opt)
        data_loader.dataset.A_paths = [page_img.filename]
        data_loader.dataset.dataset_size = len(data_loader.dataset.A_paths)
        data_loader.dataloader = torch.utils.data.DataLoader(data_loader.dataset,
                                                             batch_size=opt.batchSize,
                                                             shuffle=not opt.serial_batches,
                                                             num_workers=int(opt.nThreads))
        dataset = data_loader.load_data()
        return dataset   
        # test