Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bottleneck at "utils.py:interpolate_img" with larger datasets #74

Open
bsmarine opened this issue Dec 4, 2020 · 12 comments
Open

Bottleneck at "utils.py:interpolate_img" with larger datasets #74

bsmarine opened this issue Dec 4, 2020 · 12 comments

Comments

@bsmarine
Copy link

bsmarine commented Dec 4, 2020

Hi,

Thanks for sharing and maintaining batchgenerators it's been very useful for our work!

We are trying to scale up training larger datasets but see a considerable bottleneck with interpolate_img in the spatial_transform. It is taking nearly 11 seconds for each 3D image per channel per batch sample.

input batch data shape = 8 x 3 x 512 x 512 x 768 (float 32)

SpatialTransform(patch_size=[256, 256, 256], patch_center_dist_from_border= (125.0, 125.0), do_elastic_deform=False, alpha=(0.0, 1500.0), sigma=(30.0, 50.0), do_rotation=True, angle_x= (0, 0.0),angle_y=(0, 0.0),angle_z=(0.0, 6.283185307179586), do_scale=True, scale=(0.8, 1.1),random_crop=False)

When profiling it looks like this stems from the scipy spline filter1d function.

Have you or anyone else encountered this with larger datasets? Any suggestions on how to speed up or work around?

Thanks,

Brett

@FabianIsensee
Copy link
Member

Hi Brett,
this is normal and unfortunately there is nothing I can really do about this because the function you mention wil part of the scipy.ndimage.map_coordinates function we need for interpolation. I have not experimented with disabling the filter (there should be an option to do so in map_coordinates) so I don't know what the effect on the output will be.
If you need a speedup you can also try setting order_data=0 and order_seg=0. This will do nearest neighbor interpolation.
Best,
Fabian

@bsmarine
Copy link
Author

bsmarine commented Dec 9, 2020

Hi Fabian,

Thanks so much for the quick response. I took your advice and changed the order for data and seg to 0. Unfortunately this step is still the major bottleneck.

I recently heard of a library for pytorch monai that appears to implement some of these augmentations on the GPU side. May be the way to go since interpolation for augmentation with large datasets like these just isn't cheap.

Best,

Brett

@FabianIsensee
Copy link
Member

Hi Brett,
GPU augmentations is certainly an interesting topic and I will look into this once it is fully implemented. However, this will complicate the dataloading a lot, and I will have to create my own implementations because I like to do some things differently than most other people.
Also keep in mind that every second the GPU is not being used for training is potentially wasted. Given that you get get really high end server CPUs with 64 cores for relatively cheap (in comparison to equally high end graphics cards) my bet would just be on building servers with more CPU cores per GPU (we will be aimong for 16C/32T per GPU in our next orders).
There are a lot of pitfalls in data augmentation pipelines and even in MONAI it is not implemented efficiently. batchgenerators is more complicated than monai, but it offers far more room for optimizations. When using CPU, I doubt you will get the same throughput with (the current state of) monai as you can get with a well optimized batchgenerators pipeline ;-)
Best,
Fabian

@FabianIsensee
Copy link
Member

If you want you can share a standalone dummy script of your data augmentation pipeline and I will have a look at it. Please make sure it is standalone (no funky dependencies) and can be run by itself

@bsmarine
Copy link
Author

Hi Fabian,

Would definitely appreciate your input. We are actually adapting MIC-DKFZ's medicaldetectiontoolkit repo and using the data_loader.py on batches of 8 x 3 x 256 x 256 x 256. Each input 3x256x256x256 numpy array is float16 and 97 MB.

https://github.com/bsmarine/BleedDetection/blob/caef73ad87f0d0f27e1a5eb6d9d9fc01e324b4c9/experiments/bleed_exp/data_loader.py#L178

This is our system's CPU specs:

Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 80
On-line CPU(s) list: 0-79
Thread(s) per core: 2
Core(s) per socket: 20
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 79
Model name: Intel(R) Xeon(R) CPU E5-2698 v4 @ 2.20GHz
Stepping: 1
CPU MHz: 3498.769
CPU max MHz: 3600.0000
CPU min MHz: 1200.0000
BogoMIPS: 4390.14
Virtualization: VT-x
L1d cache: 32K
L1i cache: 32K
L2 cache: 256K
L3 cache: 51200K
NUMA node0 CPU(s): 0-19,40-59
NUMA node1 CPU(s): 20-39,60-79

@FabianIsensee
Copy link
Member

Hi,
can you please provide a standalone script that does not depend on external resources to run? This script should only contain a dummy data loader (which is representative of your workload) and a data augmentation pipeline that goes along with it.
Since you are using the medicaldetectiontoolkit I will also tag @pfjaeger . Maybe your augmentations are so slow because of the way bounding boxes are handled. Maybe hes has some insights. I am mostly a segmentation guy :-)
Best,
Fabian

