Python torch.set_num_threads() Examples

The following are 30 code examples of torch.set_num_threads(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: dataloader_new.py    From MetaFGNet with MIT License 6 votes vote down vote up
def _worker_loop(dataset, index_queue, data_queue, collate_fn):
    global _use_shared_memory
    _use_shared_memory = True

    torch.set_num_threads(1)
    while True:
        r = index_queue.get()
        if r is None:
            data_queue.put(None)
            break
        idx, batch_indices = r
        try:
            samples = collate_fn([dataset[i] for i in batch_indices])
        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples)) 
Example #2
Source File: buffer.py    From rlpyt with MIT License 6 votes vote down vote up
def get_example_outputs(agent, env, examples, subprocess=False):
    """Do this in a sub-process to avoid setup conflict in master/workers (e.g.
    MKL)."""
    if subprocess:  # i.e. in subprocess.
        import torch
        torch.set_num_threads(1)  # Some fix to prevent MKL hang.
    o = env.reset()
    a = env.action_space.sample()
    o, r, d, env_info = env.step(a)
    r = np.asarray(r, dtype="float32")  # Must match torch float dtype here.
    agent.reset()
    agent_inputs = torchify_buffer(AgentInputs(o, a, r))
    a, agent_info = agent.step(*agent_inputs)
    if "prev_rnn_state" in agent_info:
        # Agent leaves B dimension in, strip it: [B,N,H] --> [N,H]
        agent_info = agent_info._replace(prev_rnn_state=agent_info.prev_rnn_state[0])
    examples["observation"] = o
    examples["reward"] = r
    examples["done"] = d
    examples["env_info"] = env_info
    examples["action"] = a  # OK to put torch tensor here, could numpify.
    examples["agent_info"] = agent_info 
Example #3
Source File: worker.py    From rlpyt with MIT License 6 votes vote down vote up
def initialize_worker(rank, seed=None, cpu=None, torch_threads=None):
    """Assign CPU affinity, set random seed, set torch_threads if needed to
    prevent MKL deadlock.
    """
    log_str = f"Sampler rank {rank} initialized"
    cpu = [cpu] if isinstance(cpu, int) else cpu
    p = psutil.Process()
    try:
        if cpu is not None:
            p.cpu_affinity(cpu)
        cpu_affin = p.cpu_affinity()
    except AttributeError:
        cpu_affin = "UNAVAILABLE MacOS"
    log_str += f", CPU affinity {cpu_affin}"
    torch_threads = (1 if torch_threads is None and cpu is not None else
        torch_threads)  # Default to 1 to avoid possible MKL hang.
    if torch_threads is not None:
        torch.set_num_threads(torch_threads)
    log_str += f", Torch threads {torch.get_num_threads()}"
    if seed is not None:
        set_seed(seed)
        time.sleep(0.3)  # (so the printing from set_seed is not intermixed)
        log_str += f", Seed {seed}"
    logger.log(log_str) 
Example #4
Source File: torch_model.py    From rltime with Apache License 2.0 6 votes vote down vote up
def __init__(self, observation_space):
        """Initializes the model with the given observation space

        Currently supported observation spaces are:
        - Box spaces
        - A tuple of box spaces, where the 1st one is the 'main' observation,
          and the rest contain additional 1D vectors of linear features for
          the model which are fed to one of the non-convolutional layers
          (Usually the RNN layer)
        """
        super().__init__()
        # When using multiple actors each with it's own CPU copy of the model,
        # we need to limit them to be single-threaded otherwise they slow each
        # other down. This should not effect training time if training is on
        # the GPU
        torch.set_num_threads(1)
        self._setup_inputs(observation_space) 
Example #5
Source File: async_actor.py    From rltime with Apache License 2.0 6 votes vote down vote up
def run(self):
        # TODO Fix this dependency. The policy itself sets the thread limit
        # to 1, but this configuration seems to be per-thread in pytorch
        # so need to set it here too :(
        import torch
        torch.set_num_threads(1)

        while not self._close_event.is_set():
            # If queue is full, wait for it not to be
            while len(self._queue) >= self._max_pending:
                self._queue_empty_event.wait()
                self._queue_empty_event.clear()
            # Get the next sample(s)
            samples = super().get_samples(1)
            self._queue.extend(samples)
            self._queue_fill_event.set() 
