diff --git a/udata/api_fields.py b/udata/api_fields.py index 02111f2af..a53fcef5c 100644 --- a/udata/api_fields.py +++ b/udata/api_fields.py @@ -37,18 +37,20 @@ """ import functools -from typing import Dict, List +from typing import Any, Callable, Iterable import flask_restx.fields as restx_fields import mongoengine import mongoengine.fields as mongo_fields from bson import ObjectId from flask_restx.inputs import boolean +from flask_restx.reqparse import RequestParser from flask_storage.mongo import ImageField as FlaskStorageImageField import udata.api.fields as custom_restx_fields from udata.api import api, base_reference from udata.mongo.errors import FieldValidationError +from udata.mongo.queryset import DBPaginator, UDataQuerySet lazy_reference = api.model( "LazyReference", @@ -59,7 +61,7 @@ ) -def convert_db_to_field(key, field, info): +def convert_db_to_field(key, field, info) -> tuple[Callable | None, Callable | None]: """Map a Mongo field to a Flask RestX field. Most of the types are a simple 1-to-1 mapping except lists and references that requires @@ -71,15 +73,15 @@ def convert_db_to_field(key, field, info): user-supplied overrides, setting the readonly flag…), it's easier to have to do this only once at the end of the function. """ - params = {} + params: dict = {} params["required"] = field.required - read_params = {} - write_params = {} + read_params: dict = {} + write_params: dict = {} - constructor = None - constructor_read = None - constructor_write = None + constructor: Callable + constructor_read: Callable | None = None + constructor_write: Callable | None = None if info.get("convert_to"): # TODO: this is currently never used. We may remove it if the auto-conversion @@ -105,7 +107,7 @@ def convert_db_to_field(key, field, info): elif isinstance(field, mongo_fields.DictField): constructor = restx_fields.Raw elif isinstance(field, mongo_fields.ImageField) or isinstance(field, FlaskStorageImageField): - size = info.get("size", None) + size: int | None = info.get("size", None) if size: params["description"] = f"URL of the cropped and squared image ({size}x{size})" else: @@ -142,7 +144,7 @@ def constructor_read(**kwargs): # 1. `inner_field_info` inside `__additional_field_info__` on the parent # 2. `__additional_field_info__` of the inner field # 3. `__additional_field_info__` of the parent - inner_info = getattr(field.field, "__additional_field_info__", {}) + inner_info: dict = getattr(field.field, "__additional_field_info__", {}) field_read, field_write = convert_db_to_field( f"{key}.inner", field.field, {**info, **inner_info, **info.get("inner_field_info", {})} ) @@ -169,7 +171,7 @@ def constructor(**kwargs): # For reading, if the user supplied a `nested_fields` (RestX model), we use it to convert # the referenced model, if not we return a String (and RestX will call the `str()` of the model # when returning from an endpoint) - nested_fields = info.get("nested_fields") + nested_fields: dict | None = info.get("nested_fields") if nested_fields is None: # If there is no `nested_fields` convert the object to the string representation. constructor_read = restx_fields.String @@ -206,9 +208,11 @@ def constructor_write(**kwargs): read_params = {**params, **read_params, **info} write_params = {**params, **write_params, **info} - read = constructor_read(**read_params) if constructor_read else constructor(**read_params) + read: Callable = ( + constructor_read(**read_params) if constructor_read else constructor(**read_params) + ) if write_params.get("readonly", False) or (constructor_write is None and constructor is None): - write = None + write: Callable | None = None else: write = ( constructor_write(**write_params) if constructor_write else constructor(**write_params) @@ -216,7 +220,7 @@ def constructor_write(**kwargs): return read, write -def get_fields(cls): +def get_fields(cls) -> Iterable[tuple[str, Callable, dict]]: """Return all the document fields that are wrapped with the `field()` helper. Also expand image fields to add thumbnail fields. @@ -237,7 +241,7 @@ def get_fields(cls): ) -def generate_fields(**kwargs): +def generate_fields(**kwargs) -> Callable: """Mongoengine document decorator. This decorator will create two auto-generated attributes on the class `__read_fields__` and `__write_fields__` @@ -249,59 +253,60 @@ def generate_fields(**kwargs): """ - def wrapper(cls): - read_fields = {} - write_fields = {} - ref_fields = {} - sortables = kwargs.get("additional_sorts", []) + def wrapper(cls) -> Callable: + from udata.models import db + + read_fields: dict = {} + write_fields: dict = {} + ref_fields: dict = {} + sortables: list = kwargs.get("additional_sorts", []) - filterables = [] - additional_filters = get_fields_with_additional_filters( + filterables: list[dict] = [] + additional_filters: dict[str, dict] = get_fields_with_additional_filters( kwargs.get("additional_filters", []) ) read_fields["id"] = restx_fields.String(required=True, readonly=True) for key, field, info in get_fields(cls): - sortable_key = info.get("sortable", False) + sortable_key: bool = info.get("sortable", False) if sortable_key: - sortables.append( - { - "key": sortable_key if isinstance(sortable_key, str) else key, - "value": key, - } - ) - - filterable = info.get("filterable", None) + sortables.append({ + "key": sortable_key if isinstance(sortable_key, str) else key, + "value": key, + }) + filterable: dict[str, Any] | None = info.get("filterable", None) if filterable is not None: filterables.append(compute_filter(key, field, info, filterable)) - additional_filter = additional_filters.get(key, None) + additional_filter: dict | None = additional_filters.get(key, None) if additional_filter: if not isinstance( field, mongo_fields.ReferenceField | mongo_fields.LazyReferenceField ): raise Exception("Cannot use additional_filters on not a ref.") - ref_model = field.document_type + ref_model: db.Document = field.document_type for child in additional_filter.get("children", []): - inner_field = getattr(ref_model, child["key"]) + inner_field: str = getattr(ref_model, child["key"]) - column = f"{key}__{child['key']}" + column: str = f"{key}__{child['key']}" child["key"] = f"{key}_{child['key']}" filterable = compute_filter(column, inner_field, info, child) # Since MongoDB is not capable of doing joins with a column like `organization__slug` we need to # do a custom filter by splitting the query in two. - def query(filterable, query, value): + def query(filterable, query, value) -> UDataQuerySet: # We use the computed `filterable["column"]` here because the `compute_filter` function # could have add default filter at the end (for example `organization__badges` converted # in `organization__badges__kind`) parts = filterable["column"].split("__", 1) - models = ref_model.objects.filter(**{parts[1]: value}).only("id") + models: UDataQuerySet = ref_model.objects.filter(**{parts[1]: value}).only( + "id" + ) return query.filter(**{f"{parts[0]}__in": models}) # do a query-based filter instead of a column based one @@ -335,8 +340,8 @@ def query(filterable, query, value): if not callable(method): continue - info = getattr(method, "__additional_field_info__", None) - if info is None: + additional_field_info = getattr(method, "__additional_field_info__", None) + if additional_field_info is None: continue def make_lambda(method): @@ -348,16 +353,16 @@ def make_lambda(method): return lambda o: method(o) read_fields[method_name] = restx_fields.String( - attribute=make_lambda(method), **{"readonly": True, **info} + attribute=make_lambda(method), **{"readonly": True, **additional_field_info} ) - if info.get("show_as_ref", False): + if additional_field_info.get("show_as_ref", False): ref_fields[key] = read_fields[method_name] cls.__read_fields__ = api.model(f"{cls.__name__} (read)", read_fields, **kwargs) cls.__write_fields__ = api.model(f"{cls.__name__} (write)", write_fields, **kwargs) cls.__ref_fields__ = api.inherit(f"{cls.__name__}Reference", base_reference, ref_fields) - mask = kwargs.pop("mask", None) + mask: str | None = kwargs.pop("mask", None) if mask is not None: mask = "data{{{0}}},*".format(mask) cls.__page_fields__ = api.model( @@ -368,8 +373,8 @@ def make_lambda(method): ) # Parser for index sort/filters - paginable = kwargs.get("paginable", True) - parser = api.parser() + paginable: bool = kwargs.get("paginable", True) + parser: RequestParser = api.parser() if paginable: parser.add_argument( @@ -380,7 +385,7 @@ def make_lambda(method): ) if sortables: - choices = [sortable["key"] for sortable in sortables] + [ + choices: list[str] = [sortable["key"] for sortable in sortables] + [ "-" + sortable["key"] for sortable in sortables ] parser.add_argument( @@ -391,7 +396,7 @@ def make_lambda(method): help="The field (and direction) on which sorting apply", ) - searchable = kwargs.pop("searchable", False) + searchable: bool = kwargs.pop("searchable", False) if searchable: parser.add_argument("q", type=str, location="args") @@ -405,18 +410,17 @@ def make_lambda(method): cls.__index_parser__ = parser - def apply_sort_filters_and_pagination(base_query): + def apply_sort_filters_and_pagination(base_query) -> DBPaginator: args = cls.__index_parser__.parse_args() if sortables and args["sort"]: - negate = args["sort"].startswith("-") - sort_key = args["sort"][1:] if negate else args["sort"] + negate: bool = args["sort"].startswith("-") + sort_key: str = args["sort"][1:] if negate else args["sort"] - sort_by = next( + sort_by: str | None = next( (sortable["value"] for sortable in sortables if sortable["key"] == sort_key), None, ) - if sort_by: if negate: sort_by = "-" + sort_by @@ -424,7 +428,7 @@ def apply_sort_filters_and_pagination(base_query): base_query = base_query.order_by(sort_by) if searchable and args.get("q"): - phrase_query = " ".join([f'"{elem}"' for elem in args["q"].split(" ")]) + phrase_query: str = " ".join([f'"{elem}"' for elem in args["q"].split(" ")]) base_query = base_query.search_text(phrase_query) for filterable in filterables: @@ -439,11 +443,9 @@ def apply_sort_filters_and_pagination(base_query): if query: base_query = filterable["query"](base_query, args[filterable["key"]]) else: - base_query = base_query.filter( - **{ - filterable["column"]: args[filterable["key"]], - } - ) + base_query = base_query.filter(**{ + filterable["column"]: args[filterable["key"]], + }) if paginable: base_query = base_query.paginate(args["page"], args["page_size"]) @@ -457,7 +459,7 @@ def apply_sort_filters_and_pagination(base_query): return wrapper -def function_field(**info): +def function_field(**info) -> Callable: def inner(func): func.__additional_field_info__ = info return func @@ -475,7 +477,7 @@ def field(inner, **kwargs): return inner -def patch(obj, request): +def patch(obj, request) -> type: """Patch the object with the data from the request. Only fields decorated with the `field()` decorator will be read (and not readonly). @@ -527,7 +529,7 @@ def patch(obj, request): return obj -def patch_and_save(obj, request): +def patch_and_save(obj, request) -> type: obj = patch(obj, request) try: @@ -595,7 +597,7 @@ def wrap_primary_key( ) -def get_fields_with_additional_filters(additional_filters: List[str]) -> Dict[str, any]: +def get_fields_with_additional_filters(additional_filters: list[str]) -> dict[str, dict]: """Filter on additional related fields. Right now we only support additional filters with a depth of two, eg "organization.badges". @@ -604,7 +606,7 @@ def get_fields_with_additional_filters(additional_filters: List[str]) -> Dict[st be able to compute them when we loop over all the fields (`title`, `organization`…) """ - results = {} + results: dict = {} for key in additional_filters: parts = key.split(".") if len(parts) == 2: @@ -614,19 +616,17 @@ def get_fields_with_additional_filters(additional_filters: List[str]) -> Dict[st if parent not in results: results[parent] = {"children": []} - results[parent]["children"].append( - { - "key": child, - "type": str, - } - ) + results[parent]["children"].append({ + "key": child, + "type": str, + }) else: raise Exception(f"Do not support `additional_filters` without two parts: {key}.") return results -def compute_filter(column: str, field, info, filterable): +def compute_filter(column: str, field, info, filterable) -> dict: # "key" is the param key in the URL if "key" not in filterable: filterable["key"] = column diff --git a/udata/tests/test_api_fields.py b/udata/tests/test_api_fields.py index b00ebe5a6..3fd494c7a 100644 --- a/udata/tests/test_api_fields.py +++ b/udata/tests/test_api_fields.py @@ -1,34 +1,36 @@ import factory import pytest -from flask_restx.reqparse import RequestParser +from flask_restx.reqparse import Argument, RequestParser from werkzeug.exceptions import BadRequest from udata.api_fields import field, function_field, generate_fields, patch, patch_and_save from udata.core.dataset.api_fields import dataset_fields from udata.core.organization import constants as org_constants from udata.core.organization.factories import OrganizationFactory +from udata.core.organization.models import Organization from udata.core.owned import Owned from udata.core.storages import default_image_basename, images from udata.factories import ModelFactory from udata.models import Badge, BadgeMixin, BadgesList, WithMetrics, db +from udata.mongo.queryset import DBPaginator from udata.utils import faker pytestmark = [ pytest.mark.usefixtures("clean_db"), ] -BIGGEST_IMAGE_SIZE = 500 +BIGGEST_IMAGE_SIZE: int = 500 -BADGES = { +BADGES: dict[str, str] = { "badge-1": "badge 1", "badge-2": "badge 2", } -URL_RAISE_ERROR = "/raise/validation/error" -URL_EXISTS_ERROR_MESSAGE = "Url exists" +URL_RAISE_ERROR: str = "/raise/validation/error" +URL_EXISTS_ERROR_MESSAGE: str = "Url exists" -def check_url(url=""): +def check_url(url: str = "") -> None: if url == URL_RAISE_ERROR: raise ValueError(URL_EXISTS_ERROR_MESSAGE) return @@ -113,20 +115,20 @@ class Fake(WithMetrics, FakeBadgeMixin, Owned, db.Document): db.DateTimeField(), ) - def __str__(self): + def __str__(self) -> str: return self.title or "" @function_field(description="Link to the API endpoint for this fake", show_as_ref=True) - def uri(self): + def uri(self) -> str: return "fake/foobar/endpoint/" - __metrics_keys__ = [ + __metrics_keys__: list[str] = [ "datasets", "followers", "views", ] - meta = { + meta: dict = { "indexes": [ "$title", ], @@ -145,14 +147,14 @@ class Meta: class IndexParserTest: - index_parser = Fake.__index_parser__ - index_parser_args = Fake.__index_parser__.args - index_parser_args_names = set([field.name for field in Fake.__index_parser__.args]) + index_parser: RequestParser = Fake.__index_parser__ + index_parser_args: list[Argument] = Fake.__index_parser__.args + index_parser_args_names: set[str] = set([field.name for field in Fake.__index_parser__.args]) - def test_index_parser(self): + def test_index_parser(self) -> None: assert type(self.index_parser) is RequestParser - def test_filterable_fields_in_parser(self): + def test_filterable_fields_in_parser(self) -> None: """All filterable fields should have a parser arg. The parser arg uses the `key` provided instead of the field name, so @@ -167,40 +169,47 @@ def test_filterable_fields_in_parser(self): ] ).issubset(self.index_parser_args_names) - def test_readonly_and_non_wrapped_fields_not_in_parser(self): + def test_readonly_and_non_wrapped_fields_not_in_parser(self) -> None: """Readonly fields() and non wrapped fields should not have a parser arg.""" for field_ in ["slug", "image_url"]: assert field_ not in self.index_parser_args_names - def test_filterable_fields_from_mixins_in_parser(self): + def test_filterable_fields_from_mixins_in_parser(self) -> None: """Filterable fields from mixins should have a parser arg.""" assert set(["owner", "organization"]).issubset(self.index_parser_args_names) - def test_additional_filters_in_parser(self): + def test_additional_filters_in_parser(self) -> None: """Filterable fields from the `additional_filters` decorater parameter should have a parser arg.""" assert "organization_badges" in self.index_parser_args_names - def test_pagination_fields_in_parser(self): + def test_pagination_fields_in_parser(self) -> None: """Pagination fields should have a parser arg.""" assert "page" in self.index_parser_args_names assert "page_size" in self.index_parser_args_names - def test_searchable(self): + def test_searchable(self) -> None: """Searchable documents have a `q` parser arg.""" assert "q" in self.index_parser_args_names - def test_sortable_fields_in_parser(self): + def test_sortable_fields_in_parser(self) -> None: """Sortable fields are listed in the `sort` parser arg choices.""" - sort_arg = next(arg for arg in self.index_parser_args if arg.name == "sort") - choices = sort_arg.choices + sort_arg: Argument = next(arg for arg in self.index_parser_args if arg.name == "sort") + choices: list[str] = sort_arg.choices assert "title" in choices assert "-title" in choices - def test_additional_sorts_in_parser(self): + def test_additional_sorts_in_parser(self) -> None: """Additional sorts are listed in the `sort` parser arg choices.""" - sort_arg = next(arg for arg in self.index_parser_args if arg.name == "sort") - choices = sort_arg.choices - additional_sorts = ["datasets", "-datasets", "followers", "-followers", "views", "-views"] + sort_arg: Argument = next(arg for arg in self.index_parser_args if arg.name == "sort") + choices: list[str] = sort_arg.choices + additional_sorts: list[str] = [ + "datasets", + "-datasets", + "followers", + "-followers", + "views", + "-views", + ] assert set(additional_sorts).issubset(set(choices)) @@ -208,13 +217,13 @@ class PatchTest: class FakeRequest: json = {"url": URL_RAISE_ERROR, "description": None} - def test_patch_check(self): - fake = FakeFactory.create() + def test_patch_check(self) -> None: + fake: Fake = FakeFactory.create() with pytest.raises(ValueError, match=URL_EXISTS_ERROR_MESSAGE): patch(fake, self.FakeRequest()) - def test_patch_and_save(self): - fake = FakeFactory.create() + def test_patch_and_save(self) -> None: + fake: Fake = FakeFactory.create() fake_request = self.FakeRequest() fake_request.json["url"] = "ok url" with pytest.raises(BadRequest): @@ -222,56 +231,56 @@ def test_patch_and_save(self): class ApplySortAndFiltersTest: - def test_filterable_field(self, app): + def test_filterable_field(self, app) -> None: """A filterable field filters the results.""" - fake1 = FakeFactory(filter_field="test filter") - fake2 = FakeFactory(filter_field="some other filter") + fake1: Fake = FakeFactory(filter_field="test filter") + fake2: Fake = FakeFactory(filter_field="some other filter") with app.test_request_context("/foobar", query_string={"filter_field_name": "test filter"}): - results = Fake.apply_sort_filters_and_pagination(Fake.objects) + results: DBPaginator = Fake.apply_sort_filters_and_pagination(Fake.objects) assert fake1 in results assert fake2 not in results - def test_additional_filters(self, app): + def test_additional_filters(self, app) -> None: """Filtering on an additional filter filters the results.""" - org_public_service = OrganizationFactory() + org_public_service: Organization = OrganizationFactory() org_public_service.add_badge(org_constants.PUBLIC_SERVICE) - org_company = OrganizationFactory() + org_company: Organization = OrganizationFactory() org_company.add_badge(org_constants.COMPANY) - fake1 = FakeFactory(organization=org_public_service) - fake2 = FakeFactory(organization=org_company) + fake1: Fake = FakeFactory(organization=org_public_service) + fake2: Fake = FakeFactory(organization=org_company) with app.test_request_context( "/foobar", query_string={"organization_badges": org_constants.PUBLIC_SERVICE} ): - results = Fake.apply_sort_filters_and_pagination(Fake.objects) + results: DBPaginator = Fake.apply_sort_filters_and_pagination(Fake.objects) assert fake1 in results assert fake2 not in results - def test_searchable(self, app): + def test_searchable(self, app) -> None: """If `@generate_fields(searchable=True)`, then the document can be full-text searched.""" - fake1 = FakeFactory(title="foobar crux") - fake2 = FakeFactory(title="barbaz crux") + fake1: Fake = FakeFactory(title="foobar crux") + fake2: Fake = FakeFactory(title="barbaz crux") with app.test_request_context("/foobar", query_string={"q": "foobar"}): - results = Fake.apply_sort_filters_and_pagination(Fake.objects) + results: DBPaginator = Fake.apply_sort_filters_and_pagination(Fake.objects) assert fake1 in results assert fake2 not in results - def test_sortable_field(self, app): + def test_sortable_field(self, app) -> None: """A sortable field should sort the results.""" - fake1 = FakeFactory(title="abc") - fake2 = FakeFactory(title="def") + fake1: Fake = FakeFactory(title="abc") + fake2: Fake = FakeFactory(title="def") with app.test_request_context("/foobar", query_string={"sort": "title"}): - results = Fake.apply_sort_filters_and_pagination(Fake.objects) + results: DBPaginator = Fake.apply_sort_filters_and_pagination(Fake.objects) assert tuple(results) == (fake1, fake2) with app.test_request_context("/foobar", query_string={"sort": "-title"}): results = Fake.apply_sort_filters_and_pagination(Fake.objects) assert tuple(results) == (fake2, fake1) - def test_additional_sorts(self, app): + def test_additional_sorts(self, app) -> None: """Sorting on additional sort sorts the results.""" - fake1 = FakeFactory(metrics={"datasets": 1}) - fake2 = FakeFactory(metrics={"datasets": 2}) + fake1: Fake = FakeFactory(metrics={"datasets": 1}) + fake2: Fake = FakeFactory(metrics={"datasets": 2}) with app.test_request_context("/foobar", query_string={"sort": "datasets"}): - results = Fake.apply_sort_filters_and_pagination(Fake.objects) + results: DBPaginator = Fake.apply_sort_filters_and_pagination(Fake.objects) assert tuple(results) == (fake1, fake2) with app.test_request_context("/foobar", query_string={"sort": "-datasets"}): results = Fake.apply_sort_filters_and_pagination(Fake.objects)