"""
Implementation of proxy-related features.
"""
import os
import sys
import time

import Pyro4
from Pyro4.errors import ConnectionClosedError
from Pyro4.errors import NamingError
from Pyro4.message import FLAGS_ONEWAY

from . import config
from .address import SocketAddress
from .address import address_to_host_port


def locate_ns(nsaddr, timeout=3.0):
    """
    Locate a name server to ensure it actually exists.

    Parameters
    ----------
    nsaddr : SocketAddress
        The address where the name server should be up and running.
    timeout : float
        Timeout in seconds before aborting location.

    Returns
    -------
    nsaddr
        The address where the name server was located.

    Raises
    ------
    NamingError
        If the name server could not be located.
    """
    host, port = address_to_host_port(nsaddr)
    time0 = time.time()
    while True:
        try:
            Pyro4.locateNS(host, port)
            return nsaddr
        except NamingError:
            if time.time() - time0 < timeout:
                time.sleep(0.1)
                continue
            raise TimeoutError('Could not locate the name server!')


class Proxy(Pyro4.core.Proxy):
    """
    A proxy to access remote agents.

    Parameters
    ----------
    name : str
        Proxy name, as registered in the name server.
    nsaddr : SocketAddress, str
        Name server address.
    timeout : float
        Timeout, in seconds, to wait until the agent is discovered.
    safe : bool, default is None
        Use safe calls by default. When not set, osbrain default's
        :py:data:`osbrain.config['SAFE']` is used.
    """

    def __init__(self, name, nsaddr=None, timeout=3.0, safe=None):
        if not nsaddr:
            nsaddr = os.environ.get('OSBRAIN_NAMESERVER_ADDRESS')
        nshost, nsport = address_to_host_port(nsaddr)
        # Make sure name server exists
        locate_ns(nsaddr)
        time0 = time.time()
        super().__init__('PYRONAME:%s@%s:%s' % (name, nshost, nsport))
        if safe is not None:
            self._default_safe = safe
        else:
            self._default_safe = config['SAFE']
        self._safe = self._default_safe
        self._next_oneway = False
        while not self._ready_or_timeout(time0, timeout):
            continue

    def _ready_or_timeout(self, time0, timeout):
        """
        Check if the proxy is ready or raise after a timeout.

        Parameters
        ----------
        time0 : float
            Timestamp (in seconds) to take as the initial time.
        timeout : float
            Time (in seconds) allowed after `time0` before raising an
            exception.
        """
        try:
            self.unsafe.ping()
        except Exception:
            time.sleep(0.1)
            if time.time() - time0 < timeout:
                return False
            raise
        return True

    def wait_for_running(self, timeout=3.0):
        """
        Wait until the agent is running.

        Parameters
        ----------
        timeout : float
            Raise and exception if the agent is not running after this number
            of seconds. Use a negative value to wait forever.

        Raises
        ------
        TimeoutError
            If the agent is not running after the given timeout.

        Returns
        -------
        Proxy
            The object itself.
        """
        time0 = time.time()
        while not self.is_running():
            if timeout >= 0 and time.time() - time0 > timeout:
                msg = 'Timed out while waiting for the agent to be running'
                raise TimeoutError(msg)
            time.sleep(0.01)

        return self

    def __getstate__(self):
        return super().__getstate__() + (
            self._next_oneway,
            self._default_safe,
            self._safe,
        )

    def __setstate__(self, state):
        super().__setstate__(state[:-3])
        self._next_oneway = state[-3]
        self._default_safe = state[-2]
        self._safe = state[-1]

    def __setattr__(self, name, value):
        if name in ('_safe', '_default_safe', '_next_oneway'):
            return super(Pyro4.core.Proxy, self).__setattr__(name, value)
        if name.startswith('_'):
            return super().__setattr__(name, value)
        kwargs = {name: value}
        return self.set_attr(**kwargs)

    def __getattr__(self, name):
        if name in self._pyroAttrs:
            return self.get_attr(name)
        return super().__getattr__(name)

    def release(self):
        """
        Release the connection to the Pyro daemon.
        """
        self._pyroRelease()

    def nsaddr(self):
        """
        Get the socket address of the name server.

        Returns
        -------
        SocketAddress
            The socket address.
        """
        return SocketAddress(self._pyroUri.host, self._pyroUri.port)

    @property
    def safe(self):
        """
        Make the next remote method call be safe.

        Returns
        -------
        The proxy itself.
        """
        self._safe = True
        return self

    @property
    def unsafe(self):
        """
        Make the next remote method call be unsafe.

        Returns
        -------
        The proxy itself.
        """
        self._safe = False
        return self

    @property
    def oneway(self):
        """
        Make the next remote method call be one way.

        Returns
        -------
        The proxy itself.
        """
        self._next_oneway = True
        return self

    def _pyroInvoke(  # noqa: N802
        self, methodname, args, kwargs, flags=0, objectId=None  # noqa: N803
    ):
        """
        Wrapper around `_remote_call` to safely execute methods on remote
        objects.
        """
        try:
            result = self._remote_call(
                methodname, args, kwargs, flags, objectId
            )
        except Exception:
            sys.stdout.write(''.join(Pyro4.util.getPyroTraceback()))
            sys.stdout.flush()
            raise
        finally:
            self._safe = self._default_safe
            self._next_oneway = False
        self._post_invoke(methodname, args, kwargs)
        return result

    def _is_safe_method(self, methodname):
        """
        Check if a remote method can be called safely.

        Parameters
        ----------
        methodname : str
            The name of the method to evaluate.

        Returns
        -------
        bool
            Whether the method can be safely called.
        """
        return (
            methodname in self._pyroMethods
            and not methodname.startswith('_')
            and methodname
            not in (
                'run',
                'get_attr',
                'kill',
                'safe_call',
                'concurrent',
                'is_running',
            )
        )

    def _remote_call(
        self, methodname, args, kwargs, flags, objectId  # noqa: N803
    ):
        """
        Call a remote method from the proxy.
        """
        if self._next_oneway:
            flags |= FLAGS_ONEWAY
            result = super()._pyroInvoke(
                methodname, args, kwargs, flags=flags, objectId=objectId
            )
            return result
        if self._safe and self._is_safe_method(methodname):
            safe_args = [methodname] + list(args)
            result = super()._pyroInvoke(
                'safe_call', safe_args, kwargs, flags=flags, objectId=objectId
            )
            if isinstance(result, Exception):
                raise result
        else:
            result = super()._pyroInvoke(
                methodname, args, kwargs, flags=flags, objectId=objectId
            )
        return result

    def _post_invoke(self, methodname, args, kwargs):
        """
        After invoking a call, check if the proxy must be modified.

        This could happen if the `set_method` or `set_attr` have been invoked.
        In that case, the method(s) or attribute(s) are added to the proxy's
        available method(s)/attributes(s).
        """
        if methodname == 'set_method':
            self._set_new_available_methods(args, kwargs)
        elif methodname == 'set_attr':
            self._set_new_available_attributes(kwargs)

    def _set_new_available_methods(self, args, kwargs):
        """
        Set new methods available from the proxy.

        Parameters
        ----------
        args : list
            A list of new methods to be made available from the proxy.
        kwargs : dict
            A dictionary with the methods' names and their values.
        """
        for method in args:
            self._pyroMethods.add(method.__name__)
        for name, _ in kwargs.items():
            self._pyroMethods.add(name)

    def _set_new_available_attributes(self, kwargs):
        """
        Set new attributes available from the proxy.

        Parameters
        ----------
        kwargs : dict
            A dictionary with the attributes' names and their values.
        """
        for name in kwargs:
            self._pyroAttrs.add(name)


