Skip to content

Commit

Permalink
Type test_api_fields.py and partially api_fields.py
Browse files Browse the repository at this point in the history
  • Loading branch information
magopian committed Oct 28, 2024
1 parent 8539f8a commit 991b1d1
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 105 deletions.
110 changes: 58 additions & 52 deletions udata/api_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,20 @@
"""

import functools
from typing import Dict, List
from typing import Any, Callable, Iterable

import flask_restx.fields as restx_fields
import mongoengine
import mongoengine.fields as mongo_fields
from bson import ObjectId
from flask_restx.inputs import boolean
from flask_restx.reqparse import RequestParser
from flask_storage.mongo import ImageField as FlaskStorageImageField

import udata.api.fields as custom_restx_fields
from udata.api import api, base_reference
from udata.mongo.errors import FieldValidationError
from udata.mongo.queryset import DBPaginator, UDataQuerySet

lazy_reference = api.model(
"LazyReference",
Expand All @@ -59,7 +61,7 @@
)


def convert_db_to_field(key, field, info):
def convert_db_to_field(key, field, info) -> tuple[Callable | None, Callable | None]:
"""Map a Mongo field to a Flask RestX field.
Most of the types are a simple 1-to-1 mapping except lists and references that requires
Expand All @@ -71,15 +73,15 @@ def convert_db_to_field(key, field, info):
user-supplied overrides, setting the readonly flag…), it's easier to have to do this only once at the end of the function.
"""
params = {}
params: dict = {}
params["required"] = field.required

read_params = {}
write_params = {}
read_params: dict = {}
write_params: dict = {}

constructor = None
constructor_read = None
constructor_write = None
constructor: Callable
constructor_read: Callable | None = None
constructor_write: Callable | None = None

if info.get("convert_to"):
# TODO: this is currently never used. We may remove it if the auto-conversion
Expand All @@ -105,7 +107,7 @@ def convert_db_to_field(key, field, info):
elif isinstance(field, mongo_fields.DictField):
constructor = restx_fields.Raw
elif isinstance(field, mongo_fields.ImageField) or isinstance(field, FlaskStorageImageField):
size = info.get("size", None)
size: int | None = info.get("size", None)
if size:
params["description"] = f"URL of the cropped and squared image ({size}x{size})"
else:
Expand Down Expand Up @@ -142,7 +144,7 @@ def constructor_read(**kwargs):
# 1. `inner_field_info` inside `__additional_field_info__` on the parent
# 2. `__additional_field_info__` of the inner field
# 3. `__additional_field_info__` of the parent
inner_info = getattr(field.field, "__additional_field_info__", {})
inner_info: dict = getattr(field.field, "__additional_field_info__", {})
field_read, field_write = convert_db_to_field(
f"{key}.inner", field.field, {**info, **inner_info, **info.get("inner_field_info", {})}
)
Expand All @@ -169,7 +171,7 @@ def constructor(**kwargs):
# For reading, if the user supplied a `nested_fields` (RestX model), we use it to convert
# the referenced model, if not we return a String (and RestX will call the `str()` of the model
# when returning from an endpoint)
nested_fields = info.get("nested_fields")
nested_fields: dict | None = info.get("nested_fields")
if nested_fields is None:
# If there is no `nested_fields` convert the object to the string representation.
constructor_read = restx_fields.String
Expand Down Expand Up @@ -206,17 +208,19 @@ def constructor_write(**kwargs):
read_params = {**params, **read_params, **info}
write_params = {**params, **write_params, **info}

read = constructor_read(**read_params) if constructor_read else constructor(**read_params)
read: Callable = (
constructor_read(**read_params) if constructor_read else constructor(**read_params)
)
if write_params.get("readonly", False) or (constructor_write is None and constructor is None):
write = None
write: Callable | None = None
else:
write = (
constructor_write(**write_params) if constructor_write else constructor(**write_params)
)
return read, write


def get_fields(cls):
def get_fields(cls) -> Iterable[tuple[str, Callable, dict]]:
"""Return all the document fields that are wrapped with the `field()` helper.
Also expand image fields to add thumbnail fields.
Expand All @@ -237,7 +241,7 @@ def get_fields(cls):
)


