import celery as celery_module import mock import pytest from celery import Celery from celery.signals import ( before_task_publish, after_task_publish, task_postrun ) from celery.states import SUCCESS, FAILURE from celery.worker import state as celery_worker_state from kombu import Connection from opentracing.ext import tags from opentracing_instrumentation.client_hooks import celery as celery_hooks CELERY_3 = celery_module.__version__.split('.', 1)[0] == '3' @pytest.fixture(autouse=True, scope='module') def patch_celery(): celery_hooks.install_patches() try: yield finally: celery_hooks.reset_patches() def assert_span(span, result, operation, span_kind): assert span.operation_name == 'Celery:{}:foo'.format(operation) assert span.tags.get(tags.SPAN_KIND) == span_kind assert span.tags.get(tags.COMPONENT) == 'Celery' assert span.tags.get('celery.task_name') == 'foo' assert span.tags.get('celery.task_id') == result.task_id @mock.patch( 'celery.worker.job.logger' if CELERY_3 else 'celery.app.trace.logger' ) def _test_foo_task(celery, task_error, celery_logger): @celery.task(name='foo') def foo(): foo.called = True if task_error: raise ValueError('Task error') foo.called = False result = foo.delay() assert foo.called if task_error: assert result.status == FAILURE if not ( CELERY_3 and celery.conf.defaults[0].get('CELERY_ALWAYS_EAGER') ): celery_logger.log.assert_called_once() else: assert result.status == SUCCESS celery_logger.log.assert_not_called() return result def _test_with_instrumented_client(celery, tracer, task_error): result = _test_foo_task(celery, task_error) span_server, span_client = tracer.recorder.get_spans() assert span_client.parent_id is None assert span_client.context.trace_id == span_server.context.trace_id assert span_client.context.span_id == span_server.parent_id assert_span(span_client, result, 'apply_async', tags.SPAN_KIND_RPC_CLIENT) assert_span(span_server, result, 'run', tags.SPAN_KIND_RPC_SERVER) @mock.patch( 'celery.app.task.Task.apply_async', new=celery_hooks._task_apply_async ) def _test_with_regular_client(celery, tracer, task_error): before_task_publish.disconnect(celery_hooks.before_task_publish_handler) try: result = _test_foo_task(celery, task_error) spans = tracer.recorder.get_spans() assert len(spans) == 1 span = spans[0] assert span.parent_id is None assert_span(span, result, 'run', tags.SPAN_KIND_RPC_SERVER) finally: before_task_publish.connect(celery_hooks.before_task_publish_handler) TEST_METHODS = _test_with_instrumented_client, _test_with_regular_client def is_rabbitmq_running(): try: Connection('amqp://guest:guest@127.0.0.1:5672//').connect() return True except: return False @pytest.mark.skipif(not is_rabbitmq_running(), reason='RabbitMQ is not running or cannot connect') @pytest.mark.parametrize('task_error', (False, True)) @pytest.mark.parametrize('test_method', TEST_METHODS) def test_celery_with_rabbitmq(test_method, tracer, task_error): celery = Celery( 'test', # For Celery 3.x we have to use rpc:// to get the results # because with Redis we can get only PENDING for the status. # For Celery 4.x we need redis:// since with RPC we can # correctly assert status only for the first one task. # Feel free to suggest a better solution here. backend='rpc://' if CELERY_3 else 'redis://', # avoiding CDeprecationWarning changes={ 'CELERY_ACCEPT_CONTENT': ['pickle', 'json'], } ) @after_task_publish.connect def run_worker(**kwargs): celery_worker_state.should_stop = False after_task_publish.disconnect(run_worker) worker = celery.Worker(concurrency=1, pool_cls='solo', use_eventloop=False, prefetch_multiplier=1, quiet=True, without_heartbeat=True) @task_postrun.connect def stop_worker_soon(**kwargs): celery_worker_state.should_stop = True task_postrun.disconnect(stop_worker_soon) if hasattr(worker.consumer, '_pending_operations'): # Celery 4.x def stop_worker(): # avoiding AttributeError that makes tests noisy worker.consumer.connection.drain_events = mock.Mock() worker.stop() # worker must be stopped not earlier than # data exchange with RabbitMQ is completed worker.consumer._pending_operations.insert(0, stop_worker) else: # Celery 3.x worker.stop() worker.start() test_method(celery, tracer, task_error) @pytest.fixture def celery_eager(): celery = Celery('test') celery.config_from_object({ 'task_always_eager': True, # Celery 4.x 'CELERY_ALWAYS_EAGER': True, # Celery 3.x }) return celery @pytest.mark.parametrize('task_error', (False, True)) @pytest.mark.parametrize('test_method', TEST_METHODS) def test_celery_eager(test_method, celery_eager, tracer, task_error): test_method(celery_eager, tracer, task_error) @mock.patch.object(celery_hooks, 'patcher') def test_set_custom_patcher(default_patcher): patcher = mock.Mock() celery_hooks.set_patcher(patcher) assert celery_hooks.patcher is not default_patcher assert celery_hooks.patcher is patcher celery_hooks.install_patches() celery_hooks.reset_patches() patcher.install_patches.assert_called_once() patcher.reset_patches.assert_called_once()