From c953db6648073112a2bd30507167ef96bc125771 Mon Sep 17 00:00:00 2001 From: Yujing Huang Date: Tue, 2 Jan 2024 13:13:41 -0500 Subject: [PATCH 1/4] fix collapsed label handling (freesurfer/freesurfer commit db3bdc7) 1. fix the logic of building collapsed label map: for any labels that are in collapsed label list if the label is not in m_CompressionLookupTable assign new class number, add it to m_CompressionLookupTable add the class number to collapsed label map else retrieve its assigned class number add the class number to collapsed label map 2. add unittest for kvlCompressionLookupTable --- gems/Testing/CMakeLists.txt | 5 +- gems/Testing/testCompressionLookupTable.cxx | 62 +++++++++++++++++++++ gems/kvlCompressionLookupTable.cxx | 27 ++++++--- gems/kvlCompressionLookupTable.h | 10 ++++ 4 files changed, 96 insertions(+), 8 deletions(-) create mode 100644 gems/Testing/testCompressionLookupTable.cxx diff --git a/gems/Testing/CMakeLists.txt b/gems/Testing/CMakeLists.txt index 8ef916b..39e5239 100644 --- a/gems/Testing/CMakeLists.txt +++ b/gems/Testing/CMakeLists.txt @@ -8,7 +8,10 @@ endif() add_executable(kvlAtlasMeshRasterizorTestGPU kvlAtlasMeshRasterizorTestGPU.cxx) target_link_libraries(kvlAtlasMeshRasterizorTestGPU kvlGEMSCommon) -# +# testCompressionLookupTable +add_executable(testCompressionLookupTable testCompressionLookupTable.cxx) +target_link_libraries(testCompressionLookupTable kvlGEMSCommon) + add_executable(testdeformMeshPython testdeformMeshPython.cxx) target_link_libraries(testdeformMeshPython kvlGEMSCommon) diff --git a/gems/Testing/testCompressionLookupTable.cxx b/gems/Testing/testCompressionLookupTable.cxx new file mode 100644 index 0000000..56a9be8 --- /dev/null +++ b/gems/Testing/testCompressionLookupTable.cxx @@ -0,0 +1,62 @@ +#include "itkImageFileReader.h" +#include "itkMGHImageIOFactory.h" +#include "kvlCompressionLookupTable.h" + +int main(int argc, char **argv) +{ + // Add support for MGH file format to ITK. An alternative way to add this by default would be + // to edit ITK's itkImageIOFactory.cxx and explicitly adding it in the code there. + itk::ObjectFactoryBase::RegisterFactory( itk::MGHImageIOFactory::New() ); + + // Read the input images + typedef kvl::CompressionLookupTable::ImageType LabelImageType; + std::vector< LabelImageType::ConstPointer > labelImages; + for ( int argumentNumber = 1; argumentNumber < argc; argumentNumber++ ) + { + std::cout << "Reading input image: " << argv[ argumentNumber ] << std::endl; + // Read the input image + typedef itk::ImageFileReader< LabelImageType > ReaderType; + ReaderType::Pointer reader = ReaderType::New(); + reader->SetFileName( argv[ argumentNumber ] ); + reader->Update(); + LabelImageType::ConstPointer labelImage = reader->GetOutput(); + + // Over-ride the spacing and origin since at this point we can't deal with that + const double spacing[] = { 1, 1, 1 }; + const double origin[] = { 0, 0, 0 }; + const_cast< LabelImageType* >( labelImage.GetPointer() )->SetSpacing( spacing ); + const_cast< LabelImageType* >( labelImage.GetPointer() )->SetOrigin( origin ); + + // Remember this image + labelImages.push_back( labelImage ); + } + + // Build a lookup table that maps the original intensities onto class numbers starting + // at 0 and densely packed + kvl::CompressionLookupTable::Pointer lookupTable = kvl::CompressionLookupTable::New(); + lookupTable->Construct( labelImages ); + lookupTable->Write( "compressionLookupTable.txt" ); + + /*********** START OF SIMULATION ***********/ + printf("\nClasses Contributed To Label Cost Calculations:\n"); + // loop through each label, + // report which classes will contribute to cost/gradient/likelihood/color calculation + std::vector labels = lookupTable->GetLabels(); + std::vector::const_iterator labelIt; + for (labelIt = labels.begin(); labelIt != labels.end(); labelIt++) + { + printf("label %5d <= ", *labelIt); + const std::vector< int >& classNumbers = lookupTable->GetClassNumbers(*labelIt); + for ( std::vector< int >::const_iterator classIt = classNumbers.begin(); + classIt != classNumbers.end(); + ++classIt ) + printf(" %3d, ", *classIt); + + printf("\n"); + } + + printf("\nTotal # of Labels: %d\n", labels.size()); + printf("Total # of Classes: %d\n", lookupTable->GetNumberOfClasses()); + + exit(0); +} diff --git a/gems/kvlCompressionLookupTable.cxx b/gems/kvlCompressionLookupTable.cxx index 261e4be..d0b7014 100755 --- a/gems/kvlCompressionLookupTable.cxx +++ b/gems/kvlCompressionLookupTable.cxx @@ -108,6 +108,8 @@ ::Construct( const std::vector< ImageType::ConstPointer >& images ) m_LabelStringLookupTable.clear(); m_ColorLookupTable.clear(); m_NumberOfClasses = 0; + + printf("\n"); // Loop over all collapsed labels, if any, adding a new class for // each real label that is encountered while also pointing the @@ -128,25 +130,35 @@ ::Construct( const std::vector< ImageType::ConstPointer >& images ) collapsedIt != collapsedLabels.end(); ++collapsedIt ) { m_CompressionLookupTable[ collapsedIt->first ] = std::vector< int >(); + printf("Collapsed Label %-5d : \n", collapsedIt->first); + // loop through the Collapsed Label list, add any labels not in m_CompressionLookupTable yet for ( std::vector< int >::const_iterator labelIt = collapsedIt->second.begin(); labelIt != collapsedIt->second.end(); ++labelIt ) { - if ( m_CompressionLookupTable.find( *labelIt ) == m_CompressionLookupTable.end() ) - { - std::cout << "Encountered new real label " << *labelIt << std::endl; - + CompressionLookupTableType::const_iterator compressionIt = m_CompressionLookupTable.find( *labelIt ); + if ( compressionIt == m_CompressionLookupTable.end() ) //if ( m_CompressionLookupTable.find( *labelIt ) == m_CompressionLookupTable.end() ) + { const int newClassNumber = m_NumberOfClasses; m_CompressionLookupTable[ *labelIt ] = std::vector< int >( 1, newClassNumber ); m_CompressionLookupTable[ collapsedIt->first ].push_back( newClassNumber ); + printf("\t Encountered new real label (collapsed label list): %-5d -> %-3d\n", *labelIt, newClassNumber); + //std::cout << "Encountered new real label (collapsedLabels.txt): " << *labelIt << ", assigned class " << newClassNumber << std::endl; m_NumberOfClasses++; } + else + { + // the label is in m_CompressionLookupTable already, find the class assigned to it + const int newClassNumber = compressionIt->second[ 0 ]; + m_CompressionLookupTable[ collapsedIt->first ].push_back( newClassNumber ); + printf("\t (collapsed label list): %-5d -> %-3d\n", *labelIt, newClassNumber); + } } // End loop over all real labels belonging to a certain collapsed label } // End loop over all collapsed labels - // Also loop over all pixels of all images, and create a new entry for every new intensity encountered + // Also loop over all pixels of all images, and create a new entry for any labels not in m_CompressionLookupTable yet for ( std::vector< ImageType::ConstPointer >::const_iterator it = images.begin(); it != images.end(); ++it ) { @@ -157,10 +169,11 @@ ::Construct( const std::vector< ImageType::ConstPointer >& images ) { if ( m_CompressionLookupTable.find( voxelIt.Get() ) == m_CompressionLookupTable.end() ) { - std::cout << "Encountered new real label " << voxelIt.Get() << std::endl; - + // label not in m_CompressionLookupTable yet, add it const int newClassNumber = m_NumberOfClasses; m_CompressionLookupTable[ voxelIt.Get() ] = std::vector< int >( 1, newClassNumber ); + printf("\t Encountered new real label (image) %-5d -> %-3d\n" , voxelIt.Get(), newClassNumber); + //std::cout << "Encountered new real label (image) " << voxelIt.Get() << ", assigned class " << newClassNumber << std::endl; m_NumberOfClasses++; } } // End loop over all pixels diff --git a/gems/kvlCompressionLookupTable.h b/gems/kvlCompressionLookupTable.h index 535a424..e51088b 100755 --- a/gems/kvlCompressionLookupTable.h +++ b/gems/kvlCompressionLookupTable.h @@ -49,6 +49,16 @@ public : { return m_CompressionLookupTable.find( label )->second; } + + const std::vector GetLabels() const + { + std::vector labels; + CompressionLookupTableType::const_iterator it; + for (it = m_CompressionLookupTable.begin(); it != m_CompressionLookupTable.end(); ++it) + labels.push_back(it->first); + + return labels; + } const ColorType& GetColor( int classNumber ) const { From 3ca979f70e8e5c28e48f7f5e1d432990dab1c527 Mon Sep 17 00:00:00 2001 From: Yujing Huang Date: Tue, 2 Jan 2024 14:40:04 -0500 Subject: [PATCH 2/4] merge in Doug's changes from freesurfer/freesurfer/python/gems --- samseg/Samseg.py | 95 +++++++++++++++++++++++++++++++-------------- samseg/utilities.py | 12 +++++- 2 files changed, 76 insertions(+), 31 deletions(-) diff --git a/samseg/Samseg.py b/samseg/Samseg.py index 640b3ca..c1975c5 100644 --- a/samseg/Samseg.py +++ b/samseg/Samseg.py @@ -4,16 +4,16 @@ import pickle import scipy.io import surfa as sf -from scipy.ndimage import binary_dilation as dilation +from scipy.ndimage.morphology import binary_dilation as dilation -from samseg import gems -from .utilities import Specification -from .BiasField import BiasField -from .ProbabilisticAtlas import ProbabilisticAtlas -from .GMM import GMM -from .Affine import Affine -from .SamsegUtility import * -from .merge_alphas import kvlMergeAlphas, kvlGetMergingFractionsTable +import gems +from gems.utilities import Specification +from gems.BiasField import BiasField +from gems.ProbabilisticAtlas import ProbabilisticAtlas +from gems.GMM import GMM +from gems.Affine import Affine +from gems.SamsegUtility import * +from gems.merge_alphas import kvlMergeAlphas, kvlGetMergingFractionsTable eps = np.finfo(float).eps @@ -74,10 +74,10 @@ def __init__(self, raise ValueError('In photo mode, you cannot provide more than one input image volume') input_vol = sf.load_volume(self.imageFileNames[0]) - input_vol.save(os.path.join(self.savePath, 'original_input.mgz')) + input_vol.save(self.savePath + '/original_input.mgz') if input_vol.nframes > 1: input_vol = input_vol.mean(frames=True) - input_vol.save(os.path.join(self.savePath, 'grayscale_input.mgz')) + input_vol.save(self.savePath + '/grayscale_input.mgz') # We also a small band of noise around the mask; otherwise the background/skull/etc may fit the cortex self.photo_mask = input_vol.data > 0 @@ -87,14 +87,14 @@ def __init__(self, rng = np.random.default_rng(2021) input_vol.data[ring] = max_noise * rng.random(input_vol.data[ring].shape[0]) self.imageFileNames = [] - self.imageFileNames.append(os.path.join(self.savePath,'grayscale_input_with_ring.mgz')) + self.imageFileNames.append(self.savePath + '/grayscale_input_with_ring.mgz') input_vol.save(self.imageFileNames[0]) # Initialize some objects self.affine = Affine( imageFileName=self.imageFileNames[0], meshCollectionFileName=os.path.join(self.atlasDir, 'atlasForAffineRegistration.txt.gz'), - templateFileName=os.path.join(self.atlasDir, 'template.nii.gz' ) ) + templateFileName=os.path.join(self.atlasDir, 'template.nii' ) ) self.probabilisticAtlas = ProbabilisticAtlas() # Get full model specifications and optimization options (using default unless overridden by user) @@ -102,11 +102,11 @@ def __init__(self, self.optimizationOptions = getOptimizationOptions(atlasDir, userOptimizationOptions) if dissectionPhoto and (gmmFileName is None): if dissectionPhoto == 'left': - gmmFileName = os.path.join(self.atlasDir, 'photo.lh.sharedGMMParameters.txt') + gmmFileName = self.atlasDir + '/photo.lh.sharedGMMParameters.txt' elif dissectionPhoto == 'right': - gmmFileName = os.path.join(self.atlasDir, 'photo.rh.sharedGMMParameters.txt') + gmmFileName = self.atlasDir + '/photo.rh.sharedGMMParameters.txt' elif dissectionPhoto == 'both': - gmmFileName = os.path.join(self.atlasDir, 'photo.both.sharedGMMParameters.txt') + gmmFileName = self.atlasDir + '/photo.both.sharedGMMParameters.txt' else: sf.system.fatal('dissection photo mode must be left, right, or both') self.modelSpecifications = getModelSpecifications( @@ -229,16 +229,16 @@ def segment(self, costfile=None, timer=None, reg_only=False, transformFile=None, if self.imageToImageTransformMatrix is None: if self.dissectionPhoto is not None: - reference = os.path.join(self.savePath, 'grayscale_input.mgz') + reference = self.savePath + '/grayscale_input.mgz' if self.dissectionPhoto=='left': - moving = os.path.join(self.atlasDir, 'exvivo.template.lh.suptent.nii') + moving = self.atlasDir + '/exvivo.template.lh.suptent.nii' elif self.dissectionPhoto=='right': - moving = os.path.join(self.atlasDir, 'exvivo.template.rh.suptent.nii') + moving = self.atlasDir + '/exvivo.template.rh.suptent.nii' elif self.dissectionPhoto=='both': - moving = os.path.join(self.atlasDir, 'exvivo.template.suptent.nii') + moving = self.atlasDir + '/exvivo.template.suptent.nii' else: sf.system.fatal('dissection photo mode must be left, right, or both') - transformFile = os.path.join(self.savePath, 'atlas2image.lta') + transformFile = self.savePath + '/atlas2image.lta' cmd = 'mri_coreg --seed 2021 --mov ' + moving + ' --ref ' + reference + ' --reg ' + transformFile + \ ' --dof 12 --threads ' + str(self.nthreads) os.system(cmd) @@ -318,7 +318,7 @@ def preProcess(self): else: self.imageBuffers, self.transform, self.voxelSpacing, self.cropping = readCroppedImages( self.imageFileNames, - os.path.join(self.atlasDir, 'template.nii.gz'), + os.path.join(self.atlasDir, 'template.nii'), self.imageToImageTransformMatrix ) @@ -479,11 +479,11 @@ def writeResults(self, biasFields, posteriors): print(self.scalingFactors[contrastNumber], file=fid) else: # photos - self.writeImage(expBiasFields[..., 0], os.path.join(self.savePath, 'illlumination_field.mgz')) + self.writeImage(expBiasFields[..., 0], self.savePath + '/illlumination_field.mgz') original_vol = sf.load_volume(self.originalImageFileNames[0]) - bias_native = sf.load_volume(os.path.join(self.savePath, 'illlumination_field.mgz')) + bias_native = sf.load_volume(self.savePath + '/illlumination_field.mgz') original_vol = original_vol / (1e-6 + bias_native) - original_vol.save(os.path.join(self.savePath, 'illlumination_corrected.mgz')) + original_vol.save(self.savePath + '/illlumination_corrected.mgz') if self.savePosteriors: posteriorPath = os.path.join(self.savePath, 'posteriors') @@ -502,15 +502,50 @@ def writeResults(self, biasFields, posteriors): volumeOfOneVoxel = np.abs(np.linalg.det(exampleImage.transform_matrix.as_numpy_array[:3, :3])) volumesInCubicMm = np.sum(posteriors, axis=0) * volumeOfOneVoxel + # Write intracranial volume + sbtiv = icv(zip(*[names, volumesInCubicMm])) + with open(os.path.join(self.savePath, 'sbtiv.stats'), 'w') as fid: + fid.write('# Measure Intra-Cranial, %.6f, mm^3\n' % sbtiv) + # Write structural volumes with open(os.path.join(self.savePath, 'samseg.stats'), 'w') as fid: + fid.write('# Measure %s, %.6f, mm^3\n' % ('Intra-Cranial', sbtiv)) for volume, name in zip(volumesInCubicMm, names): fid.write('# Measure %s, %.6f, mm^3\n' % (name, volume)) - # Write intracranial volume - sbtiv = icv(zip(*[names, volumesInCubicMm])) - with open(os.path.join(self.savePath, 'sbtiv.stats'), 'w') as fid: - fid.write('# Measure Intra-Cranial, %.6f, mm^3\n' % sbtiv) + # Write structural volumes in a csv + with open(os.path.join(self.savePath, 'samseg.csv'), 'w') as fid: + fid.write('ROI,volume_mm3,volume_ICV_x1000\n'); + fid.write('%s,%.6f,1000\n' % ('Intra-Cranial', sbtiv)) + for volume, name in zip(volumesInCubicMm, names): + fid.write('%s,%.6f,%.6f\n' % (name, volume,1000*volume/sbtiv)) + + # Write out a freesurfer-style stats file (good for use with asegstats2table) + with open(os.path.join(self.savePath, 'samseg.fs.stats'), 'w') as fid: + fid.write('# Measure EstimatedTotalIntraCranialVol, eTIV, Estimated Total Intracranial Volume, %0.6f, mm^3\n'%(sbtiv)); + # Could add other measures here like total brain volume + fid.write('# ColHeaders Index SegId NVoxels Volume_mm3 StructName\n'); + k = 0; + voxsize = self.voxelSpacing[0]*self.voxelSpacing[1]*self.voxelSpacing[2]; + # Sort them by seg index like with the aseg.stats. Exclude Unknown, but keep WM and Cortex + seglist = []; + k = 0; + for volume, name in zip(volumesInCubicMm, names): + if(name == 'Unknown'): + k = k+1; + continue + idx = fslabels[k]; + seglist.append([idx,volume,name]); + k = k+1; + seglist.sort() + k = 0; + for seg in seglist: + idx = seg[0]; + volume = seg[1]; + name = seg[2]; + nvox = round(volume/voxsize,2); + fid.write('%3d %4d %7d %9.1f %s\n' % (k+1,idx,nvox,volume,name)); + k = k+1; return volumesInCubicMm @@ -525,7 +560,7 @@ def saveWarpField(self, filename): # extract geometries source = sf.load_volume(self.imageFileNames[0]).geom - target = sf.load_volume(os.path.join(self.atlasDir, 'template.nii.gz')).geom + target = sf.load_volume(os.path.join(self.atlasDir, 'template.nii')).geom # extract vox-to-vox template transform # TODO: Grabbing the transform from the saved .mat file in either the cross or base diff --git a/samseg/utilities.py b/samseg/utilities.py index 577e770..cf550f4 100644 --- a/samseg/utilities.py +++ b/samseg/utilities.py @@ -105,6 +105,16 @@ def icv(structures, includeStructures=None): 'Left-Thalamus', 'non-WM-hypointensities', '5th-Ventricle', - 'Lesions' + 'Lesions', + 'Left-WMCrowns', + 'Right-WMCrowns', + 'Left-Vermis-Area', + 'Right-Vermis-Area', + 'Corpus_Callosum', + 'Pons', + 'Pons-Belly-Area', + 'Vein', + 'ctx_lh_high_myelin', + 'ctx_rh_high_myelin', ] return sum(structure[1] for structure in structures if structure[0] in includeStructures) From 60df5897faa01d0658f50c4ef8a8b20322dc8f24 Mon Sep 17 00:00:00 2001 From: Yujing Huang Date: Tue, 2 Jan 2024 15:03:34 -0500 Subject: [PATCH 3/4] merge in Doug's changes from freesurfer/freesurfer/python/gems apply standalone samseg fixes --- samseg/Samseg.py | 52 ++++++++++++++++++++++++------------------------ 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/samseg/Samseg.py b/samseg/Samseg.py index c1975c5..c2dbbfa 100644 --- a/samseg/Samseg.py +++ b/samseg/Samseg.py @@ -4,16 +4,16 @@ import pickle import scipy.io import surfa as sf -from scipy.ndimage.morphology import binary_dilation as dilation +from scipy.ndimage import binary_dilation as dilation -import gems -from gems.utilities import Specification -from gems.BiasField import BiasField -from gems.ProbabilisticAtlas import ProbabilisticAtlas -from gems.GMM import GMM -from gems.Affine import Affine -from gems.SamsegUtility import * -from gems.merge_alphas import kvlMergeAlphas, kvlGetMergingFractionsTable +from samseg import gems +from .utilities import Specification +from .BiasField import BiasField +from .ProbabilisticAtlas import ProbabilisticAtlas +from .GMM import GMM +from .Affine import Affine +from .SamsegUtility import * +from .merge_alphas import kvlMergeAlphas, kvlGetMergingFractionsTable eps = np.finfo(float).eps @@ -74,10 +74,10 @@ def __init__(self, raise ValueError('In photo mode, you cannot provide more than one input image volume') input_vol = sf.load_volume(self.imageFileNames[0]) - input_vol.save(self.savePath + '/original_input.mgz') + input_vol.save(os.path.join(self.savePath, 'original_input.mgz')) if input_vol.nframes > 1: input_vol = input_vol.mean(frames=True) - input_vol.save(self.savePath + '/grayscale_input.mgz') + input_vol.save(os.path.join(self.savePath, 'grayscale_input.mgz')) # We also a small band of noise around the mask; otherwise the background/skull/etc may fit the cortex self.photo_mask = input_vol.data > 0 @@ -87,14 +87,14 @@ def __init__(self, rng = np.random.default_rng(2021) input_vol.data[ring] = max_noise * rng.random(input_vol.data[ring].shape[0]) self.imageFileNames = [] - self.imageFileNames.append(self.savePath + '/grayscale_input_with_ring.mgz') + self.imageFileNames.append(os.path.join(self.savePath,'grayscale_input_with_ring.mgz')) input_vol.save(self.imageFileNames[0]) # Initialize some objects self.affine = Affine( imageFileName=self.imageFileNames[0], meshCollectionFileName=os.path.join(self.atlasDir, 'atlasForAffineRegistration.txt.gz'), - templateFileName=os.path.join(self.atlasDir, 'template.nii' ) ) + templateFileName=os.path.join(self.atlasDir, 'template.nii.gz' ) ) self.probabilisticAtlas = ProbabilisticAtlas() # Get full model specifications and optimization options (using default unless overridden by user) @@ -102,11 +102,11 @@ def __init__(self, self.optimizationOptions = getOptimizationOptions(atlasDir, userOptimizationOptions) if dissectionPhoto and (gmmFileName is None): if dissectionPhoto == 'left': - gmmFileName = self.atlasDir + '/photo.lh.sharedGMMParameters.txt' + gmmFileName = os.path.join(self.atlasDir, 'photo.lh.sharedGMMParameters.txt') elif dissectionPhoto == 'right': - gmmFileName = self.atlasDir + '/photo.rh.sharedGMMParameters.txt' + gmmFileName = os.path.join(self.atlasDir, 'photo.rh.sharedGMMParameters.txt') elif dissectionPhoto == 'both': - gmmFileName = self.atlasDir + '/photo.both.sharedGMMParameters.txt' + gmmFileName = os.path.join(self.atlasDir, 'photo.both.sharedGMMParameters.txt') else: sf.system.fatal('dissection photo mode must be left, right, or both') self.modelSpecifications = getModelSpecifications( @@ -229,16 +229,16 @@ def segment(self, costfile=None, timer=None, reg_only=False, transformFile=None, if self.imageToImageTransformMatrix is None: if self.dissectionPhoto is not None: - reference = self.savePath + '/grayscale_input.mgz' + reference = os.path.join(self.savePath, 'grayscale_input.mgz') if self.dissectionPhoto=='left': - moving = self.atlasDir + '/exvivo.template.lh.suptent.nii' + moving = os.path.join(self.atlasDir, 'exvivo.template.lh.suptent.nii') elif self.dissectionPhoto=='right': - moving = self.atlasDir + '/exvivo.template.rh.suptent.nii' + moving = os.path.join(self.atlasDir, 'exvivo.template.rh.suptent.nii') elif self.dissectionPhoto=='both': - moving = self.atlasDir + '/exvivo.template.suptent.nii' + moving = os.path.join(self.atlasDir, 'exvivo.template.suptent.nii') else: sf.system.fatal('dissection photo mode must be left, right, or both') - transformFile = self.savePath + '/atlas2image.lta' + transformFile = os.path.join(self.savePath, 'atlas2image.lta') cmd = 'mri_coreg --seed 2021 --mov ' + moving + ' --ref ' + reference + ' --reg ' + transformFile + \ ' --dof 12 --threads ' + str(self.nthreads) os.system(cmd) @@ -318,7 +318,7 @@ def preProcess(self): else: self.imageBuffers, self.transform, self.voxelSpacing, self.cropping = readCroppedImages( self.imageFileNames, - os.path.join(self.atlasDir, 'template.nii'), + os.path.join(self.atlasDir, 'template.nii.gz'), self.imageToImageTransformMatrix ) @@ -479,11 +479,11 @@ def writeResults(self, biasFields, posteriors): print(self.scalingFactors[contrastNumber], file=fid) else: # photos - self.writeImage(expBiasFields[..., 0], self.savePath + '/illlumination_field.mgz') + self.writeImage(expBiasFields[..., 0], os.path.join(self.savePath, 'illlumination_field.mgz')) original_vol = sf.load_volume(self.originalImageFileNames[0]) - bias_native = sf.load_volume(self.savePath + '/illlumination_field.mgz') + bias_native = sf.load_volume(os.path.join(self.savePath, 'illlumination_field.mgz')) original_vol = original_vol / (1e-6 + bias_native) - original_vol.save(self.savePath + '/illlumination_corrected.mgz') + original_vol.save(os.path.join(self.savePath, 'illlumination_corrected.mgz')) if self.savePosteriors: posteriorPath = os.path.join(self.savePath, 'posteriors') @@ -560,7 +560,7 @@ def saveWarpField(self, filename): # extract geometries source = sf.load_volume(self.imageFileNames[0]).geom - target = sf.load_volume(os.path.join(self.atlasDir, 'template.nii')).geom + target = sf.load_volume(os.path.join(self.atlasDir, 'template.nii.gz')).geom # extract vox-to-vox template transform # TODO: Grabbing the transform from the saved .mat file in either the cross or base From c4f3f8ba5d54488c7626f9da930c11922178997e Mon Sep 17 00:00:00 2001 From: Yujing Huang Date: Tue, 2 Jan 2024 17:18:18 -0500 Subject: [PATCH 4/4] add samseg/cli scripts gems_compute_atlas_probs.py, merge_add_mesh_alphas.py, sbtiv.py --- samseg/cli/gems_compute_atlas_probs.py | 238 +++++++++++++++++++++++++ samseg/cli/merge_add_mesh_alphas.py | 194 ++++++++++++++++++++ samseg/cli/sbtiv.py | 44 +++++ setup.cfg | 5 +- 4 files changed, 480 insertions(+), 1 deletion(-) create mode 100644 samseg/cli/gems_compute_atlas_probs.py create mode 100644 samseg/cli/merge_add_mesh_alphas.py create mode 100755 samseg/cli/sbtiv.py diff --git a/samseg/cli/gems_compute_atlas_probs.py b/samseg/cli/gems_compute_atlas_probs.py new file mode 100644 index 0000000..a6353d2 --- /dev/null +++ b/samseg/cli/gems_compute_atlas_probs.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python + +########################### +# +# Compute mesh node probabilities (alphas) from "ground truth" segmentation images using estimated node deformations. +# Node probabilities are estimated using an expectation maximization (EM) algorithm. +# +# The script requires that SAMSEG has been run with the --history flag on the subjects of interest +# The script works with 1 or more structures. +# For estimating mesh node probabilities for more than one structure +# the --multi-structure flag should be on and a list of labels should be given as input using the flag --labels +# +########################### + +import sys +import os +import numpy as np +import nibabel as nib +import surfa as sf +import time +import argparse +import samseg +from samseg import gems + +def parseArguments(argv): + parser = argparse.ArgumentParser() + parser.add_argument('--subjects-dir', help='Directory with saved SAMSEG runs with --history flag.', required=True) + parser.add_argument('--mesh-collections', nargs='+', help='Mesh collection file(s).', required=True) + parser.add_argument('--out-dir', help='Output directory.', required=True) + parser.add_argument('--segmentations-dir', help='Directory with GT segmentations.') + parser.add_argument('--gt-from-FS', action='store_true', default=False, help='GT from FreeSurfer segmentations.') + parser.add_argument('--segmentation-name', default='aseg.mgz',help='Filename of the segmentations, assumed to be the same for each subject.') + parser.add_argument('--multi-structure', action='store_true', default=False, help="Estimate alphas from more than 1 structure.") + parser.add_argument('--labels', type=int, nargs='+', help="Labels numbers. Needs --multi-structure flag on.") + parser.add_argument('--from-samseg', action='store_true', default=False, help="SAMSEG runs obtained from command samseg instead of run_samseg.") + parser.add_argument('--EM-iterations', type=int, default=10, help="EM iterations.") + parser.add_argument('--show-figs', action='store_true', default=False, help='Show figures during run.') + parser.add_argument('--save-figs', action='store_true', default=False, help='Save rasterized prior of each subject.') + parser.add_argument('--save-average-figs', action='store_true', default=False, help='Save average rasterized prior.') + parser.add_argument('--subjects_file', help='Text file with list of subjects.') + parser.add_argument('--labels_file', help='Text file with list of labels (instead of --labels).') + parser.add_argument('--samseg-subdir', default='samseg',help='Name of samseg subdir in subject/mri folder') + + args = parser.parse_args() + + return args + + +def main(): + args = parseArguments(sys.argv[1:]) + + if not os.path.exists(args.out_dir): + os.makedirs(args.out_dir) + + if args.show_figs: + visualizer = samseg.initVisualizer(True, True) + else: + visualizer = samseg.initVisualizer(False, False) + + if args.save_figs: + import nibabel as nib + + if(args.subjects_file != None): + with open(args.subjects_file) as f: + lines = f.readlines(); + subject_list = []; + for line in lines: + subject_list.append(line.strip()); + else: + subject_list = [pathname for pathname in os.listdir(args.subjects_dir) if os.path.isdir(os.path.join(args.subjects_dir, pathname))] + subject_list.sort() + number_of_subjects = len(subject_list) + + logfile = os.path.join(args.out_dir,'gems_compute_atlas_probs.log'); + logfp = open(logfile,"w") + logfp.write("cd "+os.getcwd()+"\n"); + logfp.write(' '.join(sys.argv)+"\n") + logfp.flush(); + + outsubjectsfile = os.path.join(args.out_dir,'subjects.txt'); + with open(outsubjectsfile,"w") as f: + for subject in subject_list: + f.write(subject+"\n") + + if(args.labels_file == None): + labels = args.labels; + else: + with open(args.labels_file) as f: + lines = f.readlines(); + labels = []; + for line in lines: + for item in line.split(): + labels.append(int(item)) + + outlabelfile = os.path.join(args.out_dir,'labels.txt'); + with open(outlabelfile,"w") as f: + for label in labels: + f.write(str(label)+"\n") + + print("Labels") + print(labels) + + if args.multi_structure: + number_of_classes = len(labels) + 1 # + 1 for background + else: + number_of_classes = 2 # 1 is background + + # We need an init of the probabilistic segmentation class + # to call instance methods + atlas = samseg.ProbabilisticAtlas() + + t0 = time.time(); + for level, mesh_collection_file in enumerate(args.mesh_collections): + + print("Working on mesh collection at level " + str(level + 1)) + + # Read mesh collection + print("Loading mesh collection at: " + str(mesh_collection_file)) + mesh_collection = gems.KvlMeshCollection() + mesh_collection.read(mesh_collection_file) + + # We are interested only on the reference mesh + mesh = mesh_collection.reference_mesh + number_of_nodes = mesh.point_count + + print('Number of subjects: ' + str(len(subject_list))) + + # Define what we are interesting in, i.e., the label statistics of the structure(s) of interest + label_statistics_in_mesh_nodes = np.zeros([number_of_nodes, number_of_classes, number_of_subjects]) + + for subject_number, subject_dir in enumerate(subject_list): + + # Show progress to anyone who's watching + telapsed = (time.time()-t0)/60; + print("====================================================================") + print("") + #print("Subject number: " + str(subject_number + 1)+"/"+str(number_of_subjects) + " "+ subject_dir + " " str(telapsed)) + print("Level %d Subject %d/%d %s %6.1f" % (level+1,subject_number+1,number_of_subjects,subject_dir,telapsed)) + print("") + print("====================================================================") + logfp.write("Level %d Subject %d/%d %s %6.1f\n" % (level+1,subject_number+1,number_of_subjects,subject_dir,telapsed)) + logfp.flush(); + + # Read the manually annotated segmentation for the specific subject + if args.gt_from_FS: + segpath = os.path.join(args.segmentations_dir, subject_dir, 'mri', args.segmentation_name); + print("seg %s" % segpath); + segmentation_image = nib.load(segpath).get_fdata() + affine = nib.load(os.path.join(args.segmentations_dir, subject_dir, 'mri', args.segmentation_name)).affine + else: + segmentation_image = nib.load(os.path.join(args.segmentations_dir, subject_dir, args.segmentation_name)).get_fdata() + affine = nib.load(os.path.join(args.segmentations_dir, subject_dir, args.segmentation_name)).affine + + if args.from_samseg: + history = np.load(os.path.join(args.subjects_dir, subject_dir, 'mri', args.samseg_subdir, 'history.p'), allow_pickle=True) + else: + history = np.load(os.path.join(args.subjects_dir, subject_dir, 'history.p'), allow_pickle=True) + + # Get the node positions in image voxels + model_specifications = history['input']['modelSpecifications'] + transform_matrix = history['transform'] + transform = gems.KvlTransform(samseg.requireNumpyArray(transform_matrix)) + deformations = history['historyWithinEachMultiResolutionLevel'][level]['deformation'] + node_positions = atlas.getMesh( + mesh_collection_file, + transform, + K=model_specifications.K, + initialDeformation=deformations, + initialDeformationMeshCollectionFileName=mesh_collection_file).points + + # The image is cropped as well so the voxel coordinates + # do not exactly match with the original image, + # i.e., there's a shift. Let's undo that. + cropping = history['cropping'] + node_positions += [slc.start for slc in cropping] + + # Estimate n-class alphas representing the segmentation map, initialized with a flat prior + segmentation_map = np.zeros([segmentation_image.shape[0], segmentation_image.shape[1], segmentation_image.shape[2], + number_of_classes], np.uint16) + if args.multi_structure: + for label_number, label in enumerate(labels): + # + 1 here since we want background as first class + segmentation_map[:, :, :, label_number + 1] = (segmentation_image == label) * 65535 + # Make sure to fill what is left in background class + segmentation_map[:, :, :, 0] = 65535 - np.sum(segmentation_map[:, :, :, 1:], axis=3) + else: + segmentation_map[:, :, :, 0] = (1 - segmentation_image) * 65535 + segmentation_map[:, :, :, 1] = segmentation_image * 65535 + + mesh = mesh_collection.reference_mesh + mesh.points = node_positions + mesh.alphas = mesh.fit_alphas(segmentation_map, args.EM_iterations) + + # Show rasterized prior with updated alphas + if args.show_figs: + rasterized_prior = mesh.rasterize(segmentation_image.shape) / 65535 + rasterized_prior = rasterized_prior[:, :, :, 1:] # No need to show background + visualizer.show(images=rasterized_prior) + + # Save rasterized prior with updated alphas + if args.save_figs: + rasterized_prior = mesh.rasterize(segmentation_image.shape) / 65535 + rasterized_prior = rasterized_prior[:, :, :, 1:] # No need to save background + img = nib.Nifti1Image(rasterized_prior, affine) + nib.save(img, os.path.join(args.out_dir, "level_" + str(level + 1) + "_rasterized_prior_sub" + str(subject_number + 1))) + + # Save label statistics of subject + label_statistics_in_mesh_nodes[:, :, subject_number] = mesh.alphas.copy() + + # Show rasterized prior with alphas as mean + if args.show_figs: + mesh.alphas = np.mean(label_statistics_in_mesh_nodes, axis=2) + rasterized_prior = mesh.rasterize(segmentation_image.shape) / 65535 + rasterized_prior = rasterized_prior[:, :, :, 1:] # No need to show background + visualizer.show(images=rasterized_prior) + + # Save rasterized prior with alphas as mean + if args.save_average_figs: + + mesh.alphas = np.mean(label_statistics_in_mesh_nodes, axis=2) + rasterized_prior = mesh.rasterize(segmentation_image.shape) / 65535 + rasterized_prior = rasterized_prior[:, :, :, 1:] # No need to save background + img = nib.Nifti1Image(rasterized_prior, np.eye(4)) + nib.save(img, os.path.join(args.out_dir, "level_" + str(level + 1) + "_average_rasterized_prior")) + + # Save label statistics in a npy file + np.save(os.path.join(args.out_dir, "label_statistics_atlas_" + str(level + 1)), label_statistics_in_mesh_nodes) + + # end level loop + + logfp.write("gems_compute_atlas_probs done\n"); + logfp.close(); + print ("gems_compute_atlas_probs done"); + + + +if __name__ == '__main__': + main() diff --git a/samseg/cli/merge_add_mesh_alphas.py b/samseg/cli/merge_add_mesh_alphas.py new file mode 100644 index 0000000..f33dd6f --- /dev/null +++ b/samseg/cli/merge_add_mesh_alphas.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python + +description = """Merge prior spatial probabilties (alphas) with those of an existing +SAMSEG mesh to create a new mesh/atlas. Should be run on the output of +gems_compute_atlas_probs. + +An example of use is the following: +fspython merge_add_mesh_alphas [...] --merge-labels 2 3 --add-indexes 3 + +Here we merge labels 2 and 3 and then add the third index of +estimated_alphas.npy to the mesh. For example, in case we used +gems_compute_alphas with --labels 2 3 99, we are adding label 99 to +the mesh. Indexes start from 1 since 0 is reserved for the background +class. + +I also add a --merge-indexes flag in case the merge and add indexes +are mixed. If None the first indexes are used for merging. For +example: gems_compute_alphas --labels 2 99 3, I can do --merge-labels +2 3 --merge-indexes 1 3 --add-indexes 2. + +Few things to consider: + +- First merging and then add != first add and then merging. I + implemented the former, but maybe it's better to implement both and + leave the choice to the user. Not sure what the default should be + though. + +- How we add the extra class(es) can be done in different ways. For + one class I did the following (same thing I did for MS lesion): add + the new alphas and rescale everything else with (1-newalpha). For + more than one class, the new alphas are just added and then the + alphas are re-normalized. Again, maybe it's better to expose this to + the user. + +- You still need to manually modify the compressionLookupTable.txt + file right now (i.e., adding extra entries at the end of the + file). Might be worth considering doing this automatically in the + script. + +- Last functionality to add would be to remove some structures from + the mesh, and automatically update the compressionLookupTable.txt + file. + +""" + +import sys +import os +import numpy as np +import argparse +import surfa as sf +from samseg import gems + +from samseg.io import kvlReadCompressionLookupTable + + +eps = np.finfo(float).eps + +description = """ +Merge/add prior spatial probabilities (alphas) with those of an existing +SAMSEG mesh to create a new mesh/atlas. Should be run on the output of +gems_compute_atlas_probs. +""" + +def main(): + parser = argparse.ArgumentParser(description=description) + parser.add_argument('--out-dir', help='Output directory.', required=True) + parser.add_argument('--estimated-alpha-dirs', nargs='+', help='Estimated alphas dirs', required=True) + parser.add_argument('--samseg-atlas-dir', help='Samseg original directory; default in FS/average') + parser.add_argument('--merge-labels', nargs='+', type=int, help='Labels to merge estimated alphas to.') + parser.add_argument('--merge-indexes', nargs='+', type=int, help='Indexes of estimated alphas to merge to current alphas. First index is 1 (0 is background). If None, first indexes are used.') + parser.add_argument('--merge-names', nargs='+', help='Structures name to merge estimated alphas to.') + parser.add_argument('--add-indexes', nargs='+', type=int, help='Indexes of estimated alphas to add to current alphas. First index is 1 (0 is background).') + parser.add_argument('--level', dest='levellist', type=int, nargs='+', help='Atlas level; default is 1 and 2') + args = parser.parse_args() + + # Assuming 20 subjects were used to build the atlas + samseg_subjects = 20 + + # Get the samseg atlas dir + if (args.samseg_atlas_dir != None): + samseg_atlas_dir = args.samseg_atlas_dir + else: + fsh = os.environ.get('FREESURFER_HOME') + samseg_atlas_dir = fsh + "/average/samseg/20Subjects_smoothing2_down2_smoothingForAffine2" + + if args.levellist != None: + levellist = args.levellist + else: + levellist = [1, 2] + + # Make the output dir + os.makedirs(args.out_dir, exist_ok=True) + logfile = os.path.join(args.out_dir, 'merge_add_mesh_alphas.log') + with open(logfile, "w") as f: + f.write("cd " + os.getcwd() + "\n") + f.write(' '.join(sys.argv) + "\n") + f.write(samseg_atlas_dir + "\n") + if args.merge_labels is not None: + outlabelfile = os.path.join(args.out_dir, 'merged_labels.txt') + with open(outlabelfile, "w") as f: + for label in args.merge_labels: + f.write(str(label) + "\n") + + for level in levellist: + print("level = %d =========================================" % level) + + # Read in alphas + estimated_alphas = [] + for adir in args.estimated_alpha_dirs: + afile = os.path.join(adir, 'label_statistics_atlas_%d.npy' % level) + a = np.load(afile) + print("a " + str(a.shape)) + if (len(estimated_alphas) == 0): + estimated_alphas = a + else: + estimated_alphas = np.concatenate((estimated_alphas, a), axis=2) + # endif + # endfor + print(str(estimated_alphas.shape)) + + # Retrieve all the SAMSEG related files + # Here we are making some assumptions about file names + mesh_collection_path = os.path.join(samseg_atlas_dir, "atlas_level" + str(level) + ".txt.gz") + freesurfer_labels, names, colors = kvlReadCompressionLookupTable( + os.path.join(samseg_atlas_dir, 'compressionLookupTable.txt')) + # Get also compressed labels, as they are in the same order as FreeSurfer_labels + compressed_labels = list(np.arange(0, len(freesurfer_labels))) + + # Read mesh collection + print("Loading mesh collection at: " + str(mesh_collection_path)) + mesh_collection = gems.KvlMeshCollection() + mesh_collection.read(mesh_collection_path) + + # Load reference mesh + mesh = mesh_collection.reference_mesh + alphas = mesh.alphas.copy() + + # First merge classes + if args.merge_labels is not None: + for l, merge_label in enumerate(args.merge_labels): + if args.merge_indexes is not None: + idx2 = args.merge_indexes[l] + else: + idx2 = l + 1 + idx = freesurfer_labels.index(merge_label) + alphas[:, idx] = (alphas[:, idx] * samseg_subjects + np.sum(estimated_alphas[:, idx2, :], axis=1)) \ + / (samseg_subjects + estimated_alphas.shape[2]) + if args.merge_names is not None: + for merge_name in args.merge_names: + if args.merge_indexes is not None: + idx2 = args.merge_indexes[l] + else: + idx2 = l + 1 + idx = names.index(merge_name) + alphas[:, idx] = (alphas[:, idx] * samseg_subjects + np.sum(estimated_alphas[:, idx2, :], axis=1)) \ + / (samseg_subjects + estimated_alphas.shape[2]) + # for mergnames + # endif + + # Re-normalize alphas + if args.merge_labels is not None or args.merge_names is not None: + normalizer = np.sum(alphas, axis=1) + eps + alphas = alphas / normalizer[:, None] + + # Here we add classes. This is done after merging classes (if any). + # Note that switching the order of merging and adding produces a different output + if args.add_indexes is not None: + tmp = np.zeros([alphas.shape[0], alphas.shape[1] + len(args.add_indexes)]) + tmp[:, :alphas.shape[1]] = alphas.copy() + if len(args.add_indexes) == 1: + idx = args.add_indexes[0] + # Only one class to add (every other class is lower down by a factor (1-estimated-alpha) + tmp[:, -1] = np.mean(estimated_alphas[:, idx, :], axis=1) + tmp[:, :alphas.shape[1]] *= (1 - tmp[:, -1])[:, None] + else: + # More than one class, add all the estimated alphas and re-normalize + for l, idx in enumerate(args.add_indexes): + tmp[:, alphas.shape[1] + l] = np.mean(estimated_alphas[:, idx, :], axis=1) + normalizer = np.sum(tmp, axis=1) + eps + tmp = tmp / normalizer[:, None] + + alphas = tmp + + # Add alphas in mesh + mesh.alphas = alphas + # Save mesh + mesh_collection.write(os.path.join(args.out_dir, "atlas_level" + str(level) + ".txt")) + # end loop over levels + + print("merge_add_mesh_alphas done") + + +if __name__ == '__main__': + main() diff --git a/samseg/cli/sbtiv.py b/samseg/cli/sbtiv.py new file mode 100755 index 0000000..cafc0a7 --- /dev/null +++ b/samseg/cli/sbtiv.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python + +import argparse +import surfa as sf +from samseg import icv + +description = ''' +Calculates the total intracranial volume of a subject by summing individual volumes computed by samseg. \ +A file containing a list of intracranial labelnames can be provided via the '--labels' flag, but if omitted, \ +a default list is used. Labelnames must be identical to those defined in the samseg atlas. +''' + +def main(): + # parse command line args + parser = argparse.ArgumentParser(description=description) + parser.add_argument('stats', metavar='FILE', help='Volume stats input file.') + parser.add_argument('-o', '--out', metavar='FILE', help='Intracranial stats output file.') + parser.add_argument('-l', '--labels', metavar="FILE", help='File containing a list of intracranial structure labelnames to include in the calculation.') + args = parser.parse_args() + + # read in structure names and volumes from samseg stats + structures = [] + with open(args.input) as fid: + for line in fid.readlines(): + name, vol, _ = line.split(',') + _, _, name = name.split(' ') + structures.append([name.strip(), float(vol)]) + + # read in structure names that are considered intra-cranial + includeStructures = None + if args.map: + with open(args.map) as fid: includeStructures = [line.strip() for line in fid.readlines()] + + # compute intra-cranial volume + sbtiv = icv(structures, includeStructures) + + # write out and exit + print('intracranial volume: %.6f mm^3' % sbtiv) + if args.output: + with open(args.output, 'w') as fid: fid.write('# Measure Intra-Cranial, %.6f, mm^3\n' % sbtiv) + + +if __name__ == '__main__': + main() diff --git a/setup.cfg b/setup.cfg index 954b4ad..85f287a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,7 +33,10 @@ console_scripts= computeTissueConcentrations = samseg.cli.computeTissueConcentrations:main prepareAtlasDirectory = samseg.cli.prepareAtlasDirectory:main run_samseg_long = samseg.cli.run_samseg_long:main - segment_subregions = samseg.cl.segment_subregions:main + segment_subregions = samseg.cli.segment_subregions:main + sbtiv = samseg.cli.sbtiv:main + gems_compute_atlas_probs = samseg.cli.gems_compute_atlas_probs:main + merge_add_mesh_alphas = samseg.cli.merge_add_mesh_alphas:main