Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: use derived fields to implement per-ingredient recipe scoring #121

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
91 changes: 49 additions & 42 deletions reciperadar/search/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,7 @@ class RecipeSearch(QueryRepository):
def _generate_include_clause(ingredients):
synonyms = load_ingredient_synonyms()
include = EntityClause.term_list(ingredients, lambda x: x.positive, synonyms)
return [
{
"constant_score": {
"boost": pow(10, idx),
"filter": {"match": {"contents": inc}},
}
}
for idx, inc in enumerate(reversed(include))
]
return [{"match": {"contents": inc}} for inc in include]

@staticmethod
def _generate_include_exact_clause(ingredients):
Expand All @@ -48,15 +40,10 @@ def _generate_include_exact_clause(ingredients):
{
"nested": {
"path": "ingredients",
"query": {
"constant_score": {
"boost": pow(10, idx) * 2,
"filter": {"match": {"ingredients.product.singular": inc}},
}
},
"query": {"match": {"ingredients.product.singular": inc}},
}
}
for idx, inc in enumerate(reversed(include))
for inc in include
]

@staticmethod
Expand All @@ -82,15 +69,14 @@ def _generate_equipment_clause(equipment):
return {"bool": conditions}

@staticmethod
def sort_methods(match_count=1):
score_limit = pow(10, match_count) * 2
def sort_methods(score_limit=1):
preamble = f"""
def product_count = doc.product_count.value;
def exact_found_count = 0;
def found_count = 0;
for (def score = (long) _score; score > 0; score /= 10) {{
if (score % 10 > 2) exact_found_count++;
if (score % 10 > 0) found_count++;
for (is_exact in doc._found) {{
if (is_exact == true) exact_found_count++;
if (is_exact == false) found_count++;
}}
def missing_count = product_count - found_count;
def exact_missing_count = product_count - exact_found_count;
Expand Down Expand Up @@ -129,7 +115,7 @@ def _generate_sort_method(self, ingredients, sort):
include = [True for x in ingredients if x.positive]
if include == [] and sort != "duration":
return {"script": "doc.rating.value", "order": "desc"}
return self.sort_methods(match_count=len(include))[sort]
return self.sort_methods(score_limit=len(include))[sort]

def _domain_facets(self):
return {"domains": {"terms": {"field": "domain", "size": 100}}}
Expand Down Expand Up @@ -196,6 +182,29 @@ def _generate_aggregations(self, suggest_products, ingredients, dietary_properti
}
}

@staticmethod
def _generate_derived_fields(ingredients):
synonyms = load_ingredient_synonyms()
include = EntityClause.term_list(ingredients, lambda x: x.positive, synonyms)
derivations = {
"_found": {
"type": "boolean",
"script": {
"source": """
def products = Collections.unmodifiableSet(params._source['ingredients'].stream().map(ingredient -> ingredient.product.singular).collect(Collectors.toSet()));
def contents = Collections.unmodifiableSet(params._source['contents'].stream().collect(Collectors.toSet()));
for (product in params.products) {
if (products.contains(product)) emit(true);
else if (contents.contains(product)) emit(false);
else emit(null);
}
""",
"params": {"products": include},
},
}
}
return derivations, [field for field in derivations]

def _generate_post_filter(self, domains):
conditions = defaultdict(list)
for domain in domains:
Expand Down Expand Up @@ -232,8 +241,7 @@ def _render_query(
min_include_match = len(should)

return {
"function_score": {
"boost_mode": "replace",
"script_score": {
"query": {
"bool": {
"should": should,
Expand All @@ -242,7 +250,7 @@ def _render_query(
"minimum_should_match": min_include_match,
}
},
"script_score": {"script": {"source": sort_params["script"]}},
"script": {"source": sort_params["script"]},
}
}, [{"_score": sort_params["order"]}]

Expand Down Expand Up @@ -377,24 +385,24 @@ def query(
To achieve this, we use OpenSearch's query syntax to encode information
about the quality of each match during search execution.

We use `constant_score` queries to store a power-of-ten score for each
query ingredient, with the value doubled for exact matches.
We use `derived` fields to emit a tri-state boolean score, containing
one value for each query ingredient -- `null` for unmatched
ingredients, `false` for partial matches, and `true` for exact matches.

For example, in a query for `onion`, `tomato`, `tofu`:

onion tomato tofu score
recipe 1 exact exact partial 300 + 30 + 1 = 331
recipe 2 partial no exact 100 + 0 + 3 = 103
recipe 3 exact no exact 300 + 0 + 3 = 303
onion tomato tofu _found
recipe 1 exact exact partial [true, true, false]
recipe 2 partial no exact [false, null, true]
recipe 3 exact no exact [true, null, true]

This allows the final sorting stage to determine - with some small
possibility of error* - how many exact and inexact matches were
discovered for each recipe.
This allows the final sorting stage to determine how many exact and
inexact matches were discovered for each recipe.

score exact_matches all_matches
recipe 1 331 1 + 1 + 0 = 2 1 + 1 + 1 = 3
recipe 2 103 0 + 0 + 1 = 1 1 + 0 + 1 = 2
recipe 3 303 1 + 0 + 1 = 2 1 + 0 + 1 = 2
_found exact_matches all_matches
recipe 1 [true, true, false] 1 + 1 + 0 = 2 1 + 1 + 1 = 3
recipe 2 [false, null, true] 0 + 0 + 1 = 1 1 + 0 + 1 = 2
recipe 3 [true, null, true] 1 + 0 + 1 = 2 1 + 0 + 1 = 2

At this stage we have enough information to sort the result set based
on the number of overall matches and to use the number of exact matches
Expand All @@ -405,10 +413,6 @@ def query(
- (3 matches, 2 exact) recipe 1
- (2 matches, 2 exact) recipe 3
- (2 matches, 1 exact) recipe 2


* Inconsistent results and ranking errors can occur if an ingredient
appears multiple times in a recipe, resulting in duplicate counts
"""
offset = max(0, offset)
limit = max(0, limit)
Expand All @@ -419,6 +423,7 @@ def query(
ingredients=ingredients,
dietary_properties=dietary_properties,
)
derived, derived_fields = self._generate_derived_fields(ingredients=ingredients)
post_filter = self._generate_post_filter(domains=domains)

queries = self._refined_queries(
Expand All @@ -434,6 +439,8 @@ def query(
"query": query,
"from": offset,
"size": limit,
"derived": derived,
"fields": derived_fields,
"sort": sort_method,
"aggs": aggregations,
"post_filter": post_filter,
Expand Down