from __future__ import print_function, division import os import subprocess import json from collections import OrderedDict import torch as th import numpy as np from termcolor import colored import argparse def parseLossArguments(choices, help): """ Creates a custom type for loss parsing, it overrides the type, choice and help of add_argument, in order to properly extract the loss type, and still be able to print the choices available. Example: in: 'autoencoder:1:10' (loss:weight:state_dim) out: autoencoder, 1, 10 (loss_name, weight, state_dim) :param choices: ([str]) the list of valid losses :param help: (str) help string :return: (dict) the arguments for parse arg """ def _arg_type(arg): arg_separator = arg.count(':') if arg_separator >= 1: if arg.split(':')[0] not in choices: raise argparse.ArgumentTypeError( "invalid choice: {} (choose from {})".format(arg.split(':')[0], choices)) try: loss, first_arg, second_arg = arg.split(':')[0], float(arg.split(':')[1]), 0 if arg_separator == 2: second_arg = int(arg.split(':')[2]) return loss, first_arg, second_arg except ValueError: raise argparse. \ ArgumentTypeError("Error: must be of format '<str>:<float>:<int>', '<str>:<float/int>' or '<str>'") else: if arg not in choices: raise argparse.ArgumentTypeError("invalid choice: {} (choose from {})".format(arg, choices)) return arg def _choices_print(): str_out = "{" for loss in choices[:-1]: str_out += loss + ", " return str_out + choices[-1] + '}' return {'type': _arg_type, 'help': _choices_print() + " " + help} def buildConfig(args): """ Building the config file for the trainer :param args: (parsed args object) :return: (OrderedDict) """ # Fixes to use this function in srl_baselines/ split_dimensions = args.split_dimensions if hasattr(args, "split_dimensions") else -1 beta = args.beta if hasattr(args, "beta") else -1 l1_reg = args.l1_reg if hasattr(args, "l1_reg") else 0 l2_reg = args.l2_reg if hasattr(args, "l2_reg") else 0 if "supervised" in args.losses: args.inverse_model_type = None exp_config = OrderedDict( [("batch-size", args.batch_size), ("beta", beta), ("data-folder", args.data_folder), ("epochs", args.epochs), ("learning-rate", args.learning_rate), ("training-set-size", args.training_set_size), ("log-folder", ""), ("model-type", args.model_type), ("seed", args.seed), ("state-dim", args.state_dim), ("knn-samples", 200), ("knn-seed", 1), ("l1-reg", l1_reg), ("l2-reg", l2_reg), ("losses", args.losses), ("n-neighbors", 5), ("n-to-plot", 5), ("split-dimensions", split_dimensions), ("inverse-model-type", args.inverse_model_type)] ) return exp_config def loadData(data_folder): """ :param data_folder: (str) path to the data_folder to be loaded :return: (Numpy dictionary-like objects and np.ndarrays) """ training_data = np.load('data/{}/preprocessed_data.npz'.format(data_folder)) episode_starts = training_data['episode_starts'] ground_truth = np.load('data/{}/ground_truth.npz'.format(data_folder)) # Backward compatibility with previous names true_states = ground_truth['ground_truth_states' if 'ground_truth_states' in ground_truth.keys() else 'arm_states'] target_positions = \ ground_truth['target_positions' if 'target_positions' in ground_truth.keys() else 'button_positions'] with open('data/{}/dataset_config.json'.format(data_folder), 'r') as f: relative_pos = json.load(f).get('relative_pos', False) target_pos_ = [] # True state is the relative position to the target target_idx = -1 for i in range(len(episode_starts)): if episode_starts[i] == 1: target_idx += 1 if relative_pos: true_states[i] -= target_positions[target_idx] target_pos_.append(target_positions[target_idx]) target_pos_ = np.array(target_pos_) return training_data, ground_truth, true_states, target_pos_ def getInputBuiltin(): """ Python 2/3 compatibility Returns the python 'input' builtin :return: (input) """ try: return raw_input except NameError: return input def importMaplotlib(): """ Fix for plotting when x11 is not available """ p = subprocess.Popen(["xset", "-q"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) p.communicate() x11_available = p.returncode == 0 if not x11_available: import matplotlib matplotlib.use('Agg') def detachToNumpy(tensor): """ Gets a th.Tensor and returns a np.ndarray :param tensor: (th.Tensor) :return: (numpy float) """ return tensor.to(th.device('cpu')).detach().numpy() def parseDataFolder(path): """ Remove `data/` from dataset folder path if needed :param path: (str) :return: (str) name of the dataset folder """ if path.startswith('data/'): path = path[5:] return path def createFolder(path_to_folder, exist_msg): """ Try to create a folder (and parents if needed) print a message in case the folder already exist :param path_to_folder: (str) :param exist_msg: """ try: os.makedirs(path_to_folder) except OSError: print(exist_msg) def printGreen(string): """ Print a string in green in the terminal :param string: (str) """ print(colored(string, 'green')) def printYellow(string): """ :param string: (str) """ print(colored(string, 'yellow')) def printRed(string): """ :param string: (str) """ print(colored(string, 'red')) def printBlue(string): """ :param string: (str) """ print(colored(string, 'blue'))