import itertools from torch.utils.data import Subset class ConcatDataloader: def __init__(self, dataloaders): self.loaders = dataloaders def __iter__(self): self.iters = [iter(loader) for loader in self.loaders] self.idx_cycle = itertools.cycle(list(range(len(self.loaders)))) return self def __next__(self): loader_idx = next(self.idx_cycle) loader = self.iters[loader_idx] batch = next(loader) if isinstance(loader.dataset, Subset): dataset = loader.dataset.dataset else: dataset = loader.dataset dat_name = dataset.pose_dataset.name batch["dataset"] = dat_name if dat_name == "stereohands" or dat_name == "zimsynth": batch["root"] = "palm" else: batch["root"] = "wrist" if dat_name == "stereohands": batch["use_stereohands"] = True else: batch["use_stereohands"] = False batch["split"] = dataset.pose_dataset.split return batch def __len__(self): return min(len(loader) for loader in self.loaders) * len(self.loaders)