Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SkeletonTransform may cause the skeleton line extracted error for multi-class segmentation? #4

Open
Yuxiang1990 opened this issue Sep 26, 2024 · 1 comment

Comments

@Yuxiang1990
Copy link

Hi,
for multi-class segmentation, extraction skeleton line for each label may diff from extraction of binary mask followed by multiply label mask.

class SkeletonTransform(BasicTransform):
    def __init__(self, do_tube: bool = True, num_classes: int = 1):
        """
        Calculates the skeleton of the segmentation (plus an optional 2 px tube around it)
        and adds it to the dict with the key "skel"
        """
        super().__init__()
        self.do_tube = do_tube
        self.num_classes = num_classes  # needed for compatibility with 3D data
        assert self.num_classes >= 1

    def apply(self, data_dict, **params):
        seg_all = data_dict['segmentation'].numpy()
        # Add tubed skeleton GT
        seg_all_skel = np.zeros_like(seg_all, dtype=np.int16)

        for labelid in range(1, self.num_classes + 1):
            # Skeletonize
            if not np.sum(seg_all[0] == labelid) == 0:
                skel = skeletonize(seg_all[0] == labelid)
                skel = (skel > 0).astype(np.int16)
                if self.do_tube:
                    skel = dilation(skel)
                seg_all_skel[0][skel > 0] = labelid

        data_dict["skel"] = torch.from_numpy(seg_all_skel)
        return data_dict

    def apply_old(self, data_dict, **params):
        seg_all = data_dict['segmentation'].numpy()
        # Add tubed skeleton GT
        bin_seg = (seg_all > 0)
        seg_all_skel = np.zeros_like(bin_seg, dtype=np.int16)

        if not np.sum(bin_seg[0]) == 0:
            skel = skeletonize(bin_seg[0])
            skel = (skel > 0).astype(np.int16)
            if self.do_tube:
                skel = dilation(dilation(skel))
            skel *= seg_all[0].astype(np.int16)
            seg_all_skel[0] = skel

        data_dict["skel"] = torch.from_numpy(seg_all_skel)
        return data_dict

@ykirchhoff
Copy link
Contributor

Hi @Yuxiang1990,

you are correct that skeletonizing each label individually slightly differs from the binarized skeletonization we are doing for multiclass problems. However, that is actually what we usually want, as this way the skeletons for different classes stay connected if the original segmentations were connected. Think about vessels, where you might be interested in blood flow but have different classes in your vessel tree.

Best,
Yannick

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants