Python collections.namedtuple() Examples

The following are code examples for showing how to use collections.namedtuple(). They are from open source Python projects. You can vote up the examples you like or vote down the ones you don't like.

Example 1
Project: pyblish-win   Author: pyblish   File: _pslinux.py    GNU Lesser General Public License v3.0 6 votes vote down vote up
def _get_cputimes_fields():
    """Return a namedtuple of variable fields depending on the
    CPU times available on this Linux kernel version which may be:
    (user, nice, system, idle, iowait, irq, softirq, [steal, [guest,
     [guest_nice]]])
    """
    with open('/proc/stat', 'rb') as f:
        values = f.readline().split()[1:]
    fields = ['user', 'nice', 'system', 'idle', 'iowait', 'irq', 'softirq']
    vlen = len(values)
    if vlen >= 8:
        # Linux >= 2.6.11
        fields.append('steal')
    if vlen >= 9:
        # Linux >= 2.6.24
        fields.append('guest')
    if vlen >= 10:
        # Linux >= 3.2.0
        fields.append('guest_nice')
    return fields 
Example 2
Project: pyblish-win   Author: pyblish   File: test_collections.py    GNU Lesser General Public License v3.0 6 votes vote down vote up
def test_tupleness(self):
        Point = namedtuple('Point', 'x y')
        p = Point(11, 22)

        self.assertIsInstance(p, tuple)
        self.assertEqual(p, (11, 22))                                       # matches a real tuple
        self.assertEqual(tuple(p), (11, 22))                                # coercable to a real tuple
        self.assertEqual(list(p), [11, 22])                                 # coercable to a list
        self.assertEqual(max(p), 22)                                        # iterable
        self.assertEqual(max(*p), 22)                                       # star-able
        x, y = p
        self.assertEqual(p, (x, y))                                         # unpacks like a tuple
        self.assertEqual((p[0], p[1]), (11, 22))                            # indexable like a tuple
        self.assertRaises(IndexError, p.__getitem__, 3)

        self.assertEqual(p.x, x)
        self.assertEqual(p.y, y)
        self.assertRaises(AttributeError, eval, 'p.z', locals()) 
Example 3
Project: OpenAPS   Author: medicinexlab   File: oldpred.py    MIT License 6 votes vote down vote up
def _get_old_pred(bg_df, start_index, end_index, num_pred_minutes):
    #The number of 5 minute sections until the prediction (e.g. 30 minutes = 6 sections)
    pred_array_index = num_pred_minutes / DATA_SPACING

    actual_bg_array, actual_bg_time_array, eventual_pred_array, eventual_pred_time_array, iob_pred_array, iob_pred_time_array, cob_pred_array, cob_pred_time_array, acob_pred_array, acob_pred_time_array = _get_raw_pred_array(bg_df, start_index, end_index, pred_array_index)

    eventual_pred_data = _find_compare_array(actual_bg_array, actual_bg_time_array, eventual_pred_array, eventual_pred_time_array, 30)
    iob_pred_data = _find_compare_array(actual_bg_array, actual_bg_time_array, iob_pred_array, iob_pred_time_array, num_pred_minutes)
    cob_pred_data= _find_compare_array(actual_bg_array, actual_bg_time_array, cob_pred_array, cob_pred_time_array, num_pred_minutes)
    acob_pred_data = _find_compare_array(actual_bg_array, actual_bg_time_array, acob_pred_array, acob_pred_time_array, num_pred_minutes)

    return eventual_pred_data, iob_pred_data, cob_pred_data, acob_pred_data


#Plots old pred data given namedtuple of old data (eventualBG, acob, cob, or iob).
#Can show or save prediction plot based on show_pred_plot or save_pred_plot, respectively.
#Same goes for the Clarke Error grid with show_clarke_plot or save_clarke_plot, respectively.
#id_str, algorithm_str, minutes_str are strings of the ID, the prediction algorithm and the number of prediction minutes used for the title. 
Example 4
Project: pytuber   Author: tefra   File: test_services.py    MIT License 6 votes vote down vote up
def test_sync_with_user_friends_tracks(self, get_user, friends, *args):
        get_user.return_value = self.get_user()
        friend = namedtuple("Friend", ["recent_track"])
        friends.return_value = [
            friend(recent_track=1),
            friend(recent_track=2),
            friend(recent_track=3),
        ]

        actual = LastService.get_tracks(
            type=PlaylistType.USER_FRIENDS_RECENT_TRACKS.value,
            limit=10,
            username="foo",
        )
        self.assertEqual([1, 2, 3], actual)
        get_user.assert_called_once_with("foo")
        friends.assert_called_once_with(limit=10, recent_tracks=True) 
Example 5
Project: iSDX   Author: sdn-ixp   File: replay.py    Apache License 2.0 6 votes vote down vote up
def __init__(self, config, flows_dir, ports_dir, num_timesteps, debug=False):
        self.logger = logging.getLogger("LogHistory")
        if debug:
            self.logger.setLevel(logging.DEBUG)

        self.log_entry = namedtuple("LogEntry", "source destination type")
        self.ports = defaultdict(list)
        self.flows = defaultdict(list)

        self.data = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
        self.current_timestep = 0
        self.total_timesteps = num_timesteps

        self.parse_config(config)
        self.parse_logs(num_timesteps, flows_dir, ports_dir)
        self.info()

        pretty(self.data) 
Example 6
Project: neural-fingerprinting   Author: StephanZheng   File: submissions.py    BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def __init__(self, datastore_client, storage_client, round_name):
    """Initializes CompetitionSubmissions.

    Args:
      datastore_client: instance of CompetitionDatastoreClient
      storage_client: instance of CompetitionStorageClient
      round_name: name of the round
    """
    self._datastore_client = datastore_client
    self._storage_client = storage_client
    self._round_name = round_name
    # each of the variables is a dictionary,
    # where key - submission ID
    # value - SubmissionDescriptor namedtuple
    self._attacks = None
    self._targeted_attacks = None
    self._defenses = None 
Example 7
Project: Ansible-Example-AB2018   Author: umit-ozturk   File: poolmanager.py    MIT License 6 votes vote down vote up
def connection_from_pool_key(self, pool_key, request_context=None):
        """
        Get a :class:`ConnectionPool` based on the provided pool key.

        ``pool_key`` should be a namedtuple that only contains immutable
        objects. At a minimum it must have the ``scheme``, ``host``, and
        ``port`` fields.
        """
        with self.pools.lock:
            # If the scheme, host, or port doesn't match existing open
            # connections, open a new ConnectionPool.
            pool = self.pools.get(pool_key)
            if pool:
                return pool

            # Make a fresh ConnectionPool of the desired type
            scheme = request_context['scheme']
            host = request_context['host']
            port = request_context['port']
            pool = self._new_pool(scheme, host, port, request_context=request_context)
            self.pools[pool_key] = pool

        return pool 
Example 8
Project: Ansible-Example-AB2018   Author: umit-ozturk   File: bigip_vcmp_guest.py    MIT License 6 votes vote down vote up
def mgmt_tuple(self):
        result = None
        Destination = namedtuple('Destination', ['ip', 'subnet'])
        try:
            parts = self._values['mgmt_address'].split('/')
            if len(parts) == 2:
                result = Destination(ip=parts[0], subnet=parts[1])
            elif len(parts) < 2:
                result = Destination(ip=parts[0], subnet=None)
            else:
                F5ModuleError(
                    "The provided mgmt_address is malformed."
                )
        except ValueError:
            result = Destination(ip=None, subnet=None)
        return result 
Example 9
Project: DJFeet   Author: libre-man   File: test_communicators.py    MIT License 6 votes vote down vote up
def test_protocol_communicator_iteration(protocol_communicator, monkeypatch,
                                         _):
    filename = str(random.randint(0, 1000)) + 'song_location'
    file_loc = '/tmp/#sdaas_only/' + filename + '.mp3'
    my_remote = str(random.randint(1000, 2000)) + 'remote'
    my_id = str(random.randint(1000, 2000)) + 'id'

    MySong = namedtuple('my_song', 'file_location')
    a_song = MySong(file_loc)
    mocked_post_request = MockingFunction()
    monkeypatch.setattr(requests, 'post', mocked_post_request)

    protocol_communicator.iteration(my_remote, my_id, a_song)

    assert mocked_post_request.called
    assert mocked_post_request.args == [((my_remote + '/iteration/', ), {
        'json': {
            'id': my_id,
            'filename_mixed': filename,
        }
    })] 
Example 10
Project: dynamic-training-with-apache-mxnet-on-aws   Author: awslabs   File: test_module.py    Apache License 2.0 6 votes vote down vote up
def test_forward_types():
    #Test forward with other data batch API
    Batch = namedtuple('Batch', ['data'])
    data = mx.sym.Variable('data')
    out = data * 2
    mod = mx.mod.Module(symbol=out, label_names=None)
    mod.bind(data_shapes=[('data', (1, 10))])
    mod.init_params()
    data1 = [mx.nd.ones((1, 10))]
    mod.forward(Batch(data1))
    assert mod.get_outputs()[0].shape == (1, 10)
    data2 = [mx.nd.ones((3, 5))]
    mod.forward(Batch(data2))
    assert mod.get_outputs()[0].shape == (3, 5)

    #Test forward with other NDArray and np.ndarray inputs
    data = mx.sym.Variable('data')
    out = data * 2
    mod = mx.mod.Module(symbol=out, label_names=None)
    mod.bind(data_shapes=[('data', (1, 10))])
    mod.init_params()
    data1 = mx.nd.ones((1, 10))
    assert mod.predict(data1).shape == (1, 10)
    data2 = np.ones((1, 10))
    assert mod.predict(data1).shape == (1, 10) 
