Python pickle.dump() Examples

The following are 30 code examples for showing how to use pickle.dump(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may check out the related API usage on the sidebar.

You may also want to check out all available functions/classes of the module pickle , or try the search function .

Example 1
Project: svviz   Author: svviz   File: app.py    License: MIT License 6 votes vote down vote up
def saveState(dataHub):
    import pickle as pickle
    import gzip

    pickle.dump(dataHub, gzip.open(dataHub.args.save_state, "wb"))
    logging.warn("^"*20 + " saving state to pickle and exiting " + "^"*20) 
Example 2
Project: wechat-alfred-workflow   Author: TKkk-iOSer   File: workflow.py    License: MIT License 6 votes vote down vote up
def register(self, name, serializer):
        """Register ``serializer`` object under ``name``.

        Raises :class:`AttributeError` if ``serializer`` in invalid.

        .. note::

            ``name`` will be used as the file extension of the saved files.

        :param name: Name to register ``serializer`` under
        :type name: ``unicode`` or ``str``
        :param serializer: object with ``load()`` and ``dump()``
            methods

        """
        # Basic validation
        getattr(serializer, 'load')
        getattr(serializer, 'dump')

        self._serializers[name] = serializer 
Example 3
Project: wechat-alfred-workflow   Author: TKkk-iOSer   File: workflow.py    License: MIT License 6 votes vote down vote up
def cache_data(self, name, data):
        """Save ``data`` to cache under ``name``.

        If ``data`` is ``None``, the corresponding cache file will be
        deleted.

        :param name: name of datastore
        :param data: data to store. This may be any object supported by
                the cache serializer

        """
        serializer = manager.serializer(self.cache_serializer)

        cache_path = self.cachefile('%s.%s' % (name, self.cache_serializer))

        if data is None:
            if os.path.exists(cache_path):
                os.unlink(cache_path)
                self.logger.debug('deleted cache file: %s', cache_path)
            return

        with atomic_writer(cache_path, 'wb') as file_obj:
            serializer.dump(data, file_obj)

        self.logger.debug('cached data: %s', cache_path) 
Example 4
def gt_roidb(self):
    """
    Return the database of ground-truth regions of interest.

    This function loads/saves from/to a cache file to speed up future calls.
    """
    cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
    if os.path.exists(cache_file):
      with open(cache_file, 'rb') as fid:
        try:
          roidb = pickle.load(fid)
        except:
          roidb = pickle.load(fid, encoding='bytes')
      print('{} gt roidb loaded from {}'.format(self.name, cache_file))
      return roidb

    gt_roidb = [self._load_pascal_labels(index)
                for index in self.image_index]
    with open(cache_file, 'wb') as fid:
      pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
    print('wrote gt roidb to {}'.format(cache_file))

    return gt_roidb 
Example 5
def _write_coco_results_file(self, all_boxes, res_file):
    # [{"image_id": 42,
    #   "category_id": 18,
    #   "bbox": [258.15,41.29,348.26,243.78],
    #   "score": 0.236}, ...]
    results = []
    for cls_ind, cls in enumerate(self.classes):
      if cls == '__background__':
        continue
      print('Collecting {} results ({:d}/{:d})'.format(cls, cls_ind,
                                                       self.num_classes - 1))
      coco_cat_id = self._class_to_coco_cat_id[cls]
      results.extend(self._coco_results_one_category(all_boxes[cls_ind],
                                                     coco_cat_id))
    print('Writing results json to {}'.format(res_file))
    with open(res_file, 'w') as fid:
      json.dump(results, fid) 
Example 6
Project: comet-commonsense   Author: atcbosselut   File: data.py    License: Apache License 2.0 6 votes vote down vote up
def save_eval_file(opt, stats, eval_type="losses", split="dev", ext="pickle"):
    if cfg.test_save:
        name = "{}/{}.{}".format(utils.make_name(
            opt, prefix="garbage/{}/".format(eval_type),
            is_dir=True, eval_=True), split, ext)
    else:
        name = "{}/{}.{}".format(utils.make_name(
            opt, prefix="results/{}/".format(eval_type),
            is_dir=True, eval_=True), split, ext)
    print("Saving {} {} to {}".format(split, eval_type, name))

    if ext == "pickle":
        with open(name, "wb") as f:
            pickle.dump(stats, f)
    elif ext == "txt":
        with open(name, "w") as f:
            f.write(stats)
    elif ext == "json":
        with open(name, "w") as f:
            json.dump(stats, f)
    else:
        raise 
Example 7
Project: gated-graph-transformer-network   Author: hexahedria   File: update_cache_compatibility.py    License: MIT License 6 votes vote down vote up
def main(cache_dir):
    files_list = list(os.listdir(cache_dir))
    for file in files_list:
        full_filename = os.path.join(cache_dir, file)
        if os.path.isfile(full_filename):
            print("Processing {}".format(full_filename))
            m, stored_kwargs = pickle.load(open(full_filename, 'rb'))
            updated_kwargs = util.get_compatible_kwargs(model.Model, stored_kwargs)

            model_hash = util.object_hash(updated_kwargs)
            print("New hash -> " + model_hash)
            model_filename = os.path.join(cache_dir, "model_{}.p".format(model_hash))
            sys.setrecursionlimit(100000)
            pickle.dump((m,updated_kwargs), open(model_filename,'wb'), protocol=pickle.HIGHEST_PROTOCOL)

            os.remove(full_filename) 
Example 8
Project: Neural-LP   Author: fanyangxyz   File: experiment.py    License: MIT License 6 votes vote down vote up
def train(self):
        while (self.epoch < self.option.max_epoch and not self.early_stopped):
            self.one_epoch_train()
            self.one_epoch_valid()
            self.one_epoch_test()
            self.epoch += 1
            model_path = self.saver.save(self.sess, 
                                         self.option.model_path,
                                         global_step=self.epoch)
            print("Model saved at %s" % model_path)
            
            if self.early_stop():
                self.early_stopped = True
                print("Early stopped at epoch %d" % (self.epoch))
        
        all_test_in_top = [np.mean(x[1]) for x in self.test_stats]
        best_test_epoch = np.argmax(all_test_in_top)
        best_test = all_test_in_top[best_test_epoch]
        
        msg = "Best test in top: %0.4f at epoch %d." % (best_test, best_test_epoch + 1)       
        print(msg)
        self.log_file.write(msg + "\n")
        pickle.dump([self.train_stats, self.valid_stats, self.test_stats],
                    open(os.path.join(self.option.this_expsdir, "results.pckl"), "w")) 
Example 9
Project: graph-neural-networks   Author: alelab-upenn   File: miscTools.py    License: GNU General Public License v3.0 6 votes vote down vote up
def saveSeed(randomStates, saveDir):
    """
    Takes a list of dictionaries of random generator states of different modules
    and saves them in a .pkl format.
    
    Inputs:
        randomStates (list): The length of this list is equal to the number of
            modules whose states want to be saved (torch, numpy, etc.). Each
            element in this list is a dictionary. The dictionary has three keys:
            'module' with the name of the module in string format ('numpy' or
            'torch', for example), 'state' with the saved generator state and,
            if corresponds, 'seed' with the specific seed for the generator
            (note that torch has both state and seed, but numpy only has state)
        saveDir (path): where to save the seed, it will be saved under the 
            filename 'randomSeedUsed.pkl'
    """
    pathToSeed = os.path.join(saveDir, 'randomSeedUsed.pkl')
    with open(pathToSeed, 'wb') as seedFile:
        pickle.dump({'randomStates': randomStates}, seedFile) 
Example 10
Project: Paradrop   Author: ParadropLabs   File: pd_storage.py    License: Apache License 2.0 6 votes vote down vote up
def saveToDisk(self):
        """Saves the data to disk."""
        out.info('Saving to disk (%s)\n' % (self.filename))

        # Make sure they want to save
        if(not self.attrSaveable()):
            return

        # Get whatever the data is
        pyld = self.exportAttr(self.getAttr())

        # Write the file to disk, truncate if it exists
        try:
            with open(self.filename, 'wb') as output:
                pickle.dump(pyld, output)
                os.fsync(output.fileno())
        except Exception as e:
            out.err('Error writing to disk %s\n' % (str(e)))

        try:
            with open(self.filename + ".yaml", "w") as output:
                yaml.dump(pyld, output)
        except Exception as error:
            out.err("Error writing yaml file: {}".format(error)) 
Example 11
Project: cascade-rcnn_Pytorch   Author: guoruoqian   File: pascal_voc_rbg.py    License: MIT License 6 votes vote down vote up
def gt_roidb(self):
    """
    Return the database of ground-truth regions of interest.

    This function loads/saves from/to a cache file to speed up future calls.
    """
    cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
    if os.path.exists(cache_file):
      with open(cache_file, 'rb') as fid:
        try:
          roidb = pickle.load(fid)
        except:
          roidb = pickle.load(fid, encoding='bytes')
      print('{} gt roidb loaded from {}'.format(self.name, cache_file))
      return roidb

    gt_roidb = [self._load_pascal_annotation(index)
                for index in self.image_index]
    with open(cache_file, 'wb') as fid:
      pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
    print('wrote gt roidb to {}'.format(cache_file))

    return gt_roidb 
Example 12
Project: cascade-rcnn_Pytorch   Author: guoruoqian   File: pascal_voc.py    License: MIT License 6 votes vote down vote up
def gt_roidb(self):
        """
        Return the database of ground-truth regions of interest.

        This function loads/saves from/to a cache file to speed up future calls.
        """
        cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as fid:
                roidb = pickle.load(fid)
            print('{} gt roidb loaded from {}'.format(self.name, cache_file))
            return roidb

        gt_roidb = [self._load_pascal_annotation(index)
                    for index in self.image_index]
        with open(cache_file, 'wb') as fid:
            pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
        print('wrote gt roidb to {}'.format(cache_file))

        return gt_roidb 
Example 13
Project: cascade-rcnn_Pytorch   Author: guoruoqian   File: coco.py    License: MIT License 6 votes vote down vote up
def _write_coco_results_file(self, all_boxes, res_file):
    # [{"image_id": 42,
    #   "category_id": 18,
    #   "bbox": [258.15,41.29,348.26,243.78],
    #   "score": 0.236}, ...]
    results = []
    for cls_ind, cls in enumerate(self.classes):
      if cls == '__background__':
        continue
      print('Collecting {} results ({:d}/{:d})'.format(cls, cls_ind,
                                                       self.num_classes - 1))
      coco_cat_id = self._class_to_coco_cat_id[cls]
      results.extend(self._coco_results_one_category(all_boxes[cls_ind],
                                                     coco_cat_id))
    print('Writing results json to {}'.format(res_file))
    with open(res_file, 'w') as fid:
      json.dump(results, fid) 
Example 14
Project: cascade-rcnn_Pytorch   Author: guoruoqian   File: imagenet.py    License: MIT License 6 votes vote down vote up
def gt_roidb(self):
        """
        Return the database of ground-truth regions of interest.
        This function loads/saves from/to a cache file to speed up future calls.
        """
        cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as fid:
                roidb = pickle.load(fid)
            print('{} gt roidb loaded from {}'.format(self.name, cache_file))
            return roidb

        gt_roidb = [self._load_imagenet_annotation(index)
                    for index in self.image_index]
        with open(cache_file, 'wb') as fid:
            pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
        print('wrote gt roidb to {}'.format(cache_file))

        return gt_roidb 
Example 15
Project: cascade-rcnn_Pytorch   Author: guoruoqian   File: vg.py    License: MIT License 6 votes vote down vote up
def gt_roidb(self):
        """
        Return the database of ground-truth regions of interest.

        This function loads/saves from/to a cache file to speed up future calls.
        """        
        cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
        if os.path.exists(cache_file):
            fid = gzip.open(cache_file,'rb')
            roidb = pickle.load(fid)
            fid.close()
            print('{} gt roidb loaded from {}'.format(self.name, cache_file))
            return roidb

        gt_roidb = [self._load_vg_annotation(index)
                    for index in self.image_index]
        fid = gzip.open(cache_file,'wb')
        pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
        fid.close()
        print('wrote gt roidb to {}'.format(cache_file))
        return gt_roidb 
Example 16
Project: vergeml   Author: mme   File: cache.py    License: MIT License 5 votes vote down vote up
def write(self, file):
        """Write the content index to file and update the header.
        """
        pos = file.tell()
        pickle.dump((self.index, self.meta, self.info), file)
        file.seek(0)

        # update the header with the position of the content index.
        file.write(struct.pack('<Q', pos)) 
Example 17
Project: natural-questions   Author: google-research-datasets   File: nq_eval.py    License: Apache License 2.0 5 votes vote down vote up
def main(_):
  cache_path = os.path.join(os.path.dirname(FLAGS.gold_path), 'cache')
  if FLAGS.cache_gold_data and os.path.exists(cache_path):
    logging.info('Reading from cache: %s', format(cache_path))
    nq_gold_dict = pickle.load(open(cache_path, 'r'))
  else:
    nq_gold_dict = util.read_annotation(
        FLAGS.gold_path, n_threads=FLAGS.num_threads)
    if FLAGS.cache_gold_data:
      logging.info('Caching gold data for next time to: %s', format(cache_path))
      pickle.dump(nq_gold_dict, open(cache_path, 'w'))

  nq_pred_dict = util.read_prediction_json(FLAGS.predictions_path)

  long_answer_stats, short_answer_stats = score_answers(nq_gold_dict,
                                                        nq_pred_dict)

  if FLAGS.pretty_print:
    print('*' * 20)
    print('LONG ANSWER R@P TABLE:')
    print_r_at_p_table(long_answer_stats)
    print('*' * 20)
    print('SHORT ANSWER R@P TABLE:')
    print_r_at_p_table(short_answer_stats)

    scores = compute_final_f1(long_answer_stats, short_answer_stats)
    print('*' * 20)
    print('METRICS IGNORING SCORES (n={}):'.format(scores['long-answer-n']))
    print('              F1     /  P      /  R')
    print('Long answer  {: >7.2%} / {: >7.2%} / {: >7.2%}'.format(
        scores['long-answer-f1'], scores['long-answer-precision'],
        scores['long-answer-recall']))
    print('Short answer {: >7.2%} / {: >7.2%} / {: >7.2%}'.format(
        scores['short-answer-f1'], scores['short-answer-precision'],
        scores['short-answer-recall']))
  else:
    metrics = get_metrics_with_answer_stats(long_answer_stats,
                                            short_answer_stats)
    print(json.dumps(metrics)) 
Example 18
Project: wechat-alfred-workflow   Author: TKkk-iOSer   File: workflow.py    License: MIT License 5 votes vote down vote up
def dump(cls, obj, file_obj):
        """Serialize object ``obj`` to open JSON file.

        .. versionadded:: 1.8

        :param obj: Python object to serialize
        :type obj: JSON-serializable data structure
        :param file_obj: file handle
        :type file_obj: ``file`` object

        """
        return json.dump(obj, file_obj, indent=2, encoding='utf-8') 
Example 19
Project: wechat-alfred-workflow   Author: TKkk-iOSer   File: workflow.py    License: MIT License 5 votes vote down vote up
def dump(cls, obj, file_obj):
        """Serialize object ``obj`` to open pickle file.

        .. versionadded:: 1.8

        :param obj: Python object to serialize
        :type obj: Python object
        :param file_obj: file handle
        :type file_obj: ``file`` object

        """
        return cPickle.dump(obj, file_obj, protocol=-1) 
Example 20
Project: wechat-alfred-workflow   Author: TKkk-iOSer   File: workflow.py    License: MIT License 5 votes vote down vote up
def dump(cls, obj, file_obj):
        """Serialize object ``obj`` to open pickle file.

        .. versionadded:: 1.8

        :param obj: Python object to serialize
        :type obj: Python object
        :param file_obj: file handle
        :type file_obj: ``file`` object

        """
        return pickle.dump(obj, file_obj, protocol=-1)


# Set up default manager and register built-in serializers 
Example 21
def snapshot(self, iter):
    net = self.net

    if not os.path.exists(self.output_dir):
      os.makedirs(self.output_dir)

    # Store the model snapshot
    filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pth'
    filename = os.path.join(self.output_dir, filename)
    torch.save(self.net.state_dict(), filename)
    print('Wrote snapshot to: {:s}'.format(filename))
    
    
    if iter % 10000 == 0:
        shutil.copyfile(filename, filename + '.{:d}_cache'.format(iter))
    
    # Also store some meta information, random state, etc.
    nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
    nfilename = os.path.join(self.output_dir, nfilename)
    # current state of numpy random
    st0 = np.random.get_state()
    # current position in the database
    cur = self.data_layer._cur
    # current shuffled indexes of the database
    perm = self.data_layer._perm
    # current position in the validation database
    cur_val = self.data_layer_val._cur
    # current shuffled indexes of the validation database
    perm_val = self.data_layer_val._perm

    # Dump the meta info
    with open(nfilename, 'wb') as fid:
      pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

    return filename, nfilename 
Example 22
def selective_search_roidb(self):
    """
    Return the database of selective search regions of interest.
    Ground-truth ROIs are also included.

    This function loads/saves from/to a cache file to speed up future calls.
    """
    cache_file = os.path.join(self.cache_path,
                              self.name + '_selective_search_roidb.pkl')

    if os.path.exists(cache_file):
        with open(cache_file, 'rb') as fid:
            roidb = pickle.load(fid)
        print('{} ss roidb loaded from {}'.format(self.name, cache_file))
        return roidb

    if int(self._year) == 2007 or self._image_set != 'test':
        gt_roidb = self.gt_roidb()
        # ss_roidb = self._load_selective_search_roidb(gt_roidb)
        # roidb = datasets.imdb.merge_roidbs(gt_roidb, ss_roidb)
        roidb = self._load_selective_search_roidb(gt_roidb)
    else:
        roidb = self._load_selective_search_roidb(None)
    with open(cache_file, 'wb') as fid:
        pickle.dump(roidb, fid, pickle.HIGHEST_PROTOCOL)
    print('wrote ss roidb to {}'.format(cache_file))

    return roidb 
Example 23
def _eval_discovery(self, output_dir):
    annopath = os.path.join(
        self._devkit_path,
        'VOC' + self._year,
        'Annotations',
        '{:s}.xml')
    imagesetfile = os.path.join(
        self._devkit_path,
        'VOC' + self._year,
        'ImageSets',
        'Main',
        self._image_set + '.txt')
    cachedir = os.path.join(self._devkit_path, 'annotations_dis_cache')
    corlocs = []
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)
    for i, cls in enumerate(self._classes):
        if cls == '__background__':
            continue
        filename = self._get_voc_results_file_template().format(cls)
        corloc = dis_eval(
            filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5)
        corlocs += [corloc]
        print('CorLoc for {} = {:.4f}'.format(cls, corloc))
        with open(os.path.join(output_dir, cls + '_corloc.pkl'), 'wb') as f:
            pickle.dump({'corloc': corloc}, f)
    print('Mean CorLoc = {:.4f}'.format(np.mean(corlocs)))
    print('~~~~~~~~')
    print('Results:')
    for corloc in corlocs:
        print('{:.3f}'.format(corloc))
    print('{:.3f}'.format(np.mean(corlocs)))
    print('~~~~~~~~') 
Example 24
def _do_detection_eval(self, res_file, output_dir):
    ann_type = 'bbox'
    coco_dt = self._COCO.loadRes(res_file)
    coco_eval = COCOeval(self._COCO, coco_dt)
    coco_eval.params.useSegm = (ann_type == 'segm')
    coco_eval.evaluate()
    coco_eval.accumulate()
    self._print_detection_eval_metrics(coco_eval)
    eval_file = osp.join(output_dir, 'detection_results.pkl')
    with open(eval_file, 'wb') as fid:
      pickle.dump(coco_eval, fid, pickle.HIGHEST_PROTOCOL)
    print('Wrote COCO eval results to: {}'.format(eval_file)) 
Example 25
Project: cherrypy   Author: cherrypy   File: sessions.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _save(self, expiration_time):
        assert self.locked, ('The session was saved without being locked.  '
                             "Check your tools' priority levels.")
        f = open(self._get_file_path(), 'wb')
        try:
            pickle.dump((self._data, expiration_time), f, self.pickle_protocol)
        finally:
            f.close() 
Example 26
Project: gated-graph-transformer-network   Author: hexahedria   File: util.py    License: MIT License 5 votes vote down vote up
def save_params(params, file):
    """
    Save params into a pickle file
    """
    values = [param.get_value() for param in params]
    pickle.dump(values, file) 
