Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use OrderedSet as default set_class #1896

Merged
merged 8 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 0 additions & 31 deletions docs/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -524,37 +524,6 @@ Note that ``name`` will be automatically formatted as a :class:`String <marshmal
# No need to include 'uppername'
additional = ("name", "email", "created_at")

Ordering Output
---------------

To maintain field ordering, set the ``ordered`` option to `True`. This will instruct marshmallow to serialize data to a `collections.OrderedDict`.

.. code-block:: python

from collections import OrderedDict
from pprint import pprint

from marshmallow import Schema, fields


class UserSchema(Schema):
first_name = fields.String()
last_name = fields.String()
email = fields.Email()

class Meta:
ordered = True


u = User("Charlie", "Stones", "[email protected]")
schema = UserSchema()
result = schema.dump(u)
assert isinstance(result, OrderedDict)
pprint(result, indent=2)
#  OrderedDict([('first_name', 'Charlie'),
# ('last_name', 'Stones'),
# ('email', '[email protected]')])

Next Steps
----------

Expand Down
4 changes: 0 additions & 4 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ class Field(FieldABC):
# to exist as attributes on the objects to serialize. Set this to False
# for those fields
_CHECK_ATTRIBUTE = True
_creation_index = 0 # Used for sorting

#: Default error messages for various kinds of errors. The keys in this dictionary
#: are passed to `Field.make_error`. The values are error messages passed to
Expand Down Expand Up @@ -227,9 +226,6 @@ def __init__(
stacklevel=2,
)

self._creation_index = Field._creation_index
Field._creation_index += 1

# Collect default error message from self and parent classes
messages = {} # type: dict[str, str]
for cls in reversed(self.__class__.__mro__):
Expand Down
40 changes: 15 additions & 25 deletions src/marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,21 @@
_T = typing.TypeVar("_T")


def _get_fields(attrs, ordered=False):
"""Get fields from a class. If ordered=True, fields will sorted by creation index.
def _get_fields(attrs):
"""Get fields from a class

:param attrs: Mapping of class attributes
:param bool ordered: Sort fields by creation index
"""
fields = [
return [
(field_name, field_value)
for field_name, field_value in attrs.items()
if is_instance_or_subclass(field_value, base.FieldABC)
]
if ordered:
fields.sort(key=lambda pair: pair[1]._creation_index)
return fields


# This function allows Schemas to inherit from non-Schema classes and ensures
# inheritance according to the MRO
def _get_fields_by_mro(klass, ordered=False):
def _get_fields_by_mro(klass):
"""Collect fields from a class, following its method resolution order. The
class itself is excluded from the search; only its parents are checked. Get
fields from ``_declared_fields`` if available, else use ``__dict__``.
Expand All @@ -73,7 +69,6 @@ class itself is excluded from the search; only its parents are checked. Get
(
_get_fields(
getattr(base, "_declared_fields", base.__dict__),
ordered=ordered,
)
for base in mro[:0:-1]
),
Expand Down Expand Up @@ -102,13 +97,13 @@ def __new__(mcs, name, bases, attrs):
break
else:
ordered = False
cls_fields = _get_fields(attrs, ordered=ordered)
cls_fields = _get_fields(attrs)
# Remove fields from list of class attributes to avoid shadowing
# Schema attributes/methods in case of name conflict
for field_name, _ in cls_fields:
del attrs[field_name]
klass = super().__new__(mcs, name, bases, attrs)
inherited_fields = _get_fields_by_mro(klass, ordered=ordered)
inherited_fields = _get_fields_by_mro(klass)

meta = klass.Meta
# Set klass.opts in __new__ rather than __init__ so that it is accessible in
Expand All @@ -117,13 +112,11 @@ def __new__(mcs, name, bases, attrs):
# Add fields specified in the `include` class Meta option
cls_fields += list(klass.opts.include.items())

dict_cls = OrderedDict if ordered else dict
# Assign _declared_fields on class
klass._declared_fields = mcs.get_declared_fields(
klass=klass,
cls_fields=cls_fields,
inherited_fields=inherited_fields,
dict_cls=dict_cls,
)
return klass

Expand All @@ -133,7 +126,7 @@ def get_declared_fields(
klass: type,
cls_fields: list,
inherited_fields: list,
dict_cls: type,
dict_cls: type = dict,
):
"""Returns a dictionary of field_name => `Field` pairs declared on the class.
This is exposed mainly so that plugins can add additional fields, e.g. fields
Expand All @@ -143,8 +136,7 @@ def get_declared_fields(
:param cls_fields: The fields declared on the class, including those added
by the ``include`` class Meta option.
:param inherited_fields: Inherited fields.
:param dict_cls: Either `dict` or `OrderedDict`, depending on whether
the user specified `ordered=True`.
:param dict_cls: dict-like class to use for dict output Default to ``dict``.
"""
return dict_cls(inherited_fields + cls_fields)

Expand Down Expand Up @@ -319,6 +311,8 @@ class AlbumSchema(Schema):

OPTIONS_CLASS = SchemaOpts # type: type

set_class = OrderedSet

# These get set by SchemaMeta
opts = None # type: SchemaOpts
_declared_fields = {} # type: typing.Dict[str, ma_fields.Field]
Expand Down Expand Up @@ -350,9 +344,7 @@ class Meta:
- ``timeformat``: Default format for `Time <fields.Time>` fields.
- ``render_module``: Module to use for `loads <Schema.loads>` and `dumps <Schema.dumps>`.
Defaults to `json` from the standard library.
- ``ordered``: If `True`, order serialization output according to the
order in which fields were declared. Output of `Schema.dump` will be a
`collections.OrderedDict`.
- ``ordered``: If `True`, output of `Schema.dump` will be a `collections.OrderedDict`.
- ``index_errors``: If `True`, errors dictionaries will include the index
of invalid items in a collection.
- ``load_only``: Tuple or list of fields to exclude from serialized results.
Expand Down Expand Up @@ -386,7 +378,9 @@ def __init__(
self.declared_fields = copy.deepcopy(self._declared_fields)
self.many = many
self.only = only
self.exclude = set(self.opts.exclude) | set(exclude)
self.exclude: set[typing.Any] | typing.MutableSet[typing.Any] = set(
self.opts.exclude
) | set(exclude)
self.ordered = self.opts.ordered
self.load_only = set(load_only) or set(self.opts.load_only)
self.dump_only = set(dump_only) or set(self.opts.dump_only)
Expand Down Expand Up @@ -419,10 +413,6 @@ def __repr__(self) -> str:
def dict_class(self) -> type:
return OrderedDict if self.ordered else dict

@property
def set_class(self) -> type:
return OrderedSet if self.ordered else set

@classmethod
def from_dict(
cls,
Expand Down Expand Up @@ -970,7 +960,7 @@ def _init_fields(self) -> None:

if self.only is not None:
# Return only fields specified in only option
field_names = self.set_class(self.only)
field_names: typing.AbstractSet[typing.Any] = self.set_class(self.only)

invalid_fields |= field_names - available_field_names
else:
Expand Down
2 changes: 1 addition & 1 deletion src/marshmallow/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
"""
import typing

StrSequenceOrSet = typing.Union[typing.Sequence[str], typing.Set[str]]
StrSequenceOrSet = typing.Union[typing.Sequence[str], typing.AbstractSet[str]]
Tag = typing.Union[str, typing.Tuple[str, bool]]
Validator = typing.Callable[[typing.Any], typing.Any]
9 changes: 5 additions & 4 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
RAISE,
missing,
)
from marshmallow.orderedset import OrderedSet
from marshmallow.exceptions import StringNotCollectionError

from tests.base import ALL_FIELDS
Expand Down Expand Up @@ -380,14 +381,14 @@ class MySchema(Schema):
@pytest.mark.parametrize(
("param", "fields_list"), [("only", ["foo"]), ("exclude", ["bar"])]
)
def test_ordered_instanced_nested_schema_only_and_exclude(self, param, fields_list):
def test_nested_schema_only_and_exclude(self, param, fields_list):
class NestedSchema(Schema):
# We mean to test the use of OrderedSet to specify it explicitly
# even if it is default
set_class = OrderedSet
foo = fields.String()
bar = fields.String()

class Meta:
ordered = True

class MySchema(Schema):
nested = fields.Nested(NestedSchema(), **{param: fields_list})

Expand Down
6 changes: 1 addition & 5 deletions tests/test_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,6 @@ def test_nested_field_order_with_only_arg_is_maintained_on_load(self):

def test_nested_field_order_with_exclude_arg_is_maintained(self, user):
class HasNestedExclude(Schema):
class Meta:
ordered = True

user = fields.Nested(KeepOrder, exclude=("birthdate",))

ser = HasNestedExclude()
Expand Down Expand Up @@ -231,7 +228,7 @@ def test_fields_are_added(self):
result = s.load({"name": "Steve", "from": "Oskosh"})
assert result == in_data

def test_ordered_included(self):
def test_included_fields_ordered_after_declared_fields(self):
class AddFieldsOrdered(Schema):
name = fields.Str()
email = fields.Str()
Expand All @@ -242,7 +239,6 @@ class Meta:
"in": fields.Str(),
"@at": fields.Str(),
}
ordered = True

s = AddFieldsOrdered()
in_data = {
Expand Down
13 changes: 13 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2933,3 +2933,16 @@ class Meta:
MySchema(unknown="badval")
else:
MySchema().load({"foo": "bar"}, unknown="badval")


@pytest.mark.parametrize("dict_cls", (dict, OrderedDict))
def test_set_dict_class(dict_cls):
"""Demonstrate how to specify dict_class as class attribute"""

class MySchema(Schema):
dict_class = dict_cls
foo = fields.String()

result = MySchema().dump({"foo": "bar"})
assert result == {"foo": "bar"}
assert isinstance(result, dict_cls)