import importlib
import logging
import logging.config
import uuid
import os
import tempfile
import time

import cloudpickle
import json
import skein
import tensorflow as tf

from typing import (
    Dict,
    Optional,
    Tuple,
    NamedTuple,
    Callable,
    Union,
    List
)
from contextlib import suppress, contextmanager
from functools import partial

from threading import Thread
from datetime import timedelta
from skein.exceptions import SkeinError
from skein.model import FinalStatus, ApplicationReport, ACLs

import cluster_pack

from tf_yarn import (
    _env,
    _internal,
    cluster,
    constants,
    metrics,
    evaluator_metrics,
    mlflow,
    tensorboard,
    event,
    experiment,
    topologies
)

YARN_LOG_TRIES = 15

ExperimentFn = Callable[[], experiment.Experiment]

TASK_SPEC_NONE = topologies.single_server_topology()

logger = logging.getLogger(__name__)

here = os.path.dirname(__file__)


class SkeinCluster(NamedTuple):
    client: skein.Client
    app: skein.ApplicationClient
    tasks: List[Tuple[str, int]]
    event_listener: Thread
    events: Dict[str, Dict[str, str]]


class ContainerLogStatus(NamedTuple):
    log_urls: Dict[str, str] = dict()
    container_status: Dict[str, str] = dict()

    def by_container_id(self) -> Dict[str, Tuple[str, str]]:
        containers: Dict[str, Tuple[str, str]] = {}
        if len(self.log_urls.keys()) != len(self.container_status.keys()):
            logger.warning("logs_urls and container_status dicts have not the same length")
            return containers

        for task, url, status in zip(self.log_urls.keys(),
                                     self.log_urls.values(),
                                     self.container_status.values()):
            container_id = self._get_container_id(url)
            containers[container_id] = (task, status)

        return containers

    def _get_container_id(self, url: str) -> str:
        if not url:
            return ""
        url_components = url.split("/")
        if len(url_components) > 1:
            return url_components[-2]

        return ""


class RunFailed(Exception):
    """``run_on_yarn`` failed."""


def _setup_pyenvs(
        pyenv_zip_path: Union[str, Dict[topologies.NodeLabel, str]]
) -> Dict[topologies.NodeLabel, _env.PythonEnvDescription]:
    if isinstance(pyenv_zip_path, str):
        pyenv = _env.gen_pyenv_from_existing_archive(pyenv_zip_path)
        pyenvs = {
            topologies.NodeLabel.CPU: pyenv,
            topologies.NodeLabel.GPU: pyenv
        }
    else:
        pyenvs = {label: _env.gen_pyenv_from_existing_archive(env_zip_path)
                  for label, env_zip_path in pyenv_zip_path.items()}
    return pyenvs


def _setup_task_env(
        tempdir: str,
        files: Dict[str, str] = None,
        env: Dict[str, str] = {},
        n_try: int = 0
):
    task_files = _maybe_zip_task_files(files or {}, tempdir)
    task_files[__package__] = cluster_pack.zip_path(here, False, tempdir)

    _add_to_env(env, "LIBHDFS_OPTS", "-Xms64m -Xmx512m")

    env["TF_YARN_N_TRY"] = str(n_try)

    task_env = {
        **env,
        # Make Python modules/packages passed via ``files`` importable.
        "PYTHONPATH": ".:" + env.get("PYTHONPATH", ""),
        "PEX_ROOT": os.path.join("/tmp", str(uuid.uuid4()))
    }

    if mlflow.use_mlflow:
        task_env["MLFLOW_RUN_ID"] = mlflow.active_run_id()
        task_env["MLFLOW_TRACKING_URI"] = mlflow.get_tracking_uri()
        task_env["GIT_PYTHON_REFRESH"] = "quiet"

    return task_files, task_env


def _add_to_env(env: Dict[str, str], env_name: str, opts: str):
    if env_name in env:
        env[env_name] = f"{opts} {env.get(env_name)}"
    else:
        env[env_name] = f"{opts}"


