Python torch.utils.data.WeightedRandomSampler() Examples

The following are 5 code examples of torch.utils.data.WeightedRandomSampler(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.utils.data , or try the search function .
Example #1
Source File: test_auto.py    From ignite with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def _test_auto_methods_xla(index, ws):

    dl_type = DataLoader
    if ws > 1:

        from ignite.distributed.auto import _MpDeviceLoader

        dl_type = _MpDeviceLoader
        try:
            from torch_xla.distributed.parallel_loader import MpDeviceLoader

            dl_type = MpDeviceLoader
        except ImportError:
            pass

    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1, dl_type=dl_type)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, num_workers=10, dl_type=dl_type)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1, sampler_name="WeightedRandomSampler", dl_type=dl_type)

    device = "xla"
    _test_auto_model_optimizer(ws, device) 
Example #2
Source File: test_auto.py    From ignite with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _test_auto_dataloader(ws, nproc, batch_size, num_workers=1, sampler_name=None, dl_type=DataLoader):

    data = torch.rand(100, 3, 12, 12)

    if sampler_name is None:
        sampler = None
    elif sampler_name == "WeightedRandomSampler":
        sampler = WeightedRandomSampler(weights=torch.ones(100), num_samples=100)
    else:
        raise RuntimeError("Unknown sampler name: {}".format(sampler_name))

    # Test auto_dataloader
    assert idist.get_world_size() == ws
    dataloader = auto_dataloader(
        data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, shuffle=sampler is None
    )

    assert isinstance(dataloader, dl_type)
    if hasattr(dataloader, "_loader"):
        dataloader = dataloader._loader
    if ws < batch_size:
        assert dataloader.batch_size == batch_size // ws
    else:
        assert dataloader.batch_size == batch_size
    if ws <= num_workers:
        assert dataloader.num_workers == (num_workers + nproc - 1) // nproc
    else:
        assert dataloader.num_workers == num_workers

    if ws < 2:
        sampler_type = RandomSampler if sampler is None else type(sampler)
        assert isinstance(dataloader.sampler, sampler_type)
    else:
        sampler_type = DistributedSampler if sampler is None else DistributedProxySampler
        assert isinstance(dataloader.sampler, sampler_type)
    if isinstance(dataloader, DataLoader):
        assert dataloader.pin_memory == ("cuda" in idist.device().type) 
Example #3
Source File: test_auto.py    From ignite with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_auto_methods_no_dist():

    _test_auto_dataloader(1, 1, batch_size=1)
    _test_auto_dataloader(1, 1, batch_size=10, num_workers=10)
    _test_auto_dataloader(1, 1, batch_size=10, sampler_name="WeightedRandomSampler")

    _test_auto_model_optimizer(1, "cpu") 
Example #4
Source File: test_auto.py    From ignite with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_auto_methods_gloo(distributed_context_single_node_gloo):

    ws = distributed_context_single_node_gloo["world_size"]
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, num_workers=10)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, sampler_name="WeightedRandomSampler")

    _test_auto_model_optimizer(ws, "cpu") 
Example #5
Source File: test_auto.py    From ignite with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_auto_methods_nccl(distributed_context_single_node_nccl):

    ws = distributed_context_single_node_nccl["world_size"]
    lrank = distributed_context_single_node_nccl["local_rank"]
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, num_workers=10)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1, sampler_name="WeightedRandomSampler")

    device = "cuda"
    _test_auto_model_optimizer(ws, device)