Python functools.partial() Examples

The following are 30 code examples of functools.partial(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module functools , or try the search function .
Example #1
Source File: misc.py    From mmdetection with Apache License 2.0 8 votes vote down vote up
def multi_apply(func, *args, **kwargs):
    """Apply function to a list of arguments.

    Note:
        This function applies the ``func`` to multiple inputs and
            map the multiple outputs of the ``func`` into different
            list. Each list contains the same type of outputs corresponding
            to different inputs.

    Args:
        func (Function): A function that will be applied to a list of
            arguments

    Returns:
        tuple(list): A tuple containing multiple list, each list contains
            a kind of returned results by the function
    """
    pfunc = partial(func, **kwargs) if kwargs else func
    map_results = map(pfunc, *args)
    return tuple(map(list, zip(*map_results))) 
Example #2
Source File: dsl.py    From gql with MIT License 6 votes vote down vote up
def get_arg_serializer(arg_type):
    if isinstance(arg_type, GraphQLNonNull):
        return get_arg_serializer(arg_type.of_type)
    if isinstance(arg_type, GraphQLInputField):
        return get_arg_serializer(arg_type.type)
    if isinstance(arg_type, GraphQLInputObjectType):
        serializers = {k: get_arg_serializer(v) for k, v in arg_type.fields.items()}
        return lambda value: ObjectValueNode(
            fields=FrozenList(
                ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v))
                for k, v in value.items()
            )
        )
    if isinstance(arg_type, GraphQLList):
        inner_serializer = get_arg_serializer(arg_type.of_type)
        return partial(serialize_list, inner_serializer)
    if isinstance(arg_type, GraphQLEnumType):
        return lambda value: EnumValueNode(value=arg_type.serialize(value))
    return lambda value: ast_from_value(arg_type.serialize(value), arg_type) 
Example #3
Source File: wspbus.py    From cherrypy with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def subscribe(self, channel, callback=None, priority=None):
        """Add the given callback at the given channel (if not present).

        If callback is None, return a partial suitable for decorating
        the callback.
        """
        if callback is None:
            return functools.partial(
                self.subscribe,
                channel,
                priority=priority,
            )

        ch_listeners = self.listeners.setdefault(channel, set())
        ch_listeners.add(callback)

        if priority is None:
            priority = getattr(callback, 'priority', 50)
        self._priorities[(channel, callback)] = priority 
Example #4
Source File: window.py    From LPHK with GNU General Public License v3.0 6 votes vote down vote up
def popup_choice(self, window, title, image, text, choices):
        popup = tk.Toplevel(window)
        popup.resizable(False, False)
        if MAIN_ICON != None:
            if os.path.splitext(MAIN_ICON)[1].lower() == ".gif":
                dummy = None
                #popup.call('wm', 'iconphoto', popup._w, tk.PhotoImage(file=MAIN_ICON))
            else:
                popup.iconbitmap(MAIN_ICON)
        popup.wm_title(title)
        popup.tkraise(window)
        
        def run_end(func):
            popup.destroy()
            if func != None:
                func()

        picture_label = tk.Label(popup, image=image)
        picture_label.photo = image
        picture_label.grid(column=0, row=0, rowspan=2, padx=10, pady=10)
        tk.Label(popup, text=text, justify=tk.CENTER).grid(column=1, row=0, columnspan=len(choices), padx=10, pady=10)
        for idx, choice in enumerate(choices):
            run_end_func = partial(run_end, choice[1])
            tk.Button(popup, text=choice[0], command=run_end_func).grid(column=1 + idx, row=1, padx=10, pady=10)
        popup.wait_visibility()
        popup.grab_set()
        popup.wait_window() 
Example #5
Source File: manager.py    From wafw00f with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def load_plugins():
    here = os.path.abspath(os.path.dirname(__file__))
    get_path = partial(os.path.join, here)
    plugin_dir = get_path('plugins')

    plugin_base = PluginBase(
        package='wafw00f.plugins', searchpath=[plugin_dir]
    )
    plugin_source = plugin_base.make_plugin_source(
        searchpath=[plugin_dir], persist=True
    )

    plugin_dict = {}
    for plugin_name in plugin_source.list_plugins():
        plugin_dict[plugin_name] = plugin_source.load_plugin(plugin_name)

    return plugin_dict 
Example #6
Source File: ecs.py    From aegea with Apache License 2.0 6 votes vote down vote up
def tasks(args):
    list_clusters = clients.ecs.get_paginator("list_clusters")
    list_tasks = clients.ecs.get_paginator("list_tasks")

    def list_tasks_worker(worker_args):
        cluster, status = worker_args
        return cluster, status, list(paginate(list_tasks, cluster=cluster, desiredStatus=status))

    def describe_tasks_worker(t, cluster=None):
        return clients.ecs.describe_tasks(cluster=cluster, tasks=t)["tasks"] if t else []

    task_descs = []
    if args.clusters is None:
        args.clusters = [__name__.replace(".", "_")] if args.tasks else list(paginate(list_clusters))
    if args.tasks:
        task_descs = describe_tasks_worker(args.tasks, cluster=args.clusters[0])
    else:
        with ThreadPoolExecutor() as executor:
            for cluster, status, tasks in executor.map(list_tasks_worker, product(args.clusters, args.desired_status)):
                worker = partial(describe_tasks_worker, cluster=cluster)
                descs = executor.map(worker, (tasks[pos:pos + 100] for pos in range(0, len(tasks), 100)))
                task_descs += sum(descs, [])
    page_output(tabulate(task_descs, args)) 
Example #7
Source File: model.py    From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def fprop(self, x, **kwargs):
        del kwargs
        my_conv = functools.partial(tf.layers.conv2d,
                                    kernel_size=3,
                                    strides=2,
                                    padding='valid',
                                    activation=tf.nn.relu,
                                    kernel_initializer=HeReLuNormalInitializer)
        my_dense = functools.partial(
            tf.layers.dense, kernel_initializer=HeReLuNormalInitializer)

        with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE):
            for depth in [96, 256, 384, 384, 256]:
                x = my_conv(x, depth)
            y = tf.layers.flatten(x)
            y = my_dense(y, 4096, tf.nn.relu)
            y = fc7 = my_dense(y, 4096, tf.nn.relu)
            y = my_dense(y, 1000)
            return {'fc7': fc7,
                    self.O_LOGITS: y,
                    self.O_PROBS: tf.nn.softmax(logits=y)} 
Example #8
Source File: monitor.py    From multibootusb with GNU General Public License v2.0 6 votes vote down vote up
def run(self):
        self.monitor.start()
        notifier = poll.Poll.for_events(
            (self.monitor, 'r'), (self._stop_event.source, 'r'))
        while True:
            for file_descriptor, event in eintr_retry_call(notifier.poll):
                if file_descriptor == self._stop_event.source.fileno():
                    # in case of a stop event, close our pipe side, and
                    # return from the thread
                    self._stop_event.source.close()
                    return
                elif file_descriptor == self.monitor.fileno() and event == 'r':
                    read_device = partial(eintr_retry_call, self.monitor.poll, timeout=0)
                    for device in iter(read_device, None):
                        self._callback(device)
                else:
                    raise EnvironmentError('Observed monitor hung up') 
Example #9
Source File: train.py    From spleeter with MIT License 6 votes vote down vote up
def _create_evaluation_spec(params, audio_adapter, audio_path):
    """ Setup eval spec evaluating ever n seconds

    :param params: TF params to build spec from.
    :returns: Built evaluation spec.
    """
    input_fn = partial(
        get_validation_dataset,
        params,
        audio_adapter,
        audio_path)
    evaluation_spec = tf.estimator.EvalSpec(
        input_fn=input_fn,
        steps=None,
        throttle_secs=params['throttle_secs'])
    return evaluation_spec 