def _maybe_zip_task_files(files, tempdir):
    task_files = {}
    for target, source in files.items():
        assert target not in task_files
        if os.path.isdir(source):
            source = cluster_pack.zip_path(source, False, tempdir)

        task_files[target] = source
    return task_files


def _setup_cluster_spec(
    task_instances: List[Tuple[str, int]],
    app: skein.ApplicationClient
) -> tf.train.ClusterSpec:
    tasks_not_in_cluster = ['evaluator', 'tensorboard']
    cluster_instances = [t for t in task_instances if t[0] not in tasks_not_in_cluster]
    app.kv[constants.KV_CLUSTER_INSTANCES] = json.dumps(cluster_instances).encode()
    return tf.train.ClusterSpec(
        cluster.aggregate_spec(app, list(_internal.iter_tasks(cluster_instances)))
    )


def _setup_skein_cluster(
        pyenvs: Dict[topologies.NodeLabel, _env.PythonEnvDescription],
        task_specs: Dict[str, topologies.TaskSpec] = TASK_SPEC_NONE,
        *,
        custom_task_module: Optional[str] = None,
        skein_client: skein.Client = None,
        files: Dict[str, str] = None,
        env: Dict[str, str] = {},
        queue: str = "default",
        acls: ACLs = None,
        file_systems: List[str] = None,
        name: str = "RunOnYarn",
        n_try: int = 0,
        pre_script_hook: Optional[str] = None
) -> SkeinCluster:
    os.environ["JAVA_TOOL_OPTIONS"] = \
        "-XX:ParallelGCThreads=1 -XX:CICompilerCount=2 "\
        f"{os.environ.get('JAVA_TOOL_OPTIONS', '')}"

    pre_script_hook = pre_script_hook if pre_script_hook else ""
    with tempfile.TemporaryDirectory() as tempdir:
        task_files, task_env = _setup_task_env(tempdir, files, env, n_try)
        services = {}
        for task_type, task_spec in list(task_specs.items()):
            pyenv = pyenvs[task_spec.label]
            service_env = task_env.copy()
            if task_spec.tb_termination_timeout_seconds >= 0:
                service_env["TB_TERMINATION_TIMEOUT_SECONDS"] = \
                    str(task_spec.tb_termination_timeout_seconds)
            if task_spec.tb_model_dir:
                service_env["TB_MODEL_DIR"] = str(task_spec.tb_model_dir)
            if task_spec.tb_extra_args:
                service_env["TB_EXTRA_ARGS"] = str(task_spec.tb_extra_args)

            services[task_type] = skein.Service(
                script=f'''
                            set -x
                            {pre_script_hook}
                            {_env.gen_task_cmd(
                                pyenv,
                                task_type,
                                custom_task_module)}
                        ''',
                resources=skein.model.Resources(task_spec.memory, task_spec.vcores),
                max_restarts=0,
                instances=task_spec.instances,
                node_label=task_spec.label.value,
                files={
                    **task_files,
                    pyenv.dest_path: pyenv.path_to_archive
                },
                env=service_env)

        # on the cluster we don't ask again for delegation tokens
        if "HADOOP_TOKEN_FILE_LOCATION" in os.environ:
            file_systems = None

        spec = skein.ApplicationSpec(
            services,
            queue=queue,
            acls=acls,
            file_systems=file_systems,
            name=name
        )

        if skein_client is None:
            skein_client = skein.Client()

        task_instances = [(task_type, spec.instances) for task_type, spec in task_specs.items()]
        events: Dict[str, Dict[str, str]] = \
            {task: {} for task in _internal.iter_tasks(task_instances)}
        app = skein_client.submit_and_connect(spec)

        # Start a thread which collects all events posted by all tasks in kv store
        event_listener = Thread(target=_aggregate_events, args=(app.kv, events))
        event_listener.start()

        return SkeinCluster(skein_client, app, task_instances, event_listener, events)


def _hook_name_already_exists(
        hook: tf.estimator.SessionRunHook,
        hooks: List[tf.estimator.SessionRunHook]) -> bool:
    hook_name = type(hook).__name__
    return len([h for h in hooks
                if type(h).__name__ == hook_name]) > 0


