diff --git a/src/spdl/dataloader/_builder.py b/src/spdl/dataloader/_builder.py index 40d4a434..d3e81fcf 100644 --- a/src/spdl/dataloader/_builder.py +++ b/src/spdl/dataloader/_builder.py @@ -264,6 +264,9 @@ def _ordered_pipe( else hooks ) + # This has been checked in `PipelineBuilder.pipe()` + assert not inspect.isasyncgenfunction(op) + afunc = _convert.convert_to_async(op, executor) async def _wrap(item: T) -> asyncio.Task[U]: @@ -634,13 +637,15 @@ def pipe( f"Found: {output_order}" ) - if name is None: - if hasattr(op, "__name__"): - name = op.__name__ # type: ignore[attr-defined] - else: - name = op.__class__.__name__ - - _convert.validate_op(op, executor, output_order) + if inspect.iscoroutinefunction(op) or inspect.isasyncgenfunction(op): + if executor is not None: + raise ValueError("`executor` cannot be specified when op is async.") + if inspect.isasyncgenfunction(op): + if output_order == "input": + raise ValueError( + "pipe does not support async generator function " + "when `output_order` is 'input'." + ) self._process_args.append( ( @@ -649,7 +654,7 @@ def pipe( "op": op, "executor": executor, "concurrency": concurrency, - "name": name, + "name": name or getattr(op, "__name__", op.__class__.__name__), "hooks": hooks, "report_stats_interval": report_stats_interval, }, diff --git a/src/spdl/dataloader/_convert.py b/src/spdl/dataloader/_convert.py index 5b844176..9729dfbc 100644 --- a/src/spdl/dataloader/_convert.py +++ b/src/spdl/dataloader/_convert.py @@ -93,28 +93,13 @@ def _next() -> U: return afunc -def validate_op( - op: Callables[T, U], - executor: type[Executor] | None, - output_order: str, -) -> None: - if inspect.iscoroutinefunction(op) or inspect.isasyncgenfunction(op): - if executor is not None: - raise ValueError("`executor` cannot be specified when op is async.") - if inspect.isasyncgenfunction(op): - if output_order == "input": - raise ValueError( - "pipe does not support async generator function " - "when output_order is 'input'." - ) - - def convert_to_async( op: Callables[T, U], executor: type[Executor] | None, ) -> AsyncCallables[T, U]: if inspect.iscoroutinefunction(op) or inspect.isasyncgenfunction(op): # op is async function. No need to convert. + assert executor is None # This has been checked in `PipelineBuilder.pipe()` return op # pyre-ignore: [7] if inspect.isgeneratorfunction(op):