import asyncio
import socket
import struct
from asyncio import sslproto

from . import constants as c
from .helpers import (
    Socks4Addr, Socks5Addr, Socks5Auth, Socks4Auth
)
from .errors import (
    SocksError, NoAcceptableAuthMethods, LoginAuthenticationFailed,
    InvalidServerReply, InvalidServerVersion
)


DEFAULT_LIMIT = getattr(asyncio.streams, '_DEFAULT_LIMIT', 2**16)


class BaseSocksProtocol(asyncio.StreamReaderProtocol):
    def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, *,
                 remote_resolve=True, loop=None, ssl=False,
                 server_hostname=None, negotiate_done_cb=None,
                 reader_limit=DEFAULT_LIMIT):
        if not isinstance(dst, (tuple, list)) or len(dst) != 2:
            raise ValueError(
                'Invalid dst format, tuple("dst_host", dst_port))'
            )

        self._proxy = proxy
        self._auth = proxy_auth
        self._dst_host, self._dst_port = dst
        self._remote_resolve = remote_resolve
        self._waiter = waiter
        self._ssl = ssl
        self._server_hostname = server_hostname
        self._negotiate_done_cb = negotiate_done_cb
        self._loop = loop or asyncio.get_event_loop()

        self._transport = None
        self._negotiate_done = False
        self._proxy_peername = None
        self._proxy_sockname = None

        if app_protocol_factory:
            self._app_protocol = app_protocol_factory()
        else:
            self._app_protocol = self

        reader = asyncio.StreamReader(loop=self._loop, limit=reader_limit)

        super().__init__(stream_reader=reader,
                         client_connected_cb=self.negotiate, loop=self._loop)

    async def negotiate(self, reader, writer):
        try:
            req = self.socks_request(c.SOCKS_CMD_CONNECT)
            self._proxy_peername, self._proxy_sockname = await req
        except SocksError as exc:
            exc = SocksError('Can not connect to %s:%s. %s' %
                             (self._dst_host, self._dst_port, exc))
            if not self._waiter.cancelled():
                self._loop.call_soon(self._waiter.set_exception, exc)
        except Exception as exc:
            if not self._waiter.cancelled():
                self._loop.call_soon(self._waiter.set_exception, exc)
        else:
            self._negotiate_done = True

            if self._ssl:
                # Creating a ssl transport needs to be reworked.
                # See details: http://bugs.python.org/issue23749
                self._tls_protocol = sslproto.SSLProtocol(
                    app_protocol=self, sslcontext=self._ssl, server_side=False,
                    server_hostname=self._server_hostname, waiter=self._waiter,
                    loop=self._loop, call_connection_made=False)

                # starttls
                original_transport = self._transport
                self._transport.set_protocol(self._tls_protocol)
                self._transport = self._tls_protocol._app_transport

                self._tls_protocol.connection_made(original_transport)

                self._loop.call_soon(self._app_protocol.connection_made,
                                     self._transport)
            else:
                self._loop.call_soon(self._app_protocol.connection_made,
                                     self._transport)
                self._loop.call_soon(self._waiter.set_result, True)

            if self._negotiate_done_cb is not None:
                res = self._negotiate_done_cb(reader, writer)

                if asyncio.iscoroutine(res):
                    self._loop.create_task(res)
                    return res

    def connection_made(self, transport):
        # connection_made is called
        if self._transport:
            return

        super().connection_made(transport)
        self._transport = transport

    def connection_lost(self, exc):
        if self._negotiate_done and self._app_protocol is not self:
            self._loop.call_soon(self._app_protocol.connection_lost, exc)
        super().connection_lost(exc)

    def pause_writing(self):
        if self._negotiate_done and self._app_protocol is not self:
            self._app_protocol.pause_writing()
        else:
            super().pause_writing()

    def resume_writing(self):
        if self._negotiate_done and self._app_protocol is not self:
            self._app_protocol.resume_writing()
        else:
            super().resume_writing()

    def data_received(self, data):
        if self._negotiate_done and self._app_protocol is not self:
            self._app_protocol.data_received(data)
        else:
            super().data_received(data)

    def eof_received(self):
        if self._negotiate_done and self._app_protocol is not self:
            self._app_protocol.eof_received()
        super().eof_received()

    async def socks_request(self, cmd):
        raise NotImplementedError

    def write_request(self, request):
        bdata = bytearray()

        for item in request:
            if isinstance(item, int):
                bdata.append(item)
            elif isinstance(item, (bytearray, bytes)):
                bdata += item
            else:
                raise ValueError('Unsupported item')
        self._stream_writer.write(bdata)

    async def read_response(self, n):
        try:
            return (await self._stream_reader.readexactly(n))
        except asyncio.IncompleteReadError as e:
            raise InvalidServerReply(
                'Server sent fewer bytes than required (%s)' % str(e))

    async def _get_dst_addr(self):
        infos = await self._loop.getaddrinfo(
            self._dst_host, self._dst_port, family=socket.AF_UNSPEC,
            type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP,
            flags=socket.AI_ADDRCONFIG)
        if not infos:
            raise OSError('getaddrinfo() returned empty list')
        return infos[0][0], infos[0][4][0]

    @property
    def app_protocol(self):
        return self._app_protocol

    @property
    def app_transport(self):
        return self._transport

    @property
    def proxy_sockname(self):
        """
        Returns the bound IP address and port number at the proxy.
        """
        return self._proxy_sockname

    @property
    def proxy_peername(self):
        """
        Returns the IP and port number of the proxy.
        """
        sock = self._transport.get_extra_info('socket')
        return sock.peername if sock else None

    @property
    def peername(self):
        """
        Returns the IP address and port number of the destination
        machine (note: get_proxy_peername returns the proxy)
        """
        return self._proxy_peername

    @property
    def reader(self):
        return self._stream_reader

    @property
    def writer(self):
        return self._stream_writer


