Python torch.utils.data.WeightedRandomSampler() Examples
The following are 5
code examples of torch.utils.data.WeightedRandomSampler().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.utils.data
, or try the search function
.
Example #1
Source File: test_auto.py From ignite with BSD 3-Clause "New" or "Revised" License | 6 votes |
def _test_auto_methods_xla(index, ws): dl_type = DataLoader if ws > 1: from ignite.distributed.auto import _MpDeviceLoader dl_type = _MpDeviceLoader try: from torch_xla.distributed.parallel_loader import MpDeviceLoader dl_type = MpDeviceLoader except ImportError: pass _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1, dl_type=dl_type) _test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, num_workers=10, dl_type=dl_type) _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1, sampler_name="WeightedRandomSampler", dl_type=dl_type) device = "xla" _test_auto_model_optimizer(ws, device)
Example #2
Source File: test_auto.py From ignite with BSD 3-Clause "New" or "Revised" License | 5 votes |
def _test_auto_dataloader(ws, nproc, batch_size, num_workers=1, sampler_name=None, dl_type=DataLoader): data = torch.rand(100, 3, 12, 12) if sampler_name is None: sampler = None elif sampler_name == "WeightedRandomSampler": sampler = WeightedRandomSampler(weights=torch.ones(100), num_samples=100) else: raise RuntimeError("Unknown sampler name: {}".format(sampler_name)) # Test auto_dataloader assert idist.get_world_size() == ws dataloader = auto_dataloader( data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, shuffle=sampler is None ) assert isinstance(dataloader, dl_type) if hasattr(dataloader, "_loader"): dataloader = dataloader._loader if ws < batch_size: assert dataloader.batch_size == batch_size // ws else: assert dataloader.batch_size == batch_size if ws <= num_workers: assert dataloader.num_workers == (num_workers + nproc - 1) // nproc else: assert dataloader.num_workers == num_workers if ws < 2: sampler_type = RandomSampler if sampler is None else type(sampler) assert isinstance(dataloader.sampler, sampler_type) else: sampler_type = DistributedSampler if sampler is None else DistributedProxySampler assert isinstance(dataloader.sampler, sampler_type) if isinstance(dataloader, DataLoader): assert dataloader.pin_memory == ("cuda" in idist.device().type)
Example #3
Source File: test_auto.py From ignite with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_auto_methods_no_dist(): _test_auto_dataloader(1, 1, batch_size=1) _test_auto_dataloader(1, 1, batch_size=10, num_workers=10) _test_auto_dataloader(1, 1, batch_size=10, sampler_name="WeightedRandomSampler") _test_auto_model_optimizer(1, "cpu")
Example #4
Source File: test_auto.py From ignite with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_auto_methods_gloo(distributed_context_single_node_gloo): ws = distributed_context_single_node_gloo["world_size"] _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1) _test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, num_workers=10) _test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, sampler_name="WeightedRandomSampler") _test_auto_model_optimizer(ws, "cpu")
Example #5
Source File: test_auto.py From ignite with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_auto_methods_nccl(distributed_context_single_node_nccl): ws = distributed_context_single_node_nccl["world_size"] lrank = distributed_context_single_node_nccl["local_rank"] _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1) _test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, num_workers=10) _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1, sampler_name="WeightedRandomSampler") device = "cuda" _test_auto_model_optimizer(ws, device)