def _add_monitor_to_experiment(experiment: experiment.Experiment) -> experiment.Experiment:
    logger.info(f"configured training hooks: {experiment.train_spec.hooks}")

    training_hooks = list(experiment.train_spec.hooks)

    if experiment.config.log_step_count_steps is not None:
        steps_per_second_hook = metrics.StepPerSecondHook(
            every_n_steps=experiment.config.log_step_count_steps
        )
        if not _hook_name_already_exists(steps_per_second_hook, training_hooks):
            training_hooks.append(steps_per_second_hook)
        else:
            logger.warning("do not add StepPerSecondHook as there is already one configured")

    monitored_train_spec = experiment.train_spec._replace(
        hooks=training_hooks
    )

    monitored_eval_spec = experiment.eval_spec._replace(
        hooks=(evaluator_metrics.EvalMonitorHook(), *experiment.eval_spec.hooks)
    )

    experiment = experiment._replace(eval_spec=monitored_eval_spec, train_spec=monitored_train_spec)
    return experiment


def _run_on_cluster(
    experiment_fn: ExperimentFn,
    skein_cluster: SkeinCluster,
    eval_monitor_log_thresholds: Dict[str, Tuple[float, float]] = None,
    n_try: int = 0
) -> Optional[metrics.Metrics]:
    def _new_experiment_fn():
        return _add_monitor_to_experiment(experiment_fn())
    new_experiment_fn = _new_experiment_fn

    # Attempt serialization early to avoid allocating unnecesary resources
    serialized_fn = cloudpickle.dumps(new_experiment_fn)
    with skein_cluster.client:
        return _execute_and_await_termination(
            skein_cluster,
            serialized_fn,
            eval_monitor_log_thresholds,
            n_try=n_try
        )


def _default_acls_all_access() -> skein.model.ACLs:
    return skein.model.ACLs(
        enable=True,
        ui_users=['*'],
        view_users=['*']
    )


