# pylint: disable=too-many-instance-attributes
import base64
import datetime
import http
from collections import Counter
from contextlib import contextmanager
from copy import deepcopy
from enum import IntEnum
from logging import LogRecord
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterator, List, Optional, Sequence, Tuple, Union, cast
from urllib.parse import urljoin, urlsplit, urlunsplit

import attr
import requests
import werkzeug
from hypothesis.strategies import SearchStrategy
from starlette.testclient import TestClient as ASGIClient

from .checks import ALL_CHECKS
from .exceptions import InvalidSchema
from .types import Body, Cookies, FormData, Headers, PathParameters, Query
from .utils import GenericResponse, WSGIResponse

if TYPE_CHECKING:
    from .hooks import HookDispatcher
    from .schemas import BaseSchema
    from .stateful import StatefulTest


@attr.s(slots=True)  # pragma: no mutate
class Case:
    """A single test case parameters."""

    endpoint: "Endpoint" = attr.ib()  # pragma: no mutate
    path_parameters: Optional[PathParameters] = attr.ib(default=None)  # pragma: no mutate
    headers: Optional[Headers] = attr.ib(default=None)  # pragma: no mutate
    cookies: Optional[Cookies] = attr.ib(default=None)  # pragma: no mutate
    query: Optional[Query] = attr.ib(default=None)  # pragma: no mutate
    body: Optional[Body] = attr.ib(default=None)  # pragma: no mutate
    form_data: Optional[FormData] = attr.ib(default=None)  # pragma: no mutate

    @property
    def path(self) -> str:
        return self.endpoint.path

    @property
    def full_path(self) -> str:
        return self.endpoint.full_path

    @property
    def method(self) -> str:
        return self.endpoint.method

    @property
    def base_url(self) -> Optional[str]:
        return self.endpoint.base_url

    @property
    def app(self) -> Any:
        return self.endpoint.app

    @property
    def formatted_path(self) -> str:
        # pylint: disable=not-a-mapping
        try:
            return self.path.format(**self.path_parameters or {})
        except KeyError:
            raise InvalidSchema("Missing required property `required: true`")

    def get_full_base_url(self) -> Optional[str]:
        """Create a full base url, adding "localhost" for WSGI apps."""
        parts = urlsplit(self.base_url)
        if not parts.hostname:
            path = cast(str, parts.path or "")
            return urlunsplit(("http", "localhost", path or "", "", ""))
        return self.base_url

    def get_code_to_reproduce(self, headers: Optional[Dict[str, Any]] = None) -> str:
        """Construct a Python code to reproduce this case with `requests`."""
        base_url = self.get_full_base_url()
        kwargs = self.as_requests_kwargs(base_url)
        if headers:
            final_headers = kwargs["headers"] or {}
            final_headers.update(headers)
            kwargs["headers"] = final_headers
        method = kwargs["method"].lower()

        def are_defaults(key: str, value: Optional[Dict]) -> bool:
            default_value: Optional[Dict] = {"json": None}.get(key, None)
            return value == default_value

        printed_kwargs = ", ".join(
            f"{key}={repr(value)}"
            for key, value in kwargs.items()
            if key not in ("method", "url") and not are_defaults(key, value)
        )
        args_repr = f"'{kwargs['url']}'"
        if printed_kwargs:
            args_repr += f", {printed_kwargs}"
        return f"requests.{method}({args_repr})"

    def _get_base_url(self, base_url: Optional[str] = None) -> str:
        if base_url is None:
            if self.base_url is not None:
                base_url = self.base_url
            else:
                raise ValueError(
                    "Base URL is required as `base_url` argument in `call` or should be specified "
                    "in the schema constructor as a part of Schema URL."
                )
        return base_url

    def as_requests_kwargs(self, base_url: Optional[str] = None) -> Dict[str, Any]:
        """Convert the case into a dictionary acceptable by requests."""
        base_url = self._get_base_url(base_url)
        formatted_path = self.formatted_path.lstrip("/")  # pragma: no mutate
        url = urljoin(base_url + "/", formatted_path)
        # Form data and body are mutually exclusive
        extra: Dict[str, Optional[Body]]
        if self.form_data:
            extra = {"files": self.form_data}
        elif is_multipart(self.body):
            extra = {"data": self.body}
        else:
            extra = {"json": self.body}
        return {
            "method": self.method,
            "url": url,
            "cookies": self.cookies,
            "headers": self.headers,
            "params": self.query,
            **extra,
        }

    def call(
        self,
        base_url: Optional[str] = None,
        session: Optional[requests.Session] = None,
        headers: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> requests.Response:
        """Make a network call with `requests`."""
        if session is None:
            session = requests.Session()
            close_session = True
        else:
            close_session = False

        data = self.as_requests_kwargs(base_url)
        if headers is not None:
            data["headers"] = {**(data["headers"] or {}), **headers}
        data.update(kwargs)
        response = session.request(**data)  # type: ignore
        if close_session:
            session.close()
        return response

    def as_werkzeug_kwargs(self, headers: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
        """Convert the case into a dictionary acceptable by werkzeug.Client."""
        final_headers = self.headers.copy() if self.headers is not None else {}
        if headers:
            final_headers.update(headers)
        extra: Dict[str, Optional[Body]]
        if self.form_data:
            extra = {"data": self.form_data}
            final_headers = final_headers or {}
            final_headers.setdefault("Content-Type", "multipart/form-data")
        elif is_multipart(self.body):
            extra = {"data": self.body}
        else:
            extra = {"json": self.body}
        return {
            "method": self.method,
            "path": self.endpoint.schema.get_full_path(self.formatted_path),
            "headers": final_headers,
            "query_string": self.query,
            **extra,
        }

    def call_wsgi(self, app: Any = None, headers: Optional[Dict[str, str]] = None, **kwargs: Any) -> WSGIResponse:
        application = app or self.app
        if application is None:
            raise RuntimeError(
                "WSGI application instance is required. "
                "Please, set `app` argument in the schema constructor or pass it to `call_wsgi`"
            )
        data = self.as_werkzeug_kwargs(headers)
        client = werkzeug.Client(application, WSGIResponse)
        with cookie_handler(client, self.cookies):
            return client.open(**data, **kwargs)

    def call_asgi(
        self,
        app: Any = None,
        base_url: Optional[str] = "http://testserver",
        headers: Optional[Dict[str, str]] = None,
        **kwargs: Any,
    ) -> requests.Response:
        application = app or self.app
        if application is None:
            raise RuntimeError(
                "ASGI application instance is required. "
                "Please, set `app` argument in the schema constructor or pass it to `call_asgi`"
            )
        client = ASGIClient(application)

        return self.call(base_url=base_url, session=client, headers=headers, **kwargs)

    def validate_response(
        self,
        response: Union[requests.Response, WSGIResponse],
        checks: Tuple[Callable[[Union[requests.Response, WSGIResponse], "Case"], None], ...] = ALL_CHECKS,
    ) -> None:
        errors = []
        for check in checks:
            try:
                check(response, self)
            except AssertionError as exc:
                errors.append(exc.args[0])
        if errors:
            raise AssertionError(*errors)

    def get_full_url(self) -> str:
        """Make a full URL to the current endpoint, including query parameters."""
        base_url = self.base_url or "http://localhost"
        kwargs = self.as_requests_kwargs(base_url)
        request = requests.Request(**kwargs)
        prepared = requests.Session().prepare_request(request)  # type: ignore
        return prepared.url

    def partial_deepcopy(self) -> "Case":
        return self.__class__(
            endpoint=self.endpoint.partial_deepcopy(),
            path_parameters=deepcopy(self.path_parameters),
            headers=deepcopy(self.headers),
            cookies=deepcopy(self.cookies),
            query=deepcopy(self.query),
            body=deepcopy(self.body),
            form_data=deepcopy(self.form_data),
        )


def is_multipart(item: Optional[Body]) -> bool:
    """A poor detection if the body should be a multipart request.

    It traverses the structure and if it contains bytes in any value, then it is a multipart request, because
    it may happen only if there was `format: binary`, which usually is in multipart payloads.
    Probably a better way would be checking actual content types defined in `requestBody` and drive behavior based on
    that fact.
    """
    if isinstance(item, bytes):
        return True
    if isinstance(item, dict):
        for value in item.values():
            if is_multipart(value):
                return True
    if isinstance(item, list):
        for value in item:
            if is_multipart(value):
                return True
    return False


@contextmanager
def cookie_handler(client: werkzeug.Client, cookies: Optional[Cookies]) -> Generator[None, None, None]:
    """Set cookies required for a call."""
    if not cookies:
        yield
    else:
        for key, value in cookies.items():
            client.set_cookie("localhost", key, value)
        yield
        for key in cookies:
            client.delete_cookie("localhost", key)


def empty_object() -> Dict[str, Any]:
    return {"properties": {}, "additionalProperties": False, "type": "object", "required": []}


@attr.s(slots=True)  # pragma: no mutate
class EndpointDefinition:
    """A wrapper to store not resolved endpoint definitions.

    To prevent recursion errors we need to store definitions without resolving references. But endpoint definitions
    itself can be behind a reference (when there is a ``$ref`` in ``paths`` values), therefore we need to store this
    scope change to have a proper reference resolving later.
    """

    raw: Dict[str, Any] = attr.ib()  # pragma: no mutate
    resolved: Dict[str, Any] = attr.ib()  # pragma: no mutate
    scope: str = attr.ib()  # pragma: no mutate


@attr.s(slots=True)  # pragma: no mutate
class Endpoint:
    """A container that could be used for test cases generation."""

    # `path` does not contain `basePath`
    # Example <scheme>://<host>/<basePath>/users - "/users" is path
    # https://swagger.io/docs/specification/2-0/api-host-and-base-path/
    path: str = attr.ib()  # pragma: no mutate
    method: str = attr.ib()  # pragma: no mutate
    definition: EndpointDefinition = attr.ib()  # pragma: no mutate
    schema: "BaseSchema" = attr.ib()  # pragma: no mutate
    app: Any = attr.ib(default=None)  # pragma: no mutate
    base_url: Optional[str] = attr.ib(default=None)  # pragma: no mutate
    path_parameters: Optional[PathParameters] = attr.ib(default=None)  # pragma: no mutate
    headers: Optional[Headers] = attr.ib(default=None)  # pragma: no mutate
    cookies: Optional[Cookies] = attr.ib(default=None)  # pragma: no mutate
    query: Optional[Query] = attr.ib(default=None)  # pragma: no mutate
    body: Optional[Body] = attr.ib(default=None)  # pragma: no mutate
    form_data: Optional[FormData] = attr.ib(default=None)  # pragma: no mutate

    @property
    def full_path(self) -> str:
        return self.schema.get_full_path(self.path)

    def as_strategy(self, hooks: Optional["HookDispatcher"] = None) -> SearchStrategy:
        from ._hypothesis import get_case_strategy  # pylint: disable=import-outside-toplevel

        return get_case_strategy(self, hooks)

    def get_strategies_from_examples(self) -> List[SearchStrategy[Case]]:
        """Get examples from endpoint."""
        return self.schema.get_strategies_from_examples(self)

    def get_stateful_tests(self, response: GenericResponse, stateful: Optional[str]) -> Sequence["StatefulTest"]:
        return self.schema.get_stateful_tests(response, self, stateful)

    def get_hypothesis_conversions(self, location: str) -> Optional[Callable]:
        definitions = [item for item in self.definition.resolved.get("parameters", []) if item["in"] == location]
        if definitions:
            return self.schema.get_hypothesis_conversion(definitions)
        return None

    def partial_deepcopy(self) -> "Endpoint":
        return self.__class__(
            path=self.path,  # string, immutable
            method=self.method,  # string, immutable
            definition=deepcopy(self.definition),
            schema=self.schema.clone(),  # shallow copy
            app=self.app,  # not deepcopyable
            base_url=self.base_url,  # string, immutable
            path_parameters=deepcopy(self.path_parameters),
            headers=deepcopy(self.path_parameters),
            cookies=deepcopy(self.cookies),
            query=deepcopy(self.query),
            body=deepcopy(self.body),
            form_data=deepcopy(self.form_data),
        )


class Status(IntEnum):
    """Status of an action or multiple actions."""

    success = 1  # pragma: no mutate
    failure = 2  # pragma: no mutate
    error = 3  # pragma: no mutate


@attr.s(slots=True, repr=False)  # pragma: no mutate
class Check:
    """Single check run result."""

    name: str = attr.ib()  # pragma: no mutate
    value: Status = attr.ib()  # pragma: no mutate
    example: Optional[Case] = attr.ib(default=None)  # pragma: no mutate
    message: Optional[str] = attr.ib(default=None)  # pragma: no mutate


@attr.s(slots=True, repr=False)  # pragma: no mutate
class Request:
    """Request data extracted from `Case`."""

    method: str = attr.ib()  # pragma: no mutate
    uri: str = attr.ib()  # pragma: no mutate
    body: str = attr.ib()  # pragma: no mutate
    headers: Headers = attr.ib()  # pragma: no mutate

    @classmethod
    def from_case(cls, case: Case, session: requests.Session) -> "Request":
        """Create a new `Request` instance from `Case`."""
        base_url = case.get_full_base_url()
        kwargs = case.as_requests_kwargs(base_url)
        request = requests.Request(**kwargs)
        prepared = session.prepare_request(request)  # type: ignore
        return cls.from_prepared_request(prepared)

    @classmethod
    def from_prepared_request(cls, prepared: requests.PreparedRequest) -> "Request":
        """A prepared request version is already stored in `requests.Response`."""
        body = prepared.body or b""

        if isinstance(body, str):
            # can be a string for `application/x-www-form-urlencoded`
            body = body.encode("utf-8")

        # these values have `str` type at this point
        uri = cast(str, prepared.url)
        method = cast(str, prepared.method)
        return cls(
            uri=uri,
            method=method,
            headers={key: [value] for (key, value) in prepared.headers.items()},
            body=base64.b64encode(body).decode(),
        )


def serialize_payload(payload: bytes) -> str:
    return base64.b64encode(payload).decode()


@attr.s(slots=True, repr=False)  # pragma: no mutate
class Response:
    """Unified response data."""

    status_code: int = attr.ib()  # pragma: no mutate
    message: str = attr.ib()  # pragma: no mutate
    headers: Dict[str, List[str]] = attr.ib()  # pragma: no mutate
    body: str = attr.ib()  # pragma: no mutate
    encoding: str = attr.ib()  # pragma: no mutate
    http_version: str = attr.ib()  # pragma: no mutate
    elapsed: float = attr.ib()  # pragma: no mutate

    @classmethod
    def from_requests(cls, response: requests.Response) -> "Response":
        """Create a response from requests.Response."""
        headers = {name: response.raw.headers.getlist(name) for name in response.raw.headers.keys()}
        # Similar to http.client:319 (HTTP version detection in stdlib's `http` package)
        http_version = "1.0" if response.raw.version == 10 else "1.1"
        return cls(
            status_code=response.status_code,
            message=response.reason,
            body=serialize_payload(response.content),
            encoding=response.encoding or "utf8",
            headers=headers,
            http_version=http_version,
            elapsed=response.elapsed.total_seconds(),
        )

    @classmethod
    def from_wsgi(cls, response: WSGIResponse, elapsed: float) -> "Response":
        """Create a response from WSGI response."""
        message = http.client.responses.get(response.status_code, "UNKNOWN")
        headers = {name: response.headers.getlist(name) for name in response.headers.keys()}
        return cls(
            status_code=response.status_code,
            message=message,
            body=serialize_payload(response.data),
            encoding=response.content_encoding or "utf-8",
            headers=headers,
            http_version="1.1",
            elapsed=elapsed,
        )


@attr.s(slots=True)  # pragma: no mutate
class Interaction:
    """A single interaction with the target app."""

    request: Request = attr.ib()  # pragma: no mutate
    response: Response = attr.ib()  # pragma: no mutate
    recorded_at: str = attr.ib(factory=lambda: datetime.datetime.now().isoformat())  # pragma: no mutate

    @classmethod
    def from_requests(cls, response: requests.Response) -> "Interaction":
        return cls(request=Request.from_prepared_request(response.request), response=Response.from_requests(response))

    @classmethod
    def from_wsgi(cls, case: Case, response: WSGIResponse, headers: Dict[str, Any], elapsed: float) -> "Interaction":
        session = requests.Session()
        session.headers.update(headers)
        return cls(request=Request.from_case(case, session), response=Response.from_wsgi(response, elapsed))


@attr.s(slots=True, repr=False)  # pragma: no mutate
class TestResult:
    """Result of a single test."""

    endpoint: Endpoint = attr.ib()  # pragma: no mutate
    checks: List[Check] = attr.ib(factory=list)  # pragma: no mutate
    errors: List[Tuple[Exception, Optional[Case]]] = attr.ib(factory=list)  # pragma: no mutate
    interactions: List[Interaction] = attr.ib(factory=list)  # pragma: no mutate
    logs: List[LogRecord] = attr.ib(factory=list)  # pragma: no mutate
    is_errored: bool = attr.ib(default=False)  # pragma: no mutate
    seed: Optional[int] = attr.ib(default=None)  # pragma: no mutate
    # To show a proper reproduction code if a failure happens
    overridden_headers: Optional[Dict[str, Any]] = attr.ib(default=None)  # pragma: no mutate

    def mark_errored(self) -> None:
        self.is_errored = True

    @property
    def has_errors(self) -> bool:
        return bool(self.errors)

    @property
    def has_failures(self) -> bool:
        return any(check.value == Status.failure for check in self.checks)

    @property
    def has_logs(self) -> bool:
        return bool(self.logs)

    def add_success(self, name: str, example: Case) -> None:
        self.checks.append(Check(name, Status.success, example))

    def add_failure(self, name: str, example: Case, message: str) -> None:
        self.checks.append(Check(name, Status.failure, example, message))

    def add_error(self, exception: Exception, example: Optional[Case] = None) -> None:
        self.errors.append((exception, example))

    def store_requests_response(self, response: requests.Response) -> None:
        self.interactions.append(Interaction.from_requests(response))

    def store_wsgi_response(self, case: Case, response: WSGIResponse, headers: Dict[str, Any], elapsed: float) -> None:
        self.interactions.append(Interaction.from_wsgi(case, response, headers, elapsed))


@attr.s(slots=True, repr=False)  # pragma: no mutate
class TestResultSet:
    """Set of multiple test results."""

    results: List[TestResult] = attr.ib(factory=list)  # pragma: no mutate

    def __iter__(self) -> Iterator[TestResult]:
        return iter(self.results)

    @property
    def is_empty(self) -> bool:
        """If the result set contains no results."""
        return len(self.results) == 0

    @property
    def has_failures(self) -> bool:
        """If any result has any failures."""
        return any(result.has_failures for result in self)

    @property
    def has_errors(self) -> bool:
        """If any result has any errors."""
        return any(result.has_errors for result in self)

    @property
    def has_logs(self) -> bool:
        """If any result has any captured logs."""
        return any(result.has_logs for result in self)

    def _count(self, predicate: Callable) -> int:
        return sum(1 for result in self if predicate(result))

    @property
    def passed_count(self) -> int:
        return self._count(lambda result: not result.has_errors and not result.has_failures)

    @property
    def failed_count(self) -> int:
        return self._count(lambda result: result.has_failures and not result.is_errored)

    @property
    def errored_count(self) -> int:
        return self._count(lambda result: result.has_errors or result.is_errored)

    @property
    def total(self) -> Dict[str, Dict[Union[str, Status], int]]:
        """Aggregated statistic about test results."""
        output: Dict[str, Dict[Union[str, Status], int]] = {}
        for item in self.results:
            for check in item.checks:
                output.setdefault(check.name, Counter())
                output[check.name][check.value] += 1
                output[check.name]["total"] += 1
        # Avoid using Counter, since its behavior could harm in other places:
        # `if not total["unknown"]:` - this will lead to the branch execution
        # It is better to let it fail if there is a wrong key
        return {key: dict(value) for key, value in output.items()}

    def append(self, item: TestResult) -> None:
        """Add a new item to the results list."""
        self.results.append(item)


CheckFunction = Callable[[GenericResponse, Case], None]  # pragma: no mutate