Example #6
Source File: remote.py    From leap with MIT License 6 votes vote down vote up
def __init__(
                self,
                env,
                policy,
                exploration_policy,
                max_path_length,
                train_rollout_function,
                eval_rollout_function,
        ):
            torch.set_num_threads(1)
            self._env = env
            self._policy = policy
            self._exploration_policy = exploration_policy
            self._max_path_length = max_path_length
            self.train_rollout_function = cloudpickle.loads(train_rollout_function)
            self.eval_rollout_function = cloudpickle.loads(eval_rollout_function) 
Example #7
Source File: cpu_sampler.py    From rlpyt with MIT License 6 votes vote down vote up
def initialize(self, affinity):
        """
        Runs inside the main sampler process.  Sets process hardware affinity
        and calls the ``agent.async_cpu()`` initialization.  Then proceeds with
        usual parallel sampler initialization.
        """
        p = psutil.Process()
        if affinity.get("set_affinity", True):
            p.cpu_affinity(affinity["master_cpus"])
        torch.set_num_threads(1)  # Needed to prevent MKL hang :( .
        self.agent.async_cpu(share_memory=True)
        super().initialize(
            agent=self.agent,
            affinity=affinity,
            seed=self.seed,
            bootstrap_value=None,  # Don't need here.
            traj_info_kwargs=None,  # Already done.
            world_size=1,
            rank=0,
        ) 
Example #8
Source File: mono_3d_estimation.py    From 3d-vehicle-tracking with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def main():
    torch.set_num_threads(multiprocessing.cpu_count())
    args = parse_args()
    if args.set == 'gta':
        from model.model import Model
    elif args.set == 'kitti':
        from model.model_cen import Model
    else:
        raise ValueError("Model not found")

    model = Model(args.arch,
                  args.roi_name,
                  args.down_ratio,
                  args.roi_kernel)
    model = nn.DataParallel(model)
    model = model.to(args.device)

    if args.phase == 'train':
        run_training(model, args)
    elif args.phase == 'test':
        test_model(model, args) 
Example #9
Source File: dataloader.py    From TAKG with MIT License 6 votes vote down vote up
def _worker_loop(dataset, index_queue, data_queue, collate_fn):
    global _use_shared_memory
    _use_shared_memory = True

    torch.set_num_threads(1)
    while True:
        r = index_queue.get()
        if r is None:
            data_queue.put(None)
            break
        idx, batch_indices = r
        try:
            samples = collate_fn([dataset[i] for i in batch_indices])
        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples)) 
Example #10
Source File: dataloader.py    From weakalign with MIT License 6 votes vote down vote up
def _worker_loop(dataset, index_queue, data_queue, collate_fn, rng_seed):
    global _use_shared_memory
    _use_shared_memory = True

    np.random.seed(rng_seed)
    torch.set_num_threads(1)
    while True:
        r = index_queue.get()
        if r is None:
            data_queue.put(None)
            break
        idx, batch_indices = r
        try:
            samples = collate_fn([dataset[i] for i in batch_indices])
        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples)) 
Example #11
Source File: async_rl.py    From rlpyt with MIT License 6 votes vote down vote up
def optim_startup(self):
        """
        Sets the hardware affinity, moves the agent's model parameters onto
        device and initialize data-parallel agent, if applicable.  Computes
        optimizer throttling settings.
        """
        main_affinity = self.affinity.optimizer[0]
        p = psutil.Process()
        if main_affinity.get("set_affinity", True):
            p.cpu_affinity(main_affinity["cpus"])
        logger.log(f"Optimizer master CPU affinity: {p.cpu_affinity()}.")
        torch.set_num_threads(main_affinity["torch_threads"])
        logger.log(f"Optimizer master Torch threads: {torch.get_num_threads()}.")
        self.agent.to_device(main_affinity.get("cuda_idx", None))
        if self.world_size > 1:
            self.agent.data_parallel()
        self.algo.optim_initialize(rank=0)
        throttle_itr = 1 + getattr(self.algo,
            "min_steps_learn", 0) // self.sampler_batch_size
        delta_throttle_itr = (self.algo.batch_size * self.world_size *
            self.algo.updates_per_optimize /  # (is updates_per_sync)
            (self.sampler_batch_size * self.algo.replay_ratio))
        self.initialize_logging()
        return throttle_itr, delta_throttle_itr 
