Python pickle.dump() Examples

The following are 30 code examples of pickle.dump(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module pickle , or try the search function .
Example #1
Source Project: svviz   Author: svviz   File: app.py    License: MIT License 6 votes vote down vote up
def saveState(dataHub):
    import pickle as pickle
    import gzip

    pickle.dump(dataHub, gzip.open(dataHub.args.save_state, "wb"))
    logging.warn("^"*20 + " saving state to pickle and exiting " + "^"*20) 
Example #2
Source Project: wechat-alfred-workflow   Author: TKkk-iOSer   File: workflow.py    License: MIT License 6 votes vote down vote up
def register(self, name, serializer):
        """Register ``serializer`` object under ``name``.

        Raises :class:`AttributeError` if ``serializer`` in invalid.

        .. note::

            ``name`` will be used as the file extension of the saved files.

        :param name: Name to register ``serializer`` under
        :type name: ``unicode`` or ``str``
        :param serializer: object with ``load()`` and ``dump()``
            methods

        """
        # Basic validation
        getattr(serializer, 'load')
        getattr(serializer, 'dump')

        self._serializers[name] = serializer 
Example #3
Source Project: wechat-alfred-workflow   Author: TKkk-iOSer   File: workflow.py    License: MIT License 6 votes vote down vote up
def cache_data(self, name, data):
        """Save ``data`` to cache under ``name``.

        If ``data`` is ``None``, the corresponding cache file will be
        deleted.

        :param name: name of datastore
        :param data: data to store. This may be any object supported by
                the cache serializer

        """
        serializer = manager.serializer(self.cache_serializer)

        cache_path = self.cachefile('%s.%s' % (name, self.cache_serializer))

        if data is None:
            if os.path.exists(cache_path):
                os.unlink(cache_path)
                self.logger.debug('deleted cache file: %s', cache_path)
            return

        with atomic_writer(cache_path, 'wb') as file_obj:
            serializer.dump(data, file_obj)

        self.logger.debug('cached data: %s', cache_path) 
Example #4
def gt_roidb(self):
    """
    Return the database of ground-truth regions of interest.

    This function loads/saves from/to a cache file to speed up future calls.
    """
    cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
    if os.path.exists(cache_file):
      with open(cache_file, 'rb') as fid:
        try:
          roidb = pickle.load(fid)
        except:
          roidb = pickle.load(fid, encoding='bytes')
      print('{} gt roidb loaded from {}'.format(self.name, cache_file))
      return roidb

    gt_roidb = [self._load_pascal_labels(index)
                for index in self.image_index]
    with open(cache_file, 'wb') as fid:
      pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
    print('wrote gt roidb to {}'.format(cache_file))

    return gt_roidb 
Example #5
Source Project: Collaborative-Learning-for-Weakly-Supervised-Object-Detection   Author: Sunarker   File: coco.py    License: MIT License 6 votes vote down vote up
def _write_coco_results_file(self, all_boxes, res_file):
    # [{"image_id": 42,
    #   "category_id": 18,
    #   "bbox": [258.15,41.29,348.26,243.78],
    #   "score": 0.236}, ...]
    results = []
    for cls_ind, cls in enumerate(self.classes):
      if cls == '__background__':
        continue
      print('Collecting {} results ({:d}/{:d})'.format(cls, cls_ind,
                                                       self.num_classes - 1))
      coco_cat_id = self._class_to_coco_cat_id[cls]
      results.extend(self._coco_results_one_category(all_boxes[cls_ind],
                                                     coco_cat_id))
    print('Writing results json to {}'.format(res_file))
    with open(res_file, 'w') as fid:
      json.dump(results, fid) 
Example #6
Source Project: comet-commonsense   Author: atcbosselut   File: data.py    License: Apache License 2.0 6 votes vote down vote up
def save_eval_file(opt, stats, eval_type="losses", split="dev", ext="pickle"):
    if cfg.test_save:
        name = "{}/{}.{}".format(utils.make_name(
            opt, prefix="garbage/{}/".format(eval_type),
            is_dir=True, eval_=True), split, ext)
    else:
        name = "{}/{}.{}".format(utils.make_name(
            opt, prefix="results/{}/".format(eval_type),
            is_dir=True, eval_=True), split, ext)
    print("Saving {} {} to {}".format(split, eval_type, name))

    if ext == "pickle":
        with open(name, "wb") as f:
            pickle.dump(stats, f)
    elif ext == "txt":
        with open(name, "w") as f:
            f.write(stats)
    elif ext == "json":
        with open(name, "w") as f:
            json.dump(stats, f)
    else:
        raise 
Example #7
Source Project: gated-graph-transformer-network   Author: hexahedria   File: update_cache_compatibility.py    License: MIT License 6 votes vote down vote up
def main(cache_dir):
    files_list = list(os.listdir(cache_dir))
    for file in files_list:
        full_filename = os.path.join(cache_dir, file)
        if os.path.isfile(full_filename):
            print("Processing {}".format(full_filename))
            m, stored_kwargs = pickle.load(open(full_filename, 'rb'))
            updated_kwargs = util.get_compatible_kwargs(model.Model, stored_kwargs)

            model_hash = util.object_hash(updated_kwargs)
            print("New hash -> " + model_hash)
            model_filename = os.path.join(cache_dir, "model_{}.p".format(model_hash))
            sys.setrecursionlimit(100000)
            pickle.dump((m,updated_kwargs), open(model_filename,'wb'), protocol=pickle.HIGHEST_PROTOCOL)

            os.remove(full_filename) 
Example #8
Source Project: Neural-LP   Author: fanyangxyz   File: experiment.py    License: MIT License 6 votes vote down vote up
def train(self):
        while (self.epoch < self.option.max_epoch and not self.early_stopped):
            self.one_epoch_train()
            self.one_epoch_valid()
            self.one_epoch_test()
            self.epoch += 1
            model_path = self.saver.save(self.sess, 
                                         self.option.model_path,
                                         global_step=self.epoch)
            print("Model saved at %s" % model_path)
            
            if self.early_stop():
                self.early_stopped = True
                print("Early stopped at epoch %d" % (self.epoch))
        
        all_test_in_top = [np.mean(x[1]) for x in self.test_stats]
        best_test_epoch = np.argmax(all_test_in_top)
        best_test = all_test_in_top[best_test_epoch]
        
        msg = "Best test in top: %0.4f at epoch %d." % (best_test, best_test_epoch + 1)       
        print(msg)
        self.log_file.write(msg + "\n")
        pickle.dump([self.train_stats, self.valid_stats, self.test_stats],
                    open(os.path.join(self.option.this_expsdir, "results.pckl"), "w")) 
Example #9
Source Project: graph-neural-networks   Author: alelab-upenn   File: miscTools.py    License: GNU General Public License v3.0 6 votes vote down vote up
def saveSeed(randomStates, saveDir):
    """
    Takes a list of dictionaries of random generator states of different modules
    and saves them in a .pkl format.
    
    Inputs:
        randomStates (list): The length of this list is equal to the number of
            modules whose states want to be saved (torch, numpy, etc.). Each
            element in this list is a dictionary. The dictionary has three keys:
            'module' with the name of the module in string format ('numpy' or
            'torch', for example), 'state' with the saved generator state and,
            if corresponds, 'seed' with the specific seed for the generator
            (note that torch has both state and seed, but numpy only has state)
        saveDir (path): where to save the seed, it will be saved under the 
            filename 'randomSeedUsed.pkl'
    """
    pathToSeed = os.path.join(saveDir, 'randomSeedUsed.pkl')
    with open(pathToSeed, 'wb') as seedFile:
        pickle.dump({'randomStates': randomStates}, seedFile) 
Example #10
Source Project: Paradrop   Author: ParadropLabs   File: pd_storage.py    License: Apache License 2.0 6 votes vote down vote up
def saveToDisk(self):
        """Saves the data to disk."""
        out.info('Saving to disk (%s)\n' % (self.filename))

        # Make sure they want to save
        if(not self.attrSaveable()):
            return

        # Get whatever the data is
        pyld = self.exportAttr(self.getAttr())

        # Write the file to disk, truncate if it exists
        try:
            with open(self.filename, 'wb') as output:
                pickle.dump(pyld, output)
                os.fsync(output.fileno())
        except Exception as e:
            out.err('Error writing to disk %s\n' % (str(e)))

        try:
            with open(self.filename + ".yaml", "w") as output:
                yaml.dump(pyld, output)
        except Exception as error:
            out.err("Error writing yaml file: {}".format(error)) 
Example #11
Source Project: cascade-rcnn_Pytorch   Author: guoruoqian   File: pascal_voc_rbg.py    License: MIT License 6 votes vote down vote up
def gt_roidb(self):
    """
    Return the database of ground-truth regions of interest.

    This function loads/saves from/to a cache file to speed up future calls.
    """
    cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
    if os.path.exists(cache_file):
      with open(cache_file, 'rb') as fid:
        try:
          roidb = pickle.load(fid)
        except:
          roidb = pickle.load(fid, encoding='bytes')
      print('{} gt roidb loaded from {}'.format(self.name, cache_file))
      return roidb

    gt_roidb = [self._load_pascal_annotation(index)
                for index in self.image_index]
    with open(cache_file, 'wb') as fid:
      pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
    print('wrote gt roidb to {}'.format(cache_file))

    return gt_roidb 
Example #12
Source Project: cascade-rcnn_Pytorch   Author: guoruoqian   File: pascal_voc.py    License: MIT License 6 votes vote down vote up
def gt_roidb(self):
        """
        Return the database of ground-truth regions of interest.

        This function loads/saves from/to a cache file to speed up future calls.
        """
        cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as fid:
                roidb = pickle.load(fid)
            print('{} gt roidb loaded from {}'.format(self.name, cache_file))
            return roidb

        gt_roidb = [self._load_pascal_annotation(index)
                    for index in self.image_index]
        with open(cache_file, 'wb') as fid:
            pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
        print('wrote gt roidb to {}'.format(cache_file))

        return gt_roidb 
Example #13
Source Project: cascade-rcnn_Pytorch   Author: guoruoqian   File: coco.py    License: MIT License 6 votes vote down vote up
def _write_coco_results_file(self, all_boxes, res_file):
    # [{"image_id": 42,
    #   "category_id": 18,
    #   "bbox": [258.15,41.29,348.26,243.78],
    #   "score": 0.236}, ...]
    results = []
    for cls_ind, cls in enumerate(self.classes):
      if cls == '__background__':
        continue
      print('Collecting {} results ({:d}/{:d})'.format(cls, cls_ind,
                                                       self.num_classes - 1))
      coco_cat_id = self._class_to_coco_cat_id[cls]
      results.extend(self._coco_results_one_category(all_boxes[cls_ind],
                                                     coco_cat_id))
    print('Writing results json to {}'.format(res_file))
    with open(res_file, 'w') as fid:
      json.dump(results, fid) 
Example #14
Source Project: cascade-rcnn_Pytorch   Author: guoruoqian   File: imagenet.py    License: MIT License 6 votes vote down vote up
def gt_roidb(self):
        """
        Return the database of ground-truth regions of interest.
        This function loads/saves from/to a cache file to speed up future calls.
        """
        cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as fid:
                roidb = pickle.load(fid)
            print('{} gt roidb loaded from {}'.format(self.name, cache_file))
            return roidb

        gt_roidb = [self._load_imagenet_annotation(index)
                    for index in self.image_index]
        with open(cache_file, 'wb') as fid:
            pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
        print('wrote gt roidb to {}'.format(cache_file))

        return gt_roidb 
Example #15
Source Project: cascade-rcnn_Pytorch   Author: guoruoqian   File: vg.py    License: MIT License 6 votes vote down vote up
def gt_roidb(self):
        """
        Return the database of ground-truth regions of interest.

        This function loads/saves from/to a cache file to speed up future calls.
        """        
        cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
        if os.path.exists(cache_file):
            fid = gzip.open(cache_file,'rb')
            roidb = pickle.load(fid)
            fid.close()
            print('{} gt roidb loaded from {}'.format(self.name, cache_file))
            return roidb

        gt_roidb = [self._load_vg_annotation(index)
                    for index in self.image_index]
        fid = gzip.open(cache_file,'wb')
        pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
        fid.close()
        print('wrote gt roidb to {}'.format(cache_file))
        return gt_roidb 
Example #16
Source Project: vergeml   Author: mme   File: cache.py    License: MIT License 5 votes vote down vote up
def write(self, file):
        """Write the content index to file and update the header.
        """
        pos = file.tell()
        pickle.dump((self.index, self.meta, self.info), file)
        file.seek(0)

        # update the header with the position of the content index.
        file.write(struct.pack('<Q', pos)) 
Example #17
Source Project: natural-questions   Author: google-research-datasets   File: nq_eval.py    License: Apache License 2.0 5 votes vote down vote up
def main(_):
  cache_path = os.path.join(os.path.dirname(FLAGS.gold_path), 'cache')
  if FLAGS.cache_gold_data and os.path.exists(cache_path):
    logging.info('Reading from cache: %s', format(cache_path))
    nq_gold_dict = pickle.load(open(cache_path, 'r'))
  else:
    nq_gold_dict = util.read_annotation(
        FLAGS.gold_path, n_threads=FLAGS.num_threads)
    if FLAGS.cache_gold_data:
      logging.info('Caching gold data for next time to: %s', format(cache_path))
      pickle.dump(nq_gold_dict, open(cache_path, 'w'))

  nq_pred_dict = util.read_prediction_json(FLAGS.predictions_path)

  long_answer_stats, short_answer_stats = score_answers(nq_gold_dict,
                                                        nq_pred_dict)

  if FLAGS.pretty_print:
    print('*' * 20)
    print('LONG ANSWER R@P TABLE:')
    print_r_at_p_table(long_answer_stats)
    print('*' * 20)
    print('SHORT ANSWER R@P TABLE:')
    print_r_at_p_table(short_answer_stats)

    scores = compute_final_f1(long_answer_stats, short_answer_stats)
    print('*' * 20)
    print('METRICS IGNORING SCORES (n={}):'.format(scores['long-answer-n']))
    print('              F1     /  P      /  R')
    print('Long answer  {: >7.2%} / {: >7.2%} / {: >7.2%}'.format(
        scores['long-answer-f1'], scores['long-answer-precision'],
        scores['long-answer-recall']))
    print('Short answer {: >7.2%} / {: >7.2%} / {: >7.2%}'.format(
        scores['short-answer-f1'], scores['short-answer-precision'],
        scores['short-answer-recall']))
  else:
    metrics = get_metrics_with_answer_stats(long_answer_stats,
                                            short_answer_stats)
    print(json.dumps(metrics)) 
Example #18
Source Project: wechat-alfred-workflow   Author: TKkk-iOSer   File: workflow.py    License: MIT License 5 votes vote down vote up
def dump(cls, obj, file_obj):
        """Serialize object ``obj`` to open JSON file.

        .. versionadded:: 1.8

        :param obj: Python object to serialize
        :type obj: JSON-serializable data structure
        :param file_obj: file handle
        :type file_obj: ``file`` object

        """
        return json.dump(obj, file_obj, indent=2, encoding='utf-8') 
Example #19
Source Project: wechat-alfred-workflow   Author: TKkk-iOSer   File: workflow.py    License: MIT License 5 votes vote down vote up
def dump(cls, obj, file_obj):
        """Serialize object ``obj`` to open pickle file.

        .. versionadded:: 1.8

        :param obj: Python object to serialize
        :type obj: Python object
        :param file_obj: file handle
        :type file_obj: ``file`` object

        """
        return cPickle.dump(obj, file_obj, protocol=-1) 
Example #20
Source Project: wechat-alfred-workflow   Author: TKkk-iOSer   File: workflow.py    License: MIT License 5 votes vote down vote up
def dump(cls, obj, file_obj):
        """Serialize object ``obj`` to open pickle file.

        .. versionadded:: 1.8

        :param obj: Python object to serialize
        :type obj: Python object
        :param file_obj: file handle
        :type file_obj: ``file`` object

        """
        return pickle.dump(obj, file_obj, protocol=-1)


# Set up default manager and register built-in serializers 
Example #21
Source Project: Collaborative-Learning-for-Weakly-Supervised-Object-Detection   Author: Sunarker   File: train_val.py    License: MIT License 5 votes vote down vote up
def snapshot(self, iter):
    net = self.net

    if not os.path.exists(self.output_dir):
      os.makedirs(self.output_dir)

    # Store the model snapshot
    filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pth'
    filename = os.path.join(self.output_dir, filename)
    torch.save(self.net.state_dict(), filename)
    print('Wrote snapshot to: {:s}'.format(filename))
    
    
    if iter % 10000 == 0:
        shutil.copyfile(filename, filename + '.{:d}_cache'.format(iter))
    
    # Also store some meta information, random state, etc.
    nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
    nfilename = os.path.join(self.output_dir, nfilename)
    # current state of numpy random
    st0 = np.random.get_state()
    # current position in the database
    cur = self.data_layer._cur
    # current shuffled indexes of the database
    perm = self.data_layer._perm
    # current position in the validation database
    cur_val = self.data_layer_val._cur
    # current shuffled indexes of the validation database
    perm_val = self.data_layer_val._perm

    # Dump the meta info
    with open(nfilename, 'wb') as fid:
      pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

    return filename, nfilename 
Example #22
def selective_search_roidb(self):
    """
    Return the database of selective search regions of interest.
    Ground-truth ROIs are also included.

    This function loads/saves from/to a cache file to speed up future calls.
    """
    cache_file = os.path.join(self.cache_path,
                              self.name + '_selective_search_roidb.pkl')

    if os.path.exists(cache_file):
        with open(cache_file, 'rb') as fid:
            roidb = pickle.load(fid)
        print('{} ss roidb loaded from {}'.format(self.name, cache_file))
        return roidb

    if int(self._year) == 2007 or self._image_set != 'test':
        gt_roidb = self.gt_roidb()
        # ss_roidb = self._load_selective_search_roidb(gt_roidb)
        # roidb = datasets.imdb.merge_roidbs(gt_roidb, ss_roidb)
        roidb = self._load_selective_search_roidb(gt_roidb)
    else:
        roidb = self._load_selective_search_roidb(None)
    with open(cache_file, 'wb') as fid:
        pickle.dump(roidb, fid, pickle.HIGHEST_PROTOCOL)
    print('wrote ss roidb to {}'.format(cache_file))

    return roidb 
Example #23
def _eval_discovery(self, output_dir):
    annopath = os.path.join(
        self._devkit_path,
        'VOC' + self._year,
        'Annotations',
        '{:s}.xml')
    imagesetfile = os.path.join(
        self._devkit_path,
        'VOC' + self._year,
        'ImageSets',
        'Main',
        self._image_set + '.txt')
    cachedir = os.path.join(self._devkit_path, 'annotations_dis_cache')
    corlocs = []
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)
    for i, cls in enumerate(self._classes):
        if cls == '__background__':
            continue
        filename = self._get_voc_results_file_template().format(cls)
        corloc = dis_eval(
            filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5)
        corlocs += [corloc]
        print('CorLoc for {} = {:.4f}'.format(cls, corloc))
        with open(os.path.join(output_dir, cls + '_corloc.pkl'), 'wb') as f:
            pickle.dump({'corloc': corloc}, f)
    print('Mean CorLoc = {:.4f}'.format(np.mean(corlocs)))
    print('~~~~~~~~')
    print('Results:')
    for corloc in corlocs:
        print('{:.3f}'.format(corloc))
    print('{:.3f}'.format(np.mean(corlocs)))
    print('~~~~~~~~') 
