import sys
import enum
import socket
import datetime

import anyio
import async_exit_stack
from async_generator import async_generator, yield_, yield_from_
from purerpc.utils import is_darwin, is_windows
from purerpc.grpclib.exceptions import ProtocolError

from .grpclib.connection import GRPCConfiguration, GRPCConnection
from .grpclib.events import RequestReceived, RequestEnded, ResponseEnded, MessageReceived, WindowUpdated
from .grpclib.buffers import MessageWriteBuffer, MessageReadBuffer
from .grpclib.exceptions import StreamClosedError


class SocketWrapper(async_exit_stack.AsyncExitStack):
    def __init__(self, grpc_connection: GRPCConnection, sock: anyio.SocketStream):
        super().__init__()
        self._set_socket_options(sock)
        self._socket = sock
        self._grpc_connection = grpc_connection
        self._flush_event = anyio.create_event()
        self._running = True

    async def __aenter__(self):
        await super().__aenter__()
        task_group = await self.enter_async_context(anyio.create_task_group())
        await task_group.spawn(self._writer_thread)

        async def callback():
            self._running = False
            await self._flush_event.set()

        self.push_async_callback(callback)
        return self

    @staticmethod
    def _set_socket_options(sock: anyio.SocketStream):
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
        if hasattr(socket, "TCP_KEEPIDLE"):
            sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 300)
        elif is_darwin():
            # Darwin specific option
            TCP_KEEPALIVE = 16
            sock.setsockopt(socket.IPPROTO_TCP, TCP_KEEPALIVE, 300)
        if not is_windows():
            sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 30)
            sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5)
        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

    async def _writer_thread(self):
        while True:
            data = self._grpc_connection.data_to_send()
            if data:
                await self._socket.send_all(data)
            elif self._running:
                await self._flush_event.wait()
                self._flush_event.clear()
            else:
                return

    async def flush(self):
        """This maybe called from different threads."""
        await self._flush_event.set()

    async def recv(self, buffer_size: int):
        """This may only be called from single thread."""
        return await self._socket.receive_some(buffer_size)


class GRPCStreamState(enum.Enum):
    OPEN = 1
    HALF_CLOSED_REMOTE = 2
    HALF_CLOSED_LOCAL = 3
    CLOSED = 4


class GRPCStream:
    def __init__(self, grpc_connection: GRPCConnection, stream_id: int, socket: SocketWrapper,
                 grpc_socket: "GRPCSocket"):
        self._stream_id = stream_id
        self._grpc_connection = grpc_connection
        self._grpc_socket = grpc_socket
        self._socket = socket
        self._flow_control_update_event = anyio.create_event()
        self._incoming_events = anyio.create_queue(sys.maxsize)
        self._response_started = False
        self._state = GRPCStreamState.OPEN
        self._start_stream_event = None
        self._end_stream_event = None

    @property
    def state(self):
        return self._state

    @property
    def start_stream_event(self):
        return self._start_stream_event

    @property
    def end_stream_event(self):
        return self._end_stream_event

    @property
    def stream_id(self):
        return self._stream_id

    @property
    def client_side(self):
        return self._grpc_connection.config.client_side

    @property
    def debug_prefix(self):
        return "[CLIENT] " if self.client_side else "[SERVER] "

    def _close_remote(self):
        if self._state == GRPCStreamState.OPEN:
            self._state = GRPCStreamState.HALF_CLOSED_REMOTE
        elif self._state == GRPCStreamState.HALF_CLOSED_LOCAL:
            self._state = GRPCStreamState.CLOSED
            del self._grpc_socket._streams[self._stream_id]

    def _close_local(self):
        if self._state == GRPCStreamState.OPEN:
            self._state = GRPCStreamState.HALF_CLOSED_LOCAL
        elif self._state == GRPCStreamState.HALF_CLOSED_REMOTE:
            self._state = GRPCStreamState.CLOSED
            del self._grpc_socket._streams[self._stream_id]

    async def _set_flow_control_update(self):
        await self._flow_control_update_event.set()

    async def _wait_flow_control_update(self):
        await self._flow_control_update_event.wait()
        self._flow_control_update_event.clear()

    async def _send(self, message: bytes, compress=False):
        message_write_buffer = MessageWriteBuffer(self._grpc_connection.config.message_encoding,
                                                  self._grpc_connection.config.max_message_length)
        message_write_buffer.write_message(message, compress)
        while message_write_buffer:
            window_size = self._grpc_connection.flow_control_window(self._stream_id)
            if window_size <= 0:
                await self._wait_flow_control_update()
                continue
            num_data_to_send = min(window_size, len(message_write_buffer))
            data = message_write_buffer.data_to_send(num_data_to_send)
            self._grpc_connection.send_data(self._stream_id, data)
            await self._socket.flush()

    async def _receive(self):
        event = await self._incoming_events.get()
        if isinstance(event, MessageReceived):
            self._grpc_connection.acknowledge_received_data(self._stream_id,
                                                            event.flow_controlled_length)
            await self._socket.flush()
        elif isinstance(event, RequestEnded) or isinstance(event, ResponseEnded):
            assert self._end_stream_event is None
            self._end_stream_event = event
        else:
            assert self._start_stream_event is None
            self._start_stream_event = event
        return event

    async def close(self, status=None, content_type_suffix="", custom_metadata=()):
        if self.client_side and (status or custom_metadata):
            raise ValueError("Client side streams cannot be closed with non-default arguments")
        if self._state in (GRPCStreamState.HALF_CLOSED_LOCAL, GRPCStreamState.CLOSED):
            raise TypeError("Closing already closed stream")
        self._close_local()
        if self.client_side:
            try:
                self._grpc_connection.end_request(self._stream_id)
            except StreamClosedError:
                # Remote end already closed connection, do nothing here
                pass
        elif self._response_started:
            self._grpc_connection.end_response(self._stream_id, status, custom_metadata)
        else:
            self._grpc_connection.respond_status(self._stream_id, status,
                                                 content_type_suffix, custom_metadata)
        await self._socket.flush()

    async def start_response(self, content_type_suffix="", custom_metadata=()):
        if self.client_side:
            raise ValueError("Cannot start response on client-side socket")
        self._grpc_connection.start_response(self._stream_id, content_type_suffix, custom_metadata)
        self._response_started = True
        await self._socket.flush()