Example #10
Source File: 2_simple_mnist.py    From deep-learning-note with MIT License 6 votes vote down vote up
def __init__(self, learning_rate, max_iteration_steps, seed=None):
        """Initializes a `Generator` that builds `SimpleCNNs`.

        Args:
          learning_rate: The float learning rate to use.
          max_iteration_steps: The number of steps per iteration.
          seed: The random seed.

        Returns:
          An instance of `Generator`.
        """
        self._seed = seed
        self._cnn_builder_fn = functools.partial(
            SimpleCNNBuilder,
            learning_rate=learning_rate,
            max_iteration_steps=max_iteration_steps) 
Example #11
Source File: rnn_cell.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 6 votes vote down vote up
def __init__(self, input_shape, num_hidden,
                 h2h_kernel=(3, 3), h2h_dilate=(1, 1),
                 i2h_kernel=(3, 3), i2h_stride=(1, 1),
                 i2h_pad=(1, 1), i2h_dilate=(1, 1),
                 i2h_weight_initializer=None, h2h_weight_initializer=None,
                 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
                 activation=functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2),
                 prefix='ConvRNN_', params=None, conv_layout='NCHW'):
        super(ConvRNNCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden,
                                          h2h_kernel=h2h_kernel, h2h_dilate=h2h_dilate,
                                          i2h_kernel=i2h_kernel, i2h_stride=i2h_stride,
                                          i2h_pad=i2h_pad, i2h_dilate=i2h_dilate,
                                          i2h_weight_initializer=i2h_weight_initializer,
                                          h2h_weight_initializer=h2h_weight_initializer,
                                          i2h_bias_initializer=i2h_bias_initializer,
                                          h2h_bias_initializer=h2h_bias_initializer,
                                          activation=activation, prefix=prefix,
                                          params=params, conv_layout=conv_layout) 
Example #12
Source File: rnn_cell.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 6 votes vote down vote up
def __init__(self, input_shape, num_hidden,
                 h2h_kernel=(3, 3), h2h_dilate=(1, 1),
                 i2h_kernel=(3, 3), i2h_stride=(1, 1),
                 i2h_pad=(1, 1), i2h_dilate=(1, 1),
                 i2h_weight_initializer=None, h2h_weight_initializer=None,
                 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
                 activation=functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2),
                 prefix='ConvLSTM_', params=None,
                 conv_layout='NCHW'):
        super(ConvLSTMCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden,
                                           h2h_kernel=h2h_kernel, h2h_dilate=h2h_dilate,
                                           i2h_kernel=i2h_kernel, i2h_stride=i2h_stride,
                                           i2h_pad=i2h_pad, i2h_dilate=i2h_dilate,
                                           i2h_weight_initializer=i2h_weight_initializer,
                                           h2h_weight_initializer=h2h_weight_initializer,
                                           i2h_bias_initializer=i2h_bias_initializer,
                                           h2h_bias_initializer=h2h_bias_initializer,
                                           activation=activation, prefix=prefix,
                                           params=params, conv_layout=conv_layout) 
Example #13
Source File: rnn_cell.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 6 votes vote down vote up
def __init__(self, input_shape, num_hidden,
                 h2h_kernel=(3, 3), h2h_dilate=(1, 1),
                 i2h_kernel=(3, 3), i2h_stride=(1, 1),
                 i2h_pad=(1, 1), i2h_dilate=(1, 1),
                 i2h_weight_initializer=None, h2h_weight_initializer=None,
                 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
                 activation=functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2),
                 prefix='ConvGRU_', params=None, conv_layout='NCHW'):
        super(ConvGRUCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden,
                                          h2h_kernel=h2h_kernel, h2h_dilate=h2h_dilate,
                                          i2h_kernel=i2h_kernel, i2h_stride=i2h_stride,
                                          i2h_pad=i2h_pad, i2h_dilate=i2h_dilate,
                                          i2h_weight_initializer=i2h_weight_initializer,
                                          h2h_weight_initializer=h2h_weight_initializer,
                                          i2h_bias_initializer=i2h_bias_initializer,
                                          h2h_bias_initializer=h2h_bias_initializer,
                                          activation=activation, prefix=prefix,
                                          params=params, conv_layout=conv_layout) 