Example 27
Project: gated-graph-transformer-network   Author: hexahedria   File: fix_old_file_list.py    License: MIT License 5 votes vote down vote up
def main(task_dir, dry_run=False):
    with open(os.path.join(task_dir,'file_list.p'),'rb') as f:
        bucketed = pickle.load(f)
    if dry_run:
        print("Got {} (for example)".format(bucketed[0][0]))
    bucketed = [['./bucket_' + x.split('bucket_')[1] for x in b] for b in bucketed]
    if dry_run:
        print("Converting to {} (for example)".format(bucketed[0][0]))
        print("Will resolve to {} (for example)".format(os.path.normpath(os.path.join(task_dir,bucketed[0][0]))))
    else:
        with open(os.path.join(task_dir,'file_list.p'),'wb') as f:
            pickle.dump(bucketed, f) 
Example 28
Project: Traffic_sign_detection_YOLO   Author: AmeyaWagh   File: flow.py    License: MIT License 5 votes vote down vote up
def _save_ckpt(self, step, loss_profile):
    file = '{}-{}{}'
    model = self.meta['name']

    profile = file.format(model, step, '.profile')
    profile = os.path.join(self.FLAGS.backup, profile)
    with open(profile, 'wb') as profile_ckpt: 
        pickle.dump(loss_profile, profile_ckpt)

    ckpt = file.format(model, step, '')
    ckpt = os.path.join(self.FLAGS.backup, ckpt)
    self.say('Checkpoint at step {}'.format(step))
    self.saver.save(self.sess, ckpt) 