@pfjaeger
Copy link
Member

Hi,
I would be surprised if bounding boxes were the problem. they are drawn around segmentations only after the spatial transform. which format are your pixel-wise annotations in?

  1. A label map with individual ROIs identified by increasing label values.
  2. A binary label map. There is only one foreground class and single lesions are not identified.
    All lesions have the same class target (foreground).

@bsmarine
Copy link
Author

bsmarine commented Dec 12, 2020

Hi Paul,

Thanks for your help as well.

I'm using the same data_loader that you provide in MDTK for the LIDC dataset (numpy file input) except that our image data has 3 channels. The annotations are pixel-wise numpy binary label maps, uint8.

As I mentioned in original post profiling revealed the slow down to be at the scipy interpolation step of the batchgenerator spatial transform, not ConvertSegToBoundingBoxCoordinates. I profiled in SingleThreaded mode for simplicity. See that output below the dummy script.

@FabianIsensee, maybe Paul can correct my standalone dummy script of the augmentation below... we actually implement as a wrapper around an iterator through batches of training data so we don't entirely reproduce the same action. And sure enough, when I run this standalone script it's fast and does not reproduce the slow down :(

Maybe the problem lies with the image data I'm using?? On visualization it looks unremarkable. Is there something other than the data type I could investigate??

import numpy as np
import cProfile, pstats

from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror
from batchgenerators.transforms.abstract_transforms import Compose
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from batchgenerators.dataloading import SingleThreadedAugmenter
from batchgenerators.transforms.spatial_transforms import SpatialTransform
from batchgenerators.transforms.utility_transforms import ConvertSegToBoundingBoxCoordinates


def augmentation(patient_data):
    my_transforms = []

    mirror_transform = Mirror(axes=np.arange(3))
    my_transforms.append(mirror_transform)

    spatial_transform = SpatialTransform(patch_size=[256, 256, 256], patch_center_dist_from_border= (125.0, 125.0), 
                                do_elastic_deform=False, alpha=(0.0, 1500.0), sigma=(30.0, 50.0), do_rotation=True, 
                                angle_x= (0, 0.0),angle_y=(0, 0.0),angle_z=(0.0, 6.283185307179586), do_scale=True, 
                                scale=(0.8, 1.1),random_crop=False,order_data=0,order_seg=0)

    my_transforms.append(spatial_transform)


    my_transforms.append(ConvertSegToBoundingBoxCoordinates(3, get_rois_from_seg_flag=False, class_specific_seg_flag=False))
    all_transforms = Compose(my_transforms)

    multithreaded_generator = SingleThreadedAugmenter(patient_data, all_transforms)
    #multithreaded_generator = MultiThreadedAugmenter(patient_data, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
    
    return multithreaded_generator

##Dummy Data Creation

dumb_img = np.random.random_sample((3,256,256,256))-0.5
dumb_img.astype('float16')
data = list()
for i in range(1,9):
    data.append(dumb_img)

dumb_seg = np.zeros(shape=(256,256,256))
dumb_seg[120:135,120:135,120:135] = 1
dumb_seg.astype('uint8')
seg = list()
for i in range(1,9):
    seg.append(dumb_seg)

pid = list()
for i in range(1,9):
    pid.append(str(i))

patient_data = {'data':data,'seg':seg,'pid':pid} #Data, Seg, PID dictionary

### Run Standalone Script

augmented_data = augmentation(patient_data)
         25308 function calls (24375 primitive calls) in 63.707 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      2/1    0.000    0.000   59.316   59.316 {built-in method builtins.next}
        1    0.002    0.002   59.316   59.316 /home/aisinai/medicaldetectiontoolkit/mdt/lib/python3.8/site-packages/batchgenerators-0.20.1-py3.8.egg/batchgenerators/dataloading/single_threaded_augmenter.py:44(__next__)
        1    0.032    0.032   56.767   56.767 /home/aisinai/medicaldetectiontoolkit/mdt/lib/python3.8/site-packages/batchgenerators-0.20.1-py3.8.egg/batchgenerators/transforms/abstract_transforms.py:86(__call__)
        1    0.001    0.001   53.701   53.701 /home/aisinai/medicaldetectiontoolkit/mdt/lib/python3.8/site-packages/batchgenerators-0.20.1-py3.8.egg/batchgenerators/transforms/spatial_transforms.py:331(__call__)
        1    2.994    2.994   53.700   53.700 /home/aisinai/medicaldetectiontoolkit/mdt/lib/python3.8/site-packages/batchgenerators-0.20.1-py3.8.egg/batchgenerators/augmentations/spatial_transformations.py:190(augment_spatial)
       32    0.270    0.008   45.717    1.429 /home/aisinai/medicaldetectiontoolkit/mdt/lib/python3.8/site-packages/batchgenerators-0.20.1-py3.8.egg/batchgenerators/augmentations/utils.py:142(interpolate_img)
       32    0.001    0.000   41.897    1.309 /home/aisinai/medicaldetectiontoolkit/mdt/lib/python3.8/site-packages/scipy-1.5.4-py3.8-linux-x86_64.egg/scipy/ndimage/interpolation.py:308(map_coordinates)
       32   41.893    1.309   41.893    1.309 {built-in method scipy.ndimage._nd_image.geometric_transform}
      130    5.456    0.042    5.456    0.042 {method 'astype' of 'numpy.ndarray' objects}
      524    4.844    0.009    4.844    0.009 {built-in method numpy.array}
  800/471    1.082    0.001    3.492    0.007 {built-in method numpy.core._multiarray_umath.implement_array_function}
        1    0.181    0.181    3.477    3.477 models/mrcnn.py:949(train_forward)

@bsmarine
Copy link
Author

@FabianIsensee and @pfjaeger

My apologies! The standalone script above is wrong.

This standalone script below does replicate the issue on my machine with dummy data. I'd greatly appreciate if you can let me know if you also experience long augmentation runtime. It appears to stem from the scipy map_coordinates function (both order 0 and 3) when profiling. See profile below as well.

@pfjaeger is there a time benchmark for one batch you may be able to share for MDTK's spatial augmentation when you've run it on 3D data, or multichannel 3D data? It'd be great to know whether this is an inherent limitation of batchgenerator/scipy before moving on to something like monai.

Thank you both for your input!

import os
import cProfile, pstats

from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror
from batchgenerators.transforms.abstract_transforms import Compose
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from batchgenerators.dataloading import SingleThreadedAugmenter
from batchgenerators.transforms.spatial_transforms import SpatialTransform
from batchgenerators.transforms.utility_transforms import ConvertSegToBoundingBoxCoordinates


def augmentation(patient_data):
    my_transforms = []

    mirror_transform = Mirror(axes=np.arange(3))
    my_transforms.append(mirror_transform)

    spatial_transform = SpatialTransform(patch_size=[256, 256, 256], patch_center_dist_from_border= (125.0, 125.0), 
                                do_elastic_deform=False, alpha=(0.0, 1500.0), sigma=(30.0, 50.0), do_rotation=True, 
                                angle_x= (0, 0.0),angle_y=(0, 0.0),angle_z=(0.0, 6.283185307179586), do_scale=True, 
                                scale=(0.8, 1.1),random_crop=False,order_data=2,order_seg=2)

    my_transforms.append(spatial_transform)


    my_transforms.append(ConvertSegToBoundingBoxCoordinates(3, get_rois_from_seg_flag=False, class_specific_seg_flag=False))
    all_transforms = Compose(my_transforms)

    multithreaded_generator = SingleThreadedAugmenter(patient_data, all_transforms)
    #multithreaded_generator = MultiThreadedAugmenter(patient_data, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
    
    return multithreaded_generator

##Dummy Data Creation

dumb_img = np.random.random_sample((3,256,256,256))-0.5
dumb_img.astype('float16')
data = list()
for i in range(0,8):
    data.append(dumb_img)

dumb_seg = np.zeros(shape=(1,256,256,256))
dumb_seg[0][120:135,120:135,120:135] = 1
dumb_seg.astype('uint8')
seg = list()
for i in range(0,8):
    seg.append(dumb_seg)

class_target = list()
for i in range(0,8):
    class_target.append([1])

batch_ids = [['1'],['2'],['3'],['4'],['5'],['6'],['7'],['8']]

data = np.array(data)
seg = np.array(seg)
class_target = np.array(class_target)
print (data.shape,seg.shape,class_target.shape,class_target)

batches = list()

batch_one = {'data':data,'seg':seg,'pid':batch_ids,'class_target':class_target} #Data, Seg, PID dictionary

batches.append(batch_one)

batches_i = iter(batches)

### Run and Profile Standalone Script

profiler = cProfile.Profile()
profiler.enable()

augmented_data = augmentation(batches_i)

result = next(augmented_data)

profiler.disable()
stats = pstats.Stats(profiler).sort_stats('cumtime')
stats.print_stats()
         5898 function calls (5841 primitive calls) in 167.779 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      2/1    0.000    0.000  167.778  167.778 {built-in method builtins.next}
        1    0.000    0.000  167.778  167.778 /home/aisinai/medicaldetectiontoolkit/mdt/lib/python3.8/site-packages/batchgenerators-0.20.1-py3.8.egg/batchgenerators/dataloading/single_threaded_augmenter.py:44(__next__)
        1    0.019    0.019  167.778  167.778 /home/aisinai/medicaldetectiontoolkit/mdt/lib/python3.8/site-packages/batchgenerators-0.20.1-py3.8.egg/batchgenerators/transforms/abstract_transforms.py:86(__call__)
        1    0.001    0.001  160.102  160.102 /home/aisinai/medicaldetectiontoolkit/mdt/lib/python3.8/site-packages/batchgenerators-0.20.1-py3.8.egg/batchgenerators/transforms/spatial_transforms.py:331(__call__)
        1    2.729    2.729  160.101  160.101 /home/aisinai/medicaldetectiontoolkit/mdt/lib/python3.8/site-packages/batchgenerators-0.20.1-py3.8.egg/batchgenerators/augmentations/spatial_transformations.py:190(augment_spatial)
       32    1.474    0.046  151.464    4.733 /home/aisinai/medicaldetectiontoolkit/mdt/lib/python3.8/site-packages/batchgenerators-0.20.1-py3.8.egg/batchgenerators/augmentations/utils.py:142(interpolate_img)
       40    0.001    0.000  143.406    3.585 /home/aisinai/medicaldetectiontoolkit/mdt/lib/python3.8/site-packages/scipy-1.5.4-py3.8-linux-x86_64.egg/scipy/ndimage/interpolation.py:308(map_coordinates)
       40  118.755    2.969  118.755    2.969 {built-in method scipy.ndimage._nd_image.geometric_transform}
       40    0.001    0.000   24.644    0.616 /home/aisinai/medicaldetectiontoolkit/mdt/lib/python3.8/site-packages/scipy-1.5.4-py3.8-linux-x86_64.egg/scipy/ndimage/interpolation.py:133(spline_filter)
      120    0.002    0.000   24.641    0.205 /home/aisinai/medicaldetectiontoolkit/mdt/lib/python3.8/site-packages/scipy-1.5.4-py3.8-linux-x86_64.egg/scipy/ndimage/interpolation.py:55(spline_filter1d)
      120   24.637    0.205   24.637    0.205 {built-in method scipy.ndimage._nd_image.spline_filter1d}
  449/393    1.330    0.003    6.556    0.017 {built-in method numpy.core._multiarray_umath.implement_array_function}
        1    0.000    0.000    4.794    4.794 /home/aisinai/medicaldetectiontoolkit/mdt/lib/python3.8/site-packages/batchgenerators-0.20.1-py3.8.egg/batchgenerators/transforms/spatial_transforms.py:203(__call__)

@FabianIsensee
Copy link
Member

Hi @bsmarine ,
I finally found the time to look into this today. From my perspective, everything looks fine. I could replicate the long run time you reported (179s for me), but that is completely normal for this size of input. Unfortunately, speeding this up is beyond our control: We merely generate a coordinate grid that is scaled and rotated and then let scipy do the interpolation. This interpolation is implemented in C code in their backend (built-in method scipy.ndimage._nd_image.geometric_transform) and I would presume that they know what they are doing coding-wise :-)
Here are a couple of things you can consider to speed up the calculations:

  1. reduce the order of interpolation. If you set this to 0, it's doing nearest neighbor which is a lot faster. 1 is linear

order 0: 28
order 1: 49
order 2: 179
order 3: 253

  1. What is the patch size your model is actually trained with? If this is not 256x256x256, make sure to tell SpatialTransform the actual final patch size. If I replace 256x256x256 with 128x128x128 in the spatialtransform then the run time is reduced to 96s (from 179). Note that the output size is then of course 128, not 256.

  2. SpatialTransform has parameters p_rot_per_sample and p_scale_per_sample which default to 1. This means that it will apply these augmentations to all the patches. I have confirmed experimentally (segmentation) that this is not necessarily ideal: you want on the one hand diversity and and the same time not mess with the data distribution too much. Therefore I would recommend setting these to lower values. 0.3 works well for me. What this results in is that only 1 - (1 - 0.3) * (1 - 0.3) = 51% of the patches will be augmented. This would cut your CPU time in half. You can even go lower than that.

  3. I presume you are doing that already but using multithreaded augmentation really goes a long way. Use as many CPUs for this as you can.

Best,
Fabian

@bsmarine
Copy link
Author

bsmarine commented Dec 24, 2020 via email

@FabianIsensee
Copy link
Member

Hi,
the complexity of SpatialTransform should only be related to the output size, so you could initially crop a larger patch from your volumes and then let SpatialTransform reduce that to 256^3. That way you are preventing border artifacts from appearing.
I was a bit surprised about the covid seg results. We did not spend a lot of time with the competition and it is nice to see that what we did still kind of worked. But it was very hard to find anything that would work better than the original nnU-Net :-D
Best,
Fabian

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants