import asyncio
import json
import types
from typing import List

import pytest
import websockets

from gql import Client, gql
from gql.transport.exceptions import (
    TransportAlreadyConnected,
    TransportClosed,
    TransportProtocolError,
    TransportQueryError,
)
from gql.transport.websockets import WebsocketsTransport

from .conftest import MS, WebSocketServer

invalid_query_str = """
    query getContinents {
      continents {
        code
        bloh
      }
    }
"""

invalid_query1_server_answer = (
    '{{"type":"data","id":"{query_id}",'
    '"payload":{{"errors":['
    '{{"message":"Cannot query field \\"bloh\\" on type \\"Continent\\".",'
    '"locations":[{{"line":4,"column":5}}],'
    '"extensions":{{"code":"INTERNAL_SERVER_ERROR"}}}}]}}}}'
)

invalid_query1_server = [invalid_query1_server_answer]


@pytest.mark.asyncio
@pytest.mark.parametrize("server", [invalid_query1_server], indirect=True)
@pytest.mark.parametrize("query_str", [invalid_query_str])
async def test_websocket_invalid_query(event_loop, client_and_server, query_str):

    session, server = client_and_server

    query = gql(query_str)

    with pytest.raises(TransportQueryError) as exc_info:
        await session.execute(query)

    exception = exc_info.value

    assert isinstance(exception.errors, List)

    error = exception.errors[0]

    assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR"


invalid_subscription_str = """
    subscription getContinents {
      continents {
        code
        bloh
      }
    }
"""


async def server_invalid_subscription(ws, path):
    await WebSocketServer.send_connection_ack(ws)
    await ws.recv()
    await ws.send(invalid_query1_server_answer.format(query_id=1))
    await WebSocketServer.send_complete(ws, 1)
    await ws.wait_closed()


@pytest.mark.asyncio
@pytest.mark.parametrize("server", [server_invalid_subscription], indirect=True)
@pytest.mark.parametrize("query_str", [invalid_subscription_str])
async def test_websocket_invalid_subscription(event_loop, client_and_server, query_str):

    session, server = client_and_server

    query = gql(query_str)

    with pytest.raises(TransportQueryError) as exc_info:
        async for result in session.subscribe(query):
            pass

    exception = exc_info.value

    assert isinstance(exception.errors, List)

    error = exception.errors[0]

    assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR"


connection_error_server_answer = (
    '{"type":"connection_error","id":null,'
    '"payload":{"message":"Unexpected token Q in JSON at position 0"}}'
)


async def server_no_ack(ws, path):
    await ws.wait_closed()


@pytest.mark.asyncio
@pytest.mark.parametrize("server", [server_no_ack], indirect=True)
@pytest.mark.parametrize("query_str", [invalid_query_str])
async def test_websocket_server_does_not_send_ack(event_loop, server, query_str):

    url = f"ws://{server.hostname}:{server.port}/graphql"

    sample_transport = WebsocketsTransport(url=url, ack_timeout=1)

    with pytest.raises(asyncio.TimeoutError):
        async with Client(transport=sample_transport):
            pass


async def server_connection_error(ws, path):
    await WebSocketServer.send_connection_ack(ws)
    result = await ws.recv()
    print(f"Server received: {result}")
    await ws.send(connection_error_server_answer)
    await ws.wait_closed()


@pytest.mark.asyncio
@pytest.mark.parametrize("server", [server_connection_error], indirect=True)
@pytest.mark.parametrize("query_str", [invalid_query_str])
async def test_websocket_sending_invalid_data(event_loop, client_and_server, query_str):

    session, server = client_and_server

    invalid_data = "QSDF"
    print(f">>> {invalid_data}")
    await session.transport.websocket.send(invalid_data)

    await asyncio.sleep(2 * MS)


invalid_payload_server_answer = (
    '{"type":"error","id":"1","payload":{"message":"Must provide document"}}'
)


async def server_invalid_payload(ws, path):
    await WebSocketServer.send_connection_ack(ws)
    result = await ws.recv()
    print(f"Server received: {result}")
    await ws.send(invalid_payload_server_answer)
    await WebSocketServer.wait_connection_terminate(ws)
    await ws.wait_closed()


@pytest.mark.asyncio
@pytest.mark.parametrize("server", [server_invalid_payload], indirect=True)
@pytest.mark.parametrize("query_str", [invalid_query_str])
async def test_websocket_sending_invalid_payload(
    event_loop, client_and_server, query_str
):

    session, server = client_and_server

    # Monkey patching the _send_query method to send an invalid payload

    async def monkey_patch_send_query(
        self, document, variable_values=None, operation_name=None,
    ) -> int:
        query_id = self.next_query_id
        self.next_query_id += 1

        query_str = json.dumps(
            {"id": str(query_id), "type": "start", "payload": "BLAHBLAH"}
        )

        await self._send(query_str)
        return query_id

    session.transport._send_query = types.MethodType(
        monkey_patch_send_query, session.transport
    )

    query = gql(query_str)

    with pytest.raises(TransportQueryError) as exc_info:
        await session.execute(query)

    exception = exc_info.value

    assert isinstance(exception.errors, List)

    error = exception.errors[0]

    assert error["message"] == "Must provide document"