Example 29
Project: Neural-LP   Author: fanyangxyz   File: experiment.py    License: MIT License 5 votes vote down vote up
def get_attentions(self):
        if self.option.query_is_language:
            num_batch = int(np.ceil(1.0*len(self.data.query_for_rules)/self.option.batch_size))
            query_batches = np.array_split(self.data.query_for_rules, num_batch)
        else:   
            #print(self.data.query_for_rules)
            if not self.option.type_check:
                num_batch = int(np.ceil(1.*len(self.data.query_for_rules)/self.option.batch_size))
                query_batches = np.array_split(self.data.query_for_rules, num_batch)       
            else:
                query_batches = [[i] for i in self.data.query_for_rules]

        all_attention_operators = {}
        all_attention_memories = {}

        for queries in query_batches:
            attention_operators, attention_memories \
            = self.learner.get_attentions_given_queries(self.sess, queries)
            
            # Tuple-ize in order to be used as dict keys
            if self.option.query_is_language:
                queries = [tuple(q) for q in queries]

            for i in xrange(len(queries)):
                all_attention_operators[queries[i]] \
                                        = [[attn[i] 
                                        for attn in attn_step] 
                                        for attn_step in attention_operators]
                all_attention_memories[queries[i]] = \
                                        [attn_step[i, :] 
                                        for attn_step in attention_memories]
        pickle.dump([all_attention_operators, all_attention_memories], 
                    open(os.path.join(self.option.this_expsdir, "attentions.pckl"), "w"))
               
        msg = self.msg_with_time("Attentions collected.")
        print(msg)
        self.log_file.write(msg + "\n")

        all_queries = reduce(lambda x,y: list(x) + list(y), query_batches, [])
        return all_attention_operators, all_attention_memories, all_queries 
