Skip to content

Commit

Permalink
Send timeouts in order, and get user
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewFerr committed Jul 30, 2024
1 parent 7959679 commit e0d2d76
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 26 deletions.
24 changes: 6 additions & 18 deletions synapse/handlers/delayed_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@
)

import attr
from twisted.internet import defer
from twisted.internet.interfaces import IDelayedCall

from synapse.api.constants import EventTypes
from synapse.api.errors import Codes, NotFoundError, ShadowBanError, SynapseError
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main.delayed_events import (
Expand Down Expand Up @@ -92,15 +90,8 @@ def __init__(self, hs: "HomeServer"):
async def _schedule_db_events() -> None:
# TODO: Sync all state first, so that affected delayed state events will be cancelled
events, remaining_timeout_delays = await self.store.process_all_delays(self._get_current_ts())
await make_deferred_yieldable(
defer.gatherResults(
[
run_as_background_process("_send_event", self._send_event, *args)
for args in events
],
consumeErrors=True,
)
)
for args in events:
await self._send_event(*args)

for delay_id, user_localpart, relative_delay in remaining_timeout_delays:
self._schedule(delay_id, user_localpart, relative_delay)
Expand Down Expand Up @@ -243,14 +234,11 @@ async def update(self, requester: Requester, delay_id: str, action: str) -> None
self._schedule(delay_id, user_localpart, delay)

elif enum_action == _UpdateDelayedEventAction.SEND:
await self._send_now(delay_id, user_localpart)

async def _send_now(self, delay_id: DelayID, user_localpart: UserLocalpart) -> None:
args, removed_timeout_delay_ids = await self.store.pop_event(delay_id, user_localpart)
args, removed_timeout_delay_ids = await self.store.pop_event(delay_id, user_localpart)

for timeout_delay_id in removed_timeout_delay_ids:
self._unschedule(timeout_delay_id, user_localpart)
await self._send_event(user_localpart, *args)
for timeout_delay_id in removed_timeout_delay_ids:
self._unschedule(timeout_delay_id, user_localpart)
await self._send_event(user_localpart, *args)

async def _send_on_timeout(self, delay_id: DelayID, user_localpart: UserLocalpart) -> None:
del self._delayed_calls[_DelayedCallKey(delay_id, user_localpart)]
Expand Down
26 changes: 18 additions & 8 deletions synapse/storage/databases/main/delayed_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@
JsonDict,
]

# TODO: If a Tuple type hint can be extended, extend the above one
DelayedPartialEventWithUser = Tuple[
UserLocalpart,
RoomID,
EventType,
Optional[StateKey],
Optional[Timestamp],
JsonDict,
]


# TODO: Try to support workers
class DelayedEventsStore(SQLBaseStore):
Expand Down Expand Up @@ -295,7 +305,7 @@ async def get_all_for_user(
]

async def process_all_delays(self, current_ts: Timestamp) -> Tuple[
List[DelayedPartialEvent],
List[DelayedPartialEventWithUser],
List[Tuple[DelayID, UserLocalpart, Delay]],
]:
"""
Expand All @@ -305,34 +315,34 @@ async def process_all_delays(self, current_ts: Timestamp) -> Tuple[
"""

def process_all_delays_txn(txn: LoggingTransaction) -> Tuple[
List[DelayedPartialEvent],
List[DelayedPartialEventWithUser],
List[Tuple[DelayID, UserLocalpart, Delay]],
]:
events: List[DelayedPartialEvent] = []
events: List[DelayedPartialEventWithUser] = []
removed_timeout_delay_ids: Set[DelayID] = set()

txn.execute(
"""
WITH delay_send_times AS (
SELECT delay_rowid, running_since + delay AS send_ts
SELECT delay_rowid, user_localpart, running_since + delay AS send_ts
FROM delayed_events JOIN delayed_event_timeouts USING (delay_rowid)
)
SELECT delay_rowid
SELECT delay_rowid, user_localpart
FROM delay_send_times
WHERE send_ts < ?
ORDER BY send_ts
""",
(current_ts,),
)
for (delay_rowid,) in txn:
for row in txn:
try:
event, removed_timeout_delay_id = self._pop_event_txn(
txn,
keyvalues={"delay_rowid": delay_rowid},
keyvalues={"delay_rowid": row[0]},
)
except NotFoundError:
pass
events.append(event)
events.append((UserLocalpart(row[1]), *event))
removed_timeout_delay_ids |= removed_timeout_delay_id

txn.execute(
Expand Down

0 comments on commit e0d2d76

Please sign in to comment.