import asyncio
import json
import logging
import os
import pathlib
import ssl
import types

import pytest
import websockets
from aiohttp.test_utils import TestServer as AIOHTTPTestServer
from websockets.exceptions import ConnectionClosed

from gql import Client
from gql.transport.websockets import WebsocketsTransport


def pytest_addoption(parser):
    parser.addoption(
        "--run-online",
        action="store_true",
        default=False,
        help="run tests necessitating online resources",
    )


def pytest_configure(config):
    config.addinivalue_line(
        "markers", "online: mark test as necessitating external online resources"
    )


def pytest_collection_modifyitems(config, items):
    if config.getoption("--run-online"):
        # --run-online given in cli: do not skip online tests
        return
    skip_online = pytest.mark.skip(reason="need --run-online option to run")
    for item in items:
        if "online" in item.keywords:
            item.add_marker(skip_online)


@pytest.fixture
async def aiohttp_server():
    """Factory to create a TestServer instance, given an app.

    aiohttp_server(app, **kwargs)
    """
    servers = []

    async def go(app, *, port=None, **kwargs):  # type: ignore
        server = AIOHTTPTestServer(app, port=port)
        await server.start_server(**kwargs)
        servers.append(server)
        return server

    yield go

    while servers:
        await servers.pop().close()


# Adding debug logs to websocket tests
for name in ["websockets.server", "gql.transport.websockets"]:
    logger = logging.getLogger(name)
    logger.setLevel(logging.DEBUG)

    if len(logger.handlers) < 1:
        logger.addHandler(logging.StreamHandler())

# Unit for timeouts. May be increased on slow machines by setting the
# GQL_TESTS_TIMEOUT_FACTOR environment variable.
# Copied from websockets source
MS = 0.001 * int(os.environ.get("GQL_TESTS_TIMEOUT_FACTOR", 1))


class WebSocketServer:
    """Websocket server on localhost on a free port.

    This server allows us to test our client by simulating different correct and
    incorrect server responses.
    """

    def __init__(self, with_ssl: bool = False):
        self.with_ssl = with_ssl

    async def start(self, handler):

        print("Starting server")

        extra_serve_args = {}

        if self.with_ssl:
            # This is a copy of certificate from websockets tests folder
            #
            # Generate TLS certificate with:
            # $ openssl req -x509 -config test_localhost.cnf \
            #       -days 15340 -newkey rsa:2048 \
            #       -out test_localhost.crt -keyout test_localhost.key
            # $ cat test_localhost.key test_localhost.crt > test_localhost.pem
            # $ rm test_localhost.key test_localhost.crt
            self.testcert = bytes(
                pathlib.Path(__file__).with_name("test_localhost.pem")
            )
            ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
            ssl_context.load_cert_chain(self.testcert)

            extra_serve_args["ssl"] = ssl_context

        # Start a server with a random open port
        self.start_server = websockets.server.serve(
            handler, "127.0.0.1", 0, **extra_serve_args
        )

        # Wait that the server is started
        self.server = await self.start_server

        # Get hostname and port
        hostname, port = self.server.sockets[0].getsockname()[:2]
        assert hostname == "127.0.0.1"

        self.hostname = hostname
        self.port = port

        print(f"Server started on port {port}")

    async def stop(self):
        print("Stopping server")

        self.server.close()
        try:
            await asyncio.wait_for(self.server.wait_closed(), timeout=1)
        except asyncio.TimeoutError:  # pragma: no cover
            assert False, "Server failed to stop"

        print("Server stopped\n\n\n")

    @staticmethod
    async def send_complete(ws, query_id):
        await ws.send(f'{{"type":"complete","id":"{query_id}","payload":null}}')

    @staticmethod
    async def send_keepalive(ws):
        await ws.send('{"type":"ka"}')

    @staticmethod
    async def send_connection_ack(ws):

        # Line return for easy debugging
        print("")

        # Wait for init
        result = await ws.recv()
        json_result = json.loads(result)
        assert json_result["type"] == "connection_init"

        # Send ack
        await ws.send('{"type":"connection_ack"}')

    @staticmethod
    async def wait_connection_terminate(ws):
        result = await ws.recv()
        json_result = json.loads(result)
        assert json_result["type"] == "connection_terminate"


def get_server_handler(request):
    """Get the server handler.

    Either get it from test or use the default server handler
    if the test provides only an array of answers.
    """

    if isinstance(request.param, types.FunctionType):
        server_handler = request.param

    else:
        answers = request.param

        async def default_server_handler(ws, path):

            try:
                await WebSocketServer.send_connection_ack(ws)
                query_id = 1

                for answer in answers:
                    result = await ws.recv()
                    print(f"Server received: {result}")

                    if isinstance(answer, str) and "{query_id}" in answer:
                        answer_format_params = {"query_id": query_id}
                        formatted_answer = answer.format(**answer_format_params)
                    else:
                        formatted_answer = answer

                    await ws.send(formatted_answer)
                    await WebSocketServer.send_complete(ws, query_id)
                    query_id += 1

                await WebSocketServer.wait_connection_terminate(ws)
                await ws.wait_closed()
            except ConnectionClosed:
                pass

        server_handler = default_server_handler

    return server_handler


@pytest.fixture
async def ws_ssl_server(request):
    """Websockets server fixture using SSL.

    It can take as argument either a handler function for the websocket server for
    complete control OR an array of answers to be sent by the default server handler.
    """

    server_handler = get_server_handler(request)

    try:
        test_server = WebSocketServer(with_ssl=True)

        # Starting the server with the fixture param as the handler function
        await test_server.start(server_handler)

        yield test_server
    except Exception as e:
        print("Exception received in ws server fixture:", e)
    finally:
        await test_server.stop()


@pytest.fixture
async def server(request):
    """Fixture used to start a dummy server to test the client behaviour.

    It can take as argument either a handler function for the websocket server for
    complete control OR an array of answers to be sent by the default server handler.
    """

    server_handler = get_server_handler(request)

    try:
        test_server = WebSocketServer()

        # Starting the server with the fixture param as the handler function
        await test_server.start(server_handler)

        yield test_server
    except Exception as e:
        print("Exception received in server fixture:", e)
    finally:
        await test_server.stop()


@pytest.fixture
async def client_and_server(server):
    """Helper fixture to start a server and a client connected to its port."""

    # Generate transport to connect to the server fixture
    path = "/graphql"
    url = f"ws://{server.hostname}:{server.port}{path}"
    sample_transport = WebsocketsTransport(url=url)

    async with Client(transport=sample_transport) as session:

        # Yield both client session and server
        yield session, server