import os
import socket
import ssl
from typing import Tuple
from unittest.mock import Mock

import pytest
from _pytest.monkeypatch import MonkeyPatch

import hypercorn.config
from hypercorn.config import Config

access_log_format = "bob"
h11_max_incomplete_size = 4


def _check_standard_config(config: Config) -> None:
    assert config.access_log_format == access_log_format
    assert config.h11_max_incomplete_size == h11_max_incomplete_size
    assert config.bind == ["127.0.0.1:5555"]


def test_config_from_pyfile() -> None:
    path = os.path.join(os.path.dirname(__file__), "assets/config.py")
    config = Config.from_pyfile(path)
    _check_standard_config(config)


def test_ssl_config_from_pyfile() -> None:
    path = os.path.join(os.path.dirname(__file__), "assets/config_ssl.py")
    config = Config.from_pyfile(path)
    _check_standard_config(config)
    assert config.ssl_enabled


def test_config_from_toml() -> None:
    path = os.path.join(os.path.dirname(__file__), "assets/config.toml")
    config = Config.from_toml(path)
    _check_standard_config(config)


def test_create_ssl_context() -> None:
    path = os.path.join(os.path.dirname(__file__), "assets/config_ssl.py")
    config = Config.from_pyfile(path)
    context = config.create_ssl_context()
    assert context.options & (
        ssl.OP_NO_SSLv2
        | ssl.OP_NO_SSLv3
        | ssl.OP_NO_TLSv1
        | ssl.OP_NO_TLSv1_1
        | ssl.OP_NO_COMPRESSION
    )


@pytest.mark.parametrize(
    "bind, expected_family, expected_binding",
    [
        ("127.0.0.1:5000", socket.AF_INET, ("127.0.0.1", 5000)),
        ("127.0.0.1", socket.AF_INET, ("127.0.0.1", 8000)),
        ("[::]:5000", socket.AF_INET6, ("::", 5000)),
        ("[::]", socket.AF_INET6, ("::", 8000)),
    ],
)
def test_create_sockets_ip(
    bind: str,
    expected_family: socket.AddressFamily,
    expected_binding: Tuple[str, int],
    monkeypatch: MonkeyPatch,
) -> None:
    mock_socket = Mock()
    monkeypatch.setattr(socket, "socket", mock_socket)
    config = Config()
    config.bind = [bind]
    sockets = config.create_sockets()
    sock = sockets.insecure_sockets[0]
    mock_socket.assert_called_with(expected_family, socket.SOCK_STREAM)
    sock.setsockopt.assert_called_with(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)  # type: ignore
    sock.bind.assert_called_with(expected_binding)  # type: ignore
    sock.setblocking.assert_called_with(False)  # type: ignore
    sock.set_inheritable.assert_called_with(True)  # type: ignore


def test_create_sockets_unix(monkeypatch: MonkeyPatch) -> None:
    mock_socket = Mock()
    monkeypatch.setattr(socket, "socket", mock_socket)
    monkeypatch.setattr(os, "chown", Mock())
    config = Config()
    config.bind = ["unix:/tmp/hypercorn.sock"]
    sockets = config.create_sockets()
    sock = sockets.insecure_sockets[0]
    mock_socket.assert_called_with(socket.AF_UNIX, socket.SOCK_STREAM)
    sock.setsockopt.assert_called_with(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)  # type: ignore
    sock.bind.assert_called_with("/tmp/hypercorn.sock")  # type: ignore
    sock.setblocking.assert_called_with(False)  # type: ignore
    sock.set_inheritable.assert_called_with(True)  # type: ignore


def test_create_sockets_fd(monkeypatch: MonkeyPatch) -> None:
    mock_fromfd = Mock()
    monkeypatch.setattr(socket, "fromfd", mock_fromfd)
    config = Config()
    config.bind = ["fd://2"]
    sockets = config.create_sockets()
    sock = sockets.insecure_sockets[0]
    mock_fromfd.assert_called_with(2, socket.AF_UNIX, socket.SOCK_STREAM)
    sock.setsockopt.assert_called_with(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)  # type: ignore
    sock.setblocking.assert_called_with(False)  # type: ignore
    sock.set_inheritable.assert_called_with(True)  # type: ignore


def test_create_sockets_multiple(monkeypatch: MonkeyPatch) -> None:
    mock_socket = Mock()
    monkeypatch.setattr(socket, "socket", mock_socket)
    monkeypatch.setattr(os, "chown", Mock())
    config = Config()
    config.bind = ["127.0.0.1", "unix:/tmp/hypercorn.sock"]
    sockets = config.create_sockets()
    assert len(sockets.insecure_sockets) == 2


def test_response_headers(monkeypatch: MonkeyPatch) -> None:
    monkeypatch.setattr(hypercorn.config, "time", lambda: 1_512_229_395)
    config = Config()
    assert config.response_headers("test") == [
        (b"date", b"Sat, 02 Dec 2017 15:43:15 GMT"),
        (b"server", b"hypercorn-test"),
    ]
    config.include_server_header = False
    assert config.response_headers("test") == [(b"date", b"Sat, 02 Dec 2017 15:43:15 GMT")]