# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License'). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the 'license' file accompanying this file. This file is
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
import argparse
import inspect
import logging
import os
import subprocess
import time
from typing import Any, List, Tuple  # noqa ignore=F401 imported but unused

import paramiko
import psutil

import gethostname
from sagemaker_containers import _logging, _process, _timeout

logger = _logging.get_logger()
logging.getLogger("paramiko").setLevel(logging.INFO)


class WorkerRunner(_process.ProcessRunner):
    """Runner responsible for preparing MPI distributed training and waiting for MPI
     master execution to finish.
    """

    def __init__(self, user_entry_point, args, env_vars, master_hostname):
        """Placeholder docstring"""
        super(WorkerRunner, self).__init__(user_entry_point, args, env_vars)
        self._master_hostname = str(master_hostname)

    def run(
        self, wait=True, capture_error=False
    ):  # type: (bool, bool) -> None # pylint: disable=unused-argument
        """The WorkerRunner proceeds as following:

        - wait for the MPI Master to create its SSH daemon
        - start its SSH daemon
        - monitor the MPI orted process and wait it to finish the MPI execution
        """
        logger.info("Starting MPI run as worker node.")
        if wait:
            logger.info("Waiting for MPI Master to create SSH daemon.")
            self._wait_master_to_start()
        logger.info("MPI Master online, creating SSH daemon.")

        _start_sshd_daemon()

        if wait:
            logger.info("Waiting for MPI process to finish.")
            _wait_orted_process_to_finish()
            time.sleep(30)
        logger.info("MPI process finished.")

    def _wait_master_to_start(self):  # type: () -> None
        """Placeholder docstring"""
        while not _can_connect(self._master_hostname):
            time.sleep(1)

    def _wait_master_to_finish(self):  # type: () -> None
        """Placeholder docstring"""
        while _can_connect(self._master_hostname):
            time.sleep(30)


def _wait_orted_process_to_finish():  # type: () -> None
    """Placeholder docstring"""
    orted = _orted_process()
    psutil.wait_procs(orted)


def _orted_process():  # pylint: disable=inconsistent-return-statements
    """Waits maximum of 5 minutes for orted process to start"""
    for _ in range(5 * 60):
        procs = [p for p in psutil.process_iter(attrs=["name"]) if p.info["name"] == "orted"]

        if procs:
            return procs
        time.sleep(1)


