Python functools.partial() Examples

The following are 30 code examples of functools.partial(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module functools , or try the search function .
Example #1
Source File: misc.py    From mmdetection with Apache License 2.0 8 votes vote down vote up
def multi_apply(func, *args, **kwargs):
    """Apply function to a list of arguments.

    Note:
        This function applies the ``func`` to multiple inputs and
            map the multiple outputs of the ``func`` into different
            list. Each list contains the same type of outputs corresponding
            to different inputs.

    Args:
        func (Function): A function that will be applied to a list of
            arguments

    Returns:
        tuple(list): A tuple containing multiple list, each list contains
            a kind of returned results by the function
    """
    pfunc = partial(func, **kwargs) if kwargs else func
    map_results = map(pfunc, *args)
    return tuple(map(list, zip(*map_results))) 
Example #2
Source File: manager.py    From wafw00f with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def load_plugins():
    here = os.path.abspath(os.path.dirname(__file__))
    get_path = partial(os.path.join, here)
    plugin_dir = get_path('plugins')

    plugin_base = PluginBase(
        package='wafw00f.plugins', searchpath=[plugin_dir]
    )
    plugin_source = plugin_base.make_plugin_source(
        searchpath=[plugin_dir], persist=True
    )

    plugin_dict = {}
    for plugin_name in plugin_source.list_plugins():
        plugin_dict[plugin_name] = plugin_source.load_plugin(plugin_name)

    return plugin_dict 
Example #3
Source File: ecs.py    From aegea with Apache License 2.0 6 votes vote down vote up
def tasks(args):
    list_clusters = clients.ecs.get_paginator("list_clusters")
    list_tasks = clients.ecs.get_paginator("list_tasks")

    def list_tasks_worker(worker_args):
        cluster, status = worker_args
        return cluster, status, list(paginate(list_tasks, cluster=cluster, desiredStatus=status))

    def describe_tasks_worker(t, cluster=None):
        return clients.ecs.describe_tasks(cluster=cluster, tasks=t)["tasks"] if t else []

    task_descs = []
    if args.clusters is None:
        args.clusters = [__name__.replace(".", "_")] if args.tasks else list(paginate(list_clusters))
    if args.tasks:
        task_descs = describe_tasks_worker(args.tasks, cluster=args.clusters[0])
    else:
        with ThreadPoolExecutor() as executor:
            for cluster, status, tasks in executor.map(list_tasks_worker, product(args.clusters, args.desired_status)):
                worker = partial(describe_tasks_worker, cluster=cluster)
                descs = executor.map(worker, (tasks[pos:pos + 100] for pos in range(0, len(tasks), 100)))
                task_descs += sum(descs, [])
    page_output(tabulate(task_descs, args)) 
Example #4
Source File: dsl.py    From gql with MIT License 6 votes vote down vote up
def get_arg_serializer(arg_type):
    if isinstance(arg_type, GraphQLNonNull):
        return get_arg_serializer(arg_type.of_type)
    if isinstance(arg_type, GraphQLInputField):
        return get_arg_serializer(arg_type.type)
    if isinstance(arg_type, GraphQLInputObjectType):
        serializers = {k: get_arg_serializer(v) for k, v in arg_type.fields.items()}
        return lambda value: ObjectValueNode(
            fields=FrozenList(
                ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v))
                for k, v in value.items()
            )
        )
    if isinstance(arg_type, GraphQLList):
        inner_serializer = get_arg_serializer(arg_type.of_type)
        return partial(serialize_list, inner_serializer)
    if isinstance(arg_type, GraphQLEnumType):
        return lambda value: EnumValueNode(value=arg_type.serialize(value))
    return lambda value: ast_from_value(arg_type.serialize(value), arg_type) 
Example #5
Source File: wspbus.py    From cherrypy with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def subscribe(self, channel, callback=None, priority=None):
        """Add the given callback at the given channel (if not present).

        If callback is None, return a partial suitable for decorating
        the callback.
        """
        if callback is None:
            return functools.partial(
                self.subscribe,
                channel,
                priority=priority,
            )

        ch_listeners = self.listeners.setdefault(channel, set())
        ch_listeners.add(callback)

        if priority is None:
            priority = getattr(callback, 'priority', 50)
        self._priorities[(channel, callback)] = priority 
