Skip to content

Commit

Permalink
Merge pull request #6 from jnolan14/dev
Browse files Browse the repository at this point in the history
pybind interface for DSWbeta, Wishart, and Frobenius calculators
  • Loading branch information
ste93ste authored Aug 31, 2023
2 parents 9fbfa25 + dd89913 commit f31228e
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 8 deletions.
6 changes: 6 additions & 0 deletions samseg/cxx/module.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ PYBIND11_MODULE(gemsbindings, m) {
py::arg("K0"),
py::arg("K1"),
py::arg("transform"))

// Set parameters for DTI calculators
.def("SetWishartParams", &KvlCostAndGradientCalculator::SetWishartParams)
.def("SetFrobeniusParams", &KvlCostAndGradientCalculator::SetFrobeniusParams)
.def("SetDSWparams", &KvlCostAndGradientCalculator::SetDSWparams)

.def("evaluate_mesh_position", &KvlCostAndGradientCalculator::EvaluateMeshPosition)
// Aliases to help with profiling
.def("evaluate_mesh_position_a", &KvlCostAndGradientCalculator::EvaluateMeshPosition)
Expand Down
244 changes: 237 additions & 7 deletions samseg/cxx/pyKvlCalculator.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#include "kvlMutualInformationCostAndGradientCalculator.h"
#include "kvlAtlasMeshToPointSetCostAndGradientCalculator.h"
#include "kvlAverageAtlasMeshPositionCostAndGradientCalculator.h"
#include "kvlAtlasMeshToWishartGaussMixtureCostAndGradientCalculator.h"
#include "kvlAtlasMeshToFrobeniusGaussMixtureCostAndGradientCalculator.h"
#include "kvlAtlasMeshToDSWbetaGaussMixtureCostAndGradientCalculator.h"

