Skip to content

Commit

Permalink
Add 'max_synonyms' param to get_normalizer_results() #2367
Browse files Browse the repository at this point in the history
  • Loading branch information
amykglen committed Oct 12, 2024
1 parent 1520592 commit 7ebd333
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 5 deletions.
37 changes: 32 additions & 5 deletions code/ARAX/NodeSynonymizer/node_synonymizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import string
import sys
import time
from collections import defaultdict
from collections import defaultdict, Counter
from typing import Optional, Union, List, Set, Dict, Tuple

import pandas as pd
Expand Down Expand Up @@ -286,6 +286,7 @@ def get_curie_names(self, curies: Union[str, Set[str], List[str]], debug: bool =
return results_dict

def get_normalizer_results(self, entities: Optional[Union[str, Set[str], List[str]]],
max_synonyms: int = 1000000,
debug: bool = False) -> dict:
start = time.time()

Expand All @@ -307,6 +308,26 @@ def get_normalizer_results(self, entities: Optional[Union[str, Set[str], List[st
equivalent_curies_dict_names = self.get_equivalent_nodes(names=unrecognized_entities, include_unrecognized_entities=False)
equivalent_curies_dict.update(equivalent_curies_dict_names)

# Truncate synonyms to max number allowed per node
# First record counts for full list of equivalent curies before trimming
equiv_curie_counts_untrimmed = {input_entity: len(equivalent_curies) if equivalent_curies else 0
for input_entity, equivalent_curies in equivalent_curies_dict.items()}
all_node_ids_untrimmed = set().union(*equivalent_curies_dict.values())
sql_query_template = f"""
SELECT N.id, N.category
FROM nodes as N
WHERE N.id in ('{self.placeholder_lookup_values_str}')"""
matching_rows = self._run_sql_query_in_batches(sql_query_template, all_node_ids_untrimmed)
categories_map_untrimmed = {row[0]: f"biolink:{row[1]}" for row in matching_rows}
category_counts_untrimmed = dict()
equivalent_curies_dict_trimmed = dict()
for input_entity, equivalent_curies in equivalent_curies_dict.items():
category_counts_untrimmed[input_entity] = dict(Counter([categories_map_untrimmed[equiv_curie]
for equiv_curie in equivalent_curies]))
equivalent_curies_trimmed = equivalent_curies[:max_synonyms] if equivalent_curies else None
equivalent_curies_dict_trimmed[input_entity] = equivalent_curies_trimmed
equivalent_curies_dict = equivalent_curies_dict_trimmed

# Then get info for all of those equivalent nodes
# Note: We don't need to query by capitalized curies because these are all curies that exist in the synonymizer
all_node_ids = set().union(*equivalent_curies_dict.values())
Expand Down Expand Up @@ -340,13 +361,13 @@ def get_normalizer_results(self, entities: Optional[Union[str, Set[str], List[st
"SRI_normalizer_name": cluster_rep["name_sri"],
"SRI_normalizer_category": cluster_rep["category_sri"],
"SRI_normalizer_curie": cluster_id if cluster_rep["category_sri"] else None},
"categories": defaultdict(int),
"total_synonyms": equiv_curie_counts_untrimmed[input_entity],
"categories": category_counts_untrimmed[input_entity],
"nodes": [nodes_dict[equivalent_curie] for equivalent_curie in equivalent_curies]}

# Do some post-processing (tally up category counts and remove no-longer-needed 'cluster_id' property)
# Do some post-processing (remove no-longer-needed 'cluster_id' property)
for normalizer_info in results_dict.values():
for equivalent_node in normalizer_info["nodes"]:
normalizer_info["categories"][equivalent_node["category"]] += 1
if "cluster_id" in equivalent_node:
del equivalent_node["cluster_id"]
if "cluster_preferred_name" in equivalent_node:
Expand Down Expand Up @@ -495,7 +516,13 @@ def _get_cluster_graph(self, normalizer_info: dict) -> dict:
intra_cluster_edge_ids_str = "[]" if cluster_row[0] == "nan" else cluster_row[0]
intra_cluster_edge_ids = ast.literal_eval(intra_cluster_edge_ids_str) # Lists are stored as strings in sqlite

edges_query = f"SELECT * FROM edges WHERE id IN ('{self._convert_to_str_format(intra_cluster_edge_ids)}')"
# Get rid of any orphan edges (may be present if max_synonyms is specified in get_normalizer_results())
subj_obj_query = f"SELECT id, subject, object FROM edges WHERE id IN ('{self._convert_to_str_format(intra_cluster_edge_ids)}')"
subj_obj_rows = self._execute_sql_query(subj_obj_query)
intra_cluster_edge_ids_trimmed = {edge_id for edge_id, subject_id, object_id in subj_obj_rows
if subject_id in kg.nodes and object_id in kg.nodes}

edges_query = f"SELECT * FROM edges WHERE id IN ('{self._convert_to_str_format(intra_cluster_edge_ids_trimmed)}')"
edge_rows = self._execute_sql_query(edges_query)
edges_df = self._load_records_into_dataframe(edge_rows, "edges")
edge_dicts = edges_df.to_dict(orient="records")
Expand Down
21 changes: 21 additions & 0 deletions code/ARAX/test/test_ARAX_synonymizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,5 +417,26 @@ def test_cluster_graphs():
assert node["attributes"]


def test_truncate_cluster():
synonymizer = NodeSynonymizer()
results = synonymizer.get_normalizer_results([ACETAMINOPHEN_CURIE, PARKINSONS_CURIE], max_synonyms=2)

print(json.dumps(results[ACETAMINOPHEN_CURIE]["nodes"], indent=2))
assert len(results[ACETAMINOPHEN_CURIE]["nodes"]) == 2
assert len(results[ACETAMINOPHEN_CURIE]["knowledge_graph"]["nodes"]) == 2
assert len(results[ACETAMINOPHEN_CURIE]["knowledge_graph"]["edges"]) < 20
assert results[ACETAMINOPHEN_CURIE]["total_synonyms"] > 2
assert results[ACETAMINOPHEN_CURIE]["categories"]["biolink:Drug"] > 2
assert "biolink:Disease" not in results[ACETAMINOPHEN_CURIE]["categories"]

print(json.dumps(results[PARKINSONS_CURIE]["nodes"], indent=2))
assert len(results[PARKINSONS_CURIE]["nodes"]) == 2
assert len(results[PARKINSONS_CURIE]["knowledge_graph"]["nodes"]) == 2
assert len(results[PARKINSONS_CURIE]["knowledge_graph"]["edges"]) < 20
assert results[PARKINSONS_CURIE]["total_synonyms"] > 2
assert results[PARKINSONS_CURIE]["categories"]["biolink:Disease"] > 2
assert "biolink:Drug" not in results[PARKINSONS_CURIE]["categories"]


if __name__ == "__main__":
pytest.main(['-v', 'test_ARAX_synonymizer.py'])

0 comments on commit 7ebd333

Please sign in to comment.