Example 11
Project: DOTA_models   Author: ringringyi   File: data_utils.py    Apache License 2.0 6 votes vote down vote up
def read_names(names_path):
    """read data from downloaded file. See SmallNames.txt for example format
    or go to https://www.kaggle.com/kaggle/us-baby-names for full lists

    Args:
        names_path: path to the csv file similar to the example type
    Returns:
        Dataset: a namedtuple of two elements: deduped names and their associated
            counts. The names contain only 26 chars and are all lower case
    """
    names_data = pd.read_csv(names_path)
    names_data.Name = names_data.Name.str.lower()

    name_data = names_data.groupby(by=["Name"])["Count"].sum()
    name_counts = np.array(name_data.tolist())
    names_deduped = np.array(name_data.index.tolist())

    Dataset = collections.namedtuple('Dataset', ['Name', 'Count'])
    return Dataset(names_deduped, name_counts) 
Example 12
Project: DOTA_models   Author: ringringyi   File: model_deploy.py    Apache License 2.0 6 votes vote down vote up
def _optimize_clone(optimizer, clone, num_clones, regularization_losses,
                    **kwargs):
  """Compute losses and gradients for a single clone.

  Args:
    optimizer: A tf.Optimizer  object.
    clone: A Clone namedtuple.
    num_clones: The number of clones being deployed.
    regularization_losses: Possibly empty list of regularization_losses
      to add to the clone losses.
    **kwargs: Dict of kwarg to pass to compute_gradients().

  Returns:
    A tuple (clone_loss, clone_grads_and_vars).
      - clone_loss: A tensor for the total loss for the clone.  Can be None.
      - clone_grads_and_vars: List of (gradient, variable) for the clone.
        Can be empty.
  """
  sum_loss = _gather_clone_loss(clone, num_clones, regularization_losses)
  clone_grad = None
  if sum_loss is not None:
    with tf.device(clone.device):
      clone_grad = optimizer.compute_gradients(sum_loss, **kwargs)
  return sum_loss, clone_grad 
Example 13
Project: DOTA_models   Author: ringringyi   File: model.py    Apache License 2.0 6 votes vote down vote up
def create_loss(self, data, endpoints):
    """Creates all losses required to train the model.

    Args:
      data: InputEndpoints namedtuple.
      endpoints: Model namedtuple.

    Returns:
      Total loss.
    """
    # NOTE: the return value of ModelLoss is not used directly for the
    # gradient computation because under the hood it calls slim.losses.AddLoss,
    # which registers the loss in an internal collection and later returns it
    # as part of GetTotalLoss. We need to use total loss because model may have
    # multiple losses including regularization losses.
    self.sequence_loss_fn(endpoints.chars_logit, data.labels)
    total_loss = slim.losses.get_total_loss()
    tf.summary.scalar('TotalLoss', total_loss)
    return total_loss 
Example 14
Project: DOTA_models   Author: ringringyi   File: fsns_test.py    Apache License 2.0 6 votes vote down vote up
def test_decodes_example_proto(self):
    expected_label = range(37)
    expected_image, encoded = unittest_utils.create_random_image(
        'PNG', shape=(150, 600, 3))
    serialized = unittest_utils.create_serialized_example({
        'image/encoded': [encoded],
        'image/format': ['PNG'],
        'image/class':
        expected_label,
        'image/unpadded_class':
        range(10),
        'image/text': ['Raw text'],
        'image/orig_width': [150],
        'image/width': [600]
    })

    decoder = fsns.get_split('train', dataset_dir()).decoder
    with self.test_session() as sess:
      data_tuple = collections.namedtuple('DecodedData', decoder.list_items())
      data = sess.run(data_tuple(*decoder.decode(serialized)))

    self.assertAllEqual(expected_image, data.image)
    self.assertAllEqual(expected_label, data.label)
    self.assertEqual(['Raw text'], data.text)
    self.assertEqual([1], data.num_of_views) 
Example 15
Project: soccer-matlab   Author: utra-robosoccer   File: minitaur_reactive_env.py    BSD 2-Clause "Simplified" License 6 votes vote down vote up
def _reset(self):
    # TODO(b/73666007): Use composition instead of inheritance.
    # (http://go/design-for-testability-no-inheritance).
    init_pose = MinitaurPose(
        swing_angle_1=INIT_SWING_POS,
        swing_angle_2=INIT_SWING_POS,
        swing_angle_3=INIT_SWING_POS,
        swing_angle_4=INIT_SWING_POS,
        extension_angle_1=INIT_EXTENSION_POS,
        extension_angle_2=INIT_EXTENSION_POS,
        extension_angle_3=INIT_EXTENSION_POS,
        extension_angle_4=INIT_EXTENSION_POS)
    # TODO(b/73734502): Refactor input of _convert_from_leg_model to namedtuple.
    initial_motor_angles = self._convert_from_leg_model(list(init_pose))
    super(MinitaurReactiveEnv, self)._reset(
        initial_motor_angles=initial_motor_angles, reset_duration=0.5)
    return self._get_observation() 
Example 16
Project: sic   Author: Yanixos   File: poolmanager.py    GNU General Public License v3.0 6 votes vote down vote up
def connection_from_pool_key(self, pool_key):
        """
        Get a :class:`ConnectionPool` based on the provided pool key.

        ``pool_key`` should be a namedtuple that only contains immutable
        objects. At a minimum it must have the ``scheme``, ``host``, and
        ``port`` fields.
        """
        with self.pools.lock:
            # If the scheme, host, or port doesn't match existing open
            # connections, open a new ConnectionPool.
            pool = self.pools.get(pool_key)
            if pool:
                return pool

            # Make a fresh ConnectionPool of the desired type
            pool = self._new_pool(pool_key.scheme, pool_key.host, pool_key.port)
            self.pools[pool_key] = pool

        return pool 
Example 17
Project: fs_image   Author: facebookincubator   File: enriched_namedtuple.py    MIT License 5 votes vote down vote up
def _assert_all_fields_constructible(class_name, field_to_value):
    'We do not check that all fields are set, since namedtuple will.'
    for field, value in field_to_value.items():
        if value is NonConstructibleField:
            raise AssertionError(
                'customize_fields_fn for {} failed to construct field {}'
                .format(class_name, field)
            )
    return field_to_value 
Example 18
Project: leapp-repository   Author: oamg   File: test_forcedefaultboot.py    Apache License 2.0 5 votes vote down vote up
def __init__(self, case):
        if case.arch_s390x:
            self.configuration = namedtuple('configuration', ['architecture'])(architecture.ARCH_S390X)
        else:
            self.configuration = namedtuple('configuration', ['architecture'])(architecture.ARCH_X86_64) 
Example 19
Project: leapp-repository   Author: oamg   File: test_kernelcmdlineconfig.py    Apache License 2.0 5 votes vote down vote up
def __init__(self, arch):
        self.configuration = namedtuple('configuration', ['architecture'])(arch) 
Example 20
Project: leapp-repository   Author: oamg   File: unit_test.py    Apache License 2.0 5 votes vote down vote up
def test_actor(monkeypatch):
    def report_mocked(*models):
        yield namedtuple('msg', ['report'])(Report('title_with_inhibitor'))

    report_error = ReportError()
    monkeypatch.setattr(api, "consume", report_mocked)
    monkeypatch.setattr(api, "report_error", report_error.set)
    library.check()
    assert report_error.message == 'title_with_inhibitor'
    assert report_error.called == 1 
Example 21
Project: leapp-repository   Author: oamg   File: unit_test.py    Apache License 2.0 5 votes vote down vote up
def test_actor_no_inhibitor(monkeypatch):
    def report_mocked(*models):
        yield namedtuple('msg', ['report'])(Report('title_without_inhibitor'))

    report_error = ReportError()
    monkeypatch.setattr(api, "consume", report_mocked)
    monkeypatch.setattr(api, "report_error", report_error.set)
    library.check()
    assert not report_error.message
    assert report_error.called == 0 
Example 22
Project: leapp-repository   Author: oamg   File: unit_test.py    Apache License 2.0 5 votes vote down vote up
def __init__(self, arch):
        self.configuration = namedtuple('configuration', ['architecture'])(arch) 
Example 23
Project: leapp-repository   Author: oamg   File: unit_test.py    Apache License 2.0 5 votes vote down vote up
def __init__(self, arch):
        self.configuration = namedtuple('configuration', ['architecture'])(arch) 
Example 24
Project: leapp-repository   Author: oamg   File: unit_test.py    Apache License 2.0 5 votes vote down vote up
def __init__(self, arch):
        self.configuration = namedtuple('configuration', ['architecture'])(arch) 
Example 25
Project: leapp-repository   Author: oamg   File: test_checkmemory.py    Apache License 2.0 5 votes vote down vote up
def __init__(self, arch):
        self.configuration = namedtuple('configuration', ['architecture'])(arch) 
Example 26
Project: leapp-repository   Author: oamg   File: test_check_cpu.py    Apache License 2.0 5 votes vote down vote up
def __init__(self, arch):
        self.configuration = namedtuple('configuration', ['architecture'])(arch) 
Example 27
Project: leapp-repository   Author: oamg   File: test_checksystemarch.py    Apache License 2.0 5 votes vote down vote up
def test_valid_architectures(monkeypatch):
    class CurrentActorMocked(object):
        configuration = namedtuple('configuration', ['architecture'])(architecture.ARCH_ACCEPTED[0])

    monkeypatch.setattr(reporting, "create_report", create_report_mocked())
    monkeypatch.setattr(api, 'current_actor', CurrentActorMocked)

    library.check_architecture()

    assert reporting.create_report.called == 0 
Example 28
Project: pyblish-win   Author: pyblish   File: _psbsd.py    GNU Lesser General Public License v3.0 5 votes vote down vote up
def virtual_memory():
    """System virtual memory as a namedtuple."""
    mem = cext.virtual_mem()
    total, free, active, inactive, wired, cached, buffers, shared = mem
    avail = inactive + cached + free
    used = active + wired + cached
    percent = usage_percent((total - avail), total, _round=1)
    return svmem(total, avail, percent, used, free,
                 active, inactive, buffers, cached, shared, wired) 
