# Copyright 2018/2019 The RLgraph authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import time import numpy as np from rlgraph import get_backend from rlgraph.components.component import Component from rlgraph.graphs import GraphExecutor from rlgraph.utils import util from rlgraph.utils.define_by_run_ops import define_by_run_flatten, define_by_run_unflatten from rlgraph.utils.util import force_torch_tensors if get_backend() == "pytorch": import torch class PyTorchExecutor(GraphExecutor): """ Manages execution for component graphs using define-by-run semantics. """ def __init__(self, **kwargs): super(PyTorchExecutor, self).__init__(**kwargs) self.global_training_timestep = 0 self.cuda_enabled = torch.cuda.is_available() # In PyTorch, tensors are default created on the CPU unless assigned to a visible CUDA device, # e.g. via x = tensor([0, 0], device="cuda:0") for the first GPU. self.available_devices = os.environ.get("CUDA_VISIBLE_DEVICES") # TODO handle cuda tensors self.default_torch_tensor_type = self.execution_spec.get("dtype", "torch.FloatTensor") if self.default_torch_tensor_type is not None: torch.set_default_tensor_type(self.default_torch_tensor_type) self.torch_num_threads = self.execution_spec.get("torch_num_threads", 1) self.omp_num_threads = self.execution_spec.get("OMP_NUM_THREADS", 1) # Squeeze result dims, often necessary in tests. self.remove_batch_dims = True def build(self, root_components, input_spaces, **kwargs): start = time.perf_counter() self.init_execution() meta_build_times = [] build_times = [] for component in root_components: start = time.perf_counter() meta_graph = self.meta_graph_builder.build(component, input_spaces) meta_build_times.append(time.perf_counter() - start) build_time = self.graph_builder.build_define_by_run_graph( meta_graph=meta_graph, input_spaces=input_spaces, available_devices=self.available_devices ) build_times.append(build_time) return dict( total_build_time=time.perf_counter() - start, meta_graph_build_times=meta_build_times, build_times=build_times, ) def execute(self, *api_method_calls): # Have to call each method separately. ret = [] for api_method in api_method_calls: if api_method is None: continue elif isinstance(api_method, (list, tuple)): # Which ops are supposed to be returned? op_or_indices_to_return = api_method[2] if len(api_method) > 2 else None params = util.force_list(api_method[1]) api_method = api_method[0] tensor_params = force_torch_tensors(params=params) api_ret = self.graph_builder.execute_define_by_run_op(api_method, tensor_params) is_dict_result = isinstance(api_ret, dict) if not isinstance(api_ret, list) and not isinstance(api_ret, tuple): api_ret = [api_ret] to_return = [] if op_or_indices_to_return is not None: # Op indices can be integers into a result list or strings into a result dict. if is_dict_result: if isinstance(op_or_indices_to_return, str): op_or_indices_to_return = [op_or_indices_to_return] result_dict = {} for key in op_or_indices_to_return: result_dict[key] = api_ret[0][key] to_return.append(result_dict) else: # Build return ops in correct order. # TODO clarify op indices order vs tensorflow. for i in sorted(op_or_indices_to_return): op_result = api_ret[i] if isinstance(op_result, torch.Tensor) and op_result.requires_grad is True: op_result = op_result.detach() to_return.append(op_result) else: # Just return everything in the order it was returned by the API method. if api_ret is not None: for op_result in api_ret: if isinstance(op_result, torch.Tensor) and op_result.requires_grad is True: op_result = op_result.detach() to_return.append(op_result) # Clean and return. self.clean_results(ret, to_return) else: # Api method is string without args: to_return = [] api_ret = self.graph_builder.execute_define_by_run_op(api_method) if api_ret is None: continue if not isinstance(api_ret, list) and not isinstance(api_ret, tuple): api_ret = [api_ret] for op_result in api_ret: if isinstance(op_result, torch.Tensor) and op_result.requires_grad is True: op_result = op_result.detach() to_return.append(op_result) # Clean and return. self.clean_results(ret, to_return) # Unwrap if len 1. ret = ret[0] if len(ret) == 1 else ret return ret def clean_results(self, ret, to_return): for result in to_return: if isinstance(result, dict): cleaned_dict = {k: v for k, v in result.items() if v is not None} cleaned_dict = self.clean_dict(cleaned_dict) ret.append(cleaned_dict) elif self.remove_batch_dims and isinstance(result, np.ndarray): ret.append(np.array(np.squeeze(result))) elif hasattr(result, "numpy"): ret.append(np.array(result.numpy())) else: ret.append(result) @staticmethod def clean_dict(tensor_dict): """ Detach tensor values in nested dict. Args: tensor_dict (dict): Dict containing torch tensor. Returns: dict: Dict containing numpy arrays. """ # Un-nest. param = define_by_run_flatten(tensor_dict) ret = {} # Detach tensor values. for key, value in param.items(): if isinstance(value, torch.Tensor): ret[key] = value.detach().numpy() # Pack again. return define_by_run_unflatten(ret) def read_variable_values(self, variables): # For test compatibility. if isinstance(variables, dict): ret = {} for name, var in variables.items(): ret[name] = Component.read_variable(var) return ret elif isinstance(variables, list): return [Component.read_variable(var) for var in variables] else: # Attempt to read as single var. return Component.read_variable(variables) def init_execution(self): \ # TODO Import guards here are annoying but otherwise breaks if torch is not installed. if get_backend() == "torch": torch.set_num_threads(self.torch_num_threads) os.environ["OMP_NUM_THREADS"] = str(self.omp_num_threads) def finish_graph_setup(self): # Nothing to do here for PyTorch. pass def get_available_devices(self): return self.available_devices def load_model(self, path=None): pass def store_model(self, path=None, add_timestep=True): pass def get_device_assignments(self, device_names=None): pass def terminate(self): pass