Example #14
Source File: eval.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def main(unused_argv):
  assert FLAGS.checkpoint_dir, '`checkpoint_dir` is missing.'
  assert FLAGS.eval_dir, '`eval_dir` is missing.'
  if FLAGS.pipeline_config_path:
    model_config, eval_config, input_config = get_configs_from_pipeline_file()
  else:
    model_config, eval_config, input_config = get_configs_from_multiple_files()

  model_fn = functools.partial(
      model_builder.build,
      model_config=model_config,
      is_training=False)

  create_input_dict_fn = functools.partial(
      input_reader_builder.build,
      input_config)

  label_map = label_map_util.load_labelmap(input_config.label_map_path)
  max_num_classes = max([item.id for item in label_map.item])
  categories = label_map_util.convert_label_map_to_categories(
      label_map, max_num_classes)

  evaluator.evaluate(create_input_dict_fn, model_fn, eval_config, categories,
                     FLAGS.checkpoint_dir, FLAGS.eval_dir) 
Example #15
Source File: losses.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def mmd_loss(source_samples, target_samples, weight, scope=None):
  """Adds a similarity loss term, the MMD between two representations.

  This Maximum Mean Discrepancy (MMD) loss is calculated with a number of
  different Gaussian kernels.

  Args:
    source_samples: a tensor of shape [num_samples, num_features].
    target_samples: a tensor of shape [num_samples, num_features].
    weight: the weight of the MMD loss.
    scope: optional name scope for summary tags.

  Returns:
    a scalar tensor representing the MMD loss value.
  """
  sigmas = [
      1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100,
      1e3, 1e4, 1e5, 1e6
  ]
  gaussian_kernel = partial(
      utils.gaussian_kernel_matrix, sigmas=tf.constant(sigmas))

  loss_value = maximum_mean_discrepancy(
      source_samples, target_samples, kernel=gaussian_kernel)
  loss_value = tf.maximum(1e-4, loss_value) * weight
  assert_op = tf.Assert(tf.is_finite(loss_value), [loss_value])
  with tf.control_dependencies([assert_op]):
    tag = 'MMD Loss'
    if scope:
      tag = scope + tag
    tf.summary.scalar(tag, loss_value)
    tf.losses.add_loss(loss_value)

  return loss_value 
Example #16
Source File: minitaur_env_randomizer_from_config.py    From soccer-matlab with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def _build_randomization_function_dict(self, env):
    func_dict = {}
    func_dict["mass"] = functools.partial(
        self._randomize_masses, minitaur=env.minitaur)
    func_dict["inertia"] = functools.partial(
        self._randomize_inertia, minitaur=env.minitaur)
    func_dict["latency"] = functools.partial(
        self._randomize_latency, minitaur=env.minitaur)
    func_dict["joint friction"] = functools.partial(
        self._randomize_joint_friction, minitaur=env.minitaur)
    func_dict["motor friction"] = functools.partial(
        self._randomize_motor_friction, minitaur=env.minitaur)
    func_dict["restitution"] = functools.partial(
        self._randomize_contact_restitution, minitaur=env.minitaur)
    func_dict["lateral friction"] = functools.partial(
        self._randomize_contact_friction, minitaur=env.minitaur)
    func_dict["battery"] = functools.partial(
        self._randomize_battery_level, minitaur=env.minitaur)
    func_dict["motor strength"] = functools.partial(
        self._randomize_motor_strength, minitaur=env.minitaur)
    # Settinmg control step needs access to the environment.
    func_dict["control step"] = functools.partial(
        self._randomize_control_step, env=env)
    return func_dict 