Example #6
Source File: window.py    From LPHK with GNU General Public License v3.0 6 votes vote down vote up
def popup_choice(self, window, title, image, text, choices):
        popup = tk.Toplevel(window)
        popup.resizable(False, False)
        if MAIN_ICON != None:
            if os.path.splitext(MAIN_ICON)[1].lower() == ".gif":
                dummy = None
                #popup.call('wm', 'iconphoto', popup._w, tk.PhotoImage(file=MAIN_ICON))
            else:
                popup.iconbitmap(MAIN_ICON)
        popup.wm_title(title)
        popup.tkraise(window)
        
        def run_end(func):
            popup.destroy()
            if func != None:
                func()

        picture_label = tk.Label(popup, image=image)
        picture_label.photo = image
        picture_label.grid(column=0, row=0, rowspan=2, padx=10, pady=10)
        tk.Label(popup, text=text, justify=tk.CENTER).grid(column=1, row=0, columnspan=len(choices), padx=10, pady=10)
        for idx, choice in enumerate(choices):
            run_end_func = partial(run_end, choice[1])
            tk.Button(popup, text=choice[0], command=run_end_func).grid(column=1 + idx, row=1, padx=10, pady=10)
        popup.wait_visibility()
        popup.grab_set()
        popup.wait_window() 
Example #7
Source File: model.py    From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def fprop(self, x, **kwargs):
        del kwargs
        my_conv = functools.partial(tf.layers.conv2d,
                                    kernel_size=3,
                                    strides=2,
                                    padding='valid',
                                    activation=tf.nn.relu,
                                    kernel_initializer=HeReLuNormalInitializer)
        my_dense = functools.partial(
            tf.layers.dense, kernel_initializer=HeReLuNormalInitializer)

        with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE):
            for depth in [96, 256, 384, 384, 256]:
                x = my_conv(x, depth)
            y = tf.layers.flatten(x)
            y = my_dense(y, 4096, tf.nn.relu)
            y = fc7 = my_dense(y, 4096, tf.nn.relu)
            y = my_dense(y, 1000)
            return {'fc7': fc7,
                    self.O_LOGITS: y,
                    self.O_PROBS: tf.nn.softmax(logits=y)} 
Example #8
Source File: monitor.py    From multibootusb with GNU General Public License v2.0 6 votes vote down vote up
def run(self):
        self.monitor.start()
        notifier = poll.Poll.for_events(
            (self.monitor, 'r'), (self._stop_event.source, 'r'))
        while True:
            for file_descriptor, event in eintr_retry_call(notifier.poll):
                if file_descriptor == self._stop_event.source.fileno():
                    # in case of a stop event, close our pipe side, and
                    # return from the thread
                    self._stop_event.source.close()
                    return
                elif file_descriptor == self.monitor.fileno() and event == 'r':
                    read_device = partial(eintr_retry_call, self.monitor.poll, timeout=0)
                    for device in iter(read_device, None):
                        self._callback(device)
                else:
                    raise EnvironmentError('Observed monitor hung up') 
Example #9
Source File: train.py    From spleeter with MIT License 6 votes vote down vote up
def _create_evaluation_spec(params, audio_adapter, audio_path):
    """ Setup eval spec evaluating ever n seconds

    :param params: TF params to build spec from.
    :returns: Built evaluation spec.
    """
    input_fn = partial(
        get_validation_dataset,
        params,
        audio_adapter,
        audio_path)
    evaluation_spec = tf.estimator.EvalSpec(
        input_fn=input_fn,
        steps=None,
        throttle_secs=params['throttle_secs'])
    return evaluation_spec 
Example #10
Source File: 2_simple_mnist.py    From deep-learning-note with MIT License 6 votes vote down vote up
def __init__(self, learning_rate, max_iteration_steps, seed=None):
        """Initializes a `Generator` that builds `SimpleCNNs`.

        Args:
          learning_rate: The float learning rate to use.
          max_iteration_steps: The number of steps per iteration.
          seed: The random seed.

        Returns:
          An instance of `Generator`.
        """
        self._seed = seed
        self._cnn_builder_fn = functools.partial(
            SimpleCNNBuilder,
            learning_rate=learning_rate,
            max_iteration_steps=max_iteration_steps) 