class Socks4Protocol(BaseSocksProtocol):
    def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter,
                 remote_resolve=True, loop=None, ssl=False,
                 server_hostname=None, negotiate_done_cb=None,
                 reader_limit=DEFAULT_LIMIT):
        proxy_auth = proxy_auth or Socks4Auth('')

        if not isinstance(proxy, Socks4Addr):
            raise ValueError('Invalid proxy format')

        if not isinstance(proxy_auth, Socks4Auth):
            raise ValueError('Invalid proxy_auth format')

        super().__init__(proxy, proxy_auth, dst, app_protocol_factory,
                         waiter, remote_resolve=remote_resolve, loop=loop,
                         ssl=ssl, server_hostname=server_hostname,
                         reader_limit=reader_limit,
                         negotiate_done_cb=negotiate_done_cb)

    async def socks_request(self, cmd):
        # prepare destination addr/port
        host, port = self._dst_host, self._dst_port
        port_bytes = struct.pack(b'>H', port)
        include_hostname = False

        try:
            host_bytes = socket.inet_aton(host)
        except socket.error:
            if self._remote_resolve:
                host_bytes = bytes([c.NULL, c.NULL, c.NULL, 0x01])
                include_hostname = True
            else:
                # it's not an IP number, so it's probably a DNS name.
                family, host = await self._get_dst_addr()
                host_bytes = socket.inet_aton(host)

        # build and send connect command
        req = [c.SOCKS_VER4, cmd, port_bytes,
               host_bytes, self._auth.login, c.NULL]
        if include_hostname:
            req += [self._dst_host.encode('idna'), c.NULL]

        self.write_request(req)

        # read/process result
        resp = await self.read_response(8)

        if resp[0] != c.NULL:
            raise InvalidServerReply('SOCKS4 proxy server sent invalid data')
        if resp[1] != c.SOCKS4_GRANTED:
            error = c.SOCKS4_ERRORS.get(resp[1], 'Unknown error')
            raise SocksError('[Errno {0:#04x}]: {1}'.format(resp[1], error))

        binded = socket.inet_ntoa(resp[4:]), struct.unpack('>H', resp[2:4])[0]
        return (host, port), binded


