import numpy as np
import pandas as pd
import pytest

from respy.config import EXAMPLE_MODELS
from respy.config import INDEXER_INVALID_INDEX
from respy.config import KEANE_WOLPIN_1994_MODELS
from respy.config import KEANE_WOLPIN_1997_MODELS
from respy.interface import get_example_model
from respy.pre_processing.model_checking import check_model_solution
from respy.pre_processing.model_processing import process_params_and_options
from respy.shared import create_core_state_space_columns
from respy.solve import get_solve_func
from respy.state_space import _create_core_and_indexer
from respy.state_space import _insert_indices_of_child_states
from respy.tests._former_code import _create_state_space_kw94
from respy.tests._former_code import _create_state_space_kw97_base
from respy.tests._former_code import _create_state_space_kw97_extended
from respy.tests.utils import apply_to_attributes_of_two_state_spaces
from respy.tests.utils import process_model_or_seed


@pytest.mark.parametrize("model_or_seed", EXAMPLE_MODELS)
def test_check_solution(model_or_seed):
    params, options = process_model_or_seed(model_or_seed)

    solve = get_solve_func(params, options)
    state_space = solve(params)

    optim_paras, options = process_params_and_options(params, options)

    check_model_solution(optim_paras, options, state_space)


@pytest.mark.integration
@pytest.mark.precise
@pytest.mark.parametrize("model", EXAMPLE_MODELS)
def test_state_space_restrictions_by_traversing_forward(model):
    """Test for inadmissible states in the state space.

    The test is motivated by the addition of another restriction in
    https://github.com/OpenSourceEconomics/respy/pull/145. To ensure that similar errors
    do not happen again, this test takes all states of the first period and finds all
    their child states. Taking only the child states their children are found and so on.
    At last, the set of visited states is compared against the total set of states.

    The test can only applied to some models. Most models would need custom
    ``options["core_state_space_filters"]`` to remove inaccessible states from the state
    space.

    """
    params, options = process_model_or_seed(model)
    optim_paras, options = process_params_and_options(params, options)

    solve = get_solve_func(params, options)
    state_space = solve(params)

    indices = np.full(
        (state_space.core.shape[0], len(optim_paras["choices"])), INDEXER_INVALID_INDEX
    )
    core_columns = create_core_state_space_columns(optim_paras)

    for period in range(options["n_periods"] - 1):

        if period == 0:
            states = state_space.core.query("period == 0")[core_columns].to_numpy(
                np.int
            )
        else:
            indices_period = state_space.indices_of_child_states[
                state_space.slices_by_periods[period - 1]
            ]
            indices_period = indices_period[indices_period >= 0]
            states = state_space.core[core_columns].to_numpy(np.int)[indices_period]

        indices = _insert_indices_of_child_states(
            indices,
            states,
            state_space.indexer[period],
            state_space.indexer[period + 1],
            state_space.is_inadmissible,
            len(optim_paras["choices_w_exp"]),
            optim_paras["n_lagged_choices"],
        )

    # Take all valid indices and add the indices of the first period.
    set_valid_indices = set(indices[indices != INDEXER_INVALID_INDEX]) | set(
        range(state_space.core.query("period == 0").shape[0])
    )

    assert set_valid_indices == set(range(state_space.core.shape[0]))


@pytest.mark.integration
@pytest.mark.parametrize("model_or_seed", EXAMPLE_MODELS)
def test_invariance_of_solution(model_or_seed):
    """Test for the invariance of the solution.

    We run solve two times and check whether all attributes of the state space match.

    """
    params, options = process_model_or_seed(model_or_seed)

    optim_paras, options = process_params_and_options(params, options)

    solve = get_solve_func(params, options)
    state_space = solve(params)
    state_space_ = solve(params)

    apply_to_attributes_of_two_state_spaces(
        state_space.core, state_space_.core, np.testing.assert_array_equal
    )
    apply_to_attributes_of_two_state_spaces(
        state_space.get_attribute("wages"),
        state_space_.get_attribute("wages"),
        np.testing.assert_array_equal,
    )
    apply_to_attributes_of_two_state_spaces(
        state_space.get_attribute("nonpecs"),
        state_space_.get_attribute("nonpecs"),
        np.testing.assert_array_equal,
    )
    apply_to_attributes_of_two_state_spaces(
        state_space.get_attribute("expected_value_functions"),
        state_space_.get_attribute("expected_value_functions"),
        np.testing.assert_array_equal,
    )
    apply_to_attributes_of_two_state_spaces(
        state_space.get_attribute("base_draws_sol"),
        state_space_.get_attribute("base_draws_sol"),
        np.testing.assert_array_equal,
    )


