diff --git a/pyproject.toml b/pyproject.toml index 8706b10..c7b41e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = ["setuptools", "wheel"] name = "motile_toolbox" description = "A toolbox for tracking with (motile)[https://github.com/funkelab/motile]." readme = "README.md" -requires-python = ">=3.7" +requires-python = ">=3.10" classifiers = [ "Programming Language :: Python :: 3", ] @@ -17,7 +17,7 @@ authors = [ ] dynamic = ["version"] dependencies = [ - "motile", + "motile @git+https://github.com/funkelab/motile.git", "networkx", "numpy", "matplotlib", diff --git a/src/motile_toolbox/candidate_graph/graph_attributes.py b/src/motile_toolbox/candidate_graph/graph_attributes.py index a2927e7..cc6a9d4 100644 --- a/src/motile_toolbox/candidate_graph/graph_attributes.py +++ b/src/motile_toolbox/candidate_graph/graph_attributes.py @@ -11,6 +11,8 @@ class NodeAttr(Enum): TIME = "time" SEG_ID = "seg_id" SEG_HYPO = "seg_hypo" + BBOX = "bbox" + FLOW = "flow" class EdgeAttr(Enum): diff --git a/src/motile_toolbox/candidate_graph/utils.py b/src/motile_toolbox/candidate_graph/utils.py index 61307ba..c42a9c3 100644 --- a/src/motile_toolbox/candidate_graph/utils.py +++ b/src/motile_toolbox/candidate_graph/utils.py @@ -72,7 +72,9 @@ def nodes_from_segmentation( if hypo_id is not None: attrs[NodeAttr.SEG_HYPO.value] = hypo_id centroid = regionprop.centroid # [z,] y, x + bbox = regionprop.bbox attrs[NodeAttr.POS.value] = centroid + attrs[NodeAttr.BBOX.value] = bbox cand_graph.add_node(node_id, **attrs) nodes_in_frame.append(node_id) if nodes_in_frame: diff --git a/src/motile_toolbox/utils/flow.py b/src/motile_toolbox/utils/flow.py new file mode 100644 index 0000000..8392231 --- /dev/null +++ b/src/motile_toolbox/utils/flow.py @@ -0,0 +1,78 @@ +import math + +import networkx as nx +import numpy as np +from skimage.registration import phase_cross_correlation + +from motile_toolbox.candidate_graph.graph_attributes import EdgeAttr, NodeAttr + + +def compute_pcc_flow(candidate_graph: nx.DiGraph, images: np.ndarray): + """This calculates the flow using phase cross correlation + for the image cropped around an object + at `t` and the same region of interest at `t+1`, + and updates the `NodeAttr.FLOW`. + + Args: + candidate_graph (nx.DiGraph): Existing candidate graph with nodes. + + images (np.ndarray): Raw images (t, c, [z], y, x). + + """ + for node in candidate_graph.nodes(data=True): + frame = node[1][NodeAttr.TIME.value] + if frame + 1 >= len(images): + continue + loc = node[1][NodeAttr.POS.value] + bbox = node[1][NodeAttr.BBOX.value] + if len(loc) == 2: + reference_image = images[frame][ + 0, bbox[0] : bbox[2] + 1, bbox[1] : bbox[3] + 1 + ] + shifted_image = images[frame + 1][ + 0, bbox[0] : bbox[2] + 1, bbox[1] : bbox[3] + 1 + ] + elif len(loc) == 3: + reference_image = ( + images[frame][ + 0, + bbox[0] : bbox[3] + 1, + bbox[1] : bbox[4] + 1, + bbox[2] : bbox[5] + 1, + ], + ) + shifted_image = images[frame + 1][ + 0, + bbox[0] : bbox[3] + 1, + bbox[1] : bbox[4] + 1, + bbox[2] : bbox[5] + 1, + ] + shift, _, _ = phase_cross_correlation(reference_image, shifted_image) + node[1][NodeAttr.FLOW.value] = shift + + +def correct_edge_distance(candidate_graph: nx.DiGraph): + """This corrects for the edge distance in case the flow at a segmentation + node is available. The EdgeAttr.DISTANCE.value is set equal to + the L2 norm of (pos@t+1 - (flow + pos@t). + + + Args: + candidate_graph (nx.DiGraph): Existing candidate graph with nodes and + edges. + + Returns: + candidate_graph (nx.DiGraph): Updated candidate graph. (Edge + distance attribute is updated, by taking flow into account). + + """ + for edge in candidate_graph.edges(data=True): + in_node = candidate_graph.nodes[edge[0]] + out_node = candidate_graph.nodes[edge[1]] + dist = math.dist( + out_node[NodeAttr.POS.value], + in_node[NodeAttr.POS.value] + in_node[NodeAttr.FLOW.value], + ) + edge[2][EdgeAttr.DISTANCE.value] = dist + + return candidate_graph diff --git a/tests/test_candidate_graph/test_compute_graph.py b/tests/test_candidate_graph/test_compute_graph.py index 77ce13d..66f91a6 100644 --- a/tests/test_candidate_graph/test_compute_graph.py +++ b/tests/test_candidate_graph/test_compute_graph.py @@ -1,7 +1,7 @@ from collections import Counter import pytest -from motile_toolbox.candidate_graph import EdgeAttr, get_candidate_graph +from motile_toolbox.candidate_graph import EdgeAttr, NodeAttr, get_candidate_graph def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): @@ -14,7 +14,11 @@ def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): assert Counter(list(cand_graph.nodes)) == Counter(list(graph_2d.nodes)) assert Counter(list(cand_graph.edges)) == Counter(list(graph_2d.edges)) for node in cand_graph.nodes: - assert Counter(cand_graph.nodes[node]) == Counter(graph_2d.nodes[node]) + assert ( + pytest.approx(cand_graph.nodes[node][NodeAttr.POS.value], abs=0.01) + == graph_2d.nodes[node][NodeAttr.POS.value] + ) + for edge in cand_graph.edges: print(cand_graph.edges[edge]) assert ( @@ -39,8 +43,13 @@ def test_graph_from_segmentation_3d(segmentation_3d, graph_3d): ) assert Counter(list(cand_graph.nodes)) == Counter(list(graph_3d.nodes)) assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges)) + for node in cand_graph.nodes: - assert Counter(cand_graph.nodes[node]) == Counter(graph_3d.nodes[node]) + assert ( + pytest.approx(cand_graph.nodes[node][NodeAttr.POS.value], abs=0.01) + == graph_3d.nodes[node][NodeAttr.POS.value] + ) + for edge in cand_graph.edges: assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge] @@ -61,9 +70,11 @@ def test_graph_from_multi_segmentation_2d( list(multi_hypothesis_graph_2d.edges) ) for node in cand_graph.nodes: - assert Counter(cand_graph.nodes[node]) == Counter( - multi_hypothesis_graph_2d.nodes[node] + assert ( + pytest.approx(cand_graph.nodes[node][NodeAttr.POS.value], abs=0.01) + == multi_hypothesis_graph_2d.nodes[node][NodeAttr.POS.value] ) + for edge in cand_graph.edges: assert ( pytest.approx(cand_graph.edges[edge][EdgeAttr.IOU.value], abs=0.01)