Python collections.namedtuple() Examples
The following are 30
code examples of collections.namedtuple().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
collections
, or try the search function
.

Example #1
Source File: replay.py From iSDX with Apache License 2.0 | 6 votes |
def __init__(self, config, flows_dir, ports_dir, num_timesteps, debug=False): self.logger = logging.getLogger("LogHistory") if debug: self.logger.setLevel(logging.DEBUG) self.log_entry = namedtuple("LogEntry", "source destination type") self.ports = defaultdict(list) self.flows = defaultdict(list) self.data = defaultdict(lambda: defaultdict(lambda: defaultdict(int))) self.current_timestep = 0 self.total_timesteps = num_timesteps self.parse_config(config) self.parse_logs(num_timesteps, flows_dir, ports_dir) self.info() pretty(self.data)
Example #2
Source File: submissions.py From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License | 6 votes |
def __init__(self, datastore_client, storage_client, round_name): """Initializes CompetitionSubmissions. Args: datastore_client: instance of CompetitionDatastoreClient storage_client: instance of CompetitionStorageClient round_name: name of the round """ self._datastore_client = datastore_client self._storage_client = storage_client self._round_name = round_name # each of the variables is a dictionary, # where key - submission ID # value - SubmissionDescriptor namedtuple self._attacks = None self._targeted_attacks = None self._defenses = None
Example #3
Source File: test_yaml.py From python-clean-architecture with MIT License | 6 votes |
def test_construct_namedtuple(): """Original Loader has a problem of building an object which state is set by __new__, instead of __init__. """ from collections import namedtuple class FooClass(serialization.yaml.yaml.YAMLObject, namedtuple('Foo', "x, y")): yaml_tag = 'foo' yaml_constructor = serialization.CustomYamlLoader def __setstate__(self, data): self.data = data contents = ( "---\n" "foo: !<foo> {x: 1, y: 2}\n" ) foo_object = serialization.load_yaml(contents)['foo'] assert isinstance(foo_object, FooClass) assert foo_object.data == {'x': 1, 'y': 2}
Example #4
Source File: test_module.py From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 | 6 votes |
def test_forward_types(): #Test forward with other data batch API Batch = namedtuple('Batch', ['data']) data = mx.sym.Variable('data') out = data * 2 mod = mx.mod.Module(symbol=out, label_names=None) mod.bind(data_shapes=[('data', (1, 10))]) mod.init_params() data1 = [mx.nd.ones((1, 10))] mod.forward(Batch(data1)) assert mod.get_outputs()[0].shape == (1, 10) data2 = [mx.nd.ones((3, 5))] mod.forward(Batch(data2)) assert mod.get_outputs()[0].shape == (3, 5) #Test forward with other NDArray and np.ndarray inputs data = mx.sym.Variable('data') out = data * 2 mod = mx.mod.Module(symbol=out, label_names=None) mod.bind(data_shapes=[('data', (1, 10))]) mod.init_params() data1 = mx.nd.ones((1, 10)) assert mod.predict(data1).shape == (1, 10) data2 = np.ones((1, 10)) assert mod.predict(data1).shape == (1, 10)
Example #5
Source File: data_utils.py From DOTA_models with Apache License 2.0 | 6 votes |
def read_names(names_path): """read data from downloaded file. See SmallNames.txt for example format or go to https://www.kaggle.com/kaggle/us-baby-names for full lists Args: names_path: path to the csv file similar to the example type Returns: Dataset: a namedtuple of two elements: deduped names and their associated counts. The names contain only 26 chars and are all lower case """ names_data = pd.read_csv(names_path) names_data.Name = names_data.Name.str.lower() name_data = names_data.groupby(by=["Name"])["Count"].sum() name_counts = np.array(name_data.tolist()) names_deduped = np.array(name_data.index.tolist()) Dataset = collections.namedtuple('Dataset', ['Name', 'Count']) return Dataset(names_deduped, name_counts)
Example #6
Source File: model_deploy.py From DOTA_models with Apache License 2.0 | 6 votes |
def _optimize_clone(optimizer, clone, num_clones, regularization_losses, **kwargs): """Compute losses and gradients for a single clone. Args: optimizer: A tf.Optimizer object. clone: A Clone namedtuple. num_clones: The number of clones being deployed. regularization_losses: Possibly empty list of regularization_losses to add to the clone losses. **kwargs: Dict of kwarg to pass to compute_gradients(). Returns: A tuple (clone_loss, clone_grads_and_vars). - clone_loss: A tensor for the total loss for the clone. Can be None. - clone_grads_and_vars: List of (gradient, variable) for the clone. Can be empty. """ sum_loss = _gather_clone_loss(clone, num_clones, regularization_losses) clone_grad = None if sum_loss is not None: with tf.device(clone.device): clone_grad = optimizer.compute_gradients(sum_loss, **kwargs) return sum_loss, clone_grad
Example #7
Source File: model.py From DOTA_models with Apache License 2.0 | 6 votes |
def create_loss(self, data, endpoints): """Creates all losses required to train the model. Args: data: InputEndpoints namedtuple. endpoints: Model namedtuple. Returns: Total loss. """ # NOTE: the return value of ModelLoss is not used directly for the # gradient computation because under the hood it calls slim.losses.AddLoss, # which registers the loss in an internal collection and later returns it # as part of GetTotalLoss. We need to use total loss because model may have # multiple losses including regularization losses. self.sequence_loss_fn(endpoints.chars_logit, data.labels) total_loss = slim.losses.get_total_loss() tf.summary.scalar('TotalLoss', total_loss) return total_loss
Example #8
Source File: fsns_test.py From DOTA_models with Apache License 2.0 | 6 votes |
def test_decodes_example_proto(self): expected_label = range(37) expected_image, encoded = unittest_utils.create_random_image( 'PNG', shape=(150, 600, 3)) serialized = unittest_utils.create_serialized_example({ 'image/encoded': [encoded], 'image/format': ['PNG'], 'image/class': expected_label, 'image/unpadded_class': range(10), 'image/text': ['Raw text'], 'image/orig_width': [150], 'image/width': [600] }) decoder = fsns.get_split('train', dataset_dir()).decoder with self.test_session() as sess: data_tuple = collections.namedtuple('DecodedData', decoder.list_items()) data = sess.run(data_tuple(*decoder.decode(serialized))) self.assertAllEqual(expected_image, data.image) self.assertAllEqual(expected_label, data.label) self.assertEqual(['Raw text'], data.text) self.assertEqual([1], data.num_of_views)
Example #9
Source File: minitaur_reactive_env.py From soccer-matlab with BSD 2-Clause "Simplified" License | 6 votes |
def _reset(self): # TODO(b/73666007): Use composition instead of inheritance. # (http://go/design-for-testability-no-inheritance). init_pose = MinitaurPose( swing_angle_1=INIT_SWING_POS, swing_angle_2=INIT_SWING_POS, swing_angle_3=INIT_SWING_POS, swing_angle_4=INIT_SWING_POS, extension_angle_1=INIT_EXTENSION_POS, extension_angle_2=INIT_EXTENSION_POS, extension_angle_3=INIT_EXTENSION_POS, extension_angle_4=INIT_EXTENSION_POS) # TODO(b/73734502): Refactor input of _convert_from_leg_model to namedtuple. initial_motor_angles = self._convert_from_leg_model(list(init_pose)) super(MinitaurReactiveEnv, self)._reset( initial_motor_angles=initial_motor_angles, reset_duration=0.5) return self._get_observation()
Example #10
Source File: __init__.py From aws-ops-automator with Apache License 2.0 | 6 votes |
def as_namedtuple(name, d, deep=True, name_func=None, excludes=None): name_func = name_func if name_func is not None else tuple_name_func if not isinstance(d, dict) or getattr(d, "keys") is None: return d if excludes is None: excludes = [] dest = {} if deep: # deep copy to avoid modifications on input dictionaries for key in list(d.keys()): key_name = name_func(key) if is_dict(d[key]) and key not in excludes: dest[key_name] = as_namedtuple(key, d[key], deep=True, name_func=name_func, excludes=excludes) elif is_array(d[key]) and key not in excludes: dest[key_name] = [as_namedtuple(key, i, deep=True, name_func=name_func, excludes=excludes) for i in d[key]] else: dest[key_name] = d[key] else: dest = {name_func(key): d[key] for key in list(d.keys())} return collections.namedtuple(name_func(name), list(dest.keys()))(*list(dest.values()))
Example #11
Source File: features.py From NGU-scripts with GNU Lesser General Public License v3.0 | 6 votes |
def get_inventory_slots(slots :int) -> None: """Get coords for inventory slots from 1 to slots.""" point = namedtuple("p", ("x", "y")) i = 1 row = 1 x_pos, y_pos = coords.INVENTORY_SLOTS res = [] while i <= slots: x = x_pos + (i - (12 * (row - 1))) * 50 y = y_pos + ((row - 1) * 50) res.append(point(x, y)) if i % 12 == 0: row += 1 i += 1 return res
Example #12
Source File: leveldb.py From leveldb-py with MIT License | 6 votes |
def next(self): """Advances the iterator one step. Also returns the current value prior to moving the iterator @rtype: Row (namedtuple of key, value) if keys_only=False, otherwise string (the key) @raise StopIteration: if called on an iterator that is not valid """ if not self.valid(): raise StopIteration() if self._keys_only: rv = self.key() else: rv = Row(self.key(), self.value()) self._impl.next() return rv
Example #13
Source File: leveldb.py From leveldb-py with MIT License | 6 votes |
def prev(self): """Backs the iterator up one step. Also returns the current value prior to moving the iterator. @rtype: Row (namedtuple of key, value) if keys_only=False, otherwise string (the key) @raise StopIteration: if called on an iterator that is not valid """ if not self.valid(): raise StopIteration() if self._keys_only: rv = self.key() else: rv = Row(self.key(), self.value()) self._impl.prev() return rv
Example #14
Source File: tests.py From psmqtt with MIT License | 6 votes |
def test_tuple_command_handler(self): handler = type("TestHandler", (TupleCommandHandler, object), {"get_value": lambda s: namedtuple('test', 'a b')(10, 20)})('test') # normal execution self.assertEqual(10, handler.handle('a')) self.assertEqual({'a': 10, 'b': 20}, handler.handle('*')) self.assertEqual('{"a": 10, "b": 20}', handler.handle('*;')) # exceptions self.assertRaises(Exception, handler.handle, '') self.assertRaises(Exception, handler.handle, '/') self.assertRaises(Exception, handler.handle, '*/') self.assertRaises(Exception, handler.handle, '/*') self.assertRaises(Exception, handler.handle, 'blabla') self.assertRaises(Exception, handler.handle, 'bla/bla') self.assertRaises(Exception, handler.handle, 'bla/') self.assertRaises(Exception, handler.handle, '/bla')
Example #15
Source File: tests.py From psmqtt with MIT License | 6 votes |
def test_index_tuple_command_handler(self): r = [namedtuple('test', 'a b')(1, 2), namedtuple('test', 'a b')(3, 4)] handler = type("TestHandler", (IndexTupleCommandHandler, object), {"get_value": lambda s: r})('test') # normal execution self.assertEqual([1, 3], handler.handle('a/*')) self.assertEqual("[1, 3]", handler.handle('a/*;')) self.assertEqual(3, handler.handle('a/1')) self.assertEqual({'a': 3, 'b': 4}, handler.handle('*/1')) self.assertEqual('{"a": 3, "b": 4}', handler.handle('*;/1')) # exceptions self.assertRaises(Exception, handler.handle, '') self.assertRaises(Exception, handler.handle, '*') self.assertRaises(Exception, handler.handle, '*;') self.assertRaises(Exception, handler.handle, 'a') self.assertRaises(Exception, handler.handle, 'a/') self.assertRaises(Exception, handler.handle, '/') self.assertRaises(Exception, handler.handle, '*/') self.assertRaises(Exception, handler.handle, '/*') self.assertRaises(Exception, handler.handle, 'blabla') self.assertRaises(Exception, handler.handle, 'bla/bla') self.assertRaises(Exception, handler.handle, 'bla/') self.assertRaises(Exception, handler.handle, '/bla')
Example #16
Source File: poolmanager.py From gist-alfred with MIT License | 6 votes |
def connection_from_pool_key(self, pool_key, request_context=None): """ Get a :class:`ConnectionPool` based on the provided pool key. ``pool_key`` should be a namedtuple that only contains immutable objects. At a minimum it must have the ``scheme``, ``host``, and ``port`` fields. """ with self.pools.lock: # If the scheme, host, or port doesn't match existing open # connections, open a new ConnectionPool. pool = self.pools.get(pool_key) if pool: return pool # Make a fresh ConnectionPool of the desired type scheme = request_context['scheme'] host = request_context['host'] port = request_context['port'] pool = self._new_pool(scheme, host, port, request_context=request_context) self.pools[pool_key] = pool return pool
Example #17
Source File: sina.py From backtrader-cn with GNU General Public License v3.0 | 5 votes |
def _json_object_hook(d): class_name = d.pop('_class_name', 'NamedTuple') return namedtuple(class_name, d.keys())(*d.values())
Example #18
Source File: run_multigpu.py From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License | 5 votes |
def main(argv=None): f = {x: flags.FLAGS[x].value for x in dir(flags.FLAGS)} HParams = namedtuple('HParams', f.keys()) hparams = HParams(**f) run_trainer(hparams)
Example #19
Source File: osdriver.py From multibootusb with GNU General Public License v2.0 | 5 votes |
def collect_relevant_info(obj, tuple_name, attributes, named_tuple): if len(named_tuple)==0: names = [x[0] for x in attributes] named_tuple.append(collections.namedtuple(tuple_name, names)) L = [] for (attr, convfunc) in attributes: v = getattr(obj, attr) L.append(None if v is None else convfunc(v)) return named_tuple[0](*L)
Example #20
Source File: win32.py From multibootusb with GNU General Public License v2.0 | 5 votes |
def findVolumeGuids(): DiskExtent = collections.namedtuple( 'DiskExtent', ['DiskNumber', 'StartingOffset', 'ExtentLength']) Volume = collections.namedtuple( 'Volume', ['Guid', 'MediaType', 'DosDevice', 'Extents']) found = [] h, guid = FindFirstVolume() while h and guid: #print (guid) #print (guid, win32file.GetDriveType(guid), # win32file.QueryDosDevice(guid[4:-1])) hVolume = win32file.CreateFile( guid[:-1], win32con.GENERIC_READ, win32con.FILE_SHARE_READ | win32con.FILE_SHARE_WRITE, None, win32con.OPEN_EXISTING, win32con.FILE_ATTRIBUTE_NORMAL, None) extents = [] driveType = win32file.GetDriveType(guid) if driveType in [win32con.DRIVE_REMOVABLE, win32con.DRIVE_FIXED]: x = win32file.DeviceIoControl( hVolume, winioctlcon.IOCTL_VOLUME_GET_VOLUME_DISK_EXTENTS, None, 512, None) instream = io.BytesIO(x) numRecords = struct.unpack('<q', instream.read(8))[0] fmt = '<qqq' sz = struct.calcsize(fmt) while 1: b = instream.read(sz) if len(b) < sz: break rec = struct.unpack(fmt, b) extents.append( DiskExtent(*rec) ) vinfo = Volume(guid, driveType, win32file.QueryDosDevice(guid[4:-1]), extents) found.append(vinfo) guid = FindNextVolume(h) return found
Example #21
Source File: imager.py From multibootusb with GNU General Public License v2.0 | 5 votes |
def imager_usb_detail(physical_disk): """ Function to detect details of USB disk using lsblk :param physical_disk: /dev/sd? (linux) or integer disk number (win) :return: details of size, type and model as tuples """ _ntuple_diskusage = collections.namedtuple( 'usage', 'total_size usb_type model') if platform.system() == "Linux": output = subprocess.check_output("lsblk -ib " + physical_disk, shell=True) for line in output.splitlines(): line = line.split() if line[2].strip() == b'1' and line[5].strip() == b'disk': total_size = line[3] if not total_size: total_size = "Unknown" usb_type = "Removable" model = subprocess.check_output( "lsblk -in -f -o MODEL " + physical_disk, shell=True).decode().strip() if not model: model = "Unknown" else: dinfo = osdriver.wmi_get_physicaldrive_info_ex(physical_disk) return _ntuple_diskusage(*[dinfo[a] for a in [ 'size_total', 'mediatype', 'model']]) return _ntuple_diskusage(total_size, usb_type, model)
Example #22
Source File: usb.py From multibootusb with GNU General Public License v2.0 | 5 votes |
def disk_usage(mount_path): """ Return disk usage statistics about the given path as a (total, used, free) namedtuple. Values are expressed in bytes. """ # Author: Giampaolo Rodola' <g.rodola [AT] gmail [DOT] com> # License: MIT _ntuple_diskusage = collections.namedtuple('usage', 'total used free') if platform.system() == "Linux": st = os.statvfs(mount_path) free = st.f_bavail * st.f_frsize total = st.f_blocks * st.f_frsize used = (st.f_blocks - st.f_bfree) * st.f_frsize return _ntuple_diskusage(total, used, free) elif platform.system() == "Windows": _, total, free = ctypes.c_ulonglong(), ctypes.c_ulonglong(), \ ctypes.c_ulonglong() if sys.version_info >= (3,) or isinstance(mount_path, unicode): fun = ctypes.windll.kernel32.GetDiskFreeSpaceExW else: fun = ctypes.windll.kernel32.GetDiskFreeSpaceExA ret = fun(mount_path, ctypes.byref(_), ctypes.byref(total), ctypes.byref(free)) if ret == 0: raise ctypes.WinError() used = total.value - free.value return _ntuple_diskusage(total.value, used, free.value) else: raise NotImplementedError("Platform not supported.")
Example #23
Source File: heroku.py From friendly-telegram with GNU Affero General Public License v3.0 | 5 votes |
def __init__(self, **kwargs): super().__init__(**kwargs) if "heroku_api_token" in os.environ: # This is called before asyncio is even set up. We can only use sync methods which is fine. api_token = collections.namedtuple("api_token", ["ID", "HASH"])(os.environ["api_id"], os.environ["api_hash"]) app, config = heroku.get_app([c[1] for c in self.client_data], os.environ["heroku_api_token"], api_token, False, True) if os.environ["DYNO"].startswith("web."): app.scale_formation_process("worker-DO-NOT-TURN-ON-OR-THINGS-WILL-BREAK", 0) atexit.register(functools.partial(exit_handler, app))
Example #24
Source File: initial_setup.py From friendly-telegram with GNU Affero General Public License v3.0 | 5 votes |
def set_tg_api(self, request): if self.client_data and await self.check_user(request) is None: return web.Response(status=302, headers={"Location": "/"}) # They gotta sign in. text = await request.text() if len(text) < 36: return web.Response(status=400) api_id = text[32:] api_hash = text[:32] if any(c not in string.hexdigits for c in api_hash) or any(c not in string.digits for c in api_id): return web.Response(status=400) with open(os.path.join(utils.get_base_dir(), "api_token.py"), "w") as f: f.write("HASH = \"" + api_hash + "\"\nID = \"" + api_id + "\"\n") self.api_token = collections.namedtuple("api_token", ("ID", "HASH"))(api_id, api_hash) self.api_set.set() return web.Response()
Example #25
Source File: main.py From friendly-telegram with GNU Affero General Public License v3.0 | 5 votes |
def get_api_token(): """Get API Token from disk or environment""" while True: try: from . import api_token except ImportError: try: api_token = collections.namedtuple("api_token", ("ID", "HASH"))(os.environ["api_id"], os.environ["api_hash"]) except KeyError: return None else: return api_token else: return api_token
Example #26
Source File: utils.py From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 | 5 votes |
def namedtuple_with_defaults(typename, field_names, default_values=()): """ create a namedtuple with default values """ T = collections.namedtuple(typename, field_names) T.__new__.__defaults__ = (None, ) * len(T._fields) if isinstance(default_values, collections.Mapping): prototype = T(**default_values) else: prototype = T(*default_values) T.__new__.__defaults__ = tuple(prototype) return T
Example #27
Source File: utils.py From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 | 5 votes |
def zip_namedtuple(nt_list): """ accept list of namedtuple, return a dict of zipped fields """ if not nt_list: return dict() if not isinstance(nt_list, list): nt_list = [nt_list] for nt in nt_list: assert type(nt) == type(nt_list[0]) ret = {k : [v] for k, v in nt_list[0]._asdict().items()} for nt in nt_list[1:]: for k, v in nt._asdict().items(): ret[k].append(v) return ret
Example #28
Source File: super_resolution.py From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 | 5 votes |
def perform_inference(sym, arg_params, aux_params, input_img, img_cb, img_cr): """Perform inference on image using mxnet""" metadata = onnx_mxnet.get_model_metadata('super_resolution.onnx') data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')] # create module mod = mx.mod.Module(symbol=sym, data_names=data_names, label_names=None) mod.bind(for_training=False, data_shapes=[(data_names[0], input_img.shape)]) mod.set_params(arg_params=arg_params, aux_params=aux_params) # run inference batch = namedtuple('Batch', ['data']) mod.forward(batch([mx.nd.array(input_img)])) # Save the result img_out_y = Image.fromarray(np.uint8(mod.get_outputs()[0][0][0]. asnumpy().clip(0, 255)), mode='L') result_img = Image.merge( "YCbCr", [img_out_y, img_cb.resize(img_out_y.size, Image.BICUBIC), img_cr.resize(img_out_y.size, Image.BICUBIC)]).convert("RGB") output_img_dim = 672 assert result_img.size == (output_img_dim, output_img_dim) LOGGER.info("Super Resolution example success.") result_img.save("super_res_output.jpg") return result_img
Example #29
Source File: onnx_import_test.py From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 | 5 votes |
def test_bvlc_googlenet(): """ Tests Googlenet model""" model_path, inputs, outputs = get_test_files('bvlc_googlenet') logging.info("Translating Googlenet model from ONNX to Mxnet") sym, arg_params, aux_params = onnx_mxnet.import_model(model_path) metadata = onnx_mxnet.get_model_metadata(model_path) assert len(metadata) == 2 assert metadata.get('input_tensor_data') assert metadata.get('input_tensor_data') == [(u'data_0', (1, 3, 224, 224))] assert metadata.get('output_tensor_data') assert metadata.get('output_tensor_data') == [(u'prob_1', (1, 1000))] data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')] # run test for each test file for input_data, output_data in zip(inputs, outputs): # create module mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None) mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None) mod.set_params(arg_params=arg_params, aux_params=aux_params, allow_missing=True, allow_extra=True) # run inference batch = namedtuple('Batch', ['data']) mod.forward(batch([mx.nd.array(input_data)]), is_train=False) # verify the results npt.assert_equal(mod.get_outputs()[0].shape, output_data.shape) npt.assert_almost_equal(output_data, mod.get_outputs()[0].asnumpy(), decimal=3) logging.info("Googlenet model conversion Successful")
Example #30
Source File: onnx_import_test.py From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 | 5 votes |
def test_bvlc_rcnn_ilsvrc13(): """Tests the bvlc rcnn model""" model_path, inputs, outputs = get_test_files('bvlc_reference_rcnn_ilsvrc13') logging.info("Translating rcnn_ilsvrc13 model from ONNX to Mxnet") sym, arg_params, aux_params = onnx_mxnet.import_model(model_path) metadata = onnx_mxnet.get_model_metadata(model_path) assert len(metadata) == 2 assert metadata.get('input_tensor_data') assert metadata.get('input_tensor_data') == [(u'data_0', (1, 3, 224, 224))] assert metadata.get('output_tensor_data') assert metadata.get('output_tensor_data') == [(u'fc-rcnn_1', (1, 200))] data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')] # run test for each test file for input_data, output_data in zip(inputs, outputs): # create module mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None) mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None) mod.set_params(arg_params=arg_params, aux_params=aux_params, allow_missing=True, allow_extra=True) # run inference batch = namedtuple('Batch', ['data']) mod.forward(batch([mx.nd.array(input_data)]), is_train=False) # verify the results npt.assert_equal(mod.get_outputs()[0].shape, output_data.shape) npt.assert_almost_equal(output_data, mod.get_outputs()[0].asnumpy(), decimal=3) logging.info("rcnn_ilsvrc13 model conversion Successful")