Python unittest.mock.DEFAULT Examples

The following are 30 code examples of unittest.mock.DEFAULT(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module unittest.mock , or try the search function .
Example #1
Source File: test_event_listener.py    From clearly with MIT License 6 votes vote down vote up
def test_listener_process_event_worker(listener):
    with mock.patch.multiple(listener, _set_task_event=DEFAULT,
                             _set_worker_event=DEFAULT, _set_custom_event=DEFAULT) as mtw, \
            mock.patch('clearly.server.event_listener.obj_to_message') as otm:
        mtw['_set_worker_event'].return_value = (x for x in ('obj', 'ok'))

        # noinspection PyProtectedMember
        listener._process_event(dict(type='worker-anything'))

    mtw['_set_worker_event'].assert_called_once_with(dict(type='worker-anything'))
    for k in '_set_task_event', '_set_custom_event':
        mtw[k].assert_not_called()
    assert listener.queue_workers.qsize() == 1
    assert listener.queue_tasks.qsize() == 0
    assert otm.call_args_list == [
        call('obj', WorkerMessage, state='ok'),
    ] 
Example #2
Source File: env_test.py    From pypyr-cli with Apache License 2.0 6 votes vote down vote up
def test_env_only_calls_get():
    """Call only get."""
    context = Context({
        'key1': 'value1',
        'key2': 'value2',
        'key3': 'value3',
        'env': {'get': {
            'key2': 'ARB_GET_ME1',
            'key4': 'ARB_GET_ME2'
        }}
    })

    with patch.multiple('pypyr.steps.env',
                        env_get=DEFAULT,
                        env_set=DEFAULT,
                        env_unset=DEFAULT
                        ) as mock_env:
        pypyr.steps.env.run_step(context)

    mock_env['env_get'].assert_called_once()
    mock_env['env_set'].assert_called_once()
    mock_env['env_unset'].assert_called_once() 
Example #3
Source File: env_test.py    From pypyr-cli with Apache License 2.0 6 votes vote down vote up
def test_env_only_calls_set():
    """Call only set."""
    context = Context({
        'key1': 'value1',
        'key2': 'value2',
        'key3': 'value3',
        'env': {'set': {
            'ARB_SET_ME1': 'key2',
            'ARB_SET_ME2': 'key1'
        }}
    })

    with patch.multiple('pypyr.steps.env',
                        env_get=DEFAULT,
                        env_set=DEFAULT,
                        env_unset=DEFAULT
                        ) as mock_env:
        pypyr.steps.env.run_step(context)

    mock_env['env_get'].assert_called_once()
    mock_env['env_set'].assert_called_once()
    mock_env['env_unset'].assert_called_once() 
Example #4
Source File: tar_test.py    From pypyr-cli with Apache License 2.0 6 votes vote down vote up
def test_tar_only_calls_extract():
    """Only calls extract if only extract specified."""
    context = Context({
        'key1': 'value1',
        'key2': 'value2',
        'key3': 'value3',
        'tar': {'extract': [
            {'in': 'key2',
             'out': 'ARB_GET_ME1'},
            {'in': 'key4',
             'out': 'ARB_GET_ME2'}
        ]}
    })

    with patch.multiple('pypyr.steps.tar',
                        tar_archive=DEFAULT,
                        tar_extract=DEFAULT
                        ) as mock_tar:
        pypyr.steps.tar.run_step(context)

    mock_tar['tar_extract'].assert_called_once()
    mock_tar['tar_archive'].assert_not_called() 
Example #5
Source File: tar_test.py    From pypyr-cli with Apache License 2.0 6 votes vote down vote up
def test_tar_only_calls_archive():
    """Only calls archive if only archive specified."""
    context = Context({
        'key1': 'value1',
        'key2': 'value2',
        'key3': 'value3',
        'tar': {'archive': [
            {'in': 'key2',
             'out': 'ARB_GET_ME1'},
            {'in': 'key4',
             'out': 'ARB_GET_ME2'}
        ]}
    })

    with patch.multiple('pypyr.steps.tar',
                        tar_archive=DEFAULT,
                        tar_extract=DEFAULT
                        ) as mock_tar:
        pypyr.steps.tar.run_step(context)

    mock_tar['tar_extract'].assert_not_called()
    mock_tar['tar_archive'].assert_called_once() 
Example #6
Source File: test_luks.py    From os-brick with Apache License 2.0 6 votes vote down vote up
def test_attach_volume_fail(self, mock_execute):
        fake_key = 'ea6c2e1b8f7f4f84ae3560116d659ba2'
        self.encryptor._get_key = mock.MagicMock()
        self.encryptor._get_key.return_value = (
            test_cryptsetup.fake__get_key(None, fake_key))

        mock_execute.side_effect = [
            putils.ProcessExecutionError(exit_code=1),  # luksOpen
            mock.DEFAULT,  # isLuks
        ]

        self.assertRaises(putils.ProcessExecutionError,
                          self.encryptor.attach_volume, None)

        mock_execute.assert_has_calls([
            mock.call('cryptsetup', 'luksOpen', '--key-file=-', self.dev_path,
                      self.dev_name, process_input=fake_key,
                      root_helper=self.root_helper,
                      run_as_root=True, check_exit_code=True),
            mock.call('cryptsetup', 'isLuks', '--verbose', self.dev_path,
                      root_helper=self.root_helper,
                      run_as_root=True, check_exit_code=True),
        ], any_order=False) 
Example #7
Source File: test_allowed_address_pairs.py    From octavia with Apache License 2.0 6 votes vote down vote up
def test_set_port_admin_state_up(self):
        PORT_ID = uuidutils.generate_uuid()
        TEST_STATE = 'test state'

        self.driver.neutron_client.update_port.side_effect = [
            mock.DEFAULT, neutron_exceptions.NotFound, Exception('boom')]

        # Test successful state set
        self.driver.set_port_admin_state_up(PORT_ID, TEST_STATE)

        self.driver.neutron_client.update_port.assert_called_once_with(
            PORT_ID, {'port': {'admin_state_up': TEST_STATE}})

        # Test port NotFound
        self.assertRaises(network_base.PortNotFound,
                          self.driver.set_port_admin_state_up,
                          PORT_ID, {'port': {'admin_state_up': TEST_STATE}})

        # Test unknown exception
        self.assertRaises(exceptions.NetworkServiceError,
                          self.driver.set_port_admin_state_up, PORT_ID,
                          {'port': {'admin_state_up': TEST_STATE}}) 
Example #8
Source File: test_allowed_address_pairs.py    From octavia with Apache License 2.0 6 votes vote down vote up
def test_delete_port(self):
        PORT_ID = uuidutils.generate_uuid()

        self.driver.neutron_client.delete_port.side_effect = [
            mock.DEFAULT, neutron_exceptions.NotFound, Exception('boom')]

        # Test successful delete
        self.driver.delete_port(PORT_ID)

        self.driver.neutron_client.delete_port.assert_called_once_with(PORT_ID)

        # Test port NotFound (does not raise)
        self.driver.delete_port(PORT_ID)

        # Test unknown exception
        self.assertRaises(exceptions.NetworkServiceError,
                          self.driver.delete_port, PORT_ID) 
Example #9
Source File: test_vhdutils.py    From os-win with Apache License 2.0 6 votes vote down vote up
def _setup_lib_mocks(self):
        self._vdisk_struct = mock.Mock()
        self._ctypes = mock.Mock()
        # This is used in order to easily make assertions on the variables
        # passed by reference.
        self._ctypes.byref = lambda x: (x, "byref")
        self._ctypes.c_wchar_p = lambda x: (x, "c_wchar_p")
        self._ctypes.c_ulong = lambda x: (x, "c_ulong")

        self._ctypes_patcher = mock.patch.object(
            vhdutils, 'ctypes', self._ctypes)
        self._ctypes_patcher.start()

        mock.patch.multiple(vhdutils,
                            kernel32=mock.DEFAULT,
                            wintypes=mock.DEFAULT, virtdisk=mock.DEFAULT,
                            vdisk_struct=self._vdisk_struct,
                            create=True).start() 
Example #10
Source File: mock_.py    From asynq with Apache License 2.0 6 votes vote down vote up
def _patch_object(
    target,
    attribute,
    new=mock.DEFAULT,
    spec=None,
    create=False,
    mocksignature=False,
    spec_set=None,
    autospec=False,
    new_callable=None,
    **kwargs
):
    getter = lambda: target
    return _make_patch_async(
        getter,
        attribute,
        new,
        spec,
        create,
        mocksignature,
        spec_set,
        autospec,
        new_callable,
        kwargs,
    ) 
Example #11
Source File: test_pathutils.py    From os-win with Apache License 2.0 6 votes vote down vote up
def _setup_lib_mocks(self):
        self._ctypes = mock.Mock()
        self._wintypes = mock.Mock()

        self._wintypes.BOOL = lambda x: (x, 'BOOL')
        self._ctypes.c_wchar_p = lambda x: (x, "c_wchar_p")
        self._ctypes.pointer = lambda x: (x, 'pointer')

        self._ctypes_patcher = mock.patch.object(
            pathutils, 'ctypes', new=self._ctypes)
        self._ctypes_patcher.start()

        mock.patch.multiple(pathutils,
                            wintypes=self._wintypes,
                            kernel32=mock.DEFAULT,
                            create=True).start() 
Example #12
Source File: env_test.py    From pypyr-cli with Apache License 2.0 6 votes vote down vote up
def test_env_only_calls_unset():
    """Call only unset."""
    context = Context({
        'key1': 'value1',
        'key2': 'value2',
        'key3': 'value3',
        'env': {'unset': [
            'ARB_DELETE_ME1',
            'ARB_DELETE_ME2'
        ]}
    })

    with patch.multiple('pypyr.steps.env',
                        env_get=DEFAULT,
                        env_set=DEFAULT,
                        env_unset=DEFAULT
                        ) as mock_env:
        pypyr.steps.env.run_step(context)

    mock_env['env_get'].assert_called_once()
    mock_env['env_set'].assert_called_once()
    mock_env['env_unset'].assert_called_once() 
Example #13
Source File: test_checker.py    From awslimitchecker with GNU Affero General Public License v3.0 6 votes vote down vote up
def test_check_version_not_old(self):
        with patch.multiple(
            'awslimitchecker.checker',
            logger=DEFAULT,
            _get_version_info=DEFAULT,
            TrustedAdvisor=DEFAULT,
            _get_latest_version=DEFAULT,
            autospec=True,
        ) as mocks:
            mocks['_get_version_info'].return_value = self.mock_ver_info
            mocks['_get_latest_version'].return_value = None
            AwsLimitChecker()
        assert mocks['_get_latest_version'].mock_calls == [call()]
        assert mocks['logger'].mock_calls == [
            call.debug('Connecting to region %s', None)
        ] 
Example #14
Source File: test_p11_crypto.py    From barbican with Apache License 2.0 6 votes vote down vote up
def test_call_pkcs11_with_token_error(self):
        self.plugin._encrypt = mock.Mock()
        self.plugin._encrypt.side_effect = [ex.P11CryptoTokenException(
            'Testing error handling'
        ),
            'test payload']
        self.plugin._reinitialize_pkcs11 = mock.Mock()
        self.plugin._reinitialize_pkcs11.return_value = mock.DEFAULT

        self.plugin.encrypt(mock.MagicMock(), mock.MagicMock(),
                            mock.MagicMock())

        self.assertEqual(2, self.pkcs11.get_key_handle.call_count)
        self.assertEqual(1, self.pkcs11.get_session.call_count)
        self.assertEqual(0, self.pkcs11.return_session.call_count)
        self.assertEqual(2, self.plugin._encrypt.call_count) 
Example #15
Source File: test_route53.py    From awslimitchecker with GNU Affero General Public License v3.0 6 votes vote down vote up
def test_update_limits_from_api(self):
        """test _update_limits_from_api method calls other methods"""

        mock_conn = Mock()
        with patch('%s.connect' % pb) as mock_connect:
            with patch.multiple(
                    pb,
                    _find_limit_hosted_zone=DEFAULT,
            ) as mocks:
                cls = _Route53Service(21, 43, {}, None)
                cls.conn = mock_conn
                cls._update_limits_from_api()
        assert mock_connect.mock_calls == [call()]
        assert mock_conn.mock_calls == []
        for x in [
            '_find_limit_hosted_zone',
        ]:
            assert mocks[x].mock_calls == [call()] 
Example #16
Source File: test_rds.py    From awslimitchecker with GNU Affero General Public License v3.0 6 votes vote down vote up
def test_find_usage(self):
        mock_conn = Mock()

        with patch('%s.connect' % self.pb) as mock_connect:
            with patch.multiple(
                    self.pb,
                    _find_usage_instances=DEFAULT,
                    _find_usage_subnet_groups=DEFAULT,
                    _find_usage_security_groups=DEFAULT,
                    _update_limits_from_api=DEFAULT,
            ) as mocks:
                cls = _RDSService(21, 43, {}, None)
                cls.conn = mock_conn
                assert cls._have_usage is False
                cls.find_usage()
        assert mock_connect.mock_calls == [call()]
        assert cls._have_usage is True
        for x in [
                '_find_usage_instances',
                '_find_usage_subnet_groups',
                '_find_usage_security_groups',
                '_update_limits_from_api',
        ]:
            assert mocks[x].mock_calls == [call()] 
Example #17
Source File: test_sr_amqp.py    From sarracenia with GNU General Public License v2.0 6 votes vote down vote up
def test_connect__multiple_amqp_init_errors(self, sleep, chan, conn):
        # Prepare test
        conn.side_effect = [
            RecoverableConnectionError('connection already closed'),
            ValueError("Must supply authentication or userid/password"),
            ValueError("Invalid login method", 'login_method'),
            DEFAULT
        ]
        # Execute test
        ok = self.hc.connect()
        # Evaluate results
        expected = [call(self.hc.host, userid=self.hc.user, password=self.hc.password, virtual_host=self.hc.vhost,
                         ssl=self.hc.ssl)]*4 + [call().connect(), call().channel()]
        self.assertEqual(expected, conn.mock_calls, self.amqp_connection_assert_msg)
        expected = [call(2), call(4), call(8)]
        self.assertEqual(expected, sleep.mock_calls)
        self.assertTrue(ok)
        self.assertIsNotNone(self.hc.connection)
        self.assertIsNotNone(self.hc.channel)
        self.assertEqual(1, len(self.hc.toclose))
        self.assertErrorInLog() 
Example #18
Source File: test_apigateway.py    From awslimitchecker with GNU Affero General Public License v3.0 6 votes vote down vote up
def test_find_usage(self):
        mock_conn = Mock()
        with patch('%s.connect' % pb) as mock_connect:
            with patch.multiple(
                pb,
                autospec=True,
                _find_usage_apis=DEFAULT,
                _find_usage_api_keys=DEFAULT,
                _find_usage_certs=DEFAULT,
                _find_usage_plans=DEFAULT,
                _find_usage_vpc_links=DEFAULT
            ) as mocks:
                cls = _ApigatewayService(21, 43, {}, None)
                cls.conn = mock_conn
                assert cls._have_usage is False
                cls.find_usage()
        assert mock_connect.mock_calls == [call()]
        assert cls._have_usage is True
        assert mock_conn.mock_calls == []
        assert mocks['_find_usage_apis'].mock_calls == [call(cls)]
        assert mocks['_find_usage_api_keys'].mock_calls == [call(cls)]
        assert mocks['_find_usage_certs'].mock_calls == [call(cls)]
        assert mocks['_find_usage_plans'].mock_calls == [call(cls)]
        assert mocks['_find_usage_vpc_links'].mock_calls == [call(cls)] 
Example #19
Source File: test_elasticbeanstalk.py    From awslimitchecker with GNU Affero General Public License v3.0 6 votes vote down vote up
def test_find_usage(self):
        """test find usage method calls other methods"""
        mock_conn = Mock()
        with patch('%s.connect' % pb) as mock_connect:
            with patch.multiple(
                pb,
                _find_usage_applications=DEFAULT,
                _find_usage_application_versions=DEFAULT,
                _find_usage_environments=DEFAULT,
            ) as mocks:
                cls = _ElasticBeanstalkService(21, 43, {}, None)
                cls.conn = mock_conn
                assert cls._have_usage is False
                cls.find_usage()
        assert mock_connect.mock_calls == [call()]
        assert cls._have_usage is True
        assert mock_conn.mock_calls == []
        for x in [
            '_find_usage_applications',
            '_find_usage_application_versions',
            '_find_usage_environments',
        ]:
            assert mocks[x].mock_calls == [call()] 
Example #20
Source File: test_sr_amqp.py    From sarracenia with GNU General Public License v2.0 6 votes vote down vote up
def test_connect__multiple_amqp_connect_errors(self, sleep, chan, conn):
        # Prepare test
        conn.return_value = conn
        conn.connect.side_effect = [
            AMQPError(self.AMQPError_msg),
            SSLError('SSLError stub'),
            IOError('IOError stub'),
            OSError('OSError stub'),
            Exception(self.Exception_msg),
            DEFAULT
        ]
        # Execute test
        ok = self.hc.connect()
        # Evaluate results
        expected = [call(self.hc.host, userid=self.hc.user, password=self.hc.password, virtual_host=self.hc.vhost,
                         ssl=self.hc.ssl), call.connect()]*6 + [call.channel()]
        self.assertEqual(expected, conn.mock_calls, self.amqp_connection_assert_msg)
        expected = [call(2), call(4), call(8), call(16), call(32)]
        self.assertEqual(expected, sleep.mock_calls, self.sleep_assert_msg)
        self.assertTrue(ok)
        self.assertIsNotNone(self.hc.connection)
        self.assertIsNotNone(self.hc.channel)
        self.assertEqual(1, len(self.hc.toclose))
        self.assertErrorInLog() 
Example #21
Source File: test_redshift.py    From awslimitchecker with GNU Affero General Public License v3.0 6 votes vote down vote up
def test_find_usage(self):
        """test find usage method calls other methods"""
        mock_conn = Mock()
        with patch('%s.connect' % pb) as mock_connect:
            with patch.multiple(
                pb,
                _find_cluster_manual_snapshots=DEFAULT,
                _find_cluster_subnet_groups=DEFAULT,
            ) as mocks:
                cls = _RedshiftService(21, 43, {}, None)
                cls.conn = mock_conn
                assert cls._have_usage is False
                cls.find_usage()
        assert mock_connect.mock_calls == [call()]
        assert cls._have_usage is True
        assert mock_conn.mock_calls == []
        for x in [
            '_find_cluster_manual_snapshots',
            '_find_cluster_subnet_groups',
        ]:
            assert mocks[x].mock_calls == [call()] 
Example #22
Source File: test_sr_amqp.py    From sarracenia with GNU General Public License v2.0 6 votes vote down vote up
def test_build__bind_Exception(self, hc, chan):
        # Prepare test
        xname = self.xname_fmt.format(self.test_build.__name__)
        xkey = self.xkey_fmt.format(self.test_build.__name__)
        self.q.bindings.append((xname, xkey))
        chan.queue_bind.side_effect = [Exception(self.Exception_msg), DEFAULT]
        hc.new_channel.return_value = chan
        hc.user = '{}_user'.format(self.test_build.__name__)
        hc.host = '{}_host'.format(self.test_build.__name__)
        self.q.hc = hc
        self.q.declare = Mock(return_value=self.msg_count)
        # Execute test
        self.q.build()
        # Evaluate results
        expected = [call.new_channel()]
        self.assertEqual(expected, hc.mock_calls, self.hc_assert_msg)
        expected = [call.queue_bind(self.q.name, xname, xkey),
                    call.queue_bind(self.q.name, xname, xkey)
                    # ,call.queue_bind(self.q.name, xname, self.pulse_key)
                    ]
        self.assertEqual(expected, chan.mock_calls, self.amqp_channel_assert_msg)
        self.assertErrorInLog() 
Example #23
Source File: test_event_listener.py    From clearly with MIT License 6 votes vote down vote up
def test_listener_process_event_task(listener):
    with mock.patch.multiple(listener, _set_task_event=DEFAULT,
                             _set_worker_event=DEFAULT, _set_custom_event=DEFAULT) as mtw, \
            mock.patch('clearly.server.event_listener.obj_to_message') as otm:
        mtw['_set_task_event'].return_value = (x for x in chain(('obj',), 'abc'))

        # noinspection PyProtectedMember
        listener._process_event(dict(type='task-anything'))

    mtw['_set_task_event'].assert_called_once_with(dict(type='task-anything'))
    for k in '_set_worker_event', '_set_custom_event':
        mtw[k].assert_not_called()
    assert listener.queue_tasks.qsize() == 3
    assert listener.queue_workers.qsize() == 0
    assert otm.call_args_list == [
        call('obj', TaskMessage, state='a'),
        call('obj', TaskMessage, state='b'),
        call('obj', TaskMessage, state='c'),
    ] 
Example #24
Source File: test_checker.py    From awslimitchecker with GNU Affero General Public License v3.0 6 votes vote down vote up
def test_check_version_old(self):
        with patch.multiple(
            'awslimitchecker.checker',
            logger=DEFAULT,
            _get_version_info=DEFAULT,
            TrustedAdvisor=DEFAULT,
            _get_latest_version=DEFAULT,
            autospec=True,
        ) as mocks:
            mocks['_get_version_info'].return_value = self.mock_ver_info
            mocks['_get_latest_version'].return_value = '3.4.5'
            AwsLimitChecker()
        assert mocks['_get_latest_version'].mock_calls == [call()]
        assert mocks['logger'].mock_calls == [
            call.warning(
                'You are running awslimitchecker %s, but the latest version'
                ' is %s; please consider upgrading.', '1.2.3', '3.4.5'
            ),
            call.debug('Connecting to region %s', None)
        ] 
Example #25
Source File: test_local.py    From magnum with Apache License 2.0 5 votes vote down vote up
def _get_cert_with_fail(self, cert_id, failed='crt'):
        def fake_open(path, mode):
            if path == os.path.join('/tmp/{0}.{1}'.format(cert_id, failed)):
                raise IOError()
            return mock.DEFAULT

        file_mock = mock.mock_open()
        file_mock.side_effect = fake_open
        # Attempt to retrieve the cert
        with mock.patch('six.moves.builtins.open', file_mock, create=True):
            self.assertRaises(
                exception.CertificateStorageException,
                local_cert_manager.CertManager.get_cert,
                cert_id
            ) 
Example #26
Source File: testsentinel.py    From ironpython3 with Apache License 2.0 5 votes vote down vote up
def testDEFAULT(self):
        self.assertIs(DEFAULT, sentinel.DEFAULT) 
Example #27
Source File: volume_test.py    From imagemounter with MIT License 5 votes vote down vote up
def test_luks_key_communication(self, check_call, check_output):
        def modified_check_call(cmd, *args, **kwargs):
            if cmd[0:2] == ['cryptsetup', 'isLuks']:
                return True
            if cmd[0:1] == ['losetup']:
                return "/dev/loop0"
            return mock.DEFAULT
        check_call.side_effect = modified_check_call

        def modified_check_output(cmd, *args, **kwargs):
            if cmd[0:1] == ['losetup']:
                return "/dev/loop0"
            return mock.DEFAULT
        check_output.side_effect = modified_check_output

        original_popen = subprocess.Popen
        def modified_popen(cmd, *args, **kwargs):
            if cmd[0:3] == ['cryptsetup', '-r', 'luksOpen']:
                # A command that requests user input
                x = original_popen([sys.executable, "-c", "print(input(''))"],
                                   *args, **kwargs)
                return x
            return mock.DEFAULT

        with mock.patch("subprocess.Popen", side_effect=modified_popen) as popen:
            disk = Disk(ImageParser(keys={'1': 'p:passphrase'}), "...")
            disk.is_mounted = True
            volume = Volume(disk=disk, fstype='luks', index='1', parent=disk)
            volume.mount()

            self.assertTrue(volume.is_mounted)
            self.assertEqual(len(volume.volumes), 1)
            self.assertEqual(volume.volumes[0].info['fsdescription'], "LUKS Volume") 
Example #28
Source File: testsentinel.py    From Imogen with MIT License 5 votes vote down vote up
def testDEFAULT(self):
        self.assertIs(DEFAULT, sentinel.DEFAULT) 
Example #29
Source File: test_activities.py    From heaviside with Apache License 2.0 5 votes vote down vote up
def test_run(self, mCreateSession, mSample):
        mSample.return_value = 'XXX'
        iSession = MockSession()
        mCreateSession.return_value = (iSession, '123456')
        client = iSession.client('stepfunctions')
        client.list_activities.return_value = {
            'activities':[{
                'name': 'name',
                'activityArn': 'XXX'
            }]
        }
        client.get_activity_task.return_value = {
            'taskToken': 'YYY',
            'input': '{}'
        }

        target = mock.MagicMock()

        activity = ActivityMixin(handle_task = target)

        def stop_loop(*args, **kwargs):
            activity.polling = False
            return mock.DEFAULT
        target.side_effect = stop_loop

        activity.run('name')

        calls = [
            mock.call.list_activities(),
            mock.call.get_activity_task(activityArn = 'XXX',
                                        workerName = 'name-XXX')
        ]
        self.assertEqual(client.mock_calls, calls)

        calls = [
            mock.call('YYY', {}),
            mock.call().start()
        ]
        self.assertEqual(target.mock_calls, calls) 
Example #30
Source File: test_projects.py    From ray with Apache License 2.0 5 votes vote down vote up
def run_test_project(project_dir, command, args):
    # Run the CLI commands with patching
    test_dir = os.path.join(TEST_DIR, project_dir)
    with _chdir_and_back(test_dir):
        runner = CliRunner()
        with patch.multiple(
                "ray.projects.scripts",
                create_or_update_cluster=DEFAULT,
                rsync=DEFAULT,
                exec_cluster=DEFAULT,
        ) as mock_calls:
            result = runner.invoke(command, args)

    return result, mock_calls, test_dir