import asyncio
import logging
import collections
import socket
import ctypes
from functools import partial

from . import constants
from .utils import detect_af
from .baselistener import BaseListener


BUFSIZE = constants.BUFSIZE


def detect_af(addr):
    return socket.getaddrinfo(addr,
                              None,
                              socket.AF_UNSPEC,
                              0,
                              0,
                              socket.AI_NUMERICHOST)[0][0]


class sockaddr(ctypes.Structure):
    _fields_ = [('sa_family', ctypes.c_uint16),
                ('sa_data', ctypes.c_char * 14),
               ]


class sockaddr_in(ctypes.Structure):
    _fields_ = [('sin_family', ctypes.c_uint16),
                ('sin_port', ctypes.c_uint16),
                ('sin_addr', ctypes.c_uint32),
               ]


sockaddr_size = max(ctypes.sizeof(sockaddr_in), ctypes.sizeof(sockaddr))


class sockaddr_in6(ctypes.Structure):
    _fields_ = [('sin6_family', ctypes.c_uint16),
                ('sin6_port', ctypes.c_uint16),
                ('sin6_flowinfo', ctypes.c_uint32),
                ('sin6_addr', ctypes.c_char * 16),
                ('sin6_scope_id', ctypes.c_uint32),
               ]


sockaddr6_size = ctypes.sizeof(sockaddr_in6)


def get_orig_dst(sock):
    own_addr = sock.getsockname()[0]
    own_af = detect_af(own_addr)
    if own_af == socket.AF_INET:
        buf = sock.getsockopt(socket.SOL_IP, constants.SO_ORIGINAL_DST, sockaddr_size)
        sa = sockaddr_in.from_buffer_copy(buf)
        addr = socket.ntohl(sa.sin_addr)
        addr = str(addr >> 24) + '.' + str((addr >> 16) & 0xFF) + '.' + str((addr >> 8) & 0xFF) + '.' + str(addr & 0xFF)
        port = socket.ntohs(sa.sin_port)
        return addr, port
    elif own_af == socket.AF_INET6:
        buf = sock.getsockopt(constants.SOL_IPV6, constants.SO_ORIGINAL_DST, sockaddr6_size)
        sa = sockaddr_in6.from_buffer_copy(buf)
        addr = socket.inet_ntop(socket.AF_INET6, sa.sin6_addr)
        port = socket.ntohs(sa.sin_port)
        return addr, port
    else:
        raise RuntimeError("Unknown address family!")


class TransparentListener(BaseListener):  # pylint: disable=too-many-instance-attributes
    def __init__(self, *,
                 listen_address,
                 listen_port,
                 pool,
                 timeout=4,
                 loop=None):
        self._loop = loop if loop is not None else asyncio.get_event_loop()
        self._logger = logging.getLogger(self.__class__.__name__)
        self._listen_address = listen_address
        self._listen_port = listen_port
        self._children = set()
        self._server = None
        self._pool = pool
        self._timeout = timeout

    async def stop(self):
        self._server.close()
        await self._server.wait_closed()
        while self._children:
            children = list(self._children)
            self._children.clear()
            self._logger.debug("Cancelling %d client handlers...",
                               len(children))
            for task in children:
                task.cancel()
            await asyncio.wait(children)
            # workaround for TCP server keeps spawning handlers for a while
            # after wait_closed() completed
            await asyncio.sleep(.5)

    async def _pump(self, writer, reader):
        while True:
            data = await reader.read(BUFSIZE)
            if not data:
                break
            writer.write(data)
            await writer.drain()

    async def handler(self, reader, writer):
        peer_addr = writer.transport.get_extra_info('peername')
        self._logger.info("Client %s connected", str(peer_addr))
        dst_writer = None
        try:
            # Instead get dst addr from socket options
            sock = writer.transport.get_extra_info('socket')
            dst_addr, dst_port = get_orig_dst(sock)
            self._logger.info("Client %s requested connection to %s:%s",
                              peer_addr, dst_addr, dst_port)
            async with self._pool.borrow() as ssh_conn:
                dst_reader, dst_writer = await asyncio.wait_for(
                    ssh_conn.open_connection(dst_addr, dst_port),
                    self._timeout)
                t1 = asyncio.ensure_future(self._pump(writer, dst_reader))
                t2 = asyncio.ensure_future(self._pump(dst_writer, reader))
                try:
                    await asyncio.gather(t1, t2)
                finally:
                    for t in (t1, t2):
                        if not t.done():
                            t.cancel()
                            while not t.done():
                                try:
                                    await t
                                except asyncio.CancelledError:
                                    pass
        except asyncio.CancelledError:  # pylint: disable=try-except-raise
            raise
        except Exception as exc:  # pragma: no cover
            self._logger.exception("Connection handler stopped with exception:"
                                   " %s", str(exc))
        finally:
            self._logger.info("Client %s disconnected", str(peer_addr))
            if dst_writer is not None:
                dst_writer.close()
            writer.close()

    async def start(self):
        def _spawn(reader, writer):
            def task_cb(task, fut):
                self._children.discard(task)
            task = self._loop.create_task(self.handler(reader, writer))
            self._children.add(task)
            task.add_done_callback(partial(task_cb, task))

        self._server = await asyncio.start_server(_spawn,
                                                  self._listen_address,
                                                  self._listen_port)
        self._logger.info("Transparent Proxy server listening on %s:%d",
                          self._listen_address, self._listen_port)

    async def __aenter__(self):
        await self.start()
        return self

    async def __aexit__(self, exc_type, exc, tb):
        await self.stop()