not_json_answer = ["BLAHBLAH"]
missing_type_answer = ["{}"]
missing_id_answer_1 = ['{"type": "data"}']
missing_id_answer_2 = ['{"type": "error"}']
missing_id_answer_3 = ['{"type": "complete"}']
data_without_payload = ['{"type": "data", "id":"1"}']
error_without_payload = ['{"type": "error", "id":"1"}']
payload_is_not_a_dict = ['{"type": "data", "id":"1", "payload": "BLAH"}']
empty_payload = ['{"type": "data", "id":"1", "payload": {}}']
sending_bytes = [b"\x01\x02\x03"]


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "server",
    [
        not_json_answer,
        missing_type_answer,
        missing_id_answer_1,
        missing_id_answer_2,
        missing_id_answer_3,
        data_without_payload,
        error_without_payload,
        payload_is_not_a_dict,
        empty_payload,
        sending_bytes,
    ],
    indirect=True,
)
async def test_websocket_transport_protocol_errors(event_loop, client_and_server):

    session, server = client_and_server

    query = gql("query { hello }")

    with pytest.raises(TransportProtocolError):
        await session.execute(query)


async def server_without_ack(ws, path):
    # Sending something else than an ack
    await WebSocketServer.send_keepalive(ws)
    await ws.wait_closed()


@pytest.mark.asyncio
@pytest.mark.parametrize("server", [server_without_ack], indirect=True)
async def test_websocket_server_does_not_ack(event_loop, server):

    url = f"ws://{server.hostname}:{server.port}/graphql"
    print(f"url = {url}")

    sample_transport = WebsocketsTransport(url=url)

    with pytest.raises(TransportProtocolError):
        async with Client(transport=sample_transport):
            pass


async def server_closing_directly(ws, path):
    await ws.close()


@pytest.mark.asyncio
@pytest.mark.parametrize("server", [server_closing_directly], indirect=True)
async def test_websocket_server_closing_directly(event_loop, server):

    url = f"ws://{server.hostname}:{server.port}/graphql"
    print(f"url = {url}")

    sample_transport = WebsocketsTransport(url=url)

    with pytest.raises(websockets.exceptions.ConnectionClosed):
        async with Client(transport=sample_transport):
            pass


async def server_closing_after_ack(ws, path):
    await WebSocketServer.send_connection_ack(ws)
    await ws.close()


@pytest.mark.asyncio
@pytest.mark.parametrize("server", [server_closing_after_ack], indirect=True)
async def test_websocket_server_closing_after_ack(event_loop, client_and_server):

    session, server = client_and_server

    query = gql("query { hello }")

    with pytest.raises(websockets.exceptions.ConnectionClosed):
        await session.execute(query)

    await session.transport.wait_closed()

    with pytest.raises(TransportClosed):
        await session.execute(query)


async def server_sending_invalid_query_errors(ws, path):
    await WebSocketServer.send_connection_ack(ws)
    invalid_error = (
        '{"type":"error","id":"404","payload":'
        '{"message":"error for no good reason on non existing query"}}'
    )
    await ws.send(invalid_error)
    await ws.wait_closed()


@pytest.mark.asyncio
@pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True)
async def test_websocket_server_sending_invalid_query_errors(event_loop, server):
    url = f"ws://{server.hostname}:{server.port}/graphql"
    print(f"url = {url}")

    sample_transport = WebsocketsTransport(url=url)

    # Invalid server message is ignored
    async with Client(transport=sample_transport):
        await asyncio.sleep(2 * MS)


@pytest.mark.asyncio
@pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True)
async def test_websocket_non_regression_bug_105(event_loop, server):

    # This test will check a fix to a race condition which happens if the user is trying
    # to connect using the same client twice at the same time
    # See bug #105

    url = f"ws://{server.hostname}:{server.port}/graphql"
    print(f"url = {url}")

    sample_transport = WebsocketsTransport(url=url)

    client = Client(transport=sample_transport)

    # Create a coroutine which start the connection with the transport but does nothing
    async def client_connect(client):
        async with client:
            await asyncio.sleep(2 * MS)

    # Create two tasks which will try to connect using the same client (not allowed)
    connect_task1 = asyncio.ensure_future(client_connect(client))
    connect_task2 = asyncio.ensure_future(client_connect(client))

    with pytest.raises(TransportAlreadyConnected):
        await asyncio.gather(connect_task1, connect_task2)