from future import standard_library
standard_library.install_aliases()
from future.builtins import str, object
import datetime
from bson import ObjectId
from redis.exceptions import LockError
import time
from .exceptions import RetryInterrupt, MaxRetriesInterrupt, AbortInterrupt, MaxConcurrencyInterrupt
from .utils import load_class_by_path, group_iter
import gevent
import objgraph
import random
import gc
from collections import defaultdict
import traceback
import sys
import urllib.parse
import re
import linecache
import fnmatch
import encodings
import copyreg
from . import context


FINAL_STATUSES = {"timeout", "abort", "failed", "success", "interrupt", "retry", "maxretries", "maxconcurrency"}
TRANSIENT_STATUSES = {"cancel", "queued", "started"}


class Job(object):

    timeout = None
    result_ttl = None
    abort_ttl = None
    cancel_ttl = None
    max_retries = None
    retry_delay = None

    # All values above can be overrided from the TASKS config

    progress = 0

    _memory_stop = 0
    _memory_start = 0

    _current_io = None

    # Has this job been inserted in MongoDB yet?
    stored = None

    # List of statuses that don't trigger a storage of this task in MongoDB on raw queues.
    # In task config, use ("started", "success") to avoid storing successful raw tasks at all
    statuses_no_storage = None

    def __init__(self, job_id, queue=None, start=False, fetch=False):
        self.worker = context.get_current_worker()
        self.queue = queue
        self.datestarted = None

        self.collection = context.connections.mongodb_jobs.mrq_jobs

        if job_id is None:
            self.id = None
        else:
            if isinstance(job_id, bytes):
                self.id = ObjectId(job_id.decode('utf-8'))
            else:
                self.id = ObjectId(job_id)

        self.data = None
        self.saved = True

        self.task = None
        self.greenlet_switches = 0
        self.greenlet_time = 0

        self._trace_mongodb = defaultdict(int)

        if start:
            self.fetch(start=True, full_data=False)
        elif fetch:
            self.fetch(start=False, full_data=False)

    @property
    def redis_max_concurrency_key(self):
        """ Returns the global redis key used to control job concurrency """
        return "%s:c:%s" % (context.get_current_config()["redis_prefix"], self.data["path"])

    def exists(self):
        """ Returns True if a job with the current _id exists in MongoDB. """
        return bool(self.collection.find_one({"_id": self.id}, projection={"_id": 1}))

    def fetch(self, start=False, full_data=True):
        """ Get the current job data and possibly flag it as started. """

        if self.id is None:
            return self

        if full_data is True:
            fields = None
        elif isinstance(full_data, dict):
            fields = full_data
        else:
            fields = {
                "_id": 0,
                "path": 1,
                "params": 1,
                "status": 1,
                "retry_count": 1,
            }

        if start:
            self.datestarted = datetime.datetime.utcnow()
            self.set_data(self.collection.find_and_modify(
                {
                    "_id": self.id,
                    "status": {"$nin": ["cancel", "abort", "maxretries"]}
                },
                {"$set": {
                    "status": "started",
                    "datestarted": self.datestarted,
                    "worker": self.worker.id
                },
                "$unset": {
                    "dateexpires": 1 # we don't want started jobs to expire unexpectedly
                }},
                projection=fields)
            )

            context.metric("jobs.status.started")

        else:
            self.set_data(self.collection.find_one({
                "_id": self.id
            }, projection=fields))

        if self.data is None:
            context.log.info(
                "Job %s not found in MongoDB or status was cancelled!" %
                self.id)

        self.stored = True

        return self

    def get_task_config(self):
        cfg = context.get_current_config()
        return cfg.get("tasks", {}).get(
            self.data["path"]
        ) or {}

    def set_data(self, data):
        self.data = data
        if self.data is None:
            return

        if "path" in self.data:
            cfg = context.get_current_config()
            task_def = self.get_task_config()

            self.timeout = task_def.get("timeout", cfg["default_job_timeout"])
            self.default_ttl = task_def.get("default_ttl", cfg["default_job_ttl"])
            self.result_ttl = task_def.get("result_ttl", cfg["default_job_result_ttl"]) # success ttl
            self.abort_ttl = task_def.get("abort_ttl", cfg["default_job_abort_ttl"])
            self.cancel_ttl = task_def.get("cancel_ttl", cfg["default_job_cancel_ttl"])
            self.max_retries = task_def.get("max_retries", cfg["default_job_max_retries"])
            self.retry_delay = task_def.get("retry_delay", cfg["default_job_retry_delay"])

    def set_progress(self, ratio, save=False):
        self.data["progress"] = ratio
        self.saved = False

        # If not saved, will be updated in the next worker report
        if save:
            self.save()

    def save(self):
        """ Persists the current job metadata to MongoDB. Will be called at each worker report. """

        if not self.saved and self.data and "progress" in self.data:
            # TODO should we save more fields?
            self.collection.update({"_id": self.id}, {"$set": {
                "progress": self.data["progress"]
            }})
            self.saved = True

    @classmethod
    def insert(cls, jobs_data, queue=None, statuses_no_storage=None, return_jobs=True, w=None, j=None):
        """ Insert a job into MongoDB """

        now = datetime.datetime.utcnow()
        for data in jobs_data:
            if data["status"] == "started":
                data["datestarted"] = now

        no_storage = (statuses_no_storage is not None) and ("started" in statuses_no_storage)
        if no_storage and return_jobs:
            for data in jobs_data:
                data["_id"] = ObjectId()  # Give the job a temporary ID
        else:
            inserted = context.connections.mongodb_jobs.mrq_jobs.insert(
                jobs_data,
                manipulate=True,
                w=w,
                j=j
            )

        if return_jobs:
            jobs = []
            for data in jobs_data:
                job = cls(data["_id"], queue=queue)
                job.set_data(data)
                job.statuses_no_storage = statuses_no_storage
                job.stored = (not no_storage)
                if data["status"] == "started":
                    job.datestarted = data["datestarted"]
                jobs.append(job)

            return jobs
        else:
            return inserted

    def _attach_original_exception(self, exc):
        """ Often, a retry will be raised inside an "except" block.
            This Keep track of the first exception for debugging purposes """

        original_exception = sys.exc_info()
        if original_exception[0] is not None:
            exc.original_exception = original_exception

    def retry(self, queue=None, delay=None, max_retries=None):
        """ Marks the current job as needing to be retried. Interrupts it. """

        max_retries = max_retries
        if max_retries is None:
            max_retries = self.max_retries

        if self.data.get("retry_count", 0) >= max_retries:
            raise MaxRetriesInterrupt()

        exc = RetryInterrupt()

        exc.queue = queue or self.queue or self.data.get("queue") or "default"
        exc.retry_count = self.data.get("retry_count", 0) + 1
        exc.delay = delay
        if exc.delay is None:
            exc.delay = self.retry_delay

        self._attach_original_exception(exc)

        raise exc

    def abort(self):
        """ Aborts the current task mid-excution. """
        exc = AbortInterrupt()
        self._attach_original_exception(exc)
        raise exc

    def cancel(self):
        """ Markes the current job as cancelled. Doesn't interrupt it. """
        self._save_status("cancel")

    def requeue(self, queue=None, retry_count=0):
        """ Requeues the current job. Doesn't interrupt it """

        if not queue:
            if not self.data or not self.data.get("queue"):
                self.fetch(full_data={"_id": 0, "queue": 1, "path": 1})
            queue = self.data["queue"]

        self._save_status("queued", updates={
            "queue": queue,
            "datequeued": datetime.datetime.utcnow(),
            "retry_count": retry_count
        })

    def perform(self):
        """ Loads and starts the main task for this job, the saves the result. """

        if self.data is None:
            return

        context.log.debug("Starting %s(%s)" % (self.data["path"], self.data["params"]))
        task_class = load_class_by_path(self.data["path"])

        self.task = task_class()

        self.task.is_main_task = True

        if not self.task.max_concurrency:

            result = self.task.run_wrapped(self.data["params"])

        else:

            if self.task.max_concurrency > 1:
                raise NotImplementedError()

            lock = None
            try:

                # TODO: implement a semaphore
                lock = context.connections.redis.lock(self.redis_max_concurrency_key, timeout=self.timeout + 5)
                if not lock.acquire(blocking=True, blocking_timeout=0):
                    raise MaxConcurrencyInterrupt()

                result = self.task.run_wrapped(self.data["params"])

            finally:
                try:
                    if lock:
                        lock.release()
                except LockError:
                    pass

        self.save_success(result)

        if context.get_current_config().get("trace_greenlets"):

            # TODO: this is not the exact greenlet_time measurement because it doesn't
            # take into account the last switch's time. This is why we force a last switch.
            # This does cause a performance overhead. Instead, we should print the
            # last timing directly from the trace() function in context?

            # pylint: disable=protected-access

            gevent.sleep(0)
            current_greenlet = gevent.getcurrent()
            t = (datetime.datetime.utcnow() - self.datestarted).total_seconds()

            context.log.debug(
                "Job %s success: %0.6fs total, %0.6fs in greenlet, %s switches" %
                (self.id,
                 t,
                 current_greenlet._trace_time,
                 current_greenlet._trace_switches - 1)
            )

        else:
            context.log.debug("Job %s success: %0.6fs total" % (
                self.id, (datetime.datetime.utcnow() -
                          self.datestarted).total_seconds()
            ))

        return result

    def wait(self, poll_interval=1, timeout=None, full_data=False):
        """ Wait for this job to finish. """

        end_time = None
        if timeout:
            end_time = time.time() + timeout

        while end_time is None or time.time() < end_time:

            job_data = self.collection.find_one({
                "_id": ObjectId(self.id),
                "status": {"$nin": ["started", "queued"]}
            }, projection=({
                "_id": 0,
                "result": 1,
                "status": 1
            } if not full_data else None))

            if job_data:
                return job_data

            time.sleep(poll_interval)
        raise Exception("Waited for job result for %s seconds, timeout." % timeout)

    def kill(self, block=False, reason="unknown"):
        """ Forcefully kill all greenlets associated with this job """

        current_greenletid = id(gevent.getcurrent())

        trace = "Job killed: %s" % reason
        for greenlet, job in context._GLOBAL_CONTEXT["greenlets"].values():
            greenletid = id(greenlet)
            if job and job.id == self.id and greenletid != current_greenletid:
                greenlet.kill(block=block)
                trace += "\n\n--- Greenlet %s ---\n" % greenletid
                trace += "".join(traceback.format_stack(greenlet.gr_frame))
            context._GLOBAL_CONTEXT["greenlets"].pop(greenletid, None)

        if reason == "timeout" and self.data["status"] != "timeout":
            updates = {
                "exceptiontype": "TimeoutInterrupt",
                "traceback": trace
            }
            self._save_status("timeout", updates=updates, exception=False)

    def save_retry(self, retry_exc):

        # If delay=0, requeue right away, don't go through the "retry" status
        if retry_exc.delay == 0:
            self.requeue(queue=retry_exc.queue, retry_count=retry_exc.retry_count)

        else:

            dateretry = datetime.datetime.utcnow() + datetime.timedelta(seconds=retry_exc.delay)
            updates = {
                "dateretry": dateretry,
                "queue": retry_exc.queue,
                "retry_count": retry_exc.retry_count
            }

            self._save_status("retry", updates, exception=True)

    def _save_traceback_history(self, status, trace, job_exc):
        """ Create traceback history or add a new traceback to history. """
        failure_date = datetime.datetime.utcnow()

        new_history = {
            "date": failure_date,
            "status": status,
            "exceptiontype": job_exc.__name__
        }

        traces = trace.split("---- Original exception: -----")
        if len(traces) > 1:
            new_history["original_traceback"] = traces[1]
        worker = context.get_current_worker()
        if worker:
            new_history["worker"] = worker.id
        new_history["traceback"] = traces[0]
        self.collection.update({
            "_id": self.id
        }, {"$push": {"traceback_history": new_history}})

    def save_success(self, result=None):

        dateexpires = datetime.datetime.utcnow() + datetime.timedelta(seconds=self.result_ttl)
        updates = {
            "dateexpires": dateexpires
        }
        if result is not None:
            updates["result"] = result
        if "progress" in self.data:
            updates["progress"] = 1

        self._save_status("success", updates)

    def save_cancel(self):

        dateexpires = datetime.datetime.utcnow() + datetime.timedelta(seconds=self.cancel_ttl)
        updates = {
            "dateexpires": dateexpires
        }

        self._save_status("cancel", updates)

    def save_abort(self):
        dateexpires = datetime.datetime.utcnow() + datetime.timedelta(seconds=self.abort_ttl)
        updates = {
            "dateexpires": dateexpires
        }

        self._save_status("abort", updates, exception=True)

    def _save_status(self, status, updates=None, exception=False, w=None, j=None):

        if self.id is None:
            return

        # Forbid some status transitions
        if self.data and self.data.get("status") in FINAL_STATUSES and status not in TRANSIENT_STATUSES:
            context.log.error("Can't go from status %s to %s" % (self.data["status"], status))
            return

        context.metric("jobs.status.%s" % status)

        if self.stored is False and self.statuses_no_storage is not None and status in self.statuses_no_storage:
            return

        now = datetime.datetime.utcnow()
        db_updates = {
            "status": status,
            "dateupdated": now
        }

        # we don't want started jobs to expire unexpectedly
        if status not in ["started", "success", "abort", "cancel"] and hasattr(self, "default_ttl") and self.default_ttl is not None:
            db_updates["dateexpires"] = (self.data.get("datequeued") or now) + datetime.timedelta(seconds=self.default_ttl)

        db_updates.update(updates or {})

        if self.datestarted:
            db_updates["totaltime"] = (now - self.datestarted).total_seconds()

        if context.get_current_config().get("trace_greenlets"):
            current_greenlet = gevent.getcurrent()

            # TODO are we sure the current job is doing the save_status() on itself?
            if hasattr(current_greenlet, "_trace_time"):
                # pylint: disable=protected-access
                db_updates["time"] = current_greenlet._trace_time
                db_updates["switches"] = current_greenlet._trace_switches

        if exception:
            trace = traceback.format_exc()
            context.log.error(trace)
            exc, value = sys.exc_info()[0:2]
            if hasattr(value, "subpool_traceback"):
                trace = "Exception first caught in a subpool. Traceback:\n%s\n%s" % (value.subpool_traceback, trace)
            db_updates["traceback"] = trace
            db_updates["exceptiontype"] = exc.__name__

        if self.data:
            self.data.update(db_updates)
            # get all data before updating them
            current_queue = (db_updates or {}).get("queue") or self.data["queue"]
            old_queue = self.data.get("queue")
            old_status = self.data.get("status")
            raw_queue = self.data.get("raw_queue")
            retry_count = self.data.get("retry_count", 0)

        # In the most common case, we allow an optimization on Mongo writes
        if status == "success":
            if w is None:
                w = getattr(self.task, "status_success_update_w", None)
            if j is None:
                j = getattr(self.task, "status_success_update_j", None)

        # This job wasn't inserted because "started" is in statuses_no_storage
        # So we must insert it for the first time instead of updating it.
        if self.stored is False:
            db_updates["queue"] = self.data["queue"]
            db_updates["params"] = self.data["params"]
            db_updates["path"] = self.data["path"]
            self.collection.insert(db_updates, w=w, j=j, manipulate=True)
            self.id = db_updates["_id"]  # Persistent ID assigned by the server
            self.stored = True

        else:
            self.collection.update({
                "_id": self.id
            }, {"$set": db_updates}, w=w, j=j, manipulate=False)

        if exception:
            self._save_traceback_history(status, trace, exc)

        if self.data:
            with context.connections.redis.pipeline(transaction=False) as pipe:
                if status != "started":
                    # Queue change
                    if current_queue != old_queue:
                        pipe.decr("queuesize:%s" % old_queue)
                        if status == "queued":
                            pipe.incr("queuesize:%s" % current_queue)

                    # Regular queues
                    elif status == "queued" and old_status != "started":
                        pipe.incr("queuesize:%s" % current_queue)

                    elif status != "queued" and not raw_queue:
                        pipe.decr("queuesize:%s" % current_queue)

                    # Raw queues retries
                    elif (db_updates or {}).get("retry_count", 0) > retry_count:
                        pipe.incr("queuesize:%s" % current_queue)

                    pipe.expire("queuesize:%s" % current_queue, context.get_current_config().get("queue_ttl"))
                pipe.execute()

    def set_current_io(self, io_data):

        # pylint: disable=protected-access

        if io_data is None:
            if not self._current_io:
                return

            t = time.time() - self._current_io["started"]
            if self.worker and self.data.get("path"):
                self.worker._traced_io["types"][self._current_io["type"]] += t
                self.worker._traced_io["tasks"][self.data["path"]] += t
                self.worker._traced_io["total"] += t

            self._current_io = None

        else:
            io_data["started"] = time.time()
            self._current_io = io_data

    def trace_memory_clean_caches(self):
        """ Avoid polluting results with some builtin python caches """

        urllib.parse.clear_cache()
        re.purge()
        linecache.clearcache()
        copyreg.clear_extension_cache()

        if hasattr(fnmatch, "purge"):
            fnmatch.purge()  # pylint: disable=no-member
        elif hasattr(fnmatch, "_purge"):
            fnmatch._purge()  # pylint: disable=no-member

        if hasattr(encodings, "_cache") and len(encodings._cache) > 0:
            encodings._cache = {}

        for handler in context.log.handlers:
            handler.flush()

    def trace_memory_start(self):
        """ Starts measuring memory consumption """

        self.trace_memory_clean_caches()

        objgraph.show_growth(limit=30)

        gc.collect()
        self._memory_start = self.worker.get_memory()["total"]

    def trace_memory_stop(self):
        """ Stops measuring memory consumption """

        self.trace_memory_clean_caches()

        objgraph.show_growth(limit=30)

        trace_type = context.get_current_config()["trace_memory_type"]
        if trace_type:

            filename = '%s/%s-%s.png' % (
                context.get_current_config()["trace_memory_output_dir"],
                trace_type,
                self.id)

            chain = objgraph.find_backref_chain(
                random.choice(
                    objgraph.by_type(trace_type)
                ),
                objgraph.is_proper_module
            )
            objgraph.show_chain(chain, filename=filename)
            del filename
            del chain

        gc.collect()
        self._memory_stop = self.worker.get_memory()["total"]

        diff = self._memory_stop - self._memory_start

        context.log.debug("Memory diff for job %s : %s" % (self.id, diff))

        # We need to update it later than the results, we need them off memory
        # already.
        self.collection.update(
            {"_id": self.id},
            {"$set": {
                "memory_diff": diff
            }},
            w=1
        )