Example #17
Source File: wrappers.py    From soccer-matlab with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def step(self, action, blocking=True):
    """Step the environment.

    Args:
      action: The action to apply to the environment.
      blocking: Whether to wait for the result.

    Returns:
      Transition tuple when blocking, otherwise callable that returns the
      transition tuple.
    """
    self._conn.send((self._ACTION, action))
    if blocking:
      return self._receive(self._TRANSITION)
    else:
      return functools.partial(self._receive, self._TRANSITION) 
Example #18
Source File: data.py    From End-to-end-ASR-Pytorch with MIT License 6 votes vote down vote up
def load_textset(n_jobs, use_gpu, pin_memory, corpus, text):

    # Text tokenizer
    tokenizer = load_text_encoder(**text)
    # Dataset
    tr_set, dv_set, tr_loader_bs, dv_loader_bs, data_msg = create_textset(
        tokenizer, **corpus)
    collect_tr = partial(collect_text_batch, mode='train')
    collect_dv = partial(collect_text_batch, mode='dev')
    # Dataloader (Text data stored in RAM, no need num_workers)
    tr_set = DataLoader(tr_set, batch_size=tr_loader_bs, shuffle=True, drop_last=True, collate_fn=collect_tr,
                        num_workers=0, pin_memory=use_gpu)
    dv_set = DataLoader(dv_set, batch_size=dv_loader_bs, shuffle=False, drop_last=False, collate_fn=collect_dv,
                        num_workers=0, pin_memory=pin_memory)

    # Messages to show
    data_msg.append('I/O spec.  | Token type = {}\t| Vocab size = {}'
                    .format(tokenizer.token_type, tokenizer.vocab_size))

    return tr_set, dv_set, tokenizer.vocab_size, tokenizer, data_msg 
Example #19
Source File: __init__.py    From facebook-wda with MIT License 6 votes vote down vote up
def set_alert_callback(self, callback):
        """
        Args:
            callback (func): called when alert popup
        
        Example of callback:

            def callback(session):
                session.alert.accept()
        """
        if callable(callback):
            self.http.alert_callback = functools.partial(callback, self)
        else:
            self.http.alert_callback = None

    #Not working
    #def get_clipboard(self):
    #    return self.http.post("/wda/getPasteboard").value

    # Not working
    #def siri_activate(self, text):
    #    self.http.post("/wda/siri/activate", {"text": text}) 
Example #20
Source File: __init__.py    From aegea with Apache License 2.0 5 votes vote down vote up
def add_time_bound_args(p, snap=0):
    t = partial(Timestamp, snap=snap)
    p.add_argument("--start-time", type=t, default=Timestamp("-7d", snap=snap), help=Timestamp.__doc__, metavar="START")
    p.add_argument("--end-time", type=t, help=Timestamp.__doc__, metavar="END") 
Example #21
Source File: losses_test.py    From DOTA_models with Apache License 2.0 5 votes vote down vote up
def test_mmd_is_zero_when_inputs_are_same(self):
    with self.test_session():
      x = tf.random_uniform((2, 3), seed=1)
      kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.]))
      self.assertEquals(0, losses.maximum_mean_discrepancy(x, x, kernel).eval()) 
Example #22
Source File: losses_test.py    From DOTA_models with Apache License 2.0 5 votes vote down vote up
def test_mmd_name(self):
    with self.test_session():
      x = tf.random_uniform((2, 3), seed=1)
      kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.]))
      loss = losses.maximum_mean_discrepancy(x, x, kernel)

      self.assertEquals(loss.op.name, 'MaximumMeanDiscrepancy/value') 
