Skip to content

Commit

Permalink
Merge pull request #21 from funkelab/add_area
Browse files Browse the repository at this point in the history
Add area attribute to candidate graph nodes
  • Loading branch information
cmalinmayor authored Sep 16, 2024
2 parents 7ebd9c1 + 6ad4d55 commit bf5c52b
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/motile_toolbox/candidate_graph/graph_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class NodeAttr(Enum):
TIME = "time"
SEG_ID = "seg_id"
SEG_HYPO = "seg_hypo"
AREA = "area"


class EdgeAttr(Enum):
Expand Down
17 changes: 11 additions & 6 deletions src/motile_toolbox/candidate_graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,16 @@ def nodes_from_segmentation(
segmentation: np.ndarray,
scale: list[float] | None = None,
) -> tuple[nx.DiGraph, dict[int, list[Any]]]:
"""Extract candidate nodes from a segmentation. Also computes specified attributes.
Returns a networkx graph with only nodes, and also a dictionary from frames to
node_ids for efficient edge adding.
"""Extract candidate nodes from a segmentation. Returns a networkx graph
with only nodes, and also a dictionary from frames to node_ids for
efficient edge adding.
Each node will have the following attributes (named as in NodeAttrs):
- time
- position
- segmentation id
- area
- hypothesis id (optional)
Args:
segmentation (np.ndarray): A numpy array with integer labels and dimensions
Expand Down Expand Up @@ -77,9 +84,7 @@ def nodes_from_segmentation(
props = regionprops(hypo, spacing=tuple(scale[1:]))
for regionprop in props:
node_id = get_node_id(t, regionprop.label, hypothesis_id=hypo_id)
attrs = {
NodeAttr.TIME.value: t,
}
attrs = {NodeAttr.TIME.value: t, NodeAttr.AREA.value: regionprop.area}
attrs[NodeAttr.SEG_ID.value] = regionprop.label
if hypo_id is not None:
attrs[NodeAttr.SEG_HYPO.value] = hypo_id
Expand Down
17 changes: 17 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def graph_2d():
NodeAttr.POS.value: (50, 50),
NodeAttr.TIME.value: 0,
NodeAttr.SEG_ID.value: 1,
NodeAttr.AREA.value: 1245,
},
),
(
Expand All @@ -79,6 +80,7 @@ def graph_2d():
NodeAttr.POS.value: (20, 80),
NodeAttr.TIME.value: 1,
NodeAttr.SEG_ID.value: 1,
NodeAttr.AREA.value: 305,
},
),
(
Expand All @@ -87,6 +89,7 @@ def graph_2d():
NodeAttr.POS.value: (60, 45),
NodeAttr.TIME.value: 1,
NodeAttr.SEG_ID.value: 2,
NodeAttr.AREA.value: 697,
},
),
]
Expand All @@ -110,6 +113,7 @@ def multi_hypothesis_graph_2d():
NodeAttr.TIME.value: 0,
NodeAttr.SEG_HYPO.value: 0,
NodeAttr.SEG_ID.value: 1,
NodeAttr.AREA.value: 1245,
},
),
(
Expand All @@ -119,6 +123,7 @@ def multi_hypothesis_graph_2d():
NodeAttr.TIME.value: 0,
NodeAttr.SEG_HYPO.value: 1,
NodeAttr.SEG_ID.value: 1,
NodeAttr.AREA.value: 697,
},
),
(
Expand All @@ -128,6 +133,7 @@ def multi_hypothesis_graph_2d():
NodeAttr.TIME.value: 1,
NodeAttr.SEG_HYPO.value: 0,
NodeAttr.SEG_ID.value: 1,
NodeAttr.AREA.value: 305,
},
),
(
Expand All @@ -137,6 +143,7 @@ def multi_hypothesis_graph_2d():
NodeAttr.TIME.value: 1,
NodeAttr.SEG_HYPO.value: 1,
NodeAttr.SEG_ID.value: 1,
NodeAttr.AREA.value: 697,
},
),
(
Expand All @@ -146,6 +153,7 @@ def multi_hypothesis_graph_2d():
NodeAttr.TIME.value: 1,
NodeAttr.SEG_HYPO.value: 0,
NodeAttr.SEG_ID.value: 2,
NodeAttr.AREA.value: 697,
},
),
(
Expand All @@ -155,6 +163,7 @@ def multi_hypothesis_graph_2d():
NodeAttr.TIME.value: 1,
NodeAttr.SEG_HYPO.value: 1,
NodeAttr.SEG_ID.value: 2,
NodeAttr.AREA.value: 1245,
},
),
]
Expand Down Expand Up @@ -256,6 +265,7 @@ def graph_3d():
NodeAttr.POS.value: (50, 50, 50),
NodeAttr.TIME.value: 0,
NodeAttr.SEG_ID.value: 1,
NodeAttr.AREA.value: 33401,
},
),
(
Expand All @@ -264,6 +274,7 @@ def graph_3d():
NodeAttr.POS.value: (20, 50, 80),
NodeAttr.TIME.value: 1,
NodeAttr.SEG_ID.value: 1,
NodeAttr.AREA.value: 4169,
},
),
(
Expand All @@ -272,6 +283,7 @@ def graph_3d():
NodeAttr.POS.value: (60, 50, 45),
NodeAttr.TIME.value: 1,
NodeAttr.SEG_ID.value: 2,
NodeAttr.AREA.value: 14147,
},
),
]
Expand All @@ -297,6 +309,7 @@ def multi_hypothesis_graph_3d():
NodeAttr.TIME.value: 0,
NodeAttr.SEG_HYPO.value: 0,
NodeAttr.SEG_ID.value: 1,
NodeAttr.AREA.value: 305,
},
),
(
Expand All @@ -306,6 +319,7 @@ def multi_hypothesis_graph_3d():
NodeAttr.TIME.value: 1,
NodeAttr.SEG_HYPO.value: 1,
NodeAttr.SEG_ID.value: 1,
NodeAttr.AREA.value: 305,
},
),
(
Expand All @@ -315,6 +329,7 @@ def multi_hypothesis_graph_3d():
NodeAttr.TIME.value: 1,
NodeAttr.SEG_HYPO.value: 0,
NodeAttr.SEG_ID.value: 1,
NodeAttr.AREA.value: 305,
},
),
(
Expand All @@ -324,6 +339,7 @@ def multi_hypothesis_graph_3d():
NodeAttr.TIME.value: 1,
NodeAttr.SEG_HYPO.value: 0,
NodeAttr.SEG_ID.value: 2,
NodeAttr.AREA.value: 305,
},
),
(
Expand All @@ -333,6 +349,7 @@ def multi_hypothesis_graph_3d():
NodeAttr.TIME.value: 1,
NodeAttr.SEG_HYPO.value: 1,
NodeAttr.SEG_ID.value: 1,
NodeAttr.AREA.value: 305,
},
),
]
Expand Down
5 changes: 5 additions & 0 deletions tests/test_candidate_graph/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_nodes_from_segmentation_2d(segmentation_2d):
assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"])
assert node_graph.nodes["1_1"][NodeAttr.SEG_ID.value] == 1
assert node_graph.nodes["1_1"][NodeAttr.TIME.value] == 1
assert node_graph.nodes["1_1"][NodeAttr.AREA.value] == 305
assert node_graph.nodes["1_1"][NodeAttr.POS.value] == (20, 80)

