import io
from typing import AsyncIterator, NamedTuple

import aiohttp
import pytest
from aiohttp import FormData
from aiohttp.web import HTTPOk

from nima.api import Config, ServerConfig, WorkersConfig, create_app

def config(state_dict_path) -> Config:
    server_config = ServerConfig()
    workers_config = WorkersConfig(path_to_model_state=state_dict_path)
    return Config(server=server_config, worker=workers_config)

class ApiConfig(NamedTuple):
    host: str
    port: int

    def endpoint(self) -> str:
        return f"http://{}:{self.port}"

    def model_base_url(self) -> str:
        return self.endpoint

    def ping_url(self) -> str:
        return self.endpoint + "/ping"

async def api(config: Config) -> AsyncIterator[ApiConfig]:
    app = await create_app(config)
    runner = aiohttp.web.AppRunner(app)
    await runner.setup()
    api_config = ApiConfig(host="", port=8080)
    site = aiohttp.web.TCPSite(runner,, api_config.port)
    await site.start()
    yield api_config
    await runner.cleanup()

async def client() -> AsyncIterator[aiohttp.ClientSession]:
    async with aiohttp.ClientSession() as session:
        yield session

class TestModelApi:
    async def test_predict(self, api: ApiConfig, client: aiohttp.ClientSession, image_file_obj: io.BytesIO) -> None:
        predict_url = api.model_base_url + "/predict"

        data = FormData()
        data.add_field("file", image_file_obj, filename="test_image.jpg", content_type="image/img")

        async with, data=data) as response:
            assert response.status == HTTPOk.status_code
            res_data = await response.json()
            assert "mean_score" in res_data
            assert "std_score" in res_data
            assert "scores" in res_data
            assert "total_time" in res_data

class TestApi:
    async def test_ping(self, api: ApiConfig, client: aiohttp.ClientSession) -> None:
        async with client.get(api.ping_url) as response:
            assert response.status == HTTPOk.status_code