Example #11
Source File: rnn_cell.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 6 votes vote down vote up
def __init__(self, input_shape, num_hidden,
                 h2h_kernel=(3, 3), h2h_dilate=(1, 1),
                 i2h_kernel=(3, 3), i2h_stride=(1, 1),
                 i2h_pad=(1, 1), i2h_dilate=(1, 1),
                 i2h_weight_initializer=None, h2h_weight_initializer=None,
                 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
                 activation=functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2),
                 prefix='ConvRNN_', params=None, conv_layout='NCHW'):
        super(ConvRNNCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden,
                                          h2h_kernel=h2h_kernel, h2h_dilate=h2h_dilate,
                                          i2h_kernel=i2h_kernel, i2h_stride=i2h_stride,
                                          i2h_pad=i2h_pad, i2h_dilate=i2h_dilate,
                                          i2h_weight_initializer=i2h_weight_initializer,
                                          h2h_weight_initializer=h2h_weight_initializer,
                                          i2h_bias_initializer=i2h_bias_initializer,
                                          h2h_bias_initializer=h2h_bias_initializer,
                                          activation=activation, prefix=prefix,
                                          params=params, conv_layout=conv_layout) 
Example #12
Source File: rnn_cell.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 6 votes vote down vote up
def __init__(self, input_shape, num_hidden,
                 h2h_kernel=(3, 3), h2h_dilate=(1, 1),
                 i2h_kernel=(3, 3), i2h_stride=(1, 1),
                 i2h_pad=(1, 1), i2h_dilate=(1, 1),
                 i2h_weight_initializer=None, h2h_weight_initializer=None,
                 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
                 activation=functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2),
                 prefix='ConvLSTM_', params=None,
                 conv_layout='NCHW'):
        super(ConvLSTMCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden,
                                           h2h_kernel=h2h_kernel, h2h_dilate=h2h_dilate,
                                           i2h_kernel=i2h_kernel, i2h_stride=i2h_stride,
                                           i2h_pad=i2h_pad, i2h_dilate=i2h_dilate,
                                           i2h_weight_initializer=i2h_weight_initializer,
                                           h2h_weight_initializer=h2h_weight_initializer,
                                           i2h_bias_initializer=i2h_bias_initializer,
                                           h2h_bias_initializer=h2h_bias_initializer,
                                           activation=activation, prefix=prefix,
                                           params=params, conv_layout=conv_layout) 
Example #13
Source File: rnn_cell.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 6 votes vote down vote up
def __init__(self, input_shape, num_hidden,
                 h2h_kernel=(3, 3), h2h_dilate=(1, 1),
                 i2h_kernel=(3, 3), i2h_stride=(1, 1),
                 i2h_pad=(1, 1), i2h_dilate=(1, 1),
                 i2h_weight_initializer=None, h2h_weight_initializer=None,
                 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
                 activation=functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2),
                 prefix='ConvGRU_', params=None, conv_layout='NCHW'):
        super(ConvGRUCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden,
                                          h2h_kernel=h2h_kernel, h2h_dilate=h2h_dilate,
                                          i2h_kernel=i2h_kernel, i2h_stride=i2h_stride,
                                          i2h_pad=i2h_pad, i2h_dilate=i2h_dilate,
                                          i2h_weight_initializer=i2h_weight_initializer,
                                          h2h_weight_initializer=h2h_weight_initializer,
                                          i2h_bias_initializer=i2h_bias_initializer,
                                          h2h_bias_initializer=h2h_bias_initializer,
                                          activation=activation, prefix=prefix,
                                          params=params, conv_layout=conv_layout) 
Example #14
Source File: eval.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def main(unused_argv):
  assert FLAGS.checkpoint_dir, '`checkpoint_dir` is missing.'
  assert FLAGS.eval_dir, '`eval_dir` is missing.'
  if FLAGS.pipeline_config_path:
    model_config, eval_config, input_config = get_configs_from_pipeline_file()
  else:
    model_config, eval_config, input_config = get_configs_from_multiple_files()

  model_fn = functools.partial(
      model_builder.build,
      model_config=model_config,
      is_training=False)

  create_input_dict_fn = functools.partial(
      input_reader_builder.build,
      input_config)

  label_map = label_map_util.load_labelmap(input_config.label_map_path)
  max_num_classes = max([item.id for item in label_map.item])
  categories = label_map_util.convert_label_map_to_categories(
      label_map, max_num_classes)

  evaluator.evaluate(create_input_dict_fn, model_fn, eval_config, categories,
                     FLAGS.checkpoint_dir, FLAGS.eval_dir) 