#include "kvlAtlasMeshCollection.h"
#include "pyKvlImage.h"
Expand All @@ -30,7 +33,8 @@ class KvlCostAndGradientCalculator {
py::array_t<int> numberOfGaussiansPerClass=py::array_t<int>(),
py::array_t<double> targetPoints=py::array_t<double>()
){
if (typeName == "AtlasMeshToIntensityImage" || typeName == "AtlasMeshToIntensityImageLogDomain") {
if (typeName == "AtlasMeshToIntensityImage" || typeName == "AtlasMeshToIntensityImageLogDomain" || \
typeName == "DSWbeta" || typeName == "Frobenius" || typeName == "Wishart") {

py::buffer_info means_info = means.request();

Expand Down Expand Up @@ -71,16 +75,50 @@ class KvlCostAndGradientCalculator {
numberOfGaussiansPerClass_converted[ classNumber ] = numberOfGaussiansPerClass.at(classNumber);
}

kvl::AtlasMeshToIntensityImageCostAndGradientCalculator::Pointer myCalculator;
//kvl::AtlasMeshToIntensityImageCostAndGradientCalculator::Pointer myCalculator;
std::vector< ImageType::ConstPointer> images_converted;
for(auto image: images){
ImageType::ConstPointer constImage = static_cast< const ImageType* >( image.m_image.GetPointer() );
images_converted.push_back( constImage );
}
if (typeName == "AtlasMeshToIntensityImage")
{
myCalculator = kvl::AtlasMeshToIntensityImageCostAndGradientCalculator::New();
kvl::AtlasMeshToIntensityImageCostAndGradientCalculator::Pointer myCalculator = kvl::AtlasMeshToIntensityImageCostAndGradientCalculator::New();
myCalculator->SetImages( images_converted );
myCalculator->SetParameters( means_converted, variances_converted, mixtureWeights_converted, numberOfGaussiansPerClass_converted );
calculator = myCalculator;
}

else if (typeName == "DSWbeta")
{

kvl::AtlasMeshToDSWbetaGaussMixtureCostAndGradientCalculator::Pointer myCalculator = kvl::AtlasMeshToDSWbetaGaussMixtureCostAndGradientCalculator::New();
myCalculator->SetImages( images_converted );
myCalculator->SetParameters( means_converted, variances_converted, mixtureWeights_converted, numberOfGaussiansPerClass_converted );
calculator = myCalculator;
}
else if (typeName == "Frobenius")
{
kvl::AtlasMeshToFrobeniusGaussMixtureCostAndGradientCalculator::Pointer myCalculator = kvl::AtlasMeshToFrobeniusGaussMixtureCostAndGradientCalculator::New();
myCalculator->SetImages( images_converted );
myCalculator->SetParameters( means_converted, variances_converted, mixtureWeights_converted, numberOfGaussiansPerClass_converted );
calculator = myCalculator;
}
else if (typeName == "Wishart")
{
kvl::AtlasMeshToWishartGaussMixtureCostAndGradientCalculator::Pointer myCalculator = kvl::AtlasMeshToWishartGaussMixtureCostAndGradientCalculator::New();
myCalculator->SetImages( images_converted );
myCalculator->SetParameters( means_converted, variances_converted, mixtureWeights_converted, numberOfGaussiansPerClass_converted );
calculator = myCalculator;
}
else
{
myCalculator = kvl::AtlasMeshToIntensityImageLogDomainCostAndGradientCalculator::New();
{
kvl::AtlasMeshToIntensityImageLogDomainCostAndGradientCalculator::Pointer myCalculator = kvl::AtlasMeshToIntensityImageLogDomainCostAndGradientCalculator::New();
myCalculator->SetImages( images_converted );
myCalculator->SetParameters( means_converted, variances_converted, mixtureWeights_converted, numberOfGaussiansPerClass_converted );
calculator = myCalculator;
}

/*
std::vector< ImageType::ConstPointer> images_converted;
for(auto image: images){
ImageType::ConstPointer constImage = static_cast< const ImageType* >( image.m_image.GetPointer() );
Expand All @@ -89,7 +127,7 @@ class KvlCostAndGradientCalculator {
myCalculator->SetImages( images_converted );
myCalculator->SetParameters( means_converted, variances_converted, mixtureWeights_converted, numberOfGaussiansPerClass_converted );
calculator = myCalculator;

*/
} else if (typeName == "MutualInformation") {

kvl::MutualInformationCostAndGradientCalculator::Pointer myCalculator = kvl::MutualInformationCostAndGradientCalculator::New();
Expand Down Expand Up @@ -191,6 +229,198 @@ class KvlCostAndGradientCalculator {
py::array_t<double> gradient_np = createNumpyArrayCStyle({numberOfNodes, 3}, data);
return {cost, gradient_np};
};

void SetDSWparams(int numberOfContrasts,
std::vector<KvlImage> DTIimages,
py::array_t< double > DSWbetaMixtureWeights,
py::array_t< int > numberOfDSWbetaePerClass,
double voxratio,
py::array_t< double > DSWbetaAlpha,
py::array_t< double > DSWbetaMeans,
py::array_t< double > DSWbetaBeta,
py::array_t< double > DSWbetaConcentration,
py::array_t< double > logKummerSamples,
double logKummerIncrement)
{
// convert DSWbetaMixWeights
std::vector< double > DSWbetaMixtureWeights_converted;
int DSWbetaWeightsToConvert = DSWbetaMixtureWeights.request().size;
for (int i = 0; i <DSWbetaWeightsToConvert; i++){
DSWbetaMixtureWeights_converted[ i ] = DSWbetaMixtureWeights.at(i);
}

// conver numberOfDSWbetaePerClass
std::vector< int > numberOfDSWbetaePerClass_converted;
int numberOfDSWbetaeToConvert = numberOfDSWbetaePerClass.request().size;
for (int i = 0; i < numberOfDSWbetaeToConvert; i++){
numberOfDSWbetaePerClass_converted[ i ] = numberOfDSWbetaePerClass.at(i);
}

//convet DSWbetaAlpha
std::vector< double > DSWbetaAlpha_converted;
int numDSWbetaAlphaToConvert = DSWbetaAlpha.request().size;
for (int i = 0; i < numDSWbetaAlphaToConvert; i++){
DSWbetaAlpha_converted[ i ] = DSWbetaAlpha.at(i);
}

//convert DSWbetaMeans
py::buffer_info DSWbetaMeans_info = DSWbetaMeans.request();
std::vector< vnl_vector< double >> DSWbetaMeans_converted;
int numDSWMeansToConvert = DSWbetaMeans_info.shape[0];
int meansToConvert = DSWbetaMeans_info.shape[1];
for (int i = 0; i < numDSWMeansToConvert; i++){
vnl_vector< double > betaMean_converted(meansToConvert, 0.0f);
for (int j = 0; j < meansToConvert; j++){
betaMean_converted[j] = DSWbetaMeans.at(i,j);
}
DSWbetaMeans_converted.push_back(betaMean_converted);
}

//convert DSWbetaBeta
std::vector< double > DSWbetaBeta_converted;
int DSWbetaBetaToConvert = DSWbetaBeta.request().size;
for (int i = 0; i < DSWbetaBetaToConvert; i++){
DSWbetaBeta_converted[ i ] = DSWbetaBeta.at(i);
}

// convert DSWbetaConcentration
std::vector< double > DSWbetaConcentration_converted;
int DSWbetaConcentrationToConvert = DSWbetaConcentration.request().size;
for (int i = 0; i < DSWbetaConcentrationToConvert; i++){
DSWbetaConcentration_converted[ i ] = DSWbetaConcentration.at(i);
}

//convert logKummerSamples
std::vector< double > logKummerSamples_converted;
int logKummerSamplesToConvert = logKummerSamples.request().size;
for (int i = 0; i < logKummerSamplesToConvert; i++){
logKummerSamples_converted[ i ] = logKummerSamples.at(i);
}

//convert DTIimages
std::vector< ImageType::ConstPointer> DTIimages_converted;
for(auto image: DTIimages){
ImageType::ConstPointer constImage = static_cast< const ImageType* >( image.m_image.GetPointer() );
DTIimages_converted.push_back( constImage );
}

calculator->SetDiffusionParameters(numberOfContrasts, DSWbetaMixtureWeights_converted, numberOfDSWbetaePerClass_converted, voxratio,\
DSWbetaAlpha_converted, DSWbetaMeans_converted, DSWbetaBeta_converted, DSWbetaConcentration_converted,\
logKummerSamples_converted, logKummerIncrement);
kvl::AtlasMeshToIntensityImageCostAndGradientCalculatorBase::Pointer myCalculator= dynamic_cast< kvl::AtlasMeshToIntensityImageCostAndGradientCalculatorBase*>( calculator.GetPointer() );
myCalculator->SetDiffusionImages(DTIimages_converted);

}

void SetFrobeniusParams(int numberOfContrasts,
std::vector<KvlImage> DTIimages,
py::array_t< double > frobMixtureWeights,
py::array_t< int > numberOfFrobeniusPerClass,
double voxratio,
py::array_t< double > frobVariance,
py::array_t< double > frobMeans)
{
// convert frobMixtureWeights
std::vector< double > frobMixtureWeights_converted;
int frobMixWeightsToConvert = frobMixtureWeights.request().size;
for (int i = 0; i < frobMixWeightsToConvert; i++){
frobMixtureWeights_converted[ i ] = frobMixtureWeights.at(i);
}

//convert numberOfFrobeniusPerClass
std::vector< int > numberOfFrobeniusPerClass_converted;
int numFrobPerClassToConvert = numberOfFrobeniusPerClass.request().size;
for (int i = 0; i < numFrobPerClassToConvert; i++){
numberOfFrobeniusPerClass_converted[ i ] = numberOfFrobeniusPerClass.at(i);
}

//convert frobVariance
std::vector< double > frobVariance_converted;
int frobVarToConvert = frobVariance.request().size;
for (int i = 0; i < frobVarToConvert; i++){
frobVariance_converted[ i ] = frobVariance.at(i);
}

//convert frobMeans
std::vector< vnl_vector < double > > frobMeans_converted;
int frobMeansToConvert = frobMeans.request().size;
for (int i = 0; i < frobMeansToConvert; i++){
vnl_vector< double > frobMean_converted(numFrobPerClassToConvert, 0.0f);
//for (int j = 0; j < frobMeans.at(i).request().size; j++){
for (int j = 0; j < frobMeansToConvert; j++){
frobMean_converted[ j ] = frobMeans.at(i,j);
}
frobMeans_converted.push_back(frobMean_converted);
}

//convert DTIimages
std::vector< ImageType::ConstPointer> DTIimages_converted;
for(auto image: DTIimages){
ImageType::ConstPointer constImage = static_cast< const ImageType* >( image.m_image.GetPointer() );
DTIimages_converted.push_back( constImage );
}

calculator->SetDiffusionParameters(numberOfContrasts, frobMixtureWeights_converted, numberOfFrobeniusPerClass_converted,\
voxratio, frobVariance_converted, frobMeans_converted);
kvl::AtlasMeshToIntensityImageCostAndGradientCalculatorBase::Pointer myCalculator= dynamic_cast< kvl::AtlasMeshToIntensityImageCostAndGradientCalculatorBase*>( calculator.GetPointer() );
myCalculator->SetDiffusionImages(DTIimages_converted);
}

void SetWishartParams(int numberOfContrasts,
std::vector<KvlImage> DTIimages,
py::array_t< double > wmmMixtureWeights,
py::array_t< int > numberOfWishartsPerClass,
double voxratio,
py::array_t< double > degreesOfFreedom,
py::array_t< double > scaleMatrices)
{
//convert wmmMixWeights
std::vector< double > wmmMixtureWeights_converted;
// should this be size_t?
int numMixWeights = wmmMixtureWeights.request().size;
for (int i = 0; i < numMixWeights; i++){
wmmMixtureWeights_converted[ i ] = wmmMixtureWeights.at(i);
}

//convert numWisharts/class
std::vector< int > numberOfWishartsPerClass_converted;
int numWishartToConvert = numberOfWishartsPerClass.request().size;
for (int i = 0; i < numWishartToConvert; i++){
numberOfWishartsPerClass_converted[ i ] = numberOfWishartsPerClass.at(i);
}

//convert degFreedom
std::vector< double > degreesOfFreedom_converted;
int numToConvert = degreesOfFreedom.request().size;
for (int i = 0; i < numToConvert; i++){
degreesOfFreedom_converted[ i ] = degreesOfFreedom.at(i);
}

//convert scaleMatrices
std::vector< vnl_matrix< double >> scaleMatrices_converted;
int numMatToConvert = scaleMatrices.request().size;
// loop through each element of scaleMatrices, each matrix should have dims of wmmMixWeights?
for (int i = 0; i < numMatToConvert; i++){
vnl_matrix< double > matrix_converted(numWishartToConvert, numWishartToConvert);
for (int row = 0; row < numWishartToConvert; row++){
for (int col = 0; col < numWishartToConvert; col++){
matrix_converted[ row ][ col ] = scaleMatrices.at(i, row, col);
}
}
}

//convert DTIimages
std::vector< ImageType::ConstPointer> DTIimages_converted;
for(auto image: DTIimages){
ImageType::ConstPointer constImage = static_cast< const ImageType* >( image.m_image.GetPointer() );
DTIimages_converted.push_back( constImage );
}

kvl::AtlasMeshToIntensityImageCostAndGradientCalculatorBase::Pointer myCalculator= dynamic_cast< kvl::AtlasMeshToIntensityImageCostAndGradientCalculatorBase*>( calculator.GetPointer() );
myCalculator->SetDiffusionImages(DTIimages_converted);
calculator->SetDiffusionParameters(numberOfContrasts, wmmMixtureWeights_converted, numberOfWishartsPerClass_converted, voxratio, degreesOfFreedom_converted, scaleMatrices_converted);

}
};

#endif //GEMS_PYKVLCALCULATOR_H
2 changes: 1 addition & 1 deletion samseg/cxx/pyKvlOptimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class KvlOptimizer {
}
case 'B':
{
kvl::AtlasMeshDeformationLBFGSOptimizer::Pointer myOptimizer
kvl::AtlasMeshDeformationLBFGSOptimizer::Pointer myOptimizer
= dynamic_cast< kvl::AtlasMeshDeformationLBFGSOptimizer* >( optimizer.GetPointer() );
if ( myOptimizer )
{
Expand Down

0 comments on commit f31228e

Please sign in to comment.