Python collections.defaultdict() Examples

The following are 30 code examples of collections.defaultdict(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module collections , or try the search function .
Example #1
Source File:    From news-corpus-builder with MIT License 6 votes vote down vote up
def __init__(self,corpus_dir,datastore_type='file',db_name='corpus.db'):
        Read links and associated categories for specified articles 
        in text file seperated by a space

            corpus_dir (str): The directory to save the generated corpus
            datastore_type (Optional[str]): Format to save generated corpus.
                                            Specify either 'file' or 'sqlite'.
            db_name (Optional[str]): Name of database if 'sqlite' is selected.

        self.g = Goose({'browser_user_agent': 'Mozilla','parser_class':'soup'})
        #self.g = Goose({'browser_user_agent': 'Mozilla'})
        self.corpus_dir = corpus_dir
        self.datastore_type = datastore_type
        self.db_name = db_name
        self.stats = defaultdict(int)


        self.db = None
        if self.datastore_type == 'sqlite':
            self.db = self.corpus_dir + '/' + self.db_name
Example #2
Source File:    From DDPAE-video-prediction with MIT License 6 votes vote down vote up
def sample_latent(self, input, input_latent_mu, input_latent_sigma, pred_latent_mu,
                    pred_latent_sigma, initial_pose_mu, initial_pose_sigma, sample=True):
    Return latent variables: dictionary containing pose and content.
    Then, crop objects from the images and encode into z.
    latent = defaultdict(lambda: None)

    beta = self.get_transitions(input_latent_mu, input_latent_sigma,
                                pred_latent_mu, pred_latent_sigma, sample)
    pose = self.accumulate_pose(beta)
    # Sample initial pose
    initial_pose = self.pyro_sample('initial_pose', dist.Normal, initial_pose_mu,
                                    initial_pose_sigma, sample)
    pose += initial_pose.view(-1, 1, self.n_components, self.pose_latent_size)
    pose = self.constrain_pose(pose)

    # Get input objects
    input_pose = pose[:, :self.n_frames_input, :, :]
    input_obj = self.get_objects(input, input_pose)
    # Encode the sampled objects
    z = self.object_encoder(input_obj)
    z = self.sample_content(z, sample)
    latent.update({'pose': pose, 'content': z})
    return latent 
Example #3
Source File:    From iSDX with Apache License 2.0 6 votes vote down vote up
def __init__(self, config, flows_dir, ports_dir, num_timesteps, debug=False):
        self.logger = logging.getLogger("LogHistory")
        if debug:

        self.log_entry = namedtuple("LogEntry", "source destination type")
        self.ports = defaultdict(list)
        self.flows = defaultdict(list) = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
        self.current_timestep = 0
        self.total_timesteps = num_timesteps

        self.parse_logs(num_timesteps, flows_dir, ports_dir)

Example #4
Source File:    From deep-summarization with MIT License 6 votes vote down vote up
def precook(s, n=4, out=False):
    Takes a string as input and returns an object that can be given to
    either cook_refs or cook_test. This is optional: cook_refs and cook_test
    can take string arguments as well.

    :param s:
    :param n:
    :param out:
    words = s.split()
    counts = defaultdict(int)
    for k in xrange(1,n+1):
        for i in xrange(len(words)-k+1):
            ngram = tuple(words[i:i+k])
            counts[ngram] += 1
    return (len(words), counts) 
Example #5
Source File:    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 6 votes vote down vote up
def __init__(self, annotation_file=None):
        Constructor of Microsoft COCO helper class for reading and visualizing annotations.
        :param annotation_file (str): location of annotation file
        :param image_folder (str): location to the folder that hosts images.
        # load dataset
        self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
        self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
        if not annotation_file == None:
            print('loading annotations into memory...')
            tic = time.time()
            dataset = json.load(open(annotation_file, 'r'))
            assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
            print('Done (t={:0.2f}s)'.format(time.time()- tic))
            self.dataset = dataset
Example #6
Source File:    From aegea with Apache License 2.0 6 votes vote down vote up
def lifecycle(args):
    if args.delete:
        return resources.s3.BucketLifecycle(args.bucket_name).delete()
    rule = defaultdict(list, Prefix=args.prefix, Status="Enabled")
    if args.transition_to_infrequent_access is not None:
        rule["Transitions"].append(dict(StorageClass="STANDARD_IA", Days=args.transition_to_infrequent_access))
    if args.transition_to_glacier is not None:
        rule["Transitions"].append(dict(StorageClass="GLACIER", Days=args.transition_to_glacier))
    if args.expire is not None:
        rule["Expiration"] = dict(Days=args.expire)
    if args.abort_incomplete_multipart_upload is not None:
        rule["AbortIncompleteMultipartUpload"] = dict(DaysAfterInitiation=args.abort_incomplete_multipart_upload)
    if len(rule) > 2:
        for rule in resources.s3.BucketLifecycle(args.bucket_name).rules:
    except ClientError as e:
        expect_error_codes(e, "NoSuchLifecycleConfiguration")
        logger.error("No lifecycle configuration for bucket %s", args.bucket_name) 
Example #7
Source File:    From EDeN with MIT License 6 votes vote down vote up
def compute_matching_neighborhoods_fraction(GA, GB, pairings):
    count = 0
    matches = dict([(i, j) for i, j in enumerate(pairings)])
    matching_edges = defaultdict(list)
    for i, j in GA.edges():
        ii = matches[i]
        jj = matches[j]
        if (ii, jj) in GB.edges():
    for u in GA.nodes():
        if matching_edges.get(u, False):
            neighbors = nx.neighbors(GA, u)
            matches_neighborhood = True
            for v in neighbors:
                if v not in matching_edges[u]:
                    matches_neighborhood = False
            if matches_neighborhood:
                count += 1
    return float(count) / len(GA.nodes()) 
Example #8
Source File:    From EDeN with MIT License 6 votes vote down vote up
def extract_sequence_and_score(graph=None):
    # make dict with positions as keys and lists of ids as values
    pos_to_ids = defaultdict(list)
    for u in graph.nodes():
        if 'position' not in graph.node[u]:  # no position attributes in graph, use the vertex id instead
            raise Exception('Missing "position" attribute in node:%s %s' % (u, graph.node[u]))
            pos = graph.node[u]['position']
        # accumulate all node ids
        pos_to_ids[pos] += [u]

    # extract sequence of labels and importances
    seq = [None] * len(pos_to_ids)
    score = [0] * len(pos_to_ids)
    for pos in sorted(pos_to_ids):
        ids = pos_to_ids[pos]
        labels = [graph.node[u].get('label', 'N/A') for u in ids]
        # check that all labels for the same position are identical
        assert(sum([1 for label in labels if label == labels[0]]) == len(labels)
               ), 'ERROR: non identical labels referring to same position: %s  %s' % (pos, labels)
        seq[pos] = labels[0]
        # average all importance score for the same position
        importances = [graph.node[u].get('importance', 0) for u in ids]
        score[pos] = np.mean(importances)
    return seq, score 
Example #9
Source File:    From goodtables-py with MIT License 6 votes vote down vote up
def _create_unique_fields_cache(cells):
    primary_key_column_numbers = []
    cache = {}

    # Unique
    for _, cell in enumerate(cells, start=1):
        field = cell.get('field')
        column_number = cell.get('column-number')
        if field is not None:
            if field.descriptor.get('primaryKey'):
            if field.constraints.get('unique'):
                cache[tuple([column_number])] = defaultdict(list)

    # Primary key
    if primary_key_column_numbers:
        cache[tuple(primary_key_column_numbers)] = defaultdict(list)

    return cache 
Example #10
Source File:    From EDeN with MIT License 6 votes vote down vote up
def _add_sparse_vector_labes(self, graph, vertex_v, node_feature_list):
        # add the vector with a feature resulting from hashing
        # the discrete labeled graph sparse encoding with the sparse vector
        # feature, the val is then multiplied.
        svec = graph.nodes[vertex_v].get(self.key_svec, None)
        if svec:
            vec_feature_list = defaultdict(lambda: defaultdict(float))
            for radius_dist_key in node_feature_list:
                for feature in node_feature_list[radius_dist_key]:
                    val = node_feature_list[radius_dist_key][feature]
                    for i in svec:
                        vec_val = svec[i]
                        key = fast_hash_2(feature, i, self.bitmask)
                        vec_feature_list[radius_dist_key][key] += val * vec_val
            node_feature_list = vec_feature_list
        return node_feature_list 
Example #11
Source File:    From mmdetection with Apache License 2.0 6 votes vote down vote up
def load_json_logs(json_logs):
    # load and convert json_logs to log_dict, key is epoch, value is a sub dict
    # keys of sub dict is different metrics, e.g. memory, bbox_mAP
    # value of sub dict is a list of corresponding values of all iterations
    log_dicts = [dict() for _ in json_logs]
    for json_log, log_dict in zip(json_logs, log_dicts):
        with open(json_log, 'r') as log_file:
            for line in log_file:
                log = json.loads(line.strip())
                # skip lines without `epoch` field
                if 'epoch' not in log:
                epoch = log.pop('epoch')
                if epoch not in log_dict:
                    log_dict[epoch] = defaultdict(list)
                for k, v in log.items():
    return log_dicts 
Example #12
Source File:    From svviz with MIT License 6 votes vote down vote up
def __init__(self):
        self.args = None
        self.alignDistance = 0
        self.samples = collections.OrderedDict()
        self.genome = None
        self.sources = {}
        self.annotationSets = collections.OrderedDict()

        # for storing axes, annotations, etc, by allele
        self.alleleTracks = collections.defaultdict(collections.OrderedDict)
        self.trackCompositor = None

        self.dotplots = {} = {}

Example #13
Source File:    From query-exporter with GNU General Public License v3.0 6 votes vote down vote up
def metric_values(metric, by_labels=()):
    """Return values for the metric."""
    if metric._type == "gauge":
        suffix = ""
    elif metric._type == "counter":
        suffix = "_total"

    values = defaultdict(list)
    for sample_suffix, labels, value in metric._samples():
        if sample_suffix == suffix:
            if by_labels:
                label_values = tuple(labels[label] for label in by_labels)
                values[label_values] = value

    return values if by_labels else values[suffix] 
Example #14
Source File:    From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def _save_sorted_results(self, run_stats, scores, image_count, filename):
    """Saves sorted (by score) results of the evaluation.

      run_stats: dictionary with runtime statistics for submissions,
        can be generated by WorkPiecesBase.compute_work_statistics
      scores: dictionary mapping submission ids to scores
      image_count: dictionary with number of images processed by submission
      filename: output filename
    with open(filename, 'w') as f:
      writer = csv.writer(f)
      writer.writerow(['SubmissionID', 'ExternalTeamId', 'Score',
                       'MedianTime', 'ImageCount'])
      get_second = lambda x: x[1]
      for s_id, score in sorted(iteritems(scores),
                                key=get_second, reverse=True):
        external_id = self.submissions.get_external_id(s_id)
        stat = run_stats.get(
            s_id, collections.defaultdict(lambda: float('NaN')))
        writer.writerow([s_id, external_id, score,
Example #15
Source File:    From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def __init__(self, tau=0, name="", ds_name=""): = name
        self.ds_name = ds_name
        self.tau = tau

        self.ids = set()
        self.ids_correct = set()
        self.ids_correct_fp = set()
        self.ids_agree = set()

        # Legal = there is a fingerprint match below threshold tau
        self.ids_legal = set()

        self.counts = defaultdict(lambda: 0)
        self.counts_legal = defaultdict(lambda: 0)
        self.counts_correct = defaultdict(lambda: 0)

        # Total number of examples
        self.i = 0 
Example #16
Source File:    From deep-learning-note with MIT License 5 votes vote down vote up
def bleu(pred_tokens, label_tokens, k):
    len_pred, len_label = len(pred_tokens), len(label_tokens)
    score = math.exp(min(0, 1 - len_label / len_pred))
    for n in range(1, k + 1):
        num_matches, label_subs = 0, collections.defaultdict(int)
        for i in range(len_label - n + 1):
            label_subs[''.join(label_tokens[i: i + n])] += 1
        for i in range(len_pred - n + 1):
            if label_subs[''.join(pred_tokens[i: i + n])] > 0:
                num_matches += 1
                label_subs[''.join(pred_tokens[i: i + n])] -= 1
        score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))
    return score 
Example #17
Source File:    From friendly-telegram with GNU Affero General Public License v3.0 5 votes vote down vote up
def __init__(self, **kwargs):
        self.runner = None
        self.port = None
        self.running = asyncio.Event()
        self.ready = asyncio.Event()
        self.client_data = {}
        self._ratelimit_data = collections.defaultdict(dict) = web.Application(middlewares=[ratelimit(lambda f: self._ratelimit_data[f])])
        aiohttp_jinja2.setup(, filters={"getdoc": inspect.getdoc, "ascii": ascii},
                             loader=jinja2.FileSystemLoader("web-resources"))["static_root_url"] = "/static"
        super().__init__(**kwargs)"/static/", "web-resources/static") 
Example #18
Source File:    From goodtables-py with MIT License 5 votes vote down vote up
def _get_foreign_keys_values(schema, relations):
    # It's based on the following code:

    # we dont need to load the complete reference table to test relations
    # we can lower payload AND optimize testing foreign keys
    # by preparing the right index based on the foreign key definition
    # foreign_keys are sets of tuples of all possible values in the foreign table
    # foreign keys =
    # [reference] [foreign_keys tuple] = { (foreign_keys_values, ) : one_keyedrow, ... }
    foreign_keys = defaultdict(dict)
    if schema:
        for fk in schema.foreign_keys:
            # load relation data
            relation = fk['reference']['resource']

            # create a set of foreign keys
            # to optimize we prepare index of existing values
            # this index should use reference + foreign_keys as key
            # cause many foreign keys may use the same reference
            foreign_keys[relation][tuple(fk['reference']['fields'])] = {}
            for row in (relations[relation] or []):
                key = tuple([row[foreign_field] for foreign_field in fk['reference']['fields']])
                # here we should chose to pick the first or nth row which match
                # previous implementation picked the first, so be it
                if key not in foreign_keys[relation][tuple(fk['reference']['fields'])]:
                    foreign_keys[relation][tuple(fk['reference']['fields'])][key] = row
    return foreign_keys 
Example #19
Source File:    From friendly-telegram with GNU Affero General Public License v3.0 5 votes vote down vote up
def ratelimit(get_storage):
    async def ratelimit_middleware(request, handler):
        storage = get_storage(handler)
        if not hasattr(storage, "_ratelimit"):
            storage.setdefault("ratelimit", collections.defaultdict(lambda: 0))
            storage.setdefault("ratelimit_last", collections.defaultdict(lambda: 1))
            storage.setdefault("last_request", collections.defaultdict(lambda: 0))
        if storage["last_request"][request.remote] > time.time() - 30:
            # Maybe ratelimit, was requested within 30 seconds
            last = storage["ratelimit_last"][request.remote]
            storage["ratelimit_last"][request.remote] = storage["ratelimit"][request.remote]
            storage["ratelimit"][request.remote] += last
            if storage["ratelimit"][request.remote] > 50:
                # If they have to wait more than 5 seconds (10 requests), kill em.
                return web.Response(status=429)
            await asyncio.sleep(storage["ratelimit"][request.remote] / 10)
                del storage["ratelimit"][request.remote]
                del storage["ratelimit_last"][request.remote]
            except KeyError:
        storage["last_request"][request.remote] = time.time()
        return await handler(request)
    return ratelimit_middleware 
Example #20
Source File:    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def createIndex(self):
        # create index
        print('creating index...')
        anns, cats, imgs = {}, {}, {}
        imgToAnns,catToImgs = defaultdict(list),defaultdict(list)
        if 'annotations' in self.dataset:
            for ann in self.dataset['annotations']:
                anns[ann['id']] = ann

        if 'images' in self.dataset:
            for img in self.dataset['images']:
                imgs[img['id']] = img

        if 'categories' in self.dataset:
            for cat in self.dataset['categories']:
                cats[cat['id']] = cat

        if 'annotations' in self.dataset and 'categories' in self.dataset:
            for ann in self.dataset['annotations']:

        print('index created!')

        # create class members
        self.anns = anns
        self.imgToAnns = imgToAnns
        self.catToImgs = catToImgs
        self.imgs = imgs
        self.cats = cats 
Example #21
Source File:    From fuku-ml with MIT License 5 votes vote down vote up
def classify_with_missing_data(self, x, tree):

        if tree.is_leaf:
            # leaf
            return tree.each_class_counts
            v = x[tree.col]

            value_is_float = True
                v = float(v)
            except ValueError:
                value_is_float = False

            if v == 'None':
                true_branch = self.classify_with_missing_data(x, tree.true_branch)
                false_branch = self.classify_with_missing_data(x, tree.false_branch)
                true_branch_count = sum(true_branch.values())
                false_branch_count = sum(false_branch.values())
                true_branch_weight = float(true_branch_count) / (true_branch_count + false_branch_count)
                false_branch_weight = float(false_branch_count) / (true_branch_count + false_branch_count)
                each_class_counts = collections.defaultdict(int)
                for k, v in true_branch.items():
                    each_class_counts[k] += v * true_branch_weight
                for k, v in false_branch.items():
                    each_class_counts[k] += v * false_branch_weight
                return dict(each_class_counts)
                branch = None
                if value_is_float:
                    if v >= float(tree.value):
                        branch = tree.true_branch
                        branch = tree.false_branch
                    if v == tree.value:
                        branch = tree.true_branch
                        branch = tree.false_branch
            return self.classify_with_missing_data(x, branch) 
Example #22
Source File:    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def args_wrapper(*args):
    """Generates callback arguments for
    for a set of callback objects.
    Callback objects like PandasLogger(), LiveLearningCurve()
    get passed in.  This assembles all their callback arguments.
    out = defaultdict(list)
    for callback in args:
        callback_args = callback.callback_args()
        for k, v in callback_args.items():
    return dict(out) 
Example #23
Source File:    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def test_scope():
    class TestBlock1(gluon.HybridBlock):
        def __init__(self, prefix=None, params=None):
            super(TestBlock1, self).__init__(prefix=prefix, params=params)
        def hybrid_forward(self, F, data):
            (new_data, ) = F.contrib.cond(
                data > 0.5,
                then_func=lambda: data * 2,
                else_func=lambda: data * 3,
            return new_data
    class TestBlock2(gluon.HybridBlock):
        def __init__(self, prefix=None, params=None):
            super(TestBlock2, self).__init__(prefix=prefix, params=params)
        def hybrid_forward(self, F, data):
            (new_data, ) = F.contrib.cond(
                data > 0.5,
                then_func=lambda: data * 2,
                else_func=lambda: data * 3,
            return new_data
    AttrScope._subgraph_names = defaultdict(int)
    data = mx.nd.normal(loc=0, scale=1, shape=(1, ))
    block1 = TestBlock1()
    _ = block1(data)
    block2 = TestBlock2()
    _ = block2(data)
    assert len(AttrScope._subgraph_names) == 3
    assert AttrScope._subgraph_names['my_cond_else'] == 2
    assert AttrScope._subgraph_names['my_cond_pred'] == 2
    assert AttrScope._subgraph_names['my_cond_then'] == 2 
Example #24
Source File:    From with MIT License 5 votes vote down vote up
def predict(sequence, model_or_filelike='TMHMM2.0.model', compute_posterior=True):
    if isinstance(model_or_filelike, tuple):
        model = model_or_filelike
        _, model = parse(model_or_filelike)

    _, path = viterbi(sequence, *model)
    if compute_posterior:
        forward_table, constants = forward(sequence, *model)
        backward_table = backward(sequence, constants, *model)

        posterior = forward_table * backward_table
        _, _, _, char_map, label_map, name_map = model

        observations = len(sequence)
        states = len(name_map)

        table = np.zeros(shape=(observations, 3))
        for i in range(observations):
            group_probs = defaultdict(float)
            for j in range(states):
                group = label_map[j].lower()
                group_probs[group] += posterior[i, j]

            for k, group in enumerate(GROUP_NAMES):
                table[i, k] = group_probs[group]
        return path, table/table.sum(axis=1, keepdims=True)
    return path 
Example #25
Source File:    From DOTA_models with Apache License 2.0 5 votes vote down vote up
def main(_):
  vocab_freqs = defaultdict(int)
  doc_counts = defaultdict(int)

  # Fill vocabulary frequencies map and document counts map
  for doc in document_generators.documents(
    fill_vocab_from_doc(doc, vocab_freqs, doc_counts)

  # Filter out low-occurring terms
  vocab_freqs = dict((term, freq) for term, freq in vocab_freqs.iteritems()
                     if doc_counts[term] > FLAGS.doc_count_threshold)

  # Sort by frequency
  ordered_vocab_freqs = data_utils.sort_vocab_by_frequency(vocab_freqs)

  # Limit vocab size
  ordered_vocab_freqs = ordered_vocab_freqs[:MAX_VOCAB_SIZE]

  # Add EOS token
  ordered_vocab_freqs.append((data_utils.EOS_TOKEN, 1))

  # Write
  data_utils.write_vocab_and_frequency(ordered_vocab_freqs, FLAGS.output_dir) 
Example #26
Source File:    From soccer-matlab with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def variable_summaries(vars_, groups=None, scope='weights'):
  """Create histogram summaries for the provided variables.

  Summaries can be grouped via regexes matching variables names.

    vars_: List of variables to summarize.
    groups: Mapping of name to regex for grouping summaries.
    scope: Name scope for this operation.

    Summary tensor.
  groups = groups or {r'all': r'.*'}
  grouped = collections.defaultdict(list)
  for var in vars_:
    for name, pattern in groups.items():
      if re.match(pattern,
        name = re.sub(pattern, name,
  for name in groups:
    if name not in grouped:
      tf.logging.warn("No variables matching '{}' group.".format(name))
  summaries = []
  for name, vars_ in grouped.items():
    vars_ = [tf.reshape(var, [-1]) for var in vars_]
    vars_ = tf.concat(vars_, 0)
    summaries.append(tf.summary.histogram(scope + '/' + name, vars_))
  return tf.summary.merge(summaries) 
Example #27
Source File:    From QCElemental with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def order_molecular_formula(formula: str, order: str = "alphabetical") -> str:
    Reorders a molecular formula.

    formula: str
        A molecular formula
    order: str, optional
        Sorting order of the formula. Valid choices are "alphabetical" and "hill".

        The molecular formula.

    matches = re.findall(r"[A-Z][^A-Z]*", formula)
    if not "".join(matches) == formula:
        raise ValueError(f"{formula} is not a valid molecular formula.")
    count = collections.defaultdict(int)
    for match in matches:
        match_n = re.match(r"(\D+)(\d*)", match)
        assert match_n
        if == "":
            n = 1
            n = int(
        count[] += n
    symbols = [k for k, v in count.items() for i in range(v)]
    return molecular_formula_from_symbols(symbols=symbols, order=order) 
Example #28
Source File:    From fine-lm with MIT License 5 votes vote down vote up
def create_slots(self, var):
    """Create the factorized Adam accumulators for diet variables."""
    params = self.params
    shape = var.get_shape().as_list()

    if not hasattr(params, "slots"):
      params.slots = defaultdict(dict)

    name =
    slots = params.slots[name]

    if params.factored_second_moment_accumulator and len(shape) == 2:
      slots["adam_vr"] = tf.get_variable(
          name + "_adam_vr", [shape[0], 1],
      slots["adam_vc"] = tf.get_variable(
          name + "_adam_vc", [1, shape[1]],
      slots["adam_v"] = tf.get_variable(
          name + "_adam_v",
    if params.beta1 != 0.0:
      slots["adam_m"] = tf.get_variable(
          name + "_adam_m",
Example #29
Source File:    From QCElemental with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def __init__(self):

        from . import data

        # Of length number of elements
        self.Z = data.nist_2011_atomic_weights["Z"]
        self.E = data.nist_2011_atomic_weights["E"] = data.nist_2011_atomic_weights["name"]

        self._el2z = dict(zip(self.E, self.Z))
        self._z2el = collections.OrderedDict(zip(self.Z, self.E))
        self._element2el = dict(zip(, self.E))
        self._el2element = dict(zip(self.E,

        # Of length number of isotopes
        self._EE = data.nist_2011_atomic_weights["_EE"]
        self.EA = data.nist_2011_atomic_weights["EA"]
        self.A = data.nist_2011_atomic_weights["A"]
        self.mass = data.nist_2011_atomic_weights["mass"]

        self._eliso2mass = dict(zip(self.EA, self.mass))
        self._eliso2el = dict(zip(self.EA, self._EE))
        self._eliso2a = dict(zip(self.EA, self.A))
        self._el2a2mass = collections.defaultdict(dict)
        for EE, m, A in zip(self._EE, self.mass, self.A):
            self._el2a2mass[EE][A] = float(m) 
Example #30
Source File:    From fine-lm with MIT License 5 votes vote down vote up
def make_diet_var_getter(params):
  """Create a custom variable getter for diet variables according to params."""

  def diet_var_initializer(shape, dtype, partition_info=None):
    """Initializer for a diet variable."""
    del dtype
    del partition_info

    with common_layers.fn_device_dependency("diet_init") as out_deps:
      float_range = math.sqrt(3)
      ret = tf.random_uniform(shape, -float_range, float_range)
      if params.quantize:
        ret = _quantize(ret, params, randomize=False)
      return ret

  def diet_var_getter(getter, **kwargs):
    """Get diet variable and return it dequantized."""
    if params.quantize:
      kwargs["dtype"] = tf.float16
    kwargs["initializer"] = diet_var_initializer
    kwargs["trainable"] = False

    base_var = getter(**kwargs)

    dequantized = _dequantize(base_var, params)

    if not hasattr(params, "dequantized"):
      params.dequantized = defaultdict(list)

    return dequantized

  return diet_var_getter