def generate_fields(**kwargs):
def generate_fields(**kwargs) -> Callable:
"""Mongoengine document decorator.
This decorator will create two auto-generated attributes on the class `__read_fields__` and `__write_fields__`
Expand All @@ -249,21 +253,23 @@ def generate_fields(**kwargs):
"""

def wrapper(cls):
read_fields = {}
write_fields = {}
ref_fields = {}
sortables = kwargs.get("additional_sorts", [])
def wrapper(cls) -> Callable:
from udata.models import db

read_fields: dict = {}
write_fields: dict = {}
ref_fields: dict = {}
sortables: list = kwargs.get("additional_sorts", [])

filterables = []
additional_filters = get_fields_with_additional_filters(
filterables: list[dict] = []
additional_filters: dict[str, dict] = get_fields_with_additional_filters(
kwargs.get("additional_filters", [])
)

read_fields["id"] = restx_fields.String(required=True, readonly=True)

for key, field, info in get_fields(cls):
sortable_key = info.get("sortable", False)
sortable_key: bool = info.get("sortable", False)
if sortable_key:
sortables.append(
{
Expand All @@ -272,36 +278,37 @@ def wrapper(cls):
}
)

filterable = info.get("filterable", None)

filterable: dict[str, Any] | None = info.get("filterable", None)
if filterable is not None:
filterables.append(compute_filter(key, field, info, filterable))

additional_filter = additional_filters.get(key, None)
additional_filter: dict | None = additional_filters.get(key, None)
if additional_filter:
if not isinstance(
field, mongo_fields.ReferenceField | mongo_fields.LazyReferenceField
):
raise Exception("Cannot use additional_filters on not a ref.")

ref_model = field.document_type
ref_model: db.Document = field.document_type

for child in additional_filter.get("children", []):
inner_field = getattr(ref_model, child["key"])
inner_field: str = getattr(ref_model, child["key"])

column = f"{key}__{child['key']}"
column: str = f"{key}__{child['key']}"
child["key"] = f"{key}_{child['key']}"
filterable = compute_filter(column, inner_field, info, child)

# Since MongoDB is not capable of doing joins with a column like `organization__slug` we need to
# do a custom filter by splitting the query in two.

def query(filterable, query, value):
def query(filterable, query, value) -> UDataQuerySet:
# We use the computed `filterable["column"]` here because the `compute_filter` function
# could have add default filter at the end (for example `organization__badges` converted
# in `organization__badges__kind`)
parts = filterable["column"].split("__", 1)
models = ref_model.objects.filter(**{parts[1]: value}).only("id")
models: UDataQuerySet = ref_model.objects.filter(**{parts[1]: value}).only(
"id"
)
return query.filter(**{f"{parts[0]}__in": models})

# do a query-based filter instead of a column based one
Expand Down Expand Up @@ -335,8 +342,8 @@ def query(filterable, query, value):
if not callable(method):
continue

info = getattr(method, "__additional_field_info__", None)
if info is None:
additional_field_info = getattr(method, "__additional_field_info__", None)
if additional_field_info is None:
continue

def make_lambda(method):
Expand All @@ -348,16 +355,16 @@ def make_lambda(method):
return lambda o: method(o)

read_fields[method_name] = restx_fields.String(
attribute=make_lambda(method), **{"readonly": True, **info}
attribute=make_lambda(method), **{"readonly": True, **additional_field_info}
)
if info.get("show_as_ref", False):
if additional_field_info.get("show_as_ref", False):
ref_fields[key] = read_fields[method_name]

cls.__read_fields__ = api.model(f"{cls.__name__} (read)", read_fields, **kwargs)
cls.__write_fields__ = api.model(f"{cls.__name__} (write)", write_fields, **kwargs)
cls.__ref_fields__ = api.inherit(f"{cls.__name__}Reference", base_reference, ref_fields)

mask = kwargs.pop("mask", None)
mask: str | None = kwargs.pop("mask", None)
if mask is not None:
mask = "data{{{0}}},*".format(mask)
cls.__page_fields__ = api.model(
Expand All @@ -368,8 +375,8 @@ def make_lambda(method):
)

# Parser for index sort/filters
paginable = kwargs.get("paginable", True)
parser = api.parser()
paginable: bool = kwargs.get("paginable", True)
parser: RequestParser = api.parser()

if paginable:
parser.add_argument(
Expand All @@ -380,7 +387,7 @@ def make_lambda(method):
)

if sortables:
choices = [sortable["key"] for sortable in sortables] + [
choices: list[str] = [sortable["key"] for sortable in sortables] + [
"-" + sortable["key"] for sortable in sortables
]
parser.add_argument(
Expand All @@ -391,7 +398,7 @@ def make_lambda(method):
help="The field (and direction) on which sorting apply",
)

searchable = kwargs.pop("searchable", False)
searchable: bool = kwargs.pop("searchable", False)
if searchable:
parser.add_argument("q", type=str, location="args")

Expand All @@ -405,26 +412,25 @@ def make_lambda(method):

cls.__index_parser__ = parser

def apply_sort_filters_and_pagination(base_query):
def apply_sort_filters_and_pagination(base_query) -> DBPaginator:
args = cls.__index_parser__.parse_args()

if sortables and args["sort"]:
negate = args["sort"].startswith("-")
sort_key = args["sort"][1:] if negate else args["sort"]
negate: bool = args["sort"].startswith("-")
sort_key: str = args["sort"][1:] if negate else args["sort"]

sort_by = next(
sort_by: str | None = next(
(sortable["value"] for sortable in sortables if sortable["key"] == sort_key),
None,
)

if sort_by:
if negate:
sort_by = "-" + sort_by

base_query = base_query.order_by(sort_by)

if searchable and args.get("q"):
phrase_query = " ".join([f'"{elem}"' for elem in args["q"].split(" ")])
phrase_query: str = " ".join([f'"{elem}"' for elem in args["q"].split(" ")])
base_query = base_query.search_text(phrase_query)

for filterable in filterables:
Expand Down Expand Up @@ -457,7 +463,7 @@ def apply_sort_filters_and_pagination(base_query):
return wrapper


def function_field(**info):
def function_field(**info) -> Callable:
def inner(func):
func.__additional_field_info__ = info
return func
Expand All @@ -475,7 +481,7 @@ def field(inner, **kwargs):
return inner


def patch(obj, request):
def patch(obj, request) -> type:
"""Patch the object with the data from the request.
Only fields decorated with the `field()` decorator will be read (and not readonly).
Expand Down Expand Up @@ -527,7 +533,7 @@ def patch(obj, request):
return obj


def patch_and_save(obj, request):
def patch_and_save(obj, request) -> type:
obj = patch(obj, request)

try:
Expand Down Expand Up @@ -595,7 +601,7 @@ def wrap_primary_key(
)


def get_fields_with_additional_filters(additional_filters: List[str]) -> Dict[str, any]:
def get_fields_with_additional_filters(additional_filters: list[str]) -> dict[str, dict]:
"""Filter on additional related fields.
Right now we only support additional filters with a depth of two, eg "organization.badges".
Expand All @@ -604,7 +610,7 @@ def get_fields_with_additional_filters(additional_filters: List[str]) -> Dict[st
be able to compute them when we loop over all the fields (`title`, `organization`…)
"""
results = {}
results: dict = {}
for key in additional_filters:
parts = key.split(".")
if len(parts) == 2:
Expand All @@ -626,7 +632,7 @@ def get_fields_with_additional_filters(additional_filters: List[str]) -> Dict[st
return results


def compute_filter(column: str, field, info, filterable):
def compute_filter(column: str, field, info, filterable) -> dict:
# "key" is the param key in the URL
if "key" not in filterable:
filterable["key"] = column
Expand Down
Loading

0 comments on commit 991b1d1

Please sign in to comment.