From a7c521b82b0f2073ea12764288a1f48228482950 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 24 Jul 2024 17:37:12 -0700 Subject: [PATCH] [core][distributed] fix zmq hang (#6759) (cherry picked from commit 740374d456a638df98ffbc7d9dab328752330e62) --- .../device_communicators/shm_broadcast.py | 60 +++++++------------ 1 file changed, 21 insertions(+), 39 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index db0064951cd1b..3ab7b1b1fc8be 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -9,7 +9,7 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup -from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore +from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore import vllm.envs as envs from vllm.logger import init_logger @@ -145,9 +145,7 @@ class Handle: buffer: Optional[ShmRingBuffer] = None local_subscribe_port: Optional[int] = None - local_sync_port: Optional[int] = None remote_subscribe_port: Optional[int] = None - remote_sync_port: Optional[int] = None class MessageQueue: @@ -181,38 +179,36 @@ def __init__( self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, max_chunks) - self.local_socket = context.socket(PUB) + # XPUB is very similar to PUB, + # except that it can receive subscription messages + # to confirm the number of subscribers + self.local_socket = context.socket(XPUB) + # set the verbose option so that we can receive every subscription + # message. otherwise, we will only receive the first subscription + # see http://api.zeromq.org/3-3:zmq-setsockopt for more details + self.local_socket.setsockopt(XPUB_VERBOSE, True) local_subscribe_port = get_open_port() self.local_socket.bind(f"tcp://*:{local_subscribe_port}") - self.local_sync_socket = context.socket(REP) - local_sync_port = get_open_port() - self.local_sync_socket.bind(f"tcp://*:{local_sync_port}") self.current_idx = 0 else: self.buffer = None # type: ignore local_subscribe_port = None - local_sync_port = None self.local_socket = None - self.local_sync_socket = None self.current_idx = -1 if n_remote_reader > 0: # for remote readers, we will: # create a publish-subscribe socket to communicate large data - self.remote_socket = context.socket(PUB) + self.remote_socket = context.socket(XPUB) + self.remote_socket.setsockopt(XPUB_VERBOSE, True) remote_subscribe_port = get_open_port() self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}") - self.remote_sync_socket = context.socket(REP) - remote_sync_port = get_open_port() - self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}") else: remote_subscribe_port = None - remote_sync_port = None self.remote_socket = None - self.remote_sync_socket = None self._is_writer = True self._is_local_reader = False @@ -225,9 +221,7 @@ def __init__( local_reader_ranks=local_reader_ranks, buffer=self.buffer, local_subscribe_port=local_subscribe_port, - local_sync_port=local_sync_port, remote_subscribe_port=remote_subscribe_port, - remote_sync_port=remote_sync_port, ) def export_handle(self) -> Handle: @@ -254,12 +248,7 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self.local_socket.connect( f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}") - self.local_sync_socket = context.socket(REQ) - self.local_sync_socket.connect( - f"tcp://{handle.connect_ip}:{handle.local_sync_port}") - self.remote_socket = None - self.remote_sync_socket = None else: self.buffer = None # type: ignore self.current_idx = -1 @@ -268,17 +257,12 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self._is_remote_reader = True self.local_socket = None - self.local_sync_socket = None self.remote_socket = context.socket(SUB) self.remote_socket.setsockopt_string(SUBSCRIBE, "") self.remote_socket.connect( f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}") - self.remote_sync_socket = context.socket(REQ) - self.remote_sync_socket.connect( - f"tcp://{handle.connect_ip}:{handle.remote_sync_port}") - return self def wait_until_ready(self): @@ -290,29 +274,27 @@ def wait_until_ready(self): # local readers for i in range(self.n_local_reader): - recv = self.local_sync_socket.recv() - assert recv == b"READY" - self.local_sync_socket.send(b"READY") + # wait for subscription messages from all local readers + self.local_socket.recv() if self.n_local_reader > 0: + # send a message to all local readers + # to make sure the publish channel is working self.local_socket.send(b"READY") # remote readers for i in range(self.n_remote_reader): - recv = self.remote_sync_socket.recv() - assert recv == b"READY" - self.remote_sync_socket.send(b"READY") + # wait for subscription messages from all remote readers + self.remote_socket.recv() if self.n_remote_reader > 0: + # send a message to all remote readers + # to make sure the publish channel is working self.remote_socket.send(b"READY") elif self._is_local_reader: - self.local_sync_socket.send(b"READY") - recv = self.local_sync_socket.recv() - assert recv == b"READY" + # wait for the writer to send a message recv = self.local_socket.recv() assert recv == b"READY" elif self._is_remote_reader: - self.remote_sync_socket.send(b"READY") - recv = self.remote_sync_socket.recv() - assert recv == b"READY" + # wait for the writer to send a message recv = self.remote_socket.recv() assert recv == b"READY"