Skip to content

Commit

Permalink
Refactor common functions for data processing
Browse files Browse the repository at this point in the history
  • Loading branch information
milank94 committed May 30, 2024
1 parent b206678 commit 306fff2
Showing 1 changed file with 31 additions and 22 deletions.
53 changes: 31 additions & 22 deletions benchmark/models/yolo_v5/utils/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
# SPDX-License-Identifier: Apache-2.0

from pathlib import Path
Expand Down Expand Up @@ -31,27 +31,36 @@ def data_preprocessing(ims: Image.Image, size: tuple) -> tuple:
List of images, number of samples, filenames, image size, inference size, preprocessed images
"""

n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
shape0, shape1, files = [], [], [] # image and inference shapes, filenames

for i, im in enumerate(ims):
f = f"image{i}" # filename
im, f = np.asarray(exif_transpose(im)), getattr(im, "filename", f) or f
files.append(Path(f).with_suffix(".jpg").name)
if im.shape[0] < 5: # image in CHW
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
s = im.shape[:2] # HWC
shape0.append(s) # image shape
g = max(size) / max(s) # gain
shape1.append([int(y * g) for y in s])
ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
shape1 = [size[0] for _ in np.array(shape1).max(0)] # inf shape
x = [letterbox(im, shape1, auto=False)[0] for im in ims] # pad
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
x = torch.from_numpy(x) / 255 # uint8 to fp16/32

return ims, n, files, shape0, shape1, x
if not isinstance(ims, (list, tuple)):
ims = [ims]
num_images = len(ims)
shape_orig, shape_infer, filenames = [], [], []

for idx, img in enumerate(ims):
filename = getattr(img, "filename", f"image{idx}")
img = np.asarray(exif_transpose(img))
filename = Path(filename).with_suffix(".jpg").name
filenames.append(filename)

if img.shape[0] < 5:
img = img.transpose((1, 2, 0))

if img.ndim == 3:
img = img[..., :3]
else:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

shape_orig.append(img.shape[:2])
scale = max(size) / max(img.shape[:2])
shape_infer.append([int(dim * scale) for dim in img.shape[:2]])
ims[idx] = img if img.flags["C_CONTIGUOUS"] else np.ascontiguousarray(img)

shape_infer = [size[0] for _ in np.array(shape_infer).max(0)]
imgs_padded = [letterbox(img, shape_infer, auto=False)[0] for img in ims]
imgs_padded = np.ascontiguousarray(np.array(imgs_padded).transpose((0, 3, 1, 2)))
tensor_imgs = torch.from_numpy(imgs_padded) / 255

return ims, num_images, filenames, shape_orig, shape_infer, tensor_imgs


def yolov5_preprocessing(dataset, target_height, target_width):
Expand Down

0 comments on commit 306fff2

Please sign in to comment.