diff --git a/src/motile_toolbox/candidate_graph/graph_attributes.py b/src/motile_toolbox/candidate_graph/graph_attributes.py index a2927e7..f7f9a50 100644 --- a/src/motile_toolbox/candidate_graph/graph_attributes.py +++ b/src/motile_toolbox/candidate_graph/graph_attributes.py @@ -11,6 +11,7 @@ class NodeAttr(Enum): TIME = "time" SEG_ID = "seg_id" SEG_HYPO = "seg_hypo" + AREA = "area" class EdgeAttr(Enum): diff --git a/src/motile_toolbox/candidate_graph/utils.py b/src/motile_toolbox/candidate_graph/utils.py index 7775410..feebf44 100644 --- a/src/motile_toolbox/candidate_graph/utils.py +++ b/src/motile_toolbox/candidate_graph/utils.py @@ -38,9 +38,16 @@ def nodes_from_segmentation( segmentation: np.ndarray, scale: list[float] | None = None, ) -> tuple[nx.DiGraph, dict[int, list[Any]]]: - """Extract candidate nodes from a segmentation. Also computes specified attributes. - Returns a networkx graph with only nodes, and also a dictionary from frames to - node_ids for efficient edge adding. + """Extract candidate nodes from a segmentation. Returns a networkx graph + with only nodes, and also a dictionary from frames to node_ids for + efficient edge adding. + + Each node will have the following attributes (named as in NodeAttrs): + - time + - position + - segmentation id + - area + - hypothesis id (optional) Args: segmentation (np.ndarray): A numpy array with integer labels and dimensions @@ -77,9 +84,7 @@ def nodes_from_segmentation( props = regionprops(hypo, spacing=tuple(scale[1:])) for regionprop in props: node_id = get_node_id(t, regionprop.label, hypothesis_id=hypo_id) - attrs = { - NodeAttr.TIME.value: t, - } + attrs = {NodeAttr.TIME.value: t, NodeAttr.AREA.value: regionprop.area} attrs[NodeAttr.SEG_ID.value] = regionprop.label if hypo_id is not None: attrs[NodeAttr.SEG_HYPO.value] = hypo_id diff --git a/tests/conftest.py b/tests/conftest.py index b728b7e..b88eaaf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -71,6 +71,7 @@ def graph_2d(): NodeAttr.POS.value: (50, 50), NodeAttr.TIME.value: 0, NodeAttr.SEG_ID.value: 1, + NodeAttr.AREA.value: 1245, }, ), ( @@ -79,6 +80,7 @@ def graph_2d(): NodeAttr.POS.value: (20, 80), NodeAttr.TIME.value: 1, NodeAttr.SEG_ID.value: 1, + NodeAttr.AREA.value: 305, }, ), ( @@ -87,6 +89,7 @@ def graph_2d(): NodeAttr.POS.value: (60, 45), NodeAttr.TIME.value: 1, NodeAttr.SEG_ID.value: 2, + NodeAttr.AREA.value: 697, }, ), ] @@ -110,6 +113,7 @@ def multi_hypothesis_graph_2d(): NodeAttr.TIME.value: 0, NodeAttr.SEG_HYPO.value: 0, NodeAttr.SEG_ID.value: 1, + NodeAttr.AREA.value: 1245, }, ), ( @@ -119,6 +123,7 @@ def multi_hypothesis_graph_2d(): NodeAttr.TIME.value: 0, NodeAttr.SEG_HYPO.value: 1, NodeAttr.SEG_ID.value: 1, + NodeAttr.AREA.value: 697, }, ), ( @@ -128,6 +133,7 @@ def multi_hypothesis_graph_2d(): NodeAttr.TIME.value: 1, NodeAttr.SEG_HYPO.value: 0, NodeAttr.SEG_ID.value: 1, + NodeAttr.AREA.value: 305, }, ), ( @@ -137,6 +143,7 @@ def multi_hypothesis_graph_2d(): NodeAttr.TIME.value: 1, NodeAttr.SEG_HYPO.value: 1, NodeAttr.SEG_ID.value: 1, + NodeAttr.AREA.value: 697, }, ), ( @@ -146,6 +153,7 @@ def multi_hypothesis_graph_2d(): NodeAttr.TIME.value: 1, NodeAttr.SEG_HYPO.value: 0, NodeAttr.SEG_ID.value: 2, + NodeAttr.AREA.value: 697, }, ), ( @@ -155,6 +163,7 @@ def multi_hypothesis_graph_2d(): NodeAttr.TIME.value: 1, NodeAttr.SEG_HYPO.value: 1, NodeAttr.SEG_ID.value: 2, + NodeAttr.AREA.value: 1245, }, ), ] @@ -256,6 +265,7 @@ def graph_3d(): NodeAttr.POS.value: (50, 50, 50), NodeAttr.TIME.value: 0, NodeAttr.SEG_ID.value: 1, + NodeAttr.AREA.value: 33401, }, ), ( @@ -264,6 +274,7 @@ def graph_3d(): NodeAttr.POS.value: (20, 50, 80), NodeAttr.TIME.value: 1, NodeAttr.SEG_ID.value: 1, + NodeAttr.AREA.value: 4169, }, ), ( @@ -272,6 +283,7 @@ def graph_3d(): NodeAttr.POS.value: (60, 50, 45), NodeAttr.TIME.value: 1, NodeAttr.SEG_ID.value: 2, + NodeAttr.AREA.value: 14147, }, ), ] @@ -297,6 +309,7 @@ def multi_hypothesis_graph_3d(): NodeAttr.TIME.value: 0, NodeAttr.SEG_HYPO.value: 0, NodeAttr.SEG_ID.value: 1, + NodeAttr.AREA.value: 305, }, ), ( @@ -306,6 +319,7 @@ def multi_hypothesis_graph_3d(): NodeAttr.TIME.value: 1, NodeAttr.SEG_HYPO.value: 1, NodeAttr.SEG_ID.value: 1, + NodeAttr.AREA.value: 305, }, ), ( @@ -315,6 +329,7 @@ def multi_hypothesis_graph_3d(): NodeAttr.TIME.value: 1, NodeAttr.SEG_HYPO.value: 0, NodeAttr.SEG_ID.value: 1, + NodeAttr.AREA.value: 305, }, ), ( @@ -324,6 +339,7 @@ def multi_hypothesis_graph_3d(): NodeAttr.TIME.value: 1, NodeAttr.SEG_HYPO.value: 0, NodeAttr.SEG_ID.value: 2, + NodeAttr.AREA.value: 305, }, ), ( @@ -333,6 +349,7 @@ def multi_hypothesis_graph_3d(): NodeAttr.TIME.value: 1, NodeAttr.SEG_HYPO.value: 1, NodeAttr.SEG_ID.value: 1, + NodeAttr.AREA.value: 305, }, ), ] diff --git a/tests/test_candidate_graph/test_utils.py b/tests/test_candidate_graph/test_utils.py index f03c404..a2f9462 100644 --- a/tests/test_candidate_graph/test_utils.py +++ b/tests/test_candidate_graph/test_utils.py @@ -32,6 +32,7 @@ def test_nodes_from_segmentation_2d(segmentation_2d): assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) assert node_graph.nodes["1_1"][NodeAttr.SEG_ID.value] == 1 assert node_graph.nodes["1_1"][NodeAttr.TIME.value] == 1 + assert node_graph.nodes["1_1"][NodeAttr.AREA.value] == 305 assert node_graph.nodes["1_1"][NodeAttr.POS.value] == (20, 80) assert node_frame_dict[0] == ["0_1"] @@ -44,6 +45,7 @@ def test_nodes_from_segmentation_2d(segmentation_2d): assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) assert node_graph.nodes["1_1"][NodeAttr.SEG_ID.value] == 1 assert node_graph.nodes["1_1"][NodeAttr.TIME.value] == 1 + assert node_graph.nodes["1_1"][NodeAttr.AREA.value] == 610 assert node_graph.nodes["1_1"][NodeAttr.POS.value] == (20, 160) assert node_frame_dict[0] == ["0_1"] @@ -63,6 +65,7 @@ def test_nodes_from_segmentation_2d_hypo( assert node_graph.nodes["1_0_1"][NodeAttr.SEG_ID.value] == 1 assert node_graph.nodes["1_0_1"][NodeAttr.SEG_HYPO.value] == 0 assert node_graph.nodes["1_0_1"][NodeAttr.TIME.value] == 1 + assert node_graph.nodes["1_0_1"][NodeAttr.AREA.value] == 305 assert node_graph.nodes["1_0_1"][NodeAttr.POS.value] == (20, 80) assert Counter(node_frame_dict[0]) == Counter(["0_0_1", "0_1_1"]) @@ -77,6 +80,7 @@ def test_nodes_from_segmentation_3d(segmentation_3d): assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) assert node_graph.nodes["1_1"][NodeAttr.SEG_ID.value] == 1 assert node_graph.nodes["1_1"][NodeAttr.TIME.value] == 1 + assert node_graph.nodes["1_1"][NodeAttr.AREA.value] == 4169 assert node_graph.nodes["1_1"][NodeAttr.POS.value] == (20, 50, 80) assert node_frame_dict[0] == ["0_1"] @@ -88,6 +92,7 @@ def test_nodes_from_segmentation_3d(segmentation_3d): ) assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) assert node_graph.nodes["1_1"][NodeAttr.SEG_ID.value] == 1 + assert node_graph.nodes["1_1"][NodeAttr.AREA.value] == 4169 * 4.5 assert node_graph.nodes["1_1"][NodeAttr.TIME.value] == 1 assert node_graph.nodes["1_1"][NodeAttr.POS.value] == (20.0, 225.0, 80.0)