diff --git a/HD_BET/hd-bet b/HD_BET/hd-bet index e0bf563..11faaab 100755 --- a/HD_BET/hd-bet +++ b/HD_BET/hd-bet @@ -1,11 +1,13 @@ #!/usr/bin/env python import os +import multiprocessing from HD_BET.run import run_hd_bet from HD_BET.utils import maybe_mkdir_p, subfiles import HD_BET + if __name__ == "__main__": print("\n########################") print("If you are using hd-bet, please cite the following paper:") @@ -33,6 +35,10 @@ if __name__ == "__main__": '\'cpu\' to run on CPU. When using CPU you should ' 'consider disabling tta. Default for -device is: 0', required=False) + parser.add_argument('-threads', default=0, type=int, help='used to set the number of cpu threads. ' + 'Must be either int or str. Use 0 for max available cpus.' + 'Default for -threads is: 0', + required=False) parser.add_argument('-tta', default=1, required=False, type=int, help='whether to use test time data augmentation ' '(mirroring). 1= True, 0=False. Disable this ' 'if you are using CPU to speed things up! ' @@ -58,6 +64,7 @@ if __name__ == "__main__": mode = args.mode device = args.device + threads = args.threads tta = args.tta pp = args.pp save_mask = args.save_mask @@ -90,6 +97,13 @@ if __name__ == "__main__": output_files = [output_file_or_dir] input_files = [input_file_or_dir] + max_cpu_count = multiprocessing.cpu_count() + if threads == 0: + threads = max_cpu_count + elif threads not in range(1, max_cpu_count + 1): + raise ValueError(f"Unknown value for threads: {threads}. Expected: value between 1 and maximum number of available threads ({max_cpu_count}) \ + Tip: A value of 0 will pick the maximum available number.") + if tta == 0: tta = False elif tta == 1: @@ -118,4 +132,4 @@ if __name__ == "__main__": else: raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp)) - run_hd_bet(input_files, output_files, mode, config_file, device, pp, tta, save_mask, overwrite_existing) + run_hd_bet(input_files, output_files, mode, config_file, device, threads, pp, tta, save_mask, overwrite_existing) diff --git a/HD_BET/run.py b/HD_BET/run.py index 858934d..f2464e8 100755 --- a/HD_BET/run.py +++ b/HD_BET/run.py @@ -20,7 +20,7 @@ def apply_bet(img, bet, out_fname): def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.join(HD_BET.__path__[0], "config.py"), device=0, - postprocess=False, do_tta=True, keep_mask=True, overwrite=True): + threads=0, postprocess=False, do_tta=True, keep_mask=True, overwrite=True): """ :param mri_fnames: str or list/tuple of str @@ -35,6 +35,8 @@ def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.j :return: """ + torch.set_num_threads(threads) + list_of_param_files = [] if mode == 'fast':