import inspect
import typing
import unittest
from enum import Enum
from typing import Dict, Optional, Union, Any, List, Tuple

from marshmallow import fields, Schema

from marshmallow_dataclass import field_for_schema, dataclass


class TestFieldForSchema(unittest.TestCase):
    def assertFieldsEqual(self, a: fields.Field, b: fields.Field):
        self.assertEqual(a.__class__, b.__class__, "field class")

        def attrs(x):
            return {
                k: f"{v!r} ({v.__mro__!r})" if inspect.isclass(v) else repr(v)
                for k, v in x.__dict__.items()
                if not k.startswith("_")
            }

        self.assertEqual(attrs(a), attrs(b))

    def test_int(self):
        self.assertFieldsEqual(
            field_for_schema(int, default=9, metadata=dict(required=False)),
            fields.Integer(default=9, missing=9, required=False),
        )

    def test_any(self):
        self.assertFieldsEqual(
            field_for_schema(Any), fields.Raw(required=True, allow_none=True)
        )

    def test_dict_from_typing(self):
        self.assertFieldsEqual(
            field_for_schema(Dict[str, float]),
            fields.Dict(
                keys=fields.String(required=True),
                values=fields.Float(required=True),
                required=True,
            ),
        )

    def test_builtin_dict(self):
        self.assertFieldsEqual(
            field_for_schema(dict),
            fields.Dict(
                keys=fields.Raw(required=True, allow_none=True),
                values=fields.Raw(required=True, allow_none=True),
                required=True,
            ),
        )

    def test_builtin_list(self):
        self.assertFieldsEqual(
            field_for_schema(list, metadata=dict(required=False)),
            fields.List(fields.Raw(required=True, allow_none=True), required=False),
        )

    def test_explicit_field(self):
        explicit_field = fields.Url(required=True)
        self.assertFieldsEqual(
            field_for_schema(str, metadata={"marshmallow_field": explicit_field}),
            explicit_field,
        )

    def test_str(self):
        self.assertFieldsEqual(field_for_schema(str), fields.String(required=True))

    def test_optional_str(self):
        self.assertFieldsEqual(
            field_for_schema(Optional[str]),
            fields.String(allow_none=True, required=False, default=None, missing=None),
        )

    def test_enum(self):
        import marshmallow_enum

        class Color(Enum):
            RED: 1
            GREEN: 2
            BLUE: 3

        self.assertFieldsEqual(
            field_for_schema(Color),
            marshmallow_enum.EnumField(enum=Color, required=True),
        )

    def test_union(self):
        import marshmallow_union

        self.assertFieldsEqual(
            field_for_schema(Union[int, str]),
            marshmallow_union.Union(
                fields=[fields.Integer(), fields.String()], required=True
            ),
        )

    def test_newtype(self):
        self.assertFieldsEqual(
            field_for_schema(typing.NewType("UserId", int), default=0),
            fields.Integer(required=False, description="UserId", default=0, missing=0),
        )

    def test_marshmallow_dataclass(self):
        class NewSchema(Schema):
            pass

        @dataclass(base_schema=NewSchema)
        class NewDataclass:
            pass

        self.assertFieldsEqual(
            field_for_schema(NewDataclass, metadata=dict(required=False)),
            fields.Nested(NewDataclass.Schema),
        )

    def test_override_container_type_with_type_mapping(self):
        type_mapping = [
            (List, fields.List, List[int]),
            (Dict, fields.Dict, Dict[str, int]),
            (Tuple, fields.Tuple, Tuple[int, str, bytes]),
        ]
        for base_type, marshmallow_field, schema in type_mapping:

            class MyType(marshmallow_field):
                ...

            self.assertIsInstance(field_for_schema(schema), marshmallow_field)

            class BaseSchema(Schema):
                TYPE_MAPPING = {base_type: MyType}

            self.assertIsInstance(
                field_for_schema(schema, base_schema=BaseSchema), MyType
            )


if __name__ == "__main__":
    unittest.main()