A collection of pre-trained medical image models in PyTorch. This repository aims to provide a unified and easy-to-use interface for comparing and deploying these models.
- STU-Net (
STU-Net-S
,STU-Net-B
,STU-Net-L
,STU-Net-H
) pre-trained onTotalSegmentator
,CT-ORG
,FeTA21
,BraTS21
(more datasets are WIP). - SAM-Med3D (
SAM-Med3D
) pre-trained onSA-Med3D-140K
. - Other pre-trained medical image models are WIP. (You can request support for your model in Issues.)
You can use this cmd to install this toolkit via pip:
pip install medim
For developers, you can install in the editable mode via:
git clone https://github.com/uni-medical/MedIM.git cd MedIM pip install -e .
First, let us import medim
.
import medim
You have four ways to create a PyTorch-compatible model with create_model
:
1. use models without pretraining
model = medim.create_model("STU-Net-S")
2. use local checkpoint
model = medim.create_model(
"STU-Net-S",
pretrained=True,
checkpoint_path="../tests/data/small_ep4k.model")
3. use checkpoint pre-trained on validated datasets (will automatically download it from HuggingFace)
model = medim.create_model("STU-Net-B", dataset="BraTS21")
4. use HuggingFace url (will automatically download it from HuggingFace)
model = medim.create_model(
"STU-Net-S",
pretrained=True,
checkpoint_path="https://huggingface.co/ziyanhuang/STU-Net/blob/main/small_ep4k.model")
Tips: you can use
MEDIM_CKPT_DIR
environment variable to set custom path for medim model downloading from huggingface.
Then, you can use it as you like.
input_tensor = torch.randn(1, 1, 128, 128, 128)
output_tensor = model(input_tensor)
print("Output tensor shape:", output_tensor.shape)
More examples are in examples.
- support more pre-training of STU-Net on different datasets.
- support more pre-trained medical image models.
- An easy-to-use interface compatible with MONAI/nnU-Net is still under development. Once developed, you will be able to deploy medical image models more elegantly within the Python/PyTorch ecosystem.