Skip to content

Commit

Permalink
Filter by org badge complex
Browse files Browse the repository at this point in the history
  • Loading branch information
ThibaudDauce committed Oct 3, 2024
1 parent 52b7d2b commit 48ede80
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 149 deletions.
181 changes: 119 additions & 62 deletions udata/api_fields.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import functools
from typing import Dict, List

import flask_restx.fields as restx_fields
import mongoengine
import mongoengine.fields as mongo_fields
Expand Down Expand Up @@ -183,8 +186,11 @@ def wrapper(cls):
write_fields = {}
ref_fields = {}
sortables = kwargs.get("additional_sorts", [])
related_filters = kwargs.get("related_filters", [])

filterables = []
additional_filters = get_fields_with_additional_filters(
kwargs.get("additional_filters", [])
)

read_fields["id"] = restx_fields.String(required=True, readonly=True)

Expand All @@ -200,32 +206,39 @@ def wrapper(cls):

filterable = info.get("filterable", None)
if filterable is not None:
if "key" not in filterable:
filterable["key"] = key
if "column" not in filterable:
filterable["column"] = key
filterables.append(compute_filter(key, field, info, filterable))

additional_filter = additional_filters.get(key, None)
if additional_filter:
if not isinstance(field, mongo_fields.ReferenceField):
raise Exception("Cannot use additional_filters on not a ref.")

ref_model = field.document_type

for child in additional_filter.get("children", []):
inner_field = getattr(ref_model, child["key"])

if "constraints" not in filterable:
filterable["constraints"] = []
if isinstance(field, mongo_fields.ReferenceField) or (
isinstance(field, mongo_fields.ListField)
and isinstance(field.field, mongo_fields.ReferenceField)
):
filterable["constraints"].append("objectid")
column = f"{key}__{child['key']}"
child["key"] = f"{key}_{child['key']}"
filterable = compute_filter(column, inner_field, info, child)

if "type" not in filterable:
filterable["type"] = str
if isinstance(field, mongo_fields.BooleanField):
filterable["type"] = boolean
# Since MongoDB is not capable of doing joins with a column like `organization__slug` we need to
# do a custom filter by splitting the query in two.

filterable["choices"] = info.get("choices", None)
if hasattr(field, "choices") and field.choices:
filterable["choices"] = field.choices
def query(filterable, query, value):
# We use the computed `filterable["column"]` here because the `compute_filter` function
# could have add default filter at the end (for example `organization__badges` converted
# in `organization__badges__kind`)
parts = filterable["column"].split("__", 1)
print(parts)
models = ref_model.objects.filter(**{parts[1]: value}).only("id")
return query.filter(**{f"{parts[0]}__in": models})

# We may add more information later here:
# - type of mongo query to execute (right now only simple =)
# do a query-based filter instead of a column based one
filterable["query"] = functools.partial(query, filterable)

filterables.append(filterable)
print(filterable)
filterables.append(filterable)

read, write = convert_db_to_field(key, field, info)

Expand Down Expand Up @@ -318,15 +331,7 @@ def make_lambda(method):
filterable["key"],
type=filterable["type"],
location="args",
choices=filterable["choices"],
)

for related_filter in related_filters:
parser.add_argument(
related_filter["key"],
type=related_filter.get("type", str),
location="args",
choices=related_filter.get("choices"),
choices=filterable.get("choices", None),
)

cls.__index_parser__ = parser
Expand Down Expand Up @@ -355,49 +360,29 @@ def apply_sort_filters_and_pagination(base_query):

for filterable in filterables:
if args.get(filterable["key"]) is not None:
for constraint in filterable["constraints"]:
for constraint in filterable.get("constraints", []):
if constraint == "objectid" and not ObjectId.is_valid(
args[filterable["key"]]
):
api.abort(400, f'`{filterable["key"]}` must be an identifier')

base_query = base_query.filter(
**{
filterable["column"]: args[filterable["key"]],
}
)

