from __future__ import absolute_import

import opentracing
from opentracing.ext import tags

from ..request_context import get_current_span, span_in_context
from ._patcher import Patcher


try:
    from celery.app.task import Task
    from celery.signals import (
        before_task_publish, task_prerun, task_success, task_failure
    )
except ImportError:
    pass
else:
    _task_apply_async = Task.apply_async


def task_apply_async_wrapper(task, args=None, kwargs=None, **other_kwargs):
    operation_name = 'Celery:apply_async:{}'.format(task.name)
    span = opentracing.tracer.start_span(operation_name=operation_name,
                                         child_of=get_current_span())
    set_common_tags(span, task, tags.SPAN_KIND_RPC_CLIENT)

    with span_in_context(span), span:
        result = _task_apply_async(task, args, kwargs, **other_kwargs)
        span.set_tag('celery.task_id', result.task_id)
        return result


def set_common_tags(span, task, span_kind):
    span.set_tag(tags.SPAN_KIND, span_kind)
    span.set_tag(tags.COMPONENT, 'Celery')
    span.set_tag('celery.task_name', task.name)


def before_task_publish_handler(headers, **kwargs):
    headers['parent_span_context'] = span_context = {}
    opentracing.tracer.inject(span_context=get_current_span().context,
                              format=opentracing.Format.TEXT_MAP,
                              carrier=span_context)


def task_prerun_handler(task, task_id, **kwargs):
    request = task.request

    operation_name = 'Celery:run:{}'.format(task.name)
    child_of = None
    if request.delivery_info.get('is_eager'):
        child_of = get_current_span()
    else:
        if getattr(request, 'headers', None) is not None:
            # Celery 3.x
            parent_span_context = request.headers.get('parent_span_context')
        else:
            # Celery 4.x
            parent_span_context = getattr(request, 'parent_span_context', None)

        if parent_span_context:
            child_of = opentracing.tracer.extract(
                opentracing.Format.TEXT_MAP, parent_span_context
            )

    task.request.span = span = opentracing.tracer.start_span(
        operation_name=operation_name,
        child_of=child_of,
    )
    set_common_tags(span, task, tags.SPAN_KIND_RPC_SERVER)
    span.set_tag('celery.task_id', task_id)

    request.tracing_context = span_in_context(span)
    request.tracing_context.__enter__()


def finish_current_span(task, exc_type=None, exc_val=None, exc_tb=None):
    task.request.span.finish()
    task.request.tracing_context.__exit__(exc_type, exc_val, exc_tb)


def task_success_handler(sender, **kwargs):
    finish_current_span(task=sender)


def task_failure_handler(sender, exception, traceback, **kwargs):
    finish_current_span(
        task=sender,
        exc_type=type(exception),
        exc_val=exception,
        exc_tb=traceback,
    )


class CeleryPatcher(Patcher):
    applicable = '_task_apply_async' in globals()

    def _install_patches(self):
        Task.apply_async = task_apply_async_wrapper
        before_task_publish.connect(before_task_publish_handler)
        task_prerun.connect(task_prerun_handler)
        task_success.connect(task_success_handler)
        task_failure.connect(task_failure_handler)

    def _reset_patches(self):
        Task.apply_async = _task_apply_async
        before_task_publish.disconnect(before_task_publish_handler)
        task_prerun.disconnect(task_prerun_handler)
        task_success.disconnect(task_success_handler)
        task_failure.disconnect(task_failure_handler)


CeleryPatcher.configure_hook_module(globals())