import asyncio import pytest from aioamqp import AioamqpException from async_timeout import timeout from aioamqp_consumer import JsonRpcMethod, RpcError, RpcMethod @pytest.mark.asyncio async def test_rpc(rpc_client_close, rpc_server_close, amqp_queue_name): test_data = b'test' @RpcMethod.init(amqp_queue_name) async def test_method(payload): return payload await rpc_server_close(test_method, amqp_queue_name) client = await rpc_client_close() test_result = await client.wait(test_method(test_data)) assert test_result == test_data @pytest.mark.asyncio async def test_json_rpc_kwargs( rpc_client_close, rpc_server_close, amqp_queue_name, ): @JsonRpcMethod.init(amqp_queue_name) async def test_method(*, x): return x ** 2 await rpc_server_close(test_method, amqp_queue_name) client = await rpc_client_close() test_result = await client.wait(test_method(x=2)) assert test_result == 4 @pytest.mark.asyncio async def test_json_rpc_args( rpc_client_close, rpc_server_close, amqp_queue_name, ): @JsonRpcMethod.init(amqp_queue_name) async def test_method(x): return x ** 2 await rpc_server_close(test_method, amqp_queue_name) client = await rpc_client_close() test_result = await client.wait(test_method(2)) assert test_result == 4 @pytest.mark.asyncio async def test_json_empty_payload( rpc_client_close, rpc_server_close, amqp_queue_name, ): test_data = 42 @JsonRpcMethod.init(amqp_queue_name) async def test_method(): return test_data await rpc_server_close(test_method, amqp_queue_name) client = await rpc_client_close() test_result = await client.wait(test_method()) assert test_result == test_data @pytest.mark.asyncio async def test_rpc_none_arg( rpc_client_close, rpc_server_close, amqp_queue_name, ): @RpcMethod.init(amqp_queue_name) async def test_method(obj): assert obj is None await rpc_server_close(test_method, amqp_queue_name) client = await rpc_client_close() test_result = await client.wait(test_method(None)) assert test_result is None @pytest.mark.asyncio async def test_rpc_no_result( rpc_client_close, rpc_server_close, amqp_queue_name, ): @RpcMethod.init(amqp_queue_name) async def test_method(): pass await rpc_server_close(test_method, amqp_queue_name) client = await rpc_client_close() test_result = await client.wait(test_method()) assert test_result is None @pytest.mark.asyncio async def test_rpc_no_payload( rpc_client_close, rpc_server_close, amqp_queue_name, ): test_data = b'test' @RpcMethod.init(amqp_queue_name) async def test_method(): return test_data await rpc_server_close(test_method, amqp_queue_name) client = await rpc_client_close() test_result = await client.wait(test_method()) assert test_result == test_data @pytest.mark.asyncio async def test_rpc_empty_payload( rpc_client_close, rpc_server_close, amqp_queue_name, ): test_data = b'' @RpcMethod.init(amqp_queue_name) async def test_method(payload): assert payload == test_data return test_data await rpc_server_close(test_method, amqp_queue_name) client = await rpc_client_close() test_result = await client.wait(test_method(test_data)) assert test_result == test_data @pytest.mark.asyncio async def test_rpc_call( rpc_client_factory, rpc_server_close, amqp_queue_name, ): fut = asyncio.Future() @RpcMethod.init(amqp_queue_name) async def test_method(): fut.set_result(True) await rpc_server_close(test_method, amqp_queue_name) client = await rpc_client_factory() resp = await client.call(test_method()) assert resp is None assert not client._map await client.close() async with timeout(0.1): assert await fut @pytest.mark.asyncio async def test_rpc_remote(rpc_client_close, rpc_server_close, amqp_queue_name): test_data = b'test' @RpcMethod.init(amqp_queue_name) async def remote_test_method(payload): return payload await rpc_server_close(remote_test_method, amqp_queue_name) client = await rpc_client_close() local_test_method = RpcMethod.remote_init(amqp_queue_name) test_result = await client.wait(local_test_method(test_data)) assert test_result == test_data @pytest.mark.asyncio async def test_rpc_timeout( rpc_client_close, rpc_server_close, amqp_queue_name, ): fut = asyncio.Future() @RpcMethod.init(amqp_queue_name) async def test_method(): await asyncio.sleep(0.2) fut.set_result(True) server = await rpc_server_close(test_method, amqp_queue_name) client = await rpc_client_close() with pytest.raises(asyncio.TimeoutError): await client.wait(test_method(), timeout=0.1) await server.join() async with timeout(0.2): assert await fut @pytest.mark.asyncio async def test_rpc_wait_response( rpc_client_close, rpc_server_close, amqp_queue_name, ): fut = asyncio.Future() test_result = b'result' @RpcMethod.init(amqp_queue_name) async def test_method(): await asyncio.sleep(0.2) fut.set_result(test_result) return b'result' await rpc_server_close(test_method, amqp_queue_name) client = await rpc_client_close() response = await client.wait( test_method(), timeout=0.1, wait_response=False, ) with pytest.raises(asyncio.TimeoutError): await response await asyncio.sleep(0.15) async with timeout(0.2): assert await fut == test_result assert await response == test_result @pytest.mark.asyncio async def test_rpc_error(rpc_client_close, rpc_server_close, amqp_queue_name): @RpcMethod.init(amqp_queue_name) async def test_method(): class Error(Exception): pass raise Error await rpc_server_close(test_method, amqp_queue_name) client = await rpc_client_close() with pytest.raises(RpcError) as exc_info: await client.wait(test_method()) # Can't pickle local object assert isinstance(exc_info.value.err, AttributeError) @pytest.mark.asyncio async def test_on_error_shutdown( rpc_client_close, rpc_server_close, amqp_queue_name, ): test_data = b'result' @RpcMethod.init(amqp_queue_name) async def test_method(): return test_data await rpc_server_close(test_method, amqp_queue_name) client = await rpc_client_close() client._transport.close() with pytest.raises(AioamqpException): await client.wait(test_method()) await asyncio.sleep(0.2) test_result = await client.wait(test_method()) test_result == test_data @pytest.mark.asyncio async def test_rpc_server_down( rpc_client_close, rpc_server_close, amqp_queue_name, ): test_data = b'test' calls = 0 @RpcMethod.init(amqp_queue_name) async def test_method(payload): nonlocal calls if not calls: server._transport.close() calls += 1 return payload server = await rpc_server_close(test_method, amqp_queue_name) client = await rpc_client_close() test_result = await client.wait(test_method(test_data)) assert test_result == test_data assert calls == 2 @pytest.mark.asyncio async def test_rpc_marshal_exc( rpc_client_close, rpc_server_close, amqp_queue_name, ): @RpcMethod.init(amqp_queue_name) async def test_method(): pass calls = 0 async def _unmarshal(obj): nonlocal calls calls += 1 raise ValueError test_method.packer.unmarshal = _unmarshal await rpc_server_close( test_method, amqp_queue_name, marshal_exc=ZeroDivisionError, ) client = await rpc_client_close() with pytest.raises(RpcError) as exc_info: await client.wait(test_method()) assert isinstance(exc_info.value.err, ZeroDivisionError) assert calls == 1