import marshmallow as m import marshmallow.fields as f from marshmallow_oneofschema import OneOfSchema import pytest REQUIRED_ERROR = "Missing data for required field." class Foo: def __init__(self, value=None): self.value = value def __repr__(self): return "<Foo value=%s>" % self.value def __eq__(self, other): return isinstance(other, self.__class__) and self.value == other.value class FooSchema(m.Schema): value = f.String(required=True) @m.post_load def make_foo(self, data, **kwargs): return Foo(**data) class Bar: def __init__(self, value=None): self.value = value def __repr__(self): return "<Bar value=%s>" % self.value def __eq__(self, other): return isinstance(other, self.__class__) and self.value == other.value class BarSchema(m.Schema): value = f.Integer(required=True) @m.post_load def make_bar(self, data, **kwargs): return Bar(**data) class Baz: def __init__(self, value1=None, value2=None): self.value1 = value1 self.value2 = value2 def __repr__(self): return "<Bar value1={} value2={}>".format(self.value1, self.value2) def __eq__(self, other): return ( isinstance(other, self.__class__) and self.value1 == other.value1 and self.value2 == other.value2 ) class BazSchema(m.Schema): value1 = f.Integer(required=True) value2 = f.String(required=True) @m.post_load def make_baz(self, data, **kwargs): return Baz(**data) class Empty: pass class EmptySchema(m.Schema): @m.post_load def make_empty(self, data, **kwargs): return Empty(**data) class MySchema(OneOfSchema): type_schemas = { "Foo": FooSchema, "Bar": BarSchema, "Baz": BazSchema, "Empty": EmptySchema, } class TestOneOfSchema: def test_dump(self): foo_result = MySchema().dump(Foo("hello")) assert {"type": "Foo", "value": "hello"} == foo_result bar_result = MySchema().dump(Bar(123)) assert {"type": "Bar", "value": 123} == bar_result def test_dump_many(self): result = MySchema().dump([Foo("hello"), Bar(123)], many=True) assert [ {"type": "Foo", "value": "hello"}, {"type": "Bar", "value": 123}, ] == result def test_dump_many_in_constructor(self): result = MySchema(many=True).dump([Foo("hello"), Bar(123)]) assert [ {"type": "Foo", "value": "hello"}, {"type": "Bar", "value": 123}, ] == result def test_dump_with_empty_keeps_type(self): result = MySchema().dump(Empty()) assert {"type": "Empty"} == result def test_load(self): foo_result = MySchema().load({"type": "Foo", "value": "world"}) assert Foo("world") == foo_result bar_result = MySchema().load({"type": "Bar", "value": 456}) assert Bar(456) == bar_result def test_load_many(self): result = MySchema().load( [{"type": "Foo", "value": "hello world!"}, {"type": "Bar", "value": 123}], many=True, ) assert Foo("hello world!"), Bar(123) == result def test_load_many_in_constructor(self): result = MySchema(many=True).load( [{"type": "Foo", "value": "hello world!"}, {"type": "Bar", "value": 123}] ) assert Foo("hello world!"), Bar(123) == result def test_load_removes_type_field(self): class Nonlocal: data = None class MySchema(m.Schema): def load(self, data, *args, **kwargs): Nonlocal.data = data return super().load(data, *args, **kwargs) class FooSchema(MySchema): foo = f.String(required=True) class BarSchema(MySchema): bar = f.Integer(required=True) class TestSchema(OneOfSchema): type_schemas = {"Foo": FooSchema, "Bar": BarSchema} TestSchema().load({"type": "Foo", "foo": "hello"}) assert "type" not in Nonlocal.data TestSchema().load({"type": "Bar", "bar": 123}) assert "type" not in Nonlocal.data def test_load_keeps_type_field(self): class Nonlocal: data = None type = None class MySchema(m.Schema): def load(self, data, *args, **kwargs): Nonlocal.data = data return super().load(data, *args, **kwargs) class FooSchema(MySchema): foo = f.String(required=True) class BarSchema(MySchema): bar = f.Integer(required=True) class TestSchema(OneOfSchema): type_field_remove = False type_schemas = {"Foo": FooSchema, "Bar": BarSchema} TestSchema(unknown="exclude").load({"type": "Foo", "foo": "hello"}) assert Nonlocal.data["type"] == "Foo" TestSchema(unknown="exclude").load({"type": "Bar", "bar": 123}) assert Nonlocal.data["type"] == "Bar" def test_load_non_dict(self): with pytest.raises(m.ValidationError) as exc_info: MySchema().load(123) assert {} != exc_info.value with pytest.raises(m.ValidationError) as exc_info: MySchema().load("Foo") assert {} != exc_info.value def test_load_errors_no_type(self): with pytest.raises(m.ValidationError) as exc_info: MySchema().load({"value": "Foo"}) assert {"type": [REQUIRED_ERROR]} == exc_info.value.messages def test_load_errors_field_error(self): with pytest.raises(m.ValidationError) as exc_info: MySchema().load({"type": "Foo"}) assert {"value": [REQUIRED_ERROR]} == exc_info.value.messages def test_load_errors_strict(self): with pytest.raises(m.ValidationError) as exc_info: MySchema().load({"type": "Foo"}) assert { "value": ["Missing data for required field."] } == exc_info.value.messages def test_load_many_errors_are_indexed_by_object_position(self): with pytest.raises(m.ValidationError) as exc_info: MySchema().load([{"type": "Foo"}, {"type": "Bar", "value": 123}], many=True) assert {0: {"value": [REQUIRED_ERROR]}} == exc_info.value.messages def test_load_many_errors_strict(self): with pytest.raises(m.ValidationError) as exc_info: MySchema().load( [ {"type": "Foo", "value": "hello world!"}, {"type": "Foo"}, {"type": "Bar", "value": 123}, {"type": "Bar", "value": "hello"}, ], many=True, ) assert { 1: {"value": ["Missing data for required field."]}, 3: {"value": ["Not a valid integer."]}, } == exc_info.value.messages def test_load_partial_specific(self): result = MySchema().load({"type": "Foo"}, partial=("value", "value2")) assert Foo() == result result = MySchema().load( {"type": "Baz", "value1": 123}, partial=("value", "value2") ) assert Baz(value1=123) == result def test_load_partial_any(self): result = MySchema().load({"type": "Foo"}, partial=True) assert Foo() == result result = MySchema().load({"type": "Baz", "value1": 123}, partial=True) assert Baz(value1=123) == result result = MySchema().load({"type": "Baz", "value2": "hello"}, partial=True) assert Baz(value2="hello") == result def test_load_partial_specific_in_constructor(self): result = MySchema(partial=("value", "value2")).load({"type": "Foo"}) assert Foo() == result result = MySchema(partial=("value", "value2")).load( {"type": "Baz", "value1": 123} ) assert Baz(value1=123) == result def test_load_partial_any_in_constructor(self): result = MySchema(partial=True).load({"type": "Foo"}) assert Foo() == result result = MySchema(partial=True).load({"type": "Baz", "value1": 123}) assert Baz(value1=123) == result result = MySchema(partial=True).load({"type": "Baz", "value2": "hello"}) assert Baz(value2="hello") == result def test_validate(self): assert {} == MySchema().validate({"type": "Foo", "value": "123"}) assert {"value": [REQUIRED_ERROR]} == MySchema().validate({"type": "Bar"}) assert {"value": [REQUIRED_ERROR]} == MySchema().validate({"type": "Bar"}) def test_validate_many(self): errors = MySchema().validate( [{"type": "Foo", "value": "123"}, {"type": "Bar", "value": 123}], many=True ) assert {} == errors errors = MySchema().validate([{"value": "123"}, {"type": "Bar"}], many=True) assert {0: {"type": [REQUIRED_ERROR]}, 1: {"value": [REQUIRED_ERROR]}} == errors errors = MySchema().validate([{"value": "123"}, {"type": "Bar"}], many=True) assert {0: {"type": [REQUIRED_ERROR]}, 1: {"value": [REQUIRED_ERROR]}} == errors def test_validate_many_in_constructor(self): errors = MySchema(many=True).validate( [{"type": "Foo", "value": "123"}, {"type": "Bar", "value": 123}] ) assert {} == errors errors = MySchema(many=True).validate([{"value": "123"}, {"type": "Bar"}]) assert {0: {"type": [REQUIRED_ERROR]}, 1: {"value": [REQUIRED_ERROR]}} == errors def test_validate_partial_specific(self): errors = MySchema().validate({"type": "Foo"}, partial=("value", "value2")) assert {} == errors errors = MySchema().validate( {"type": "Baz", "value1": 123}, partial=("value", "value2") ) assert {} == errors def test_validate_partial_any(self): errors = MySchema().validate({"type": "Foo"}, partial=True) assert {} == errors errors = MySchema().validate({"type": "Baz", "value1": 123}, partial=True) assert {} == errors errors = MySchema().validate({"type": "Baz", "value2": "hello"}, partial=True) assert {} == errors def test_validate_partial_specific_in_constructor(self): errors = MySchema(partial=("value", "value2")).validate({"type": "Foo"}) assert {} == errors errors = MySchema(partial=("value", "value2")).validate( {"type": "Baz", "value1": 123} ) assert {} == errors def test_validate_partial_any_in_constructor(self): errors = MySchema(partial=True).validate({"type": "Foo"}) assert {} == errors errors = MySchema(partial=True).validate({"type": "Baz", "value1": 123}) assert {} == errors errors = MySchema(partial=True).validate({"type": "Baz", "value2": "hello"}) assert {} == errors def test_using_as_nested_schema(self): class SchemaWithList(m.Schema): items = f.List(f.Nested(MySchema)) schema = SchemaWithList() result = schema.load( { "items": [ {"type": "Foo", "value": "hello world!"}, {"type": "Bar", "value": 123}, ] } ) assert {"items": [Foo("hello world!"), Bar(123)]} == result with pytest.raises(m.ValidationError) as exc_info: schema.load( {"items": [{"type": "Foo", "value": "hello world!"}, {"value": 123}]} ) assert {"items": {1: {"type": [REQUIRED_ERROR]}}} == exc_info.value.messages def test_using_as_nested_schema_with_many(self): class SchemaWithMany(m.Schema): items = f.Nested(MySchema, many=True) schema = SchemaWithMany() result = schema.load( { "items": [ {"type": "Foo", "value": "hello world!"}, {"type": "Bar", "value": 123}, ] } ) assert {"items": [Foo("hello world!"), Bar(123)]} == result with pytest.raises(m.ValidationError) as exc_info: schema.load( {"items": [{"type": "Foo", "value": "hello world!"}, {"value": 123}]} ) assert {"items": {1: {"type": [REQUIRED_ERROR]}}} == exc_info.value.messages def test_using_custom_type_names(self): class MyCustomTypeNameSchema(OneOfSchema): type_schemas = {"baz": FooSchema, "bam": BarSchema} def get_obj_type(self, obj): return {"Foo": "baz", "Bar": "bam"}.get(obj.__class__.__name__) schema = MyCustomTypeNameSchema() data = [Foo("hello"), Bar(111)] marshalled = schema.dump(data, many=True) assert [ {"type": "baz", "value": "hello"}, {"type": "bam", "value": 111}, ] == marshalled unmarshalled = schema.load(marshalled, many=True) assert data == unmarshalled def test_using_custom_type_field(self): class MyCustomTypeFieldSchema(MySchema): type_field = "object_type" schema = MyCustomTypeFieldSchema() data = [Foo("hello"), Bar(111)] marshalled = schema.dump(data, many=True) assert [ {"object_type": "Foo", "value": "hello"}, {"object_type": "Bar", "value": 111}, ] == marshalled unmarshalled = schema.load(marshalled, many=True) assert data == unmarshalled