from typing import Optional, Type, Union
from flask_restx import fields as fr, inputs
from marshmallow import fields as ma
from marshmallow.schema import Schema, SchemaMeta
import uuid


def unpack_list(val, api, model_name: str = None, operation: str = "dump"):
    model_name = model_name or get_default_model_name()
    return fr.List(
        map_type(val.inner, api, model_name, operation), **_ma_field_to_fr_field(val)
    )


def unpack_nested(val, api, model_name: str = None, operation: str = "dump"):
    if val.nested == "self":
        return unpack_nested_self(val, api, model_name, operation)

    model_name = get_default_model_name(val.nested)
    return fr.Nested(
        map_type(val.nested, api, model_name, operation), **_ma_field_to_fr_field(val)
    )


def unpack_nested_self(val, api, model_name: str = None, operation: str = "dump"):
    model_name = model_name or get_default_model_name(val.schema)
    fields = {
        k: map_type(v, api, model_name, operation)
        for k, v in (vars(val.schema).get("fields").items())
        if type(v) in type_map and _check_load_dump_only(v, operation)
    }
    if val.many:
        return fr.List(
            fr.Nested(
                api.model(f"{model_name}-child", fields), **_ma_field_to_fr_field(val)
            )
        )
    else:
        return fr.Nested(
            api.model(f"{model_name}-child", fields), **_ma_field_to_fr_field(val)
        )


def for_swagger(schema, api, model_name: str = None, operation: str = "dump"):
    """
    Convert a marshmallow schema to equivalent Flask-restx model

    Args:
        schema (Marshmallow Schema): Schema defining the inputs
        api (Namespace): Flask-restx namespace (necessary for context)
        model_name (str): Name of Flask-restx model

    Returns:
        api.model: An equivalent api.model
    """

    model_name = model_name or get_default_model_name(schema)

    # For nested Schemas, the internal fields are stored in _declared_fields, whereas
    # for Schemas the name is declared_fields, so check for both.
    if isinstance(schema, SchemaMeta):
        schema = schema()
    fields = {
        v.data_key or k: map_type(v, api, model_name, operation)
        for k, v in (vars(schema).get("fields").items())
        if type(v) in type_map and _check_load_dump_only(v, operation)
    }

    model_name = _maybe_add_operation(schema, model_name, operation)
    return api.model(model_name, fields)


def _maybe_add_operation(schema, model_name: str, operation: str):
    if any(f.load_only or f.dump_only for k, f in (vars(schema).get("fields").items())):
        return f"{model_name}-{operation}"
    return f"{model_name}"


def _check_load_dump_only(field: ma.Field, operation: str) -> bool:
    if operation == "dump":
        return not field.load_only
    elif operation == "load":
        return not field.dump_only
    else:
        raise ValueError(
            f"Invalid operation: {operation}. Options are 'load' and 'dump'."
        )


def make_type_mapper(field_type):
    """Factory for creating mapping functions for `type_map` with additional
    marshmallow fields, if present"""

    def mapper(val, api, model_name, operation):
        return field_type(**_ma_field_to_fr_field(val))

    return mapper


type_map = {
    ma.AwareDateTime: fr.Raw,
    ma.Bool: fr.Boolean,
    ma.Boolean: fr.Boolean,
    ma.Constant: fr.Raw,
    ma.Date: fr.Date,
    ma.DateTime: fr.DateTime,
    # For some reason, fr.Decimal has no example parameter, so use Float instead
    ma.Decimal: fr.Float,
    ma.Dict: fr.Raw,
    ma.Email: fr.String,
    ma.Float: fr.Float,
    ma.Function: fr.Raw,
    ma.Int: fr.Integer,
    ma.Integer: fr.Integer,
    ma.Length: fr.Float,
    ma.Mapping: fr.Raw,
    ma.Method: fr.Raw,
    ma.NaiveDateTime: fr.DateTime,
    ma.Number: fr.Float,
    ma.Pluck: fr.Raw,
    ma.Raw: fr.Raw,
    ma.Str: fr.String,
    ma.String: fr.String,
    ma.Time: fr.DateTime,
    ma.Url: fr.Url,
    ma.URL: fr.Url,
    ma.UUID: fr.String,
}


type_map = {k: make_type_mapper(v) for k, v in type_map.items()}

# Add in the special cases
type_map.update(
    {
        ma.List: unpack_list,
        ma.Nested: unpack_nested,
        Schema: for_swagger,
        SchemaMeta: for_swagger,
    }
)

num_default_models = 0


def get_default_model_name(schema: Optional[Union[Schema, Type[Schema]]] = None) -> str:
    if schema:
        if isinstance(schema, Schema):
            return "".join(schema.__class__.__name__.rsplit("Schema", 1))
        else:
            # It is a type itself
            return "".join(schema.__name__.rsplit("Schema", 1))

    global num_default_models
    name = f"DefaultResponseModel_{num_default_models}"
    num_default_models += 1
    return name


def _ma_field_to_fr_field(value: ma.Field) -> dict:
    fr_field_parameters = {}

    if hasattr(value, "default"):
        fr_field_parameters["example"] = value.default

    if hasattr(value, "required"):
        fr_field_parameters["required"] = value.required

    if hasattr(value, "metadata") and "description" in value.metadata:
        fr_field_parameters["description"] = value.metadata["description"]

    if hasattr(value, "missing") and type(value.missing) != ma.utils._Missing:
        fr_field_parameters["default"] = value.missing

    return fr_field_parameters


def map_type(val, api, model_name, operation):
    value_type = type(val)

    if value_type in type_map:
        return type_map[value_type](val, api, model_name, operation)

    if issubclass(value_type, SchemaMeta) or issubclass(value_type, Schema):
        return type_map[Schema](val, api, model_name, operation)

    raise TypeError('Unknown type for marshmallow model field was used.')


type_map_ma_to_reqparse = {
    ma.Bool: inputs.boolean,
    ma.Boolean: inputs.boolean,
    ma.Int: int,
    ma.Integer: int,
    ma.Float: float
}


def ma_field_to_reqparse_argument(value: ma.Field) -> dict:
    """Maps a marshmallow field to a dictionary that can be used to initialize a
    request parser argument.
    """
    reqparse_argument_parameters = {}

    if is_list_field(value):
        value_type = type(value.inner)
        reqparse_argument_parameters["action"] = "append"
    else:
        value_type = type(value)
        reqparse_argument_parameters["action"] = "store"

    reqparse_argument_parameters["type"] = type_map_ma_to_reqparse.get(value_type, str)

    if hasattr(value, "required"):
        reqparse_argument_parameters["required"] = value.required

    if hasattr(value, "metadata") and "description" in value.metadata:
        reqparse_argument_parameters["help"] = value.metadata["description"]

    return reqparse_argument_parameters


def is_list_field(field):
    """Returns `True` if the given field is a list type."""
    # Need to handle both flask_restx and marshmallow fields.
    return isinstance(field, ma.List) or isinstance(field, fr.List)