Example #24
Source Project: Collaborative-Learning-for-Weakly-Supervised-Object-Detection   Author: Sunarker   File: coco.py    License: MIT License 5 votes vote down vote up
def _do_detection_eval(self, res_file, output_dir):
    ann_type = 'bbox'
    coco_dt = self._COCO.loadRes(res_file)
    coco_eval = COCOeval(self._COCO, coco_dt)
    coco_eval.params.useSegm = (ann_type == 'segm')
    coco_eval.evaluate()
    coco_eval.accumulate()
    self._print_detection_eval_metrics(coco_eval)
    eval_file = osp.join(output_dir, 'detection_results.pkl')
    with open(eval_file, 'wb') as fid:
      pickle.dump(coco_eval, fid, pickle.HIGHEST_PROTOCOL)
    print('Wrote COCO eval results to: {}'.format(eval_file)) 
Example #25
Source Project: cherrypy   Author: cherrypy   File: sessions.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _save(self, expiration_time):
        assert self.locked, ('The session was saved without being locked.  '
                             "Check your tools' priority levels.")
        f = open(self._get_file_path(), 'wb')
        try:
            pickle.dump((self._data, expiration_time), f, self.pickle_protocol)
        finally:
            f.close() 
Example #26
Source Project: gated-graph-transformer-network   Author: hexahedria   File: util.py    License: MIT License 5 votes vote down vote up
def save_params(params, file):
    """
    Save params into a pickle file
    """
    values = [param.get_value() for param in params]
    pickle.dump(values, file) 
Example #27
Source Project: gated-graph-transformer-network   Author: hexahedria   File: fix_old_file_list.py    License: MIT License 5 votes vote down vote up
def main(task_dir, dry_run=False):
    with open(os.path.join(task_dir,'file_list.p'),'rb') as f:
        bucketed = pickle.load(f)
    if dry_run:
        print("Got {} (for example)".format(bucketed[0][0]))
    bucketed = [['./bucket_' + x.split('bucket_')[1] for x in b] for b in bucketed]
    if dry_run:
        print("Converting to {} (for example)".format(bucketed[0][0]))
        print("Will resolve to {} (for example)".format(os.path.normpath(os.path.join(task_dir,bucketed[0][0]))))
    else:
        with open(os.path.join(task_dir,'file_list.p'),'wb') as f:
            pickle.dump(bucketed, f) 
Example #28
Source Project: Traffic_sign_detection_YOLO   Author: AmeyaWagh   File: flow.py    License: MIT License 5 votes vote down vote up
def _save_ckpt(self, step, loss_profile):
    file = '{}-{}{}'
    model = self.meta['name']

    profile = file.format(model, step, '.profile')
    profile = os.path.join(self.FLAGS.backup, profile)
    with open(profile, 'wb') as profile_ckpt: 
        pickle.dump(loss_profile, profile_ckpt)

    ckpt = file.format(model, step, '')
    ckpt = os.path.join(self.FLAGS.backup, ckpt)
    self.say('Checkpoint at step {}'.format(step))
    self.saver.save(self.sess, ckpt) 
