#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
The joblib spark backend implementation.
"""
import warnings
from multiprocessing.pool import ThreadPool
import uuid
from distutils.version import LooseVersion

from joblib.parallel \
    import AutoBatchingMixin, ParallelBackendBase, register_parallel_backend, SequentialBackend
from joblib._parallel_backends import SafeFunction

import pyspark
from pyspark.sql import SparkSession
from pyspark import cloudpickle
from pyspark.util import VersionUtils


def register():
    """
    Register joblib spark backend.
    """
    try:
        import sklearn  # pylint: disable=C0415
        if LooseVersion(sklearn.__version__) < LooseVersion('0.21'):
            warnings.warn("Your sklearn version is < 0.21, but joblib-spark only support "
                          "sklearn >=0.21 . You can upgrade sklearn to version >= 0.21 to "
                          "make sklearn use spark backend.")
    except ImportError:
        pass
    register_parallel_backend('spark', SparkDistributedBackend)


class SparkDistributedBackend(ParallelBackendBase, AutoBatchingMixin):
    """A ParallelBackend which will execute all batches on spark.

    This backend will launch one spark job for task batch. Multiple spark job will run parallelly.
    The maximum parallelism won't exceed `sparkContext._jsc.sc().maxNumConcurrentTasks()`

    Each task batch will be run inside one spark task on worker node, and will be executed
    by `SequentialBackend`
    """

    def __init__(self, **backend_args):
        super(SparkDistributedBackend, self).__init__(**backend_args)
        self._pool = None
        self._n_jobs = None
        self._spark = SparkSession \
            .builder \
            .appName("JoblibSparkBackend") \
            .getOrCreate()
        self._job_group = "joblib-spark-job-group-" + str(uuid.uuid4())

    def _cancel_all_jobs(self):
        if VersionUtils.majorMinorVersion(pyspark.__version__)[0] < 3:
            # Note: There's bug existing in `sparkContext.cancelJobGroup`.
            # See https://issues.apache.org/jira/browse/SPARK-31549
            warnings.warn("For spark version < 3, pyspark cancelling job API has bugs, "
                          "so we could not terminate running spark jobs correctly. "
                          "See https://issues.apache.org/jira/browse/SPARK-31549 for reference.")
        else:
            self._spark.sparkContext.cancelJobGroup(self._job_group)

    def effective_n_jobs(self, n_jobs):
        max_num_concurrent_tasks = self._get_max_num_concurrent_tasks()
        if n_jobs is None:
            n_jobs = 1
        elif n_jobs == -1:
            # n_jobs=-1 means requesting all available workers
            n_jobs = max_num_concurrent_tasks
        if n_jobs > max_num_concurrent_tasks:
            warnings.warn("User-specified n_jobs ({n}) is greater than the max number of "
                          "concurrent tasks ({c}) this cluster can run now. If dynamic "
                          "allocation is enabled for the cluster, you might see more "
                          "executors allocated."
                          .format(n=n_jobs, c=max_num_concurrent_tasks))
        return n_jobs

    def _get_max_num_concurrent_tasks(self):
        # maxNumConcurrentTasks() is a package private API
        # pylint: disable=W0212
        return self._spark.sparkContext._jsc.sc().maxNumConcurrentTasks()

    def abort_everything(self, ensure_ready=True):
        self._cancel_all_jobs()
        if ensure_ready:
            self.configure(n_jobs=self.parallel.n_jobs, parallel=self.parallel,
                           **self.parallel._backend_args)  # pylint: disable=W0212

    def terminate(self):
        self._cancel_all_jobs()

    def configure(self, n_jobs=1, parallel=None, prefer=None, require=None, **backend_args):
        n_jobs = self.effective_n_jobs(n_jobs)
        self._n_jobs = n_jobs
        self.parallel = parallel  # pylint: disable=attribute-defined-outside-init
        return n_jobs

    def _get_pool(self):
        """Lazily initialize the thread pool
        The actual pool of worker threads is only initialized at the first
        call to apply_async.
        """
        if self._pool is None:
            self._pool = ThreadPool(self._n_jobs)
        return self._pool

    def apply_async(self, func, callback=None):
        # Note the `func` args is a batch here. (BatchedCalls type)
        # See joblib.parallel.Parallel._dispatch
        def run_on_worker_and_fetch_result():
            # TODO: handle possible spark exception here. # pylint: disable=fixme
            rdd = self._spark.sparkContext.parallelize([0], 1) \
                .map(lambda _: cloudpickle.dumps(func()))
            if VersionUtils.majorMinorVersion(pyspark.__version__)[0] < 3:
                ser_res = rdd.collect()[0]
            else:
                ser_res = rdd.collectWithJobGroup(self._job_group, "joblib spark jobs")[0]
            return cloudpickle.loads(ser_res)

        return self._get_pool().apply_async(
            SafeFunction(run_on_worker_and_fetch_result),
            callback=callback
        )

    def get_nested_backend(self):
        """Backend instance to be used by nested Parallel calls.
           For nested backend, always use `SequentialBackend`
        """
        nesting_level = getattr(self, 'nesting_level', 0) + 1
        return SequentialBackend(nesting_level=nesting_level), None