for related_filter in related_filters:
# This allows to define a query like so:
# related_filters=[
# {
# "key": "organization_badge",
# "lookup": "organization__in",
# "object": Organization,
# "queryset": "with_badge",
# "choices": list(Organization.__badges__),
# },
# ]
# This will return reuses with an organization that have the badge provided in the
# `organization_badge` parameter:
# - referenced_object: Organization
# - queryset: Organization.objects.with_badge
# - filtered_objects: Organization.objects.with_badge(<value of the parameter `organization_badge`>)
# - Reuse.objects.filter(organization__in=list(filtered_objects))
if args.get(related_filter["key"]) is not None:
referenced_object = related_filter["object"]
queryset = getattr(referenced_object.objects, related_filter["queryset"])
filtered_objects = queryset(args[related_filter["key"]]).only("id")
base_query = base_query.filter(
**{related_filter["lookup"]: list(filtered_objects)}
)
query = filterable.get("query", None)
if query:
base_query = filterable["query"](base_query, args[filterable["key"]])
else:
base_query = base_query.filter(
**{
filterable["column"]: args[filterable["key"]],
}
)

if paginable:
base_query = base_query.paginate(args["page"], args["page_size"])

return base_query

cls.apply_sort_filters_and_pagination = apply_sort_filters_and_pagination
cls.__additional_class_info__ = kwargs
return cls

return wrapper
Expand Down Expand Up @@ -533,3 +518,75 @@ def wrap_primary_key(
raise ValueError(
f"Unknown ID field type {id_field.__class__} for {document_type} (ID field name is {id_field_name}, value was {value})"
)


def get_fields_with_additional_filters(additional_filters: List[str]) -> Dict[str, any]:
"""
Right now we only support additional filters like "organization.badges"
The goal of this function is to keyby the additional filters by the first part (`organization`) to
be able to compute them when we loop over all the fields (`title`, `organization`…)
"""
results = {}
for key in additional_filters:
parts = key.split(".")
if len(parts) == 2:
parent = parts[0]
child = parts[1]

if parent not in results:
results[parent] = {"children": []}

results[parent]["children"].append(
{
"key": child,
"type": str,
}
)
else:
raise Exception(f"Do not support `additional_filters` without two parts: {key}.")

return results


def compute_filter(column: str, field, info, filterable):
# "key" is the param key in the URL
if "key" not in filterable:
filterable["key"] = column

# If we do a filter on a embed document, get the class info
# of this document to see if there is a default filter value
embed_info = None
if isinstance(field, mongo_fields.EmbeddedDocumentField):
embed_info = field.get("__additional_class_info__", None)
elif isinstance(field, mongo_fields.EmbeddedDocumentListField):
embed_info = getattr(field.field.document_type, "__additional_class_info__", None)

if embed_info and embed_info.get("default_filterable_field", None):
# There is a default filterable field so append it to the column and replace the
# field to use the inner one (for example using the `kind` `StringField` instead of
# the embed `Badge` field.)
filterable["column"] = f"{column}__{embed_info['default_filterable_field']}"
field = getattr(field.field.document_type, embed_info["default_filterable_field"])
else:
filterable["column"] = column

if "constraints" not in filterable:
filterable["constraints"] = []

if isinstance(field, mongo_fields.ReferenceField) or (
isinstance(field, mongo_fields.ListField)
and isinstance(field.field, mongo_fields.ReferenceField)
):
filterable["constraints"].append("objectid")

if "type" not in filterable:
if isinstance(field, mongo_fields.BooleanField):
filterable["type"] = boolean
else:
filterable["type"] = str

filterable["choices"] = info.get("choices", None)
if hasattr(field, "choices") and field.choices:
filterable["choices"] = field.choices

return filterable
140 changes: 75 additions & 65 deletions udata/core/badges/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from mongoengine.signals import post_save

from udata.api_fields import field
from udata.api_fields import field, generate_fields
from udata.auth import current_user
from udata.core.badges.fields import badge_fields
from udata.mongo import db
Expand All @@ -12,27 +12,23 @@

log = logging.getLogger(__name__)

__all__ = ("Badge", "BadgeMixin")

def new_badge_type(choices):
@generate_fields(default_filterable_field="kind")
class Badge(db.EmbeddedDocument):
kind = db.StringField(required=True, choices=choices)
created = db.DateTimeField(default=datetime.utcnow, required=True)
created_by = db.ReferenceField("User")

class Badge(db.EmbeddedDocument):
kind = db.StringField(required=True)
created = db.DateTimeField(default=datetime.utcnow, required=True)
created_by = db.ReferenceField("User")
def __str__(self):
return self.kind

def __str__(self):
return self.kind

def validate(self, clean=True):
badges = getattr(self._instance, "__badges__", {})
if self.kind not in badges.keys():
raise db.ValidationError("Unknown badge type %s" % self.kind)
return super(Badge, self).validate(clean=clean)
return Badge


class BadgesList(db.EmbeddedDocumentListField):
def __init__(self, *args, **kwargs):
return super(BadgesList, self).__init__(Badge, *args, **kwargs)
def __init__(self, badge_type, *args, **kwargs):
return super(BadgesList, self).__init__(badge_type, *args, **kwargs)

def validate(self, value):
kinds = [b.kind for b in value]
Expand All @@ -41,52 +37,66 @@ def validate(self, value):
return super(BadgesList, self).validate(value)


class BadgeMixin(object):
badges = field(
BadgesList(),
readonly=True,
inner_field_info={"nested_fields": badge_fields},
)

def get_badge(self, kind):
"""Get a badge given its kind if present"""
candidates = [b for b in self.badges if b.kind == kind]
return candidates[0] if candidates else None

def add_badge(self, kind):
"""Perform an atomic prepend for a new badge"""
badge = self.get_badge(kind)
if badge:
return badge
if kind not in getattr(self, "__badges__", {}):
msg = "Unknown badge type for {model}: {kind}"
raise db.ValidationError(msg.format(model=self.__class__.__name__, kind=kind))
badge = Badge(kind=kind)
if current_user.is_authenticated:
badge.created_by = current_user.id

self.update(__raw__={"$push": {"badges": {"$each": [badge.to_mongo()], "$position": 0}}})
self.reload()
post_save.send(self.__class__, document=self)
on_badge_added.send(self, kind=kind)
return self.get_badge(kind)

def remove_badge(self, kind):
"""Perform an atomic removal for a given badge"""
self.update(__raw__={"$pull": {"badges": {"kind": kind}}})
self.reload()
on_badge_removed.send(self, kind=kind)
post_save.send(self.__class__, document=self)

def toggle_badge(self, kind):
"""Toggle a bdage given its kind"""
badge = self.get_badge(kind)
if badge:
return self.remove_badge(kind)
else:
return self.add_badge(kind)

def badge_label(self, badge):
"""Display the badge label for a given kind"""
kind = badge.kind if isinstance(badge, Badge) else badge
return self.__badges__[kind]
def badge_mixin(choices):
badge_type = new_badge_type(choices)

class BadgeMixin(object):
badges = field(
BadgesList(badge_type),
readonly=True,
inner_field_info={"nested_fields": badge_fields},
)

def get_badge(self, kind):
"""Get a badge given its kind if present"""
candidates = [b for b in self.badges if b.kind == kind]
return candidates[0] if candidates else None

def add_badge(self, kind):
"""Perform an atomic prepend for a new badge"""
badge = self.get_badge(kind)
if badge:
return badge
if kind not in getattr(self, "__badges__", {}):
msg = "Unknown badge type for {model}: {kind}"
raise db.ValidationError(msg.format(model=self.__class__.__name__, kind=kind))
badge = self.badge_type(kind=kind)
if current_user.is_authenticated:
badge.created_by = current_user.id

self.update(
__raw__={"$push": {"badges": {"$each": [badge.to_mongo()], "$position": 0}}}
)
self.reload()
post_save.send(self.__class__, document=self)
on_badge_added.send(self, kind=kind)
return self.get_badge(kind)

def remove_badge(self, kind):
"""Perform an atomic removal for a given badge"""
self.update(__raw__={"$pull": {"badges": {"kind": kind}}})
self.reload()
on_badge_removed.send(self, kind=kind)
post_save.send(self.__class__, document=self)

def toggle_badge(self, kind):
"""Toggle a bdage given its kind"""
badge = self.get_badge(kind)
if badge:
return self.remove_badge(kind)
else:
return self.add_badge(kind)

def badge_label(self, badge):
"""Display the badge label for a given kind"""
kind = badge.kind if isinstance(badge, self.badge_type) else badge
return self.__badges__[kind]

setattr(BadgeMixin, "badge_type", badge_type)

return BadgeMixin


# For backward compatibility create a badge type without any choices
Badge = new_badge_type(None)
BadgeMixin = badge_mixin(None)
Loading

0 comments on commit 48ede80

Please sign in to comment.