class MasterRunner(_process.ProcessRunner):
    """Responsible to prepare MPI distributed training and syncronize work with the Workers.
    """

    def __init__(
        self,
        user_entry_point,
        args,
        env_vars,
        master_hostname,
        hosts,
        process_per_host,
        custom_mpi_options,
        network_interface_name,
        interval=1,
        timeout_in_seconds=60 * 60,
        num_processes=None,
    ):
        """Placeholder docstring"""

        super(MasterRunner, self).__init__(user_entry_point, args, env_vars)

        self._master_hostname = master_hostname
        self._hosts = hosts
        self._process_per_host = process_per_host
        self._num_processes = num_processes
        self._custom_mpi_options = custom_mpi_options
        self._network_interface_name = network_interface_name
        self._interval = interval
        self.timeout_in_seconds = timeout_in_seconds

    def _setup(self):  # type: () -> None
        """Placeholder docstring"""
        logger.info("Starting MPI run as worker node.")
        logger.info("Creating SSH daemon.")
        _start_sshd_daemon()

        self._wait_for_workers()

    def _wait_for_workers(self):  # type: () -> None
        """Placeholder docstring"""
        logger.info("Waiting for MPI workers to establish their SSH connections")

        workers = [host for host in self._hosts if host != self._master_hostname]
        with _timeout.timeout(seconds=self.timeout_in_seconds):
            for host in workers:
                while not _can_connect(host):
                    time.sleep(self._interval)
                logger.info("Worker %s available for communication", host)

    def _create_command(self):  # type: () -> List[str, Any]
        """Placeholder docstring"""
        num_hosts = len(self._hosts)
        num_processes = self._num_processes or self._process_per_host * num_hosts

        # By default, use one process per GPU, or one process per node (if training with CPU).
        if self._process_per_host == 1:
            host_list = self._hosts
        else:
            host_list = ["%s:%s" % (host, self._process_per_host) for host in self._hosts]

        msg = "Env Hosts: %s Hosts: %s process_per_hosts: %s num_processes: %s"
        logger.info(msg, self._hosts, host_list, self._process_per_host, num_processes)

        overridden_known_options, additional_options = _parse_custom_mpi_options(
            self._custom_mpi_options
        )

        logger.info("Network interface name: %s" % self._network_interface_name)

        command = [
            "mpirun",
            "--host",
            ",".join(host_list),
            "-np",
            str(num_processes),
            "--allow-run-as-root",
            "--display-map",
            "--tag-output",
            "-mca",
            "btl_tcp_if_include",
            self._network_interface_name,
            "-mca",
            "oob_tcp_if_include",
            self._network_interface_name,
            "-mca",
            "plm_rsh_no_tree_spawn",
            "1",
            "-bind-to",
            "socket",
            "-map-by",
            "slot",
            "-mca",
            "pml",
            "ob1",
            "-mca",
            "btl",
            "^openib",
            "-mca",
            "orte_abort_on_non_zero_status",
            "1",
            "-x",
            "NCCL_MIN_NRINGS=4",
            "-x",
            "NCCL_SOCKET_IFNAME=%s" % self._network_interface_name,
            "-x",
            "NCCL_DEBUG=%s" % overridden_known_options.NCCL_DEBUG,
            "-x",
            "LD_LIBRARY_PATH",
            "-x",
            "PATH",
            "-x",
            "LD_PRELOAD=%s" % inspect.getfile(gethostname),
        ]

        command.extend(additional_options)

        for credential in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]:
            if credential in os.environ:
                command.extend(["-x", credential])

        for name in self._env_vars:
            command.extend(["-x", name])

        command.extend(super(MasterRunner, self)._create_command())

        return command

    def _python_command(self):
        """Use mpi4py to force processes to abort if an uncaught exception occurs.
        https://docs.chainer.org/en/stable/chainermn/tutorial/tips_faqs.html#mpi-process-hangs-after-an-unhandled-python-exception
        """
        return super(MasterRunner, self)._python_command() + ["-m", "mpi4py"]


_SSH_DAEMON_NOT_FOUND_ERROR_MESSAGE = """
SSH daemon not found, please install SSH to allow MPI to communicate different nodes in cluster.

You can install ssh by running following commands:
-------------------------------------------------

1. Install SSH via apt-get:

apt-get update && apt-get install -y --no-install-recommends openssh-server && mkdir -p /var/run/sshd

2. SSH login fix. Otherwise user is kicked off after login:
sed 's@session\\s*required\\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd

3. Create SSH key to allow password less ssh between different docker instances:
mkdir -p /root/.ssh/ && ssh-keygen -q -t rsa -N '' -f /root/.ssh/id_rsa && \
  cp /root/.ssh/id_rsa.pub /root/.ssh/authorized_keys && \
  printf "Host *\n  StrictHostKeyChecking no\n" >> /root/.ssh/config
"""


def _start_sshd_daemon():  # type: () -> None
    """Placeholder docstring"""
    sshd_executable = "/usr/sbin/sshd"

    if not os.path.exists(sshd_executable):
        raise RuntimeError(_SSH_DAEMON_NOT_FOUND_ERROR_MESSAGE)

    subprocess.Popen([sshd_executable, "-D"])


def _can_connect(host, port=22):  # type: (str, int) -> bool
    """Checks if the connection to provided ``host`` and ``port`` is possible or not.
       Args:
           host (str): Hostname for the host to check connection.
           port (int): Port name of the host to check connection on.
    """
    try:
        logger.debug("Testing connection to host %s", host)
        client = paramiko.SSHClient()
        client.load_system_host_keys()
        client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        client.connect(host, port=port)
        client.close()
        logger.info("Can connect to host %s", host)
        return True
    except Exception as e:  # pylint: disable=broad-except
        logger.info("Cannot connect to host %s", host)

        logger.info("Connection failed with exception: \n %s", str(e))
        return False


def _parse_custom_mpi_options(custom_mpi_options):
    # type: (str) -> Tuple[argparse.Namespace, List[str]]
    """Parse custom MPI options provided by user. Known options default value will be overridden
    and unknown options would be identified separately."""

    parser = argparse.ArgumentParser()
    parser.add_argument("--NCCL_DEBUG", default="INFO", type=str)

    return parser.parse_known_args(custom_mpi_options.split())