From c4f3f8ba5d54488c7626f9da930c11922178997e Mon Sep 17 00:00:00 2001 From: Yujing Huang Date: Tue, 2 Jan 2024 17:18:18 -0500 Subject: [PATCH] add samseg/cli scripts gems_compute_atlas_probs.py, merge_add_mesh_alphas.py, sbtiv.py --- samseg/cli/gems_compute_atlas_probs.py | 238 +++++++++++++++++++++++++ samseg/cli/merge_add_mesh_alphas.py | 194 ++++++++++++++++++++ samseg/cli/sbtiv.py | 44 +++++ setup.cfg | 5 +- 4 files changed, 480 insertions(+), 1 deletion(-) create mode 100644 samseg/cli/gems_compute_atlas_probs.py create mode 100644 samseg/cli/merge_add_mesh_alphas.py create mode 100755 samseg/cli/sbtiv.py diff --git a/samseg/cli/gems_compute_atlas_probs.py b/samseg/cli/gems_compute_atlas_probs.py new file mode 100644 index 0000000..a6353d2 --- /dev/null +++ b/samseg/cli/gems_compute_atlas_probs.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python + +########################### +# +# Compute mesh node probabilities (alphas) from "ground truth" segmentation images using estimated node deformations. +# Node probabilities are estimated using an expectation maximization (EM) algorithm. +# +# The script requires that SAMSEG has been run with the --history flag on the subjects of interest +# The script works with 1 or more structures. +# For estimating mesh node probabilities for more than one structure +# the --multi-structure flag should be on and a list of labels should be given as input using the flag --labels +# +########################### + +import sys +import os +import numpy as np +import nibabel as nib +import surfa as sf +import time +import argparse +import samseg +from samseg import gems + +def parseArguments(argv): + parser = argparse.ArgumentParser() + parser.add_argument('--subjects-dir', help='Directory with saved SAMSEG runs with --history flag.', required=True) + parser.add_argument('--mesh-collections', nargs='+', help='Mesh collection file(s).', required=True) + parser.add_argument('--out-dir', help='Output directory.', required=True) + parser.add_argument('--segmentations-dir', help='Directory with GT segmentations.') + parser.add_argument('--gt-from-FS', action='store_true', default=False, help='GT from FreeSurfer segmentations.') + parser.add_argument('--segmentation-name', default='aseg.mgz',help='Filename of the segmentations, assumed to be the same for each subject.') + parser.add_argument('--multi-structure', action='store_true', default=False, help="Estimate alphas from more than 1 structure.") + parser.add_argument('--labels', type=int, nargs='+', help="Labels numbers. Needs --multi-structure flag on.") + parser.add_argument('--from-samseg', action='store_true', default=False, help="SAMSEG runs obtained from command samseg instead of run_samseg.") + parser.add_argument('--EM-iterations', type=int, default=10, help="EM iterations.") + parser.add_argument('--show-figs', action='store_true', default=False, help='Show figures during run.') + parser.add_argument('--save-figs', action='store_true', default=False, help='Save rasterized prior of each subject.') + parser.add_argument('--save-average-figs', action='store_true', default=False, help='Save average rasterized prior.') + parser.add_argument('--subjects_file', help='Text file with list of subjects.') + parser.add_argument('--labels_file', help='Text file with list of labels (instead of --labels).') + parser.add_argument('--samseg-subdir', default='samseg',help='Name of samseg subdir in subject/mri folder') + + args = parser.parse_args() + + return args + + +def main(): + args = parseArguments(sys.argv[1:]) + + if not os.path.exists(args.out_dir): + os.makedirs(args.out_dir) + + if args.show_figs: + visualizer = samseg.initVisualizer(True, True) + else: + visualizer = samseg.initVisualizer(False, False) + + if args.save_figs: + import nibabel as nib + + if(args.subjects_file != None): + with open(args.subjects_file) as f: + lines = f.readlines(); + subject_list = []; + for line in lines: + subject_list.append(line.strip()); + else: + subject_list = [pathname for pathname in os.listdir(args.subjects_dir) if os.path.isdir(os.path.join(args.subjects_dir, pathname))] + subject_list.sort() + number_of_subjects = len(subject_list) + + logfile = os.path.join(args.out_dir,'gems_compute_atlas_probs.log'); + logfp = open(logfile,"w") + logfp.write("cd "+os.getcwd()+"\n"); + logfp.write(' '.join(sys.argv)+"\n") + logfp.flush(); + + outsubjectsfile = os.path.join(args.out_dir,'subjects.txt'); + with open(outsubjectsfile,"w") as f: + for subject in subject_list: + f.write(subject+"\n") + + if(args.labels_file == None): + labels = args.labels; + else: + with open(args.labels_file) as f: + lines = f.readlines(); + labels = []; + for line in lines: + for item in line.split(): + labels.append(int(item)) + + outlabelfile = os.path.join(args.out_dir,'labels.txt'); + with open(outlabelfile,"w") as f: + for label in labels: + f.write(str(label)+"\n") + + print("Labels") + print(labels) + + if args.multi_structure: + number_of_classes = len(labels) + 1 # + 1 for background + else: + number_of_classes = 2 # 1 is background + + # We need an init of the probabilistic segmentation class + # to call instance methods + atlas = samseg.ProbabilisticAtlas() + + t0 = time.time(); + for level, mesh_collection_file in enumerate(args.mesh_collections): + + print("Working on mesh collection at level " + str(level + 1)) + + # Read mesh collection + print("Loading mesh collection at: " + str(mesh_collection_file)) + mesh_collection = gems.KvlMeshCollection() + mesh_collection.read(mesh_collection_file) + + # We are interested only on the reference mesh + mesh = mesh_collection.reference_mesh + number_of_nodes = mesh.point_count + + print('Number of subjects: ' + str(len(subject_list))) + + # Define what we are interesting in, i.e., the label statistics of the structure(s) of interest + label_statistics_in_mesh_nodes = np.zeros([number_of_nodes, number_of_classes, number_of_subjects]) + + for subject_number, subject_dir in enumerate(subject_list): + + # Show progress to anyone who's watching + telapsed = (time.time()-t0)/60; + print("====================================================================") + print("") + #print("Subject number: " + str(subject_number + 1)+"/"+str(number_of_subjects) + " "+ subject_dir + " " str(telapsed)) + print("Level %d Subject %d/%d %s %6.1f" % (level+1,subject_number+1,number_of_subjects,subject_dir,telapsed)) + print("") + print("====================================================================") + logfp.write("Level %d Subject %d/%d %s %6.1f\n" % (level+1,subject_number+1,number_of_subjects,subject_dir,telapsed)) + logfp.flush(); + + # Read the manually annotated segmentation for the specific subject + if args.gt_from_FS: + segpath = os.path.join(args.segmentations_dir, subject_dir, 'mri', args.segmentation_name); + print("seg %s" % segpath); + segmentation_image = nib.load(segpath).get_fdata() + affine = nib.load(os.path.join(args.segmentations_dir, subject_dir, 'mri', args.segmentation_name)).affine + else: + segmentation_image = nib.load(os.path.join(args.segmentations_dir, subject_dir, args.segmentation_name)).get_fdata() + affine = nib.load(os.path.join(args.segmentations_dir, subject_dir, args.segmentation_name)).affine + + if args.from_samseg: + history = np.load(os.path.join(args.subjects_dir, subject_dir, 'mri', args.samseg_subdir, 'history.p'), allow_pickle=True) + else: + history = np.load(os.path.join(args.subjects_dir, subject_dir, 'history.p'), allow_pickle=True) + + # Get the node positions in image voxels + model_specifications = history['input']['modelSpecifications'] + transform_matrix = history['transform'] + transform = gems.KvlTransform(samseg.requireNumpyArray(transform_matrix)) + deformations = history['historyWithinEachMultiResolutionLevel'][level]['deformation'] + node_positions = atlas.getMesh( + mesh_collection_file, + transform, + K=model_specifications.K, + initialDeformation=deformations, + initialDeformationMeshCollectionFileName=mesh_collection_file).points + + # The image is cropped as well so the voxel coordinates + # do not exactly match with the original image, + # i.e., there's a shift. Let's undo that. + cropping = history['cropping'] + node_positions += [slc.start for slc in cropping] + + # Estimate n-class alphas representing the segmentation map, initialized with a flat prior + segmentation_map = np.zeros([segmentation_image.shape[0], segmentation_image.shape[1], segmentation_image.shape[2], + number_of_classes], np.uint16) + if args.multi_structure: + for label_number, label in enumerate(labels): + # + 1 here since we want background as first class + segmentation_map[:, :, :, label_number + 1] = (segmentation_image == label) * 65535 + # Make sure to fill what is left in background class + segmentation_map[:, :, :, 0] = 65535 - np.sum(segmentation_map[:, :, :, 1:], axis=3) + else: + segmentation_map[:, :, :, 0] = (1 - segmentation_image) * 65535 + segmentation_map[:, :, :, 1] = segmentation_image * 65535 + + mesh = mesh_collection.reference_mesh + mesh.points = node_positions + mesh.alphas = mesh.fit_alphas(segmentation_map, args.EM_iterations) + + # Show rasterized prior with updated alphas + if args.show_figs: + rasterized_prior = mesh.rasterize(segmentation_image.shape) / 65535 + rasterized_prior = rasterized_prior[:, :, :, 1:] # No need to show background + visualizer.show(images=rasterized_prior) + + # Save rasterized prior with updated alphas + if args.save_figs: + rasterized_prior = mesh.rasterize(segmentation_image.shape) / 65535 + rasterized_prior = rasterized_prior[:, :, :, 1:] # No need to save background + img = nib.Nifti1Image(rasterized_prior, affine) + nib.save(img, os.path.join(args.out_dir, "level_" + str(level + 1) + "_rasterized_prior_sub" + str(subject_number + 1))) + + # Save label statistics of subject + label_statistics_in_mesh_nodes[:, :, subject_number] = mesh.alphas.copy() + + # Show rasterized prior with alphas as mean + if args.show_figs: + mesh.alphas = np.mean(label_statistics_in_mesh_nodes, axis=2) + rasterized_prior = mesh.rasterize(segmentation_image.shape) / 65535 + rasterized_prior = rasterized_prior[:, :, :, 1:] # No need to show background + visualizer.show(images=rasterized_prior) + + # Save rasterized prior with alphas as mean + if args.save_average_figs: + + mesh.alphas = np.mean(label_statistics_in_mesh_nodes, axis=2) + rasterized_prior = mesh.rasterize(segmentation_image.shape) / 65535 + rasterized_prior = rasterized_prior[:, :, :, 1:] # No need to save background + img = nib.Nifti1Image(rasterized_prior, np.eye(4)) + nib.save(img, os.path.join(args.out_dir, "level_" + str(level + 1) + "_average_rasterized_prior")) + + # Save label statistics in a npy file + np.save(os.path.join(args.out_dir, "label_statistics_atlas_" + str(level + 1)), label_statistics_in_mesh_nodes) + + # end level loop + + logfp.write("gems_compute_atlas_probs done\n"); + logfp.close(); + print ("gems_compute_atlas_probs done"); + + + +if __name__ == '__main__': + main() diff --git a/samseg/cli/merge_add_mesh_alphas.py b/samseg/cli/merge_add_mesh_alphas.py new file mode 100644 index 0000000..f33dd6f --- /dev/null +++ b/samseg/cli/merge_add_mesh_alphas.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python + +description = """Merge prior spatial probabilties (alphas) with those of an existing +SAMSEG mesh to create a new mesh/atlas. Should be run on the output of +gems_compute_atlas_probs. + +An example of use is the following: +fspython merge_add_mesh_alphas [...] --merge-labels 2 3 --add-indexes 3 + +Here we merge labels 2 and 3 and then add the third index of +estimated_alphas.npy to the mesh. For example, in case we used +gems_compute_alphas with --labels 2 3 99, we are adding label 99 to +the mesh. Indexes start from 1 since 0 is reserved for the background +class. + +I also add a --merge-indexes flag in case the merge and add indexes +are mixed. If None the first indexes are used for merging. For +example: gems_compute_alphas --labels 2 99 3, I can do --merge-labels +2 3 --merge-indexes 1 3 --add-indexes 2. + +Few things to consider: + +- First merging and then add != first add and then merging. I + implemented the former, but maybe it's better to implement both and + leave the choice to the user. Not sure what the default should be + though. + +- How we add the extra class(es) can be done in different ways. For + one class I did the following (same thing I did for MS lesion): add + the new alphas and rescale everything else with (1-newalpha). For + more than one class, the new alphas are just added and then the + alphas are re-normalized. Again, maybe it's better to expose this to + the user. + +- You still need to manually modify the compressionLookupTable.txt + file right now (i.e., adding extra entries at the end of the + file). Might be worth considering doing this automatically in the + script. + +- Last functionality to add would be to remove some structures from + the mesh, and automatically update the compressionLookupTable.txt + file. + +""" + +import sys +import os +import numpy as np +import argparse +import surfa as sf +from samseg import gems + +from samseg.io import kvlReadCompressionLookupTable + + +eps = np.finfo(float).eps + +description = """ +Merge/add prior spatial probabilities (alphas) with those of an existing +SAMSEG mesh to create a new mesh/atlas. Should be run on the output of +gems_compute_atlas_probs. +""" + +def main(): + parser = argparse.ArgumentParser(description=description) + parser.add_argument('--out-dir', help='Output directory.', required=True) + parser.add_argument('--estimated-alpha-dirs', nargs='+', help='Estimated alphas dirs', required=True) + parser.add_argument('--samseg-atlas-dir', help='Samseg original directory; default in FS/average') + parser.add_argument('--merge-labels', nargs='+', type=int, help='Labels to merge estimated alphas to.') + parser.add_argument('--merge-indexes', nargs='+', type=int, help='Indexes of estimated alphas to merge to current alphas. First index is 1 (0 is background). If None, first indexes are used.') + parser.add_argument('--merge-names', nargs='+', help='Structures name to merge estimated alphas to.') + parser.add_argument('--add-indexes', nargs='+', type=int, help='Indexes of estimated alphas to add to current alphas. First index is 1 (0 is background).') + parser.add_argument('--level', dest='levellist', type=int, nargs='+', help='Atlas level; default is 1 and 2') + args = parser.parse_args() + + # Assuming 20 subjects were used to build the atlas + samseg_subjects = 20 + + # Get the samseg atlas dir + if (args.samseg_atlas_dir != None): + samseg_atlas_dir = args.samseg_atlas_dir + else: + fsh = os.environ.get('FREESURFER_HOME') + samseg_atlas_dir = fsh + "/average/samseg/20Subjects_smoothing2_down2_smoothingForAffine2" + + if args.levellist != None: + levellist = args.levellist + else: + levellist = [1, 2] + + # Make the output dir + os.makedirs(args.out_dir, exist_ok=True) + logfile = os.path.join(args.out_dir, 'merge_add_mesh_alphas.log') + with open(logfile, "w") as f: + f.write("cd " + os.getcwd() + "\n") + f.write(' '.join(sys.argv) + "\n") + f.write(samseg_atlas_dir + "\n") + if args.merge_labels is not None: + outlabelfile = os.path.join(args.out_dir, 'merged_labels.txt') + with open(outlabelfile, "w") as f: + for label in args.merge_labels: + f.write(str(label) + "\n") + + for level in levellist: + print("level = %d =========================================" % level) + + # Read in alphas + estimated_alphas = [] + for adir in args.estimated_alpha_dirs: + afile = os.path.join(adir, 'label_statistics_atlas_%d.npy' % level) + a = np.load(afile) + print("a " + str(a.shape)) + if (len(estimated_alphas) == 0): + estimated_alphas = a + else: + estimated_alphas = np.concatenate((estimated_alphas, a), axis=2) + # endif + # endfor + print(str(estimated_alphas.shape)) + + # Retrieve all the SAMSEG related files + # Here we are making some assumptions about file names + mesh_collection_path = os.path.join(samseg_atlas_dir, "atlas_level" + str(level) + ".txt.gz") + freesurfer_labels, names, colors = kvlReadCompressionLookupTable( + os.path.join(samseg_atlas_dir, 'compressionLookupTable.txt')) + # Get also compressed labels, as they are in the same order as FreeSurfer_labels + compressed_labels = list(np.arange(0, len(freesurfer_labels))) + + # Read mesh collection + print("Loading mesh collection at: " + str(mesh_collection_path)) + mesh_collection = gems.KvlMeshCollection() + mesh_collection.read(mesh_collection_path) + + # Load reference mesh + mesh = mesh_collection.reference_mesh + alphas = mesh.alphas.copy() + + # First merge classes + if args.merge_labels is not None: + for l, merge_label in enumerate(args.merge_labels): + if args.merge_indexes is not None: + idx2 = args.merge_indexes[l] + else: + idx2 = l + 1 + idx = freesurfer_labels.index(merge_label) + alphas[:, idx] = (alphas[:, idx] * samseg_subjects + np.sum(estimated_alphas[:, idx2, :], axis=1)) \ + / (samseg_subjects + estimated_alphas.shape[2]) + if args.merge_names is not None: + for merge_name in args.merge_names: + if args.merge_indexes is not None: + idx2 = args.merge_indexes[l] + else: + idx2 = l + 1 + idx = names.index(merge_name) + alphas[:, idx] = (alphas[:, idx] * samseg_subjects + np.sum(estimated_alphas[:, idx2, :], axis=1)) \ + / (samseg_subjects + estimated_alphas.shape[2]) + # for mergnames + # endif + + # Re-normalize alphas + if args.merge_labels is not None or args.merge_names is not None: + normalizer = np.sum(alphas, axis=1) + eps + alphas = alphas / normalizer[:, None] + + # Here we add classes. This is done after merging classes (if any). + # Note that switching the order of merging and adding produces a different output + if args.add_indexes is not None: + tmp = np.zeros([alphas.shape[0], alphas.shape[1] + len(args.add_indexes)]) + tmp[:, :alphas.shape[1]] = alphas.copy() + if len(args.add_indexes) == 1: + idx = args.add_indexes[0] + # Only one class to add (every other class is lower down by a factor (1-estimated-alpha) + tmp[:, -1] = np.mean(estimated_alphas[:, idx, :], axis=1) + tmp[:, :alphas.shape[1]] *= (1 - tmp[:, -1])[:, None] + else: + # More than one class, add all the estimated alphas and re-normalize + for l, idx in enumerate(args.add_indexes): + tmp[:, alphas.shape[1] + l] = np.mean(estimated_alphas[:, idx, :], axis=1) + normalizer = np.sum(tmp, axis=1) + eps + tmp = tmp / normalizer[:, None] + + alphas = tmp + + # Add alphas in mesh + mesh.alphas = alphas + # Save mesh + mesh_collection.write(os.path.join(args.out_dir, "atlas_level" + str(level) + ".txt")) + # end loop over levels + + print("merge_add_mesh_alphas done") + + +if __name__ == '__main__': + main() diff --git a/samseg/cli/sbtiv.py b/samseg/cli/sbtiv.py new file mode 100755 index 0000000..cafc0a7 --- /dev/null +++ b/samseg/cli/sbtiv.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python + +import argparse +import surfa as sf +from samseg import icv + +description = ''' +Calculates the total intracranial volume of a subject by summing individual volumes computed by samseg. \ +A file containing a list of intracranial labelnames can be provided via the '--labels' flag, but if omitted, \ +a default list is used. Labelnames must be identical to those defined in the samseg atlas. +''' + +def main(): + # parse command line args + parser = argparse.ArgumentParser(description=description) + parser.add_argument('stats', metavar='FILE', help='Volume stats input file.') + parser.add_argument('-o', '--out', metavar='FILE', help='Intracranial stats output file.') + parser.add_argument('-l', '--labels', metavar="FILE", help='File containing a list of intracranial structure labelnames to include in the calculation.') + args = parser.parse_args() + + # read in structure names and volumes from samseg stats + structures = [] + with open(args.input) as fid: + for line in fid.readlines(): + name, vol, _ = line.split(',') + _, _, name = name.split(' ') + structures.append([name.strip(), float(vol)]) + + # read in structure names that are considered intra-cranial + includeStructures = None + if args.map: + with open(args.map) as fid: includeStructures = [line.strip() for line in fid.readlines()] + + # compute intra-cranial volume + sbtiv = icv(structures, includeStructures) + + # write out and exit + print('intracranial volume: %.6f mm^3' % sbtiv) + if args.output: + with open(args.output, 'w') as fid: fid.write('# Measure Intra-Cranial, %.6f, mm^3\n' % sbtiv) + + +if __name__ == '__main__': + main() diff --git a/setup.cfg b/setup.cfg index 954b4ad..85f287a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,7 +33,10 @@ console_scripts= computeTissueConcentrations = samseg.cli.computeTissueConcentrations:main prepareAtlasDirectory = samseg.cli.prepareAtlasDirectory:main run_samseg_long = samseg.cli.run_samseg_long:main - segment_subregions = samseg.cl.segment_subregions:main + segment_subregions = samseg.cli.segment_subregions:main + sbtiv = samseg.cli.sbtiv:main + gems_compute_atlas_probs = samseg.cli.gems_compute_atlas_probs:main + merge_add_mesh_alphas = samseg.cli.merge_add_mesh_alphas:main