Example #23
Source File: losses_test.py    From DOTA_models with Apache License 2.0 5 votes vote down vote up
def test_fast_mmd_is_similar_to_slow_mmd(self):
    with self.test_session():
      x = tf.constant(np.random.normal(size=(2, 3)), tf.float32)
      y = tf.constant(np.random.rand(2, 3), tf.float32)

      cost_old = MaximumMeanDiscrepancySlow(x, y, [1.]).eval()
      kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.]))
      cost_new = losses.maximum_mean_discrepancy(x, y, kernel).eval()

      self.assertAlmostEqual(cost_old, cost_new, delta=1e-5) 
Example #24
Source File: losses_test.py    From DOTA_models with Apache License 2.0 5 votes vote down vote up
def test_mmd_is_zero_when_distributions_are_same(self):

    with self.test_session():
      x = tf.random_uniform((1000, 10), seed=1)
      y = tf.random_uniform((1000, 10), seed=3)

      kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([100.]))
      loss = losses.maximum_mean_discrepancy(x, y, kernel=kernel).eval()

      self.assertAlmostEqual(0, loss, delta=1e-4) 
Example #25
Source File: wrappers.py    From soccer-matlab with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def reset(self, blocking=True):
    """Reset the environment.

    Args:
      blocking: Whether to wait for the result.

    Returns:
      New observation when blocking, otherwise callable that returns the new
      observation.
    """
    self._conn.send((self._RESET, None))
    if blocking:
      return self._receive(self._OBSERV)
    else:
      return functools.partial(self._receive, self._OBSERV) 
Example #26
Source File: wrappers_test.py    From soccer-matlab with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def test_close_no_hang_after_init(self):
    constructor = functools.partial(
        tools.MockEnvironment,
        observ_shape=(2, 3), action_shape=(2,),
        min_duration=2, max_duration=2)
    env = tools.wrappers.ExternalProcess(constructor)
    env.close() 
Example #27
Source File: wrappers_test.py    From soccer-matlab with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def test_reraise_exception_in_step(self):
    constructor = functools.partial(
        MockEnvironmentCrashInStep, crash_at_step=3)
    env = tools.wrappers.ExternalProcess(constructor)
    env.reset()
    env.step(env.action_space.sample())
    env.step(env.action_space.sample())
    with self.assertRaises(Exception):
      env.step(env.action_space.sample()) 
Example #28
Source File: train_ppo_test.py    From soccer-matlab with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def test_no_crash_observation_shape(self):
    nets = networks.ForwardGaussianPolicy, networks.RecurrentGaussianPolicy
    observ_shapes = (1,), (2, 3), (2, 3, 4)
    for network, observ_shape in itertools.product(nets, observ_shapes):
      config = self._define_config()
      with config.unlocked:
        config.env = functools.partial(
            tools.MockEnvironment, observ_shape, action_shape=(3,),
            min_duration=15, max_duration=15)
        config.max_length = 20
        config.steps = 100
        config.network = network
      for score in train.train(config, env_processes=False):
        float(score) 
Example #29
Source File: configs.py    From soccer-matlab with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def pybullet_racecar():
  """Configuration for Bullet MIT Racecar task."""
  locals().update(default())
  # Environment
  env = 'RacecarBulletEnv-v0' #functools.partial(racecarGymEnv.RacecarGymEnv, isDiscrete=False, renders=True)
  max_length = 10
  steps = 1e7  # 10M
  return locals() 
Example #30
Source File: train_ppo_test.py    From soccer-matlab with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def test_no_crash_variable_duration(self):
    config = self._define_config()
    with config.unlocked:
      config.env = functools.partial(
          tools.MockEnvironment, observ_shape=(2, 3), action_shape=(3,),
          min_duration=5, max_duration=25)
      config.max_length = 25
      config.steps = 200
      config.network = networks.RecurrentGaussianPolicy
    for score in train.train(config, env_processes=False):
      float(score)