Example 29
Project: pyblish-win   Author: pyblish   File: _psbsd.py    GNU Lesser General Public License v3.0 5 votes vote down vote up
def swap_memory():
    """System swap memory as (total, used, free, sin, sout) namedtuple."""
    total, used, free, sin, sout = [x * PAGESIZE for x in cext.swap_mem()]
    percent = usage_percent(used, total, _round=1)
    return _common.sswap(total, used, free, percent, sin, sout) 
Example 30
Project: pyblish-win   Author: pyblish   File: _psbsd.py    GNU Lesser General Public License v3.0 5 votes vote down vote up
def cpu_times():
    """Return system per-CPU times as a namedtuple"""
    user, nice, system, idle, irq = cext.cpu_times()
    return scputimes(user, nice, system, idle, irq) 
Example 31
Project: pyblish-win   Author: pyblish   File: _psbsd.py    GNU Lesser General Public License v3.0 5 votes vote down vote up
def per_cpu_times():
        """Return system CPU times as a namedtuple"""
        ret = []
        for cpu_t in cext.per_cpu_times():
            user, nice, system, idle, irq = cpu_t
            item = scputimes(user, nice, system, idle, irq)
            ret.append(item)
        return ret 
Example 32
Project: pyblish-win   Author: pyblish   File: _pswindows.py    GNU Lesser General Public License v3.0 5 votes vote down vote up
def virtual_memory():
    """System virtual memory as a namedtuple."""
    mem = cext.virtual_mem()
    totphys, availphys, totpagef, availpagef, totvirt, freevirt = mem
    #
    total = totphys
    avail = availphys
    free = availphys
    used = total - avail
    percent = usage_percent((total - avail), total, _round=1)
    return svmem(total, avail, percent, used, free) 
Example 33
Project: pyblish-win   Author: pyblish   File: _psosx.py    GNU Lesser General Public License v3.0 5 votes vote down vote up
def virtual_memory():
    """System virtual memory as a namedtuple."""
    total, active, inactive, wired, free = cext.virtual_mem()
    avail = inactive + free
    used = active + inactive + wired
    percent = usage_percent((total - avail), total, _round=1)
    return svmem(total, avail, percent, used, free,
                 active, inactive, wired) 
Example 34
Project: pyblish-win   Author: pyblish   File: _psosx.py    GNU Lesser General Public License v3.0 5 votes vote down vote up
def cpu_times():
    """Return system CPU times as a namedtuple."""
    user, nice, system, idle = cext.cpu_times()
    return scputimes(user, nice, system, idle) 
Example 35
Project: pyblish-win   Author: pyblish   File: test_collections.py    GNU Lesser General Public License v3.0 5 votes vote down vote up
def test_factory(self):
        Point = namedtuple('Point', 'x y')
        self.assertEqual(Point.__name__, 'Point')
        self.assertEqual(Point.__slots__, ())
        self.assertEqual(Point.__module__, __name__)
        self.assertEqual(Point.__getitem__, tuple.__getitem__)
        self.assertEqual(Point._fields, ('x', 'y'))

        self.assertRaises(ValueError, namedtuple, 'abc%', 'efg ghi')       # type has non-alpha char
        self.assertRaises(ValueError, namedtuple, 'class', 'efg ghi')      # type has keyword
        self.assertRaises(ValueError, namedtuple, '9abc', 'efg ghi')       # type starts with digit

        self.assertRaises(ValueError, namedtuple, 'abc', 'efg g%hi')       # field with non-alpha char
        self.assertRaises(ValueError, namedtuple, 'abc', 'abc class')      # field has keyword
        self.assertRaises(ValueError, namedtuple, 'abc', '8efg 9ghi')      # field starts with digit
        self.assertRaises(ValueError, namedtuple, 'abc', '_efg ghi')       # field with leading underscore
        self.assertRaises(ValueError, namedtuple, 'abc', 'efg efg ghi')    # duplicate field

        namedtuple('Point0', 'x1 y2')   # Verify that numbers are allowed in names
        namedtuple('_', 'a b c')        # Test leading underscores in a typename

        nt = namedtuple('nt', u'the quick brown fox')                       # check unicode input
        self.assertNotIn("u'", repr(nt._fields))
        nt = namedtuple('nt', (u'the', u'quick'))                           # check unicode input
        self.assertNotIn("u'", repr(nt._fields))

        self.assertRaises(TypeError, Point._make, [11])                     # catch too few args
        self.assertRaises(TypeError, Point._make, [11, 22, 33])             # catch too many args 
Example 36
Project: pyblish-win   Author: pyblish   File: test_collections.py    GNU Lesser General Public License v3.0 5 votes vote down vote up
def test_factory_doc_attr(self):
        Point = namedtuple('Point', 'x y')
        self.assertEqual(Point.__doc__, 'Point(x, y)') 
Example 37
Project: pyblish-win   Author: pyblish   File: test_collections.py    GNU Lesser General Public License v3.0 5 votes vote down vote up
def test_instance(self):
        Point = namedtuple('Point', 'x y')
        p = Point(11, 22)
        self.assertEqual(p, Point(x=11, y=22))
        self.assertEqual(p, Point(11, y=22))
        self.assertEqual(p, Point(y=22, x=11))
        self.assertEqual(p, Point(*(11, 22)))
        self.assertEqual(p, Point(**dict(x=11, y=22)))
        self.assertRaises(TypeError, Point, 1)                              # too few args
        self.assertRaises(TypeError, Point, 1, 2, 3)                        # too many args
        self.assertRaises(TypeError, eval, 'Point(XXX=1, y=2)', locals())   # wrong keyword argument
        self.assertRaises(TypeError, eval, 'Point(x=1)', locals())          # missing keyword argument
        self.assertEqual(repr(p), 'Point(x=11, y=22)')
        self.assertNotIn('__weakref__', dir(p))
        self.assertEqual(p, Point._make([11, 22]))                          # test _make classmethod
        self.assertEqual(p._fields, ('x', 'y'))                             # test _fields attribute
        self.assertEqual(p._replace(x=1), (1, 22))                          # test _replace method
        self.assertEqual(p._asdict(), dict(x=11, y=22))                     # test _asdict method
        self.assertEqual(vars(p), p._asdict())                              # verify that vars() works

        try:
            p._replace(x=1, error=2)
        except ValueError:
            pass
        else:
            self._fail('Did not detect an incorrect fieldname')

        # verify that field string can have commas
        Point = namedtuple('Point', 'x, y')
        p = Point(x=11, y=22)
        self.assertEqual(repr(p), 'Point(x=11, y=22)')

        # verify that fieldspec can be a non-string sequence
        Point = namedtuple('Point', ('x', 'y'))
        p = Point(x=11, y=22)
        self.assertEqual(repr(p), 'Point(x=11, y=22)') 
Example 38
Project: pyblish-win   Author: pyblish   File: test_collections.py    GNU Lesser General Public License v3.0 5 votes vote down vote up
def test_odd_sizes(self):
        Zero = namedtuple('Zero', '')
        self.assertEqual(Zero(), ())
        self.assertEqual(Zero._make([]), ())
        self.assertEqual(repr(Zero()), 'Zero()')
        self.assertEqual(Zero()._asdict(), {})
        self.assertEqual(Zero()._fields, ())

        Dot = namedtuple('Dot', 'd')
        self.assertEqual(Dot(1), (1,))
        self.assertEqual(Dot._make([1]), (1,))
        self.assertEqual(Dot(1).d, 1)
        self.assertEqual(repr(Dot(1)), 'Dot(d=1)')
        self.assertEqual(Dot(1)._asdict(), {'d':1})
        self.assertEqual(Dot(1)._replace(d=999), (999,))
        self.assertEqual(Dot(1)._fields, ('d',))

        n = 5000
        import string, random
        names = list(set(''.join([random.choice(string.ascii_letters)
                                  for j in range(10)]) for i in range(n)))
        n = len(names)
        Big = namedtuple('Big', names)
        b = Big(*range(n))
        self.assertEqual(b, tuple(range(n)))
        self.assertEqual(Big._make(range(n)), tuple(range(n)))
        for pos, name in enumerate(names):
            self.assertEqual(getattr(b, name), pos)
        repr(b)                                 # make sure repr() doesn't blow-up
        d = b._asdict()
        d_expected = dict(zip(names, range(n)))
        self.assertEqual(d, d_expected)
        b2 = b._replace(**dict([(names[1], 999),(names[-5], 42)]))
        b2_expected = range(n)
        b2_expected[1] = 999
        b2_expected[-5] = 42
        self.assertEqual(b2, tuple(b2_expected))
        self.assertEqual(b._fields, tuple(names)) 
Example 39
Project: pyblish-win   Author: pyblish   File: test_collections.py    GNU Lesser General Public License v3.0 5 votes vote down vote up
def test_name_conflicts(self):
        # Some names like "self", "cls", "tuple", "itemgetter", and "property"
        # failed when used as field names.  Test to make sure these now work.
        T = namedtuple('T', 'itemgetter property self cls tuple')
        t = T(1, 2, 3, 4, 5)
        self.assertEqual(t, (1,2,3,4,5))
        newt = t._replace(itemgetter=10, property=20, self=30, cls=40, tuple=50)
        self.assertEqual(newt, (10,20,30,40,50))

        # Broader test of all interesting names in a template
        with test_support.captured_stdout() as template:
            T = namedtuple('T', 'x', verbose=True)
        words = set(re.findall('[A-Za-z]+', template.getvalue()))
        words -= set(keyword.kwlist)
        T = namedtuple('T', words)
        # test __new__
        values = tuple(range(len(words)))
        t = T(*values)
        self.assertEqual(t, values)
        t = T(**dict(zip(T._fields, values)))
        self.assertEqual(t, values)
        # test _make
        t = T._make(values)
        self.assertEqual(t, values)
        # exercise __repr__
        repr(t)
        # test _asdict
        self.assertEqual(t._asdict(), dict(zip(T._fields, values)))
        # test _replace
        t = T._make(values)
        newvalues = tuple(v*10 for v in values)
        newt = t._replace(**dict(zip(T._fields, newvalues)))
        self.assertEqual(newt, newvalues)
        # test _fields
        self.assertEqual(T._fields, tuple(words))
        # test __getnewargs__
        self.assertEqual(t.__getnewargs__(), values) 
