Python collections.namedtuple() Examples

The following are 30 code examples for showing how to use collections.namedtuple(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may check out the related API usage on the sidebar.

You may also want to check out all available functions/classes of the module collections , or try the search function .

Example 1
Project: iSDX   Author: sdn-ixp   File: replay.py    License: Apache License 2.0 6 votes vote down vote up
def __init__(self, config, flows_dir, ports_dir, num_timesteps, debug=False):
        self.logger = logging.getLogger("LogHistory")
        if debug:
            self.logger.setLevel(logging.DEBUG)

        self.log_entry = namedtuple("LogEntry", "source destination type")
        self.ports = defaultdict(list)
        self.flows = defaultdict(list)

        self.data = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
        self.current_timestep = 0
        self.total_timesteps = num_timesteps

        self.parse_config(config)
        self.parse_logs(num_timesteps, flows_dir, ports_dir)
        self.info()

        pretty(self.data) 
Example 2
Project: neural-fingerprinting   Author: StephanZheng   File: submissions.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def __init__(self, datastore_client, storage_client, round_name):
    """Initializes CompetitionSubmissions.

    Args:
      datastore_client: instance of CompetitionDatastoreClient
      storage_client: instance of CompetitionStorageClient
      round_name: name of the round
    """
    self._datastore_client = datastore_client
    self._storage_client = storage_client
    self._round_name = round_name
    # each of the variables is a dictionary,
    # where key - submission ID
    # value - SubmissionDescriptor namedtuple
    self._attacks = None
    self._targeted_attacks = None
    self._defenses = None 
Example 3
Project: python-clean-architecture   Author: pcah   File: test_yaml.py    License: MIT License 6 votes vote down vote up
def test_construct_namedtuple():
    """Original Loader has a problem of building an object which state is set
    by __new__, instead of __init__.
    """
    from collections import namedtuple

    class FooClass(serialization.yaml.yaml.YAMLObject, namedtuple('Foo', "x, y")):
        yaml_tag = 'foo'
        yaml_constructor = serialization.CustomYamlLoader

        def __setstate__(self, data):
            self.data = data

    contents = (
        "---\n"
        "foo: !<foo> {x: 1, y: 2}\n"
    )
    foo_object = serialization.load_yaml(contents)['foo']
    assert isinstance(foo_object, FooClass)
    assert foo_object.data == {'x': 1, 'y': 2} 
Example 4
def test_forward_types():
    #Test forward with other data batch API
    Batch = namedtuple('Batch', ['data'])
    data = mx.sym.Variable('data')
    out = data * 2
    mod = mx.mod.Module(symbol=out, label_names=None)
    mod.bind(data_shapes=[('data', (1, 10))])
    mod.init_params()
    data1 = [mx.nd.ones((1, 10))]
    mod.forward(Batch(data1))
    assert mod.get_outputs()[0].shape == (1, 10)
    data2 = [mx.nd.ones((3, 5))]
    mod.forward(Batch(data2))
    assert mod.get_outputs()[0].shape == (3, 5)

    #Test forward with other NDArray and np.ndarray inputs
    data = mx.sym.Variable('data')
    out = data * 2
    mod = mx.mod.Module(symbol=out, label_names=None)
    mod.bind(data_shapes=[('data', (1, 10))])
    mod.init_params()
    data1 = mx.nd.ones((1, 10))
    assert mod.predict(data1).shape == (1, 10)
    data2 = np.ones((1, 10))
    assert mod.predict(data1).shape == (1, 10) 
Example 5
Project: DOTA_models   Author: ringringyi   File: data_utils.py    License: Apache License 2.0 6 votes vote down vote up
def read_names(names_path):
    """read data from downloaded file. See SmallNames.txt for example format
    or go to https://www.kaggle.com/kaggle/us-baby-names for full lists

    Args:
        names_path: path to the csv file similar to the example type
    Returns:
        Dataset: a namedtuple of two elements: deduped names and their associated
            counts. The names contain only 26 chars and are all lower case
    """
    names_data = pd.read_csv(names_path)
    names_data.Name = names_data.Name.str.lower()

    name_data = names_data.groupby(by=["Name"])["Count"].sum()
    name_counts = np.array(name_data.tolist())
    names_deduped = np.array(name_data.index.tolist())

    Dataset = collections.namedtuple('Dataset', ['Name', 'Count'])
    return Dataset(names_deduped, name_counts) 
Example 6
Project: DOTA_models   Author: ringringyi   File: model_deploy.py    License: Apache License 2.0 6 votes vote down vote up
def _optimize_clone(optimizer, clone, num_clones, regularization_losses,
                    **kwargs):
  """Compute losses and gradients for a single clone.

  Args:
    optimizer: A tf.Optimizer  object.
    clone: A Clone namedtuple.
    num_clones: The number of clones being deployed.
    regularization_losses: Possibly empty list of regularization_losses
      to add to the clone losses.
    **kwargs: Dict of kwarg to pass to compute_gradients().

  Returns:
    A tuple (clone_loss, clone_grads_and_vars).
      - clone_loss: A tensor for the total loss for the clone.  Can be None.
      - clone_grads_and_vars: List of (gradient, variable) for the clone.
        Can be empty.
  """
  sum_loss = _gather_clone_loss(clone, num_clones, regularization_losses)
  clone_grad = None
  if sum_loss is not None:
    with tf.device(clone.device):
      clone_grad = optimizer.compute_gradients(sum_loss, **kwargs)
  return sum_loss, clone_grad 
Example 7
Project: DOTA_models   Author: ringringyi   File: model.py    License: Apache License 2.0 6 votes vote down vote up
def create_loss(self, data, endpoints):
    """Creates all losses required to train the model.

    Args:
      data: InputEndpoints namedtuple.
      endpoints: Model namedtuple.

    Returns:
      Total loss.
    """
    # NOTE: the return value of ModelLoss is not used directly for the
    # gradient computation because under the hood it calls slim.losses.AddLoss,
    # which registers the loss in an internal collection and later returns it
    # as part of GetTotalLoss. We need to use total loss because model may have
    # multiple losses including regularization losses.
    self.sequence_loss_fn(endpoints.chars_logit, data.labels)
    total_loss = slim.losses.get_total_loss()
    tf.summary.scalar('TotalLoss', total_loss)
    return total_loss 
Example 8
Project: DOTA_models   Author: ringringyi   File: fsns_test.py    License: Apache License 2.0 6 votes vote down vote up
def test_decodes_example_proto(self):
    expected_label = range(37)
    expected_image, encoded = unittest_utils.create_random_image(
        'PNG', shape=(150, 600, 3))
    serialized = unittest_utils.create_serialized_example({
        'image/encoded': [encoded],
        'image/format': ['PNG'],
        'image/class':
        expected_label,
        'image/unpadded_class':
        range(10),
        'image/text': ['Raw text'],
        'image/orig_width': [150],
        'image/width': [600]
    })

    decoder = fsns.get_split('train', dataset_dir()).decoder
    with self.test_session() as sess:
      data_tuple = collections.namedtuple('DecodedData', decoder.list_items())
      data = sess.run(data_tuple(*decoder.decode(serialized)))

    self.assertAllEqual(expected_image, data.image)
    self.assertAllEqual(expected_label, data.label)
    self.assertEqual(['Raw text'], data.text)
    self.assertEqual([1], data.num_of_views) 
Example 9
Project: soccer-matlab   Author: utra-robosoccer   File: minitaur_reactive_env.py    License: BSD 2-Clause "Simplified" License 6 votes vote down vote up
def _reset(self):
    # TODO(b/73666007): Use composition instead of inheritance.
    # (http://go/design-for-testability-no-inheritance).
    init_pose = MinitaurPose(
        swing_angle_1=INIT_SWING_POS,
        swing_angle_2=INIT_SWING_POS,
        swing_angle_3=INIT_SWING_POS,
        swing_angle_4=INIT_SWING_POS,
        extension_angle_1=INIT_EXTENSION_POS,
        extension_angle_2=INIT_EXTENSION_POS,
        extension_angle_3=INIT_EXTENSION_POS,
        extension_angle_4=INIT_EXTENSION_POS)
    # TODO(b/73734502): Refactor input of _convert_from_leg_model to namedtuple.
    initial_motor_angles = self._convert_from_leg_model(list(init_pose))
    super(MinitaurReactiveEnv, self)._reset(
        initial_motor_angles=initial_motor_angles, reset_duration=0.5)
    return self._get_observation() 
Example 10
Project: aws-ops-automator   Author: awslabs   File: __init__.py    License: Apache License 2.0 6 votes vote down vote up
def as_namedtuple(name, d, deep=True, name_func=None, excludes=None):
    name_func = name_func if name_func is not None else tuple_name_func

    if not isinstance(d, dict) or getattr(d, "keys") is None:
        return d

    if excludes is None:
        excludes = []

    dest = {}

    if deep:
        # deep copy to avoid modifications on input dictionaries
        for key in list(d.keys()):
            key_name = name_func(key)
            if is_dict(d[key]) and key not in excludes:
                dest[key_name] = as_namedtuple(key, d[key], deep=True, name_func=name_func, excludes=excludes)
            elif is_array(d[key]) and key not in excludes:
                dest[key_name] = [as_namedtuple(key, i, deep=True, name_func=name_func, excludes=excludes) for i in d[key]]
            else:
                dest[key_name] = d[key]
    else:
        dest = {name_func(key): d[key] for key in list(d.keys())}

    return collections.namedtuple(name_func(name), list(dest.keys()))(*list(dest.values())) 
Example 11
Project: NGU-scripts   Author: kujan   File: features.py    License: GNU Lesser General Public License v3.0 6 votes vote down vote up
def get_inventory_slots(slots :int) -> None:
        """Get coords for inventory slots from 1 to slots."""
        point = namedtuple("p", ("x", "y"))
        i = 1
        row = 1
        x_pos, y_pos = coords.INVENTORY_SLOTS
        res = []
        
        while i <= slots:
            x = x_pos + (i - (12 * (row - 1))) * 50
            y = y_pos + ((row - 1) * 50)
            res.append(point(x, y))
            if i % 12 == 0:
                row += 1
            i += 1
        return res 
Example 12
Project: leveldb-py   Author: jtolio   File: leveldb.py    License: MIT License 6 votes vote down vote up
def next(self):
        """Advances the iterator one step. Also returns the current value prior
        to moving the iterator

        @rtype: Row (namedtuple of key, value) if keys_only=False, otherwise
                string (the key)

        @raise StopIteration: if called on an iterator that is not valid
        """
        if not self.valid():
            raise StopIteration()
        if self._keys_only:
            rv = self.key()
        else:
            rv = Row(self.key(), self.value())
        self._impl.next()
        return rv 
Example 13
Project: leveldb-py   Author: jtolio   File: leveldb.py    License: MIT License 6 votes vote down vote up
def prev(self):
        """Backs the iterator up one step. Also returns the current value prior
        to moving the iterator.

        @rtype: Row (namedtuple of key, value) if keys_only=False, otherwise
                string (the key)

        @raise StopIteration: if called on an iterator that is not valid
        """
        if not self.valid():
            raise StopIteration()
        if self._keys_only:
            rv = self.key()
        else:
            rv = Row(self.key(), self.value())
        self._impl.prev()
        return rv 
Example 14
Project: psmqtt   Author: eschava   File: tests.py    License: MIT License 6 votes vote down vote up
def test_tuple_command_handler(self):
        handler = type("TestHandler", (TupleCommandHandler, object),
                       {"get_value": lambda s: namedtuple('test', 'a b')(10, 20)})('test')
        # normal execution
        self.assertEqual(10, handler.handle('a'))
        self.assertEqual({'a': 10, 'b': 20}, handler.handle('*'))
        self.assertEqual('{"a": 10, "b": 20}', handler.handle('*;'))
        # exceptions
        self.assertRaises(Exception, handler.handle, '')
        self.assertRaises(Exception, handler.handle, '/')
        self.assertRaises(Exception, handler.handle, '*/')
        self.assertRaises(Exception, handler.handle, '/*')
        self.assertRaises(Exception, handler.handle, 'blabla')
        self.assertRaises(Exception, handler.handle, 'bla/bla')
        self.assertRaises(Exception, handler.handle, 'bla/')
        self.assertRaises(Exception, handler.handle, '/bla') 
Example 15
Project: psmqtt   Author: eschava   File: tests.py    License: MIT License 6 votes vote down vote up
def test_index_tuple_command_handler(self):
        r = [namedtuple('test', 'a b')(1, 2), namedtuple('test', 'a b')(3, 4)]
        handler = type("TestHandler", (IndexTupleCommandHandler, object),
                       {"get_value": lambda s: r})('test')
        # normal execution
        self.assertEqual([1, 3], handler.handle('a/*'))
        self.assertEqual("[1, 3]", handler.handle('a/*;'))
        self.assertEqual(3, handler.handle('a/1'))
        self.assertEqual({'a': 3, 'b': 4}, handler.handle('*/1'))
        self.assertEqual('{"a": 3, "b": 4}', handler.handle('*;/1'))
        # exceptions
        self.assertRaises(Exception, handler.handle, '')
        self.assertRaises(Exception, handler.handle, '*')
        self.assertRaises(Exception, handler.handle, '*;')
        self.assertRaises(Exception, handler.handle, 'a')
        self.assertRaises(Exception, handler.handle, 'a/')
        self.assertRaises(Exception, handler.handle, '/')
        self.assertRaises(Exception, handler.handle, '*/')
        self.assertRaises(Exception, handler.handle, '/*')
        self.assertRaises(Exception, handler.handle, 'blabla')
        self.assertRaises(Exception, handler.handle, 'bla/bla')
        self.assertRaises(Exception, handler.handle, 'bla/')
        self.assertRaises(Exception, handler.handle, '/bla') 
Example 16
Project: gist-alfred   Author: danielecook   File: poolmanager.py    License: MIT License 6 votes vote down vote up
def connection_from_pool_key(self, pool_key, request_context=None):
        """
        Get a :class:`ConnectionPool` based on the provided pool key.

        ``pool_key`` should be a namedtuple that only contains immutable
        objects. At a minimum it must have the ``scheme``, ``host``, and
        ``port`` fields.
        """
        with self.pools.lock:
            # If the scheme, host, or port doesn't match existing open
            # connections, open a new ConnectionPool.
            pool = self.pools.get(pool_key)
            if pool:
                return pool

            # Make a fresh ConnectionPool of the desired type
            scheme = request_context['scheme']
            host = request_context['host']
            port = request_context['port']
            pool = self._new_pool(scheme, host, port, request_context=request_context)
            self.pools[pool_key] = pool

        return pool 
Example 17
Project: backtrader-cn   Author: pandalibin   File: sina.py    License: GNU General Public License v3.0 5 votes vote down vote up
def _json_object_hook(d):
    class_name = d.pop('_class_name', 'NamedTuple')
    return namedtuple(class_name, d.keys())(*d.values()) 
Example 18
Project: neural-fingerprinting   Author: StephanZheng   File: run_multigpu.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def main(argv=None):
    f = {x: flags.FLAGS[x].value for x in dir(flags.FLAGS)}
    HParams = namedtuple('HParams', f.keys())
    hparams = HParams(**f)
    run_trainer(hparams) 
Example 19
Project: multibootusb   Author: mbusb   File: osdriver.py    License: GNU General Public License v2.0 5 votes vote down vote up
def collect_relevant_info(obj, tuple_name, attributes, named_tuple):
    if len(named_tuple)==0:
        names = [x[0] for x in attributes]
        named_tuple.append(collections.namedtuple(tuple_name, names))
    L = []
    for (attr, convfunc) in attributes:
        v = getattr(obj, attr)
        L.append(None if v is None else convfunc(v))
    return named_tuple[0](*L) 
Example 20
Project: multibootusb   Author: mbusb   File: win32.py    License: GNU General Public License v2.0 5 votes vote down vote up
def findVolumeGuids():
    DiskExtent = collections.namedtuple(
        'DiskExtent', ['DiskNumber', 'StartingOffset', 'ExtentLength'])
    Volume = collections.namedtuple(
        'Volume', ['Guid', 'MediaType', 'DosDevice', 'Extents'])
    found = []
    h, guid = FindFirstVolume()
    while h and guid:
        #print (guid)
        #print (guid, win32file.GetDriveType(guid),
        #       win32file.QueryDosDevice(guid[4:-1]))
        hVolume = win32file.CreateFile(
            guid[:-1], win32con.GENERIC_READ,
            win32con.FILE_SHARE_READ | win32con.FILE_SHARE_WRITE,
            None, win32con.OPEN_EXISTING, win32con.FILE_ATTRIBUTE_NORMAL,  None)
        extents = []
        driveType = win32file.GetDriveType(guid)
        if driveType in [win32con.DRIVE_REMOVABLE, win32con.DRIVE_FIXED]:
            x = win32file.DeviceIoControl(
                hVolume, winioctlcon.IOCTL_VOLUME_GET_VOLUME_DISK_EXTENTS,
                None, 512, None)
            instream = io.BytesIO(x)
            numRecords = struct.unpack('<q', instream.read(8))[0]
            fmt = '<qqq'
            sz = struct.calcsize(fmt)
            while 1:
                b = instream.read(sz)
                if len(b) < sz:
                    break
                rec = struct.unpack(fmt, b)
                extents.append( DiskExtent(*rec) )
        vinfo = Volume(guid, driveType, win32file.QueryDosDevice(guid[4:-1]),
                       extents)
        found.append(vinfo)
        guid = FindNextVolume(h)
    return found 
Example 21
Project: multibootusb   Author: mbusb   File: imager.py    License: GNU General Public License v2.0 5 votes vote down vote up
def imager_usb_detail(physical_disk):
        """
        Function to detect details of USB disk using lsblk
        :param physical_disk: /dev/sd? (linux) or integer disk number (win)
        :return: details of size, type and model as tuples
        """
        _ntuple_diskusage = collections.namedtuple(
            'usage', 'total_size usb_type model')

        if platform.system() == "Linux":
            output = subprocess.check_output("lsblk -ib " + physical_disk,
                                             shell=True)
            for line in output.splitlines():
                line = line.split()
                if line[2].strip() == b'1' and line[5].strip() == b'disk':
                    total_size = line[3]
                    if not total_size:
                        total_size = "Unknown"
                    usb_type = "Removable"
                    model = subprocess.check_output(
                        "lsblk -in -f -o MODEL " + physical_disk,
                        shell=True).decode().strip()
                    if not model:
                        model = "Unknown"
        else:
            dinfo = osdriver.wmi_get_physicaldrive_info_ex(physical_disk)
            return _ntuple_diskusage(*[dinfo[a] for a in [
                'size_total', 'mediatype', 'model']])

        return _ntuple_diskusage(total_size, usb_type, model) 
Example 22
Project: multibootusb   Author: mbusb   File: usb.py    License: GNU General Public License v2.0 5 votes vote down vote up
def disk_usage(mount_path):
    """
    Return disk usage statistics about the given path as a (total, used, free)
    namedtuple.  Values are expressed in bytes.
    """
    # Author: Giampaolo Rodola' <g.rodola [AT] gmail [DOT] com>
    # License: MIT
    _ntuple_diskusage = collections.namedtuple('usage', 'total used free')

    if platform.system() == "Linux":
        st = os.statvfs(mount_path)
        free = st.f_bavail * st.f_frsize
        total = st.f_blocks * st.f_frsize
        used = (st.f_blocks - st.f_bfree) * st.f_frsize

        return _ntuple_diskusage(total, used, free)

    elif platform.system() == "Windows":

        _, total, free = ctypes.c_ulonglong(), ctypes.c_ulonglong(), \
                         ctypes.c_ulonglong()
        if sys.version_info >= (3,) or isinstance(mount_path, unicode):
            fun = ctypes.windll.kernel32.GetDiskFreeSpaceExW
        else:
            fun = ctypes.windll.kernel32.GetDiskFreeSpaceExA
        ret = fun(mount_path, ctypes.byref(_), ctypes.byref(total), ctypes.byref(free))
        if ret == 0:
            raise ctypes.WinError()
        used = total.value - free.value

        return _ntuple_diskusage(total.value, used, free.value)
    else:
        raise NotImplementedError("Platform not supported.") 
Example 23
Project: friendly-telegram   Author: friendly-telegram   File: heroku.py    License: GNU Affero General Public License v3.0 5 votes vote down vote up
def __init__(self, **kwargs):
        super().__init__(**kwargs)
        if "heroku_api_token" in os.environ:
            # This is called before asyncio is even set up. We can only use sync methods which is fine.
            api_token = collections.namedtuple("api_token", ["ID", "HASH"])(os.environ["api_id"],
                                                                            os.environ["api_hash"])
            app, config = heroku.get_app([c[1] for c in self.client_data],
                                         os.environ["heroku_api_token"], api_token, False, True)
            if os.environ["DYNO"].startswith("web."):
                app.scale_formation_process("worker-DO-NOT-TURN-ON-OR-THINGS-WILL-BREAK", 0)
            atexit.register(functools.partial(exit_handler, app)) 
Example 24
Project: friendly-telegram   Author: friendly-telegram   File: initial_setup.py    License: GNU Affero General Public License v3.0 5 votes vote down vote up
def set_tg_api(self, request):
        if self.client_data and await self.check_user(request) is None:
            return web.Response(status=302, headers={"Location": "/"})  # They gotta sign in.
        text = await request.text()
        if len(text) < 36:
            return web.Response(status=400)
        api_id = text[32:]
        api_hash = text[:32]
        if any(c not in string.hexdigits for c in api_hash) or any(c not in string.digits for c in api_id):
            return web.Response(status=400)
        with open(os.path.join(utils.get_base_dir(), "api_token.py"), "w") as f:
            f.write("HASH = \"" + api_hash + "\"\nID = \"" + api_id + "\"\n")
        self.api_token = collections.namedtuple("api_token", ("ID", "HASH"))(api_id, api_hash)
        self.api_set.set()
        return web.Response() 
Example 25
Project: friendly-telegram   Author: friendly-telegram   File: main.py    License: GNU Affero General Public License v3.0 5 votes vote down vote up
def get_api_token():
    """Get API Token from disk or environment"""
    while True:
        try:
            from . import api_token
        except ImportError:
            try:
                api_token = collections.namedtuple("api_token", ("ID", "HASH"))(os.environ["api_id"],
                                                                                os.environ["api_hash"])
            except KeyError:
                return None
            else:
                return api_token
        else:
            return api_token 
Example 26
Project: dynamic-training-with-apache-mxnet-on-aws   Author: awslabs   File: utils.py    License: Apache License 2.0 5 votes vote down vote up
def namedtuple_with_defaults(typename, field_names, default_values=()):
    """ create a namedtuple with default values """
    T = collections.namedtuple(typename, field_names)
    T.__new__.__defaults__ = (None, ) * len(T._fields)
    if isinstance(default_values, collections.Mapping):
        prototype = T(**default_values)
    else:
        prototype = T(*default_values)
    T.__new__.__defaults__ = tuple(prototype)
    return T 
Example 27
Project: dynamic-training-with-apache-mxnet-on-aws   Author: awslabs   File: utils.py    License: Apache License 2.0 5 votes vote down vote up
def zip_namedtuple(nt_list):
    """ accept list of namedtuple, return a dict of zipped fields """
    if not nt_list:
        return dict()
    if not isinstance(nt_list, list):
        nt_list = [nt_list]
    for nt in nt_list:
        assert type(nt) == type(nt_list[0])
    ret = {k : [v] for k, v in nt_list[0]._asdict().items()}
    for nt in nt_list[1:]:
        for k, v in nt._asdict().items():
            ret[k].append(v)
    return ret 
Example 28
def perform_inference(sym, arg_params, aux_params, input_img, img_cb, img_cr):
    """Perform inference on image using mxnet"""
    metadata = onnx_mxnet.get_model_metadata('super_resolution.onnx')
    data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')]
    # create module
    mod = mx.mod.Module(symbol=sym, data_names=data_names, label_names=None)
    mod.bind(for_training=False, data_shapes=[(data_names[0], input_img.shape)])
    mod.set_params(arg_params=arg_params, aux_params=aux_params)

    # run inference
    batch = namedtuple('Batch', ['data'])
    mod.forward(batch([mx.nd.array(input_img)]))

    # Save the result
    img_out_y = Image.fromarray(np.uint8(mod.get_outputs()[0][0][0].
                                         asnumpy().clip(0, 255)), mode='L')

    result_img = Image.merge(
        "YCbCr", [img_out_y,
                  img_cb.resize(img_out_y.size, Image.BICUBIC),
                  img_cr.resize(img_out_y.size, Image.BICUBIC)]).convert("RGB")
    output_img_dim = 672
    assert result_img.size == (output_img_dim, output_img_dim)
    LOGGER.info("Super Resolution example success.")
    result_img.save("super_res_output.jpg")
    return result_img 
Example 29
def test_bvlc_googlenet():
    """ Tests Googlenet model"""
    model_path, inputs, outputs = get_test_files('bvlc_googlenet')
    logging.info("Translating Googlenet model from ONNX to Mxnet")
    sym, arg_params, aux_params = onnx_mxnet.import_model(model_path)
    metadata = onnx_mxnet.get_model_metadata(model_path)
    assert len(metadata) == 2
    assert metadata.get('input_tensor_data')
    assert metadata.get('input_tensor_data') == [(u'data_0', (1, 3, 224, 224))]
    assert metadata.get('output_tensor_data')
    assert metadata.get('output_tensor_data') == [(u'prob_1', (1, 1000))]
    data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')]

    # run test for each test file
    for input_data, output_data in zip(inputs, outputs):
        # create module
        mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)
        mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None)
        mod.set_params(arg_params=arg_params, aux_params=aux_params,
                       allow_missing=True, allow_extra=True)
        # run inference
        batch = namedtuple('Batch', ['data'])
        mod.forward(batch([mx.nd.array(input_data)]), is_train=False)

        # verify the results
        npt.assert_equal(mod.get_outputs()[0].shape, output_data.shape)
        npt.assert_almost_equal(output_data, mod.get_outputs()[0].asnumpy(), decimal=3)
    logging.info("Googlenet model conversion Successful") 
Example 30
def test_bvlc_rcnn_ilsvrc13():
    """Tests the bvlc rcnn model"""
    model_path, inputs, outputs = get_test_files('bvlc_reference_rcnn_ilsvrc13')
    logging.info("Translating rcnn_ilsvrc13 model from ONNX to Mxnet")
    sym, arg_params, aux_params = onnx_mxnet.import_model(model_path)
    metadata = onnx_mxnet.get_model_metadata(model_path)
    assert len(metadata) == 2
    assert metadata.get('input_tensor_data')
    assert metadata.get('input_tensor_data') == [(u'data_0', (1, 3, 224, 224))]
    assert metadata.get('output_tensor_data')
    assert metadata.get('output_tensor_data') == [(u'fc-rcnn_1', (1, 200))]
    data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')]

    # run test for each test file
    for input_data, output_data in zip(inputs, outputs):
        # create module
        mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)
        mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None)
        mod.set_params(arg_params=arg_params, aux_params=aux_params,
                       allow_missing=True, allow_extra=True)
        # run inference
        batch = namedtuple('Batch', ['data'])
        mod.forward(batch([mx.nd.array(input_data)]), is_train=False)

        # verify the results
        npt.assert_equal(mod.get_outputs()[0].shape, output_data.shape)
        npt.assert_almost_equal(output_data, mod.get_outputs()[0].asnumpy(), decimal=3)
    logging.info("rcnn_ilsvrc13 model conversion Successful")