import asyncio
import logging
import time
import collections
import sys
import os
import socket
from functools import partial

from .resolver import STSResolver, STSFetchResult
from .constants import QUEUE_LIMIT, CHUNK, REQUEST_LIMIT
from .utils import create_custom_socket, filter_domain, is_ipaddr
from .base_cache import CacheEntry
from . import netstring


ZoneEntry = collections.namedtuple('ZoneEntry', ('strict', 'resolver', 'require_sni'))


# pylint: disable=too-many-instance-attributes
class STSSocketmapResponder:
    def __init__(self, cfg, loop, cache):
        self._logger = logging.getLogger("STS")
        self._loop = loop
        if cfg.get('path') is not None:
            self._unix = True
            self._path = cfg['path']
            self._sockmode = cfg.get('mode')
        else:
            self._unix = False
            self._host = cfg['host']
            self._port = cfg['port']
        self._reuse_port = cfg['reuse_port']
        self._shutdown_timeout = cfg['shutdown_timeout']
        self._grace = cfg['cache_grace']

        # Construct configurations and resolvers for every socketmap name
        self._default_zone = ZoneEntry(cfg["default_zone"]["strict_testing"],
                                       STSResolver(loop=loop,
                                                   timeout=cfg["default_zone"]["timeout"]),
                                       cfg["default_zone"]["require_sni"])

        self._zones = dict((k, ZoneEntry(zone["strict_testing"],
                                         STSResolver(loop=loop,
                                                     timeout=zone["timeout"]),
                                         zone["require_sni"]))
                           for k, zone in cfg["zones"].items())

        self._cache = cache
        self._children = set()
        self._server = None

    # Check if cached record is nonexistent or stale
    def is_stale(self, cached):
        ts = time.time()  # pylint: disable=invalid-name

        # Nonexistent ?
        if cached is None:
            return True

        # Expired grace period ?
        if ts - cached.ts > self._grace:
            return True

        # Expired policy ?
        if cached.pol_body['max_age'] + cached.ts < ts:
            return True

        return False

    async def start(self):
        def _spawn(reader, writer):
            def done_cb(task, fut):
                self._children.discard(task)
            task = self._loop.create_task(self.handler(reader, writer))
            task.add_done_callback(partial(done_cb, task))
            self._children.add(task)
            self._logger.debug("len(self._children) = %d", len(self._children))

        if self._unix:
            self._server = await asyncio.start_unix_server(_spawn, path=self._path)
            if self._sockmode is not None:
                os.chmod(self._path, self._sockmode)
        else:
            if self._reuse_port: # pragma: no cover
                if sys.platform in ('win32', 'cygwin'):
                    opts = {
                        'host': self._host,
                        'port': self._port,
                        'reuse_address': True,
                    }
                elif os.name == 'posix':
                    if sys.platform.startswith('freebsd'):
                        sockopts = [
                            (socket.SOL_SOCKET, socket.SO_REUSEADDR, 1),
                            (socket.SOL_SOCKET, 0x10000, 1),  # SO_REUSEPORT_LB
                        ]
                        sock = await create_custom_socket(self._host, self._port,
                                                          options=sockopts)
                        opts = {
                            'sock': sock,
                        }
                    else:
                        opts = {
                            'host': self._host,
                            'port': self._port,
                            'reuse_address': True,
                            'reuse_port': True,
                        }
            self._server = await asyncio.start_server(_spawn, **opts)

    async def stop(self):
        self._server.close()
        await self._server.wait_closed()
        while True:
            self._logger.warning("Awaiting %d client handlers to finish...",
                                 len(self._children))
            remaining = asyncio.gather(*self._children, return_exceptions=True)
            self._children.clear()
            try:
                await asyncio.wait_for(remaining, self._shutdown_timeout)
            except asyncio.TimeoutError:
                self._logger.warning("Shutdown timeout expired. "
                                     "Remaining handlers terminated.")
                try:
                    await remaining
                except asyncio.CancelledError:
                    pass
            await asyncio.sleep(1)
            if not self._children:
                break

    async def sender(self, queue, writer):
        def cleanup_queue():
            while not queue.empty():
                task = queue.get_nowait()
                try:
                    task.cancel()
                except Exception:  # pragma: no cover
                    pass

        try:
            while True:
                fut = await queue.get()
                # Check for shutdown
                if fut is None:
                    return
                self._logger.debug("Got new future from queue")
                data = await fut
                self._logger.debug("Future await complete: data=%s", repr(data))
                writer.write(data)
                self._logger.debug("Wrote: %s", repr(data))
                await writer.drain()
        except asyncio.CancelledError:
            cleanup_queue()
        except Exception as exc: # pragma: no cover
            self._logger.exception("Exception in sender coro: %s", exc)
            cleanup_queue()
        finally:
            writer.close()

    # pylint: disable=too-many-locals,too-many-branches,too-many-statements
    async def process_request(self, raw_req):
        have_policy = True

        # Parse request and canonicalize domain
        req_zone, _, req_domain = raw_req.decode('ascii').partition(' ')
        domain = filter_domain(req_domain)

        # Skip lookups for parent domain policies
        # Skip lookups to non-domains
        if domain.startswith('.') or is_ipaddr(domain):
            return netstring.encode(b'NOTFOUND ')

        # Find appropriate zone config
        if req_zone in self._zones:
            zone_cfg = self._zones[req_zone]
        else:
            zone_cfg = self._default_zone

        # Lookup for cached policy
        try:
            cached = await self._cache.get(domain)
        except asyncio.CancelledError:  # pragma: no cover pylint: disable=try-except-raise
            raise
        except Exception as exc:  # pragma: no cover
            self._logger.exception("Cache get failed: %s", str(exc))
            cached = None

        # DNS lookup and cache update
        if self.is_stale(cached):
            ts = time.time()  # pylint: disable=invalid-name
            self._logger.debug("Lookup PERFORMED: domain = %s", domain)
            # Check if newer policy exists or
            # retrieve policy from scratch if there is no cached one
            latest_pol_id = None if cached is None else cached.pol_id
            status, policy = await zone_cfg.resolver.resolve(domain, latest_pol_id)

            if status is STSFetchResult.NOT_CHANGED:
                cached = CacheEntry(ts, cached.pol_id, cached.pol_body)
                await self._cache.safe_set(domain, cached, self._logger)
            elif status is STSFetchResult.VALID:
                pol_id, pol_body = policy
                cached = CacheEntry(ts, pol_id, pol_body)
                await self._cache.safe_set(domain, cached, self._logger)
            else:
                if cached is None:
                    have_policy = False
                else:
                    # Check if cached policy is expired
                    if cached.pol_body['max_age'] + cached.ts < ts:
                        have_policy = False
        else:
            self._logger.debug("Lookup skipped: domain = %s", domain)

        if have_policy:
            mode = cached.pol_body['mode']
            # pylint: disable=no-else-return
            if mode == 'none' or (mode == 'testing' and not zone_cfg.strict):
                return netstring.encode(b'NOTFOUND ')
            else:
                assert cached.pol_body['mx'], "Empty MX list for restrictive policy!"
                mxlist = [mx.lstrip('*') for mx in set(cached.pol_body['mx'])]
                resp = "OK secure match=" + ":".join(mxlist)
                if zone_cfg.require_sni:
                    resp += " servername=hostname"
                return netstring.encode(resp.encode('utf-8'))
        else:
            return netstring.encode(b'NOTFOUND ')

    async def handler(self, reader, writer):
        # Construct netstring parser
        stream_reader = netstring.StreamReader(REQUEST_LIMIT)

        # Construct queue for responses ordering
        queue = asyncio.Queue(QUEUE_LIMIT)

        # Create coroutine which awaits for steady responses and sends them
        sender = asyncio.ensure_future(self.sender(queue, writer), loop=self._loop)

        class EndOfStream(Exception):
            pass

        async def finalize():
            try:
                await queue.put(None)
            except asyncio.CancelledError:  # pragma: no cover
                sender.cancel()
                raise
            await sender

        try:
            while True:
                # Extract and parse request
                string_reader = stream_reader.next_string()
                request_parts = []
                while True:
                    try:
                        buf = string_reader.read()
                    except netstring.WantRead:
                        part = await reader.read(CHUNK)
                        if not part:
                            raise EndOfStream()
                        self._logger.debug("Read: %s", repr(part))
                        stream_reader.feed(part)
                    else:
                        if buf:
                            request_parts.append(buf)
                        else:
                            req = b''.join(request_parts)
                            self._logger.debug("Enq request: %s", repr(req))
                            fut = asyncio.ensure_future(self.process_request(req), loop=self._loop)
                            await queue.put(fut)
                            break
        except netstring.ParseError:
            self._logger.warning("Bad netstring message received")
            await finalize()
        except (EndOfStream, ConnectionError, TimeoutError):
            self._logger.debug("Client disconnected")
            await finalize()
        except OSError as exc:  # pragma: no cover
            if exc.errno == 107:
                self._logger.debug("Client disconnected")
                await finalize()
            else:
                self._logger.exception("Unhandled exception: %s", exc)
                await finalize()
        except asyncio.CancelledError:
            sender.cancel()
            raise
        except Exception as exc:  # pragma: no cover
            self._logger.exception("Unhandled exception: %s", exc)
            await finalize()