Example #15
Source File: losses.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def mmd_loss(source_samples, target_samples, weight, scope=None):
  """Adds a similarity loss term, the MMD between two representations.

  This Maximum Mean Discrepancy (MMD) loss is calculated with a number of
  different Gaussian kernels.

  Args:
    source_samples: a tensor of shape [num_samples, num_features].
    target_samples: a tensor of shape [num_samples, num_features].
    weight: the weight of the MMD loss.
    scope: optional name scope for summary tags.

  Returns:
    a scalar tensor representing the MMD loss value.
  """
  sigmas = [
      1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100,
      1e3, 1e4, 1e5, 1e6
  ]
  gaussian_kernel = partial(
      utils.gaussian_kernel_matrix, sigmas=tf.constant(sigmas))

  loss_value = maximum_mean_discrepancy(
      source_samples, target_samples, kernel=gaussian_kernel)
  loss_value = tf.maximum(1e-4, loss_value) * weight
  assert_op = tf.Assert(tf.is_finite(loss_value), [loss_value])
  with tf.control_dependencies([assert_op]):
    tag = 'MMD Loss'
    if scope:
      tag = scope + tag
    tf.summary.scalar(tag, loss_value)
    tf.losses.add_loss(loss_value)

  return loss_value 
Example #16
Source File: minitaur_env_randomizer_from_config.py    From soccer-matlab with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def _build_randomization_function_dict(self, env):
    func_dict = {}
    func_dict["mass"] = functools.partial(
        self._randomize_masses, minitaur=env.minitaur)
    func_dict["inertia"] = functools.partial(
        self._randomize_inertia, minitaur=env.minitaur)
    func_dict["latency"] = functools.partial(
        self._randomize_latency, minitaur=env.minitaur)
    func_dict["joint friction"] = functools.partial(
        self._randomize_joint_friction, minitaur=env.minitaur)
    func_dict["motor friction"] = functools.partial(
        self._randomize_motor_friction, minitaur=env.minitaur)
    func_dict["restitution"] = functools.partial(
        self._randomize_contact_restitution, minitaur=env.minitaur)
    func_dict["lateral friction"] = functools.partial(
        self._randomize_contact_friction, minitaur=env.minitaur)
    func_dict["battery"] = functools.partial(
        self._randomize_battery_level, minitaur=env.minitaur)
    func_dict["motor strength"] = functools.partial(
        self._randomize_motor_strength, minitaur=env.minitaur)
    # Settinmg control step needs access to the environment.
    func_dict["control step"] = functools.partial(
        self._randomize_control_step, env=env)
    return func_dict 
Example #17
Source File: wrappers.py    From soccer-matlab with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def step(self, action, blocking=True):
    """Step the environment.

    Args:
      action: The action to apply to the environment.
      blocking: Whether to wait for the result.

    Returns:
      Transition tuple when blocking, otherwise callable that returns the
      transition tuple.
    """
    self._conn.send((self._ACTION, action))
    if blocking:
      return self._receive(self._TRANSITION)
    else:
      return functools.partial(self._receive, self._TRANSITION) 
Example #18
Source File: data.py    From End-to-end-ASR-Pytorch with MIT License 6 votes vote down vote up
def load_textset(n_jobs, use_gpu, pin_memory, corpus, text):

    # Text tokenizer
    tokenizer = load_text_encoder(**text)
    # Dataset
    tr_set, dv_set, tr_loader_bs, dv_loader_bs, data_msg = create_textset(
        tokenizer, **corpus)
    collect_tr = partial(collect_text_batch, mode='train')
    collect_dv = partial(collect_text_batch, mode='dev')
    # Dataloader (Text data stored in RAM, no need num_workers)
    tr_set = DataLoader(tr_set, batch_size=tr_loader_bs, shuffle=True, drop_last=True, collate_fn=collect_tr,
                        num_workers=0, pin_memory=use_gpu)
    dv_set = DataLoader(dv_set, batch_size=dv_loader_bs, shuffle=False, drop_last=False, collate_fn=collect_dv,
                        num_workers=0, pin_memory=pin_memory)

    # Messages to show
    data_msg.append('I/O spec.  | Token type = {}\t| Vocab size = {}'
                    .format(tokenizer.token_type, tokenizer.vocab_size))

    return tr_set, dv_set, tokenizer.vocab_size, tokenizer, data_msg 
