From 84bdca26c402eb9d701981c2c0c25eb95b24a70b Mon Sep 17 00:00:00 2001 From: Jan Funke Date: Thu, 25 Apr 2024 13:44:27 -0400 Subject: [PATCH 1/3] Use KDTrees for candidate edge extraction --- src/motile_toolbox/candidate_graph/utils.py | 31 ++++++++++++++------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/src/motile_toolbox/candidate_graph/utils.py b/src/motile_toolbox/candidate_graph/utils.py index 520d471..61307ba 100644 --- a/src/motile_toolbox/candidate_graph/utils.py +++ b/src/motile_toolbox/candidate_graph/utils.py @@ -1,10 +1,11 @@ import logging import math -from typing import Any +from typing import Any, Iterable import networkx as nx import numpy as np from skimage.measure import regionprops +from scipy.spatial import KDTree from tqdm import tqdm from .graph_attributes import EdgeAttr, NodeAttr @@ -99,6 +100,11 @@ def _compute_node_frame_dict(cand_graph: nx.DiGraph) -> dict[int, list[Any]]: return node_frame_dict +def create_kdtree(cand_graph: nx.DiGraph, node_ids: Iterable[Any]) -> KDTree: + positions = [cand_graph.nodes[node][NodeAttr.POS.value] for node in node_ids] + return KDTree(positions) + + def add_cand_edges( cand_graph: nx.DiGraph, max_edge_distance: float, @@ -122,15 +128,20 @@ def add_cand_edges( node_frame_dict = _compute_node_frame_dict(cand_graph) frames = sorted(node_frame_dict.keys()) + prev_node_ids = node_frame_dict[frames[0]] + prev_kdtree = create_kdtree(cand_graph, prev_node_ids) for frame in tqdm(frames): if frame + 1 not in node_frame_dict: continue - next_nodes = node_frame_dict[frame + 1] - next_locs = [cand_graph.nodes[n][NodeAttr.POS.value] for n in next_nodes] - for node in node_frame_dict[frame]: - loc = cand_graph.nodes[node][NodeAttr.POS.value] - for next_id, next_loc in zip(next_nodes, next_locs): - dist = math.dist(next_loc, loc) - if dist <= max_edge_distance: - attrs = {EdgeAttr.DISTANCE.value: dist} - cand_graph.add_edge(node, next_id, **attrs) + next_node_ids = node_frame_dict[frame + 1] + next_kdtree = create_kdtree(cand_graph, next_node_ids) + + matched_indices = prev_kdtree.query_ball_tree(next_kdtree, max_edge_distance) + + for prev_node_id, next_node_indices in zip(prev_node_ids, matched_indices): + for next_node_index in next_node_indices: + next_node_id = next_node_ids[next_node_index] + cand_graph.add_edge(prev_node_id, next_node_id) + + prev_node_ids = next_node_ids + prev_kdtree = next_kdtree From ee2617099c550e1e6ba1c94d7b3d3ad496d48bf8 Mon Sep 17 00:00:00 2001 From: Jan Funke Date: Thu, 25 Apr 2024 13:45:12 -0400 Subject: [PATCH 2/3] Remove DISTANCE enum from edges Distances are implied by the incident node positions. --- .../candidate_graph/graph_attributes.py | 1 - tests/conftest.py | 36 +++++++++---------- .../test_compute_graph.py | 11 ------ tests/test_candidate_graph/test_utils.py | 7 ---- 4 files changed, 18 insertions(+), 37 deletions(-) diff --git a/src/motile_toolbox/candidate_graph/graph_attributes.py b/src/motile_toolbox/candidate_graph/graph_attributes.py index 478c2b3..a2927e7 100644 --- a/src/motile_toolbox/candidate_graph/graph_attributes.py +++ b/src/motile_toolbox/candidate_graph/graph_attributes.py @@ -19,5 +19,4 @@ class EdgeAttr(Enum): implementations of commonly used ones, listed here. """ - DISTANCE = "distance" IOU = "iou" diff --git a/tests/conftest.py b/tests/conftest.py index 93bbd17..ae62bd1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -90,8 +90,8 @@ def graph_2d(): ), ] edges = [ - ("0_1", "1_1", {EdgeAttr.DISTANCE.value: 42.43, EdgeAttr.IOU.value: 0.0}), - ("0_1", "1_2", {EdgeAttr.DISTANCE.value: 11.18, EdgeAttr.IOU.value: 0.395}), + ("0_1", "1_1", {EdgeAttr.IOU.value: 0.0}), + ("0_1", "1_2", {EdgeAttr.IOU.value: 0.395}), ] graph.add_nodes_from(nodes) graph.add_edges_from(edges) @@ -159,25 +159,25 @@ def multi_hypothesis_graph_2d(): ] edges = [ - ("0_0_1", "1_0_1", {EdgeAttr.DISTANCE.value: 42.426, EdgeAttr.IOU.value: 0.0}), - ("0_0_1", "1_1_1", {EdgeAttr.DISTANCE.value: 43.011, EdgeAttr.IOU.value: 0.0}), + ("0_0_1", "1_0_1", {EdgeAttr.IOU.value: 0.0}), + ("0_0_1", "1_1_1", {EdgeAttr.IOU.value: 0.0}), ( "0_0_1", "1_0_2", - {EdgeAttr.DISTANCE.value: 11.180, EdgeAttr.IOU.value: 0.3931}, + {EdgeAttr.IOU.value: 0.3931}, ), ( "0_0_1", "1_1_2", - {EdgeAttr.DISTANCE.value: 11.180, EdgeAttr.IOU.value: 0.4768}, + {EdgeAttr.IOU.value: 0.4768}, ), - ("0_1_1", "1_0_1", {EdgeAttr.DISTANCE.value: 43.011, EdgeAttr.IOU.value: 0.0}), - ("0_1_1", "1_1_1", {EdgeAttr.DISTANCE.value: 42.426, EdgeAttr.IOU.value: 0.0}), - ("0_1_1", "1_0_2", {EdgeAttr.DISTANCE.value: 15.0, EdgeAttr.IOU.value: 0.2402}), + ("0_1_1", "1_0_1", {EdgeAttr.IOU.value: 0.0}), + ("0_1_1", "1_1_1", {EdgeAttr.IOU.value: 0.0}), + ("0_1_1", "1_0_2", {EdgeAttr.IOU.value: 0.2402}), ( "0_1_1", "1_1_2", - {EdgeAttr.DISTANCE.value: 11.180, EdgeAttr.IOU.value: 0.3931}, + {EdgeAttr.IOU.value: 0.3931}, ), ] graph.add_nodes_from(nodes) @@ -276,9 +276,9 @@ def graph_3d(): ] edges = [ # math.dist([50, 50], [20, 80]) - ("0_1", "1_1", {EdgeAttr.DISTANCE.value: 42.43}), + ("0_1", "1_1"), # math.dist([50, 50], [60, 45]) - ("0_1", "1_2", {EdgeAttr.DISTANCE.value: 11.18}), + ("0_1", "1_2"), ] graph.add_nodes_from(nodes) graph.add_edges_from(edges) @@ -336,12 +336,12 @@ def multi_hypothesis_graph_3d(): ), ] edges = [ - ("0_0_1", "1_0_1", {EdgeAttr.DISTANCE.value: 42.4264}), - ("0_0_1", "1_0_2", {EdgeAttr.DISTANCE.value: 11.1803}), - ("0_1_1", "1_0_1", {EdgeAttr.DISTANCE.value: 35.3553}), - ("0_1_1", "1_0_2", {EdgeAttr.DISTANCE.value: 18.0277}), - ("0_0_1", "1_1_1", {EdgeAttr.DISTANCE.value: 40.3112}), - ("0_1_1", "1_1_1", {EdgeAttr.DISTANCE.value: 33.5410}), + ("0_0_1", "1_0_1"), + ("0_0_1", "1_0_2"), + ("0_1_1", "1_0_1"), + ("0_1_1", "1_0_2"), + ("0_0_1", "1_1_1"), + ("0_1_1", "1_1_1"), ] graph.add_nodes_from(nodes) graph.add_edges_from(edges) diff --git a/tests/test_candidate_graph/test_compute_graph.py b/tests/test_candidate_graph/test_compute_graph.py index eaeb241..77ce13d 100644 --- a/tests/test_candidate_graph/test_compute_graph.py +++ b/tests/test_candidate_graph/test_compute_graph.py @@ -17,10 +17,6 @@ def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): assert Counter(cand_graph.nodes[node]) == Counter(graph_2d.nodes[node]) for edge in cand_graph.edges: print(cand_graph.edges[edge]) - assert ( - pytest.approx(cand_graph.edges[edge][EdgeAttr.DISTANCE.value], abs=0.01) - == graph_2d.edges[edge][EdgeAttr.DISTANCE.value] - ) assert ( pytest.approx(cand_graph.edges[edge][EdgeAttr.IOU.value], abs=0.01) == graph_2d.edges[edge][EdgeAttr.IOU.value] @@ -33,9 +29,6 @@ def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): ) assert Counter(list(cand_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) assert Counter(list(cand_graph.edges)) == Counter([("0_1", "1_2")]) - assert cand_graph.edges[("0_1", "1_2")][EdgeAttr.DISTANCE.value] == pytest.approx( - 11.18, abs=0.01 - ) def test_graph_from_segmentation_3d(segmentation_3d, graph_3d): @@ -72,10 +65,6 @@ def test_graph_from_multi_segmentation_2d( multi_hypothesis_graph_2d.nodes[node] ) for edge in cand_graph.edges: - assert ( - pytest.approx(cand_graph.edges[edge][EdgeAttr.DISTANCE.value], abs=0.01) - == multi_hypothesis_graph_2d.edges[edge][EdgeAttr.DISTANCE.value] - ) assert ( pytest.approx(cand_graph.edges[edge][EdgeAttr.IOU.value], abs=0.01) == multi_hypothesis_graph_2d.edges[edge][EdgeAttr.IOU.value] diff --git a/tests/test_candidate_graph/test_utils.py b/tests/test_candidate_graph/test_utils.py index afa9ea1..9b4dda6 100644 --- a/tests/test_candidate_graph/test_utils.py +++ b/tests/test_candidate_graph/test_utils.py @@ -75,11 +75,6 @@ def test_add_cand_edges_2d(graph_2d): cand_graph = nx.create_empty_copy(graph_2d) add_cand_edges(cand_graph, max_edge_distance=50) assert Counter(list(cand_graph.edges)) == Counter(list(graph_2d.edges)) - for edge in cand_graph.edges: - assert ( - pytest.approx(cand_graph.edges[edge][EdgeAttr.DISTANCE.value], abs=0.01) - == graph_2d.edges[edge][EdgeAttr.DISTANCE.value] - ) def test_add_cand_edges_3d(graph_3d): @@ -87,8 +82,6 @@ def test_add_cand_edges_3d(graph_3d): add_cand_edges(cand_graph, max_edge_distance=15) graph_3d.remove_edge("0_1", "1_1") assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges)) - for edge in cand_graph.edges: - assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge] def test_get_node_id(): From 9c980102df19344bcde53800697c6adb5178e03c Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 29 Apr 2024 11:08:29 -0400 Subject: [PATCH 3/3] Update mypy.ini to ignore missing scipy stubs --- mypy.ini | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mypy.ini b/mypy.ini index c873b64..4b35b3b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -16,3 +16,6 @@ ignore_missing_imports = True [mypy-motile.*] ignore_missing_imports = True + +[mypy-scipy.*] +ignore_missing_imports = True