#!/usr/bin/env python """Convert trajectories from `imitation` format to openai/baselines GAIL format.""" import argparse import os from pathlib import Path from typing import Sequence import numpy as np from imitation.data import rollout, types def convert_trajs_to_sb(trajs: Sequence[types.TrajectoryWithRew]) -> dict: """Converts Trajectories into the dict format used by Stable Baselines GAIL.""" trans = rollout.flatten_trajectories_with_rew(trajs) return dict( acs=trans.acts, rews=trans.rews, obs=trans.obs, ep_rets=np.array([np.sum(t.rews) for t in trajs]), ) def main(): parser = argparse.ArgumentParser() parser.add_argument("src_path", type=str) parser.add_argument("dst_path", type=str) args = parser.parse_args() src_path = Path(args.src_path) dst_path = Path(args.dst_path) assert src_path.is_file() src_trajs = types.load(str(src_path)) dst_trajs = convert_trajs_to_sb(src_trajs) os.makedirs(dst_path.parent, exist_ok=True) with open(dst_path, "wb") as f: np.savez_compressed(f, **dst_trajs) print(f"Dumped rollouts to {dst_path}") if __name__ == "__main__": main()