Example #12
Source File: async_rl.py    From rlpyt with MIT License 6 votes vote down vote up
def startup(self):
        torch.distributed.init_process_group(
            backend="nccl",
            rank=self.rank,
            world_size=self.world_size,
            init_method=f"tcp://127.0.0.1:{self.port}",
        )
        p = psutil.Process()
        if self.affinity.get("set_affinity", True):
            p.cpu_affinity(self.affinity["cpus"])
        logger.log(f"Optimizer rank {self.rank} CPU affinity: {p.cpu_affinity()}.")
        torch.set_num_threads(self.affinity["torch_threads"])
        logger.log(f"Optimizer rank {self.rank} Torch threads: {torch.get_num_threads()}.")
        logger.log(f"Optimizer rank {self.rank} CUDA index: "
            f"{self.affinity.get('cuda_idx', None)}.")
        set_seed(self.seed)
        self.agent.to_device(cuda_idx=self.affinity.get("cuda_idx", None))
        self.agent.data_parallel()
        self.algo.optim_initialize(rank=self.rank) 
Example #13
Source File: dataloader.py    From keyphrase-generation-rl with MIT License 6 votes vote down vote up
def _worker_loop(dataset, index_queue, data_queue, collate_fn):
    global _use_shared_memory
    _use_shared_memory = True

    torch.set_num_threads(1)
    while True:
        r = index_queue.get()
        if r is None:
            data_queue.put(None)
            break
        idx, batch_indices = r
        try:
            samples = collate_fn([dataset[i] for i in batch_indices])
        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples)) 
Example #14
Source File: data_collector.py    From torchsupport with MIT License 6 votes vote down vote up
def _collector_worker(statistics, buffer, distributor,
                      collector, done, piecewise):
  torch.set_num_threads(1)
  while True:
    if done.value:
      break
    result = collector.sample_trajectory()
    trajectory_statistics = collector.compute_statistics(result)
    trajectory = distributor.commit_trajectory(result)

    if piecewise:
      for item in trajectory:
        buffer.append(item)
    else:
      buffer.append(trajectory)

    statistics.update(trajectory_statistics) 
Example #15
Source File: path_collector.py    From oac-explore with MIT License 6 votes vote down vote up
def __init__(self,
                 domain_name, env_seed, policy_producer,
                 max_num_epoch_paths_saved=None,
                 render=False,
                 render_kwargs=None,
                 ):

        torch.set_num_threads(1)

        env = env_producer(domain_name, env_seed)

        self._policy_producer = policy_producer

        super().__init__(env,
                         max_num_epoch_paths_saved=max_num_epoch_paths_saved,
                         render=render,
                         render_kwargs=render_kwargs,
                         ) 
Example #16
Source File: trainer.py    From mrqa with Apache License 2.0 6 votes vote down vote up
def set_random_seed(random_seed):
        if random_seed is not None:
            print("Set random seed as {}".format(random_seed))
            os.environ['PYTHONHASHSEED'] = str(random_seed)
            random.seed(random_seed)
            np.random.seed(random_seed)
            torch.manual_seed(random_seed)
            torch.cuda.manual_seed_all(random_seed)
            torch.set_num_threads(1)
            cudnn.benchmark = False
            cudnn.deterministic = True
            warnings.warn('You have chosen to seed training. '
                          'This will turn on the CUDNN deterministic setting, '
                          'which can slow down your training considerably! '
                          'You may see unexpected behavior when restarting '
                          'from checkpoints.') 
Example #17
Source File: torch_ranker_agent.py    From KBRD with MIT License 6 votes vote down vote up
def share(self):
        """Share model parameters."""
        shared = super().share()
        shared['model'] = self.model
        if self.opt.get('numthreads', 1) > 1 and isinstance(self.metrics, dict):
            torch.set_num_threads(1)
            # move metrics and model to shared memory
            self.metrics = SharedTable(self.metrics)
            self.model.share_memory()
        shared['metrics'] = self.metrics
        shared['fixed_candidates'] = self.fixed_candidates
        shared['fixed_candidate_vecs'] = self.fixed_candidate_vecs
        shared['fixed_candidate_encs'] = self.fixed_candidate_encs
        shared['vocab_candidates'] = self.vocab_candidates
        shared['vocab_candidate_vecs'] = self.vocab_candidate_vecs
        shared['optimizer'] = self.optimizer
        return shared 
Example #18
Source File: dataloader.py    From Dense-CoAttention-Network with MIT License 6 votes vote down vote up
def _worker_process_loop(dataset, index_queue, data_queue, collate_fn):
    global _use_shared_memory
    _use_shared_memory = True

    torch.set_num_threads(1)
    while True:
        r = index_queue.get()
        if r is None:
            data_queue.put(None)
            break
        idx, batch_indices = r
        try:
            samples = collate_fn([dataset[i] for i in batch_indices])
        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples)) 