Example 40
Project: pyblish-win   Author: pyblish   File: test_pydoc.py    GNU Lesser General Public License v3.0 5 votes vote down vote up
def test_namedtuple_public_underscore(self):
        NT = namedtuple('NT', ['abc', 'def'], rename=True)
        with captured_stdout() as help_io:
            pydoc.help(NT)
        helptext = help_io.getvalue()
        self.assertIn('_1', helptext)
        self.assertIn('_replace', helptext)
        self.assertIn('_asdict', helptext) 
Example 41
Project: backtrader-cn   Author: pandalibin   File: sina.py    GNU General Public License v3.0 5 votes vote down vote up
def _json_object_hook(d):
    class_name = d.pop('_class_name', 'NamedTuple')
    return namedtuple(class_name, d.keys())(*d.values()) 
Example 42
Project: OpenAPS   Author: medicinexlab   File: oldpred.py    MIT License 5 votes vote down vote up
def _find_nearest_index(array, value):
    nearest_index = (np.abs(array-value)).argmin() #finds the index of the time value closest to the input value
    if (int(np.abs(array[nearest_index] - value)) < ACTUAL_BG_RANGE):
        #If inside the ACTUAL_BG_RANGE, then return the nearest index
        return nearest_index
    else:
        return -1


#Given the actual_bg_array, actual_bg_time_array, pred_array, pred_time_array, and num_pred_minutes,
#this function finds the nearest actual bg value to compare to the prediction value.
#If there is one, then it adds all the values to the result arrays, which are returned as a namedtuple.
#Returns the arrays such that the predBG corresponds to the actualBG in NUM_PRED_MINUTES in the future 
Example 43
Project: OpenAPS   Author: medicinexlab   File: oldpred.py    MIT License 5 votes vote down vote up
def _find_compare_array(actual_bg_array, actual_bg_time_array, pred_array, pred_time_array, num_pred_minutes):
    array_len = len(pred_array)

    result_actual_bg_array = np.zeros(array_len)
    result_actual_bg_time_array = np.zeros(array_len)
    result_pred_array = np.zeros(array_len)
    result_pred_time_array = np.zeros(array_len)
    curr = 0
    miss = 0

    for array_index in range(array_len):
        #The time that the prediction is predicting for
        future_time = int(pred_time_array[array_index]) + num_pred_minutes
        nearest_index = _find_nearest_index(actual_bg_time_array, future_time)

        if nearest_index == -1:
            miss += 1 #No corresponding bg to prediction
        else:
            result_actual_bg_array[curr] = actual_bg_array[nearest_index]
            result_actual_bg_time_array[curr] = actual_bg_time_array[nearest_index]
            result_pred_array[curr] = pred_array[array_index]
            result_pred_time_array[curr] = future_time
            curr += 1 #update index

    result_actual_bg_array = np.resize(result_actual_bg_array, array_len - miss) #resize arrays
    result_actual_bg_time_array = np.resize(result_actual_bg_time_array, array_len - miss)
    result_pred_array = np.resize(result_pred_array, array_len - miss)
    result_pred_time_array = np.resize(result_pred_time_array, array_len - miss)

    #Created namedtuple to hold the data
    OldPredData = namedtuple('OldPredData', ['result_actual_bg_array', 'result_actual_bg_time_array', 'result_pred_array', 'result_pred_time_array'])

    return OldPredData(result_actual_bg_array, result_actual_bg_time_array, result_pred_array, result_pred_time_array)


#This function takes in the bg dataframe, the start and end indices, and the number of minutes in the future
#that you want to make a prediction for AKA prediction horizon (e.g. make a prediction for 30 minutes in the future).
#It returns the namedtuple with the following attributes ['result_actual_bg_array', 'result_actual_bg_time_array', 'result_pred_array', 'result_pred_time_array'] 
Example 44
Project: pytuber   Author: tefra   File: test_params.py    MIT License 5 votes vote down vote up
def test_convert_successful(self, get_user):
        user = namedtuple("User", ["name"])
        get_user.return_value = user(name="Rj")

        self.assertEqual("Rj", self.param.convert("rj", None, None))
        get_user.assert_called_once_with("rj") 
Example 45
Project: kuaa   Author: rafaelwerneck   File: application.py    GNU General Public License v3.0 5 votes vote down vote up
def update_temporary_link(self):
        """Routine for drawing a temporary link to the mouse position."""
        if self._currentMousePositionPair:
            self._links.remove(self._currentMousePositionPair)
            self._currentMousePositionPair = None

        if not self._current_link_block:
            # Update for the last time
            self.draw_links()
            # return and break loop
            return

        mousePos = self.winfo_pointerxy()
        mousePos = (mousePos[0] - self.winfo_rootx(), mousePos[1] -
                    self.winfo_rooty())

        if (mousePos[0] >= self._current_link_block.x and mousePos[0] <=
            self._current_link_block.x + config.BLOCK_SIZE and
           mousePos[1] >= self._current_link_block.y and mousePos[1] <=
           self._current_link_block.y + config.BLOCK_SIZE):
            if self._currentMousePositionPair:
                self._links.remove(self._currentMousePositionPair)
                self._currentMousePositionPair = None

            self.draw_links()
            self.after(30, self.update_temporary_link)
            return

        # Faking a mouse "block" with x and y properties in order to make it
        # into draw_links()
        mouseBlock = namedtuple('MouseBlock', ['x', 'y'])
        mouseBlock = mouseBlock(mousePos[0], mousePos[1])

        self._currentMousePositionPair = (self._current_link_block, mouseBlock,
                                          False)

        self._links.append(self._currentMousePositionPair)

        self.draw_links()

        self.after(30, self.update_temporary_link) 
Example 46
Project: neural-fingerprinting   Author: StephanZheng   File: run_multigpu.py    BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def main(argv=None):
    f = {x: flags.FLAGS[x].value for x in dir(flags.FLAGS)}
    HParams = namedtuple('HParams', f.keys())
    hparams = HParams(**f)
    run_trainer(hparams) 
Example 47
Project: Ansible-Example-AB2018   Author: umit-ozturk   File: bigip_virtual_server.py    MIT License 5 votes vote down vote up
def destination_tuple(self):
        Destination = namedtuple('Destination', ['ip', 'port', 'route_domain'])
        if self._values['destination'] is None:
            result = Destination(ip=None, port=None, route_domain=None)
            return result
        addr = self._values['destination'].split("%")[0]
        result = Destination(ip=addr, port=self.port, route_domain=self.route_domain)
        return result 
Example 48
Project: friendly-telegram   Author: friendly-telegram   File: main.py    GNU Affero General Public License v3.0 5 votes vote down vote up
def get_api_token():
    """Get API Token from disk or environment"""
    while True:
        try:
            from . import api_token
        except ImportError:
            try:
                api_token = collections.namedtuple("api_token", ["ID", "HASH"])(os.environ["api_id"],
                                                                                os.environ["api_hash"])
            except KeyError:
                run_config({})
            else:
                return api_token
        else:
            return api_token 
Example 49
Project: interact   Author: dongshengmu   File: util.py    MIT License 5 votes vote down vote up
def return_o_e(func):
    """Decorator to convert return values to namedtuple(o, e)."""
    @wraps(func)
    def inner_func(*args, **kwargs):
        oe = func(*args, **kwargs)
        return Namedtuple_oe(*oe) if isinstance(oe, tuple) and len(oe) == 2 else oe
    return inner_func 
Example 50
Project: interact   Author: dongshengmu   File: util.py    MIT License 5 votes vote down vote up
def return_o_e_r(func):
    """Decorator to convert return values to namedtuple(o, e, r)."""
    @wraps(func)
    def inner_func(*args, **kwargs):
        oer = func(*args, **kwargs)
        return Namedtuple_oer(*oer) if isinstance(oer, tuple) and len(oer) == 3 else oer
    return inner_func


## unit test 
Example 51
Project: interact   Author: dongshengmu   File: util.py    MIT License 5 votes vote down vote up
def return_o_e(func):
    """Decorator to convert return values to namedtuple(o, e)."""
    @wraps(func)
    def inner_func(*args, **kwargs):
        oe = func(*args, **kwargs)
        return Namedtuple_oe(*oe) if isinstance(oe, tuple) and len(oe) == 2 else oe
    return inner_func 
Example 52
Project: interact   Author: dongshengmu   File: util.py    MIT License 5 votes vote down vote up
def return_o_e_r(func):
    """Decorator to convert return values to namedtuple(o, e, r)."""
    @wraps(func)
    def inner_func(*args, **kwargs):
        oer = func(*args, **kwargs)
        return Namedtuple_oer(*oer) if isinstance(oer, tuple) and len(oer) == 3 else oer
    return inner_func


