Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include phase cross correlation flow #10

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ requires = ["setuptools", "wheel"]
name = "motile_toolbox"
description = "A toolbox for tracking with (motile)[https://github.com/funkelab/motile]."
readme = "README.md"
requires-python = ">=3.7"
requires-python = ">=3.10"
classifiers = [
"Programming Language :: Python :: 3",
]
Expand All @@ -17,7 +17,7 @@ authors = [
]
dynamic = ["version"]
dependencies = [
"motile",
"motile @git+https://github.com/funkelab/motile.git",
"networkx",
"numpy",
"matplotlib",
Expand Down
2 changes: 2 additions & 0 deletions src/motile_toolbox/candidate_graph/graph_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ class NodeAttr(Enum):
TIME = "time"
SEG_ID = "seg_id"
SEG_HYPO = "seg_hypo"
BBOX = "bbox"
FLOW = "flow"


class EdgeAttr(Enum):
Expand Down
2 changes: 2 additions & 0 deletions src/motile_toolbox/candidate_graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def nodes_from_segmentation(
if hypo_id is not None:
attrs[NodeAttr.SEG_HYPO.value] = hypo_id
centroid = regionprop.centroid # [z,] y, x
bbox = regionprop.bbox
attrs[NodeAttr.POS.value] = centroid
attrs[NodeAttr.BBOX.value] = bbox
cand_graph.add_node(node_id, **attrs)
nodes_in_frame.append(node_id)
if nodes_in_frame:
Expand Down
78 changes: 78 additions & 0 deletions src/motile_toolbox/utils/flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import math

import networkx as nx
import numpy as np
from skimage.registration import phase_cross_correlation

from motile_toolbox.candidate_graph.graph_attributes import EdgeAttr, NodeAttr


def compute_pcc_flow(candidate_graph: nx.DiGraph, images: np.ndarray):
"""This calculates the flow using phase cross correlation
for the image cropped around an object
at `t` and the same region of interest at `t+1`,
and updates the `NodeAttr.FLOW`.

Args:
candidate_graph (nx.DiGraph): Existing candidate graph with nodes.

images (np.ndarray): Raw images (t, c, [z], y, x).

"""
for node in candidate_graph.nodes(data=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unpack into node_id, data would be nicer to read

frame = node[1][NodeAttr.TIME.value]
if frame + 1 >= len(images):
continue
loc = node[1][NodeAttr.POS.value]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can get rid of loc and use the size of the bbox to infer the number of dimensions

bbox = node[1][NodeAttr.BBOX.value]
if len(loc) == 2:
reference_image = images[frame][
0, bbox[0] : bbox[2] + 1, bbox[1] : bbox[3] + 1
]
shifted_image = images[frame + 1][
0, bbox[0] : bbox[2] + 1, bbox[1] : bbox[3] + 1
]
elif len(loc) == 3:
reference_image = (
images[frame][
0,
bbox[0] : bbox[3] + 1,
bbox[1] : bbox[4] + 1,
bbox[2] : bbox[5] + 1,
],
)
shifted_image = images[frame + 1][
0,
bbox[0] : bbox[3] + 1,
bbox[1] : bbox[4] + 1,
bbox[2] : bbox[5] + 1,
]
shift, _, _ = phase_cross_correlation(reference_image, shifted_image)
node[1][NodeAttr.FLOW.value] = shift
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this represent? Is it a single number? A vector? A vector at each pixel in the input image?



def correct_edge_distance(candidate_graph: nx.DiGraph):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, I strongly don't think this should overwrite distance, it needs its own attribute. Second, Jan and I removed the distance attribute (and updated Motile.EdgeDistance cost) 😆 so now this feature definitely needs its own attribute!

"""This corrects for the edge distance in case the flow at a segmentation
node is available. The EdgeAttr.DISTANCE.value is set equal to
the L2 norm of (pos@t+1 - (flow + pos@t).


Args:
candidate_graph (nx.DiGraph): Existing candidate graph with nodes and
edges.

Returns:
candidate_graph (nx.DiGraph): Updated candidate graph. (Edge
distance attribute is updated, by taking flow into account).

"""
for edge in candidate_graph.edges(data=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, I prefer unpacking the tuple here rather than using [2] later

in_node = candidate_graph.nodes[edge[0]]
out_node = candidate_graph.nodes[edge[1]]
dist = math.dist(
out_node[NodeAttr.POS.value],
in_node[NodeAttr.POS.value] + in_node[NodeAttr.FLOW.value],
)
edge[2][EdgeAttr.DISTANCE.value] = dist

return candidate_graph
21 changes: 16 additions & 5 deletions tests/test_candidate_graph/test_compute_graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import Counter

import pytest
from motile_toolbox.candidate_graph import EdgeAttr, get_candidate_graph
from motile_toolbox.candidate_graph import EdgeAttr, NodeAttr, get_candidate_graph


def test_graph_from_segmentation_2d(segmentation_2d, graph_2d):
Expand All @@ -14,7 +14,11 @@ def test_graph_from_segmentation_2d(segmentation_2d, graph_2d):
assert Counter(list(cand_graph.nodes)) == Counter(list(graph_2d.nodes))
assert Counter(list(cand_graph.edges)) == Counter(list(graph_2d.edges))
for node in cand_graph.nodes:
assert Counter(cand_graph.nodes[node]) == Counter(graph_2d.nodes[node])
assert (
pytest.approx(cand_graph.nodes[node][NodeAttr.POS.value], abs=0.01)
== graph_2d.nodes[node][NodeAttr.POS.value]
)

for edge in cand_graph.edges:
print(cand_graph.edges[edge])
assert (
Expand All @@ -39,8 +43,13 @@ def test_graph_from_segmentation_3d(segmentation_3d, graph_3d):
)
assert Counter(list(cand_graph.nodes)) == Counter(list(graph_3d.nodes))
assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges))

for node in cand_graph.nodes:
assert Counter(cand_graph.nodes[node]) == Counter(graph_3d.nodes[node])
assert (
pytest.approx(cand_graph.nodes[node][NodeAttr.POS.value], abs=0.01)
== graph_3d.nodes[node][NodeAttr.POS.value]
)

for edge in cand_graph.edges:
assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge]

Expand All @@ -61,9 +70,11 @@ def test_graph_from_multi_segmentation_2d(
list(multi_hypothesis_graph_2d.edges)
)
for node in cand_graph.nodes:
assert Counter(cand_graph.nodes[node]) == Counter(
multi_hypothesis_graph_2d.nodes[node]
assert (
pytest.approx(cand_graph.nodes[node][NodeAttr.POS.value], abs=0.01)
== multi_hypothesis_graph_2d.nodes[node][NodeAttr.POS.value]
)

for edge in cand_graph.edges:
assert (
pytest.approx(cand_graph.edges[edge][EdgeAttr.IOU.value], abs=0.01)
Expand Down
Loading