Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multithreading for PYRO_STACK #1343

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

fehiepsi
Copy link
Member

Based on the discussion with @fritzo and @eb8680, it seems that the simplest way to support multithreading is to get PYRO_STACK from an intermediate function get_pyro_stack, where we can modify to return a stack for a particular thread.

TODOs:

  • see how to do threading in jax
  • finish the example

@fehiepsi fehiepsi added the WIP label Feb 21, 2022
@chkothe
Copy link

chkothe commented Oct 20, 2024

There's actually a much simpler way to accomplish basically the same, but without any code changes to either user code or any Messengers or other numpyro internals that have to interact with this variable. The trick is to use a subclass of threading.local as per Python's official documentation here ("or a subclass") and their example here. Then every thread automatically sees a unique copy of the data structure.

You can also see the pattern in action for example in Haiku, where they have a similar global state stack (here). The numpyro equivalent would come down to pretty much this:

import threading


class MessengerStack(threading.local):
    """A stack of Messenger instances that is unique to the current thread."""

    def __init__(self):
        super().__init__()
        self.lst = []

    def __len__(self):
        return len(self.lst)

    def __iter__(self):
        return iter(self.lst)

    def __contains__(self, item):
        return item in self.lst

    def __getitem__(self, idx):
        return self.lst[idx]

    def append(self, item):
        self.lst.append(item)

    def index(self, item):
        return self.lst.index(item)

    def pop(self):
        return self.lst.pop()


_PYRO_STACK = MessengerStack()

I'd be happy to prepare a pull request for that since we do have a need for that sort of thread safety.

@fehiepsi
Copy link
Member Author

Thanks @chkothe! Looking forward to it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants