import contextlib
import os
import typing
from unittest import mock

import pytest
import skein

from tf_yarn import cluster

MODULE_TO_TEST = "tf_yarn.cluster"


def test_aggregate_spec():
    client = mock.MagicMock(spec=skein.ApplicationClient)
    dict_sockaddr: typing.Dict[str, bytes] = {
        "worker:0/init": "1.1.1.1:8020".encode(),
        "worker:1/init": "1.1.1.2:4042".encode(),
        "ps:0/init": "1.1.1.3:8888".encode()
    }
    client.kv = mock.MagicMock(spec=skein.kv.KeyValueStore)
    client.kv.wait.side_effect = lambda arg: dict_sockaddr[arg]

    res = cluster.aggregate_spec(client, ["worker:1", "ps:0", "worker:0"])
    assert res == {"worker": ["1.1.1.1:8020", "1.1.1.2:4042"],
                   "ps": ["1.1.1.3:8888"]}


def test_get_task_description():
    with mock.patch.dict(os.environ):
        os.environ["SKEIN_CONTAINER_ID"] = "MYTASK_42"
        assert "MYTASK", 42 == cluster.get_task_description()


CURRENT_HOST = "1.1.1.1"
CURRENT_PORT = 8888
WORKER0_HOST = "1.1.1.2"
WORKER0_PORT = 8888
WORKER1_HOST = "1.1.1.3"
WORKER1_PORT = 8888


@pytest.mark.parametrize("task_name, task_index", [
    pytest.param("worker", 1),
    pytest.param("ps", 0)
])
def test_start_cluster_worker(task_name, task_index):
    task = f"{task_name}:{task_index}"

    CLUSTER_SPEC = {"worker:0/init": [f"{WORKER0_HOST}:{WORKER0_PORT}"],
                    f"{task}/init": [f"{CURRENT_HOST}:{CURRENT_PORT}"]}

    with contextlib.ExitStack() as stack:
        stack.enter_context(mock.patch.dict(os.environ))
        mock_event = stack.enter_context(mock.patch(f"{MODULE_TO_TEST}.event"))

        os.environ["SKEIN_CONTAINER_ID"] = f"{task_name}_{task_index}"

        mock_event.wait.side_effect = lambda client, key: CLUSTER_SPEC[key][0]
        mock_client = mock.Mock(spec=skein.ApplicationClient)
        cluster.start_cluster((CURRENT_HOST, CURRENT_PORT), mock_client, [task, "worker:0"])
        mock_event.init_event.assert_called_once_with(mock_client, task,
                                                      f"{CURRENT_HOST}:{CURRENT_PORT}")


@pytest.mark.parametrize("task_name, task_index, is_server_started", [
    pytest.param("worker", 1, True),
    pytest.param("ps", 0, False)
])
def test_start_tf_server(task_name, task_index, is_server_started):

    CLUSTER_SPEC = {"worker": [f"worker0.{WORKER0_HOST}:{WORKER0_PORT}",
                              f"worker1.{WORKER1_HOST}:{WORKER1_PORT}"],
                    "ps": [f"ps0.{CURRENT_HOST}:{CURRENT_PORT}"]}

    with contextlib.ExitStack() as stack:
        stack.enter_context(mock.patch.dict(os.environ))
        os.environ["SKEIN_CONTAINER_ID"] = f"{task_name}_{task_index}"
        mock_server = stack.enter_context(mock.patch(f"{MODULE_TO_TEST}.tf.distribute"))
        cluster.start_tf_server(CLUSTER_SPEC)

        if is_server_started:
            assert mock_server.Server.call_count == 1
            _, kwargs = mock_server.Server.call_args
            assert kwargs["job_name"] == task_name
            assert kwargs["task_index"] == task_index
            assert kwargs["start"] is True
        else:
            assert mock_server.Server.call_count == 0