import pytest from dataclasses import dataclass from unittest.mock import patch, Mock from marshmallow import Schema, fields as ma from marshmallow.schema import SchemaMeta from flask import Flask from flask_restx import Api, fields as fr, namespace # from .utils import unpack_list, unpack_nested import flask_accepts.utils as utils def test_unpack_list(): app = Flask(__name__) api = Api(app) with patch("flask_accepts.utils.unpack_list", wraps=utils.unpack_list) as mock: result = utils.unpack_list(ma.List(ma.Integer()), api=api) assert isinstance(result, fr.List) assert mock.call_count == 1 def test_unpack_list_of_list(): app = Flask(__name__) api = Api(app) with patch( "flask_accepts.utils.unpack_list", wraps=utils.unpack_list ) as mock, patch.dict("flask_accepts.utils.type_map", {ma.List: mock}): result = utils.unpack_list(ma.List(ma.List(ma.Integer())), api=api) assert isinstance(result, fr.List) assert mock.call_count == 2 def test_unpack_nested(): app = Flask(__name__) api = Api(app) class IntegerSchema(Schema): my_int: ma.Integer() result = utils.unpack_nested(ma.Nested(IntegerSchema), api=api) assert result def test_unpack_nested_self(): app = Flask(__name__) api = Api(app) class IntegerSchema(Schema): my_int = ma.Integer() children = ma.Nested("self", exclude=["children"]) schema = IntegerSchema() result = utils.unpack_nested(schema.fields.get("children"), api=api) assert type(result) == fr.Nested def test_unpack_nested_self_many(): app = Flask(__name__) api = Api(app) class IntegerSchema(Schema): my_int = ma.Integer() children = ma.Nested("self", exclude=["children"], many=True) schema = IntegerSchema() result = utils.unpack_nested(schema.fields.get("children"), api=api) assert type(result) == fr.List def test_get_default_model_name(): from .utils import get_default_model_name class TestSchema(Schema): pass result = get_default_model_name(TestSchema) expected = "Test" assert result == expected def test_get_default_model_name_works_with_multiple_schema_in_name(): from .utils import get_default_model_name class TestSchemaSchema(Schema): pass result = get_default_model_name(TestSchemaSchema) expected = "TestSchema" assert result == expected def test_get_default_model_name_that_does_not_end_in_schema(): from .utils import get_default_model_name class SomeOtherName(Schema): pass result = get_default_model_name(SomeOtherName) expected = "SomeOtherName" assert result == expected def test_get_default_model_name_default_names(): from .utils import get_default_model_name, num_default_models for model_num in range(5): result = get_default_model_name() expected = f"DefaultResponseModel_{model_num + num_default_models}" assert result == expected def test__check_load_dump_only_on_dump(): @dataclass class FakeField: load_only: bool dump_only: bool assert not utils._check_load_dump_only( FakeField(load_only=True, dump_only=False), "dump" ) assert utils._check_load_dump_only( FakeField(load_only=False, dump_only=True), "dump" ) def test__check_load_dump_only_on_load(): @dataclass class FakeField: load_only: bool dump_only: bool assert utils._check_load_dump_only( FakeField(load_only=True, dump_only=False), "load" ) assert not utils._check_load_dump_only( FakeField(load_only=False, dump_only=True), "load" ) def test__check_load_dump_only_raises_on_invalid_operation(): @dataclass class FakeField: load_only: bool dump_only: bool with pytest.raises(ValueError): utils._check_load_dump_only( FakeField(load_only=True, dump_only=False), "not an operation" ) def test__ma_field_to_fr_field_converts_required_param_if_present(): @dataclass class FakeFieldWithRequired(ma.Field): required: bool fr_field_dict = utils._ma_field_to_fr_field(FakeFieldWithRequired(required=True)) assert fr_field_dict["required"] is True @dataclass class FakeFieldNoRequired(ma.Field): pass fr_field_dict = utils._ma_field_to_fr_field(FakeFieldNoRequired()) assert "required" not in fr_field_dict def test__ma_field_to_fr_field_converts_missing_param_to_default_if_present(): @dataclass class FakeFieldWithMissing(ma.Field): missing: bool fr_field_dict = utils._ma_field_to_fr_field(FakeFieldWithMissing(missing=True)) assert fr_field_dict["default"] is True @dataclass class FakeFieldNoMissing(ma.Field): pass fr_field_dict = utils._ma_field_to_fr_field(FakeFieldNoMissing()) assert "default" not in fr_field_dict def test__ma_field_to_fr_field_converts_metadata_param_to_description_if_present(): @dataclass class FakeFieldWithDescription(ma.Field): metadata: dict expected_description = "test" fr_field_dict = utils._ma_field_to_fr_field( FakeFieldWithDescription(metadata={"description": expected_description}) ) assert fr_field_dict["description"] == expected_description @dataclass class FakeFieldNoMetaData(ma.Field): pass fr_field_dict = utils._ma_field_to_fr_field(FakeFieldNoMetaData()) assert "description" not in fr_field_dict @dataclass class FakeFieldNoDescription(ma.Field): metadata: dict fr_field_dict = utils._ma_field_to_fr_field(FakeFieldNoDescription(metadata={})) assert "description" not in fr_field_dict def test__ma_field_to_fr_field_converts_default_to_example_if_present(): @dataclass class FakeFieldWithDefault(ma.Field): default: str expected_example_value = "test" fr_field_dict = utils._ma_field_to_fr_field( FakeFieldWithDefault(default=expected_example_value) ) assert fr_field_dict["example"] == expected_example_value @dataclass class FakeFieldNoDefault(ma.Field): pass fr_field_dict = utils._ma_field_to_fr_field(FakeFieldNoDefault()) assert "example" not in fr_field_dict def test__ma_field_to_fr_field_returns_empty_dict_for_no_params_present_in_ma_field(): @dataclass class FakeFieldWithNoParams(ma.Field): pass fr_field_dict = utils._ma_field_to_fr_field(FakeFieldWithNoParams()) assert not fr_field_dict def test_make_type_mapper_works_with_required(): from flask_accepts.utils import make_type_mapper app = Flask(__name__) api = Api(app) mapper = make_type_mapper(fr.Raw) result = mapper(ma.Raw(required=True), api=api, model_name="test_model_name", operation="load") assert result.required def test_make_type_mapper_produces_nonrequired_param_by_default(): from flask_accepts.utils import make_type_mapper app = Flask(__name__) api = Api(app) mapper = make_type_mapper(fr.Raw) result = mapper(ma.Raw(), api=api, model_name="test_model_name", operation="load") assert not result.required def test__maybe_add_operation_passes_through_if_no_load_only(): from flask_accepts.utils import _maybe_add_operation class TestSchema(Schema): _id = ma.Integer() model_name = "TestSchema" operation = "load" result = _maybe_add_operation(TestSchema(), model_name, operation) expected = model_name assert result == expected def test__maybe_add_operation_append_if_load_only(): from flask_accepts.utils import _maybe_add_operation class TestSchema(Schema): _id = ma.Integer(load_only=True) model_name = "TestSchema" operation = "load" result = _maybe_add_operation(TestSchema(), model_name, operation) expected = f"{model_name}-load" assert result == expected def test__maybe_add_operation_passes_through_if_no_dump_only(): from flask_accepts.utils import _maybe_add_operation class TestSchema(Schema): _id = ma.Integer() model_name = "TestSchema" operation = "dump" result = _maybe_add_operation(TestSchema(), model_name, operation) expected = model_name assert result == expected def test__maybe_add_operation_append_if_dump_only(): from flask_accepts.utils import _maybe_add_operation class TestSchema(Schema): _id = ma.Integer(dump_only=True) model_name = "TestSchema" operation = "dump" result = _maybe_add_operation(TestSchema(), model_name, operation) expected = f"{model_name}-dump" assert result == expected def test_map_type_calls_type_map_dict_function_for_known_type_with_correct_parameters(): expected_ma_field = ma.Float expected_model_name, expected_operation, expected_namespace = _get_type_mapper_default_params() float_type_mapper = Mock() type_map_mock = { type(expected_ma_field): float_type_mapper } type_map_patch = patch.object(utils, "type_map", new=type_map_mock) with type_map_patch: utils.map_type(expected_ma_field, expected_namespace, expected_model_name, expected_operation) float_type_mapper.assert_called_with( expected_ma_field, expected_namespace, expected_model_name, expected_operation ) def test_map_type_calls_type_map_dict_function_for_schema_instance(): class MarshmallowSchema(Schema): test_field: ma.Float expected_ma_field = MarshmallowSchema() expected_model_name, expected_operation, expected_namespace = _get_type_mapper_default_params() schema_type_mapper_mock = Mock() type_map_mock = dict(utils.type_map) type_map_mock[Schema] = schema_type_mapper_mock type_map_patch = patch.object(utils, "type_map", new=type_map_mock) with type_map_patch: utils.map_type(expected_ma_field, expected_namespace, expected_model_name, expected_operation) schema_type_mapper_mock.assert_called_with( expected_ma_field, expected_namespace, expected_model_name, expected_operation ) def test_map_type_calls_type_map_dict_function_for_schema_class(): class InheritedMeta(SchemaMeta): pass class MarshmallowSchema(Schema, metaclass=InheritedMeta): test_field: ma.Float expected_ma_field = MarshmallowSchema expected_model_name, expected_operation, expected_namespace = _get_type_mapper_default_params() schema_type_mapper_mock = Mock() type_map_mock = dict(utils.type_map) type_map_mock[Schema] = schema_type_mapper_mock type_map_patch = patch.object(utils, "type_map", new=type_map_mock) with type_map_patch: utils.map_type(expected_ma_field, expected_namespace, expected_model_name, expected_operation) schema_type_mapper_mock.assert_called_with( expected_ma_field, expected_namespace, expected_model_name, expected_operation ) def test_map_type_raises_error_for_unknown_type(): class UnknownType: test_field: ma.Float unknown_ma_field = UnknownType expected_model_name, expected_operation, expected_namespace = _get_type_mapper_default_params() with pytest.raises(TypeError): utils.map_type(unknown_ma_field, expected_namespace, expected_model_name, expected_operation) def test_map_type_dump_ma_method_returns_fr_raw(): class TestSchema(Schema): method_field = ma.Method() TestApi = Api() expected_method_field_mapping = fr.Raw map_result = utils.map_type(TestSchema, TestApi, 'TestSchema','dump') assert isinstance(map_result['method_field'], expected_method_field_mapping) def _get_type_mapper_default_params(): return "test-model", "test-operation", namespace.Namespace("test-ns") def test_ma_field_to_reqparse_argument_single_values(): # Test a simple integer. result = utils.ma_field_to_reqparse_argument(ma.Integer(required=True)) assert result["type"] is int assert result["required"] is True assert result["action"] == "store" assert "help" not in result # Test that complex fields default to string. result = utils.ma_field_to_reqparse_argument(ma.Email(required=True, description="A description")) assert result["type"] is str assert result["required"] is True assert result["action"] == "store" assert result["help"] == "A description" def test_ma_field_to_reqparse_argument_list_values(): result = utils.ma_field_to_reqparse_argument(ma.List(ma.Integer())) assert result["type"] is int assert result["required"] is False assert result["action"] == "append" assert "help" not in result result = utils.ma_field_to_reqparse_argument(ma.List(ma.String(), description="A description")) assert result["type"] is str assert result["required"] is False assert result["action"] == "append" assert result["help"] == "A description"