def get_job_result(job_id):
    job = Job(job_id)
    job.fetch(full_data={"result": 1, "status": 1, "_id": 0})
    return job.data


def queue_raw_jobs(queue, params_list, **kwargs):
    """ Queue some jobs on a raw queue """

    from .queue import Queue
    queue_obj = Queue(queue)
    queue_obj.enqueue_raw_jobs(params_list, **kwargs)
    # No need to store queue size as we already have a fast way to get raw queue size

def queue_job(main_task_path, params, **kwargs):
    """ Queue one job on a regular queue """

    return queue_jobs(main_task_path, [params], **kwargs)[0]

def set_queues_size(size_by_queues, action="incr"):
    if len(size_by_queues) > 0:
        with context.connections.redis.pipeline(transaction=False) as pipe:
            action_func = getattr(pipe, action)
            for queue in size_by_queues:
                action_func("queuesize:%s" % queue, amount=size_by_queues[queue])
                pipe.expire("queuesize:%s" % queue, context.get_current_config().get("queue_ttl"))
            pipe.execute()

def queue_jobs(main_task_path, params_list, queue=None, batch_size=1000):
    """ Queue multiple jobs on a regular queue """
    if len(params_list) == 0:
        return []
    if queue is None:
        task_def = context.get_current_config().get("tasks", {}).get(main_task_path) or {}
        queue = task_def.get("queue", "default")

    from .queue import Queue
    queue_obj = Queue(queue)

    if queue_obj.is_raw:
        raise Exception("Can't queue regular jobs on a raw queue")

    all_ids = []

    for params_group in group_iter(params_list, n=batch_size):

        context.metric("jobs.status.queued", len(params_group))

        # Insert the job in MongoDB
        job_ids = Job.insert([{
            "path": main_task_path,
            "params": params,
            "queue": queue,
            "datequeued": datetime.datetime.utcnow(),
            "status": "queued"
        } for params in params_group], w=1, return_jobs=False)

        all_ids += job_ids

    queue_obj.notify(len(all_ids))
    set_queues_size({queue: len(all_ids)})

    return all_ids