## unit test 
Example 53
Project: core   Author: lifemapper   File: lmobj.py    GNU General Public License v3.0 5 votes vote down vote up
def _processGEOGCS(geocsStr):
        """
        @summary: Processes a geographic coordinate system's WKT into an object
        """
        GeoCS = namedtuple('GEOGCS', ['name', 'datum', 'spheroid', 'primeMeridian', 'unit'])
        Spheroid = namedtuple('Spheroid', ['name', 'semiAxisMajor', 'denomFlatRatio'])
        PrimeM = namedtuple('PrimeMeridian', ['name', 'longitude'])
        
        # Name
        name = geocsStr.split('"')[1]
        
        # Datum
        datumString = geocsStr.split('DATUM')[1].split('PRIMEM')[0]
        datum = datumString.split('"')[1]
        
        # Spheroid
        spheroidString = datumString.split('SPHEROID')[1]
        spheroidParts = spheroidString.split(',')
        spheroid = Spheroid(
                                  name=spheroidParts[0].split('"')[1],
                                  semiAxisMajor=float(spheroidParts[1]),
                                  denomFlatRatio=float(spheroidParts[2].split(']')[0])
                                 )
        
        # Prime Meridian
        pmString = geocsStr.split('PRIMEM')[1].split('UNIT')[0]
        primeM = PrimeM(name=pmString.split('"')[1],
                             longitude=float(pmString.split(',')[1].split(']')[0]))
    
        # Unit
        unit = geocsStr.split('UNIT')[1].split('"')[1]
        if unit.lower() == "metre": # Must match for EML
            unit = "meter"
        elif unit == "Degree":
            unit = "degree"
        
        ret = GeoCS(name, datum, spheroid, primeM, unit)
        return ret
    
    # .............................................................................. 
Example 54
Project: dynamic-training-with-apache-mxnet-on-aws   Author: awslabs   File: utils.py    Apache License 2.0 5 votes vote down vote up
def namedtuple_with_defaults(typename, field_names, default_values=()):
    """ create a namedtuple with default values """
    T = collections.namedtuple(typename, field_names)
    T.__new__.__defaults__ = (None, ) * len(T._fields)
    if isinstance(default_values, collections.Mapping):
        prototype = T(**default_values)
    else:
        prototype = T(*default_values)
    T.__new__.__defaults__ = tuple(prototype)
    return T 
Example 55
Project: dynamic-training-with-apache-mxnet-on-aws   Author: awslabs   File: utils.py    Apache License 2.0 5 votes vote down vote up
def zip_namedtuple(nt_list):
    """ accept list of namedtuple, return a dict of zipped fields """
    if not nt_list:
        return dict()
    if not isinstance(nt_list, list):
        nt_list = [nt_list]
    for nt in nt_list:
        assert type(nt) == type(nt_list[0])
    ret = {k : [v] for k, v in nt_list[0]._asdict().items()}
    for nt in nt_list[1:]:
        for k, v in nt._asdict().items():
            ret[k].append(v)
    return ret 
Example 56
Project: dynamic-training-with-apache-mxnet-on-aws   Author: awslabs   File: super_resolution.py    Apache License 2.0 5 votes vote down vote up
def perform_inference(sym, arg_params, aux_params, input_img, img_cb, img_cr):
    """Perform inference on image using mxnet"""
    metadata = onnx_mxnet.get_model_metadata('super_resolution.onnx')
    data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')]
    # create module
    mod = mx.mod.Module(symbol=sym, data_names=data_names, label_names=None)
    mod.bind(for_training=False, data_shapes=[(data_names[0], input_img.shape)])
    mod.set_params(arg_params=arg_params, aux_params=aux_params)

    # run inference
    batch = namedtuple('Batch', ['data'])
    mod.forward(batch([mx.nd.array(input_img)]))

    # Save the result
    img_out_y = Image.fromarray(np.uint8(mod.get_outputs()[0][0][0].
                                         asnumpy().clip(0, 255)), mode='L')

    result_img = Image.merge(
        "YCbCr", [img_out_y,
                  img_cb.resize(img_out_y.size, Image.BICUBIC),
                  img_cr.resize(img_out_y.size, Image.BICUBIC)]).convert("RGB")
    output_img_dim = 672
    assert result_img.size == (output_img_dim, output_img_dim)
    LOGGER.info("Super Resolution example success.")
    result_img.save("super_res_output.jpg")
    return result_img 
Example 57
Project: dynamic-training-with-apache-mxnet-on-aws   Author: awslabs   File: onnx_import_test.py    Apache License 2.0 5 votes vote down vote up
def test_bvlc_googlenet():
    """ Tests Googlenet model"""
    model_path, inputs, outputs = get_test_files('bvlc_googlenet')
    logging.info("Translating Googlenet model from ONNX to Mxnet")
    sym, arg_params, aux_params = onnx_mxnet.import_model(model_path)
    metadata = onnx_mxnet.get_model_metadata(model_path)
    assert len(metadata) == 2
    assert metadata.get('input_tensor_data')
    assert metadata.get('input_tensor_data') == [(u'data_0', (1, 3, 224, 224))]
    assert metadata.get('output_tensor_data')
    assert metadata.get('output_tensor_data') == [(u'prob_1', (1, 1000))]
    data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')]

    # run test for each test file
    for input_data, output_data in zip(inputs, outputs):
        # create module
        mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)
        mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None)
        mod.set_params(arg_params=arg_params, aux_params=aux_params,
                       allow_missing=True, allow_extra=True)
        # run inference
        batch = namedtuple('Batch', ['data'])
        mod.forward(batch([mx.nd.array(input_data)]), is_train=False)

        # verify the results
        npt.assert_equal(mod.get_outputs()[0].shape, output_data.shape)
        npt.assert_almost_equal(output_data, mod.get_outputs()[0].asnumpy(), decimal=3)
    logging.info("Googlenet model conversion Successful") 
Example 58
Project: dynamic-training-with-apache-mxnet-on-aws   Author: awslabs   File: onnx_import_test.py    Apache License 2.0 5 votes vote down vote up
def test_bvlc_rcnn_ilsvrc13():
    """Tests the bvlc rcnn model"""
    model_path, inputs, outputs = get_test_files('bvlc_reference_rcnn_ilsvrc13')
    logging.info("Translating rcnn_ilsvrc13 model from ONNX to Mxnet")
    sym, arg_params, aux_params = onnx_mxnet.import_model(model_path)
    metadata = onnx_mxnet.get_model_metadata(model_path)
    assert len(metadata) == 2
    assert metadata.get('input_tensor_data')
    assert metadata.get('input_tensor_data') == [(u'data_0', (1, 3, 224, 224))]
    assert metadata.get('output_tensor_data')
    assert metadata.get('output_tensor_data') == [(u'fc-rcnn_1', (1, 200))]
    data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')]

    # run test for each test file
    for input_data, output_data in zip(inputs, outputs):
        # create module
        mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)
        mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None)
        mod.set_params(arg_params=arg_params, aux_params=aux_params,
                       allow_missing=True, allow_extra=True)
        # run inference
        batch = namedtuple('Batch', ['data'])
        mod.forward(batch([mx.nd.array(input_data)]), is_train=False)

        # verify the results
        npt.assert_equal(mod.get_outputs()[0].shape, output_data.shape)
        npt.assert_almost_equal(output_data, mod.get_outputs()[0].asnumpy(), decimal=3)
    logging.info("rcnn_ilsvrc13 model conversion Successful") 
Example 59
Project: dynamic-training-with-apache-mxnet-on-aws   Author: awslabs   File: mxnet_export_test.py    Apache License 2.0 5 votes vote down vote up
def forward_pass(sym, arg, aux, data_names, input_data):
    """ Perform forward pass on given data"""
    # create module
    mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)
    mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None)
    mod.set_params(arg_params=arg, aux_params=aux,
                   allow_missing=True, allow_extra=True)
    # run inference
    batch = namedtuple('Batch', ['data'])
    mod.forward(batch([mx.nd.array(input_data)]), is_train=False)

    return mod.get_outputs()[0].asnumpy() 
Example 60
Project: dynamic-training-with-apache-mxnet-on-aws   Author: awslabs   File: test_mxnet_converter.py    Apache License 2.0 5 votes vote down vote up
def test_tiny_synset_random_input(self):
        np.random.seed(1989)
        input_shape = (1, 10)
        net = mx.sym.Variable('data')
        net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=5)
        net = mx.sym.SoftmaxOutput(net, name='softmax')
        mod = _get_mxnet_module(net, data_shapes=[('data', input_shape)],
                                mode='random', label_names=['softmax_label'])

        # Generate some dummy data
        input_data = np.random.uniform(-0.1, 0.1, input_shape)

        Batch = namedtuple('Batch', ['data'])
        mod.forward(Batch([mx.nd.array(input_data)]))

        kwargs = {'input_shape': {'data': input_shape}}
        # Get predictions from coreml
        coreml_model = convert(
            model=mod,
            class_labels=['Category1', 'Category2', 'Category3', 'Category4', 'Category5'],
            mode='classifier',
            **kwargs
        )

        prediction = coreml_model.predict(
            _mxnet_remove_batch({'data': input_data}))
        self.assertEqual(prediction['classLabel'], 'Category3') 
Example 61
Project: DOTA_models   Author: ringringyi   File: model_deploy.py    Apache License 2.0 5 votes vote down vote up
def _gather_clone_loss(clone, num_clones, regularization_losses):
  """Gather the loss for a single clone.

  Args:
    clone: A Clone namedtuple.
    num_clones: The number of clones being deployed.
    regularization_losses: Possibly empty list of regularization_losses
      to add to the clone losses.

  Returns:
    A tensor for the total loss for the clone.  Can be None.
  """
  # The return value.
  sum_loss = None
  # Individual components of the loss that will need summaries.
  clone_loss = None
  regularization_loss = None
  # Compute and aggregate losses on the clone device.
  with tf.device(clone.device):
    all_losses = []
    clone_losses = tf.get_collection(tf.GraphKeys.LOSSES, clone.scope)
    if clone_losses:
      clone_loss = tf.add_n(clone_losses, name='clone_loss')
      if num_clones > 1:
        clone_loss = tf.div(clone_loss, 1.0 * num_clones,
                            name='scaled_clone_loss')
      all_losses.append(clone_loss)
    if regularization_losses:
      regularization_loss = tf.add_n(regularization_losses,
                                     name='regularization_loss')
      all_losses.append(regularization_loss)
    if all_losses:
      sum_loss = tf.add_n(all_losses)
  # Add the summaries out of the clone device block.
  if clone_loss is not None:
    tf.summary.scalar(clone.scope + '/clone_loss', clone_loss)
  if regularization_loss is not None:
    tf.summary.scalar('regularization_loss', regularization_loss)
  return sum_loss 