Example #19
Source File: __init__.py    From facebook-wda with MIT License 6 votes vote down vote up
def set_alert_callback(self, callback):
        """
        Args:
            callback (func): called when alert popup
        
        Example of callback:

            def callback(session):
                session.alert.accept()
        """
        if callable(callback):
            self.http.alert_callback = functools.partial(callback, self)
        else:
            self.http.alert_callback = None

    #Not working
    #def get_clipboard(self):
    #    return self.http.post("/wda/getPasteboard").value

    # Not working
    #def siri_activate(self, text):
    #    self.http.post("/wda/siri/activate", {"text": text}) 
Example #20
Source File: logs.py    From aegea with Apache License 2.0 5 votes vote down vote up
def grep(args):
    if args.context:
        args.before_context = args.after_context = args.context
    if not args.end_time:
        args.end_time = Timestamp("-0s")
    query = clients.logs.start_query(logGroupName=args.log_group,
                                     startTime=int(timestamp(args.start_time) * 1000),
                                     endTime=int(timestamp(args.end_time) * 1000),
                                     queryString=args.query)
    seen_results = {}
    print_with_context = partial(print_log_event_with_context, before=args.before_context, after=args.after_context)
    try:
        with ThreadPoolExecutor() as executor:
            while True:
                res = clients.logs.get_query_results(queryId=query["queryId"])
                log_record_pointers = []
                for record in res["results"]:
                    event = {r["field"]: r["value"] for r in record}
                    event_hash = hashlib.sha256(json.dumps(event, sort_keys=True).encode()).hexdigest()[:32]
                    if event_hash in seen_results:
                        continue
                    if "@ptr" in event and (args.before_context or args.after_context):
                        log_record_pointers.append(event["@ptr"])
                    else:
                        print_log_event(event)
                    seen_results[event_hash] = event
                if log_record_pointers:
                    executor.map(print_with_context, log_record_pointers)
                if res["status"] == "Complete":
                    break
                elif res["status"] in {"Failed", "Cancelled"}:
                    raise AegeaException("Query status: {}".format(res["status"]))
                time.sleep(1)
    finally:
        try:
            clients.logs.stop_query(queryId=query["queryId"])
        except clients.logs.exceptions.InvalidParameterException:
            pass
    logger.debug("Query %s: %s", query["queryId"], res["statistics"])
    return SystemExit(os.EX_OK if seen_results else os.EX_DATAERR) 
Example #21
Source File: __init__.py    From aegea with Apache License 2.0 5 votes vote down vote up
def add_time_bound_args(p, snap=0):
    t = partial(Timestamp, snap=snap)
    p.add_argument("--start-time", type=t, default=Timestamp("-7d", snap=snap), help=Timestamp.__doc__, metavar="START")
    p.add_argument("--end-time", type=t, help=Timestamp.__doc__, metavar="END") 
Example #22
Source File: util.py    From BASS with GNU General Public License v2.0 5 votes vote down vote up
def file_sha256(path):
    """Get the hex digest of the given file"""
    sha256 = hashlib.sha256()
    with open(path, "rb") as f:
        map(sha256.update, iter(partial(f.read, BLOCK_SIZE), ""))
    return sha256.hexdigest() 
Example #23
Source File: util.py    From BASS with GNU General Public License v2.0 5 votes vote down vote up
def file_sha256(path):
    """Get the hex digest of the given file"""
    sha256 = hashlib.sha256()
    with open(path, "rb") as f:
        map(sha256.update, iter(partial(f.read, BLOCK_SIZE), ""))
    return sha256.hexdigest() 
