From 489809a8cbe2d88a014085b9b479a48a9e988e2f Mon Sep 17 00:00:00 2001 From: Ashish Kumar Singh Date: Tue, 10 Sep 2024 01:16:55 -0400 Subject: [PATCH] feat: added support for passing custom feature extractors to online dataset loader --- flaxdiff/data/online_loader.py | 24 +++++++++++++++++++----- setup.py | 2 +- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/flaxdiff/data/online_loader.py b/flaxdiff/data/online_loader.py index 6dba6fa..6c64506 100644 --- a/flaxdiff/data/online_loader.py +++ b/flaxdiff/data/online_loader.py @@ -117,7 +117,12 @@ def map_sample( # "error": str(e) # }) pass - + +def default_feature_extractor(sample): + return { + "url": sample["url"], + "caption": sample["caption"], + } def map_batch( batch, num_threads=256, image_shape=(256, 256), @@ -125,6 +130,7 @@ def map_batch( timeout=15, retries=3, image_processor=default_image_processor, upscale_interpolation=cv2.INTER_CUBIC, downscale_interpolation=cv2.INTER_AREA, + feature_extractor=default_feature_extractor, ): try: map_sample_fn = partial(map_sample, image_shape=image_shape, min_image_shape=min_image_shape, @@ -132,7 +138,9 @@ def map_batch( upscale_interpolation=upscale_interpolation, downscale_interpolation=downscale_interpolation) with ThreadPoolExecutor(max_workers=num_threads) as executor: - executor.map(map_sample_fn, batch["url"], batch['caption']) + features = feature_extractor(batch) + url, caption = features["url"], features["caption"] + executor.map(map_sample_fn, url, caption) except Exception as e: print(f"Error maping batch", e) traceback.print_exc() @@ -149,12 +157,14 @@ def parallel_image_loader( num_threads=256, timeout=15, retries=3, image_processor=default_image_processor, upscale_interpolation=cv2.INTER_CUBIC, downscale_interpolation=cv2.INTER_AREA, + feature_extractor=default_feature_extractor, ): map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape, min_image_shape=min_image_shape, timeout=timeout, retries=retries, image_processor=image_processor, upscale_interpolation=upscale_interpolation, - downscale_interpolation=downscale_interpolation) + downscale_interpolation=downscale_interpolation, + feature_extractor=feature_extractor) shard_len = len(dataset) // num_workers print(f"Local Shard lengths: {shard_len}") with multiprocessing.Pool(num_workers) as pool: @@ -181,6 +191,7 @@ def __init__( image_processor=default_image_processor, upscale_interpolation=cv2.INTER_CUBIC, downscale_interpolation=cv2.INTER_AREA, + feature_extractor=default_feature_extractor, ): self.dataset = dataset self.num_workers = num_workers @@ -191,7 +202,8 @@ def __init__( num_workers=num_workers, timeout=timeout, retries=retries, image_processor=image_processor, upscale_interpolation=upscale_interpolation, - downscale_interpolation=downscale_interpolation) + downscale_interpolation=downscale_interpolation, + feature_extractor=feature_extractor) self.thread = threading.Thread(target=loader, args=(dataset,)) self.thread.start() @@ -256,6 +268,7 @@ def __init__( image_processor=default_image_processor, upscale_interpolation=cv2.INTER_CUBIC, downscale_interpolation=cv2.INTER_AREA, + feature_extractor=default_feature_extractor, ): if isinstance(dataset, str): dataset_path = dataset @@ -281,7 +294,8 @@ def __init__( num_workers=num_workers, batch_size=batch_size, num_threads=num_threads, timeout=timeout, retries=retries, image_processor=image_processor, upscale_interpolation=upscale_interpolation, - downscale_interpolation=downscale_interpolation) + downscale_interpolation=downscale_interpolation, + feature_extractor=feature_extractor) self.batch_size = batch_size # Launch a thread to load batches in the background diff --git a/setup.py b/setup.py index f7acb66..8b5f880 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setup( name='flaxdiff', packages=find_packages(), - version='0.1.31', + version='0.1.32', description='A versatile and easy to understand Diffusion library', long_description=open('README.md').read(), long_description_content_type='text/markdown',