Example #19
Source File: dataloader.py    From OISR-PyTorch with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id):
    global _use_shared_memory
    _use_shared_memory = True
    _set_worker_signal_handlers()

    torch.set_num_threads(1)
    torch.manual_seed(seed)
    while True:
        r = index_queue.get()
        if r is None:
            break
        idx, batch_indices = r
        try:
            idx_scale = 0
            if len(scale) > 1 and dataset.train:
                idx_scale = random.randrange(0, len(scale))
                dataset.set_scale(idx_scale)

            samples = collate_fn([dataset[i] for i in batch_indices])
            samples.append(idx_scale)

        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples)) 
Example #20
Source File: dataloader.py    From OISR-PyTorch with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id):
    global _use_shared_memory
    _use_shared_memory = True
    _set_worker_signal_handlers()

    torch.set_num_threads(1)
    torch.manual_seed(seed)
    while True:
        r = index_queue.get()
        if r is None:
            break
        idx, batch_indices = r
        try:
            idx_scale = 0
            if len(scale) > 1 and dataset.train:
                idx_scale = random.randrange(0, len(scale))
                dataset.set_scale(idx_scale)

            samples = collate_fn([dataset[i] for i in batch_indices])
            samples.append(idx_scale)

        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples)) 
Example #21
Source File: dataloader.py    From OISR-PyTorch with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id):
    global _use_shared_memory
    _use_shared_memory = True
    _set_worker_signal_handlers()

    torch.set_num_threads(1)
    torch.manual_seed(seed)
    while True:
        r = index_queue.get()
        if r is None:
            break
        idx, batch_indices = r
        try:
            idx_scale = 0
            if len(scale) > 1 and dataset.train:
                idx_scale = random.randrange(0, len(scale))
                dataset.set_scale(idx_scale)

            samples = collate_fn([dataset[i] for i in batch_indices])
            samples.append(idx_scale)

        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples)) 
Example #22
Source File: dataloader.py    From OISR-PyTorch with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id):
    global _use_shared_memory
    _use_shared_memory = True
    _set_worker_signal_handlers()

    torch.set_num_threads(1)
    torch.manual_seed(seed)
    while True:
        r = index_queue.get()
        if r is None:
            break
        idx, batch_indices = r
        try:
            idx_scale = 0
            if len(scale) > 1 and dataset.train:
                idx_scale = random.randrange(0, len(scale))
                dataset.set_scale(idx_scale)

            samples = collate_fn([dataset[i] for i in batch_indices])
            samples.append(idx_scale)

        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples)) 
Example #23
Source File: dataloader.py    From OISR-PyTorch with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id):
    global _use_shared_memory
    _use_shared_memory = True
    _set_worker_signal_handlers()

    torch.set_num_threads(1)
    torch.manual_seed(seed)
    while True:
        r = index_queue.get()
        if r is None:
            break
        idx, batch_indices = r
        try:
            idx_scale = 0
            if len(scale) > 1 and dataset.train:
                idx_scale = random.randrange(0, len(scale))
                dataset.set_scale(idx_scale)

            samples = collate_fn([dataset[i] for i in batch_indices])
            samples.append(idx_scale)

        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples)) 
Example #24
Source File: dataloader.py    From seq2seq-keyphrase-pytorch with Apache License 2.0 6 votes vote down vote up
def _worker_loop(dataset, index_queue, data_queue, collate_fn):
    global _use_shared_memory
    _use_shared_memory = True

    torch.set_num_threads(1)
    while True:
        r = index_queue.get()
        if r is None:
            data_queue.put(None)
            break
        idx, batch_indices = r
        try:
            samples = collate_fn([dataset[i] for i in batch_indices])
        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples)) 
Example #25
Source File: EndToEnd_Evaluation.py    From gnn-comparison with GNU General Public License v3.0 6 votes vote down vote up
def main(config_file, dataset_name,
         outer_k, outer_processes, inner_k, inner_processes, result_folder, debug=False):

    # Needed to avoid thread spawning, conflicts with multi-processing. You may set a number > 1 but take into account
    # the number of processes on the machine
    torch.set_num_threads(1)

    experiment_class = EndToEndExperiment

    model_configurations = Grid(config_file, dataset_name)
    model_configuration = Config(**model_configurations[0])

    exp_path = os.path.join(result_folder, f'{model_configuration.exp_name}_assessment')

    model_selector = HoldOutSelector(max_processes=inner_processes)
    risk_assesser = KFoldAssessment(outer_k, model_selector, exp_path, model_configurations,
                                    outer_processes=outer_processes)

    risk_assesser.risk_assessment(experiment_class, debug=debug) 
Example #26
Source File: ppo_pybullet.py    From cherry with Apache License 2.0 6 votes vote down vote up
def main(env='MinitaurTrottingEnv-v0'):
    env = gym.make(env)
    env = envs.AddTimestep(env)
    env = envs.Logger(env, interval=PPO_STEPS)
    env = envs.Normalizer(env, states=True, rewards=True)
    env = envs.Torch(env)
    # env = envs.Recorder(env)
    env = envs.Runner(env)
    env.seed(SEED)

    th.set_num_threads(1)
    policy = ActorCriticNet(env)
    optimizer = optim.Adam(policy.parameters(), lr=LR, eps=1e-5)
    num_updates = TOTAL_STEPS // PPO_STEPS + 1
    lr_schedule = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 1 - epoch/num_updates)
    get_action = lambda state: get_action_value(state, policy)

    for epoch in range(num_updates):
        # We use the Runner collector, but could've written our own
        replay = env.run(get_action, steps=PPO_STEPS, render=RENDER)

        # Update policy
        update(replay, optimizer, policy, env, lr_schedule) 
Example #27
Source File: learner.py    From sample-factory with MIT License 5 votes vote down vote up
def initialize(self, timing):
        with timing.timeit('init'):
            # initialize the Torch modules
            if self.cfg.seed is None:
                log.info('Starting seed is not provided')
            else:
                log.info('Setting fixed seed %d', self.cfg.seed)
                torch.manual_seed(self.cfg.seed)
                np.random.seed(self.cfg.seed)

            # this does not help with a single experiment
            # but seems to do better when we're running more than one experiment in parallel
            torch.set_num_threads(1)

            if self.cfg.device == 'gpu':
                torch.backends.cudnn.benchmark = True

                # we should already see only one CUDA device, because of env vars
                assert torch.cuda.device_count() == 1
                self.device = torch.device('cuda', index=0)
            else:
                self.device = torch.device('cpu')
            self.init_model(timing)

            self.optimizer = torch.optim.Adam(
                self.actor_critic.parameters(),
                self.cfg.learning_rate,
                betas=(self.cfg.adam_beta1, self.cfg.adam_beta2),
                eps=self.cfg.adam_eps,
            )

            self.load_from_checkpoint(self.policy_id)

            self._broadcast_model_weights()  # sync the very first version of the weights

        self.train_thread_initialized.set() 
Example #28
Source File: dataloader.py    From 3D_Appearance_SR with MIT License 5 votes vote down vote up
def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id):
    global _use_shared_memory
    _use_shared_memory = True
    _set_worker_signal_handlers()

    torch.set_num_threads(1)
    torch.manual_seed(seed)
    while True:
        r = index_queue.get()
        if r is None:
            break
        idx, batch_indices = r
        try:
            idx_scale = 0
            if len(scale) > 1 and dataset.train:
                idx_scale = random.randrange(0, len(scale))
                dataset.set_scale(idx_scale)

            samples = collate_fn([dataset[i] for i in batch_indices])
            samples.append(idx_scale)
            #This is why idx_scale appears in the samples of the train loader

        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples)) 
Example #29
Source File: utils.py    From robotics-rl-srl with MIT License 5 votes vote down vote up
def _run(self, env_kwargs):
        # this is to control the number of CPUs that torch is allowed to use.
        # By default it will use all CPUs, even with GPU acceleration
        th.set_num_threads(1)
        self.model = loadSRLModel(env_kwargs.get("srl_model_path", None), th.cuda.is_available(), self.state_dim,
                                  env_object=None)
        # run until the end of the caller thread
        while True:
            # pop an item, get state, and return to sender.
            env_id, var = self.pipe[0].get()
            self.pipe[1][env_id].put(self.model.getState(var, env_id=env_id)) 
Example #30
Source File: torchloader.py    From mxbox with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _worker_loop(dataset, index_queue, data_queue, collate_fn):
    global _use_shared_memory
    _use_shared_memory = True

    # torch.set_num_threads(1)
    while True:
        r = index_queue.get()
        if r is None:
            data_queue.put(None)
            break
        idx, batch_indices = r
        try:
            samples = collate_fn([dataset[i] for i in batch_indices])
        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples))


# numpy_type_map = {
#     'float64': torch.DoubleTensor,
#     'float32': torch.FloatTensor,
#     'float16': torch.HalfTensor,
#     'int64': torch.LongTensor,
#     'int32': torch.IntTensor,
#     'int16': torch.ShortTensor,
#     'int8': torch.CharTensor,
#     'uint8': torch.ByteTensor,
# }