Example #24
Source File: asyncioEvent.py    From Learning-Concurrency-in-Python with MIT License 5 votes vote down vote up
def main(loop):
    # Create a shared event
    event = asyncio.Event()
    print('event start state: {}'.format(event.is_set()))

    loop.call_later(
        0.1, functools.partial(set_event, event)
    )

    await asyncio.wait([coro1(event), coro2(event)])
    print('event end state: {}'.format(event.is_set())) 
Example #25
Source File: test_attacks.py    From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def setUp(self):
        super(TestFastFeatureAdversaries, self).setUp()

        def make_imagenet_cnn(input_shape=(None, 224, 224, 3)):
            """
            Similar CNN to AlexNet.
            """

            class ModelImageNetCNN(Model):
                def __init__(self, scope, nb_classes=1000, **kwargs):
                    del kwargs
                    Model.__init__(self, scope, nb_classes, locals())

                def fprop(self, x, **kwargs):
                    del kwargs
                    my_conv = functools.partial(tf.layers.conv2d,
                                                kernel_size=3,
                                                strides=2,
                                                padding='valid',
                                                activation=tf.nn.relu,
                                                kernel_initializer=HeReLuNormalInitializer)
                    my_dense = functools.partial(tf.layers.dense,
                                                 kernel_initializer=HeReLuNormalInitializer)
                    with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE):
                        for depth in [96, 256, 384, 384, 256]:
                            x = my_conv(x, depth)
                        y = tf.layers.flatten(x)
                        y = my_dense(y, 4096, tf.nn.relu)
                        y = fc7 = my_dense(y, 4096, tf.nn.relu)
                        y = my_dense(y, 1000)
                        return {'fc7': fc7,
                                self.O_LOGITS: y,
                                self.O_PROBS: tf.nn.softmax(logits=y)}

            return ModelImageNetCNN('imagenet')

        self.input_shape = [10, 224, 224, 3]
        self.sess = tf.Session()
        self.model = make_imagenet_cnn(self.input_shape)
        self.attack = FastFeatureAdversaries(self.model) 
Example #26
Source File: mnist_blackbox_keras.py    From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def fprop(self, x, **kwargs):
        del kwargs
        my_dense = functools.partial(
            tf.layers.dense, kernel_initializer=HeReLuNormalInitializer)
        with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE):
            y = tf.layers.flatten(x)
            y = my_dense(y, self.nb_filters, activation=tf.nn.relu)
            y = my_dense(y, self.nb_filters, activation=tf.nn.relu)
            logits = my_dense(y, self.nb_classes)
            return {self.O_LOGITS: logits,
                    self.O_PROBS: tf.nn.softmax(logits=logits)} 
Example #27
Source File: mnist_blackbox.py    From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def fprop(self, x, **kwargs):
        del kwargs
        my_dense = functools.partial(
            tf.layers.dense, kernel_initializer=HeReLuNormalInitializer)
        with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE):
            y = tf.layers.flatten(x)
            y = my_dense(y, self.nb_filters, activation=tf.nn.relu)
            y = my_dense(y, self.nb_filters, activation=tf.nn.relu)
            logits = my_dense(y, self.nb_classes)
            return {self.O_LOGITS: logits,
                    self.O_PROBS: tf.nn.softmax(logits=logits)} 
Example #28
Source File: pipe.py    From multibootusb with GNU General Public License v2.0 5 votes vote down vote up
def _get_pipe2_implementation():
    """Find the appropriate implementation for ``pipe2``.

Return a function implementing ``pipe2``."""
    if hasattr(os, 'pipe2'):
        return os.pipe2 # pylint: disable=no-member
    else:
        try:
            libc = load_ctypes_library("libc", SIGNATURES, ERROR_CHECKERS)
            return (partial(_pipe2_ctypes, libc)
                    if hasattr(libc, 'pipe2') else
                    _pipe2_by_pipe)
        except ImportError:
            return _pipe2_by_pipe 
Example #29
Source File: param_rewrite.py    From multibootusb with GNU General Public License v2.0 5 votes vote down vote up
def add_tokens(*tokens):
    return partial(op_add_tokens, tokens) 
Example #30
Source File: param_rewrite.py    From multibootusb with GNU General Public License v2.0 5 votes vote down vote up
def remove_tokens(*tokens):
    return partial(op_remove_tokens, tokens)