From 7ebd333cc7df4a5133b6bf1d276c7ca81dd57fad Mon Sep 17 00:00:00 2001 From: amykglen Date: Fri, 11 Oct 2024 17:05:45 -0700 Subject: [PATCH] Add 'max_synonyms' param to get_normalizer_results() #2367 --- code/ARAX/NodeSynonymizer/node_synonymizer.py | 37 ++++++++++++++++--- code/ARAX/test/test_ARAX_synonymizer.py | 21 +++++++++++ 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/code/ARAX/NodeSynonymizer/node_synonymizer.py b/code/ARAX/NodeSynonymizer/node_synonymizer.py index 8abf890be..c74fe7abe 100644 --- a/code/ARAX/NodeSynonymizer/node_synonymizer.py +++ b/code/ARAX/NodeSynonymizer/node_synonymizer.py @@ -7,7 +7,7 @@ import string import sys import time -from collections import defaultdict +from collections import defaultdict, Counter from typing import Optional, Union, List, Set, Dict, Tuple import pandas as pd @@ -286,6 +286,7 @@ def get_curie_names(self, curies: Union[str, Set[str], List[str]], debug: bool = return results_dict def get_normalizer_results(self, entities: Optional[Union[str, Set[str], List[str]]], + max_synonyms: int = 1000000, debug: bool = False) -> dict: start = time.time() @@ -307,6 +308,26 @@ def get_normalizer_results(self, entities: Optional[Union[str, Set[str], List[st equivalent_curies_dict_names = self.get_equivalent_nodes(names=unrecognized_entities, include_unrecognized_entities=False) equivalent_curies_dict.update(equivalent_curies_dict_names) + # Truncate synonyms to max number allowed per node + # First record counts for full list of equivalent curies before trimming + equiv_curie_counts_untrimmed = {input_entity: len(equivalent_curies) if equivalent_curies else 0 + for input_entity, equivalent_curies in equivalent_curies_dict.items()} + all_node_ids_untrimmed = set().union(*equivalent_curies_dict.values()) + sql_query_template = f""" + SELECT N.id, N.category + FROM nodes as N + WHERE N.id in ('{self.placeholder_lookup_values_str}')""" + matching_rows = self._run_sql_query_in_batches(sql_query_template, all_node_ids_untrimmed) + categories_map_untrimmed = {row[0]: f"biolink:{row[1]}" for row in matching_rows} + category_counts_untrimmed = dict() + equivalent_curies_dict_trimmed = dict() + for input_entity, equivalent_curies in equivalent_curies_dict.items(): + category_counts_untrimmed[input_entity] = dict(Counter([categories_map_untrimmed[equiv_curie] + for equiv_curie in equivalent_curies])) + equivalent_curies_trimmed = equivalent_curies[:max_synonyms] if equivalent_curies else None + equivalent_curies_dict_trimmed[input_entity] = equivalent_curies_trimmed + equivalent_curies_dict = equivalent_curies_dict_trimmed + # Then get info for all of those equivalent nodes # Note: We don't need to query by capitalized curies because these are all curies that exist in the synonymizer all_node_ids = set().union(*equivalent_curies_dict.values()) @@ -340,13 +361,13 @@ def get_normalizer_results(self, entities: Optional[Union[str, Set[str], List[st "SRI_normalizer_name": cluster_rep["name_sri"], "SRI_normalizer_category": cluster_rep["category_sri"], "SRI_normalizer_curie": cluster_id if cluster_rep["category_sri"] else None}, - "categories": defaultdict(int), + "total_synonyms": equiv_curie_counts_untrimmed[input_entity], + "categories": category_counts_untrimmed[input_entity], "nodes": [nodes_dict[equivalent_curie] for equivalent_curie in equivalent_curies]} - # Do some post-processing (tally up category counts and remove no-longer-needed 'cluster_id' property) + # Do some post-processing (remove no-longer-needed 'cluster_id' property) for normalizer_info in results_dict.values(): for equivalent_node in normalizer_info["nodes"]: - normalizer_info["categories"][equivalent_node["category"]] += 1 if "cluster_id" in equivalent_node: del equivalent_node["cluster_id"] if "cluster_preferred_name" in equivalent_node: @@ -495,7 +516,13 @@ def _get_cluster_graph(self, normalizer_info: dict) -> dict: intra_cluster_edge_ids_str = "[]" if cluster_row[0] == "nan" else cluster_row[0] intra_cluster_edge_ids = ast.literal_eval(intra_cluster_edge_ids_str) # Lists are stored as strings in sqlite - edges_query = f"SELECT * FROM edges WHERE id IN ('{self._convert_to_str_format(intra_cluster_edge_ids)}')" + # Get rid of any orphan edges (may be present if max_synonyms is specified in get_normalizer_results()) + subj_obj_query = f"SELECT id, subject, object FROM edges WHERE id IN ('{self._convert_to_str_format(intra_cluster_edge_ids)}')" + subj_obj_rows = self._execute_sql_query(subj_obj_query) + intra_cluster_edge_ids_trimmed = {edge_id for edge_id, subject_id, object_id in subj_obj_rows + if subject_id in kg.nodes and object_id in kg.nodes} + + edges_query = f"SELECT * FROM edges WHERE id IN ('{self._convert_to_str_format(intra_cluster_edge_ids_trimmed)}')" edge_rows = self._execute_sql_query(edges_query) edges_df = self._load_records_into_dataframe(edge_rows, "edges") edge_dicts = edges_df.to_dict(orient="records") diff --git a/code/ARAX/test/test_ARAX_synonymizer.py b/code/ARAX/test/test_ARAX_synonymizer.py index 945a87a48..3f8500e98 100644 --- a/code/ARAX/test/test_ARAX_synonymizer.py +++ b/code/ARAX/test/test_ARAX_synonymizer.py @@ -417,5 +417,26 @@ def test_cluster_graphs(): assert node["attributes"] +def test_truncate_cluster(): + synonymizer = NodeSynonymizer() + results = synonymizer.get_normalizer_results([ACETAMINOPHEN_CURIE, PARKINSONS_CURIE], max_synonyms=2) + + print(json.dumps(results[ACETAMINOPHEN_CURIE]["nodes"], indent=2)) + assert len(results[ACETAMINOPHEN_CURIE]["nodes"]) == 2 + assert len(results[ACETAMINOPHEN_CURIE]["knowledge_graph"]["nodes"]) == 2 + assert len(results[ACETAMINOPHEN_CURIE]["knowledge_graph"]["edges"]) < 20 + assert results[ACETAMINOPHEN_CURIE]["total_synonyms"] > 2 + assert results[ACETAMINOPHEN_CURIE]["categories"]["biolink:Drug"] > 2 + assert "biolink:Disease" not in results[ACETAMINOPHEN_CURIE]["categories"] + + print(json.dumps(results[PARKINSONS_CURIE]["nodes"], indent=2)) + assert len(results[PARKINSONS_CURIE]["nodes"]) == 2 + assert len(results[PARKINSONS_CURIE]["knowledge_graph"]["nodes"]) == 2 + assert len(results[PARKINSONS_CURIE]["knowledge_graph"]["edges"]) < 20 + assert results[PARKINSONS_CURIE]["total_synonyms"] > 2 + assert results[PARKINSONS_CURIE]["categories"]["biolink:Disease"] > 2 + assert "biolink:Drug" not in results[PARKINSONS_CURIE]["categories"] + + if __name__ == "__main__": pytest.main(['-v', 'test_ARAX_synonymizer.py'])