-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Include phase cross correlation flow #10
base: main
Are you sure you want to change the base?
Changes from 8 commits
f75e60d
6c27a49
210371f
5eed914
faa5f5f
77d2a85
7ed84bb
f31144b
55847af
e480586
570c976
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import math | ||
|
||
import networkx as nx | ||
import numpy as np | ||
from skimage.registration import phase_cross_correlation | ||
|
||
from motile_toolbox.candidate_graph.graph_attributes import EdgeAttr, NodeAttr | ||
|
||
|
||
def compute_pcc_flow(candidate_graph: nx.DiGraph, images: np.ndarray): | ||
"""This calculates the flow using phase cross correlation | ||
for the image cropped around an object | ||
at `t` and the same region of interest at `t+1`, | ||
and updates the `NodeAttr.FLOW`. | ||
|
||
Args: | ||
candidate_graph (nx.DiGraph): Existing candidate graph with nodes. | ||
|
||
images (np.ndarray): Raw images (t, c, [z], y, x). | ||
|
||
""" | ||
for node in candidate_graph.nodes(data=True): | ||
frame = node[1][NodeAttr.TIME.value] | ||
if frame + 1 >= len(images): | ||
continue | ||
loc = node[1][NodeAttr.POS.value] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can get rid of loc and use the size of the bbox to infer the number of dimensions |
||
bbox = node[1][NodeAttr.BBOX.value] | ||
if len(loc) == 2: | ||
reference_image = images[frame][ | ||
0, bbox[0] : bbox[2] + 1, bbox[1] : bbox[3] + 1 | ||
] | ||
shifted_image = images[frame + 1][ | ||
0, bbox[0] : bbox[2] + 1, bbox[1] : bbox[3] + 1 | ||
] | ||
elif len(loc) == 3: | ||
reference_image = ( | ||
images[frame][ | ||
0, | ||
bbox[0] : bbox[3] + 1, | ||
bbox[1] : bbox[4] + 1, | ||
bbox[2] : bbox[5] + 1, | ||
], | ||
) | ||
shifted_image = images[frame + 1][ | ||
0, | ||
bbox[0] : bbox[3] + 1, | ||
bbox[1] : bbox[4] + 1, | ||
bbox[2] : bbox[5] + 1, | ||
] | ||
shift, _, _ = phase_cross_correlation(reference_image, shifted_image) | ||
node[1][NodeAttr.FLOW.value] = shift | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does this represent? Is it a single number? A vector? A vector at each pixel in the input image? |
||
|
||
|
||
def correct_edge_distance(candidate_graph: nx.DiGraph): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. First, I strongly don't think this should overwrite distance, it needs its own attribute. Second, Jan and I removed the distance attribute (and updated Motile.EdgeDistance cost) 😆 so now this feature definitely needs its own attribute! |
||
"""This corrects for the edge distance in case the flow at a segmentation | ||
node is available. The EdgeAttr.DISTANCE.value is set equal to | ||
the L2 norm of (pos@t+1 - (flow + pos@t). | ||
|
||
|
||
Args: | ||
candidate_graph (nx.DiGraph): Existing candidate graph with nodes and | ||
edges. | ||
|
||
Returns: | ||
candidate_graph (nx.DiGraph): Updated candidate graph. (Edge | ||
distance attribute is updated, by taking flow into account). | ||
|
||
""" | ||
for edge in candidate_graph.edges(data=True): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. again, I prefer unpacking the tuple here rather than using [2] later |
||
in_node = candidate_graph.nodes[edge[0]] | ||
out_node = candidate_graph.nodes[edge[1]] | ||
dist = math.dist( | ||
out_node[NodeAttr.POS.value], | ||
in_node[NodeAttr.POS.value] + in_node[NodeAttr.FLOW.value], | ||
) | ||
edge[2][EdgeAttr.DISTANCE.value] = dist | ||
|
||
return candidate_graph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unpack into node_id, data would be nicer to read