Example 30
Project: Neural-LP   Author: fanyangxyz   File: experiment.py    License: MIT License 5 votes vote down vote up
def get_rules(self):
        all_attention_operators, all_attention_memories, queries = self.get_attentions()

        all_listed_rules = {}
        all_printed_rules = []
        for i, q in enumerate(queries):
            if not self.option.query_is_language:
                if (i+1) % max(1, (len(queries) / 5)) == 0:
                    sys.stdout.write("%d/%d\t" % (i, len(queries)))
                    sys.stdout.flush()
            else: 
                # Tuple-ize in order to be used as dict keys
                q = tuple(q)
            all_listed_rules[q] = list_rules(all_attention_operators[q], 
                                             all_attention_memories[q],
                                             self.option.rule_thr,)
            all_printed_rules += print_rules(q, 
                                             all_listed_rules[q], 
                                             self.data.parser,
                                             self.option.query_is_language)

        pickle.dump(all_listed_rules, 
                    open(os.path.join(self.option.this_expsdir, "rules.pckl"), "w"))
        with open(os.path.join(self.option.this_expsdir, "rules.txt"), "w") as f:
            for line in all_printed_rules:
                f.write(line + "\n")
        msg = self.msg_with_time("\nRules listed and printed.")
        print(msg)
        self.log_file.write(msg + "\n")