Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
ThibaudDauce committed Aug 20, 2024
1 parent 0c4c1cb commit aabc831
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 47 deletions.
83 changes: 54 additions & 29 deletions udata/api_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import udata.api.fields as custom_restx_fields
from udata.api import api, base_reference
from udata.core.elasticsearch import is_elasticsearch_enable
from udata.mongo.errors import FieldValidationError

lazy_reference = api.model(
Expand Down Expand Up @@ -244,6 +245,9 @@ def wrapper(cls):
if info is None:
continue

if not info.get("api", True):
continue

def make_lambda(method):
"""
Factory function to create a lambda with the correct scope.
Expand Down Expand Up @@ -308,43 +312,64 @@ def make_lambda(method):
def apply_sort_filters_and_pagination(base_query):
args = cls.__index_parser__.parse_args()

if sortables and args["sort"]:
negate = args["sort"].startswith("-")
sort_key = args["sort"][1:] if negate else args["sort"]

sort_by = next(
(sortable["value"] for sortable in sortables if sortable["key"] == sort_key),
None,
if (
args.get("q")
and is_elasticsearch_enable()
and getattr(cls, "__elasticsearch_search__", None) is not None
):
# Do an Elasticsearch query
print(cls.__elasticsearch_search__(args.get("q")))
print(
{
"data": cls.__elasticsearch_search__(args.get("q")),
}
)
return {
"data": cls.__elasticsearch_search__(args.get("q")),
}
else:
# Do a regular MongoDB query
if sortables and args["sort"]:
negate = args["sort"].startswith("-")
sort_key = args["sort"][1:] if negate else args["sort"]

sort_by = next(
(
sortable["value"]
for sortable in sortables
if sortable["key"] == sort_key
),
None,
)

if sort_by:
if negate:
sort_by = "-" + sort_by
if sort_by:
if negate:
sort_by = "-" + sort_by

base_query = base_query.order_by(sort_by)
base_query = base_query.order_by(sort_by)

if searchable and args.get("q"):
phrase_query = " ".join([f'"{elem}"' for elem in args["q"].split(" ")])
base_query = base_query.search_text(phrase_query)
if searchable and args.get("q"):
phrase_query = " ".join([f'"{elem}"' for elem in args["q"].split(" ")])
base_query = base_query.search_text(phrase_query)

for filterable in filterables:
if args.get(filterable["key"]):
for constraint in filterable["constraints"]:
if constraint == "objectid" and not ObjectId.is_valid(
args[filterable["key"]]
):
api.abort(400, f'`{filterable["key"]}` must be an identifier')
for filterable in filterables:
if args.get(filterable["key"]):
for constraint in filterable["constraints"]:
if constraint == "objectid" and not ObjectId.is_valid(
args[filterable["key"]]
):
api.abort(400, f'`{filterable["key"]}` must be an identifier')

base_query = base_query.filter(
**{
filterable["column"]: args[filterable["key"]],
}
)
base_query = base_query.filter(
**{
filterable["column"]: args[filterable["key"]],
}
)

if paginable:
base_query = base_query.paginate(args["page"], args["page_size"])
if paginable:
base_query = base_query.paginate(args["page"], args["page_size"])

return base_query
return base_query

cls.apply_sort_filters_and_pagination = apply_sort_filters_and_pagination
return cls
Expand Down
6 changes: 5 additions & 1 deletion udata/core/dataservices/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ def get(self):
"""List or search all dataservices"""
query = Dataservice.objects.visible()

return Dataservice.apply_sort_filters_and_pagination(query)
results = Dataservice.apply_sort_filters_and_pagination(query)
print(results)

print("here")
return results

@api.secure
@api.doc("create_dataservice", responses={400: "Validation error"})
Expand Down
2 changes: 1 addition & 1 deletion udata/core/dataservices/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class HarvestMetadata(db.EmbeddedDocument):
archived_at = field(db.DateTimeField())


@generate_fields()
@generate_fields(searchable=True)
@elasticsearch(
score_functions_description={
"public_service_score": {"factor": 8, "modifier": "sqrt", "missing": 1},
Expand Down
6 changes: 5 additions & 1 deletion udata/core/elasticsearch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@
T = TypeVar("T")


def is_elasticsearch_enable() -> bool:
return True


def elasticsearch(
score_functions_description: dict[str, dict] = {},
build_search_query=None,
Expand Down Expand Up @@ -166,7 +170,7 @@ def elasticsearch_search(query_text: str):
else:
query = Q(
"function_score",
query=query.MatchAll(),
query=query.MatchAll(), # todo only match `searchable` field and not `indexable` / `filterable`
functions=score_functions,
)

Expand Down
32 changes: 17 additions & 15 deletions udata/tests/api/test_dataservices_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,48 +353,50 @@ def test_elasticsearch(self):
)
time.sleep(1)

dataservices = Dataservice.__elasticsearch_search__("AMDAC")
print(self.get(url_for("api.dataservices", q="AMDAC")).json)

dataservices = self.get(url_for("api.dataservices", q="AMDAC")).json["data"]

assert len(dataservices) == 3
assert dataservices[0].title == dataservice_c.title
assert dataservices[1].title == dataservice_a.title
assert dataservices[0]["title"] == dataservice_c.title
assert dataservices[1]["title"] == dataservice_a.title
assert (
dataservices[2].title == dataservice_b.title
dataservices[2]["title"] == dataservice_b.title
) # b is last even if it doesn't really match.

dataservice_b.title = "B - Hello AMD world!"
dataservice_b.save()
time.sleep(3)

dataservices = Dataservice.__elasticsearch_search__("AMDAC")
dataservices = self.get(url_for("api.dataservices", q="AMDAC")).json["data"]

assert len(dataservices) == 3

# `dataservice_b` should be first because it has a lot of followers
assert dataservices[0].title == dataservice_b.title
assert dataservices[1].title == dataservice_c.title
assert dataservices[2].title == dataservice_a.title
assert dataservices[0]["title"] == dataservice_b.title
assert dataservices[1]["title"] == dataservice_c.title
assert dataservices[2]["title"] == dataservice_a.title

dataservice_a.organization = orga_sp
dataservice_a.save()
assert dataservice_a.public_service_score() == 4
time.sleep(3)

dataservices = Dataservice.__elasticsearch_search__("AMDAC")
dataservices = self.get(url_for("api.dataservices", q="AMDAC")).json["data"]

assert len(dataservices) == 3

assert dataservices[0].title == dataservice_b.title
assert dataservices[1].title == dataservice_a.title
assert dataservices[2].title == dataservice_c.title
assert dataservices[0]["title"] == dataservice_b.title
assert dataservices[1]["title"] == dataservice_a.title
assert dataservices[2]["title"] == dataservice_c.title

dataservice_b.archived_at = datetime.utcnow()
dataservice_b.save()
time.sleep(3)

dataservices = Dataservice.__elasticsearch_search__("AMDAC")
dataservices = self.get(url_for("api.dataservices", q="AMDAC")).json["data"]

assert len(dataservices) == 2

assert dataservices[0].title == dataservice_a.title
assert dataservices[1].title == dataservice_c.title
assert dataservices[0]["title"] == dataservice_a.title
assert dataservices[1]["title"] == dataservice_c.title

0 comments on commit aabc831

Please sign in to comment.