Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefect 3.0.10 fails to serialize context when using task_fun.delay() #15753

Open
gigaverse-oz opened this issue Oct 18, 2024 · 4 comments
Open
Labels
bug Something isn't working

Comments

@gigaverse-oz
Copy link
Contributor

gigaverse-oz commented Oct 18, 2024

Bug summary

I'm trying to launch a task using the delay() from a flow.
When starting the task, the context to the tasks includes the flow which is not serializable object (by either pickle or json).

  File "/workspaces/gigaverse-ai/.venv/lib/python3.11/site-packages/prefect/results.py", line 1013, in serialize_result
    raise SerializationError(
prefect.exceptions.SerializationError: Failed to serialize object of type 'dict' with serializer 'pickle'. You can try a different serializer (e.g. result_serializer="json") or disabling persistence (persist_result=False) for this flow or task.
11:47:31.785 | ERROR   | Flow run 'super-hamster' - Finished in state Failed('Flow run encountered an exception: SerializationError: Failed to serialize object of type \'dict\' with serializer \'pickle\'. You can try a different serializer (e.g. result_serializer="json") or disabling persistence (persist_result=False) for this flow or task.')
2024-10-18 11:47:31.791 | ERROR    | __prefect_loader__:start_stream_flow:35 - topic_poll_flow failed: Failed to serialize object of type 'dict' with serializer 'pickle'. You can try a different serializer (e.g. result_serializer="json") or disabling persistence (persist_result=False) for this flow or task.

More debuging information in the additional context

Version info (prefect version output)

Version:             3.0.10
API version:         0.8.4
Python version:      3.11.6
Git commit:          3aa2d893
Built:               Tue, Oct 15, 2024 1:31 PM
OS/Arch:             linux/x86_64
Profile:             ephemeral
Server type:         server
Pydantic version:    2.9.2

Additional context

Debugged the error and found that the following field was failing the serialization:
context["flow_run_context"]["flow"] which is of type Flow

The real exception in serialize_result():

TypeError("cannot pickle '_thread.RLock' object")

In order to double check that this is the only problematic field, changed the following in the EngineContext and it passed, but the task worker fails cause it needs the flow:

   def serialize(self):
        return self.model_dump(
            include={
                "flow_run",
                # "flow", - works when disabled
                "parameters",
                "log_prints",
                "start_time",
                "input_keyset",
                "result_store",
                "persist_result",
            },
            exclude_unset=True,
        )
@gigaverse-oz gigaverse-oz added the bug Something isn't working label Oct 18, 2024
@desertaxle
Copy link
Member

Thanks for the bug report @gigaverse-oz! Can you provide some example code where you see this error? This simple flow works for me:

from prefect import flow, task


@task
def delay_task():
    print("Hello")

@flow(log_prints=True)
def delay_flow():
    delay_task.delay()
    
if __name__ == "__main__":
    delay_flow()

so I suspect there are some additional variables at play causing this issue.

@gigaverse-oz
Copy link
Contributor Author

gigaverse-oz commented Oct 18, 2024

Hi @desertaxle,
Thanks for the response.

The setup is as follows:

I have a pod with prefect server.
I serve multiple flows in a different pods with PREFECT_URL_API="http://localhost:4200/api":

def get_current_iso_time():
    """
    Returns the current time in ISO format with UTC timezone.

    Returns:
        datetime: The current datetime with timezone set to UTC.
    """
    return datetime.now(timezone.utc)


class PrefectFlowInputBase(BaseModel):
    """
    Base model for all Prefect flow input models.

    Attributes:
        timestamp (Optional[datetime]): The timestamp indicating when the
            input was created. Defaults to the current UTC time.
    """

    timestamp: Optional[datetime] = Field(default_factory=get_current_iso_time)


class EndStreamFlowInput(PrefectFlowInputBase):
    """
    Input model for a snapshot-related Prefect flow.

    Attributes:
        channel_id (str): The ID of the channel the snapshot flow should process.
    """

    channel_id: str
    recording_url: Optional[str] = Field(None, description="URL of the recording, if available")


class StartStreamFlowInput(PrefectFlowInputBase):
    """
    Input model for a snapshot-related Prefect flow.

    Attributes:
        channel_id (str): The ID of the channel the snapshot flow should process.
    """

    channel_id: str


class SnapshotFlowInput(PrefectFlowInputBase):
    """
    Input model for a snapshot-related Prefect flow.

    Attributes:
        channel_id (str): The ID of the channel the snapshot flow should process.
    """

    channel_id: str


class TopicPollFlowInput(PrefectFlowInputBase):
    """
    Input model for a topic polling Prefect flow.

    Attributes:
        channel_id (str): The ID of the channel the topic poll flow should process.
    """

    channel_id: str

@task(log_prints=False, persist_result=False)
async def topic_poll_task(flow_input: TopicPollFlowInput) -> bool:
    """
    Orchestrates the topic poll workflow for a given channel, initializes the worker to
    create a topic poll, and schedules a new task if a livestream is active.

    Args:
        channel_id (str): The ID of the channel for which the topic poll is being generated.
        timestamp (datetime): The timestamp indicating when the topic poll flow started.

    Returns:
        bool: True if a new topic poll task was successfully scheduled; False otherwise.
    """
    # SOME PROCESSING, WE NEVER REALLY GET HERE WHEN USING DELAY SO PUT WHAT EVER YOU WANT
    return True

@flow(log_prints=False, persist_result=False)
async def topic_poll_flow(flow_input: TopicPollFlowInput) -> bool:
    logger.info(f"Starting topic poll flow: {flow_input.channel_id}. PID: {os.getpid()}")
    await topic_poll_task(flow_input=flow_input) # THIS WORKS, also topic_poll_task.submit(), but in the same worker.
    await topic_poll_task.delay(flow_input=flow_input) # FAILS WITH MULTIPLE VARIATIONS (with await, without await) DOESNT WORK
    return True

@flow
async def start_stream_flow(flow_input: StartStreamFlowInput):
    logger.info(f"Stream {flow_input.channel_id} started. {os.getpid()}")
    list_of_flows = []
    # for i in range(10):
    list_of_flows.append(
        asyncio.create_task(
            topic_poll_flow(TopicPollFlowInput(**flow_input.model_dump())), name=topic_poll_flow.__name__
        )
    )
    # )

    done, pending = await asyncio.wait(list_of_flows, timeout=600)

    if pending:
        raise Exception("Not all tasks are finished")

    for task in done:
        task: asyncio.Task = task
        if task.exception():
            logger.error(f"{task.get_name()} failed: {str(task.exception())}")
            continue
        logger.info(f"{task.get_name()} finished succesfully")



def serve_multiple_flows(list_of_flows: List[Flow], concurrent_limit: int = 10):
    list_of_deployments = [flow.to_deployment(name=flow.name) for flow in list_of_flows]
    serve(*list_of_deployments, limit=concurrent_limit)
    
if __name__ == "__main__":
        list_of_served_flows = [start_stream_flow, topic_poll_flow]
        serve_multiple_flows(list_of_served_flows, concurrent_limit=10)

I copied the relevant code snippets and functions from multiple files. this is not a "working" code in a single file.

@desertaxle
Copy link
Member

Thanks for the example @gigaverse-oz! Unfortunately, I wasn't able to reproduce the issue with your example.

Here's the code that I ran:

import asyncio
from datetime import datetime, timezone
import os
from typing import List, Optional
from prefect import Flow, flow, serve, task
from prefect.logging import get_logger
from pydantic import BaseModel, Field

logger = get_logger(__name__)


def get_current_iso_time():
    """
    Returns the current time in ISO format with UTC timezone.

    Returns:
        datetime: The current datetime with timezone set to UTC.
    """
    return datetime.now(timezone.utc)


class PrefectFlowInputBase(BaseModel):
    """
    Base model for all Prefect flow input models.

    Attributes:
        timestamp (Optional[datetime]): The timestamp indicating when the
            input was created. Defaults to the current UTC time.
    """

    timestamp: Optional[datetime] = Field(default_factory=get_current_iso_time)


class EndStreamFlowInput(PrefectFlowInputBase):
    """
    Input model for a snapshot-related Prefect flow.

    Attributes:
        channel_id (str): The ID of the channel the snapshot flow should process.
    """

    channel_id: str
    recording_url: Optional[str] = Field(
        None, description="URL of the recording, if available"
    )


class StartStreamFlowInput(PrefectFlowInputBase):
    """
    Input model for a snapshot-related Prefect flow.

    Attributes:
        channel_id (str): The ID of the channel the snapshot flow should process.
    """

    channel_id: str


class SnapshotFlowInput(PrefectFlowInputBase):
    """
    Input model for a snapshot-related Prefect flow.

    Attributes:
        channel_id (str): The ID of the channel the snapshot flow should process.
    """

    channel_id: str


class TopicPollFlowInput(PrefectFlowInputBase):
    """
    Input model for a topic polling Prefect flow.

    Attributes:
        channel_id (str): The ID of the channel the topic poll flow should process.
    """

    channel_id: str


@task(log_prints=False, persist_result=False)
async def topic_poll_task(flow_input: TopicPollFlowInput) -> bool:
    """
    Orchestrates the topic poll workflow for a given channel, initializes the worker to
    create a topic poll, and schedules a new task if a livestream is active.

    Args:
        channel_id (str): The ID of the channel for which the topic poll is being generated.
        timestamp (datetime): The timestamp indicating when the topic poll flow started.

    Returns:
        bool: True if a new topic poll task was successfully scheduled; False otherwise.
    """
    # SOME PROCESSING, WE NEVER REALLY GET HERE WHEN USING DELAY SO PUT WHAT EVER YOU WANT
    return True


@flow(log_prints=False, persist_result=False)
async def topic_poll_flow(flow_input: TopicPollFlowInput) -> bool:
    logger.info(
        f"Starting topic poll flow: {flow_input.channel_id}. PID: {os.getpid()}"
    )
    topic_poll_task.delay(
        flow_input=flow_input
    ) 
    return True


@flow
async def start_stream_flow(flow_input: StartStreamFlowInput):
    logger.info(f"Stream {flow_input.channel_id} started. {os.getpid()}")
    list_of_flows = []
    # for i in range(10):
    list_of_flows.append(
        asyncio.create_task(
            topic_poll_flow(TopicPollFlowInput(**flow_input.model_dump())),
            name=topic_poll_flow.__name__,
        )
    )
    # )

    done, pending = await asyncio.wait(list_of_flows, timeout=600)

    if pending:
        raise Exception("Not all tasks are finished")

    for task in done:
        task: asyncio.Task = task
        if task.exception():
            logger.error(f"{task.get_name()} failed: {str(task.exception())}")
            continue
        logger.info(f"{task.get_name()} finished succesfully")


def serve_multiple_flows(list_of_flows: List[Flow], concurrent_limit: int = 10):
    list_of_deployments = [flow.to_deployment(name=flow.name) for flow in list_of_flows]
    serve(*list_of_deployments, limit=concurrent_limit)


if __name__ == "__main__":
    list_of_served_flows = [start_stream_flow, topic_poll_flow]
    serve_multiple_flows(list_of_served_flows, concurrent_limit=10)

Maybe one of your other flows cannot be pickled? You can check to see if a flow is picklable like this:

import cloudpickle

print(cloudpickle.dumps(start_stream_flow))

@gigaverse-oz
Copy link
Contributor Author

gigaverse-oz commented Oct 18, 2024

Hi @desertaxle ,
The code you run is fine and work on my machine as well.
In practice my code is in multiple files, after long investigation I found the following file structure fails. Could you verify that on your machine as well?
BTW, using the cloudpickle.dumps() on the flow in main always works. If it is used inside the first flow, it fails. It seems the serialization is very sensitive to the structure of the files and imports.

Many thanks

# pydantic_models.py
from datetime import datetime, timezone
from pydantic import BaseModel, Field
from typing import Optional

def get_current_iso_time():
    """
    Returns the current time in ISO format with UTC timezone.

    Returns:
        datetime: The current datetime with timezone set to UTC.
    """
    return datetime.now(timezone.utc)


class PrefectFlowInputBase(BaseModel):
    """
    Base model for all Prefect flow input models.

    Attributes:
        timestamp (Optional[datetime]): The timestamp indicating when the
            input was created. Defaults to the current UTC time.
    """

    timestamp: Optional[datetime] = Field(default_factory=get_current_iso_time)


class EndStreamFlowInput(PrefectFlowInputBase):
    """
    Input model for a snapshot-related Prefect flow.

    Attributes:
        channel_id (str): The ID of the channel the snapshot flow should process.
    """

    channel_id: str
    recording_url: Optional[str] = Field(
        None, description="URL of the recording, if available"
    )


class StartStreamFlowInput(PrefectFlowInputBase):
    """
    Input model for a snapshot-related Prefect flow.

    Attributes:
        channel_id (str): The ID of the channel the snapshot flow should process.
    """

    channel_id: str


class SnapshotFlowInput(PrefectFlowInputBase):
    """
    Input model for a snapshot-related Prefect flow.

    Attributes:
        channel_id (str): The ID of the channel the snapshot flow should process.
    """

    channel_id: str


class TopicPollFlowInput(PrefectFlowInputBase):
    """
    Input model for a topic polling Prefect flow.

    Attributes:
        channel_id (str): The ID of the channel the topic poll flow should process.
    """

    channel_id: str
    
 # topic_poll_flow.py
from loguru import logger
from prefect import flow, task
from gv.ai.livestream_prefect_worker.pydantic_models import TopicPollFlowInput

@task(log_prints=False, persist_result=False)
async def topic_poll_task(flow_input: TopicPollFlowInput) -> bool:
    """
    Orchestrates the topic poll workflow for a given channel, initializes the worker to
    create a topic poll, and schedules a new task if a livestream is active.

    Args:
        channel_id (str): The ID of the channel for which the topic poll is being generated.
        timestamp (datetime): The timestamp indicating when the topic poll flow started.

    Returns:
        bool: True if a new topic poll task was successfully scheduled; False otherwise.
    """
    # SOME PROCESSING, WE NEVER REALLY GET HERE WHEN USING DELAY SO PUT WHAT EVER YOU WANT
    return True


@flow(log_prints=False, persist_result=False)
async def topic_poll_flow(flow_input: TopicPollFlowInput) -> bool:
    logger.info(
        f"Starting topic poll flow: {flow_input.channel_id}. PID: {os.getpid()}"
    )
    topic_poll_task.delay(
        flow_input=flow_input
    ) 
    return True
    
    
# main.py
import asyncio
import os
from typing import List, Optional
from prefect import Flow, flow, serve, task
from prefect.logging import get_logger

logger = get_logger(__name__)


from gv.ai.livestream_prefect_worker.pydantic_models import TopicPollFlowInput, StartStreamFlowInput
from gv.ai.livestream_prefect_worker.topic_poll_flow import topic_poll_task, topic_poll_flow

# from .pydantic_models import TopicPollFlowInput, StartStreamFlowInput # This import version also raises the error even if topic_poll_task, topic_poll_flow are in the main.py

@flow
async def start_stream_flow(flow_input: StartStreamFlowInput):
    logger.info(f"Stream {flow_input.channel_id} started. {os.getpid()}")
    list_of_flows = []
    # for i in range(10):
    list_of_flows.append(
        asyncio.create_task(
            topic_poll_flow(flow_input=TopicPollFlowInput(**flow_input.model_dump())),
            name=topic_poll_flow.__name__,
        )
    )
    # )

    done, pending = await asyncio.wait(list_of_flows, timeout=600)

    if pending:
        raise Exception("Not all tasks are finished")

    for task in done:
        task: asyncio.Task = task
        if task.exception():
            logger.error(f"{task.get_name()} failed: {str(task.exception())}")
            continue
        logger.info(f"{task.get_name()} finished succesfully")


def serve_multiple_flows(list_of_flows: List[Flow], concurrent_limit: int = 10):
    list_of_deployments = [flow.to_deployment(name=flow.name) for flow in list_of_flows]
    serve(*list_of_deployments, limit=concurrent_limit)


if __name__ == "__main__":
    list_of_served_flows = [start_stream_flow, topic_poll_flow]
    serve_multiple_flows(list_of_served_flows, concurrent_limit=10)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants