Python itertools.chain() Examples

The following are 30 code examples of itertools.chain(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module itertools , or try the search function .
Example #1
Source File: DDPAE.py    From DDPAE-video-prediction with MIT License 6 votes vote down vote up
def setup_training(self):
    '''
    Setup Pyro SVI, optimizers.
    '''
    if not self.is_train:
      return

    self.pyro_optimizer = optim.Adam({'lr': self.lr_init})
    self.svis = {'elbo': SVI(self.model, self.guide, self.pyro_optimizer, loss=Trace_ELBO())}

    # Separate pose_model parameters and other networks' parameters
    params = []
    for name, net in self.nets.items():
      if name != 'pose_model':
        params.append(net.parameters())
    self.optimizer = torch.optim.Adam(\
                     [{'params': self.pose_model.parameters(), 'lr': self.lr_init},
                      {'params': itertools.chain(*params), 'lr': self.lr_init}
                     ], betas=(0.5, 0.999)) 
Example #2
Source File: block_base.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def CreateWeightLoss(self):
    """Returns L2 loss list of (almost) all variables used inside this block.

    When this method needs to be overridden, there are two choices.

    1. Override CreateWeightLoss() to change the weight loss of all variables
       that belong to this block, both directly and indirectly.
    2. Override _CreateWeightLoss() to change the weight loss of all
       variables that directly belong to this block but not to the sub-blocks.

    Returns:
      A Tensor object or None.
    """
    losses = list(itertools.chain(
        itertools.chain.from_iterable(
            t.CreateWeightLoss() for t in self._subblocks),
        self._CreateWeightLoss()))
    return losses 
Example #3
Source File: single_path_functions.py    From apted with MIT License 6 votes vote down vote up
def sub_spf1(ni, subtree1, subtree2, op, calculate):
    """Implements spf1 single path function for the case when the
    other subtree is a single node

    Params:
      ni -- node indexer for the subtree that has more than one element
      subtree1 -- subtree that has a single element
      subtree2 -- subtree that has more than one element
      op -- cost of deleting/inserting node
      calculate -- function(node, other) that returns the cost of
        renaming nodes
    """
    # pylint: disable=invalid-name
    # pylint: disable=too-many-arguments
    cost = subtree2.sum_cost
    max_cost = cost + op
    min_ren_minus_op = min(chain([cost], [
        calculate(subtree1, info)
        for _, info in ni.preorder_ltr(subtree2)
    ]))
    return min(min_ren_minus_op + cost, max_cost) 
Example #4
Source File: EncodingDataParallel.py    From torch-toolbox with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def forward(self, *inputs, **kwargs):
        if not self.device_ids:
            return self.module(*inputs, **kwargs)

        for t in chain(self.module.parameters(), self.module.buffers()):
            if t.device != self.src_device_obj:
                raise RuntimeError(
                    "module must have its parameters and buffers "
                    "on device {} (device_ids[0]) but found one of "
                    "them on device: {}".format(
                        self.src_device_obj, t.device))
        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
        if len(self.device_ids) == 1:
            return self.module(*inputs, **kwargs)
        replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
        outputs = self.parallel_apply(replicas, inputs, kwargs)
        return outputs 
Example #5
Source File: app.py    From quart with MIT License 6 votes vote down vote up
def update_template_context(self, context: dict) -> None:
        """Update the provided template context.

        This adds additional context from the various template context
        processors.

        Arguments:
            context: The context to update (mutate).
        """
        processors = self.template_context_processors[None]
        if has_request_context():
            blueprint = _request_ctx_stack.top.request.blueprint
            if blueprint is not None and blueprint in self.template_context_processors:
                processors = chain(  # type: ignore
                    processors, self.template_context_processors[blueprint]
                )
        extra_context: dict = {}
        for processor in processors:
            extra_context.update(await processor())
        original = context.copy()
        context.update(extra_context)
        context.update(original) 
Example #6
Source File: app.py    From quart with MIT License 6 votes vote down vote up
def do_teardown_request(
        self, exc: Optional[BaseException], request_context: Optional[RequestContext] = None
    ) -> None:
        """Teardown the request, calling the teardown functions.

        Arguments:
            exc: Any exception not handled that has caused the request
                to teardown.
            request_context: The request context, optional as Flask
                omits this argument.
        """
        request_ = (request_context or _request_ctx_stack.top).request
        functions = self.teardown_request_funcs[None]
        blueprint = request_.blueprint
        if blueprint is not None:
            functions = chain(functions, self.teardown_request_funcs[blueprint])  # type: ignore

        for function in functions:
            await function(exc)
        await request_tearing_down.send(self, exc=exc) 
Example #7
Source File: app.py    From quart with MIT License 6 votes vote down vote up
def do_teardown_websocket(
        self, exc: Optional[BaseException], websocket_context: Optional[WebsocketContext] = None
    ) -> None:
        """Teardown the websocket, calling the teardown functions.

        Arguments:
            exc: Any exception not handled that has caused the websocket
                to teardown.
            websocket_context: The websocket context, optional as Flask
                omits this argument.
        """
        websocket_ = (websocket_context or _websocket_ctx_stack.top).websocket
        functions = self.teardown_websocket_funcs[None]
        blueprint = websocket_.blueprint
        if blueprint is not None:
            functions = chain(functions, self.teardown_websocket_funcs[blueprint])  # type: ignore

        for function in functions:
            await function(exc)
        await websocket_tearing_down.send(self, exc=exc) 
Example #8
Source File: app.py    From quart with MIT License 6 votes vote down vote up
def preprocess_request(
        self, request_context: Optional[RequestContext] = None
    ) -> Optional[ResponseReturnValue]:
        """Preprocess the request i.e. call before_request functions.

        Arguments:
            request_context: The request context, optional as Flask
                omits this argument.
        """
        request_ = (request_context or _request_ctx_stack.top).request
        blueprint = request_.blueprint
        processors = self.url_value_preprocessors[None]
        if blueprint is not None:
            processors = chain(processors, self.url_value_preprocessors[blueprint])  # type: ignore
        for processor in processors:
            processor(request.endpoint, request.view_args)

        functions = self.before_request_funcs[None]
        if blueprint is not None:
            functions = chain(functions, self.before_request_funcs[blueprint])  # type: ignore
        for function in functions:
            result = await function()
            if result is not None:
                return result
        return None 
Example #9
Source File: app.py    From quart with MIT License 6 votes vote down vote up
def preprocess_websocket(
        self, websocket_context: Optional[WebsocketContext] = None
    ) -> Optional[ResponseReturnValue]:
        """Preprocess the websocket i.e. call before_websocket functions.

        Arguments:
            websocket_context: The websocket context, optional as Flask
                omits this argument.
        """
        websocket_ = (websocket_context or _websocket_ctx_stack.top).websocket
        blueprint = websocket_.blueprint
        processors = self.url_value_preprocessors[None]
        if blueprint is not None:
            processors = chain(processors, self.url_value_preprocessors[blueprint])  # type: ignore
        for processor in processors:
            processor(websocket_.endpoint, websocket_.view_args)

        functions = self.before_websocket_funcs[None]
        if blueprint is not None:
            functions = chain(functions, self.before_websocket_funcs[blueprint])  # type: ignore
        for function in functions:
            result = await function()
            if result is not None:
                return result
        return None 
Example #10
Source File: hmm.py    From razzy-spinner with GNU General Public License v3.0 6 votes vote down vote up
def point_entropy(self, unlabeled_sequence):
        """
        Returns the pointwise entropy over the possible states at each
        position in the chain, given the observation sequence.
        """
        unlabeled_sequence = self._transform(unlabeled_sequence)

        T = len(unlabeled_sequence)
        N = len(self._states)

        alpha = self._forward_probability(unlabeled_sequence)
        beta = self._backward_probability(unlabeled_sequence)
        normalisation = logsumexp2(alpha[T-1])

        entropies = np.zeros(T, np.float64)
        probs = np.zeros(N, np.float64)
        for t in range(T):
            for s in range(N):
                probs[s] = alpha[t, s] + beta[t, s] - normalisation

            for s in range(N):
                entropies[t] -= 2**(probs[s]) * probs[s]

        return entropies 
Example #11
Source File: berny.py    From pyberny with Mozilla Public License 2.0 6 votes vote down vote up
def __init__(
        self, geom, debug=False, restart=None, maxsteps=100, logger=None, **params
    ):
        self._debug = debug
        self._maxsteps = maxsteps
        self._converged = False
        self._n = 0
        self._log = BernyAdapter(logger or log, {'step': self._n})
        s = self._state = Berny.State()
        if restart:
            vars(s).update(restart)
            return
        s.geom = geom
        s.params = dict(chain(defaults.items(), params.items()))
        s.trust = s.params['trust']
        s.coords = InternalCoords(
            s.geom, dihedral=s.params['dihedral'], superweakdih=s.params['superweakdih']
        )
        s.H = s.coords.hessian_guess(s.geom)
        s.weights = s.coords.weights(s.geom)
        s.future = Berny.Point(s.coords.eval_geom(s.geom), None, None)
        s.first = True
        for line in str(s.coords).split('\n'):
            self._log.info(line) 
Example #12
Source File: didyoumean_internal.py    From DidYouMean-Python with MIT License 6 votes vote down vote up
def get_attribute_suggestions(type_str, attribute, frame):
    """Get the suggestions closest to the attribute name for a given type."""
    types = get_types_for_str(type_str, frame)
    attributes = set(a for t in types for a in dir(t))
    if type_str == 'module':
        # For module, we manage to get the corresponding 'module' type
        # but the type doesn't bring much information about its content.
        # A hacky way to do so is to assume that the exception was something
        # like 'module_name.attribute' so that we can actually find the module
        # based on the name. Eventually, we check that the found object is a
        # module indeed. This is not failproof but it brings a whole lot of
        # interesting suggestions and the (minimal) risk is to have invalid
        # suggestions.
        module_name = frame.f_code.co_names[0]
        objs = get_objects_in_frame(frame)
        mod = objs[module_name][0].obj
        if inspect.ismodule(mod):
            attributes = set(dir(mod))

    return itertools.chain(
        suggest_attribute_as_builtin(attribute, type_str, frame),
        suggest_attribute_alternative(attribute, type_str, attributes),
        suggest_attribute_as_typo(attribute, attributes),
        suggest_attribute_as_special_case(attribute)) 
Example #13
Source File: argparse_to_json.py    From me-ica with GNU Lesser General Public License v2.1 6 votes vote down vote up
def process(parser, widget_dict):
  mutually_exclusive_groups = [
                  [mutex_action for mutex_action in group_actions._group_actions]
                  for group_actions in parser._mutually_exclusive_groups]

  group_options = list(chain(*mutually_exclusive_groups))

  base_actions = [action for action in parser._actions
                  if action not in group_options
                  and action.dest != 'help']

  required_actions = filter(is_required, base_actions)
  optional_actions = filter(is_optional, base_actions)

  return list(categorize(required_actions, widget_dict, required=True)) + \
         list(categorize(optional_actions, widget_dict)) + \
         map(build_radio_group, mutually_exclusive_groups) 
Example #14
Source File: manager.py    From zun with Apache License 2.0 6 votes vote down vote up
def _wait_for_volumes_deleted(self, context, volmaps, container,
                                  timeout=60, poll_interval=1):
        start_time = time.time()
        try:
            volmaps = itertools.chain(volmaps)
            volmap = next(volmaps)
            while time.time() - start_time < timeout:
                if not volmap.auto_remove:
                    volmap = next(volmaps)
                driver = self._get_driver(container)
                is_deleted, is_error = driver.is_volume_deleted(
                    context, volmap)
                if is_deleted:
                    volmap = next(volmaps)
                if is_error:
                    break
                time.sleep(poll_interval)
        except StopIteration:
            return
        msg = _("Volumes cannot be successfully deleted after "
                "%d seconds") % (timeout)
        self._fail_container(context, container, msg, unset_host=True)
        raise exception.Conflict(msg) 
Example #15
Source File: agent_pop.py    From indras_net with GNU General Public License v3.0 5 votes vote down vote up
def __iter__(self):
        alists = []
        for var in self.varieties_iter():
            alists.append(self.vars[var]["agents"])
        # create an iterator that chains the lists together as if one:
        return itertools.chain(*alists) 
Example #16
Source File: grid_env.py    From indras_net with GNU General Public License v3.0 5 votes vote down vote up
def __iter__(self):
        """
        Iterate over all our cells: note,
        right now, this return the center cell twice.
        """
        return itertools.chain(self.views) 
Example #17
Source File: grid_env.py    From indras_net with GNU General Public License v3.0 5 votes vote down vote up
def __iter__(self):
        # create an iterator that chains the
        #  rows of grid together as if one list:
        return itertools.chain(*self.grid) 
Example #18
Source File: pipes.py    From pypette with MIT License 5 votes vote down vote up
def report(self):
        """Method to pretty print the report."""
        print("")
        print(crayons.green(self.name, bold=True))

        if not self.thread_map:
            print(crayons.red("No jobs run in pipeline yet !"))
            return

        joblen = len(self.thread_map)
        for i, jobs in enumerate(self.thread_map.values()):
            print(crayons.blue(u"| "))
            if len(jobs) == 1:
                print(crayons.blue(u"\u21E8  ") + Pipe._cstate(jobs[0]))
            else:
                if i == joblen - 1:
                    pre = u"  "
                else:
                    pre = u"| "
                l1 = [u"-" * 10 for j in jobs]
                l1 = u"".join(l1)
                l1 = l1[:-1]
                print(crayons.blue(u"\u21E8 ") + crayons.blue(l1))
                fmt = u"{0:^{wid}}"
                l2 = [fmt.format(u"\u21E9", wid=12) for j in jobs]
                print(crayons.blue(pre) + crayons.blue(u"".join(l2)))
                l3 = [
                    Pipe._cstate(fmt.format(j.state.name, wid=12))
                    for j in jobs
                ]
                print(crayons.blue(pre) + u"".join(l3))

        pipes = filter(
            lambda x: isinstance(x.job, Pipe), chain(*self.thread_map.values())
        )

        for item in pipes:
            item.job.report() 
Example #19
Source File: pipes.py    From pypette with MIT License 5 votes vote down vote up
def _pretty_print(self):
        """Method to pretty print the pipeline."""
        print("")
        print(crayons.green(self.name, bold=True))

        if not self.job_map:
            print(crayons.red("No jobs added to the pipeline yet !"))
            return

        joblen = len(self.job_map)
        for i, jobs in enumerate(self.job_map.values()):
            print(crayons.blue(u"| "))
            if len(jobs) == 1:
                print(crayons.blue(u"\u21E8  ") + crayons.white(jobs[0].name))
            else:
                if i == joblen - 1:
                    pre = u"  "
                else:
                    pre = u"| "
                l1 = [u"-" * (len(j.name) + 2) for j in jobs]
                l1 = u"".join(l1)
                l1 = l1[: -len(jobs[-1].name) // 2 + 1]
                print(crayons.blue(u"\u21E8 ") + crayons.blue(l1))
                fmt = u"{0:^{wid}}"
                l2 = [fmt.format(u"\u21E9", wid=len(j.name) + 2) for j in jobs]
                print(crayons.blue(pre) + crayons.blue(u"".join(l2)))
                l3 = [fmt.format(j.name, wid=len(j.name) + 2) for j in jobs]
                print(crayons.blue(pre) + crayons.white(u"".join(l3)))

        pipes = filter(
            lambda x: isinstance(x, Pipe), chain(*self.job_map.values())
        )

        for item in pipes:
            item._pretty_print() 
Example #20
Source File: ggtnn_graph_parse.py    From gated-graph-transformer-network with MIT License 5 votes vote down vote up
def get_wordlist(stories):
    words = [PAD_WORD] + sorted(list(set((word
        for (sents_graphs, query, answer) in stories
        for wordbag in itertools.chain((s for s,g in sents_graphs), [query])
        for word in wordbag ))))
    wordmap = list_to_map(words)
    return words, wordmap 
Example #21
Source File: model.py    From gated-graph-transformer-network with MIT License 5 votes vote down vote up
def params(self):
        return list(itertools.chain(*(l.params for l in self.parameterized))) 
Example #22
Source File: input_embedding.py    From icme2019 with MIT License 5 votes vote down vote up
def get_inputs_list(inputs):
    return list(chain(*list(map(lambda x: x.values(), filter(lambda x: x is not None, inputs))))) 
Example #23
Source File: managers.py    From pinax-documents with MIT License 5 votes vote down vote up
def members(self, folder, **kwargs):
        direct = kwargs.get("direct", True)
        user = kwargs.get("user")
        Document = apps.get_model("documents", "Document")
        folders = self.filter(parent=folder)
        documents = Document.objects.filter(folder=folder)
        if user:
            folders = folders.for_user(user)
            documents = documents.for_user(user)
        M = sorted(itertools.chain(folders, documents), key=operator.attrgetter("name"))
        if direct:
            return M
        for child in folders:
            M.extend(self.members(child, **kwargs))
        return M 
Example #24
Source File: db.py    From query-exporter with GNU General Public License v3.0 5 votes vote down vote up
def labels(self) -> FrozenSet[str]:
        """Resturn all labels for metrics in the query."""
        return frozenset(chain(*(metric.labels for metric in self.metrics))) 
Example #25
Source File: bio_utils.py    From models with MIT License 5 votes vote down vote up
def sequence_to_int(sequences, max_len):
    if type(sequences) is list:
        seqs_enc = np.asarray([nucleotide_to_int(read, max_len) for read in sequences], 'uint8')
    else:
        seqs_enc = np.asarray([nucleotide_to_int(read, max_len) for read in sequences], 'uint8')
        seqs_enc = list(itertools.chain(*seqs_enc))
        seqs_enc = np.asarray(seqs_enc)

    return seqs_enc 
Example #26
Source File: model.py    From VSE-C with MIT License 5 votes vote down vote up
def init_weights(self):
        for name, parameter in itertools.chain(self.gru_f.named_parameters(), self.gru_b.named_parameters()):
            if name.startswith('weight'):
                init.orthogonal(parameter.data)
            elif name.startswith('bias'):
                parameter.data.zero_()
            else:
                raise ValueError('Unknown parameter type: {}'.format(name)) 
Example #27
Source File: data_producer.py    From neural-pipeline with MIT License 5 votes vote down vote up
def __init__(self, dataset:AbstractDataset, indices: list):
        super().__init__(dataset)
        self.indices = list(itertools.chain(*indices)) 
Example #28
Source File: ner.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def build_vocab(nested_list):
    """
    :param nested_list: list of list of string
    :return: dictionary mapping from string to int, inverse of that dictionary
    """
    # Build vocabulary
    word_counts = Counter(itertools.chain(*nested_list))

    # Mapping from index to label
    vocabulary_inv = [x[0] for x in word_counts.most_common()]

    # Mapping from label to index
    vocabulary = {x: i for i, x in enumerate(vocabulary_inv)}
    return vocabulary, vocabulary_inv 
Example #29
Source File: data_helpers.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def build_vocab(sentences):
    """
    Builds a vocabulary mapping from word to index based on the sentences.
    Returns vocabulary mapping and inverse vocabulary mapping.
    """
    # Build vocabulary
    word_counts = Counter(itertools.chain(*sentences))
    # Mapping from index to word
    vocabulary_inv = [x[0] for x in word_counts.most_common()]
    # Mapping from word to index
    vocabulary = {x: i for i, x in enumerate(vocabulary_inv)}
    return [vocabulary, vocabulary_inv] 
Example #30
Source File: data_helpers.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def build_vocab(sentences):
    """
    Builds a vocabulary mapping from word to index based on the sentences.
    Returns vocabulary mapping and inverse vocabulary mapping.
    """
    # Build vocabulary
    word_counts = Counter(itertools.chain(*sentences))
    # Mapping from index to word
    vocabulary_inv = [x[0] for x in word_counts.most_common()]
    # Mapping from word to index
    vocabulary = {x: i for i, x in enumerate(vocabulary_inv)}
    return [vocabulary, vocabulary_inv]