class Socks5Protocol(BaseSocksProtocol):
    def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter,
                 remote_resolve=True, loop=None, ssl=False,
                 server_hostname=None, negotiate_done_cb=None,
                 reader_limit=DEFAULT_LIMIT):
        proxy_auth = proxy_auth or Socks5Auth('', '')

        if not isinstance(proxy, Socks5Addr):
            raise ValueError('Invalid proxy format')

        if not isinstance(proxy_auth, Socks5Auth):
            raise ValueError('Invalid proxy_auth format')

        super().__init__(proxy, proxy_auth, dst, app_protocol_factory,
                         waiter, remote_resolve=remote_resolve, loop=loop,
                         ssl=ssl, server_hostname=server_hostname,
                         reader_limit=reader_limit,
                         negotiate_done_cb=negotiate_done_cb)

    async def socks_request(self, cmd):
        await self.authenticate()

        # build and send command
        dst_addr, resolved = await self.build_dst_address(
            self._dst_host, self._dst_port)
        self.write_request([c.SOCKS_VER5, cmd, c.RSV] + dst_addr)

        # read/process command response
        resp = await self.read_response(3)

        if resp[0] != c.SOCKS_VER5:
            raise InvalidServerVersion(
                'SOCKS5 proxy server sent invalid version'
            )
        if resp[1] != c.SOCKS5_GRANTED:
            error = c.SOCKS5_ERRORS.get(resp[1], 'Unknown error')
            raise SocksError('[Errno {0:#04x}]: {1}'.format(resp[1], error))

        binded = await self.read_address()

        return resolved, binded

    async def authenticate(self):
        # send available auth methods
        if self._auth.login and self._auth.password:
            req = [c.SOCKS_VER5, 0x02,
                   c.SOCKS5_AUTH_ANONYMOUS, c.SOCKS5_AUTH_UNAME_PWD]
        else:
            req = [c.SOCKS_VER5, 0x01, c.SOCKS5_AUTH_ANONYMOUS]

        self.write_request(req)

        # read/process response and send auth data if necessary
        chosen_auth = await self.read_response(2)

        if chosen_auth[0] != c.SOCKS_VER5:
            raise InvalidServerVersion(
                'SOCKS5 proxy server sent invalid version'
            )

        if chosen_auth[1] == c.SOCKS5_AUTH_UNAME_PWD:
            req = [0x01, chr(len(self._auth.login)).encode(), self._auth.login,
                   chr(len(self._auth.password)).encode(), self._auth.password]
            self.write_request(req)

            auth_status = await self.read_response(2)
            if auth_status[0] != 0x01:
                raise InvalidServerReply(
                    'SOCKS5 proxy server sent invalid data'
                )
            if auth_status[1] != c.SOCKS5_GRANTED:
                raise LoginAuthenticationFailed(
                    "SOCKS5 authentication failed"
                )
        # offered auth methods rejected
        elif chosen_auth[1] != c.SOCKS5_AUTH_ANONYMOUS:
            if chosen_auth[1] == c.SOCKS5_AUTH_NO_ACCEPTABLE_METHODS:
                raise NoAcceptableAuthMethods(
                    'All offered SOCKS5 authentication methods were rejected'
                )
            else:
                raise InvalidServerReply(
                    'SOCKS5 proxy server sent invalid data'
                )

    async def build_dst_address(self, host, port):
        family_to_byte = {socket.AF_INET: c.SOCKS5_ATYP_IPv4,
                          socket.AF_INET6: c.SOCKS5_ATYP_IPv6}
        port_bytes = struct.pack('>H', port)

        # if the given destination address is an IP address, we will
        # use the IP address request even if remote resolving was specified.
        for family in (socket.AF_INET, socket.AF_INET6):
            try:
                host_bytes = socket.inet_pton(family, host)
                req = [family_to_byte[family], host_bytes, port_bytes]
                return req, (host, port)
            except socket.error:
                pass

        # it's not an IP number, so it's probably a DNS name.
        if self._remote_resolve:
            host_bytes = host.encode('idna')
            req = [c.SOCKS5_ATYP_DOMAIN, chr(len(host_bytes)).encode(),
                   host_bytes, port_bytes]
        else:
            family, host_bytes = await self._get_dst_addr()
            host_bytes = socket.inet_pton(family, host_bytes)
            req = [family_to_byte[family], host_bytes, port_bytes]
            host = socket.inet_ntop(family, host_bytes)

        return req, (host, port)

    async def read_address(self):
        atype = await self.read_response(1)

        if atype[0] == c.SOCKS5_ATYP_IPv4:
            addr = socket.inet_ntoa((await self.read_response(4)))
        elif atype[0] == c.SOCKS5_ATYP_DOMAIN:
            length = await self.read_response(1)
            addr = await self.read_response(ord(length))
        elif atype[0] == c.SOCKS5_ATYP_IPv6:
            addr = await self.read_response(16)
            addr = socket.inet_ntop(socket.AF_INET6, addr)
        else:
            raise InvalidServerReply('SOCKS5 proxy server sent invalid data')

        port = await self.read_response(2)
        port = struct.unpack('>H', port)[0]

        return addr, port