Example 62
Project: DOTA_models   Author: ringringyi   File: sequence_layers.py    Apache License 2.0 5 votes vote down vote up
def __init__(self, net, labels_one_hot, model_params, method_params):
    """Stores argument in member variable for further use.

    Args:
      net: A tensor with shape [batch_size, num_features, feature_size] which
        contains some extracted image features.
      labels_one_hot: An optional (can be None) ground truth labels for the
        input features. Is a tensor with shape
        [batch_size, seq_length, num_char_classes]
      model_params: A namedtuple with model parameters (model.ModelParams).
      method_params: A SequenceLayerParams instance.
    """
    self._params = model_params
    self._mparams = method_params
    self._net = net
    self._labels_one_hot = labels_one_hot
    self._batch_size = net.get_shape().dims[0].value

    # Initialize parameters for char logits which will be computed on the fly
    # inside an LSTM decoder.
    self._char_logits = {}
    regularizer = slim.l2_regularizer(self._mparams.weight_decay)
    self._softmax_w = slim.model_variable(
        'softmax_w',
        [self._mparams.num_lstm_units, self._params.num_char_classes],
        initializer=orthogonal_initializer,
        regularizer=regularizer)
    self._softmax_b = slim.model_variable(
        'softmax_b', [self._params.num_char_classes],
        initializer=tf.zeros_initializer(),
        regularizer=regularizer) 
Example 63
Project: o2g   Author: hiposfer   File: gtfs_dummy.py    MIT License 5 votes vote down vote up
def create_dummy_data(routes, stops):
    """Create `calendar`, `stop_times`, `trips` and `shapes`.

    :return: DummyData namedtuple
    """
    # Build stops per route auxiliary map
    stops_per_route = defaultdict(lambda: [])
    stops_map = {}
    for s in stops:
        if not s.route_id:
            continue
        stops_per_route[s.route_id].append(s)
        stops_map[s.stop_id] = s

    calendar = _create_dummy_calendar()

    trips = \
        _create_dummy_trips(
            routes,
            stops_per_route,
            calendar)

    stop_times = _create_dummy_stoptimes(trips, stops_per_route)
    frequencies = _create_dummy_frequencies(trips)

    return DummyData(calendar, stop_times, trips, frequencies) 
Example 64
Project: Trusted-Platform-Module-nova   Author: BU-NU-CLOUD-SP16   File: vmops.py    Apache License 2.0 5 votes vote down vote up
def _get_esx_host_and_cookies(self, datastore, dc_path, file_path):
        hosts = datastore.get_connected_hosts(self._session)
        host = ds_obj.Datastore.choose_host(hosts)
        host_name = self._session._call_method(vutil, 'get_object_property',
                                               host, 'name')
        url = ds_obj.DatastoreURL('https', host_name, file_path, dc_path,
                                  datastore.name)
        cookie_header = url.get_transfer_ticket(self._session, 'PUT')
        name, value = cookie_header.split('=')
        # TODO(rgerganov): this is a hack to emulate cookiejar until we fix
        # oslo.vmware to accept plain http headers
        Cookie = collections.namedtuple('Cookie', ['name', 'value'])
        return host_name, [Cookie(name, value)] 
Example 65
Project: Trusted-Platform-Module-nova   Author: BU-NU-CLOUD-SP16   File: test_neutronv2.py    Apache License 2.0 5 votes vote down vote up
def test_populate_neutron_extension_values_binding_sriov(self,
                                         mock_get_instance_pci_devs,
                                         mock_get_pci_device_devspec):
        api = neutronapi.API()
        host_id = 'my_host_id'
        instance = {'host': host_id}
        port_req_body = {'port': {}}
        pci_req_id = 'my_req_id'
        pci_dev = {'vendor_id': '1377',
                   'product_id': '0047',
                   'address': '0000:0a:00.1',
                  }
        PciDevice = collections.namedtuple('PciDevice',
                               ['vendor_id', 'product_id', 'address'])
        mydev = PciDevice(**pci_dev)
        profile = {'pci_vendor_info': '1377:0047',
                   'pci_slot': '0000:0a:00.1',
                   'physical_network': 'phynet1',
                  }

        mock_get_instance_pci_devs.return_value = [mydev]
        devspec = mock.Mock()
        devspec.get_tags.return_value = {'physical_network': 'phynet1'}
        mock_get_pci_device_devspec.return_value = devspec
        api._populate_neutron_binding_profile(instance,
                                              pci_req_id, port_req_body)

        self.assertEqual(profile, port_req_body['port']['binding:profile']) 
Example 66
Project: Trusted-Platform-Module-nova   Author: BU-NU-CLOUD-SP16   File: test_driver_api.py    Apache License 2.0 5 votes vote down vote up
def _test_get_vnc_console(self):
        self._create_vm()
        fake_vm = self._get_vm_record()
        OptionValue = collections.namedtuple('OptionValue', ['key', 'value'])
        opt_val = OptionValue(key='', value=5906)
        fake_vm.set(vm_util.VNC_CONFIG_KEY, opt_val)
        vnc_console = self.conn.get_vnc_console(self.context, self.instance)
        self.assertEqual(self.vnc_host, vnc_console.host)
        self.assertEqual(5906, vnc_console.port) 
Example 67
Project: Trusted-Platform-Module-nova   Author: BU-NU-CLOUD-SP16   File: test_vm_util.py    Apache License 2.0 5 votes vote down vote up
def _create_fake_vms(self):
        fake_vms = fake.FakeRetrieveResult()
        OptionValue = collections.namedtuple('OptionValue', ['key', 'value'])
        for i in range(10):
            vm = fake.ManagedObject()
            opt_val = OptionValue(key='', value=5900 + i)
            vm.set(vm_util.VNC_CONFIG_KEY, opt_val)
            fake_vms.add_object(vm)
        return fake_vms 
Example 68
Project: Trusted-Platform-Module-nova   Author: BU-NU-CLOUD-SP16   File: test_vm_util.py    Apache License 2.0 5 votes vote down vote up
def test_propset_dict_simple(self):
        ObjectContent = collections.namedtuple('ObjectContent', ['propSet'])
        DynamicProperty = collections.namedtuple('Property', ['name', 'val'])

        object = ObjectContent(propSet=[
                    DynamicProperty(name='foo', val="bar")])
        propdict = vm_util.propset_dict(object.propSet)
        self.assertEqual("bar", propdict['foo']) 
Example 69
Project: sic   Author: Yanixos   File: poolmanager.py    GNU General Public License v3.0 5 votes vote down vote up
def _default_key_normalizer(key_class, request_context):
    """
    Create a pool key of type ``key_class`` for a request.

    According to RFC 3986, both the scheme and host are case-insensitive.
    Therefore, this function normalizes both before constructing the pool
    key for an HTTPS request. If you wish to change this behaviour, provide
    alternate callables to ``key_fn_by_scheme``.

    :param key_class:
        The class to use when constructing the key. This should be a namedtuple
        with the ``scheme`` and ``host`` keys at a minimum.

    :param request_context:
        A dictionary-like object that contain the context for a request.
        It should contain a key for each field in the :class:`HTTPPoolKey`
    """
    context = {}
    for key in key_class._fields:
        context[key] = request_context.get(key)
    context['scheme'] = context['scheme'].lower()
    context['host'] = context['host'].lower()
    return key_class(**context)


# A dictionary that maps a scheme to a callable that creates a pool key.
# This can be used to alter the way pool keys are constructed, if desired.
# Each PoolManager makes a copy of this dictionary so they can be configured
# globally here, or individually on the instance. 
Example 70
Project: parfive   Author: Cadair   File: results.py    MIT License 5 votes vote down vote up
def __init__(self, *args, errors=None):
        super().__init__(*args)
        self._errors = errors or list()
        self._error = namedtuple("error", ("filepath_partial", "url", "exception")) 
Example 71
Project: fs_image   Author: facebookincubator   File: enriched_namedtuple.py    MIT License 4 votes vote down vote up
def _normalize_enriched_namedtuple_fields(
    cls, field_to_value, field_to_base_and_default
):
    '''
    When constructing an enriched namedtuple instance, the user passes
    a number of keyword arguments to populate the namedtuple's fields.
    This helper takes the user-supplied keyword arguments as the
    dictionary `field_to_value`, and:
     - validates that all the keys are fields of this enriched namedtuple,
     - populates defaults for any keys that the user did not specify,
     - errors when a field is required, but the user did not supply a key,
     - adds `DO_NOT_USE_type` to prevent type-punning (see doc above).

    After this helper is done modifying `field_to_value`, the dictionary
    is additionally passed into the user-supplied `customize_fields_fn`,
    and only then is the namedtuple instantiated.

    If the namedtuple defines `NonConstructibleFields`, the user's
    `customize_fields_fn` will have to supply them.

    `field_to_base_and_default` has the form:

      {'field_name': (BaseClassDefiningField, default_value_or_RequiredField)}

    This dictionary is built by the metaclass via `_merge_fields_across_bases`
    at the time that your enriched type is being instantiated.

    DANGER: This **MUTATES** field_to_value.
    '''
    # Make sure all arguments are known.
    for field, _value in field_to_value.items():
        base_and_default = field_to_base_and_default.get(field)
        assert (
            (base_and_default is not None) and
            (base_and_default[1] is not NonConstructibleField)
        ), 'Constructing {} with unknown field {}'.format(cls, field)

    # Check we have required args, and back-fill optional ones.
    for field, (_base, default) in field_to_base_and_default.items():
        if field not in field_to_value:
            assert default is not RequiredField, (
                '{} requires the field {}'.format(cls, field)
            )
            field_to_value[field] = default

    # `customize_fields_fn` can do theas sorts of checks and assignments for
    # other, non-builtin fields.
    assert field_to_value['DO_NOT_USE_type'] is NonConstructibleField
    field_to_value['DO_NOT_USE_type'] = cls 
Example 72
Project: UR5_Controller   Author: tsinghua-rll   File: rtif.py    MIT License 4 votes vote down vote up
def unpack(byte_stream):
        joint_data = namedtuple("vector6d", ("p0", "p1", "p2", "p3", "p4", "p5"))
        coordinate_data = namedtuple("coordinate6d", ("x", "y", "z", "rx", "ry", "rz"))
        message = {"Time Step": 0.0,
                   "Target Joint Positions": joint_data(0., 0., 0., 0., 0., 0.),
                   "Target Joint Velocities": joint_data(0., 0., 0., 0., 0., 0.),
                   "Target Joint Accelerations": joint_data(0., 0., 0., 0., 0., 0.),
                   "Target Joint Currents": joint_data(0., 0., 0., 0., 0., 0.),
                   "Target Joint Torques": joint_data(0., 0., 0., 0., 0., 0.),
                   "Actual Joint Positions": joint_data(0., 0., 0., 0., 0., 0.),
                   "Actual Joint Velocities": joint_data(0., 0., 0., 0., 0., 0.),
                   "Actual Joint Currents": joint_data(0., 0., 0., 0., 0., 0.),
                   "Joint Control Currents": joint_data(0., 0., 0., 0., 0., 0.),
                   "Actual Tool Coordinates": coordinate_data(0., 0., 0., 0., 0., 0.),
                   "Actual Tool Speed": coordinate_data(0., 0., 0., 0., 0., 0.),
                   "Generalized Tool Force": coordinate_data(0., 0., 0., 0., 0., 0.),
                   "Target Tool Coordinates": coordinate_data(0., 0., 0., 0., 0., 0.),
                   "Target Tool Speed": coordinate_data(0., 0., 0., 0., 0., 0.),
                   "Digit Input": 0.0,
                   "Temperature": joint_data(0., 0., 0., 0., 0., 0.),
                   "Execute Time": 0.0,
                   "Robot Mode": 0.0,
                   "Joint Mode": joint_data(0., 0., 0., 0., 0., 0.),
                   "Safety Mode": 0.0,
                   }
        cnt, message["Time Step"] = struct.unpack_from('!Id', byte_stream)
        bias = 12                       # byte count + time step = 12 byte
        while 0 < bias < cnt:
            if bias == 12:
                message["Target Joint Positions"] = joint_data(*struct.unpack_from('!6d', byte_stream, bias + 0 * 48))
                message["Target Joint Velocities"] = joint_data(*struct.unpack_from('!6d', byte_stream, bias + 1 * 48))
                message["Target Joint Accelerations"] = joint_data(*struct.unpack_from('!6d', byte_stream, bias + 2 * 48))
                message["Target Joint Currents"] = joint_data(*struct.unpack_from('!6d', byte_stream, bias + 3 * 48))
                message["Target Joint Torques"] = joint_data(*struct.unpack_from('!6d', byte_stream, bias + 4 * 48))
                message["Actual Joint Positions"] = joint_data(*struct.unpack_from('!6d', byte_stream, bias + 5 * 48))
                message["Actual Joint Velocities"] = joint_data(*struct.unpack_from('!6d', byte_stream, bias + 6 * 48))
                message["Actual Joint Currents"] = joint_data(*struct.unpack_from('!6d', byte_stream, bias + 7 * 48))
                message["Joint Control Currents"] = joint_data(*struct.unpack_from('!6d', byte_stream, bias + 7 * 48))
                bias += 9 * 48
            elif bias == 444:
                message["Actual Tool Coordinates"] = coordinate_data(*struct.unpack_from('!6d', byte_stream, bias + 0 * 48))
                message["Actual Tool Speed"] = coordinate_data(*struct.unpack_from('!6d', byte_stream, bias + 1 * 48))
                message["Generalized Tool Force"] = coordinate_data(*struct.unpack_from('!6d', byte_stream, bias + 2 * 48))
                message["Target Tool Coordinates"] = coordinate_data(*struct.unpack_from('!6d', byte_stream, bias + 3 * 48))
                message["Target Tool Speed"] = coordinate_data(*struct.unpack_from('!6d', byte_stream, bias + 4 * 48))
                bias += 5 * 48
            elif bias == 684:
                message["Digit Input"] = struct.unpack_from('!d', byte_stream, bias + 0)
                message["Temperature"] = joint_data(*struct.unpack_from('!6d', byte_stream, bias + 8))
                message["Execute Time"] = struct.unpack_from('!d', byte_stream, bias + 56)
                message["Robot Mode"] = struct.unpack_from('!d', byte_stream, bias + 72)
                message["Joint Mode"] = joint_data(*struct.unpack_from('!6d', byte_stream, bias + 80))
                message["Safety Mode"] = struct.unpack_from('!d', byte_stream, bias + 128)
                bias = -1
        return message, (bias < 0) 
Example 73
Project: Ansible-Example-AB2018   Author: umit-ozturk   File: poolmanager.py    MIT License 4 votes vote down vote up
def _default_key_normalizer(key_class, request_context):
    """
    Create a pool key out of a request context dictionary.

    According to RFC 3986, both the scheme and host are case-insensitive.
    Therefore, this function normalizes both before constructing the pool
    key for an HTTPS request. If you wish to change this behaviour, provide
    alternate callables to ``key_fn_by_scheme``.

    :param key_class:
        The class to use when constructing the key. This should be a namedtuple
        with the ``scheme`` and ``host`` keys at a minimum.
    :type  key_class: namedtuple
    :param request_context:
        A dictionary-like object that contain the context for a request.
    :type  request_context: dict

    :return: A namedtuple that can be used as a connection pool key.
    :rtype:  PoolKey
    """
    # Since we mutate the dictionary, make a copy first
    context = request_context.copy()
    context['scheme'] = context['scheme'].lower()
    context['host'] = context['host'].lower()

    # These are both dictionaries and need to be transformed into frozensets
    for key in ('headers', '_proxy_headers', '_socks_options'):
        if key in context and context[key] is not None:
            context[key] = frozenset(context[key].items())

    # The socket_options key may be a list and needs to be transformed into a
    # tuple.
    socket_opts = context.get('socket_options')
    if socket_opts is not None:
        context['socket_options'] = tuple(socket_opts)

    # Map the kwargs to the names in the namedtuple - this is necessary since
    # namedtuples can't have fields starting with '_'.
    for key in list(context.keys()):
        context['key_' + key] = context.pop(key)

    # Default to ``None`` for keys missing from the context
    for field in key_class._fields:
        if field not in context:
            context[field] = None

    return key_class(**context)


#: A dictionary that maps a scheme to a callable that creates a pool key.
#: This can be used to alter the way pool keys are constructed, if desired.
#: Each PoolManager makes a copy of this dictionary so they can be configured
#: globally here, or individually on the instance. 
Example 74
Project: Ansible-Example-AB2018   Author: umit-ozturk   File: ec2_instance.py    MIT License 4 votes vote down vote up
def diff_instance_and_params(instance, params, ec2=None, skip=None):
    """boto3 instance obj, module params"""
    if ec2 is None:
        ec2 = module.client('ec2')

    if skip is None:
        skip = []

    changes_to_apply = []
    id_ = instance['InstanceId']

    ParamMapper = namedtuple('ParamMapper', ['param_key', 'instance_key', 'attribute_name', 'add_value'])

    def value_wrapper(v):
        return {'Value': v}

    param_mappings = [
        ParamMapper('ebs_optimized', 'EbsOptimized', 'ebsOptimized', value_wrapper),
        ParamMapper('termination_protection', 'DisableApiTermination', 'disableApiTermination', value_wrapper),
        # user data is an immutable property
        # ParamMapper('user_data', 'UserData', 'userData', value_wrapper),
    ]

    for mapping in param_mappings:
        if params.get(mapping.param_key) is not None and mapping.instance_key not in skip:
            value = ec2.describe_instance_attribute(Attribute=mapping.attribute_name, InstanceId=id_)
            if params.get(mapping.param_key) is not None and value[mapping.instance_key]['Value'] != params.get(mapping.param_key):
                arguments = dict(
                    InstanceId=instance['InstanceId'],
                    # Attribute=mapping.attribute_name,
                )
                arguments[mapping.instance_key] = mapping.add_value(params.get(mapping.param_key))
                changes_to_apply.append(arguments)

    if (params.get('network') or {}).get('source_dest_check') is not None:
        # network.source_dest_check is nested, so needs to be treated separately
        check = bool(params.get('network').get('source_dest_check'))
        if instance['SourceDestCheck'] != check:
            changes_to_apply.append(dict(
                InstanceId=instance['InstanceId'],
                SourceDestCheck={'Value': check},
            ))

    return changes_to_apply 
Example 75
Project: L   Author: vaultah   File: auth.py    MIT License 4 votes vote down vote up
def _logged(self, acid, token):
        if acid is None:
            return False

        rows = list(cookies.get_by_acid(acid))

        if not rows:
            return False

        def _make_ntuple(x):
            # Calculate TTL
            ts = datetime.datetime.timestamp(x['_id'].generation_time)
            is_session = x.get('session', False)
            ttl = ts - time.time() + (session_age if is_session else cookie_age)
            # Load record, construct instance, add instance to dicts
            token, record = x['token'], Record(id=x['account'])
            o = self._tup_cls(token, record, ttl, is_session)
            if o.ttl > 0:
                self.records[record] = o
                self.tokens[token] = o
            return o

        tups = {_make_ntuple(x) for x in rows}
        uptodate = {x for x in tups if x.ttl > 0}

        # No tokens left, delete the ACID
        if not uptodate:
            cookies.delete_acid(acid)
            return False

        # Delete the outdated tokens
        cookies.delete_tokens([x.token for x in tups - uptodate])
        
        # Get max ttl and the corresponsing namedtuple instance
        self.max_ttl, self.last = max((x.ttl, x) for x in uptodate)

        # Current token (can be None)
        if token is not None:
            try:
                self.token = (token, self.tokens[token].ttl)
            except KeyError:
                # Token not in self.tokens
                return False
        else:
            self.token = self.last
            self._newer_token()

        # We need self values outside the BL layer
        self.acid = acid, self.max_ttl
        self.multi = len(self.records) > 1
        self.uptodate = uptodate
        self.record = self.tokens[self.token[0]].record
        return True 
Example 76
Project: core   Author: lifemapper   File: lmobj.py    GNU General Public License v3.0 4 votes vote down vote up
def _processPROJCS(prjcsStr):
        """
        @summary: Processes a projected coordinate system's WKT into an object
        """
        PrjCS = namedtuple('PROJCS', ['name', 'geogcs', 'projectionName', 'parameters', 'unit'])
        Parameter = namedtuple('Parameter', ['name', 'value'])
        
        # Name
        name = prjcsStr.split('"')[1]
        
        # GeoGCS
        geocsStr = "GEOGCS{}".format(prjcsStr.split('GEOGCS')[1].split('PROJECTION')[0])
        geocs = LMSpatialObject._processGEOGCS(geocsStr)
        
        # Projection Name
        try:
            prjName = prjcsStr.split('PROJECTION')[1].split('"')[1]
        except:
            prjName = ""
        
        # Parameters
        parameters = []
        parametersGroup = prjcsStr.split('PARAMETER')
        
        try:
            for param in parametersGroup[1:]: # Cut out beginning string
                n = param.split('"')[1]
                v = param.split(']')[0].split(',')[1]
                parameters.append(Parameter(name=n, value=v))
        except:
            pass
        
        # Unit
        unit = prjcsStr.split('UNIT')[-1].split('"')[1]
        if unit.lower() == "metre": # Must match for EML
            unit = "meter"
        elif unit == "Degree":
            unit = "degree"
        
        ret = PrjCS(name, geocs, prjName, parameters, unit)
        return ret
    
    # .............................................................................. 
Example 77
Project: dynamic-training-with-apache-mxnet-on-aws   Author: awslabs   File: test_mxnet_converter.py    Apache License 2.0 4 votes vote down vote up
def _test_mxnet_model(self, net, input_shape, mode, class_labels=None,
                          coreml_mode=None, label_names=None, delta=1e-2,
                          pre_processing_args=None, input_name='data'):
        """ Helper method that convert the CoreML model into CoreML and compares the predictions
        over random data.

        Parameters
        ----------
        net: MXNet Symbol Graph
            The graph that we'll be converting into CoreML.

        input_shape: tuple of ints
            The shape of input data. Generally of the format (batch-size, channels, height, width)

        mode: (random|zeros|ones)
            The mode to use in order to set the parameters (weights and biases).

        label_names: list of strings
            The names of the output labels. Default: None

        delta: float
            The maximum difference b/w predictions of MXNet and CoreML that is tolerable.

        input_name: str
            The name of the input variable to the symbolic graph.
        """

        data_shapes = [(input_name, input_shape)]

        mod = _get_mxnet_module(net, data_shapes, mode, label_names)

        # Generate some dummy data
        input_data = {input_name: np.random.uniform(-10., 10., input_shape)}
        Batch = namedtuple('Batch', ['data'])
        mod.forward(Batch([mx.nd.array(input_data[input_name])]))
        mxnet_preds = mod.get_outputs()[0].asnumpy().flatten()

        # Get predictions from coreml
        coreml_model = convert(
            model=mod,
            class_labels=class_labels,
            mode=coreml_mode,
            input_shape={input_name: input_shape},
            preprocessor_args=pre_processing_args
        )
        coreml_preds = coreml_model.predict(_mxnet_remove_batch(input_data)).values()[0].flatten()

        # Check prediction accuracy
        self.assertEquals(len(mxnet_preds), len(coreml_preds))
        for i in range(len(mxnet_preds)):
            self.assertAlmostEquals(mxnet_preds[i], coreml_preds[i], delta=delta) 
Example 78
Project: dynamic-training-with-apache-mxnet-on-aws   Author: awslabs   File: test_mxnet_models.py    Apache License 2.0 4 votes vote down vote up
def _test_model(self, model_name, epoch_num, input_shape=(1, 3, 224, 224),
                    files=None):
        """ Tests whether the converted CoreML model's preds are equal to MXNet
        preds for a given model or not.

        Parameters
        ----------
        model_name: str
            Prefix of the MXNet model name as stored on the local directory.

        epoch_num : int
            Epoch number of model we would like to load.

        input_shape: tuple
            The shape of the input data in the form of (batch_size, channels,
            height, width)

        files: list of strings
            List of URLs pertaining to files that need to be downloaded in
            order to use the model.
        """

        if files is not None:
            print("Downloading files from urls: %s" % (files))
            for url in files:
                mx.test_utils.download(url)
                print("Downloaded %s" % (url))

        module = self._load_model(
            model_name=model_name,
            epoch_num=epoch_num,
            input_shape=input_shape
        )

        coreml_model = convert(module, input_shape={'data': input_shape})

        # Get predictions from MXNet and coreml
        div = []  # For storing KL divergence for each input.
        for _ in xrange(1):
            np.random.seed(1993)
            input_data = {'data': np.random.uniform(0, 1, input_shape)
                                           .astype(np.float32)}
            Batch = namedtuple('Batch', ['data'])
            module.forward(Batch([mx.nd.array(input_data['data'])]),
                           is_train=False)
            mxnet_pred = module.get_outputs()[0].asnumpy().flatten()
            coreml_pred = coreml_model \
                .predict(_mxnet_remove_batch(input_data)) \
                .values()[0] \
                .flatten()
            self.assertEqual(len(mxnet_pred), len(coreml_pred))
            div.append(_kl_divergence(mxnet_pred, coreml_pred))

        print("Average KL divergence is % s" % np.mean(div))
        self.assertTrue(np.mean(div) < 1e-4) 
Example 79
Project: DOTA_models   Author: ringringyi   File: model_deploy.py    Apache License 2.0 4 votes vote down vote up
def create_clones(config, model_fn, args=None, kwargs=None):
  """Creates multiple clones according to config using a `model_fn`.

  The returned values of `model_fn(*args, **kwargs)` are collected along with
  the scope and device used to created it in a namedtuple
  `Clone(outputs, scope, device)`

  Note: it is assumed that any loss created by `model_fn` is collected at
  the tf.GraphKeys.LOSSES collection.

  To recover the losses, summaries or update_ops created by the clone use:
  ```python
    losses = tf.get_collection(tf.GraphKeys.LOSSES, clone.scope)
    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, clone.scope)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, clone.scope)
  ```

  The deployment options are specified by the config object and support
  deploying one or several clones on different GPUs and one or several replicas
  of such clones.

  The argument `model_fn` is called `config.num_clones` times to create the
  model clones as `model_fn(*args, **kwargs)`.

  If `config` specifies deployment on multiple replicas then the default
  tensorflow device is set appropriatly for each call to `model_fn` and for the
  slim variable creation functions: model and global variables will be created
  on the `ps` device, the clone operations will be on the `worker` device.

  Args:
    config: A DeploymentConfig object.
    model_fn: A callable. Called as `model_fn(*args, **kwargs)`
    args: Optional list of arguments to pass to `model_fn`.
    kwargs: Optional list of keyword arguments to pass to `model_fn`.

  Returns:
    A list of namedtuples `Clone`.
  """
  clones = []
  args = args or []
  kwargs = kwargs or {}
  with slim.arg_scope([slim.model_variable, slim.variable],
                      device=config.variables_device()):
    # Create clones.
    for i in range(0, config.num_clones):
      with tf.name_scope(config.clone_scope(i)) as clone_scope:
        clone_device = config.clone_device(i)
        with tf.device(clone_device):
          with tf.variable_scope(tf.get_variable_scope(),
                                 reuse=True if i > 0 else None):
            outputs = model_fn(*args, **kwargs)
          clones.append(Clone(outputs, clone_scope, clone_device))
  return clones 
Example 80
Project: bigquerylayers   Author: smandaric   File: poolmanager.py    GNU General Public License v3.0 4 votes vote down vote up
def _default_key_normalizer(key_class, request_context):
    """
    Create a pool key out of a request context dictionary.

    According to RFC 3986, both the scheme and host are case-insensitive.
    Therefore, this function normalizes both before constructing the pool
    key for an HTTPS request. If you wish to change this behaviour, provide
    alternate callables to ``key_fn_by_scheme``.

    :param key_class:
        The class to use when constructing the key. This should be a namedtuple
        with the ``scheme`` and ``host`` keys at a minimum.
    :type  key_class: namedtuple
    :param request_context:
        A dictionary-like object that contain the context for a request.
    :type  request_context: dict

    :return: A namedtuple that can be used as a connection pool key.
    :rtype:  PoolKey
    """
    # Since we mutate the dictionary, make a copy first
    context = request_context.copy()
    context["scheme"] = context["scheme"].lower()
    context["host"] = context["host"].lower()

    # These are both dictionaries and need to be transformed into frozensets
    for key in ("headers", "_proxy_headers", "_socks_options"):
        if key in context and context[key] is not None:
            context[key] = frozenset(context[key].items())

    # The socket_options key may be a list and needs to be transformed into a
    # tuple.
    socket_opts = context.get("socket_options")
    if socket_opts is not None:
        context["socket_options"] = tuple(socket_opts)

    # Map the kwargs to the names in the namedtuple - this is necessary since
    # namedtuples can't have fields starting with '_'.
    for key in list(context.keys()):
        context["key_" + key] = context.pop(key)

    # Default to ``None`` for keys missing from the context
    for field in key_class._fields:
        if field not in context:
            context[field] = None

    return key_class(**context)


#: A dictionary that maps a scheme to a callable that creates a pool key.
#: This can be used to alter the way pool keys are constructed, if desired.
#: Each PoolManager makes a copy of this dictionary so they can be configured
#: globally here, or individually on the instance.