def run_on_yarn(
    pyenv_zip_path: Union[str, Dict[topologies.NodeLabel, str]],
    experiment_fn: ExperimentFn,
    task_specs: Dict[str, topologies.TaskSpec] = TASK_SPEC_NONE,
    *,
    skein_client: skein.Client = None,
    files: Dict[str, str] = None,
    env: Dict[str, str] = {},
    queue: str = "default",
    acls: ACLs = _default_acls_all_access(),
    file_systems: List[str] = None,
    eval_monitor_log_thresholds: Dict[str, Tuple[float, float]] = None,
    nb_retries: int = 0,
    custom_task_module: Optional[str] = None,
    name: str = "RunOnYarn",
    pre_script_hook: Optional[str] = None
) -> Optional[metrics.Metrics]:
    """Run an experiment on YARN.

    The implementation allocates a service with the requested number
    of instances for each distributed TensorFlow task type. Each
    instance runs ``_dispatch_task`` which roughly does the following.

    1. Reserve a TCP port and communicate the resulting socket address
        (host/port pair) to other instances using the "init" barrier.
    2. Spawn ``train_and_evaluate`` in a separate thread.
    3. Synchronize the "ps" tasks on the "stop" barrier.
        The barrier compensates for the fact that "ps" tasks never
        terminate, and therefore should be killed once all other
        tasks are finished.

    Parameters
    ----------

    pyenv_zip_path
        Path to an archive of a python environment to be deployed
        It can be a zip conda env or a pex archive
        In case of GPU/CPU cluster, provide a dictionnary with both
        environments.

    experiment_fn
        A function constructing the estimator alongside the train
        and eval specs.

    skein_client
        Skein client used to submit yarn jobs

    task_specs
        Resources to allocate for each task type. The keys
        must be a subset of ``"chief"``, ``"worker"``, ``"ps"``, and
        ``"evaluator"``. The minimal spec must contain at least
        ``"chief"``.

    files
        Local files or directories to upload to the container.
        The keys are the target locations of the resources relative
        to the container root, while the values -- their
        corresponding local sources. Note that container root is
        appended to ``PYTHONPATH``. Therefore, any listed Python
        module a package is automatically importable.

    env
        Environment variables to forward to the containers.

    queue
        YARN queue to use.

    acls
        Configures the application-level Access Control Lists (ACLs).
        Optional, defaults to ACLs all access.

        See `ACLs <https://jcrist.github.io/skein/specification.html#acls>` for details.

    file_systems
        A list of namenode URIs to acquire delegation tokens for
        in addition to ``fs.defaultFS``.

    eval_monitor_log_thresholds
        optional dictionnary of string to (float 1, float 2).
        Each couple (key, value) corresponds to an evaluation
        monitored metric and an associated range. The evaluation monitored metric
        is logged if it is in [float 1; float 2]. If the lower bound is None it is set to 0.
        If the upper bound is None, it is set to maximum value
        A monitored metric with no range is always logged. List of monitored metrics:
        'awake_time_ratio': 'Awake/idle ratio',
        'eval_step_mean_duration': 'Eval step mean duration (in sec)',
        'last_training_step': 'Training set of last checkpoint',
        'nb_eval_steps': 'Number of evaluation steps done'

    nb_retries
        Number of times the yarn application is retried in case of failures

    custom_task_module
        Provide the full module name of a custom task that is executed on each worker
        None by default
        (Module will be invoked with python -m {custom_task_module} on the cluster)
        Only for advanced use cases, can be useful for example,
        to bypass/tweek the existing estimator.train_and_evaluate pattern

    name
        Name of the yarn application

    pre_script_hook
        bash command to prepare Hadoop environment

    Raises
    ------
    RunFailed
        If the final status of the YARN application is ``"FAILED"``.
    """
    if nb_retries < 0:
        raise ValueError(f'nb_retries must be greater or equal to 0. Got {nb_retries}')

    pyenvs = _setup_pyenvs(pyenv_zip_path)

    n_try = 0
    while True:
        try:
            skein_cluster = _setup_skein_cluster(
                pyenvs=pyenvs,
                task_specs=task_specs,
                skein_client=skein_client,
                files=files,
                env=env,
                queue=queue,
                acls=acls,
                file_systems=file_systems,
                name=name,
                n_try=n_try,
                custom_task_module=custom_task_module,
                pre_script_hook=pre_script_hook
            )
            with _shutdown_on_exception(skein_cluster.app):
                _setup_cluster_spec(skein_cluster.tasks, skein_cluster.app)

                return _run_on_cluster(
                    experiment_fn,
                    skein_cluster,
                    eval_monitor_log_thresholds,
                    n_try
                )
        except Exception:
            n_try += 1
            if n_try == nb_retries + 1:
                raise
            logger.exception(f"Retrying user application ... "
                             f"{nb_retries + 1 - n_try} remaining attempts")

    # Necessary for type checking
    return None


def get_safe_experiment_fn(full_fn_name: str, *args):
    """
    tf-yarn serializes the provided experiment function with cloudpickle.dumps.
    This is good for interactive experiments but can sometimes fail
    because the function is not serializable.
    You can use this wrapper function
    if you ship your experiment function (via conda, pex) manually to the workers.

    full_fn_name
        the name of the function ( with the full path to package and module)
        i.e. tf_yarn.my_module.my_experiment_fn

    args
        arguments to be provided to this function

    """
    module_name, fn_name = full_fn_name.rsplit('.', 1)
    module = importlib.import_module(module_name)
    experiment_fn = getattr(module, fn_name)
    return partial(experiment_fn, *args)


def _send_config_proto(
        skein_cluster: SkeinCluster,
        tf_session_config: tf.compat.v1.ConfigProto):
    serialized_fn = cloudpickle.dumps(tf_session_config)
    skein_cluster.app.kv[constants.KV_TF_SESSION_CONFIG] = serialized_fn


