import torch import numpy as np def soft_update_from_to(source, target, tau): for target_param, param in zip(target.parameters(), source.parameters()): target_param.data.copy_( target_param.data * (1.0 - tau) + param.data * tau ) def copy_model_params_from_to(source, target): for target_param, param in zip(target.parameters(), source.parameters()): target_param.data.copy_(param.data) def fanin_init(tensor): size = tensor.size() if len(size) == 2: fan_in = size[0] elif len(size) > 2: fan_in = np.prod(size[1:]) else: raise Exception("Shape must be have dimension at least 2.") bound = 1. / np.sqrt(fan_in) return tensor.data.uniform_(-bound, bound) def fanin_init_weights_like(tensor): size = tensor.size() if len(size) == 2: fan_in = size[0] elif len(size) > 2: fan_in = np.prod(size[1:]) else: raise Exception("Shape must be have dimension at least 2.") bound = 1. / np.sqrt(fan_in) new_tensor = FloatTensor(tensor.size()) new_tensor.uniform_(-bound, bound) return new_tensor """ GPU wrappers """ _use_gpu = False device = None _gpu_id = 0 def set_gpu_mode(mode, gpu_id=0): global _use_gpu global device global _gpu_id _gpu_id = gpu_id _use_gpu = mode device = torch.device("cuda:" + str(gpu_id) if _use_gpu else "cpu") def gpu_enabled(): return _use_gpu def set_device(gpu_id): torch.cuda.set_device(gpu_id) # noinspection PyPep8Naming def FloatTensor(*args, torch_device=None, **kwargs): if torch_device is None: torch_device = device return torch.FloatTensor(*args, **kwargs, device=torch_device) def from_numpy(*args, **kwargs): return torch.from_numpy(*args, **kwargs).float().to(device) def get_numpy(tensor): return tensor.to('cpu').detach().numpy() def zeros(*sizes, torch_device=None, **kwargs): if torch_device is None: torch_device = device return torch.zeros(*sizes, **kwargs, device=torch_device) def ones(*sizes, torch_device=None, **kwargs): if torch_device is None: torch_device = device return torch.ones(*sizes, **kwargs, device=torch_device) def ones_like(*args, torch_device=None, **kwargs): if torch_device is None: torch_device = device return torch.ones_like(*args, **kwargs, device=torch_device) def randn(*args, torch_device=None, **kwargs): if torch_device is None: torch_device = device return torch.randn(*args, **kwargs, device=torch_device) def zeros_like(*args, torch_device=None, **kwargs): if torch_device is None: torch_device = device return torch.zeros_like(*args, **kwargs, device=torch_device) def tensor(*args, torch_device=None, **kwargs): if torch_device is None: torch_device = device return torch.tensor(*args, **kwargs, device=torch_device) def normal(*args, **kwargs): return torch.normal(*args, **kwargs).to(device) """ CPU wrappers """ def state_dict_cpu(network): pol_state_dict = network.state_dict() for k, v in pol_state_dict.items(): # Returns a copy of v in CPU memory. pol_state_dict[k] = v.cpu() return pol_state_dict