Skip to content

Commit

Permalink
Add search queries customisation
Browse files Browse the repository at this point in the history
  • Loading branch information
ThibaudDauce committed Jul 31, 2024
1 parent c73f4cf commit 1989e8d
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 13 deletions.
45 changes: 44 additions & 1 deletion udata/core/dataservices/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from datetime import datetime

from elasticsearch_dsl import Search, query

import udata.core.contact_point.api_fields as contact_api_fields
import udata.core.dataset.api_fields as datasets_api_fields
from udata.api_fields import field, function_field, generate_fields
Expand All @@ -23,6 +25,42 @@
DATASERVICE_FORMATS = ["REST", "WMS", "WSL"]


def build_search_query(query_text: str, score_functions):
return query.Q(
"bool",
should=[
query.Q(
"function_score",
query=query.Bool(
should=[
query.MultiMatch(
query=query_text,
type="phrase",
fields=["title^15", "acronym^15", "description^8"],
)
]
),
functions=score_functions,
),
query.Q(
"function_score",
query=query.Bool(
should=[
query.MultiMatch(
query=query_text,
type="cross_fields",
fields=["title^7", "acronym^7", "description^4"],
operator="and",
)
]
),
functions=score_functions,
),
# query.Match(title={"query": query_text, "fuzziness": "AUTO:4,6"}),
],
)


class DataserviceQuerySet(OwnedQuerySet):
def visible(self):
return self(archived_at=None, deleted_at=None, private=False)
Expand Down Expand Up @@ -58,7 +96,12 @@ class HarvestMetadata(db.EmbeddedDocument):


@generate_fields()
@elasticsearch()
@elasticsearch(
score_functions_description={
"metrics.followers": {"factor": 4, "modifier": "sqrt", "missing": 1}
},
build_search_query=build_search_query,
)
class Dataservice(WithMetrics, Owned, db.Document):
meta = {
"indexes": [
Expand Down
63 changes: 56 additions & 7 deletions udata/core/elasticsearch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import random
import string
Expand All @@ -12,13 +13,17 @@
Field,
Float,
Index,
InnerDoc,
Integer,
Keyword,
Nested,
Object,
Q,
Search,
Text,
analyzer,
connections,
query,
token_filter,
tokenizer,
)
Expand Down Expand Up @@ -74,15 +79,21 @@
client = connections.create_connection(hosts=["localhost"])


def elasticsearch(**kwargs):
def elasticsearch(score_functions_description={}, build_search_query=None, **kwargs):
def wrapper(cls):
cls.elasticsearch = generate_elasticsearch_model(cls)
cls.elasticsearch = generate_elasticsearch_model(
cls,
score_functions_description=score_functions_description,
build_search_query=build_search_query,
)
return cls

return wrapper


def generate_elasticsearch_model(cls: type) -> type:
def generate_elasticsearch_model(
cls: type, score_functions_description, build_search_query
) -> type:
index_name = cls._get_collection_name()

# Testing name to have a new index in each test.
Expand All @@ -103,10 +114,45 @@ class Index:
def elasticsearch_index(cls, document, **kwargs):
convert_mongo_document_to_elasticsearch_document(document).save()

score_functions = [
query.SF("field_value_factor", field=key, **value)
for key, value in score_functions_description.items()
]

def elasticsearch_search(query_text):
s = Search(using=client, index=index_name).query("match", title=query_text)
response = s.execute()
print(response)
s: Search = ElasticSearchModel.search()

if query_text:
query = build_search_query(query_text, score_functions)
else:
query = Q(
"function_score",
query=query.MatchAll(),
functions=score_functions,
)

print("---------------------")
print("---------------------")
print("---------------------")
print("---------------------")
print("---------------------")
print(score_functions_description)
for field in score_functions_description.keys():
print(field)
levels = field.split(".")
print(levels)

if len(levels) == 1:
pass
elif len(levels) == 2:
query = Q("nested", path=levels[0], query=query)
else:
raise RuntimeError(
f"This system only support one level deep score function fields. '{field}' contains two or more dots."
)

print(json.dumps(s.query(query).to_dict(), indent=2))
response = s.query(query).execute()

# Get all the models from MongoDB to fetch all the correct fields.
models = {
Expand Down Expand Up @@ -151,6 +197,8 @@ def convert_db_field_to_elasticsearch(field, searchable: bool | str) -> Field:
return Boolean()
elif isinstance(field, mongo_fields.DateTimeField):
return Date()
elif isinstance(field, mongo_fields.DictField):
return Nested()
elif isinstance(field, mongo_fields.ReferenceField):
return Nested(field.document_type_obj.__elasticsearch_model__)
else:
Expand All @@ -160,6 +208,7 @@ def convert_db_field_to_elasticsearch(field, searchable: bool | str) -> Field:
def convert_mongo_document_to_elasticsearch_document(document: MongoDocument) -> Document:
attributes = {}
attributes["id"] = str(document.id)
attributes["meta"] = {"id": str(document.id)}

for key, field, searchable in get_searchable_fields(document.__class__):
attributes[key] = getattr(document, key)
Expand All @@ -180,7 +229,7 @@ def ensure_index_exists(index: Index, index_name: str) -> None:
if index.exists():
return

now = datetime.now(datetime.UTC).strftime("%Y-%m-%d-%H-%M")
now = datetime.utcnow().strftime("%Y-%m-%d-%H-%M")
index_name_with_suffix = f"{index_name}-{now}"

# Because we create the index manually (`elasticsearch_dsl` creates an index
Expand Down
1 change: 1 addition & 0 deletions udata/core/metrics/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class WithMetrics(object):
metrics = field(
db.DictField(),
readonly=True,
searchable=True, # TODO change to indexable
)

__metrics_keys__ = []
Expand Down
22 changes: 17 additions & 5 deletions udata/tests/api/test_dataservices_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,18 @@ def test_dataservice_api_create_with_custom_user_or_org(self):
self.assertEqual(dataservice.organization.id, me_org.id)

def test_elasticsearch(self):
dataservice_a = DataserviceFactory(title="Hello AMD world!")
dataservice_b = DataserviceFactory(title="Other one")
dataservice_a = DataserviceFactory(
title="Hello AMD world!",
metrics={
"followers": 42,
},
)
dataservice_b = DataserviceFactory(
title="Other one",
metrics={
"followers": 1337,
},
)
time.sleep(1)

dataservices = Dataservice.__elasticsearch_search__("AMDAC")
Expand All @@ -330,10 +340,12 @@ def test_elasticsearch(self):

dataservice_b.title = "Hello AMD world!"
dataservice_b.save()
time.sleep(1)
time.sleep(3)

dataservices = Dataservice.__elasticsearch_search__("AMDAC")

assert len(dataservices) == 2
assert dataservices[0].id == dataservice_a.id
assert dataservices[1].id == dataservice_b.id

# `dataservice_b` should be first because it has a lot of followers
assert dataservices[0].id == dataservice_b.id
assert dataservices[1].id == dataservice_a.id

0 comments on commit 1989e8d

Please sign in to comment.