class NSProxy(Pyro4.core.Proxy):
    """
    A proxy to access a name server.

    Parameters
    ----------
    nsaddr : SocketAddress, str
        Name server address.
    timeout : float
        Timeout, in seconds, to wait until the name server is discovered.
    """

    def __init__(self, nsaddr=None, timeout=3):
        if not nsaddr:
            nsaddr = os.environ.get('OSBRAIN_NAMESERVER_ADDRESS')
        nshost, nsport = address_to_host_port(nsaddr)
        # Make sure name server exists
        locate_ns(nsaddr, timeout)
        ns_name = Pyro4.constants.NAMESERVER_NAME
        super().__init__('PYRONAME:%s@%s:%d' % (ns_name, nshost, nsport))

    def release(self):
        """
        Release the connection to the Pyro daemon.
        """
        self._pyroRelease()

    def proxy(self, name, timeout=3.0):
        """
        Get a proxy to access an agent registered in the name server.

        Parameters
        ----------
        name : str
            Proxy name, as registered in the name server.
        timeout : float
            Timeout, in seconds, to wait until the agent is discovered.

        Returns
        -------
        Proxy
            A proxy to access an agent registered in the name server.
        """
        return Proxy(name, nsaddr=self.addr(), timeout=timeout)

    def addr(self, agent_alias=None, address_alias=None):
        """
        Return the name server address or the address of an agent's socket.

        Parameters
        ----------
        agent_alias : str, default is None
            The alias of the agent to retrieve its socket address.
        address_alias : str, default is None
            The alias of the socket address to retrieve from the agent.

        Returns
        -------
        SocketAddress or AgentAddress
            The name server or agent's socket address.
        """
        if not agent_alias and not address_alias:
            return SocketAddress(self._pyroUri.host, self._pyroUri.port)
        agent = self.proxy(agent_alias)
        addr = agent.addr(address_alias)
        agent.release()
        return addr

    def shutdown_agents(self, timeout=10.0):
        """
        Shutdown all agents registered in the name server.

        Parameters
        ----------
        timeout : float, default is 10.
            Timeout, in seconds, to wait for the agents to shutdown.
        """
        # Wait for all agents to be shutdown (unregistered)
        time0 = time.time()
        super()._pyroInvoke('async_shutdown_agents', (self.addr(),), {})
        while time.time() - time0 <= timeout / 2.0:
            if not len(self.agents()):
                return
            time.sleep(0.1)
        super()._pyroInvoke('async_kill_agents', (self.addr(),), {})
        while time.time() - time0 <= timeout:
            if not len(self.agents()):
                return
            time.sleep(0.1)
        raise TimeoutError(
            'Chances are {} were not shutdown after {} s!'.format(
                self.agents(), timeout
            )
        )

    def shutdown(self, timeout=10.0):
        """
        Shutdown the name server. All agents will be shutdown as well.

        Parameters
        ----------
        timeout : float, default is 10.
            Timeout, in seconds, to wait for the agents to shutdown.
        """
        self.shutdown_agents(timeout)
        try:
            super()._pyroInvoke('daemon_shutdown', (), {}, flags=0)
        except ConnectionClosedError:
            pass