from datetime import date, datetime, time, timedelta
from decimal import Decimal
from typing import Dict, List, Tuple, Union
from uuid import UUID

from marshmallow import fields
from marshmallow.base import FieldABC, SchemaABC

from ._compat import _get_base
from .base import AbstractConverter, ConfigOptions, FieldFactory, TypeRegistry
from .exceptions import AnnotationConversionError


def _is_generic(typehint: type) -> bool:
    # this *could* be isinstance(typehint, (Generic, _GenericAlias)) but
    # this works out better given that __origin__ isn't likely to go away
    # the way _GenericAlias might
    return getattr(typehint, "__origin__", None) is not None


def field_factory(field: FieldABC) -> FieldFactory:
    """
    Maps a marshmallow field into a field factory
    """

    def _(
        converter: AbstractConverter, subtypes: Tuple[type], opts: ConfigOptions
    ) -> FieldABC:
        return field(**opts)

    _.__name__ = f"{field.__name__}FieldFactory"
    return _


def scheme_factory(scheme_name: str) -> FieldFactory:
    """
    Maps a scheme or scheme name into a field factory
    """

    def _(
        converter: AbstractConverter, subtypes: Tuple[type], opts: ConfigOptions
    ) -> FieldABC:
        return fields.Nested(scheme_name, **opts)

    _.__name__ = f"{scheme_name}FieldFactory"
    _.__is_scheme__ = True  # type: ignore
    return _


def _list_converter(
    converter: AbstractConverter, subtypes: Tuple[type], opts: ConfigOptions
) -> FieldABC:
    if converter.is_scheme(subtypes[0]):
        opts["many"] = True
        return converter.convert(subtypes[0], opts)
    sub_opts = opts.pop("_interior", {})
    return fields.List(converter.convert(subtypes[0], sub_opts), **opts)


class DefaultTypeRegistry(TypeRegistry):
    """
    Default implementation of :class:`~marshmallow_annotations.base.TypeRegistry`.

    Provides default mappings of:

    - bool -> fields.Boolean
    - date -> fields.Date
    - datetime -> fields.DateTime
    - Decimal -> fields.Decimal
    - float -> fields.Float
    - int -> fields.Integer
    - str -> fields.String
    - time -> fields.Time
    - timedelta -> fields.TimeDelta
    - UUID -> fields.UUID
    - dict -> fields.Dict
    - typing.Dict -> fields.Dict

    As well as a special factory for typing.List[T] that will generate either
    fields.List or fields.Nested
    """

    _registry = {
        k: field_factory(v)
        for k, v in {
            bool: fields.Boolean,
            date: fields.Date,
            datetime: fields.DateTime,
            Decimal: fields.Decimal,
            float: fields.Float,
            int: fields.Integer,
            str: fields.String,
            time: fields.Time,
            timedelta: fields.TimeDelta,
            UUID: fields.UUID,
            dict: fields.Dict,
            Dict: fields.Dict,
        }.items()
    }

    # py36, py37 compatibility, register both out of praticality
    _registry[List] = _list_converter
    _registry[list] = _list_converter

    def __init__(self, registry: Dict[type, FieldFactory] = None) -> None:
        if registry is None:
            registry = {}

        self._registry = {**self._registry, **registry}

    def register(self, target: type, constructor: FieldFactory) -> None:
        self._registry[target] = constructor

    def get(self, target: type) -> FieldFactory:
        converter = self._registry.get(target)
        if converter is None and _is_generic(target):
            converter = self._registry.get(_get_base(target))

        if converter is None:
            raise AnnotationConversionError(f"No field factory found for {target!r}")
        return converter

    def register_field_for_type(self, target: type, field: FieldABC) -> None:
        self.register(target, field_factory(field))

    def register_scheme_factory(
        self, target: type, scheme_or_name: Union[str, SchemaABC]
    ) -> None:
        self.register(target, scheme_factory(scheme_or_name))

    def has(self, target: type) -> bool:
        return target in self._registry


registry = DefaultTypeRegistry()