# TODO: this name is not correct, should be something like GRPCConnection (but this name is already
# occupied)
class GRPCSocket(async_exit_stack.AsyncExitStack):
    StreamClass = GRPCStream

    def __init__(self, config: GRPCConfiguration, sock,
                 receive_buffer_size=1024*1024):
        super().__init__()
        self._grpc_connection = GRPCConnection(config=config)
        self._socket = SocketWrapper(self._grpc_connection, sock)
        self._receive_buffer_size = receive_buffer_size
        self._streams = {}  # type: Dict[int, GRPCStream]

    async def __aenter__(self):
        await super().__aenter__()
        self._socket = await self.enter_async_context(self._socket)
        self._grpc_connection.initiate_connection()
        await self._socket.flush()
        if self.client_side:
            task_group = await self.enter_async_context(anyio.create_task_group())
            self.push_async_callback(task_group.cancel_scope.cancel)
            await task_group.spawn(self._reader_thread)
        return self

    @property
    def client_side(self):
        return self._grpc_connection.config.client_side

    def _stream_ctor(self, stream_id):
        return self.StreamClass(self._grpc_connection, stream_id, self._socket, self)

    def _allocate_stream(self, stream_id):
        self._streams[stream_id] = self._stream_ctor(stream_id)
        return self._streams[stream_id]

    @async_generator
    async def _listen(self):
        while True:
            data = await self._socket.recv(self._receive_buffer_size)
            if not data:
                return
            events = self._grpc_connection.receive_data(data)
            await self._socket.flush()
            for event in events:
                if isinstance(event, WindowUpdated):
                    if event.stream_id == 0:
                        for stream in self._streams.values():
                            await stream._set_flow_control_update()
                    elif event.stream_id in self._streams:
                        await self._streams[event.stream_id]._set_flow_control_update()
                    continue
                elif isinstance(event, RequestReceived):
                    self._allocate_stream(event.stream_id)

                await self._streams[event.stream_id]._incoming_events.put(event)

                if isinstance(event, RequestReceived):
                    await yield_(self._streams[event.stream_id])
                elif isinstance(event, ResponseEnded) or isinstance(event, RequestEnded):
                    self._streams[event.stream_id]._close_remote()

    async def _reader_thread(self):
        async for _ in self._listen():
            raise ProtocolError("Received request on client end")

    @async_generator
    async def listen(self):
        if self.client_side:
            raise ValueError("Cannot listen client-side socket")
        await yield_from_(self._listen())

    async def start_request(self, scheme: str, service_name: str, method_name: str,
                            message_type=None, authority=None, timeout: datetime.timedelta=None,
                            content_type_suffix="", custom_metadata=()):
        if not self.client_side:
            raise ValueError("Cannot start request on server-side socket")
        stream_id = self._grpc_connection.get_next_available_stream_id()
        stream = self._allocate_stream(stream_id)
        self._grpc_connection.start_request(stream_id, scheme, service_name, method_name,
                                            message_type, authority, timeout,
                                            content_type_suffix, custom_metadata)
        await self._socket.flush()
        return stream