@contextmanager
def _shutdown_on_exception(app: skein.ApplicationClient):
    # Ensure SIGINT is not masked to enable kill on C-c.
    import signal
    signal.signal(signal.SIGINT, signal.default_int_handler)

    try:
        yield
    except (KeyboardInterrupt, SystemExit):
        with suppress(SkeinError):
            app.shutdown(FinalStatus.KILLED)
        logger.error("Application killed on user request")
    except Exception:
        with suppress(SkeinError):
            app.shutdown(FinalStatus.FAILED)
        logger.exception("Application shutdown due to an exception")
        raise


def _execute_and_await_termination(
    skein_cluster: SkeinCluster,
    serialized_fn: bytes,
    eval_monitor_log_thresholds: Dict[str, Tuple[float, float]] = None,
    n_try: int = 0,
    poll_every_secs: int = 10
) -> Optional[metrics.Metrics]:
    skein_cluster.app.kv[constants.KV_EXPERIMENT_FN] = serialized_fn
    eval_metrics_logger = evaluator_metrics.EvaluatorMetricsLogger(
        [task for task in _internal.iter_tasks(skein_cluster.tasks)
         if task.startswith('evaluator')],
        skein_cluster.app,
        eval_monitor_log_thresholds
    )

    tensorboard_url_event_name = tensorboard.url_event_name(
        _internal.iter_tasks(skein_cluster.tasks)
    )
    tensorboard_url_logger = metrics.OneShotMetricsLogger(
        skein_cluster.app,
        [(tensorboard_url_event_name, tensorboard.URL_EVENT_LABEL)]
        if tensorboard_url_event_name else [],
        n_try
    )

    state = None
    while True:
        report = skein_cluster.client.application_report(skein_cluster.app.id)
        logger.info(
            f"Application report for {skein_cluster.app.id} (state: {report.state})")
        if state != report.state:
            logger.info(_format_app_report(report))

        if report.final_status != "undefined":
            skein_cluster.event_listener.join()
            log_events, result_metrics, container_status = _handle_events(skein_cluster.events,
                                                                          n_try)
            logger.info(log_events)

            containers = container_status.by_container_id()
            # add one for AM container
            wait_for_nb_logs = sum([instances for task, instances in skein_cluster.tasks]) + 1
            logs = _get_app_logs(
                skein_cluster.client,
                skein_cluster.app,
                wait_for_nb_logs
            )
            _save_logs_to_mlflow(logs, containers, n_try)

            if report.final_status == "failed":
                raise RunFailed
            else:
                break
        else:
            eval_metrics_logger.log()
            tensorboard_url_logger.log()
        time.sleep(poll_every_secs)
        state = report.state

    result_metrics.log_mlflow(n_try)
    return result_metrics


def _save_logs_to_mlflow(logs: Optional[skein.model.ApplicationLogs],
                        containers: Dict[str, Tuple[str, str]],
                        n_try: int):
    if not logs:
        return

    for key, logs in logs.items():
        if key in containers:
            task, status = containers[key]
            filename = mlflow.format_key(f"{task}_{status}_{n_try}")
        else:
            filename = mlflow.format_key(f"{key}_{n_try}")
        mlflow.save_text_to_mlflow(logs, filename)


def _format_app_report(report: ApplicationReport) -> str:
    attrs = [
        "queue",
        "start_time",
        "finish_time",
        "final_status",
        "tracking_url",
        "user"
    ]
    return os.linesep + os.linesep.join(
        f"{attr:>16}: {getattr(report, attr) or ''}" for attr in attrs)


def _aggregate_events(
    kv: skein.kv.KeyValueStore,
    events: Dict[str, Dict[str, str]]
) -> None:
    """
    Aggregate events from all dispatched tasks.

    The lifecycle of a task consists of three stages:
    * init which carries the reserved socket address,
    * start with no payload, and
    * stop with an optional formatted exception.
    """
    # ``ConnectionError`` indicates that the app has finished and
    # the AM is down.
    queue = kv.events(event_type="PUT")
    with suppress(skein.exceptions.ConnectionError), queue:
        for evt in queue:
            if "/" in evt.key:
                task, stage = evt.key.rsplit("/", 1)
                events[task][stage] = evt.result.value.decode()


