# Copyright (C) 2018 Heron Systems, Inc. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. import abc from collections import deque from functools import reduce import cv2 import torch from torch.nn import functional as F import numpy as np cv2.ocl.setUseOpenCL(False) class Operation(abc.ABC): def __init__(self, name_filters, rank_filters): self.name_filters = ( frozenset(name_filters) if name_filters else frozenset() ) self.rank_filters = ( frozenset(rank_filters) if (rank_filters and not name_filters) else frozenset() ) def reset(self): pass @abc.abstractmethod def update_shape(self, old_shape): raise NotImplementedError def update_dtype(self, old_dtype): raise NotImplementedError @abc.abstractmethod def update_obs(self, obs): raise NotImplementedError class CastToFloat(Operation): def __init__(self, name_filters=None, rank_filters=None): super(CastToFloat, self).__init__(name_filters, rank_filters) def update_shape(self, old_shape): return old_shape def update_dtype(self, old_dtype): return {k: torch.float32 for k in old_dtype.keys()} def update_obs(self, obs): return {k: ob.float() for k, ob in obs.items()} class CastToHalf(Operation): def __init__(self, name_filters=None, rank_filters=None): super(CastToHalf, self).__init__(name_filters, rank_filters) def update_shape(self, old_shape): return old_shape def update_dtype(self, old_dtype): return {k: torch.float16 for k in old_dtype.keys()} def update_obs(self, obs): return {k: ob.half() for k, ob in obs.items()} class GrayScaleAndMoveChannel(Operation): def __init__(self, name_filters=None, rank_filters=frozenset([3])): super(GrayScaleAndMoveChannel, self).__init__( name_filters, rank_filters ) def update_shape(self, old_shape): return {k: (1,) + v[:-1] for k, v in old_shape.items()} def update_dtype(self, old_dtype): return old_dtype def update_obs(self, obs): updated = {} for k, v in obs.items(): if v.dim() == 3: result = torch.from_numpy( cv2.cvtColor(v.numpy(), cv2.COLOR_RGB2GRAY) ).unsqueeze(0) elif v.dim() == 4: result = v.mean(dim=3).unsqueeze(1) else: raise ValueError( "cant grayscale a rank" + str(obs.dim()) + " tensor" ) updated[k] = result return updated class ResizeTo84x84(Operation): def __init__(self, name_filters=None, rank_filters=frozenset([3])): super().__init__(name_filters, rank_filters) def update_shape(self, old_shape): return {k: (1, 84, 84) for k, v in old_shape.items()} def update_dtype(self, old_dtype): return old_dtype def update_obs(self, obs): updated = {} for k, v in obs.items(): if v.dim() == 3: result = cv2.resize( v.squeeze(0).numpy(), (84, 84), interpolation=cv2.INTER_AREA ) result = torch.from_numpy(result).unsqueeze(0) elif v.dim() == 4: result = F.interpolate(v, (84, 84), mode="area") else: raise ValueError( "cant resize a rank" + str(obs.dim()) + " tensor to 84x84" ) updated[k] = result return updated class Divide255(Operation): def __init__(self, name_filters=None, rank_filters=frozenset([3])): super().__init__(name_filters, rank_filters) def update_shape(self, old_shape): return old_shape def update_dtype(self, old_dtype): return {k: torch.float32 for k in old_dtype.keys()} def update_obs(self, obs): updated = {} for k, v in obs.items(): if k in self.name_filters: v = v.float() v *= 1.0 / 255.0 updated[k] = v return updated class FrameStackCPU(Operation): def __init__( self, nb_frame, name_filters=None, rank_filters=frozenset([3]) ): super().__init__(name_filters, rank_filters) self.nb_frame = nb_frame self.frames = None self.obs_space = None def update_shape(self, old_shape): # lazily initialize old observation space if self.obs_space is None: self.obs_space = old_shape self.reset() updated = {} for k, v in old_shape.items(): result = (v[0] * self.nb_frame,) + v[1:] updated[k] = result return updated def update_dtype(self, old_dtype): return old_dtype def update_obs(self, obs): updated = {} for k, v in obs.items(): self.frames[k].append(v) updated[k] = self._update_obs(v) return updated def _update_obs(self, obs): if obs.dim() == 3: # cpu if len(self.frames) == self.nb_frame: return torch.cat(list(self.frames)) else: raise NotImplementedError( f"Dimensionality not supported: {obs.dim()}" ) def reset(self): self.frames = { k: deque([torch.zeros(dims)] * self.nb_frame, maxlen=self.nb_frame) for k, dims in self.obs_space.items() } class FrameStackGPU(FrameStackCPU): def reset(self): self.frames = { k: deque( [torch.zeros((1,) + v)] * self.nb_frame, maxlen=self.nb_frame ) for k, v in self.obs_space.items() } def _update_obs(self, obs): if obs.dim() == 4: if len(self.frames) == self.nb_frame: return torch.cat(list(self.frames), dim=1) else: raise NotImplementedError( f"Dimensionality not supported: {obs.dim()}" ) class FlattenSpace(Operation): def __init__(self, name_filters=None, rank_filters=None): super().__init__(name_filters, rank_filters) def update_shape(self, old_shape): updated = {} for k, olds in old_shape.items(): updated[k] = (reduce(lambda prev, cur: prev * cur, olds),) return updated def update_dtype(self, old_dtype): return old_dtype def update_obs(self, obs): updated = {} for k, v in obs.items(): updated[k] = v.view(-1) return updated class FromNumpy(Operation): def __init__(self, name_filters=None, rank_filters=None): super().__init__(name_filters, rank_filters) def update_shape(self, old_shape): return old_shape def update_dtype(self, old_dtype): updated = {} for k, v in old_dtype.items(): if v == np.float32: dt = torch.float32 elif v == np.float64: dt = torch.float64 elif v == np.float16: dt = torch.float16 elif v == np.uint8: dt = torch.uint8 elif v == np.int8: dt = torch.int8 elif v == np.int16: dt = torch.int16 elif v == np.int32: dt = torch.int32 elif v == np.int16: dt = torch.int16 else: raise ValueError("Unsupported dtype {}".format(v)) updated[k] = dt return updated def update_obs(self, obs): updated = {} for k, v in obs.items(): updated[k] = torch.from_numpy(v) return updated if __name__ == "__main__": # fstack = FrameStackCPU(4) # fstack.update_shape({"box": (3, 5, 5)}) # fstack.update_obs({"box": torch.ones(3, 5, 5)}) # print(fstack.frames) fstack = FrameStackGPU(4) fstack.update_shape({"box": (3, 5, 5)}) fstack.update_obs({"box": torch.ones(3, 3, 5, 5)}) print(fstack.frames)