Python toolz.assoc() Examples

The following are 6 code examples of toolz.assoc(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module toolz , or try the search function .
Example #1
Source File: test_term.py    From catalyst with Apache License 2.0 6 votes vote down vote up
def test_parameterized_term_default_value(self):
        defaults = {'a': 'default for a', 'b': 'default for b'}

        class F(Factor):
            params = defaults

            inputs = (SomeDataSet.foo,)
            dtype = 'f8'
            window_length = 5

        assert_equal(F().params, defaults)
        assert_equal(F(a='new a').params, assoc(defaults, 'a', 'new a'))
        assert_equal(F(b='new b').params, assoc(defaults, 'b', 'new b'))
        assert_equal(
            F(a='new a', b='new b').params,
            {'a': 'new a', 'b': 'new b'},
        ) 
Example #2
Source File: test_term.py    From catalyst with Apache License 2.0 6 votes vote down vote up
def test_parameterized_term_default_value_with_not_specified(self):
        defaults = {'a': 'default for a', 'b': NotSpecified}

        class F(Factor):
            params = defaults

            inputs = (SomeDataSet.foo,)
            dtype = 'f8'
            window_length = 5

        pattern = r"F expected a keyword parameter 'b'\."
        with assert_raises_regex(TypeError, pattern):
            F()
        with assert_raises_regex(TypeError, pattern):
            F(a='new a')

        assert_equal(F(b='new b').params, assoc(defaults, 'b', 'new b'))
        assert_equal(
            F(a='new a', b='new b').params,
            {'a': 'new a', 'b': 'new b'},
        ) 
Example #3
Source File: p2p_proto.py    From pyquarkchain with MIT License 5 votes vote down vote up
def decode(self, data: bytes) -> _DecodedMsgType:
        try:
            raw_decoded = cast(Dict[str, int], super().decode(data))
        except rlp.exceptions.ListDeserializationError:
            self.logger.warning("Malformed Disconnect message: %s" % data)
            raise MalformedMessage("Malformed Disconnect message: {}".format(data))
        return assoc(
            raw_decoded, "reason_name", self.get_reason_name(raw_decoded["reason"])
        ) 
Example #4
Source File: base_network.py    From cloudformation-environmentbase with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def _get_subnet_config_w_az(self, network_config):
        az_count = int(network_config.get('az_count', 2))
        subnet_config = network_config.get('subnet_config', {})

        for subnet in subnet_config:
            for az in range(az_count):
                newsubnet = assoc(subnet, 'AZ', az)
                yield newsubnet 
Example #5
Source File: base_network.py    From cloudformation-environmentbase with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def _get_subnet_config_w_cidr(self, network_config):
        network_cidr_base = str(network_config.get('network_cidr_base', '172.16.0.0'))
        network_cidr_size = str(network_config.get('network_cidr_size', '20'))
        first_network_address_block = str(network_config.get('first_network_address_block', network_cidr_base))

        ret_val = {}
        base_cidr = network_cidr_base + '/' + network_cidr_size
        net = netaddr.IPNetwork(base_cidr)

        grouped_subnet = groupby('size', self._get_subnet_config_w_az(network_config))
        subnet_groups = sorted(grouped_subnet.items())
        available_cidrs = []

        for subnet_size, subnet_configs in subnet_groups:
            newcidrs = net.subnet(int(subnet_size))

            for subnet_config in subnet_configs:
                try:
                    cidr = newcidrs.next()
                except StopIteration as e:
                    net = chain(*reversed(available_cidrs)).next()
                    newcidrs = net.subnet(int(subnet_size))
                    cidr = newcidrs.next()

                new_config = assoc(subnet_config, 'cidr', str(cidr))
                yield new_config
            else:
                net = newcidrs.next()
                available_cidrs.append(newcidrs) 
Example #6
Source File: core.py    From dask-lightgbm with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
def train(client, data, label, params, model_factory, weight=None, **kwargs):
    # Split arrays/dataframes into parts. Arrange parts into tuples to enforce co-locality
    data_parts = _split_to_parts(data, is_matrix=True)
    label_parts = _split_to_parts(label, is_matrix=False)
    if weight is None:
        parts = list(map(delayed, zip(data_parts, label_parts)))
    else:
        weight_parts = _split_to_parts(weight, is_matrix=False)
        parts = list(map(delayed, zip(data_parts, label_parts, weight_parts)))

    # Start computation in the background
    parts = client.compute(parts)
    wait(parts)

    for part in parts:
        if part.status == 'error':
            return part  # trigger error locally

    # Find locations of all parts and map them to particular Dask workers
    key_to_part_dict = dict([(part.key, part) for part in parts])
    who_has = client.who_has(parts)
    worker_map = defaultdict(list)
    for key, workers in who_has.items():
        worker_map[first(workers)].append(key_to_part_dict[key])

    master_worker = first(worker_map)
    worker_ncores = client.ncores()

    if 'tree_learner' not in params or params['tree_learner'].lower() not in {'data', 'feature', 'voting'}:
        logger.warning('Parameter tree_learner not set or set to incorrect value '
                       f'({params.get("tree_learner", None)}), using "data" as default')
        params['tree_learner'] = 'data'

    # Tell each worker to train on the parts that it has locally
    futures_classifiers = [client.submit(_train_part,
                                         model_factory=model_factory,
                                         params=assoc(params, 'num_threads', worker_ncores[worker]),
                                         list_of_parts=list_of_parts,
                                         worker_addresses=list(worker_map.keys()),
                                         local_listen_port=params.get('local_listen_port', 12400),
                                         time_out=params.get('time_out', 120),
                                         return_model=(worker == master_worker),
                                         **kwargs)
                           for worker, list_of_parts in worker_map.items()]

    results = client.gather(futures_classifiers)
    results = [v for v in results if v]
    return results[0]