@pytest.mark.precise
@pytest.mark.unit
@pytest.mark.parametrize("model", KEANE_WOLPIN_1994_MODELS)
def test_create_state_space_vs_specialized_kw94(model):
    point_constr = {"n_lagged_choices": 1, "observables": False}
    params, options = process_model_or_seed(model, point_constr=point_constr)

    optim_paras, options = process_params_and_options(params, options)

    # Create old state space arguments.
    n_periods = options["n_periods"]
    n_types = optim_paras["n_types"]
    edu_max = optim_paras["choices"]["edu"]["max"]
    edu_starts = np.array(list(optim_paras["choices"]["edu"]["start"]))

    # Get states and indexer from old state space.
    states_old, indexer_old = _create_state_space_kw94(
        n_periods, n_types, edu_starts, edu_max
    )
    if n_types == 1:
        states_old = states_old[:, :-1]
        for i, idx in enumerate(indexer_old):
            shape = idx.shape
            indexer_old[i] = idx.reshape(shape[:-2] + (-1,))

    states_new, indexer_new = _create_core_and_indexer(optim_paras, options)

    # Compare the state spaces via sets as ordering changed in some cases.
    states_old_set = set(map(tuple, states_old))
    states_new_set = set(map(tuple, states_new.to_numpy()))
    assert states_old_set == states_new_set

    # Compare indexers via masks for valid indices.
    for period in range(n_periods):
        mask_old = indexer_old[period] != INDEXER_INVALID_INDEX
        mask_new = indexer_new[period] != INDEXER_INVALID_INDEX
        assert np.array_equal(mask_old, mask_new)


@pytest.mark.precise
@pytest.mark.unit
@pytest.mark.parametrize("model", KEANE_WOLPIN_1997_MODELS)
def test_create_state_space_vs_specialized_kw97(model):
    params, options = process_model_or_seed(model)

    # Reduce runtime
    options["n_periods"] = 10 if options["n_periods"] > 10 else options["n_periods"]

    optim_paras, options = process_params_and_options(params, options)

    # Create old state space arguments.
    n_periods = options["n_periods"]
    n_types = optim_paras["n_types"]
    edu_max = optim_paras["choices"]["school"]["max"]
    edu_starts = np.array(list(optim_paras["choices"]["school"]["start"]))

    # Get states and indexer from old state space.
    if model == "kw_97_basic":
        states_old, indexer_old = _create_state_space_kw97_base(
            n_periods, n_types, edu_starts, edu_max
        )
    else:
        states_old, indexer_old = _create_state_space_kw97_extended(
            n_periods, n_types, edu_starts, edu_max
        )
    if n_types == 1:
        states_old = states_old[:, :-1]
        for i, idx in enumerate(indexer_old):
            shape = idx.shape
            indexer_old[i] = idx.reshape(shape[:-2] + (-1,))

    states_new, indexer_new = _create_core_and_indexer(optim_paras, options)
    states_new = pd.concat([states_new.copy().assign(type=i) for i in range(4)])

    # Compare the state spaces via sets as ordering changed in some cases.
    states_old_set = set(map(tuple, states_old))
    states_new_set = set(map(tuple, states_new.to_numpy()))
    assert states_old_set == states_new_set

    # Compare indexers via masks for valid indices.
    for period in range(n_periods):
        mask_old = indexer_old[period] != INDEXER_INVALID_INDEX
        mask_new = indexer_new[period] != INDEXER_INVALID_INDEX
        adj_mask_new = np.repeat(mask_new, 4).reshape(mask_old.shape)
        assert np.array_equal(mask_old, adj_mask_new)


@pytest.mark.edge_case
@pytest.mark.integration
def test_explicitly_nonpec_choice_rewards_of_kw_94_one():
    params, options = get_example_model("kw_94_one", with_data=False)

    solve = get_solve_func(params, options)
    state_space = solve(params)

    assert (state_space.nonpecs[:, :2] == 0).all()
    assert np.isin(state_space.nonpecs[:, 2], [0, -4_000, -400_000, -404_000]).all()
    assert (state_space.nonpecs[:, 3] == 17_750).all()


@pytest.mark.edge_case
@pytest.mark.integration
def test_explicitly_nonpec_choice_rewards_of_kw_94_two():
    params, options = get_example_model("kw_94_two", with_data=False)

    solve = get_solve_func(params, options)
    state_space = solve(params)

    assert (state_space.nonpecs[:, :2] == 0).all()
    assert np.isin(
        state_space.nonpecs[:, 2], [5_000, 0, -10_000, -15_000, -400_000, -415_000]
    ).all()
    assert (state_space.nonpecs[:, 3] == 14_500).all()