Example #29
Source Project: Neural-LP   Author: fanyangxyz   File: experiment.py    License: MIT License 5 votes vote down vote up
def get_attentions(self):
        if self.option.query_is_language:
            num_batch = int(np.ceil(1.0*len(self.data.query_for_rules)/self.option.batch_size))
            query_batches = np.array_split(self.data.query_for_rules, num_batch)
        else:   
            #print(self.data.query_for_rules)
            if not self.option.type_check:
                num_batch = int(np.ceil(1.*len(self.data.query_for_rules)/self.option.batch_size))
                query_batches = np.array_split(self.data.query_for_rules, num_batch)       
            else:
                query_batches = [[i] for i in self.data.query_for_rules]

        all_attention_operators = {}
        all_attention_memories = {}

        for queries in query_batches:
            attention_operators, attention_memories \
            = self.learner.get_attentions_given_queries(self.sess, queries)
            
            # Tuple-ize in order to be used as dict keys
            if self.option.query_is_language:
                queries = [tuple(q) for q in queries]

            for i in xrange(len(queries)):
                all_attention_operators[queries[i]] \
                                        = [[attn[i] 
                                        for attn in attn_step] 
                                        for attn_step in attention_operators]
                all_attention_memories[queries[i]] = \
                                        [attn_step[i, :] 
                                        for attn_step in attention_memories]
        pickle.dump([all_attention_operators, all_attention_memories], 
                    open(os.path.join(self.option.this_expsdir, "attentions.pckl"), "w"))
               
        msg = self.msg_with_time("Attentions collected.")
        print(msg)
        self.log_file.write(msg + "\n")

        all_queries = reduce(lambda x,y: list(x) + list(y), query_batches, [])
        return all_attention_operators, all_attention_memories, all_queries 
Example #30
Source Project: Neural-LP   Author: fanyangxyz   File: experiment.py    License: MIT License 5 votes vote down vote up
def get_rules(self):
        all_attention_operators, all_attention_memories, queries = self.get_attentions()

        all_listed_rules = {}
        all_printed_rules = []
        for i, q in enumerate(queries):
            if not self.option.query_is_language:
                if (i+1) % max(1, (len(queries) / 5)) == 0:
                    sys.stdout.write("%d/%d\t" % (i, len(queries)))
                    sys.stdout.flush()
            else: 
                # Tuple-ize in order to be used as dict keys
                q = tuple(q)
            all_listed_rules[q] = list_rules(all_attention_operators[q], 
                                             all_attention_memories[q],
                                             self.option.rule_thr,)
            all_printed_rules += print_rules(q, 
                                             all_listed_rules[q], 
                                             self.data.parser,
                                             self.option.query_is_language)

        pickle.dump(all_listed_rules, 
                    open(os.path.join(self.option.this_expsdir, "rules.pckl"), "w"))
        with open(os.path.join(self.option.this_expsdir, "rules.txt"), "w") as f:
            for line in all_printed_rules:
                f.write(line + "\n")
        msg = self.msg_with_time("\nRules listed and printed.")
        print(msg)
        self.log_file.write(msg + "\n")