-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prefect 3.0.10 fails to serialize context when using task_fun.delay() #15753
Comments
Thanks for the bug report @gigaverse-oz! Can you provide some example code where you see this error? This simple flow works for me: from prefect import flow, task
@task
def delay_task():
print("Hello")
@flow(log_prints=True)
def delay_flow():
delay_task.delay()
if __name__ == "__main__":
delay_flow() so I suspect there are some additional variables at play causing this issue. |
Hi @desertaxle, The setup is as follows: I have a pod with prefect server. def get_current_iso_time():
"""
Returns the current time in ISO format with UTC timezone.
Returns:
datetime: The current datetime with timezone set to UTC.
"""
return datetime.now(timezone.utc)
class PrefectFlowInputBase(BaseModel):
"""
Base model for all Prefect flow input models.
Attributes:
timestamp (Optional[datetime]): The timestamp indicating when the
input was created. Defaults to the current UTC time.
"""
timestamp: Optional[datetime] = Field(default_factory=get_current_iso_time)
class EndStreamFlowInput(PrefectFlowInputBase):
"""
Input model for a snapshot-related Prefect flow.
Attributes:
channel_id (str): The ID of the channel the snapshot flow should process.
"""
channel_id: str
recording_url: Optional[str] = Field(None, description="URL of the recording, if available")
class StartStreamFlowInput(PrefectFlowInputBase):
"""
Input model for a snapshot-related Prefect flow.
Attributes:
channel_id (str): The ID of the channel the snapshot flow should process.
"""
channel_id: str
class SnapshotFlowInput(PrefectFlowInputBase):
"""
Input model for a snapshot-related Prefect flow.
Attributes:
channel_id (str): The ID of the channel the snapshot flow should process.
"""
channel_id: str
class TopicPollFlowInput(PrefectFlowInputBase):
"""
Input model for a topic polling Prefect flow.
Attributes:
channel_id (str): The ID of the channel the topic poll flow should process.
"""
channel_id: str
@task(log_prints=False, persist_result=False)
async def topic_poll_task(flow_input: TopicPollFlowInput) -> bool:
"""
Orchestrates the topic poll workflow for a given channel, initializes the worker to
create a topic poll, and schedules a new task if a livestream is active.
Args:
channel_id (str): The ID of the channel for which the topic poll is being generated.
timestamp (datetime): The timestamp indicating when the topic poll flow started.
Returns:
bool: True if a new topic poll task was successfully scheduled; False otherwise.
"""
# SOME PROCESSING, WE NEVER REALLY GET HERE WHEN USING DELAY SO PUT WHAT EVER YOU WANT
return True
@flow(log_prints=False, persist_result=False)
async def topic_poll_flow(flow_input: TopicPollFlowInput) -> bool:
logger.info(f"Starting topic poll flow: {flow_input.channel_id}. PID: {os.getpid()}")
await topic_poll_task(flow_input=flow_input) # THIS WORKS, also topic_poll_task.submit(), but in the same worker.
await topic_poll_task.delay(flow_input=flow_input) # FAILS WITH MULTIPLE VARIATIONS (with await, without await) DOESNT WORK
return True
@flow
async def start_stream_flow(flow_input: StartStreamFlowInput):
logger.info(f"Stream {flow_input.channel_id} started. {os.getpid()}")
list_of_flows = []
# for i in range(10):
list_of_flows.append(
asyncio.create_task(
topic_poll_flow(TopicPollFlowInput(**flow_input.model_dump())), name=topic_poll_flow.__name__
)
)
# )
done, pending = await asyncio.wait(list_of_flows, timeout=600)
if pending:
raise Exception("Not all tasks are finished")
for task in done:
task: asyncio.Task = task
if task.exception():
logger.error(f"{task.get_name()} failed: {str(task.exception())}")
continue
logger.info(f"{task.get_name()} finished succesfully")
def serve_multiple_flows(list_of_flows: List[Flow], concurrent_limit: int = 10):
list_of_deployments = [flow.to_deployment(name=flow.name) for flow in list_of_flows]
serve(*list_of_deployments, limit=concurrent_limit)
if __name__ == "__main__":
list_of_served_flows = [start_stream_flow, topic_poll_flow]
serve_multiple_flows(list_of_served_flows, concurrent_limit=10) I copied the relevant code snippets and functions from multiple files. this is not a "working" code in a single file. |
Thanks for the example @gigaverse-oz! Unfortunately, I wasn't able to reproduce the issue with your example. Here's the code that I ran: import asyncio
from datetime import datetime, timezone
import os
from typing import List, Optional
from prefect import Flow, flow, serve, task
from prefect.logging import get_logger
from pydantic import BaseModel, Field
logger = get_logger(__name__)
def get_current_iso_time():
"""
Returns the current time in ISO format with UTC timezone.
Returns:
datetime: The current datetime with timezone set to UTC.
"""
return datetime.now(timezone.utc)
class PrefectFlowInputBase(BaseModel):
"""
Base model for all Prefect flow input models.
Attributes:
timestamp (Optional[datetime]): The timestamp indicating when the
input was created. Defaults to the current UTC time.
"""
timestamp: Optional[datetime] = Field(default_factory=get_current_iso_time)
class EndStreamFlowInput(PrefectFlowInputBase):
"""
Input model for a snapshot-related Prefect flow.
Attributes:
channel_id (str): The ID of the channel the snapshot flow should process.
"""
channel_id: str
recording_url: Optional[str] = Field(
None, description="URL of the recording, if available"
)
class StartStreamFlowInput(PrefectFlowInputBase):
"""
Input model for a snapshot-related Prefect flow.
Attributes:
channel_id (str): The ID of the channel the snapshot flow should process.
"""
channel_id: str
class SnapshotFlowInput(PrefectFlowInputBase):
"""
Input model for a snapshot-related Prefect flow.
Attributes:
channel_id (str): The ID of the channel the snapshot flow should process.
"""
channel_id: str
class TopicPollFlowInput(PrefectFlowInputBase):
"""
Input model for a topic polling Prefect flow.
Attributes:
channel_id (str): The ID of the channel the topic poll flow should process.
"""
channel_id: str
@task(log_prints=False, persist_result=False)
async def topic_poll_task(flow_input: TopicPollFlowInput) -> bool:
"""
Orchestrates the topic poll workflow for a given channel, initializes the worker to
create a topic poll, and schedules a new task if a livestream is active.
Args:
channel_id (str): The ID of the channel for which the topic poll is being generated.
timestamp (datetime): The timestamp indicating when the topic poll flow started.
Returns:
bool: True if a new topic poll task was successfully scheduled; False otherwise.
"""
# SOME PROCESSING, WE NEVER REALLY GET HERE WHEN USING DELAY SO PUT WHAT EVER YOU WANT
return True
@flow(log_prints=False, persist_result=False)
async def topic_poll_flow(flow_input: TopicPollFlowInput) -> bool:
logger.info(
f"Starting topic poll flow: {flow_input.channel_id}. PID: {os.getpid()}"
)
topic_poll_task.delay(
flow_input=flow_input
)
return True
@flow
async def start_stream_flow(flow_input: StartStreamFlowInput):
logger.info(f"Stream {flow_input.channel_id} started. {os.getpid()}")
list_of_flows = []
# for i in range(10):
list_of_flows.append(
asyncio.create_task(
topic_poll_flow(TopicPollFlowInput(**flow_input.model_dump())),
name=topic_poll_flow.__name__,
)
)
# )
done, pending = await asyncio.wait(list_of_flows, timeout=600)
if pending:
raise Exception("Not all tasks are finished")
for task in done:
task: asyncio.Task = task
if task.exception():
logger.error(f"{task.get_name()} failed: {str(task.exception())}")
continue
logger.info(f"{task.get_name()} finished succesfully")
def serve_multiple_flows(list_of_flows: List[Flow], concurrent_limit: int = 10):
list_of_deployments = [flow.to_deployment(name=flow.name) for flow in list_of_flows]
serve(*list_of_deployments, limit=concurrent_limit)
if __name__ == "__main__":
list_of_served_flows = [start_stream_flow, topic_poll_flow]
serve_multiple_flows(list_of_served_flows, concurrent_limit=10) Maybe one of your other flows cannot be pickled? You can check to see if a flow is picklable like this: import cloudpickle
print(cloudpickle.dumps(start_stream_flow)) |
Hi @desertaxle , Many thanks # pydantic_models.py
from datetime import datetime, timezone
from pydantic import BaseModel, Field
from typing import Optional
def get_current_iso_time():
"""
Returns the current time in ISO format with UTC timezone.
Returns:
datetime: The current datetime with timezone set to UTC.
"""
return datetime.now(timezone.utc)
class PrefectFlowInputBase(BaseModel):
"""
Base model for all Prefect flow input models.
Attributes:
timestamp (Optional[datetime]): The timestamp indicating when the
input was created. Defaults to the current UTC time.
"""
timestamp: Optional[datetime] = Field(default_factory=get_current_iso_time)
class EndStreamFlowInput(PrefectFlowInputBase):
"""
Input model for a snapshot-related Prefect flow.
Attributes:
channel_id (str): The ID of the channel the snapshot flow should process.
"""
channel_id: str
recording_url: Optional[str] = Field(
None, description="URL of the recording, if available"
)
class StartStreamFlowInput(PrefectFlowInputBase):
"""
Input model for a snapshot-related Prefect flow.
Attributes:
channel_id (str): The ID of the channel the snapshot flow should process.
"""
channel_id: str
class SnapshotFlowInput(PrefectFlowInputBase):
"""
Input model for a snapshot-related Prefect flow.
Attributes:
channel_id (str): The ID of the channel the snapshot flow should process.
"""
channel_id: str
class TopicPollFlowInput(PrefectFlowInputBase):
"""
Input model for a topic polling Prefect flow.
Attributes:
channel_id (str): The ID of the channel the topic poll flow should process.
"""
channel_id: str
# topic_poll_flow.py
from loguru import logger
from prefect import flow, task
from gv.ai.livestream_prefect_worker.pydantic_models import TopicPollFlowInput
@task(log_prints=False, persist_result=False)
async def topic_poll_task(flow_input: TopicPollFlowInput) -> bool:
"""
Orchestrates the topic poll workflow for a given channel, initializes the worker to
create a topic poll, and schedules a new task if a livestream is active.
Args:
channel_id (str): The ID of the channel for which the topic poll is being generated.
timestamp (datetime): The timestamp indicating when the topic poll flow started.
Returns:
bool: True if a new topic poll task was successfully scheduled; False otherwise.
"""
# SOME PROCESSING, WE NEVER REALLY GET HERE WHEN USING DELAY SO PUT WHAT EVER YOU WANT
return True
@flow(log_prints=False, persist_result=False)
async def topic_poll_flow(flow_input: TopicPollFlowInput) -> bool:
logger.info(
f"Starting topic poll flow: {flow_input.channel_id}. PID: {os.getpid()}"
)
topic_poll_task.delay(
flow_input=flow_input
)
return True
# main.py
import asyncio
import os
from typing import List, Optional
from prefect import Flow, flow, serve, task
from prefect.logging import get_logger
logger = get_logger(__name__)
from gv.ai.livestream_prefect_worker.pydantic_models import TopicPollFlowInput, StartStreamFlowInput
from gv.ai.livestream_prefect_worker.topic_poll_flow import topic_poll_task, topic_poll_flow
# from .pydantic_models import TopicPollFlowInput, StartStreamFlowInput # This import version also raises the error even if topic_poll_task, topic_poll_flow are in the main.py
@flow
async def start_stream_flow(flow_input: StartStreamFlowInput):
logger.info(f"Stream {flow_input.channel_id} started. {os.getpid()}")
list_of_flows = []
# for i in range(10):
list_of_flows.append(
asyncio.create_task(
topic_poll_flow(flow_input=TopicPollFlowInput(**flow_input.model_dump())),
name=topic_poll_flow.__name__,
)
)
# )
done, pending = await asyncio.wait(list_of_flows, timeout=600)
if pending:
raise Exception("Not all tasks are finished")
for task in done:
task: asyncio.Task = task
if task.exception():
logger.error(f"{task.get_name()} failed: {str(task.exception())}")
continue
logger.info(f"{task.get_name()} finished succesfully")
def serve_multiple_flows(list_of_flows: List[Flow], concurrent_limit: int = 10):
list_of_deployments = [flow.to_deployment(name=flow.name) for flow in list_of_flows]
serve(*list_of_deployments, limit=concurrent_limit)
if __name__ == "__main__":
list_of_served_flows = [start_stream_flow, topic_poll_flow]
serve_multiple_flows(list_of_served_flows, concurrent_limit=10) |
Bug summary
I'm trying to launch a task using the delay() from a flow.
When starting the task, the context to the tasks includes the flow which is not serializable object (by either pickle or json).
More debuging information in the additional context
Version info (
prefect version
output)Additional context
Debugged the error and found that the following field was failing the serialization:
context["flow_run_context"]["flow"]
which is of typeFlow
The real exception in
serialize_result()
:TypeError("cannot pickle '_thread.RLock' object")
In order to double check that this is the only problematic field, changed the following in the EngineContext and it passed, but the task worker fails cause it needs the flow:
The text was updated successfully, but these errors were encountered: