Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved 2D preprocessing script #287

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion MedSAM_Inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def medsam_inference(medsam_model, img_embed, box_1024, H, W):
"-chk",
"--checkpoint",
type=str,
default="work_dir/MedSAM/medsam_vit_b.pth",
default="medsam_vit_b.pth",
help="path to the trained model",
)
args = parser.parse_args()
Expand Down
58 changes: 39 additions & 19 deletions train_one_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ def __init__(self, data_root, bbox_shift=20):
self.gt_path = join(data_root, "gts")
self.img_path = join(data_root, "imgs")
self.gt_path_files = sorted(
glob.glob(join(self.gt_path, "**/*.npy"), recursive=True)
glob.glob(join(self.gt_path, "**/*.npy.npz"), recursive=True)
)
print(f"number of images: {len(self.gt_path_files)}")

self.gt_path_files = [
file
for file in self.gt_path_files
Expand All @@ -80,31 +82,48 @@ def __getitem__(self, index):
img_1024 = np.load(
join(self.img_path, img_name), "r", allow_pickle=True
) # (1024, 1024, 3)

# Access the array using the default key 'arr_0'
img_1024 = img_1024['arr_0']/255.0 # Use the key 'arr_0' here

# convert the shape to (3, H, W)
img_1024 = np.transpose(img_1024, (2, 0, 1))
assert (
np.max(img_1024) <= 1.0 and np.min(img_1024) >= 0.0
), "image should be normalized to [0, 1]"
gt = np.load(
self.gt_path_files[index], "r", allow_pickle=True
) # multiple labels [0, 1,4,5...], (256,256)
)['arr_0'] # Use the key 'arr_0' here # multiple labels [0, 1,4,5...], (256,256)
assert img_name == os.path.basename(self.gt_path_files[index]), (
"img gt name error" + self.gt_path_files[index] + self.npy_files[index]
)
label_ids = np.unique(gt)[1:]
gt2D = np.uint8(
gt == random.choice(label_ids.tolist())
) # only one label, (256, 256)
assert np.max(gt2D) == 1 and np.min(gt2D) == 0.0, "ground truth should be 0, 1"
y_indices, x_indices = np.where(gt2D > 0)
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
# add perturbation to bounding box coordinates
H, W = gt2D.shape
x_min = max(0, x_min - random.randint(0, self.bbox_shift))
x_max = min(W, x_max + random.randint(0, self.bbox_shift))
y_min = max(0, y_min - random.randint(0, self.bbox_shift))
y_max = min(H, y_max + random.randint(0, self.bbox_shift))
if label_ids.size > 0:
gt2D = np.uint8(gt == random.choice(label_ids.tolist())) # Choose a random label
else:
gt2D = np.zeros_like(gt) # If no labels, create an empty mask
# print("Warning: No labels found other than background. Returning an empty mask.")

# Check and handle if gt2D is empty
if np.max(gt2D) == 0: # Means gt2D contains no positive labels
y_indices, x_indices = np.array([]), np.array([])
else:
y_indices, x_indices = np.where(gt2D > 0)

# If no indices are found, set bounding box to zero-size at top left corner
if y_indices.size == 0 or x_indices.size == 0:
x_min, x_max, y_min, y_max = 0, 0, 0, 0
else:
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
# Add perturbation to bounding box coordinates
H, W = gt2D.shape
x_min = max(0, x_min - random.randint(0, self.bbox_shift))
x_max = min(W, x_max + random.randint(0, self.bbox_shift))
y_min = max(0, y_min - random.randint(0, self.bbox_shift))
y_max = min(H, y_max + random.randint(0, self.bbox_shift))


bboxes = np.array([x_min, y_min, x_max, y_max])
return (
torch.tensor(img_1024).float(),
Expand All @@ -115,7 +134,7 @@ def __getitem__(self, index):


# %% sanity test of dataset class
tr_dataset = NpyDataset("data/npy/CT_Abd")
tr_dataset = NpyDataset("/app/data/medsam_practice/npy/Ultrasound_femoralTriangle")
tr_dataloader = DataLoader(tr_dataset, batch_size=8, shuffle=True)
for step, (image, gt, bboxes, names_temp) in enumerate(tr_dataloader):
print(image.shape, gt.shape, bboxes.shape)
Expand Down Expand Up @@ -147,7 +166,7 @@ def __getitem__(self, index):
"-i",
"--tr_npy_path",
type=str,
default="data/npy/CT_Abd",
default="/app/data/medsam_practice/npy/Ultrasound_femoralTriangle",
help="path to training npy files; two subfolders: gts and imgs",
)
parser.add_argument("-task_name", type=str, default="MedSAM-ViT-B")
Expand All @@ -162,9 +181,9 @@ def __getitem__(self, index):
parser.add_argument("-pretrain_model_path", type=str, default="")
parser.add_argument("-work_dir", type=str, default="./work_dir")
# train
parser.add_argument("-num_epochs", type=int, default=1000)
parser.add_argument("-num_epochs", type=int, default=5)
parser.add_argument("-batch_size", type=int, default=2)
parser.add_argument("-num_workers", type=int, default=0)
parser.add_argument("-num_workers", type=int, default=4)
# Optimizer parameters
parser.add_argument(
"-weight_decay", type=float, default=0.01, help="weight decay (default: 0.01)"
Expand Down Expand Up @@ -298,6 +317,7 @@ def main():
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
persistent_workers=True
)

start_epoch = 0
Expand Down
171 changes: 36 additions & 135 deletions tutorial_quickstart.ipynb

Large diffs are not rendered by default.

59 changes: 59 additions & 0 deletions utils/faster_pre_grey_rgb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import numpy as np
import os
from skimage import io, transform
from tqdm import tqdm
import multiprocessing as mp

# Function to process each image and mask
def process_image(name):
img_name_suffix = '.PNG'
gt_name_suffix = '.png'
prefix = modality + '_' + anatomy + '_'
npy_save_name = prefix + name.split(gt_name_suffix)[0] + '.npy'
gt_data_ori = np.uint8(io.imread(os.path.join(gt_path, name)))

for remove_label_id in remove_label_ids:
gt_data_ori[gt_data_ori == remove_label_id] = 0

image_name = name.split(gt_name_suffix)[0] + img_name_suffix
image_data = io.imread(os.path.join(img_path, image_name))
if np.max(image_data) > 255.0:
image_data = np.uint8((image_data - image_data.min()) / (np.max(image_data) - image_data.min()) * 255.0)
if len(image_data.shape) == 2:
image_data = np.repeat(np.expand_dims(image_data, -1), 3, -1)

if do_intensity_cutoff:
lower_bound, upper_bound = np.percentile(image_data[image_data > 0], 0.5), np.percentile(image_data[image_data > 0], 99.5)
image_data = np.clip(image_data, lower_bound, upper_bound)
image_data = (image_data - image_data.min()) / (image_data.max() - image_data.min()) * 255.0
image_data[image_data == 0] = 0

resize_img = transform.resize(image_data, (image_size, image_size), order=3, mode='constant', preserve_range=True, anti_aliasing=True)
resize_gt = transform.resize(gt_data_ori, (image_size, image_size), order=0, mode='constant', preserve_range=True, anti_aliasing=False)

# reduce image size and save compressed npy
np.savez_compressed(os.path.join(npy_path, "imgs", npy_save_name), resize_img.astype(np.uint8))
np.savez_compressed(os.path.join(npy_path, "gts", npy_save_name), resize_gt.astype(np.uint8))

# Main script
if __name__ == '__main__':
modality = 'Ultrasound'
anatomy = 'femoralTriangle'
image_size = 1024
img_path = '/app/data/medsam_practice/images'
gt_path = '/app/data/medsam_practice/labels'
npy_path = '/app/data/medsam_practice/npy/' + modality + '_' + anatomy
os.makedirs(os.path.join(npy_path, "gts"), exist_ok=True)
os.makedirs(os.path.join(npy_path, "imgs"), exist_ok=True)
names = sorted(os.listdir(gt_path))
remove_label_ids = []
do_intensity_cutoff = False

# Create a pool of processes. Number of processes is set to the number of CPUs available.
pool = mp.Pool(mp.cpu_count())

# Process each file in parallel
list(tqdm(pool.imap(process_image, names), total=len(names)))

pool.close()
pool.join()