Python contextlib.ExitStack() Examples

The following are 30 code examples of contextlib.ExitStack(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module contextlib , or try the search function .
Example #1
Source File: ptpimg_uploader.py    From ptpimg-uploader with BSD 2-Clause "Simplified" License 7 votes vote down vote up
def upload_files(self, *filenames):
        """ Upload files using form """
        # The ExitStack closes files for us when the with block exits
        with contextlib.ExitStack() as stack:
            files = {}
            for i, filename in enumerate(filenames):
                open_file = stack.enter_context(open(filename, 'rb'))
                mime_type, _ = mimetypes.guess_type(filename)
                if not mime_type or mime_type.split('/')[0] != 'image':
                    raise ValueError(
                        'Unknown image file type {}'.format(mime_type))

                name = os.path.basename(filename)
                try:
                    # until https://github.com/shazow/urllib3/issues/303 is
                    # resolved, only use the filename if it is Latin-1 safe
                    name.encode('latin1')
                except UnicodeEncodeError:
                    name = 'justfilename'
                files['file-upload[{}]'.format(i)] = (
                    name, open_file, mime_type)
            return self._perform(files=files) 
Example #2
Source File: conftest.py    From indy-plenum with Apache License 2.0 6 votes vote down vote up
def txnPoolNodeSet(node_config_helper_class,
                   patchPluginManager,
                   txnPoolNodesLooper,
                   tdirWithPoolTxns,
                   tdirWithDomainTxns,
                   tdir,
                   tconf,
                   poolTxnNodeNames,
                   allPluginsPath,
                   tdirWithNodeKeepInited,
                   testNodeClass,
                   do_post_node_creation,
                   testNodeBootstrapClass):
    with ExitStack() as exitStack:
        nodes = []
        for nm in poolTxnNodeNames:
            node = exitStack.enter_context(create_new_test_node(
                testNodeClass, node_config_helper_class, nm, tconf, tdir,
                allPluginsPath, bootstrap_cls=testNodeBootstrapClass))
            do_post_node_creation(node)
            txnPoolNodesLooper.add(node)
            nodes.append(node)
        txnPoolNodesLooper.run(checkNodesConnected(nodes))
        ensureElectionsDone(looper=txnPoolNodesLooper, nodes=nodes)
        yield nodes 
Example #3
Source File: pool.py    From aioredis with MIT License 6 votes vote down vote up
def discover_slave(self, service, timeout, **kwargs):
        """Perform Slave discovery for specified service."""
        # TODO: use kwargs to change how slaves are picked up
        #   (eg: round-robin, priority, random, etc)
        idle_timeout = timeout
        pools = self._pools[:]
        for sentinel in pools:
            try:
                with async_timeout(timeout):
                    address = await self._get_slave_address(
                        sentinel, service)  # add **kwargs
                pool = self._slaves[service]
                with async_timeout(timeout), \
                        contextlib.ExitStack() as stack:
                    conn = await pool._create_new_connection(address)
                    stack.callback(conn.close)
                    await self._verify_service_role(conn, 'slave')
                    stack.pop_all()
                return conn
            except asyncio.CancelledError:
                raise
            except asyncio.TimeoutError:
                continue
            except DiscoverError:
                await asyncio.sleep(idle_timeout)
                continue
            except RedisError as err:
                raise SlaveReplyError("Service {} error".format(service), err)
            except Exception:
                await asyncio.sleep(idle_timeout)
                continue
        raise SlaveNotFoundError("No slave found for {}".format(service)) 
Example #4
Source File: test__task_commons.py    From tf-yarn with Apache License 2.0 6 votes vote down vote up
def test__execute_dispatched_function():
    with contextlib.ExitStack() as stack:
        mocked_event = stack.enter_context(patch(f'{MODULE_TO_TEST}.event'))
        mocked_train = stack.enter_context(
            patch(f'{MODULE_TO_TEST}.tf.estimator.train_and_evaluate'))
        passed_args = []
        mocked_train.side_effect = lambda *args: passed_args.append(args)
        mocked_cluster = stack.enter_context(patch(f'{MODULE_TO_TEST}.cluster'))
        mocked_cluster.get_task_description.return_value = ("worker", "0")

        mocked_client = mock.MagicMock(spec=skein.ApplicationClient)
        mocked_experiment = Experiment(None, None, None)
        thread = _execute_dispatched_function(mocked_client, mocked_experiment)
        # assert thread.state == 'RUNNING'
        thread.join()
        mocked_event.start_event.assert_called_once()
        assert passed_args == [(None, None, None)]
        assert thread.state == 'SUCCEEDED' 
Example #5
Source File: shadow_demo.py    From ratcave with MIT License 6 votes vote down vote up
def on_draw():
    window.clear()
    with ExitStack() as stack:
        for shader in [rc.resources.shadow_shader, rc.default_shader]:
            with shader, rc.default_states, light, rc.default_camera:
                if shader == rc.resources.shadow_shader:
                    stack.enter_context(fbo_shadow)
                    window.clear()
                else:
                    stack.close()

                for x, y in it.product([-2, -1, 0, 1, 2], [-2, -1, 0, 1, 2]):
                    monkey.position.x = x
                    monkey.position.y = y
                    monkey.drawmode = rc.GL_POINTS if x % 2 and y % 2 else rc.GL_TRIANGLES
                    monkey.uniforms['diffuse'][0] = (x + 1) / 4.
                    monkey.uniforms['diffuse'][1:] = (y + 1) / 4.
                    monkey.scale.z = np.linalg.norm((x, y)) / 10. + .03
                    monkey.draw()

                plane.draw()

    fps_display.draw() 
Example #6
Source File: test__task_commons.py    From tf-yarn with Apache License 2.0 6 votes vote down vote up
def test__prepare_container():
    with contextlib.ExitStack() as stack:
        # mock modules
        mocked_client_call = stack.enter_context(
            patch(f"{MODULE_TO_TEST}.skein.ApplicationClient.from_current"))
        mocked_logs = stack.enter_context(patch(f'{MODULE_TO_TEST}._setup_container_logs'))
        mocked_cluster_spec = stack.enter_context(patch(f'{MODULE_TO_TEST}.cluster.start_cluster'))

        # fill client mock
        mocked_client = mock.MagicMock(spec=skein.ApplicationClient)
        host_port = ('localhost', 1234)
        instances = [('worker', 10), ('chief', 1)]
        mocked_client.kv.wait.return_value = json.dumps(instances).encode()
        mocked_client_call.return_value = mocked_client
        (client, cluster_spec, cluster_tasks) = _prepare_container(host_port)

        # checks
        mocked_logs.assert_called_once()
        mocked_cluster_spec.assert_called_once_with(host_port, mocked_client, cluster_tasks)
        assert client == mocked_client
        assert cluster_tasks == list(iter_tasks(instances)) 
Example #7
Source File: test_cluster.py    From tf-yarn with Apache License 2.0 6 votes vote down vote up
def test_start_tf_server(task_name, task_index, is_server_started):

    CLUSTER_SPEC = {"worker": [f"worker0.{WORKER0_HOST}:{WORKER0_PORT}",
                              f"worker1.{WORKER1_HOST}:{WORKER1_PORT}"],
                    "ps": [f"ps0.{CURRENT_HOST}:{CURRENT_PORT}"]}

    with contextlib.ExitStack() as stack:
        stack.enter_context(mock.patch.dict(os.environ))
        os.environ["SKEIN_CONTAINER_ID"] = f"{task_name}_{task_index}"
        mock_server = stack.enter_context(mock.patch(f"{MODULE_TO_TEST}.tf.distribute"))
        cluster.start_tf_server(CLUSTER_SPEC)

        if is_server_started:
            assert mock_server.Server.call_count == 1
            _, kwargs = mock_server.Server.call_args
            assert kwargs["job_name"] == task_name
            assert kwargs["task_index"] == task_index
            assert kwargs["start"] is True
        else:
            assert mock_server.Server.call_count == 0 
Example #8
Source File: configure_synapse_test.py    From synapse-tools with Apache License 2.0 6 votes vote down vote up
def mock_available_location_types():
    mock_types = [
        'runtimeenv',
        'ecosystem',
        'superregion',
        'region',
        'habitat',
    ]
    patchers = [
        mock.patch(
            'environment_tools.type_utils.available_location_types',
            return_value=mock_types,
        ),
        mock.patch(
            'synapse_tools.configure_synapse.available_location_types',
            return_value=mock_types,
        ),
    ]

    with contextlib.ExitStack() as stack:
        yield tuple(stack.enter_context(patch) for patch in patchers) 
Example #9
Source File: zmq_decor.py    From bert-as-service with MIT License 6 votes vote down vote up
def __call__(self, *dec_args, **dec_kwargs):
        kw_name, dec_args, dec_kwargs = self.process_decorator_args(*dec_args, **dec_kwargs)
        num_socket_str = dec_kwargs.pop('num_socket')

        def decorator(func):
            @wraps(func)
            def wrapper(*args, **kwargs):
                num_socket = getattr(args[0], num_socket_str)
                targets = [self.get_target(*args, **kwargs) for _ in range(num_socket)]
                with ExitStack() as stack:
                    for target in targets:
                        obj = stack.enter_context(target(*dec_args, **dec_kwargs))
                        args = args + (obj,)

                    return func(*args, **kwargs)

            return wrapper

        return decorator 
Example #10
Source File: test_start_vc_ts_in_node_info.py    From indy-plenum with Apache License 2.0 6 votes vote down vote up
def create_node_and_not_start(testNodeClass,
                              node_config_helper_class,
                              tconf,
                              tdir,
                              allPluginsPath,
                              looper,
                              tdirWithPoolTxns,
                              tdirWithDomainTxns,
                              tdirWithNodeKeepInited):
    with ExitStack() as exitStack:
        node = exitStack.enter_context(create_new_test_node(testNodeClass,
                                node_config_helper_class,
                                "Alpha",
                                tconf,
                                tdir,
                                allPluginsPath))
        yield node
        node.stop() 
Example #11
Source File: ptpimg_uploader.py    From ptpimg-uploader with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def upload_urls(self, *urls):
        """ Upload image URLs by downloading them before """
        with contextlib.ExitStack() as stack:
            files = {}
            for i, url in enumerate(urls):
                resp = requests.get(url, timeout=self.timeout)
                if resp.status_code != requests.codes.ok:
                    raise ValueError(
                        'Cannot fetch url {} with error {}'.format(url, resp.status_code))

                mime_type = resp.headers['content-type']
                if not mime_type or mime_type.split('/')[0] != 'image':
                    raise ValueError(
                        'Unknown image file type {}'.format(mime_type))
                open_file = stack.enter_context(BytesIO(resp.content))
                files['file-upload[{}]'.format(i)] = (
                    'file-{}'.format(i), open_file, mime_type)

            return self._perform(files=files) 
Example #12
Source File: conftest.py    From indy-plenum with Apache License 2.0 6 votes vote down vote up
def txnPoolNodeSetNotStarted(node_config_helper_class,
                             patchPluginManager,
                             txnPoolNodesLooper,
                             tdirWithPoolTxns,
                             tdirWithDomainTxns,
                             tdir,
                             tconf,
                             poolTxnNodeNames,
                             allPluginsPath,
                             tdirWithNodeKeepInited,
                             testNodeClass,
                             do_post_node_creation):
    with ExitStack() as exitStack:
        nodes = []
        for nm in poolTxnNodeNames:
            node = exitStack.enter_context(create_new_test_node(
                testNodeClass, node_config_helper_class, nm, tconf, tdir,
                allPluginsPath))
            do_post_node_creation(node)
            nodes.append(node)
        yield nodes 
Example #13
Source File: conftest.py    From indy-plenum with Apache License 2.0 6 votes vote down vote up
def create_node_and_not_start(testNodeClass,
                              node_config_helper_class,
                              tconf,
                              tdir,
                              allPluginsPath,
                              looper,
                              tdirWithPoolTxns,
                              tdirWithDomainTxns,
                              tdirWithNodeKeepInited):
    with ExitStack() as exitStack:
        node = exitStack.enter_context(create_new_test_node(testNodeClass,
                                node_config_helper_class,
                                "Alpha",
                                tconf,
                                tdir,
                                allPluginsPath))
        node.write_manager.on_catchup_finished()
        yield node
        node.stop() 
Example #14
Source File: base.py    From rankedftw with GNU Affero General Public License v3.0 6 votes vote down vote up
def __call__(self):
        status = 1
        with log_exception(status=1):
            args = self.parser.parse_args()
            log_args(args)
            config.log_cached()
            logger = getLogger('django')

            with ExitStack() as stack:
                if self._pid_file:
                    stack.enter_context(pid_file(dirname=config.PID_DIR, max_age=self._pid_file_max_age))

                if self._stoppable:
                    self._stoppable_instance = stoppable()
                    stack.enter_context(self._stoppable_instance)

                status = self.run(args, logger) or 0
        sys.exit(status) 
Example #15
Source File: data.py    From dl4mt-nonauto with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def __init__(self, path=None, exts=None, fields=None,
                load_dataset=False, prefix='', examples=None, **kwargs):

        if examples is None:
            assert len(exts) == len(fields), 'N parallel dataset must match'
            self.N = len(fields)

            paths = tuple(os.path.expanduser(path + x) for x in exts)
            if load_dataset and (os.path.exists(path + '.processed.{}.pt'.format(prefix))):
                examples = torch.load(path + '.processed.{}.pt'.format(prefix))
            else:
                examples = []
                with ExitStack() as stack:
                    files = [stack.enter_context(open(fname)) for fname in paths]
                    for lines in zip(*files):
                        lines = [line.strip() for line in lines]
                        if not any(line == '' for line in lines):
                            examples.append(data.Example.fromlist(lines, fields))
                if load_dataset:
                    torch.save(examples, path + '.processed.{}.pt'.format(prefix))

        super(datasets.TranslationDataset, self).__init__(examples, fields, **kwargs) 
Example #16
Source File: evaluate_custom.py    From OpenBookQA with Apache License 2.0 6 votes vote down vote up
def evaluate(model: Model,
             instances: Iterable[Instance],
             data_iterator: DataIterator,
             output_file: str = None) -> Dict[str, Any]:
    model.eval()

    iterator = data_iterator(instances, num_epochs=1)
    logger.info("Iterating over dataset")
    generator_tqdm = Tqdm.tqdm(iterator, total=data_iterator.get_num_batches(instances))
    with ExitStack() as stack:
        if output_file is None:
            file_handle = None
        else:
            file_handle = stack.enter_context(open(output_file, 'w'))
        for batch in generator_tqdm:
            model_output = model(**batch)
            metrics = model.get_metrics()
            if file_handle:
                id2label = model.vocab.get_index_to_token_vocabulary("labels")
                _persist_data(file_handle, batch.get("metadata"), model_output, id2label=id2label)
            description = ', '.join(["%s: %.2f" % (name, value) for name, value in metrics.items()]) + " ||"
            generator_tqdm.set_description(description)

    return model.get_metrics() 
Example #17
Source File: evaluate_predictions_qa_mc_know_visualize.py    From OpenBookQA with Apache License 2.0 6 votes vote down vote up
def evaluate(model: Model,
             instances: Iterable[Instance],
             data_iterator: DataIterator,
             output_file: str = None,
             eval_type: str = None) -> Dict[str, Any]:
    model.eval()

    iterator = data_iterator(instances, num_epochs=1)
    logger.info("Iterating over dataset")
    generator_tqdm = Tqdm.tqdm(iterator, total=data_iterator.get_num_batches(instances))
    with ExitStack() as stack:
        if output_file is None:
            file_handle = None
        else:
            file_handle = stack.enter_context(open(output_file, 'w'))
        for batch in generator_tqdm:
            model_output = model(**batch)
            metrics = model.get_metrics()
            if file_handle:
                _persist_data(file_handle, batch.get("metadata"), model_output, eval_type)
            description = ', '.join(["%s: %.2f" % (name, value) for name, value in metrics.items()]) + " ||"
            generator_tqdm.set_description(description)

    return model.get_metrics(reset=True) 
Example #18
Source File: env.py    From parasol with MIT License 6 votes vote down vote up
def rollouts(self, num_rollouts, num_steps, show_progress=False,
                 noise=None,
                 callback=lambda x: None,
                 **kwargs):
        states, actions, costs = (
            np.empty([num_rollouts, num_steps] + [self.get_state_dim()]),
            np.empty([num_rollouts, num_steps] + [self.get_action_dim()]),
            np.empty([num_rollouts, num_steps])
        )
        infos = [None] * num_rollouts
        rollouts = tqdm.trange(num_rollouts, desc='Rollouts') if show_progress else range(num_rollouts)
        for i in rollouts:
            with contextlib.ExitStack() as stack:
                context = callback(i)
                if context is not None:
                    stack.enter_context(callback(i))
                n = None
                if noise is not None:
                    n = noise()
                states[i], actions[i], costs[i], infos[i] = \
                        self.rollout(num_steps, noise=n,**kwargs)
        return states, actions, costs, infos 
Example #19
Source File: compat.py    From pipenv with MIT License 6 votes vote down vote up
def _ensure_wheel_cache(
    wheel_cache=None,  # type: Optional[Type[TWheelCache]]
    wheel_cache_provider=None,  # type: Optional[Callable]
    format_control=None,  # type: Optional[TFormatControl]
    format_control_provider=None,  # type: Optional[Type[TShimmedFunc]]
    options=None,  # type: Optional[Values]
    cache_dir=None,  # type: Optional[str]
):
    if wheel_cache is not None:
        yield wheel_cache
    elif wheel_cache_provider is not None:
        with ExitStack() as stack:
            cache_dir = getattr(options, "cache_dir", cache_dir)
            format_control = getattr(
                options,
                "format_control",
                format_control_provider(None, None),  # TFormatControl
            )
            wheel_cache = stack.enter_context(
                wheel_cache_provider(cache_dir, format_control)
            )
            yield wheel_cache 
Example #20
Source File: compat.py    From pipenv with MIT License 6 votes vote down vote up
def wheel_cache(
    cache_dir=None,  # type: str
    format_control=None,  # type: Any
    wheel_cache_provider=None,  # type: TShimmedFunc
    format_control_provider=None,  # type: Optional[TShimmedFunc]
    tempdir_manager_provider=None,  # type: TShimmedFunc
):
    tempdir_manager_provider = resolve_possible_shim(tempdir_manager_provider)
    wheel_cache_provider = resolve_possible_shim(wheel_cache_provider)
    format_control_provider = resolve_possible_shim(format_control_provider)
    if not format_control and not format_control_provider:
        raise TypeError("Format control or provider needed for wheel cache!")
    if not format_control:
        format_control = format_control_provider(None, None)
    with ExitStack() as ctx:
        ctx.enter_context(tempdir_manager_provider())
        wheel_cache = wheel_cache_provider(cache_dir, format_control)
        yield wheel_cache 
Example #21
Source File: compat.py    From pipenv with MIT License 6 votes vote down vote up
def get_requirement_tracker(req_tracker_creator=None):
    # type: (Optional[Callable]) -> Generator[Optional[TReqTracker], None, None]
    root = os.environ.get("PIP_REQ_TRACKER")
    if not req_tracker_creator:
        yield None
    else:
        req_tracker_args = []
        _, required_args = get_method_args(req_tracker_creator.__init__)  # type: ignore
        with ExitStack() as ctx:
            if root is None:
                root = ctx.enter_context(TemporaryDirectory(prefix="req-tracker"))
                if root:
                    root = str(root)
                    ctx.enter_context(temp_environ())
                    os.environ["PIP_REQ_TRACKER"] = root
            if required_args is not None and "root" in required_args:
                req_tracker_args.append(root)
            with req_tracker_creator(*req_tracker_args) as tracker:
                yield tracker 
Example #22
Source File: metrics.py    From armada with Apache License 2.0 6 votes vote down vote up
def get_context(self, *args, **kwargs):
        """ Any extra args are used as metric label values.

        :return: a context manager for the action which observes the desired
        metrics.
        :rtype: contextmanager
        """
        progress = self.progress.labels(*args, **kwargs)
        attempt_total = self.attempt_total.labels(*args, **kwargs)
        attempt_total.inc()
        failure_total = self.failure_total.labels(*args, **kwargs)
        duration = self.duration.labels(*args, **kwargs)

        e = ExitStack()
        contexts = [
            progress.track_inprogress(),
            failure_total.count_exceptions(),
            duration.time()
        ]
        for ctx in contexts:
            e.enter_context(ctx)
        return e 
Example #23
Source File: evaluate.py    From swiftnet with GNU General Public License v3.0 6 votes vote down vote up
def evaluate_semseg(model, data_loader, class_info, observers=()):
    model.eval()
    managers = [torch.no_grad()] + list(observers)
    with contextlib.ExitStack() as stack:
        for ctx_mgr in managers:
            stack.enter_context(ctx_mgr)
        conf_mat = np.zeros((model.num_classes, model.num_classes), dtype=np.uint64)
        for step, batch in tqdm(enumerate(data_loader), total=len(data_loader)):
            batch['original_labels'] = batch['original_labels'].numpy().astype(np.uint32)
            logits, additional = model.do_forward(batch, batch['original_labels'].shape[1:3])
            pred = torch.argmax(logits.data, dim=1).byte().cpu().numpy().astype(np.uint32)
            for o in observers:
                o(pred, batch, additional)
            cylib.collect_confusion_matrix(pred.flatten(), batch['original_labels'].flatten(), conf_mat)
        print('')
        pixel_acc, iou_acc, recall, precision, _, per_class_iou = compute_errors(conf_mat, class_info, verbose=True)
    model.train()
    return iou_acc, per_class_iou 
Example #24
Source File: test_start_view_change_ts_set.py    From indy-plenum with Apache License 2.0 6 votes vote down vote up
def create_node_and_not_start(testNodeClass,
                              node_config_helper_class,
                              tconf,
                              tdir,
                              allPluginsPath,
                              looper,
                              tdirWithPoolTxns,
                              tdirWithDomainTxns,
                              tdirWithNodeKeepInited):
    with ExitStack() as exitStack:
        node = exitStack.enter_context(create_new_test_node(testNodeClass,
                                                            node_config_helper_class,
                                                            "Alpha",
                                                            tconf,
                                                            tdir,
                                                            allPluginsPath))
        yield node
        node.stop() 
Example #25
Source File: test__task_commons.py    From tf-yarn with Apache License 2.0 5 votes vote down vote up
def test_wait_for_connected_tasks():
    with contextlib.ExitStack() as stack:
        mocked_event = stack.enter_context(patch(f'{MODULE_TO_TEST}.event'))
        mocked_filter = stack.enter_context(patch(f'{MODULE_TO_TEST}.matches_device_filters'))
        mocked_filter.return_value = True
        tasks = ['task:1', 'task:2']
        message = 'tag'
        wait_for_connected_tasks(None, tasks, None, message)
        calls = [mock.call(None, f'{task}/{message}') for task in tasks]
        mocked_event.wait.assert_has_calls(calls, any_order=True) 
Example #26
Source File: test_particles.py    From Particle-Cloud-Framework with Apache License 2.0 5 votes vote down vote up
def test_apply(definition, changes, test_type):
    flavor = definition.get("flavor")
    particle_class = particle_flavor_scanner.get_particle_flavor(flavor)
    session = None
    with ExitStack() as stack:
        if test_type[0] == "placebo":
            session = boto3.Session()
            dirname = os.path.dirname(__file__)
            filename = os.path.join(dirname, test_type[1])
            pill = placebo.attach(session, data_path=filename)
            pill.playback()
        else:
            for context in test_type:
                stack.enter_context(getattr(moto, context)())
        # create
        particle = particle_class(definition, session)
        particle.set_desired_state(State.running)
        particle.apply(sync=True)

        assert particle.get_state() == State.running
        # print(particle.current_state_definition, particle.desired_state_definition)
        assert particle.is_state_definition_equivalent()
        # update
        if changes:
            updated_definition, diff = pcf_util.update_dict(definition, changes)
            if changes.get("aws_resource", {}).get("Tags"):
                updated_definition["aws_resource"]["Tags"] = changes.get("aws_resource", {}).get("Tags")
            elif changes.get("aws_resource", {}).get("tags"):
                updated_definition["aws_resource"]["tags"] = changes.get("aws_resource", {}).get("tags")
            particle = particle_class(updated_definition, session)
            particle.set_desired_state(State.running)
            particle.apply(sync=True)
            assert particle.is_state_definition_equivalent()
        # terminate
        particle.set_desired_state(State.terminated)
        particle.apply(sync=True)

        assert particle.get_state() == State.terminated 
Example #27
Source File: test_query_device_capacity.py    From bitmath with MIT License 5 votes vote down vote up
def nested(*contexts):
        """Emulation of contextlib.nested in terms of ExitStack

        Has the problems for which "nested" was removed from Python; see:
            https://docs.python.org/2/library/contextlib.html#contextlib.nested
        But for mock.patch, these do not matter.
        """
        with ExitStack() as stack:
            yield tuple(stack.enter_context(c) for c in contexts) 
Example #28
Source File: compatibility.py    From pex with Apache License 2.0 5 votes vote down vote up
def nested(*context_managers):
    enters = []
    with ExitStack() as stack:
      for manager in context_managers:
        enters.append(stack.enter_context(manager))
      yield tuple(enters) 
Example #29
Source File: fixtures.py    From python-netsurv with MIT License 5 votes vote down vote up
def setUp(self):
        self.fixtures = ExitStack()
        self.addCleanup(self.fixtures.close)
        self.site_dir = self.fixtures.enter_context(tempdir()) 
Example #30
Source File: test__task_commons.py    From tf-yarn with Apache License 2.0 5 votes vote down vote up
def test__shutdown_container():
    with contextlib.ExitStack() as stack:
        stack.enter_context(patch(f'{MODULE_TO_TEST}.cluster'))
        mocked_event = stack.enter_context(patch(f'{MODULE_TO_TEST}.event'))
        mocked_wait = stack.enter_context(patch(f'{MODULE_TO_TEST}.wait_for_connected_tasks'))

        mocked_config = mock.MagicMock(spec=tf.estimator.RunConfig)
        mocked_thread = mock.MagicMock(spec=MonitoredThread)
        mocked_thread.exception.return_value = Exception()
        with pytest.raises(Exception):
            _shutdown_container(None, None, mocked_config, mocked_thread)

        mocked_event.stop_event.assert_called_once()
        mocked_wait.assert_called_once()