Skip to content

Commit

Permalink
feat: added support for passing custom feature extractors to online d…
Browse files Browse the repository at this point in the history
…ataset loader
  • Loading branch information
AshishKumar4 committed Sep 10, 2024
1 parent 3c22222 commit 489809a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
24 changes: 19 additions & 5 deletions flaxdiff/data/online_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,22 +117,30 @@ def map_sample(
# "error": str(e)
# })
pass


def default_feature_extractor(sample):
return {
"url": sample["url"],
"caption": sample["caption"],
}

def map_batch(
batch, num_threads=256, image_shape=(256, 256),
min_image_shape=(128, 128),
timeout=15, retries=3, image_processor=default_image_processor,
upscale_interpolation=cv2.INTER_CUBIC,
downscale_interpolation=cv2.INTER_AREA,
feature_extractor=default_feature_extractor,
):
try:
map_sample_fn = partial(map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
timeout=timeout, retries=retries, image_processor=image_processor,
upscale_interpolation=upscale_interpolation,
downscale_interpolation=downscale_interpolation)
with ThreadPoolExecutor(max_workers=num_threads) as executor:
executor.map(map_sample_fn, batch["url"], batch['caption'])
features = feature_extractor(batch)
url, caption = features["url"], features["caption"]
executor.map(map_sample_fn, url, caption)
except Exception as e:
print(f"Error maping batch", e)
traceback.print_exc()
Expand All @@ -149,12 +157,14 @@ def parallel_image_loader(
num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
upscale_interpolation=cv2.INTER_CUBIC,
downscale_interpolation=cv2.INTER_AREA,
feature_extractor=default_feature_extractor,
):
map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
min_image_shape=min_image_shape,
timeout=timeout, retries=retries, image_processor=image_processor,
upscale_interpolation=upscale_interpolation,
downscale_interpolation=downscale_interpolation)
downscale_interpolation=downscale_interpolation,
feature_extractor=feature_extractor)
shard_len = len(dataset) // num_workers
print(f"Local Shard lengths: {shard_len}")
with multiprocessing.Pool(num_workers) as pool:
Expand All @@ -181,6 +191,7 @@ def __init__(
image_processor=default_image_processor,
upscale_interpolation=cv2.INTER_CUBIC,
downscale_interpolation=cv2.INTER_AREA,
feature_extractor=default_feature_extractor,
):
self.dataset = dataset
self.num_workers = num_workers
Expand All @@ -191,7 +202,8 @@ def __init__(
num_workers=num_workers,
timeout=timeout, retries=retries, image_processor=image_processor,
upscale_interpolation=upscale_interpolation,
downscale_interpolation=downscale_interpolation)
downscale_interpolation=downscale_interpolation,
feature_extractor=feature_extractor)
self.thread = threading.Thread(target=loader, args=(dataset,))
self.thread.start()

Expand Down Expand Up @@ -256,6 +268,7 @@ def __init__(
image_processor=default_image_processor,
upscale_interpolation=cv2.INTER_CUBIC,
downscale_interpolation=cv2.INTER_AREA,
feature_extractor=default_feature_extractor,
):
if isinstance(dataset, str):
dataset_path = dataset
Expand All @@ -281,7 +294,8 @@ def __init__(
num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
timeout=timeout, retries=retries, image_processor=image_processor,
upscale_interpolation=upscale_interpolation,
downscale_interpolation=downscale_interpolation)
downscale_interpolation=downscale_interpolation,
feature_extractor=feature_extractor)
self.batch_size = batch_size

# Launch a thread to load batches in the background
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
setup(
name='flaxdiff',
packages=find_packages(),
version='0.1.31',
version='0.1.32',
description='A versatile and easy to understand Diffusion library',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 489809a

Please sign in to comment.