# Copyright 2016 Hynek Schlawack # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio import http.client import inspect import uuid from unittest import mock import pytest from prometheus_client import Counter from prometheus_async import aio from prometheus_async.aio.sd import ConsulAgent, _LocalConsulAgentClient try: import aiohttp except ImportError: aiohttp = None async def coro(): await asyncio.sleep(0) class TestTime: @pytest.mark.asyncio async def test_still_coroutine_function(self, fo): """ It's ensured that a decorated function still passes as a coroutine function. Otherwise PYTHONASYNCIODEBUG=1 breaks. """ func = aio.time(fo)(coro) new_coro = func() assert inspect.iscoroutine(new_coro) assert inspect.iscoroutinefunction(func) await new_coro @pytest.mark.asyncio async def test_decorator_sync(self, fo, patch_timer): """ time works with sync results functions. """ @aio.time(fo) async def func(): if True: return 42 else: await asyncio.sleep(0) assert 42 == await func() assert [1] == fo._observed @pytest.mark.asyncio async def test_decorator(self, fo, patch_timer): """ time works with asyncio results functions. """ @aio.time(fo) async def func(): await asyncio.sleep(0) return 42 rv = func() assert asyncio.iscoroutine(rv) assert [] == fo._observed rv = await rv assert [1] == fo._observed assert 42 == rv @pytest.mark.asyncio async def test_decorator_exc(self, fo, patch_timer): """ Does not swallow exceptions. """ v = ValueError("foo") @aio.time(fo) async def func(): await asyncio.sleep(0) raise v with pytest.raises(ValueError) as e: await func() assert v is e.value assert [1] == fo._observed @pytest.mark.asyncio async def test_future(self, fo, patch_timer): """ time works with a asyncio.Future. """ fut = asyncio.Future() coro = aio.time(fo, fut) assert [] == fo._observed fut.set_result(42) assert 42 == await coro assert [1] == fo._observed @pytest.mark.asyncio async def test_future_exc(self, fo, patch_timer): """ Does not swallow exceptions. """ fut = asyncio.Future() coro = aio.time(fo, fut) v = ValueError("foo") assert [] == fo._observed fut.set_exception(v) with pytest.raises(ValueError) as e: await coro assert [1] == fo._observed assert v is e.value class TestCountExceptions: @pytest.mark.asyncio async def test_decorator_no_exc(self, fc): """ If no exception is raised, the counter does not change. """ @aio.count_exceptions(fc) async def func(): await asyncio.sleep(0.0) return 42 assert 42 == await func() assert 0 == fc._val @pytest.mark.asyncio async def test_decorator_wrong_exc(self, fc): """ If a wrong exception is raised, the counter does not change. """ @aio.count_exceptions(fc, exc=ValueError) async def func(): await asyncio.sleep(0.0) raise Exception() with pytest.raises(Exception): await func() assert 0 == fc._val @pytest.mark.asyncio async def test_decorator_exc(self, fc): """ If the correct exception is raised, count it. """ @aio.count_exceptions(fc, exc=ValueError) async def func(): await asyncio.sleep(0.0) raise ValueError() with pytest.raises(ValueError): await func() assert 1 == fc._val @pytest.mark.asyncio async def test_future_no_exc(self, fc): """ If no exception is raised, the counter does not change. """ fut = asyncio.Future() coro = aio.count_exceptions(fc, future=fut) fut.set_result(42) assert 42 == await coro assert 0 == fc._val @pytest.mark.asyncio async def test_future_wrong_exc(self, fc): """ If a wrong exception is raised, the counter does not change. """ fut = asyncio.Future() coro = aio.count_exceptions(fc, exc=ValueError, future=fut) exc = Exception() fut.set_exception(exc) with pytest.raises(Exception): assert 42 == await coro assert 0 == fc._val @pytest.mark.asyncio async def test_future_exc(self, fc): """ If the correct exception is raised, count it. """ fut = asyncio.Future() coro = aio.count_exceptions(fc, exc=ValueError, future=fut) exc = ValueError() fut.set_exception(exc) with pytest.raises(Exception): assert 42 == await coro assert 1 == fc._val class TestTrackInprogress: @pytest.mark.asyncio async def test_coroutine(self, fg): """ Incs and decs. """ f = aio.track_inprogress(fg)(coro) await f() assert 0 == fg._val assert 2 == fg._calls @pytest.mark.asyncio async def test_future(self, fg): """ Incs and decs. """ fut = asyncio.Future() wrapped = aio.track_inprogress(fg, fut) assert 1 == fg._val fut.set_result(42) await wrapped assert 0 == fg._val class FakeSD: """ Fake Service Discovery. """ registered_ms = None async def register(self, metrics_server): self.registered_ms = metrics_server async def deregister(): return True return deregister @pytest.mark.skipif(aiohttp is None, reason="Needs aiohttp.") class TestWeb: @pytest.mark.asyncio async def test_server_stats(self): """ Returns a response with the current stats. """ Counter("test_server_stats_total", "cnt").inc() rv = await aio.web.server_stats(None) assert ( b"# HELP test_server_stats_total cnt\n# TYPE " b"test_server_stats_total counter\n" b"test_server_stats_total 1.0\n" in rv.body ) @pytest.mark.asyncio async def test_cheap(self): """ Returns a simple string. """ rv = await aio.web._cheap(None) assert ( b'<html><body><a href="/metrics">Metrics</a></body></html>' == rv.body ) assert "text/html" == rv.content_type @pytest.mark.asyncio @pytest.mark.parametrize("sd", [None, FakeSD()]) async def test_start_http_server(self, sd): """ Integration test: server gets started, is registered, and serves stats. """ server = await aio.web.start_http_server( addr="127.0.0.1", service_discovery=sd ) assert isinstance(server, aio.web.MetricsHTTPServer) assert server.is_registered is (sd is not None) if sd is not None: assert sd.registered_ms is server addr, port = server.socket Counter("test_start_http_server_total", "cnt").inc() async with aiohttp.ClientSession() as s: rv = await s.request( "GET", "http://{addr}:{port}/metrics".format(addr=addr, port=port), ) body = await rv.text() assert ( "# HELP test_start_http_server_total cnt\n# " "TYPE test_start_http_server_total" " counter\ntest_start_http_server_total 1.0\n" in body ) await server.close() @pytest.mark.parametrize("sd", [None, FakeSD()]) def test_start_in_thread(self, sd): """ Threaded version starts and exits properly, passes on service discovery. """ Counter("test_start_http_server_in_thread_total", "cnt").inc() t = aio.web.start_http_server_in_thread( addr="127.0.0.1", service_discovery=sd ) assert isinstance(t, aio.web.ThreadedMetricsHTTPServer) assert "PrometheusAsyncWebEndpoint" == t._thread.name assert t.url.startswith("http") assert False is t.https assert t.is_registered is (sd is not None) if sd is not None: assert sd.registered_ms is t._http_server s = t.socket h = http.client.HTTPConnection(s.addr, port=s[1]) h.request("GET", "/metrics") rsp = h.getresponse() body = rsp.read().decode() rsp.close() h.close() assert "HELP test_start_http_server_in_thread_total cnt" in body t.close() assert False is t._thread.is_alive() @pytest.mark.asyncio @pytest.mark.parametrize("addr,url", [("127.0.0.1", "127.0.0.1:")]) async def test_url(self, addr, url): """ The URL of a MetricsHTTPServer is correctly computed. """ server = await aio.web.start_http_server(addr=addr) sock = server.socket part = url + str(sock.port) + "/" assert "http://" + part == server.url server.https = True assert "https://" + part == server.url await server.close() class TestNeedsAioHTTP: @pytest.mark.skipif(aiohttp is None, reason="Needs aiohttp.") def test_present(self): """ If aiohttp is present, the original object is returned. """ o = object() assert o is aio.web._needs_aiohttp(o) @pytest.mark.skipif(aiohttp is not None, reason="Needs missing aiohttp.") def test_missing(self): """ If aiohttp is missing, raise RuntimeError if called. """ with pytest.raises(RuntimeError) as e: aio.web._needs_aiohttp(coro)() assert "'coro' requires aiohttp." == str(e.value) @pytest.mark.skipif(aiohttp is None, reason="Needs aiohttp.") class TestConsulAgent: @pytest.mark.parametrize("deregister", [True, False]) @pytest.mark.asyncio async def test_integration(self, deregister): """ Integration test with a real consul agent. Start a service, register it, close it, verify it's deregistered. """ tags = ("foo", "bar") service_id = str(uuid.uuid4()) # allow for parallel tests con = _LocalConsulAgentClient(token=None) ca = ConsulAgent( name="test-metrics", service_id=service_id, tags=tags, deregister=deregister, ) try: server = await aio.web.start_http_server( addr="127.0.0.1", service_discovery=ca ) except aiohttp.ClientOSError: pytest.skip("Missing consul agent.") svc = (await con.get_services())[service_id] assert "test-metrics" == svc["Service"] assert sorted(tags) == sorted(svc["Tags"]) assert server.socket.addr == svc["Address"] assert server.socket.port == svc["Port"] await server.close() services = await con.get_services() if deregister: # Assert service is gone iff we are supposed to deregister. assert service_id not in services else: assert service_id in services # Clean up behind ourselves. resp = await con.deregister_service(service_id) assert 200 == resp.status @pytest.mark.asyncio async def test_loop_warns(self, event_loop): """ If a loop is passed, raise a DeprecationWarning. """ with pytest.deprecated_call(): server = await aio.web.start_http_server(loop=event_loop) await server.close() @pytest.mark.asyncio async def test_none_if_register_fails(self): """ If register fails, return None. """ class FakeMetricsServer: socket = mock.Mock(addr="127.0.0.1", port=12345) url = "http://127.0.0.1:12345/metrics" class FakeSession: async def __aexit__(self, exc_type, exc_value, traceback): pass async def __aenter__(self): class FakeConnection: async def put(self, *args, **kw): return mock.Mock(status=400) return FakeConnection() ca = ConsulAgent() ca.consul.session_factory = FakeSession assert None is (await ca.register(FakeMetricsServer())) @pytest.mark.skipif(aiohttp is None, reason="Needs aiohttp.") class TestLocalConsulAgentClient: def test_sets_headers(self): """ If a token is passed, "X-Consul-Token" header is set. """ con = _LocalConsulAgentClient(token="token42") assert "token42" == con.headers["X-Consul-Token"]