def _handle_events(
    events: Dict[str, Dict[str, str]],
    n_try: int
) -> Tuple[str, metrics.Metrics, ContainerLogStatus]:
    header = []
    details = []
    min_training_start_time = timedelta.max
    max_training_stop_time = timedelta.min
    min_eval_start_time = timedelta.max
    max_eval_stop_time = timedelta.min
    valid_training_time = True
    valid_eval_time = True
    container_duration: Dict[str, Optional[timedelta]] = dict()
    train_eval_time_per_node: Dict[str, Optional[timedelta]] = dict()
    container_log_urls: Dict[str, str] = dict()
    container_status: Dict[str, str] = dict()
    for task, stages in sorted(events.items()):
        if "stop" in stages:
            status = "FAILED" if stages["stop"] else "SUCCEEDED"
        elif stages:
            status = "KILLED"
        else:
            # No events -- container was never started.
            status = "REQUESTED"

        sock_addr = stages.get("init", "")
        exception = stages.get("stop", "")
        logs = stages.get("logs", "")

        container_log_urls[task] = logs
        container_status[task] = status
        container_duration[task] = None
        if 'container_start_time' in stages and 'container_stop_time' in stages:
            container_duration[task] = timedelta(seconds=(float(stages['container_stop_time'])
                                                          - float(stages['container_start_time'])))

        train_eval_time_per_node[task] = None
        task_type = cluster.get_task_type(task)
        if 'train_eval_start_time' in stages and 'train_eval_stop_time' in stages and not exception:
            start_time = timedelta(seconds=float(stages['train_eval_start_time']))
            stop_time = timedelta(seconds=float(stages['train_eval_stop_time']))
            train_eval_time_per_node[task] = stop_time - start_time
            if cluster.is_worker(task_type) or cluster.is_chief(task_type):
                if start_time < min_training_start_time:
                    min_training_start_time = start_time
                if stop_time > max_training_stop_time:
                    max_training_stop_time = stop_time
            elif cluster.is_evaluator(task_type):
                if start_time < min_eval_start_time:
                    min_eval_start_time = start_time
                if stop_time > max_eval_stop_time:
                    max_eval_stop_time = stop_time
        else:
            if cluster.is_worker(task_type) or cluster.is_chief(task_type):
                valid_training_time = False
            elif cluster.is_evaluator(task_type):
                valid_eval_time = False

        header.append(f"{task:>16}  {sock_addr}  {status}  {logs}"
                      f"  Container duration: {container_duration[task]}"
                      f"  Training/evaluation duration : {train_eval_time_per_node[task]}")

        if exception:
            details.append(f"Exception in task {task}:")
            details.append(exception)

    training_time = max_training_stop_time - min_training_start_time\
        if valid_training_time and min_training_start_time < timedelta.max else None
    eval_time = max_eval_stop_time - min_eval_start_time\
        if valid_eval_time and min_eval_start_time < timedelta.max else None
    header.append(f'Training time = {training_time}')
    header.append(f'Evaluation time = {eval_time}')

    result_metrics = metrics.Metrics(
        training_time,
        eval_time,
        container_duration,
        train_eval_time_per_node
    )

    return ((os.linesep + os.linesep.join(header)
             + os.linesep * (1 + bool(details))
             + os.linesep.join(details)),
            result_metrics,
            ContainerLogStatus(container_log_urls, container_status))


def _get_app_logs(
    client: skein.Client,
    app: skein.ApplicationClient,
    wait_for_nb_logs: int
) -> Optional[skein.model.ApplicationLogs]:
    for ind in range(YARN_LOG_TRIES):
        try:
            logs = client.application_logs(app.id)
            nb_keys = len(logs.keys())
            logger.info(f"Got {nb_keys}/{wait_for_nb_logs} log files")
            if nb_keys == wait_for_nb_logs:
                return logs
        except Exception:
            logger.warn(
                f"Cannot collect logs (attempt {ind+1}/{YARN_LOG_TRIES})",
                exc_info=True)
        time.sleep(3)
    return None