Python functools.partial() Examples
The following are 30
code examples of functools.partial().
Example #1
Source File: From mmdetection with Apache License 2.0 | 8 votes |
def multi_apply(func, *args, **kwargs): """Apply function to a list of arguments. Note: This function applies the ``func`` to multiple inputs and map the multiple outputs of the ``func`` into different list. Each list contains the same type of outputs corresponding to different inputs. Args: func (Function): A function that will be applied to a list of arguments Returns: tuple(list): A tuple containing multiple list, each list contains a kind of returned results by the function """ pfunc = partial(func, **kwargs) if kwargs else func map_results = map(pfunc, *args) return tuple(map(list, zip(*map_results)))
Example #2
Source File: From gql with MIT License | 6 votes |
def get_arg_serializer(arg_type): if isinstance(arg_type, GraphQLNonNull): return get_arg_serializer(arg_type.of_type) if isinstance(arg_type, GraphQLInputField): return get_arg_serializer(arg_type.type) if isinstance(arg_type, GraphQLInputObjectType): serializers = {k: get_arg_serializer(v) for k, v in arg_type.fields.items()} return lambda value: ObjectValueNode( fields=FrozenList( ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v)) for k, v in value.items() ) ) if isinstance(arg_type, GraphQLList): inner_serializer = get_arg_serializer(arg_type.of_type) return partial(serialize_list, inner_serializer) if isinstance(arg_type, GraphQLEnumType): return lambda value: EnumValueNode(value=arg_type.serialize(value)) return lambda value: ast_from_value(arg_type.serialize(value), arg_type)
Example #3
Source File: From cherrypy with BSD 3-Clause "New" or "Revised" License | 6 votes |
def subscribe(self, channel, callback=None, priority=None): """Add the given callback at the given channel (if not present). If callback is None, return a partial suitable for decorating the callback. """ if callback is None: return functools.partial( self.subscribe, channel, priority=priority, ) ch_listeners = self.listeners.setdefault(channel, set()) ch_listeners.add(callback) if priority is None: priority = getattr(callback, 'priority', 50) self._priorities[(channel, callback)] = priority
Example #4
Source File: From LPHK with GNU General Public License v3.0 | 6 votes |
def popup_choice(self, window, title, image, text, choices): popup = tk.Toplevel(window) popup.resizable(False, False) if MAIN_ICON != None: if os.path.splitext(MAIN_ICON)[1].lower() == ".gif": dummy = None'wm', 'iconphoto', popup._w, tk.PhotoImage(file=MAIN_ICON)) else: popup.iconbitmap(MAIN_ICON) popup.wm_title(title) popup.tkraise(window) def run_end(func): popup.destroy() if func != None: func() picture_label = tk.Label(popup, image=image) = image picture_label.grid(column=0, row=0, rowspan=2, padx=10, pady=10) tk.Label(popup, text=text, justify=tk.CENTER).grid(column=1, row=0, columnspan=len(choices), padx=10, pady=10) for idx, choice in enumerate(choices): run_end_func = partial(run_end, choice[1]) tk.Button(popup, text=choice[0], command=run_end_func).grid(column=1 + idx, row=1, padx=10, pady=10) popup.wait_visibility() popup.grab_set() popup.wait_window()
Example #5
Source File: From wafw00f with BSD 3-Clause "New" or "Revised" License | 6 votes |
def load_plugins(): here = os.path.abspath(os.path.dirname(__file__)) get_path = partial(os.path.join, here) plugin_dir = get_path('plugins') plugin_base = PluginBase( package='wafw00f.plugins', searchpath=[plugin_dir] ) plugin_source = plugin_base.make_plugin_source( searchpath=[plugin_dir], persist=True ) plugin_dict = {} for plugin_name in plugin_source.list_plugins(): plugin_dict[plugin_name] = plugin_source.load_plugin(plugin_name) return plugin_dict
Example #6
Source File: From aegea with Apache License 2.0 | 6 votes |
def tasks(args): list_clusters = clients.ecs.get_paginator("list_clusters") list_tasks = clients.ecs.get_paginator("list_tasks") def list_tasks_worker(worker_args): cluster, status = worker_args return cluster, status, list(paginate(list_tasks, cluster=cluster, desiredStatus=status)) def describe_tasks_worker(t, cluster=None): return clients.ecs.describe_tasks(cluster=cluster, tasks=t)["tasks"] if t else [] task_descs = [] if args.clusters is None: args.clusters = [__name__.replace(".", "_")] if args.tasks else list(paginate(list_clusters)) if args.tasks: task_descs = describe_tasks_worker(args.tasks, cluster=args.clusters[0]) else: with ThreadPoolExecutor() as executor: for cluster, status, tasks in, product(args.clusters, args.desired_status)): worker = partial(describe_tasks_worker, cluster=cluster) descs =, (tasks[pos:pos + 100] for pos in range(0, len(tasks), 100))) task_descs += sum(descs, []) page_output(tabulate(task_descs, args))
Example #7
Source File: From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License | 6 votes |
def fprop(self, x, **kwargs): del kwargs my_conv = functools.partial(tf.layers.conv2d, kernel_size=3, strides=2, padding='valid', activation=tf.nn.relu, kernel_initializer=HeReLuNormalInitializer) my_dense = functools.partial( tf.layers.dense, kernel_initializer=HeReLuNormalInitializer) with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE): for depth in [96, 256, 384, 384, 256]: x = my_conv(x, depth) y = tf.layers.flatten(x) y = my_dense(y, 4096, tf.nn.relu) y = fc7 = my_dense(y, 4096, tf.nn.relu) y = my_dense(y, 1000) return {'fc7': fc7, self.O_LOGITS: y, self.O_PROBS: tf.nn.softmax(logits=y)}
Example #8
Source File: From multibootusb with GNU General Public License v2.0 | 6 votes |
def run(self): self.monitor.start() notifier = poll.Poll.for_events( (self.monitor, 'r'), (self._stop_event.source, 'r')) while True: for file_descriptor, event in eintr_retry_call(notifier.poll): if file_descriptor == self._stop_event.source.fileno(): # in case of a stop event, close our pipe side, and # return from the thread self._stop_event.source.close() return elif file_descriptor == self.monitor.fileno() and event == 'r': read_device = partial(eintr_retry_call, self.monitor.poll, timeout=0) for device in iter(read_device, None): self._callback(device) else: raise EnvironmentError('Observed monitor hung up')
Example #9
Source File: From spleeter with MIT License | 6 votes |
def _create_evaluation_spec(params, audio_adapter, audio_path): """ Setup eval spec evaluating ever n seconds :param params: TF params to build spec from. :returns: Built evaluation spec. """ input_fn = partial( get_validation_dataset, params, audio_adapter, audio_path) evaluation_spec = tf.estimator.EvalSpec( input_fn=input_fn, steps=None, throttle_secs=params['throttle_secs']) return evaluation_spec
Example #10
Source File: From deep-learning-note with MIT License | 6 votes |
def __init__(self, learning_rate, max_iteration_steps, seed=None): """Initializes a `Generator` that builds `SimpleCNNs`. Args: learning_rate: The float learning rate to use. max_iteration_steps: The number of steps per iteration. seed: The random seed. Returns: An instance of `Generator`. """ self._seed = seed self._cnn_builder_fn = functools.partial( SimpleCNNBuilder, learning_rate=learning_rate, max_iteration_steps=max_iteration_steps)
Example #11
Source File: From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 | 6 votes |
def __init__(self, input_shape, num_hidden, h2h_kernel=(3, 3), h2h_dilate=(1, 1), i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1), i2h_dilate=(1, 1), i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', activation=functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2), prefix='ConvRNN_', params=None, conv_layout='NCHW'): super(ConvRNNCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden, h2h_kernel=h2h_kernel, h2h_dilate=h2h_dilate, i2h_kernel=i2h_kernel, i2h_stride=i2h_stride, i2h_pad=i2h_pad, i2h_dilate=i2h_dilate, i2h_weight_initializer=i2h_weight_initializer, h2h_weight_initializer=h2h_weight_initializer, i2h_bias_initializer=i2h_bias_initializer, h2h_bias_initializer=h2h_bias_initializer, activation=activation, prefix=prefix, params=params, conv_layout=conv_layout)
Example #12
Source File: From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 | 6 votes |
def __init__(self, input_shape, num_hidden, h2h_kernel=(3, 3), h2h_dilate=(1, 1), i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1), i2h_dilate=(1, 1), i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', activation=functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2), prefix='ConvLSTM_', params=None, conv_layout='NCHW'): super(ConvLSTMCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden, h2h_kernel=h2h_kernel, h2h_dilate=h2h_dilate, i2h_kernel=i2h_kernel, i2h_stride=i2h_stride, i2h_pad=i2h_pad, i2h_dilate=i2h_dilate, i2h_weight_initializer=i2h_weight_initializer, h2h_weight_initializer=h2h_weight_initializer, i2h_bias_initializer=i2h_bias_initializer, h2h_bias_initializer=h2h_bias_initializer, activation=activation, prefix=prefix, params=params, conv_layout=conv_layout)
Example #13
Source File: From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 | 6 votes |
def __init__(self, input_shape, num_hidden, h2h_kernel=(3, 3), h2h_dilate=(1, 1), i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1), i2h_dilate=(1, 1), i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', activation=functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2), prefix='ConvGRU_', params=None, conv_layout='NCHW'): super(ConvGRUCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden, h2h_kernel=h2h_kernel, h2h_dilate=h2h_dilate, i2h_kernel=i2h_kernel, i2h_stride=i2h_stride, i2h_pad=i2h_pad, i2h_dilate=i2h_dilate, i2h_weight_initializer=i2h_weight_initializer, h2h_weight_initializer=h2h_weight_initializer, i2h_bias_initializer=i2h_bias_initializer, h2h_bias_initializer=h2h_bias_initializer, activation=activation, prefix=prefix, params=params, conv_layout=conv_layout)
Example #14
Source File: From DOTA_models with Apache License 2.0 | 6 votes |
def main(unused_argv): assert FLAGS.checkpoint_dir, '`checkpoint_dir` is missing.' assert FLAGS.eval_dir, '`eval_dir` is missing.' if FLAGS.pipeline_config_path: model_config, eval_config, input_config = get_configs_from_pipeline_file() else: model_config, eval_config, input_config = get_configs_from_multiple_files() model_fn = functools.partial(, model_config=model_config, is_training=False) create_input_dict_fn = functools.partial(, input_config) label_map = label_map_util.load_labelmap(input_config.label_map_path) max_num_classes = max([ for item in label_map.item]) categories = label_map_util.convert_label_map_to_categories( label_map, max_num_classes) evaluator.evaluate(create_input_dict_fn, model_fn, eval_config, categories, FLAGS.checkpoint_dir, FLAGS.eval_dir)
Example #15
Source File: From DOTA_models with Apache License 2.0 | 6 votes |
def mmd_loss(source_samples, target_samples, weight, scope=None): """Adds a similarity loss term, the MMD between two representations. This Maximum Mean Discrepancy (MMD) loss is calculated with a number of different Gaussian kernels. Args: source_samples: a tensor of shape [num_samples, num_features]. target_samples: a tensor of shape [num_samples, num_features]. weight: the weight of the MMD loss. scope: optional name scope for summary tags. Returns: a scalar tensor representing the MMD loss value. """ sigmas = [ 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100, 1e3, 1e4, 1e5, 1e6 ] gaussian_kernel = partial( utils.gaussian_kernel_matrix, sigmas=tf.constant(sigmas)) loss_value = maximum_mean_discrepancy( source_samples, target_samples, kernel=gaussian_kernel) loss_value = tf.maximum(1e-4, loss_value) * weight assert_op = tf.Assert(tf.is_finite(loss_value), [loss_value]) with tf.control_dependencies([assert_op]): tag = 'MMD Loss' if scope: tag = scope + tag tf.summary.scalar(tag, loss_value) tf.losses.add_loss(loss_value) return loss_value
Example #16
Source File: From soccer-matlab with BSD 2-Clause "Simplified" License | 6 votes |
def _build_randomization_function_dict(self, env): func_dict = {} func_dict["mass"] = functools.partial( self._randomize_masses, minitaur=env.minitaur) func_dict["inertia"] = functools.partial( self._randomize_inertia, minitaur=env.minitaur) func_dict["latency"] = functools.partial( self._randomize_latency, minitaur=env.minitaur) func_dict["joint friction"] = functools.partial( self._randomize_joint_friction, minitaur=env.minitaur) func_dict["motor friction"] = functools.partial( self._randomize_motor_friction, minitaur=env.minitaur) func_dict["restitution"] = functools.partial( self._randomize_contact_restitution, minitaur=env.minitaur) func_dict["lateral friction"] = functools.partial( self._randomize_contact_friction, minitaur=env.minitaur) func_dict["battery"] = functools.partial( self._randomize_battery_level, minitaur=env.minitaur) func_dict["motor strength"] = functools.partial( self._randomize_motor_strength, minitaur=env.minitaur) # Settinmg control step needs access to the environment. func_dict["control step"] = functools.partial( self._randomize_control_step, env=env) return func_dict
Example #17
Source File: From soccer-matlab with BSD 2-Clause "Simplified" License | 6 votes |
def step(self, action, blocking=True): """Step the environment. Args: action: The action to apply to the environment. blocking: Whether to wait for the result. Returns: Transition tuple when blocking, otherwise callable that returns the transition tuple. """ self._conn.send((self._ACTION, action)) if blocking: return self._receive(self._TRANSITION) else: return functools.partial(self._receive, self._TRANSITION)
Example #18
Source File: From End-to-end-ASR-Pytorch with MIT License | 6 votes |
def load_textset(n_jobs, use_gpu, pin_memory, corpus, text): # Text tokenizer tokenizer = load_text_encoder(**text) # Dataset tr_set, dv_set, tr_loader_bs, dv_loader_bs, data_msg = create_textset( tokenizer, **corpus) collect_tr = partial(collect_text_batch, mode='train') collect_dv = partial(collect_text_batch, mode='dev') # Dataloader (Text data stored in RAM, no need num_workers) tr_set = DataLoader(tr_set, batch_size=tr_loader_bs, shuffle=True, drop_last=True, collate_fn=collect_tr, num_workers=0, pin_memory=use_gpu) dv_set = DataLoader(dv_set, batch_size=dv_loader_bs, shuffle=False, drop_last=False, collate_fn=collect_dv, num_workers=0, pin_memory=pin_memory) # Messages to show data_msg.append('I/O spec. | Token type = {}\t| Vocab size = {}' .format(tokenizer.token_type, tokenizer.vocab_size)) return tr_set, dv_set, tokenizer.vocab_size, tokenizer, data_msg
Example #19
Source File: From facebook-wda with MIT License | 6 votes |
def set_alert_callback(self, callback): """ Args: callback (func): called when alert popup Example of callback: def callback(session): session.alert.accept() """ if callable(callback): self.http.alert_callback = functools.partial(callback, self) else: self.http.alert_callback = None #Not working #def get_clipboard(self): # return"/wda/getPasteboard").value # Not working #def siri_activate(self, text): #"/wda/siri/activate", {"text": text})
Example #20
Source File: From aegea with Apache License 2.0 | 5 votes |
def add_time_bound_args(p, snap=0): t = partial(Timestamp, snap=snap) p.add_argument("--start-time", type=t, default=Timestamp("-7d", snap=snap), help=Timestamp.__doc__, metavar="START") p.add_argument("--end-time", type=t, help=Timestamp.__doc__, metavar="END")
Example #21
Source File: From DOTA_models with Apache License 2.0 | 5 votes |
def test_mmd_is_zero_when_inputs_are_same(self): with self.test_session(): x = tf.random_uniform((2, 3), seed=1) kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.])) self.assertEquals(0, losses.maximum_mean_discrepancy(x, x, kernel).eval())
Example #22
Source File: From DOTA_models with Apache License 2.0 | 5 votes |
def test_mmd_name(self): with self.test_session(): x = tf.random_uniform((2, 3), seed=1) kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.])) loss = losses.maximum_mean_discrepancy(x, x, kernel) self.assertEquals(, 'MaximumMeanDiscrepancy/value')
Example #23
Source File: From DOTA_models with Apache License 2.0 | 5 votes |
def test_fast_mmd_is_similar_to_slow_mmd(self): with self.test_session(): x = tf.constant(np.random.normal(size=(2, 3)), tf.float32) y = tf.constant(np.random.rand(2, 3), tf.float32) cost_old = MaximumMeanDiscrepancySlow(x, y, [1.]).eval() kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.])) cost_new = losses.maximum_mean_discrepancy(x, y, kernel).eval() self.assertAlmostEqual(cost_old, cost_new, delta=1e-5)
Example #24
Source File: From DOTA_models with Apache License 2.0 | 5 votes |
def test_mmd_is_zero_when_distributions_are_same(self): with self.test_session(): x = tf.random_uniform((1000, 10), seed=1) y = tf.random_uniform((1000, 10), seed=3) kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([100.])) loss = losses.maximum_mean_discrepancy(x, y, kernel=kernel).eval() self.assertAlmostEqual(0, loss, delta=1e-4)
Example #25
Source File: From soccer-matlab with BSD 2-Clause "Simplified" License | 5 votes |
def reset(self, blocking=True): """Reset the environment. Args: blocking: Whether to wait for the result. Returns: New observation when blocking, otherwise callable that returns the new observation. """ self._conn.send((self._RESET, None)) if blocking: return self._receive(self._OBSERV) else: return functools.partial(self._receive, self._OBSERV)
Example #26
Source File: From soccer-matlab with BSD 2-Clause "Simplified" License | 5 votes |
def test_close_no_hang_after_init(self): constructor = functools.partial( tools.MockEnvironment, observ_shape=(2, 3), action_shape=(2,), min_duration=2, max_duration=2) env = tools.wrappers.ExternalProcess(constructor) env.close()
Example #27
Source File: From soccer-matlab with BSD 2-Clause "Simplified" License | 5 votes |
def test_reraise_exception_in_step(self): constructor = functools.partial( MockEnvironmentCrashInStep, crash_at_step=3) env = tools.wrappers.ExternalProcess(constructor) env.reset() env.step(env.action_space.sample()) env.step(env.action_space.sample()) with self.assertRaises(Exception): env.step(env.action_space.sample())
Example #28
Source File: From soccer-matlab with BSD 2-Clause "Simplified" License | 5 votes |
def test_no_crash_observation_shape(self): nets = networks.ForwardGaussianPolicy, networks.RecurrentGaussianPolicy observ_shapes = (1,), (2, 3), (2, 3, 4) for network, observ_shape in itertools.product(nets, observ_shapes): config = self._define_config() with config.unlocked: config.env = functools.partial( tools.MockEnvironment, observ_shape, action_shape=(3,), min_duration=15, max_duration=15) config.max_length = 20 config.steps = 100 = network for score in train.train(config, env_processes=False): float(score)
Example #29
Source File: From soccer-matlab with BSD 2-Clause "Simplified" License | 5 votes |
def pybullet_racecar(): """Configuration for Bullet MIT Racecar task.""" locals().update(default()) # Environment env = 'RacecarBulletEnv-v0' #functools.partial(racecarGymEnv.RacecarGymEnv, isDiscrete=False, renders=True) max_length = 10 steps = 1e7 # 10M return locals()
Example #30
Source File: From soccer-matlab with BSD 2-Clause "Simplified" License | 5 votes |
def test_no_crash_variable_duration(self): config = self._define_config() with config.unlocked: config.env = functools.partial( tools.MockEnvironment, observ_shape=(2, 3), action_shape=(3,), min_duration=5, max_duration=25) config.max_length = 25 config.steps = 200 = networks.RecurrentGaussianPolicy for score in train.train(config, env_processes=False): float(score)