from .cem_base_controller import CEMBaseController from .visualizer.construct_html import save_gifs, save_img, save_html, fill_template from visual_mpc.video_prediction.pred_util import get_context, rollout_predictions import numpy as np import imp from collections import OrderedDict import sys if sys.version_info[0] == 2: input_fn = raw_input else: input_fn = input class HumanCEMController(CEMBaseController): def __init__(self, ag_params, policyparams, gpu_id, ngpu): """ :param ag_params: agent parameter dictionary :param policyparams: policy parameter dict :param gpu_id: gpu id :param ngpu: number of gpus to use """ CEMBaseController.__init__(self, ag_params, policyparams) params = imp.load_source('params', ag_params['current_dir'] + '/conf.py') netconf = params.configuration self.predictor = netconf['setup_predictor'](ag_params, netconf, gpu_id, ngpu, self._logger) self._net_bsize = netconf['batch_size'] self._net_context = netconf['context_frames'] self._hp.start_planning = self._net_context self._n_cam = netconf['ncam'] self._images, self._verbose_worker = None, None self._save_actions = None def reset(self): super(HumanCEMController, self).reset() self._save_actions = None def _default_hparams(self): default_dict = { "verbose_img_height": 128, 'state_append': None, } parent_params = super(HumanCEMController, self)._default_hparams() for k in default_dict.keys(): parent_params.add_hparam(k, default_dict[k]) return parent_params def evaluate_rollouts(self, actions, cem_itr): last_frames, last_states = get_context(self._net_context, self._t, self._state, self._images, self._hp) gen_images = rollout_predictions(self.predictor, self._net_bsize, actions, last_frames, last_states, logger=self._logger)[0] gen_images = np.concatenate(gen_images, 0) verbose_folder = "planning_{}_itr_{}".format(self._t, cem_itr) content_dict = OrderedDict() # start images for c in range(self._n_cam): name = 'cam_{}_start'.format(c) save_path = save_img(self._verbose_worker, verbose_folder, name, self._images[-1, c]) content_dict[name] = [save_path for _ in range(gen_images.shape[0])] # render predicted images for c in range(self._n_cam): verbose_images = [(gen_images[g_i, :, c] * 255).astype(np.uint8) for g_i in range(gen_images.shape[0])] row_name = 'cam_{}_pred_images'.format(c) content_dict[row_name] = save_gifs(self._verbose_worker, verbose_folder, row_name, verbose_images) html_page = fill_template(cem_itr, self._t, content_dict, img_height=self._hp.verbose_img_height) save_html(self._verbose_worker, "{}/preds.html".format(verbose_folder), html_page) scores = np.zeros(gen_images.shape[0]) for i in range(gen_images.shape[0]): scores[i] = float(input_fn("Score for traj {}: ".format(i))) content_dict['scores'] = scores html_page = fill_template(cem_itr, self._t, content_dict, img_height=self._hp.verbose_img_height) save_html(self._verbose_worker, "{}/plan.html".format(verbose_folder), html_page) return scores def act(self, t=None, i_tr=None, images=None, state=None, verbose_worker=None): """ Return a random action for a state. Args: if performing highres tracking images is highres image t: the current controller's Time step goal_pix: in coordinates of small image desig_pix: in coordinates of small image """ if t <= 0 and 'y' == input_fn("restore traj?: "): import cPickle as pkl self._save_actions = pkl.load(open(input_fn('path:'), 'rb')) import pdb; pdb.set_trace() if self._save_actions is not None and t < len(self._save_actions): return {'actions': self._save_actions[t]['actions']} self._images = images self._verbose_worker = verbose_worker return super(HumanCEMController, self).act(t, i_tr, state)