assert node_frame_dict[0] == ["0_1"]
Expand All @@ -44,6 +45,7 @@ def test_nodes_from_segmentation_2d(segmentation_2d):
assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"])
assert node_graph.nodes["1_1"][NodeAttr.SEG_ID.value] == 1
assert node_graph.nodes["1_1"][NodeAttr.TIME.value] == 1
assert node_graph.nodes["1_1"][NodeAttr.AREA.value] == 610
assert node_graph.nodes["1_1"][NodeAttr.POS.value] == (20, 160)

assert node_frame_dict[0] == ["0_1"]
Expand All @@ -63,6 +65,7 @@ def test_nodes_from_segmentation_2d_hypo(
assert node_graph.nodes["1_0_1"][NodeAttr.SEG_ID.value] == 1
assert node_graph.nodes["1_0_1"][NodeAttr.SEG_HYPO.value] == 0
assert node_graph.nodes["1_0_1"][NodeAttr.TIME.value] == 1
assert node_graph.nodes["1_0_1"][NodeAttr.AREA.value] == 305
assert node_graph.nodes["1_0_1"][NodeAttr.POS.value] == (20, 80)

assert Counter(node_frame_dict[0]) == Counter(["0_0_1", "0_1_1"])
Expand All @@ -77,6 +80,7 @@ def test_nodes_from_segmentation_3d(segmentation_3d):
assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"])
assert node_graph.nodes["1_1"][NodeAttr.SEG_ID.value] == 1
assert node_graph.nodes["1_1"][NodeAttr.TIME.value] == 1
assert node_graph.nodes["1_1"][NodeAttr.AREA.value] == 4169
assert node_graph.nodes["1_1"][NodeAttr.POS.value] == (20, 50, 80)

assert node_frame_dict[0] == ["0_1"]
Expand All @@ -88,6 +92,7 @@ def test_nodes_from_segmentation_3d(segmentation_3d):
)
assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"])
assert node_graph.nodes["1_1"][NodeAttr.SEG_ID.value] == 1
assert node_graph.nodes["1_1"][NodeAttr.AREA.value] == 4169 * 4.5
assert node_graph.nodes["1_1"][NodeAttr.TIME.value] == 1
assert node_graph.nodes["1_1"][NodeAttr.POS.value] == (20.0, 225.0, 80.0)

Expand Down

0 comments on commit bf5c52b

Please sign in to comment.