From 044f57a6765a252a861f2d50c67e4984dcd397fb Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 6 Aug 2024 13:22:46 -0400 Subject: [PATCH 01/79] Add DB schema for delayed events --- synapse/storage/schema/__init__.py | 6 ++- .../main/delta/87/01_add_delayed_events.sql | 47 +++++++++++++++++++ .../87/01_add_delayed_events.sql.postgres | 14 ++++++ 3 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 synapse/storage/schema/main/delta/87/01_add_delayed_events.sql create mode 100644 synapse/storage/schema/main/delta/87/01_add_delayed_events.sql.postgres diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 581d00346bf..7f70b014dc9 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -19,7 +19,7 @@ # # -SCHEMA_VERSION = 86 # remember to update the list below when updating +SCHEMA_VERSION = 87 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -142,6 +142,10 @@ Changes in SCHEMA_VERSION = 86 - Add a column `authenticated` to the tables `local_media_repository` and `remote_media_cache` + +Changes in SCHEMA_VERSION = 87 + - MSC4140: Add `delayed_events` table that keeps track of events that are to + be posted in response to a resettable timeout or an on-demand action. """ diff --git a/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql new file mode 100644 index 00000000000..cb1bafcb681 --- /dev/null +++ b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql @@ -0,0 +1,47 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +CREATE TABLE delayed_events ( + delay_rowid INTEGER PRIMARY KEY, -- An alias of rowid in SQLite + delay_id TEXT NOT NULL, + user_localpart TEXT NOT NULL, + running_since BIGINT NOT NULL, + room_id TEXT NOT NULL, + event_type TEXT NOT NULL, + state_key TEXT, + origin_server_ts BIGINT, + content bytea NOT NULL, + UNIQUE (delay_id, user_localpart) +); + +CREATE INDEX delayed_events_room_state_event_idx ON delayed_events (room_id, event_type, state_key) WHERE state_key IS NOT NULL; +CREATE INDEX delayed_events_user_idx ON delayed_events (user_localpart); + +CREATE TABLE delayed_event_timeouts ( + delay_rowid INTEGER PRIMARY KEY + REFERENCES delayed_events (delay_rowid) ON DELETE CASCADE, + delay BIGINT NOT NULL +); + +CREATE TABLE delayed_event_parents ( + delay_rowid INTEGER PRIMARY KEY + REFERENCES delayed_event_timeouts (delay_rowid) ON DELETE CASCADE +); + +CREATE TABLE delayed_event_children ( + child_rowid INTEGER PRIMARY KEY + REFERENCES delayed_events (delay_rowid) ON DELETE CASCADE, + parent_rowid INTEGER NOT NULL + REFERENCES delayed_event_parents (delay_rowid) ON DELETE CASCADE, + CHECK (child_rowid <> parent_rowid) +); diff --git a/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql.postgres b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql.postgres new file mode 100644 index 00000000000..4e4b16d0a28 --- /dev/null +++ b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql.postgres @@ -0,0 +1,14 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +ALTER TABLE delayed_events ALTER COLUMN delay_rowid ADD GENERATED ALWAYS AS IDENTITY; From 82c5437e280c9f86977de4dc495f30f5af46f416 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 6 Aug 2024 23:02:55 -0400 Subject: [PATCH 02/79] Support scheduling delayed events --- synapse/handlers/delayed_events.py | 227 +++++++++++ synapse/rest/client/room.py | 48 ++- synapse/server.py | 6 + synapse/storage/databases/main/__init__.py | 2 + .../storage/databases/main/delayed_events.py | 358 ++++++++++++++++++ 5 files changed, 637 insertions(+), 4 deletions(-) create mode 100644 synapse/handlers/delayed_events.py create mode 100644 synapse/storage/databases/main/delayed_events.py diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py new file mode 100644 index 00000000000..d38b9dbaee3 --- /dev/null +++ b/synapse/handlers/delayed_events.py @@ -0,0 +1,227 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2024 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# Originally licensed under the Apache License, Version 2.0: +# . +# +# + +import logging +from typing import TYPE_CHECKING, Dict, Optional + +import attr + +from twisted.internet.interfaces import IDelayedCall + +from synapse.api.constants import EventTypes +from synapse.api.errors import ShadowBanError +from synapse.logging.opentracing import set_tag +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.databases.main.delayed_events import ( + Delay, + DelayID, + EventType, + StateKey, + Timestamp, + UserLocalpart, +) +from synapse.types import JsonDict, Requester, RoomID, UserID, create_requester +from synapse.util.stringutils import random_string + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _DelayedCallKey: + delay_id: DelayID + user_localpart: UserLocalpart + + def __str__(self) -> str: + return f"{self.user_localpart}:{self.delay_id}" + + +class DelayedEventsHandler: + def __init__(self, hs: "HomeServer"): + self.store = hs.get_datastores().main + self.config = hs.config + self.clock = hs.get_clock() + self.request_ratelimiter = hs.get_request_ratelimiter() + self.event_creation_handler = hs.get_event_creation_handler() + self.room_member_handler = hs.get_room_member_handler() + + self._delayed_calls: Dict[_DelayedCallKey, IDelayedCall] = {} + + async def add( + self, + requester: Requester, + *, + room_id: str, + event_type: str, + state_key: Optional[str], + origin_server_ts: Optional[int], + content: JsonDict, + delay: Optional[int], + parent_id: Optional[str], + ) -> str: + """ + Creates a new delayed event. + + Params: + requester: The requester of the delayed event, who will be its owner. + room_id: The room where the event should be sent. + event_type: The type of event to be sent. + state_key: The state key of the event to be sent, or None if it is not a state event. + origin_server_ts: The custom timestamp to send the event with. + If None, the timestamp will be the actual time when the event is sent. + content: The content of the event to be sent. + delay: How long (in milliseconds) to wait before automatically sending the event. + If None, the event won't be automatically sent (allowed only when parent_id is set). + parent_id: The ID of the delayed event this one is grouped with. + May only refer to a delayed event that has no parent itself. + + Returns: + The ID of the added delayed event. + """ + # Callers should ensure that at least one of these are set + assert delay or parent_id + + await self.request_ratelimiter.ratelimit(requester) + + # TODO: Validate that the event is valid before scheduling it + + user_localpart = UserLocalpart(requester.user.localpart) + delay_id = await self.store.add( + user_localpart=user_localpart, + current_ts=self._get_current_ts(), + room_id=RoomID.from_string(room_id), + event_type=event_type, + state_key=state_key, + origin_server_ts=origin_server_ts, + content=content, + delay=delay, + parent_id=parent_id, + ) + + if delay is not None: + self._schedule(delay_id, user_localpart, Delay(delay)) + + return delay_id + + async def _send_on_timeout( + self, delay_id: DelayID, user_localpart: UserLocalpart + ) -> None: + del self._delayed_calls[_DelayedCallKey(delay_id, user_localpart)] + + args, removed_timeout_delay_ids = await self.store.pop_event( + delay_id, user_localpart + ) + + removed_timeout_delay_ids.remove(delay_id) + for timeout_delay_id in removed_timeout_delay_ids: + self._unschedule(timeout_delay_id, user_localpart) + await self._send_event(user_localpart, *args) + + def _schedule( + self, + delay_id: DelayID, + user_localpart: UserLocalpart, + delay: Delay, + ) -> None: + """NOTE: Should not be called with a delay_id that isn't in the DB, or with a negative delay.""" + delay_sec = delay / 1000 + + logger.info( + "Scheduling delayed event %s for local user %s to be sent in %.3fs", + delay_id, + user_localpart, + delay_sec, + ) + + self._delayed_calls[_DelayedCallKey(delay_id, user_localpart)] = ( + self.clock.call_later( + delay_sec, + run_as_background_process, + "_send_on_timeout", + self._send_on_timeout, + delay_id, + user_localpart, + ) + ) + + def _unschedule(self, delay_id: DelayID, user_localpart: UserLocalpart) -> None: + delayed_call = self._delayed_calls.pop( + _DelayedCallKey(delay_id, user_localpart) + ) + self.clock.cancel_call_later(delayed_call) + + async def _send_event( + self, + user_localpart: UserLocalpart, + room_id: RoomID, + event_type: EventType, + state_key: Optional[StateKey], + origin_server_ts: Optional[Timestamp], + content: JsonDict, + txn_id: Optional[str] = None, + ) -> None: + user_id = UserID(user_localpart, self.config.server.server_name) + user_id_str = user_id.to_string() + requester = create_requester( + user_id, + is_guest=await self.store.is_guest(user_id_str), + ) + + try: + if state_key is not None and event_type == EventTypes.Member: + membership = content.get("membership", None) + event_id, _ = await self.room_member_handler.update_membership( + requester, + target=UserID.from_string(state_key), + room_id=room_id.to_string(), + action=membership, + content=content, + origin_server_ts=origin_server_ts, + ) + else: + event_dict: JsonDict = { + "type": event_type, + "content": content, + "room_id": room_id.to_string(), + "sender": user_id_str, + } + + if state_key is not None: + event_dict["state_key"] = state_key + + if origin_server_ts is not None: + event_dict["origin_server_ts"] = origin_server_ts + + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, + event_dict, + txn_id=txn_id, + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) + + set_tag("event_id", event_id) + + def _get_current_ts(self) -> Timestamp: + return Timestamp(self.clock.time_msec()) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 903c74f6d8f..8a20fc26314 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -193,6 +193,7 @@ def __init__(self, hs: "HomeServer"): self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.message_handler = hs.get_message_handler() + self.delayed_events_handler = hs.get_delayed_events_handler() self.auth = hs.get_auth() def register(self, http_server: HttpServer) -> None: @@ -289,6 +290,24 @@ async def on_PUT( if requester.app_service: origin_server_ts = parse_integer(request, "ts") + delay = parse_integer(request, "org.matrix.msc4140.delay") + parent_id = parse_string(request, "org.matrix.msc4140.parent_delay_id") + if delay is not None or parent_id is not None: + delay_id = await self.delayed_events_handler.add( + requester, + room_id=room_id, + event_type=event_type, + state_key=state_key, + origin_server_ts=origin_server_ts, + content=content, + delay=delay, + parent_id=parent_id, + ) + + set_tag("delay_id", delay_id) + ret = {"delay_id": delay_id} + return 200, ret + try: if event_type == EventTypes.Member: membership = content.get("membership", None) @@ -339,6 +358,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() + self.delayed_events_handler = hs.get_delayed_events_handler() self.auth = hs.get_auth() def register(self, http_server: HttpServer) -> None: @@ -356,6 +376,28 @@ async def _do( ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) + origin_server_ts = None + if requester.app_service: + origin_server_ts = parse_integer(request, "ts") + + delay = parse_integer(request, "org.matrix.msc4140.delay") + parent_id = parse_string(request, "org.matrix.msc4140.parent_delay_id") + if delay is not None or parent_id is not None: + delay_id = await self.delayed_events_handler.add( + requester, + room_id=room_id, + event_type=event_type, + state_key=None, + origin_server_ts=origin_server_ts, + content=content, + delay=delay, + parent_id=parent_id, + ) + + set_tag("delay_id", delay_id) + ret = {"delay_id": delay_id} + return 200, ret + event_dict: JsonDict = { "type": event_type, "content": content, @@ -363,10 +405,8 @@ async def _do( "sender": requester.user.to_string(), } - if requester.app_service: - origin_server_ts = parse_integer(request, "ts") - if origin_server_ts is not None: - event_dict["origin_server_ts"] = origin_server_ts + if origin_server_ts is not None: + event_dict["origin_server_ts"] = origin_server_ts try: ( diff --git a/synapse/server.py b/synapse/server.py index 46b9d83a044..18bed8fe54a 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -67,6 +67,7 @@ from synapse.handlers.auth import AuthHandler, PasswordAuthProvider from synapse.handlers.cas import CasHandler from synapse.handlers.deactivate_account import DeactivateAccountHandler +from synapse.handlers.delayed_events import DelayedEventsHandler from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.directory import DirectoryHandler @@ -249,6 +250,7 @@ class HomeServer(metaclass=abc.ABCMeta): "account_validity", "auth", "deactivate_account", + "delayed_events", "message", "pagination", "profile", @@ -941,3 +943,7 @@ def get_worker_locks_handler(self) -> WorkerLocksHandler: @cache_in_self def get_task_scheduler(self) -> TaskScheduler: return TaskScheduler(self) + + @cache_in_self + def get_delayed_events_handler(self) -> DelayedEventsHandler: + return DelayedEventsHandler(self) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 586e84f2a4d..490d0c14b67 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -43,6 +43,7 @@ from .cache import CacheInvalidationWorkerStore from .censor_events import CensorEventsStore from .client_ips import ClientIpWorkerStore +from .delayed_events import DelayedEventsStore from .deviceinbox import DeviceInboxStore from .devices import DeviceStore from .directory import DirectoryStore @@ -156,6 +157,7 @@ class DataStore( LockStore, SessionStore, TaskSchedulerWorkerStore, + DelayedEventsStore, ): def __init__( self, diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py new file mode 100644 index 00000000000..85673650613 --- /dev/null +++ b/synapse/storage/databases/main/delayed_events.py @@ -0,0 +1,358 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2024 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# Originally licensed under the Apache License, Version 2.0: +# . +# +# + +from binascii import crc32 +from http import HTTPStatus +from typing import TYPE_CHECKING, Any, Dict, List, NewType, Optional, Set, Tuple + +from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.storage.engines import PostgresEngine +from synapse.types import JsonDict, RoomID, StrCollection +from synapse.util import json_encoder, stringutils as stringutils +from synapse.util.stringutils import base62_encode + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +DelayID = NewType("DelayID", str) +UserLocalpart = NewType("UserLocalpart", str) +EventType = NewType("EventType", str) +StateKey = NewType("StateKey", str) + +Delay = NewType("Delay", int) +Timestamp = NewType("Timestamp", int) + +DelayedPartialEvent = Tuple[ + RoomID, + EventType, + Optional[StateKey], + Optional[Timestamp], + JsonDict, +] + + +# TODO: Try to support workers +class DelayedEventsStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + # TODO: Always use RETURNING once the minimum supported Sqlite version is 3.35.0 + self._use_returning = isinstance(self.database_engine, PostgresEngine) + + async def add( + self, + *, + user_localpart: UserLocalpart, + current_ts: Timestamp, + room_id: RoomID, + event_type: str, + state_key: Optional[str], + origin_server_ts: Optional[int], + content: JsonDict, + delay: Optional[int], + parent_id: Optional[str], + ) -> DelayID: + """ + Inserts a new delayed event in the DB. + + Returns: The generated ID assigned to the added delayed event. + + Raises: + SynapseError if the delayed event failed to be added. + """ + + def add_txn(txn: LoggingTransaction) -> DelayID: + delay_id = _generate_delay_id() + try: + sql = """ + INSERT INTO delayed_events ( + delay_id, user_localpart, running_since, + room_id, event_type, state_key, origin_server_ts, + content + ) VALUES ( + ?, ?, ?, + ?, ?, ?, ?, + ? + ) + """ + if self._use_returning: + sql += "RETURNING delay_rowid" + txn.execute( + sql, + ( + delay_id, + user_localpart, + current_ts, + room_id.to_string(), + event_type, + state_key, + origin_server_ts, + json_encoder.encode(content), + ), + ) + # TODO: Handle only the error for DB key collisions + except Exception as e: + raise SynapseError( + HTTPStatus.INTERNAL_SERVER_ERROR, + f"Couldn't generate a unique delay_id for user_localpart {user_localpart}", + # TODO: Maybe remove this + additional_fields={"db_error": str(e)}, + ) + + if not self._use_returning: + txn.execute( + """ + SELECT delay_rowid + FROM delayed_events + WHERE delay_id = ? AND user_localpart = ? + """, + ( + delay_id, + user_localpart, + ), + ) + row = txn.fetchone() + assert row is not None + delay_rowid = row[0] + + if delay is not None: + self.db_pool.simple_insert_txn( + txn, + table="delayed_event_timeouts", + values={ + "delay_rowid": delay_rowid, + "delay": delay, + }, + ) + + if parent_id is None: + self.db_pool.simple_insert_txn( + txn, + table="delayed_event_parents", + values={"delay_rowid": delay_rowid}, + ) + else: + try: + txn.execute( + """ + INSERT INTO delayed_event_children (child_rowid, parent_rowid) + SELECT ?, delay_rowid + FROM delayed_events + WHERE delay_id = ? AND user_localpart = ? + """, + ( + delay_rowid, + parent_id, + user_localpart, + ), + ) + # TODO: Handle only the error for the relevant foreign key / check violation + except Exception as e: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + # TODO: Improve the wording for this + "Invalid parent delayed event", + Codes.INVALID_PARAM, + # TODO: Maybe remove this + additional_fields={"db_error": str(e)}, + ) + if txn.rowcount != 1: + raise NotFoundError("Parent delayed event not found") + + return delay_id + + attempts_remaining = 10 + while True: + try: + return await self.db_pool.runInteraction("add", add_txn) + except SynapseError as e: + if ( + e.code == HTTPStatus.INTERNAL_SERVER_ERROR + and attempts_remaining > 0 + ): + attempts_remaining -= 1 + else: + raise e + + async def pop_event( + self, + delay_id: DelayID, + user_localpart: UserLocalpart, + ) -> Tuple[ + DelayedPartialEvent, + Set[DelayID], + ]: + """ + Get the partial event of the matching delayed event, + and remove it and all of its parent/child/sibling events from the DB. + + Returns: + A tuple of: + - The partial event to send for the matching delayed event. + - The IDs of all removed delayed events with a timeout that must be unscheduled. + + Raises: + NotFoundError if there is no matching delayed event. + """ + return await self.db_pool.runInteraction( + "pop_event", + self._pop_event_txn, + keyvalues={ + "delay_id": delay_id, + "user_localpart": user_localpart, + }, + ) + + def _pop_event_txn( + self, + txn: LoggingTransaction, + keyvalues: Dict[str, Any], + ) -> Tuple[ + DelayedPartialEvent, + Set[DelayID], + ]: + row = self.db_pool.simple_select_one_txn( + txn, + table="delayed_events", + keyvalues=keyvalues, + retcols=( + "delay_rowid", + "room_id", + "event_type", + "state_key", + "origin_server_ts", + "content", + ), + allow_none=True, + ) + if row is None: + raise NotFoundError("Delayed event not found") + target_delay_rowid = row[0] + event_row = row[1:] + + parent_rowid = self.db_pool.simple_select_one_onecol_txn( + txn, + table="delayed_event_children JOIN delayed_events ON child_rowid = delay_rowid", + keyvalues={"delay_rowid": target_delay_rowid}, + retcol="parent_rowid", + allow_none=True, + ) + + removed_timeout_delay_ids = self._remove_txn( + txn, + keyvalues={ + "delay_rowid": ( + parent_rowid if parent_rowid is not None else target_delay_rowid + ), + }, + retcols=("delay_id",), + ) + + contents: JsonDict = db_to_json(event_row[4]) + return ( + ( + RoomID.from_string(event_row[0]), + EventType(event_row[1]), + StateKey(event_row[2]) if event_row[2] is not None else None, + Timestamp(event_row[3]) if event_row[3] is not None else None, + # TODO: Verify contents? + contents, + ), + {DelayID(r[0]) for r in removed_timeout_delay_ids}, + ) + + def _remove_txn( + self, + txn: LoggingTransaction, + keyvalues: Dict[str, Any], + retcols: StrCollection, + allow_none: bool = False, + ) -> List[Tuple]: + """ + Removes delayed events matching the keyvalues, and any children they may have. + + Returns: + The specified columns for each delayed event with a timeout that was removed. + + Raises: + NotFoundError if allow_none is False and no delayed events match the keyvalues. + """ + sql_with = f""" + WITH target_rowids AS ( + SELECT delay_rowid + FROM delayed_events + WHERE {" AND ".join("%s = ?" % k for k in keyvalues)} + ) + """ + sql_where = """ + WHERE delay_rowid IN (SELECT * FROM target_rowids) + OR delay_rowid IN ( + SELECT child_rowid + FROM delayed_event_children + JOIN target_rowids ON parent_rowid = delay_rowid + ) + """ + args = list(keyvalues.values()) + txn.execute( + f""" + {sql_with} + SELECT {", ".join(retcols)} + FROM delayed_events + JOIN delayed_event_timeouts USING (delay_rowid) + {sql_where} + """, + args, + ) + rows = txn.fetchall() + txn.execute( + f""" + {sql_with} + DELETE FROM delayed_events + {sql_where} + """, + args, + ) + if not allow_none and txn.rowcount == 0: + raise NotFoundError("No delayed event found") + return rows + + +def _generate_delay_id() -> DelayID: + """Generates an opaque string, for use as a delay ID""" + + # We use the following format for delay IDs: + # syf__ + # They are scoped to user localparts, so it is possible for + # the same ID to exist for multiple users. + + random_string = stringutils.random_string(20) + base = f"syd_{random_string}" + + crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) + return DelayID(f"{base}_{crc}") From 645e225c6d146ba13e7786fa237e2c01a3e6efce Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 6 Aug 2024 14:28:19 -0400 Subject: [PATCH 03/79] Support config for maximum allowed event delay --- synapse/config/experimental.py | 14 ++++++++++++++ synapse/handlers/delayed_events.py | 15 ++++++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index bae9cc80476..cb22932be6c 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -443,6 +443,20 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: "msc3823_account_suspension", False ) + # MSC4140: Delayed events + # The maximum allowed duration for delayed events. + try: + self.msc4140_max_delay = int(experimental["msc4140_max_delay"]) + if self.msc4140_max_delay <= 0: + raise ValueError + except ValueError: + raise ConfigError( + "msc4140_max_delay must be a positive integer", + ("experimental", "msc4140_max_delay"), + ) + except KeyError: + self.msc4140_max_delay = 10 * 365 * 24 * 60 * 60 * 1000 # 10 years + # MSC4151: Report room API (Client-Server API) self.msc4151_enabled: bool = experimental.get("msc4151_enabled", False) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index d38b9dbaee3..30646b133ad 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -17,6 +17,7 @@ # import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Dict, Optional import attr @@ -24,7 +25,7 @@ from twisted.internet.interfaces import IDelayedCall from synapse.api.constants import EventTypes -from synapse.api.errors import ShadowBanError +from synapse.api.errors import Codes, ShadowBanError, SynapseError from synapse.logging.opentracing import set_tag from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main.delayed_events import ( @@ -95,6 +96,18 @@ async def add( Returns: The ID of the added delayed event. """ + if delay is not None: + max_delay = self.config.experimental.msc4140_max_delay + if delay > max_delay: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "The requested delay exceeds the allowed maximum.", + Codes.UNKNOWN, + { + "org.matrix.msc4140.errcode": "M_MAX_DELAY_EXCEEDED", + "org.matrix.msc4140.max_delay": max_delay, + }, + ) # Callers should ensure that at least one of these are set assert delay or parent_id From 54725883c602850f674ec482ab28155b380fab9c Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 6 Aug 2024 22:38:54 -0400 Subject: [PATCH 04/79] Support updating delayed events --- synapse/handlers/delayed_events.py | 59 ++++++++++++ synapse/rest/__init__.py | 2 + synapse/rest/client/delayed_events.py | 70 +++++++++++++++ .../storage/databases/main/delayed_events.py | 90 ++++++++++++++++++- 4 files changed, 220 insertions(+), 1 deletion(-) create mode 100644 synapse/rest/client/delayed_events.py diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 30646b133ad..bdc1589c243 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -17,6 +17,7 @@ # import logging +from enum import Enum from http import HTTPStatus from typing import TYPE_CHECKING, Dict, Optional @@ -45,6 +46,12 @@ logger = logging.getLogger(__name__) +class _UpdateDelayedEventAction(Enum): + CANCEL = "cancel" + RESTART = "restart" + SEND = "send" + + @attr.s(slots=True, frozen=True, auto_attribs=True) class _DelayedCallKey: delay_id: DelayID @@ -133,6 +140,58 @@ async def add( return delay_id + async def update(self, requester: Requester, delay_id: str, action: str) -> None: + """ + Executes the appropriate action for the matching delayed event. + + Params: + delay_id: The ID of the delayed event to act on. + action: What to do with the delayed event. + + Raises: + SynapseError if the provided action is unknown, or is unsupported for the target delayed event. + NotFoundError if no matching delayed event could be found. + """ + try: + enum_action = _UpdateDelayedEventAction(action) + except ValueError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "'action' is not one of " + + ", ".join(f"'{m.value}'" for m in _UpdateDelayedEventAction), + Codes.INVALID_PARAM, + ) + + delay_id = DelayID(delay_id) + user_localpart = UserLocalpart(requester.user.localpart) + + await self.request_ratelimiter.ratelimit(requester) + + if enum_action == _UpdateDelayedEventAction.CANCEL: + for removed_timeout_delay_id in await self.store.remove( + delay_id, user_localpart + ): + self._unschedule(removed_timeout_delay_id, user_localpart) + + elif enum_action == _UpdateDelayedEventAction.RESTART: + delay = await self.store.restart( + delay_id, + user_localpart, + self._get_current_ts(), + ) + + self._unschedule(delay_id, user_localpart) + self._schedule(delay_id, user_localpart, delay) + + elif enum_action == _UpdateDelayedEventAction.SEND: + args, removed_timeout_delay_ids = await self.store.pop_event( + delay_id, user_localpart + ) + + for timeout_delay_id in removed_timeout_delay_ids: + self._unschedule(timeout_delay_id, user_localpart) + await self._send_event(user_localpart, *args) + async def _send_on_timeout( self, delay_id: DelayID, user_localpart: UserLocalpart ) -> None: diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 1aa9ea3877a..d6638d133ca 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -30,6 +30,7 @@ auth, auth_issuer, capabilities, + delayed_events, devices, directory, events, @@ -103,6 +104,7 @@ def register_servlets(client_resource: HttpServer, hs: "HomeServer") -> None: events.register_servlets(hs, client_resource) room.register_servlets(hs, client_resource) + delayed_events.register_servlets(hs, client_resource) login.register_servlets(hs, client_resource) profile.register_servlets(hs, client_resource) presence.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/delayed_events.py b/synapse/rest/client/delayed_events.py new file mode 100644 index 00000000000..3ccea5d3d96 --- /dev/null +++ b/synapse/rest/client/delayed_events.py @@ -0,0 +1,70 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2024 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# Originally licensed under the Apache License, Version 2.0: +# . +# +# + +""" This module contains REST servlets to do with delayed events: /delayed_events/ """ +import logging +from http import HTTPStatus +from typing import TYPE_CHECKING, Tuple + +from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +# TODO: Needs unit testing +class UpdateDelayedEventServlet(RestServlet): + PATTERNS = client_patterns( + r"/org\.matrix\.msc4140/delayed_events/(?P[^/]*)$", + releases=(), + ) + CATEGORY = "Delayed event management requests" + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.delayed_events_handler = hs.get_delayed_events_handler() + + async def on_POST( + self, request: SynapseRequest, delay_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + body = parse_json_object_from_request(request) + try: + action = str(body["action"]) + except KeyError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "'action' is missing", + Codes.MISSING_PARAM, + ) + + await self.delayed_events_handler.update(requester, delay_id, action) + return 200, {} + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + UpdateDelayedEventServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 85673650613..2d46a2fa221 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -20,7 +20,13 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Any, Dict, List, NewType, Optional, Set, Tuple -from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.api.errors import ( + Codes, + InvalidAPICallError, + NotFoundError, + StoreError, + SynapseError, +) from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( DatabasePool, @@ -200,6 +206,62 @@ def add_txn(txn: LoggingTransaction) -> DelayID: else: raise e + async def restart( + self, + delay_id: DelayID, + user_localpart: UserLocalpart, + current_ts: Timestamp, + ) -> Delay: + """ + Resets the matching delayed event, as long as it has a timeout. + + Params: + delay_id: The ID of the delayed event to restart. + user_localpart: The localpart of the delayed event's owner. + current_ts: The current time, to which the delayed event's "running_since" will be set to. + + Returns: The delay at which the delayed event will be sent (unless it is reset again). + + Raises: + NotFoundError if there is no matching delayed event. + SynapseError if the matching delayed event has no timeout. + """ + + def restart_txn(txn: LoggingTransaction) -> Delay: + keyvalues = { + "delay_id": delay_id, + "user_localpart": user_localpart, + } + row = self.db_pool.simple_select_one_txn( + txn, + table="delayed_events JOIN delayed_event_timeouts USING (delay_rowid)", + keyvalues=keyvalues, + retcols=("delay_rowid", "delay"), + allow_none=True, + ) + if row is None: + try: + self.db_pool.simple_select_one_onecol_txn( + txn, + table="delayed_events", + keyvalues=keyvalues, + retcol="1", + ) + except StoreError: + raise NotFoundError("Delayed event not found") + else: + raise InvalidAPICallError("Delayed event has no timeout") + + self.db_pool.simple_update_txn( + txn, + table="delayed_events", + keyvalues={"delay_rowid": row[0]}, + updatevalues={"running_since": current_ts}, + ) + return Delay(row[1]) + + return await self.db_pool.runInteraction("restart", restart_txn) + async def pop_event( self, delay_id: DelayID, @@ -287,6 +349,32 @@ def _pop_event_txn( {DelayID(r[0]) for r in removed_timeout_delay_ids}, ) + async def remove( + self, + delay_id: DelayID, + user_localpart: UserLocalpart, + ) -> Set[DelayID]: + """ + Removes the matching delayed event, as well as all of its child events if it is a parent. + + Returns: + The IDs of all removed delayed events with a timeout that must be unscheduled. + + Raises: + NotFoundError if there is no matching delayed event. + """ + + removed_timeout_delay_ids = await self.db_pool.runInteraction( + "remove", + self._remove_txn, + keyvalues={ + "delay_id": delay_id, + "user_localpart": user_localpart, + }, + retcols=("delay_id",), + ) + return {DelayID(r[0]) for r in removed_timeout_delay_ids} + def _remove_txn( self, txn: LoggingTransaction, From 14cf8ec89583e42e236dee7b2c33a25041887d76 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 6 Aug 2024 22:52:28 -0400 Subject: [PATCH 05/79] Support listing delayed events --- synapse/handlers/delayed_events.py | 9 +++- synapse/rest/client/delayed_events.py | 23 +++++++++- .../storage/databases/main/delayed_events.py | 42 +++++++++++++++++++ 3 files changed, 72 insertions(+), 2 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index bdc1589c243..6a1f79b5782 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -19,7 +19,7 @@ import logging from enum import Enum from http import HTTPStatus -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, List, Optional import attr @@ -239,6 +239,13 @@ def _unschedule(self, delay_id: DelayID, user_localpart: UserLocalpart) -> None: ) self.clock.cancel_call_later(delayed_call) + async def get_all_for_user(self, requester: Requester) -> List[JsonDict]: + """Return all pending delayed events requested by the given user.""" + await self.request_ratelimiter.ratelimit(requester) + return await self.store.get_all_for_user( + UserLocalpart(requester.user.localpart) + ) + async def _send_event( self, user_localpart: UserLocalpart, diff --git a/synapse/rest/client/delayed_events.py b/synapse/rest/client/delayed_events.py index 3ccea5d3d96..c50cc96f6cf 100644 --- a/synapse/rest/client/delayed_events.py +++ b/synapse/rest/client/delayed_events.py @@ -19,7 +19,7 @@ """ This module contains REST servlets to do with delayed events: /delayed_events/ """ import logging from http import HTTPStatus -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, List, Tuple from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer @@ -66,5 +66,26 @@ async def on_POST( return 200, {} +# TODO: Needs unit testing +class DelayedEventsServlet(RestServlet): + PATTERNS = client_patterns( + r"/org\.matrix\.msc4140/delayed_events$", + releases=(), + ) + CATEGORY = "Delayed event management requests" + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.delayed_events_handler = hs.get_delayed_events_handler() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, List[JsonDict]]: + requester = await self.auth.get_user_by_req(request) + # TODO: Support Pagination stream API ("from" query parameter) + data = await self.delayed_events_handler.get_all_for_user(requester) + return 200, data + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UpdateDelayedEventServlet(hs).register(http_server) + DelayedEventsServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 2d46a2fa221..561f0342786 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -262,6 +262,48 @@ def restart_txn(txn: LoggingTransaction) -> Delay: return await self.db_pool.runInteraction("restart", restart_txn) + async def get_all_for_user( + self, + user_localpart: UserLocalpart, + ) -> List[JsonDict]: + """Returns all pending delayed events owned by the given user.""" + # TODO: Store and return "transaction_id" + # TODO: Support Pagination stream API ("next_batch" field) + rows = await self.db_pool.execute( + "get_all_for_user", + """ + SELECT + delay_id, + room_id, event_type, state_key, + delay, parent_id, + running_since, + content + FROM delayed_events + LEFT JOIN delayed_event_timeouts USING (delay_rowid) + LEFT JOIN ( + SELECT delay_id AS parent_id, child_rowid + FROM delayed_event_children + JOIN delayed_events ON parent_rowid = delay_rowid + ) ON delay_rowid = child_rowid + WHERE user_localpart = ? + """, + user_localpart, + ) + return [ + { + "delay_id": DelayID(row[0]), + "room_id": str(RoomID.from_string(row[1])), + "type": EventType(row[2]), + **({"state_key": StateKey(row[3])} if row[3] is not None else {}), + **({"delay": Delay(row[4])} if row[4] is not None else {}), + **({"parent_delay_id": DelayID(row[5])} if row[5] is not None else {}), + "running_since": Timestamp(row[6]), + # TODO: Verify contents? + "content": db_to_json(row[7]), + } + for row in rows + ] + async def pop_event( self, delay_id: DelayID, From d8e313584c2b5fece7ed4492f5f20a2c742966c0 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 6 Aug 2024 23:23:38 -0400 Subject: [PATCH 06/79] Restore pending delayed events on startup --- synapse/handlers/delayed_events.py | 21 +++++ .../storage/databases/main/delayed_events.py | 77 +++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 6a1f79b5782..b66ce0bfad3 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -72,6 +72,24 @@ def __init__(self, hs: "HomeServer"): self._delayed_calls: Dict[_DelayedCallKey, IDelayedCall] = {} + async def _schedule_db_events() -> None: + # TODO: Sync all state first, so that affected delayed state events will be cancelled + events, remaining_timeout_delays = await self.store.process_all_delays( + self._get_current_ts() + ) + for args in events: + try: + await self._send_event(*args) + except Exception: + logger.exception("Failed to send delayed event on startup") + + for delay_id, user_localpart, relative_delay in remaining_timeout_delays: + self._schedule(delay_id, user_localpart, relative_delay) + + self._initialized_from_db = run_as_background_process( + "_schedule_db_events", _schedule_db_events + ) + async def add( self, requester: Requester, @@ -119,6 +137,7 @@ async def add( assert delay or parent_id await self.request_ratelimiter.ratelimit(requester) + await self._initialized_from_db # TODO: Validate that the event is valid before scheduling it @@ -166,6 +185,7 @@ async def update(self, requester: Requester, delay_id: str, action: str) -> None user_localpart = UserLocalpart(requester.user.localpart) await self.request_ratelimiter.ratelimit(requester) + await self._initialized_from_db if enum_action == _UpdateDelayedEventAction.CANCEL: for removed_timeout_delay_id in await self.store.remove( @@ -242,6 +262,7 @@ def _unschedule(self, delay_id: DelayID, user_localpart: UserLocalpart) -> None: async def get_all_for_user(self, requester: Requester) -> List[JsonDict]: """Return all pending delayed events requested by the given user.""" await self.request_ratelimiter.ratelimit(requester) + await self._initialized_from_db return await self.store.get_all_for_user( UserLocalpart(requester.user.localpart) ) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 561f0342786..b7e8d3f4cc8 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -58,6 +58,16 @@ JsonDict, ] +# TODO: If a Tuple type hint can be extended, extend the above one +DelayedPartialEventWithUser = Tuple[ + UserLocalpart, + RoomID, + EventType, + Optional[StateKey], + Optional[Timestamp], + JsonDict, +] + # TODO: Try to support workers class DelayedEventsStore(SQLBaseStore): @@ -304,6 +314,73 @@ async def get_all_for_user( for row in rows ] + async def process_all_delays(self, current_ts: Timestamp) -> Tuple[ + List[DelayedPartialEventWithUser], + List[Tuple[DelayID, UserLocalpart, Delay]], + ]: + """ + Pops all delayed events that should have timed out prior to the provided time, + and returns all remaining timeout delayed events along with + how much later from the provided time they should time out at. + """ + + def process_all_delays_txn(txn: LoggingTransaction) -> Tuple[ + List[DelayedPartialEventWithUser], + List[Tuple[DelayID, UserLocalpart, Delay]], + ]: + events: List[DelayedPartialEventWithUser] = [] + removed_timeout_delay_ids: Set[DelayID] = set() + + txn.execute( + """ + WITH delay_send_times AS ( + SELECT *, running_since + delay AS send_ts + FROM delayed_events + JOIN delayed_event_timeouts USING (delay_rowid) + ) + SELECT delay_rowid, user_localpart + FROM delay_send_times + WHERE send_ts < ? + ORDER BY send_ts + """, + (current_ts,), + ) + for row in txn.fetchall(): + try: + event, removed_timeout_delay_id = self._pop_event_txn( + txn, + keyvalues={"delay_rowid": row[0]}, + ) + except NotFoundError: + pass + events.append((UserLocalpart(row[1]), *event)) + removed_timeout_delay_ids |= removed_timeout_delay_id + + txn.execute( + """ + SELECT + delay_id, + user_localpart, + running_since + delay - ? AS relative_delay + FROM delayed_events + JOIN delayed_event_timeouts USING (delay_rowid) + """, + (current_ts,), + ) + remaining_timeout_delays = [ + ( + DelayID(row[0]), + UserLocalpart(row[1]), + Delay(row[2]), + ) + for row in txn + ] + return events, remaining_timeout_delays + + return await self.db_pool.runInteraction( + "process_all_delays", process_all_delays_txn + ) + async def pop_event( self, delay_id: DelayID, From f9261b9ac1c8a69025cb264781e3b6b8c66bbca1 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 6 Aug 2024 22:52:59 -0400 Subject: [PATCH 07/79] Cancel delayed state events on state change --- synapse/handlers/delayed_events.py | 36 ++++++++++++++++++- .../storage/databases/main/delayed_events.py | 24 +++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index b66ce0bfad3..9c448b6e471 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -27,6 +27,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import Codes, ShadowBanError, SynapseError +from synapse.events import EventBase from synapse.logging.opentracing import set_tag from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main.delayed_events import ( @@ -37,7 +38,14 @@ Timestamp, UserLocalpart, ) -from synapse.types import JsonDict, Requester, RoomID, UserID, create_requester +from synapse.types import ( + JsonDict, + Requester, + RoomID, + StateMap, + UserID, + create_requester, +) from synapse.util.stringutils import random_string if TYPE_CHECKING: @@ -86,10 +94,36 @@ async def _schedule_db_events() -> None: for delay_id, user_localpart, relative_delay in remaining_timeout_delays: self._schedule(delay_id, user_localpart, relative_delay) + hs.get_module_api().register_third_party_rules_callbacks( + on_new_event=self.on_new_event, + ) + self._initialized_from_db = run_as_background_process( "_schedule_db_events", _schedule_db_events ) + async def on_new_event( + self, event: EventBase, _state_events: StateMap[EventBase] + ) -> None: + """ + Checks if a received event is a state event, and if so, + cancels any delayed events that target the same state. + """ + state_key = event.get_state_key() + if state_key is not None: + for ( + removed_timeout_delay_id, + removed_timeout_delay_user_localpart, + ) in await self.store.remove_state_events( + event.room_id, + event.type, + state_key, + ): + self._unschedule( + removed_timeout_delay_id, + removed_timeout_delay_user_localpart, + ) + async def add( self, requester: Requester, diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index b7e8d3f4cc8..9790caabd8a 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -494,6 +494,30 @@ async def remove( ) return {DelayID(r[0]) for r in removed_timeout_delay_ids} + async def remove_state_events( + self, + room_id: str, + event_type: str, + state_key: str, + ) -> List[Tuple[DelayID, UserLocalpart]]: + """ + Removes all matching delayed state events from the DB, as well as their children. + + Returns: + The ID & owner of every removed delayed event with a timeout that must be unscheduled. + """ + return await self.db_pool.runInteraction( + "remove_state_events", + self._remove_txn, + keyvalues={ + "room_id": room_id, + "event_type": event_type, + "state_key": state_key, + }, + retcols=("delay_id", "user_localpart"), + allow_none=True, + ) + def _remove_txn( self, txn: LoggingTransaction, From d3ea9683fb5a50f298c9596bc09d8ecf029dccd7 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 6 Aug 2024 16:35:24 -0400 Subject: [PATCH 08/79] Prevent race conditions in delayed event updates --- synapse/handlers/delayed_events.py | 121 +++++++++++++++++++---------- 1 file changed, 80 insertions(+), 41 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 9c448b6e471..04974a95d37 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -17,16 +17,24 @@ # import logging +from contextlib import asynccontextmanager from enum import Enum from http import HTTPStatus -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import ( + TYPE_CHECKING, + AsyncContextManager, + AsyncIterator, + Dict, + List, + Optional, +) import attr from twisted.internet.interfaces import IDelayedCall from synapse.api.constants import EventTypes -from synapse.api.errors import Codes, ShadowBanError, SynapseError +from synapse.api.errors import Codes, NotFoundError, ShadowBanError, SynapseError from synapse.events import EventBase from synapse.logging.opentracing import set_tag from synapse.metrics.background_process_metrics import run_as_background_process @@ -46,6 +54,7 @@ UserID, create_requester, ) +from synapse.util.async_helpers import Linearizer, ReadWriteLock from synapse.util.stringutils import random_string if TYPE_CHECKING: @@ -53,6 +62,8 @@ logger = logging.getLogger(__name__) +_STATE_LOCK_KEY = "STATE_LOCK_KEY" + class _UpdateDelayedEventAction(Enum): CANCEL = "cancel" @@ -79,6 +90,10 @@ def __init__(self, hs: "HomeServer"): self.room_member_handler = hs.get_room_member_handler() self._delayed_calls: Dict[_DelayedCallKey, IDelayedCall] = {} + # This is for making delayed event actions atomic + self._linearizer = Linearizer("delayed_events_handler") + # This is to prevent running actions on delayed events removed due to state changes + self._state_lock = ReadWriteLock() async def _schedule_db_events() -> None: # TODO: Sync all state first, so that affected delayed state events will be cancelled @@ -111,18 +126,19 @@ async def on_new_event( """ state_key = event.get_state_key() if state_key is not None: - for ( - removed_timeout_delay_id, - removed_timeout_delay_user_localpart, - ) in await self.store.remove_state_events( - event.room_id, - event.type, - state_key, - ): - self._unschedule( + async with self._get_state_context(): + for ( removed_timeout_delay_id, removed_timeout_delay_user_localpart, - ) + ) in await self.store.remove_state_events( + event.room_id, + event.type, + state_key, + ): + self._unschedule( + removed_timeout_delay_id, + removed_timeout_delay_user_localpart, + ) async def add( self, @@ -219,46 +235,55 @@ async def update(self, requester: Requester, delay_id: str, action: str) -> None user_localpart = UserLocalpart(requester.user.localpart) await self.request_ratelimiter.ratelimit(requester) - await self._initialized_from_db - if enum_action == _UpdateDelayedEventAction.CANCEL: - for removed_timeout_delay_id in await self.store.remove( - delay_id, user_localpart - ): - self._unschedule(removed_timeout_delay_id, user_localpart) + async with self._get_delay_context(delay_id, user_localpart): + if enum_action == _UpdateDelayedEventAction.CANCEL: + for removed_timeout_delay_id in await self.store.remove( + delay_id, user_localpart + ): + self._unschedule(removed_timeout_delay_id, user_localpart) + + elif enum_action == _UpdateDelayedEventAction.RESTART: + delay = await self.store.restart( + delay_id, + user_localpart, + self._get_current_ts(), + ) - elif enum_action == _UpdateDelayedEventAction.RESTART: - delay = await self.store.restart( - delay_id, - user_localpart, - self._get_current_ts(), - ) + self._unschedule(delay_id, user_localpart) + self._schedule(delay_id, user_localpart, delay) - self._unschedule(delay_id, user_localpart) - self._schedule(delay_id, user_localpart, delay) + elif enum_action == _UpdateDelayedEventAction.SEND: + args, removed_timeout_delay_ids = await self.store.pop_event( + delay_id, user_localpart + ) - elif enum_action == _UpdateDelayedEventAction.SEND: - args, removed_timeout_delay_ids = await self.store.pop_event( - delay_id, user_localpart - ) - - for timeout_delay_id in removed_timeout_delay_ids: - self._unschedule(timeout_delay_id, user_localpart) - await self._send_event(user_localpart, *args) + for timeout_delay_id in removed_timeout_delay_ids: + self._unschedule(timeout_delay_id, user_localpart) + await self._send_event(user_localpart, *args) async def _send_on_timeout( self, delay_id: DelayID, user_localpart: UserLocalpart ) -> None: del self._delayed_calls[_DelayedCallKey(delay_id, user_localpart)] - args, removed_timeout_delay_ids = await self.store.pop_event( - delay_id, user_localpart - ) + async with self._get_delay_context(delay_id, user_localpart): + try: + args, removed_timeout_delay_ids = await self.store.pop_event( + delay_id, user_localpart + ) + except NotFoundError: + logger.debug( + "delay_id %s for local user %s was removed after it timed out, but before it was sent on timeout", + delay_id, + user_localpart, + ) + return - removed_timeout_delay_ids.remove(delay_id) - for timeout_delay_id in removed_timeout_delay_ids: - self._unschedule(timeout_delay_id, user_localpart) - await self._send_event(user_localpart, *args) + removed_timeout_delay_ids.remove(delay_id) + for timeout_delay_id in removed_timeout_delay_ids: + self._unschedule(timeout_delay_id, user_localpart) + await self._send_event(user_localpart, *args) def _schedule( self, @@ -359,3 +384,17 @@ async def _send_event( def _get_current_ts(self) -> Timestamp: return Timestamp(self.clock.time_msec()) + + @asynccontextmanager + async def _get_delay_context( + self, delay_id: DelayID, user_localpart: UserLocalpart + ) -> AsyncIterator[None]: + await self._initialized_from_db + # TODO: Use parenthesized context manager once the minimum supported Python version is 3.10 + async with self._state_lock.read(_STATE_LOCK_KEY), self._linearizer.queue( + _DelayedCallKey(delay_id, user_localpart) + ): + yield + + def _get_state_context(self) -> AsyncContextManager: + return self._state_lock.write(_STATE_LOCK_KEY) From c34221f5c285ee58558f7cdc34e011733a64f203 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Fri, 2 Aug 2024 09:33:35 -0400 Subject: [PATCH 09/79] Check startup delayed state events for same state If on startup there are multiple delayed state events to be sent, do not send multiple events that target the same state key for a room. --- synapse/handlers/delayed_events.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 04974a95d37..5a4481fb300 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -27,6 +27,8 @@ Dict, List, Optional, + Set, + Tuple, ) import attr @@ -100,11 +102,35 @@ async def _schedule_db_events() -> None: events, remaining_timeout_delays = await self.store.process_all_delays( self._get_current_ts() ) - for args in events: + sent_state: Set[Tuple[RoomID, EventType, StateKey]] = set() + for ( + user_localpart, + room_id, + event_type, + state_key, + timestamp, + content, + ) in events: + if state_key is not None: + state_info = (room_id, event_type, state_key) + if state_info in sent_state: + continue + else: + state_info = None try: - await self._send_event(*args) + await self._send_event( + user_localpart, + room_id, + event_type, + state_key, + timestamp, + content, + ) + if state_info is not None: + sent_state.add(state_info) except Exception: logger.exception("Failed to send delayed event on startup") + sent_state.clear() for delay_id, user_localpart, relative_delay in remaining_timeout_delays: self._schedule(delay_id, user_localpart, relative_delay) From f2d81446654e593b47251eab5dc129ca55cbbcdd Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 6 Aug 2024 13:50:38 -0400 Subject: [PATCH 10/79] Advertise as unstable feature --- synapse/rest/client/versions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 75df6844166..3096d04f051 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -167,6 +167,8 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: is not None ) ), + # MSC4140: Delayed events + "org.matrix.msc4140": True, # MSC4151: Report room API (Client-Server API) "org.matrix.msc4151": self.config.experimental.msc4151_enabled, }, From d2c9ca7f994a5b3451a39482b1e5ae9672ad392e Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 6 Aug 2024 13:20:27 -0400 Subject: [PATCH 11/79] Update copyright years --- synapse/config/experimental.py | 2 +- synapse/rest/__init__.py | 2 +- synapse/rest/client/room.py | 2 +- synapse/rest/client/versions.py | 2 +- synapse/server.py | 2 +- synapse/storage/databases/main/__init__.py | 2 +- synapse/storage/schema/__init__.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index cb22932be6c..f9b38a5c04e 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -2,7 +2,7 @@ # This file is licensed under the Affero General Public License (AGPL) version 3. # # Copyright 2021 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023-2024 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index d6638d133ca..ca7e3c75ed0 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -2,7 +2,7 @@ # This file is licensed under the Affero General Public License (AGPL) version 3. # # Copyright 2014-2016 OpenMarket Ltd -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023-2024 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 8a20fc26314..e5b1f360f11 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -2,7 +2,7 @@ # This file is licensed under the Affero General Public License (AGPL) version 3. # # Copyright 2014-2016 OpenMarket Ltd -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023-2024 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 3096d04f051..ff3509f2056 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -4,7 +4,7 @@ # Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright 2017 Vector Creations Ltd # Copyright 2016 OpenMarket Ltd -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023-2024 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as diff --git a/synapse/server.py b/synapse/server.py index 18bed8fe54a..9f363032fec 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -2,7 +2,7 @@ # This file is licensed under the Affero General Public License (AGPL) version 3. # # Copyright 2021 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023-2024 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 490d0c14b67..04b6684bc95 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -3,7 +3,7 @@ # # Copyright 2019-2021 The Matrix.org Foundation C.I.C. # Copyright 2014-2016 OpenMarket Ltd -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023-2024 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 7f70b014dc9..8715dec5c60 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -2,7 +2,7 @@ # This file is licensed under the Affero General Public License (AGPL) version 3. # # Copyright 2021 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023-2024 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as From 56c6d87767c6ff831e69f7fe45f7590482c04f03 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 6 Aug 2024 13:20:40 -0400 Subject: [PATCH 12/79] Add changelog --- changelog.d/17326.feature | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 changelog.d/17326.feature diff --git a/changelog.d/17326.feature b/changelog.d/17326.feature new file mode 100644 index 00000000000..348c54b0404 --- /dev/null +++ b/changelog.d/17326.feature @@ -0,0 +1,2 @@ +Add initial implementation of delayed events as proposed by [MSC4140](https://github.com/matrix-org/matrix-spec-proposals/pull/4140). + From c24c41b6abcf152fa795f8c93fe4eb5629b837e3 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Wed, 7 Aug 2024 09:34:30 -0400 Subject: [PATCH 13/79] Don't throw when event callback finds no event --- .../callbacks/third_party_event_rules_callbacks.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/synapse/module_api/callbacks/third_party_event_rules_callbacks.py b/synapse/module_api/callbacks/third_party_event_rules_callbacks.py index 9f7a04372de..f4cf8596a8a 100644 --- a/synapse/module_api/callbacks/third_party_event_rules_callbacks.py +++ b/synapse/module_api/callbacks/third_party_event_rules_callbacks.py @@ -426,7 +426,10 @@ async def on_new_event(self, event_id: str) -> None: if len(self._on_new_event_callbacks) == 0: return - event = await self.store.get_event(event_id) + event = await self.store.get_event(event_id, allow_none=True) + if not event: + logger.warning("Could not find event %s" % (event_id,)) + return # We *don't* want to wait for the full state here, because waiting for full # state will persist event, which in turn will call this method. From 6e382dfe11bae3b35e26a425150b70378b384896 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Wed, 7 Aug 2024 09:41:27 -0400 Subject: [PATCH 14/79] Increase expected db_txn_counts Include counts for the delayed event handler's state event callback --- tests/rest/client/test_rooms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index c559dfda834..da243a09d61 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -742,7 +742,7 @@ def test_post_room_no_keys(self) -> None: self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(33, channel.resource_usage.db_txn_count) + self.assertEqual(38, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -755,7 +755,7 @@ def test_post_room_initial_state(self) -> None: self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(35, channel.resource_usage.db_txn_count) + self.assertEqual(40, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id From 32cbacf4c8825d9090105de4b0ce90ab2cdc544e Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Fri, 9 Aug 2024 15:21:25 -0400 Subject: [PATCH 15/79] Validate a delayed event before scheduling it --- synapse/handlers/delayed_events.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 5a4481fb300..94ded5c32e6 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -215,7 +215,18 @@ async def add( await self.request_ratelimiter.ratelimit(requester) await self._initialized_from_db - # TODO: Validate that the event is valid before scheduling it + self.event_creation_handler.validator.validate_builder( + self.event_creation_handler.event_builder_factory.for_room_version( + await self.store.get_room_version(room_id), + { + "type": event_type, + "content": content, + "room_id": room_id, + "sender": str(requester.user), + **({"state_key": state_key} if state_key is not None else {}), + }, + ) + ) user_localpart = UserLocalpart(requester.user.localpart) delay_id = await self.store.add( From 71e8997e32d4710f897d60e1920b12e31b9ec1fa Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Fri, 9 Aug 2024 16:08:23 -0400 Subject: [PATCH 16/79] Start adding unit tests for delayed events --- tests/rest/client/test_rooms.py | 45 +++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index da243a09d61..d132cc7c540 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2291,6 +2291,51 @@ def test_room_message_filter_wildcard(self) -> None: self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) +class RoomDelayedEventTestCase(RoomBase): + """Tests delayed events.""" + + user_id = "@sid1:red" + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.room_id = self.helper.create_room_as(self.user_id) + + def test_send_delayed_invalid_event(self) -> None: + """Test sending a delayed event with invalid content.""" + channel = self.make_request( + "PUT", + ( + "rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=2000" + % self.room_id + ).encode("ascii"), + {}, + ) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) + + def test_send_delayed_message_event(self) -> None: + """Test sending a delayed event with invalid content.""" + channel = self.make_request( + "PUT", + ( + "rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=2000" + % self.room_id + ).encode("ascii"), + {"body": "test", "msgtype": "m.text"}, + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + + def test_send_delayed_state_event(self) -> None: + """Test sending a delayed event with invalid content.""" + channel = self.make_request( + "PUT", + ( + "rooms/%s/state/m.room.topic/?org.matrix.msc4140.delay=2000" + % self.room_id + ).encode("ascii"), + {"topic": "This is a topic"}, + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + + class RoomSearchTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, From fafaa03214bf735a347626a907bb226f7971b870 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Thu, 15 Aug 2024 08:36:19 -0400 Subject: [PATCH 17/79] Add comments to explain rowid / identity columns --- .../storage/schema/main/delta/87/01_add_delayed_events.sql | 7 ++++++- .../main/delta/87/01_add_delayed_events.sql.postgres | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql index cb1bafcb681..7ae4d189ca7 100644 --- a/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql +++ b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql @@ -12,7 +12,12 @@ -- . CREATE TABLE delayed_events ( - delay_rowid INTEGER PRIMARY KEY, -- An alias of rowid in SQLite + -- An alias of rowid in SQLite. + -- Newly-inserted rows that don't assign a (non-NULL) value for this column + -- will have it set to a table-unique value. + -- For Postgres to do this, the column must be set as an identity column. + delay_rowid INTEGER PRIMARY KEY, + delay_id TEXT NOT NULL, user_localpart TEXT NOT NULL, running_since BIGINT NOT NULL, diff --git a/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql.postgres b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql.postgres index 4e4b16d0a28..3771db43781 100644 --- a/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql.postgres +++ b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql.postgres @@ -11,4 +11,6 @@ -- See the GNU Affero General Public License for more details: -- . +-- Sets the column as an identity column, meaning that the column in new rows +-- will automatically have values from an implicit sequence assigned to it. ALTER TABLE delayed_events ALTER COLUMN delay_rowid ADD GENERATED ALWAYS AS IDENTITY; From 3afa3cf8f2d458929d9a93945fbca5fa7124f612 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Thu, 15 Aug 2024 08:37:28 -0400 Subject: [PATCH 18/79] Prefix internally-used attributes with underscore --- synapse/handlers/delayed_events.py | 56 +++++++++++++++--------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 94ded5c32e6..d0810db4748 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -84,12 +84,12 @@ def __str__(self) -> str: class DelayedEventsHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastores().main - self.config = hs.config - self.clock = hs.get_clock() - self.request_ratelimiter = hs.get_request_ratelimiter() - self.event_creation_handler = hs.get_event_creation_handler() - self.room_member_handler = hs.get_room_member_handler() + self._store = hs.get_datastores().main + self._config = hs.config + self._clock = hs.get_clock() + self._request_ratelimiter = hs.get_request_ratelimiter() + self._event_creation_handler = hs.get_event_creation_handler() + self._room_member_handler = hs.get_room_member_handler() self._delayed_calls: Dict[_DelayedCallKey, IDelayedCall] = {} # This is for making delayed event actions atomic @@ -99,7 +99,7 @@ def __init__(self, hs: "HomeServer"): async def _schedule_db_events() -> None: # TODO: Sync all state first, so that affected delayed state events will be cancelled - events, remaining_timeout_delays = await self.store.process_all_delays( + events, remaining_timeout_delays = await self._store.process_all_delays( self._get_current_ts() ) sent_state: Set[Tuple[RoomID, EventType, StateKey]] = set() @@ -156,7 +156,7 @@ async def on_new_event( for ( removed_timeout_delay_id, removed_timeout_delay_user_localpart, - ) in await self.store.remove_state_events( + ) in await self._store.remove_state_events( event.room_id, event.type, state_key, @@ -198,7 +198,7 @@ async def add( The ID of the added delayed event. """ if delay is not None: - max_delay = self.config.experimental.msc4140_max_delay + max_delay = self._config.experimental.msc4140_max_delay if delay > max_delay: raise SynapseError( HTTPStatus.BAD_REQUEST, @@ -212,12 +212,12 @@ async def add( # Callers should ensure that at least one of these are set assert delay or parent_id - await self.request_ratelimiter.ratelimit(requester) + await self._request_ratelimiter.ratelimit(requester) await self._initialized_from_db - self.event_creation_handler.validator.validate_builder( - self.event_creation_handler.event_builder_factory.for_room_version( - await self.store.get_room_version(room_id), + self._event_creation_handler.validator.validate_builder( + self._event_creation_handler.event_builder_factory.for_room_version( + await self._store.get_room_version(room_id), { "type": event_type, "content": content, @@ -229,7 +229,7 @@ async def add( ) user_localpart = UserLocalpart(requester.user.localpart) - delay_id = await self.store.add( + delay_id = await self._store.add( user_localpart=user_localpart, current_ts=self._get_current_ts(), room_id=RoomID.from_string(room_id), @@ -271,17 +271,17 @@ async def update(self, requester: Requester, delay_id: str, action: str) -> None delay_id = DelayID(delay_id) user_localpart = UserLocalpart(requester.user.localpart) - await self.request_ratelimiter.ratelimit(requester) + await self._request_ratelimiter.ratelimit(requester) async with self._get_delay_context(delay_id, user_localpart): if enum_action == _UpdateDelayedEventAction.CANCEL: - for removed_timeout_delay_id in await self.store.remove( + for removed_timeout_delay_id in await self._store.remove( delay_id, user_localpart ): self._unschedule(removed_timeout_delay_id, user_localpart) elif enum_action == _UpdateDelayedEventAction.RESTART: - delay = await self.store.restart( + delay = await self._store.restart( delay_id, user_localpart, self._get_current_ts(), @@ -291,7 +291,7 @@ async def update(self, requester: Requester, delay_id: str, action: str) -> None self._schedule(delay_id, user_localpart, delay) elif enum_action == _UpdateDelayedEventAction.SEND: - args, removed_timeout_delay_ids = await self.store.pop_event( + args, removed_timeout_delay_ids = await self._store.pop_event( delay_id, user_localpart ) @@ -306,7 +306,7 @@ async def _send_on_timeout( async with self._get_delay_context(delay_id, user_localpart): try: - args, removed_timeout_delay_ids = await self.store.pop_event( + args, removed_timeout_delay_ids = await self._store.pop_event( delay_id, user_localpart ) except NotFoundError: @@ -339,7 +339,7 @@ def _schedule( ) self._delayed_calls[_DelayedCallKey(delay_id, user_localpart)] = ( - self.clock.call_later( + self._clock.call_later( delay_sec, run_as_background_process, "_send_on_timeout", @@ -353,13 +353,13 @@ def _unschedule(self, delay_id: DelayID, user_localpart: UserLocalpart) -> None: delayed_call = self._delayed_calls.pop( _DelayedCallKey(delay_id, user_localpart) ) - self.clock.cancel_call_later(delayed_call) + self._clock.cancel_call_later(delayed_call) async def get_all_for_user(self, requester: Requester) -> List[JsonDict]: """Return all pending delayed events requested by the given user.""" - await self.request_ratelimiter.ratelimit(requester) + await self._request_ratelimiter.ratelimit(requester) await self._initialized_from_db - return await self.store.get_all_for_user( + return await self._store.get_all_for_user( UserLocalpart(requester.user.localpart) ) @@ -373,17 +373,17 @@ async def _send_event( content: JsonDict, txn_id: Optional[str] = None, ) -> None: - user_id = UserID(user_localpart, self.config.server.server_name) + user_id = UserID(user_localpart, self._config.server.server_name) user_id_str = user_id.to_string() requester = create_requester( user_id, - is_guest=await self.store.is_guest(user_id_str), + is_guest=await self._store.is_guest(user_id_str), ) try: if state_key is not None and event_type == EventTypes.Member: membership = content.get("membership", None) - event_id, _ = await self.room_member_handler.update_membership( + event_id, _ = await self._room_member_handler.update_membership( requester, target=UserID.from_string(state_key), room_id=room_id.to_string(), @@ -408,7 +408,7 @@ async def _send_event( ( event, _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( + ) = await self._event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id, @@ -420,7 +420,7 @@ async def _send_event( set_tag("event_id", event_id) def _get_current_ts(self) -> Timestamp: - return Timestamp(self.clock.time_msec()) + return Timestamp(self._clock.time_msec()) @asynccontextmanager async def _get_delay_context( From b3d4d6cb41d86a32a15ec6f842bd76bbb0b6d843 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Thu, 15 Aug 2024 08:38:00 -0400 Subject: [PATCH 19/79] Use "Params" instead of "Args" in docstrings to follow Google's Python style guide https://google.github.io/styleguide/pyguide.html#doc-function-args --- synapse/handlers/delayed_events.py | 4 ++-- synapse/storage/databases/main/delayed_events.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index d0810db4748..e5d19b92b35 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -181,7 +181,7 @@ async def add( """ Creates a new delayed event. - Params: + Args: requester: The requester of the delayed event, who will be its owner. room_id: The room where the event should be sent. event_type: The type of event to be sent. @@ -250,7 +250,7 @@ async def update(self, requester: Requester, delay_id: str, action: str) -> None """ Executes the appropriate action for the matching delayed event. - Params: + Args: delay_id: The ID of the delayed event to act on. action: What to do with the delayed event. diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 9790caabd8a..c6400cacfc7 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -225,7 +225,7 @@ async def restart( """ Resets the matching delayed event, as long as it has a timeout. - Params: + Args: delay_id: The ID of the delayed event to restart. user_localpart: The localpart of the delayed event's owner. current_ts: The current time, to which the delayed event's "running_since" will be set to. From 48d81262b5392f470db7e56730d2a7f7390d07fb Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Thu, 15 Aug 2024 08:39:30 -0400 Subject: [PATCH 20/79] Add comment to explain parent delayed events --- synapse/handlers/delayed_events.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index e5d19b92b35..a4a77eab6ef 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -193,6 +193,8 @@ async def add( If None, the event won't be automatically sent (allowed only when parent_id is set). parent_id: The ID of the delayed event this one is grouped with. May only refer to a delayed event that has no parent itself. + When the parent event is sent or cancelled, this one is cancelled; + and when this event is sent, the parent is cancelled. Returns: The ID of the added delayed event. From bee52bd95fd0fd707c801b71a3e1444f787e44f5 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Thu, 15 Aug 2024 09:37:06 -0400 Subject: [PATCH 21/79] Use utility function for generating a fake ID --- synapse/handlers/delayed_events.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index a4a77eab6ef..1d3682888bc 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -57,7 +57,7 @@ create_requester, ) from synapse.util.async_helpers import Linearizer, ReadWriteLock -from synapse.util.stringutils import random_string +from synapse.util.events import generate_fake_event_id if TYPE_CHECKING: from synapse.server import HomeServer @@ -417,7 +417,7 @@ async def _send_event( ) event_id = event.event_id except ShadowBanError: - event_id = "$" + random_string(43) + event_id = generate_fake_event_id() set_tag("event_id", event_id) From 62527082e6b95be078848ed86fd98ad2d508e606 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Thu, 15 Aug 2024 10:54:48 -0400 Subject: [PATCH 22/79] Move DB error messages to debug log --- synapse/storage/databases/main/delayed_events.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index c6400cacfc7..80bf8fc4cda 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -16,6 +16,7 @@ # # +import logging from binascii import crc32 from http import HTTPStatus from typing import TYPE_CHECKING, Any, Dict, List, NewType, Optional, Set, Tuple @@ -41,6 +42,8 @@ if TYPE_CHECKING: from synapse.server import HomeServer +logger = logging.getLogger(__name__) + DelayID = NewType("DelayID", str) UserLocalpart = NewType("UserLocalpart", str) @@ -134,11 +137,13 @@ def add_txn(txn: LoggingTransaction) -> DelayID: ) # TODO: Handle only the error for DB key collisions except Exception as e: + logger.debug( + "Error inserting into delayed_events", + str(e), + ) raise SynapseError( HTTPStatus.INTERNAL_SERVER_ERROR, f"Couldn't generate a unique delay_id for user_localpart {user_localpart}", - # TODO: Maybe remove this - additional_fields={"db_error": str(e)}, ) if not self._use_returning: @@ -190,13 +195,15 @@ def add_txn(txn: LoggingTransaction) -> DelayID: ) # TODO: Handle only the error for the relevant foreign key / check violation except Exception as e: + logger.debug( + "Error inserting into delayed_event_children", + str(e), + ) raise SynapseError( HTTPStatus.BAD_REQUEST, # TODO: Improve the wording for this "Invalid parent delayed event", Codes.INVALID_PARAM, - # TODO: Maybe remove this - additional_fields={"db_error": str(e)}, ) if txn.rowcount != 1: raise NotFoundError("Parent delayed event not found") From b1f74a81c38813f759182be32c50aceaaec7d6d7 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Thu, 15 Aug 2024 11:22:54 -0400 Subject: [PATCH 23/79] Remove TODO to verify delayed event contents because they are already verified when the event is scheduled --- synapse/storage/databases/main/delayed_events.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 80bf8fc4cda..515e74cb5d8 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -315,7 +315,6 @@ async def get_all_for_user( **({"delay": Delay(row[4])} if row[4] is not None else {}), **({"parent_delay_id": DelayID(row[5])} if row[5] is not None else {}), "running_since": Timestamp(row[6]), - # TODO: Verify contents? "content": db_to_json(row[7]), } for row in rows @@ -469,7 +468,6 @@ def _pop_event_txn( EventType(event_row[1]), StateKey(event_row[2]) if event_row[2] is not None else None, Timestamp(event_row[3]) if event_row[3] is not None else None, - # TODO: Verify contents? contents, ), {DelayID(r[0]) for r in removed_timeout_delay_ids}, From 221e0af65f923c42137976cd83f8d875549375d3 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Thu, 15 Aug 2024 11:26:06 -0400 Subject: [PATCH 24/79] Don't bother using a CRC for delay_ids as all they need to be is a random opaque string --- synapse/storage/databases/main/delayed_events.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 515e74cb5d8..f4645fb9c46 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -17,7 +17,6 @@ # import logging -from binascii import crc32 from http import HTTPStatus from typing import TYPE_CHECKING, Any, Dict, List, NewType, Optional, Set, Tuple @@ -37,7 +36,6 @@ from synapse.storage.engines import PostgresEngine from synapse.types import JsonDict, RoomID, StrCollection from synapse.util import json_encoder, stringutils as stringutils -from synapse.util.stringutils import base62_encode if TYPE_CHECKING: from synapse.server import HomeServer @@ -583,12 +581,8 @@ def _generate_delay_id() -> DelayID: """Generates an opaque string, for use as a delay ID""" # We use the following format for delay IDs: - # syf__ + # syf_ # They are scoped to user localparts, so it is possible for # the same ID to exist for multiple users. - random_string = stringutils.random_string(20) - base = f"syd_{random_string}" - - crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) - return DelayID(f"{base}_{crc}") + return DelayID(f"syd_{stringutils.random_string(20)}") From 08f54ca54c0735420b58cffaa5cd133c97bd0188 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Thu, 15 Aug 2024 11:56:35 -0400 Subject: [PATCH 25/79] Assert non-negative delay; allow missing delay ID A delayed event missing from the DB on timeout isn't destructive, so don't worry about asserting for its presence. Do post a debug message explaining its absence, though --- synapse/handlers/delayed_events.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 1d3682888bc..a8c8286e375 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -313,7 +313,7 @@ async def _send_on_timeout( ) except NotFoundError: logger.debug( - "delay_id %s for local user %s was removed after it timed out, but before it was sent on timeout", + "delay_id %s for local user %s was removed from the DB before it timed out (or was always missing)", delay_id, user_localpart, ) @@ -330,7 +330,7 @@ def _schedule( user_localpart: UserLocalpart, delay: Delay, ) -> None: - """NOTE: Should not be called with a delay_id that isn't in the DB, or with a negative delay.""" + assert delay > 0, "Clock.call_later doesn't support negative delays" delay_sec = delay / 1000 logger.info( From 56abbb95d74184cf2f1de0bb81fe671c31be95ae Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Thu, 15 Aug 2024 13:29:33 -0400 Subject: [PATCH 26/79] Check for membership in delayed member events --- synapse/handlers/delayed_events.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index a8c8286e375..226ad8f4248 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -384,7 +384,8 @@ async def _send_event( try: if state_key is not None and event_type == EventTypes.Member: - membership = content.get("membership", None) + membership = content.get("membership") + assert membership is not None event_id, _ = await self._room_member_handler.update_membership( requester, target=UserID.from_string(state_key), From c4e80adcbe11ef35f98904bc3953df6d0b45e7e2 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Thu, 15 Aug 2024 14:00:41 -0400 Subject: [PATCH 27/79] Put colons after exception types in docstrings --- synapse/handlers/delayed_events.py | 4 ++-- synapse/storage/databases/main/delayed_events.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 226ad8f4248..f626584a774 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -257,8 +257,8 @@ async def update(self, requester: Requester, delay_id: str, action: str) -> None action: What to do with the delayed event. Raises: - SynapseError if the provided action is unknown, or is unsupported for the target delayed event. - NotFoundError if no matching delayed event could be found. + SynapseError: if the provided action is unknown, or is unsupported for the target delayed event. + NotFoundError: if no matching delayed event could be found. """ try: enum_action = _UpdateDelayedEventAction(action) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index f4645fb9c46..09ef6539e42 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -101,7 +101,7 @@ async def add( Returns: The generated ID assigned to the added delayed event. Raises: - SynapseError if the delayed event failed to be added. + SynapseError: if the delayed event failed to be added. """ def add_txn(txn: LoggingTransaction) -> DelayID: @@ -238,8 +238,8 @@ async def restart( Returns: The delay at which the delayed event will be sent (unless it is reset again). Raises: - NotFoundError if there is no matching delayed event. - SynapseError if the matching delayed event has no timeout. + NotFoundError: if there is no matching delayed event. + SynapseError: if the matching delayed event has no timeout. """ def restart_txn(txn: LoggingTransaction) -> Delay: @@ -403,7 +403,7 @@ async def pop_event( - The IDs of all removed delayed events with a timeout that must be unscheduled. Raises: - NotFoundError if there is no matching delayed event. + NotFoundError: if there is no matching delayed event. """ return await self.db_pool.runInteraction( "pop_event", @@ -483,7 +483,7 @@ async def remove( The IDs of all removed delayed events with a timeout that must be unscheduled. Raises: - NotFoundError if there is no matching delayed event. + NotFoundError: if there is no matching delayed event. """ removed_timeout_delay_ids = await self.db_pool.runInteraction( @@ -535,7 +535,7 @@ def _remove_txn( The specified columns for each delayed event with a timeout that was removed. Raises: - NotFoundError if allow_none is False and no delayed events match the keyvalues. + NotFoundError: if allow_none is False and no delayed events match the keyvalues. """ sql_with = f""" WITH target_rowids AS ( From 99e421cb7e21dee77367e61f020b3b3f3620f212 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Thu, 15 Aug 2024 14:54:13 -0400 Subject: [PATCH 28/79] Use built-in method to check for RETURNING support and remove some now-unneeded imports --- .../storage/databases/main/delayed_events.py | 26 +++---------------- 1 file changed, 4 insertions(+), 22 deletions(-) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 09ef6539e42..6651b5fabdd 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -18,7 +18,7 @@ import logging from http import HTTPStatus -from typing import TYPE_CHECKING, Any, Dict, List, NewType, Optional, Set, Tuple +from typing import Any, Dict, List, NewType, Optional, Set, Tuple from synapse.api.errors import ( Codes, @@ -28,18 +28,10 @@ SynapseError, ) from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import ( - DatabasePool, - LoggingDatabaseConnection, - LoggingTransaction, -) -from synapse.storage.engines import PostgresEngine +from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict, RoomID, StrCollection from synapse.util import json_encoder, stringutils as stringutils -if TYPE_CHECKING: - from synapse.server import HomeServer - logger = logging.getLogger(__name__) @@ -72,16 +64,6 @@ # TODO: Try to support workers class DelayedEventsStore(SQLBaseStore): - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - super().__init__(database, db_conn, hs) - # TODO: Always use RETURNING once the minimum supported Sqlite version is 3.35.0 - self._use_returning = isinstance(self.database_engine, PostgresEngine) - async def add( self, *, @@ -118,7 +100,7 @@ def add_txn(txn: LoggingTransaction) -> DelayID: ? ) """ - if self._use_returning: + if self.database_engine.supports_returning: sql += "RETURNING delay_rowid" txn.execute( sql, @@ -144,7 +126,7 @@ def add_txn(txn: LoggingTransaction) -> DelayID: f"Couldn't generate a unique delay_id for user_localpart {user_localpart}", ) - if not self._use_returning: + if not self.database_engine.supports_returning: txn.execute( """ SELECT delay_rowid From 335eeb7793325bd1ff922ec71c14a424dfdcf073 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Thu, 15 Aug 2024 15:10:53 -0400 Subject: [PATCH 29/79] Properly indent comment block Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- synapse/storage/schema/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 8715dec5c60..df8edb70025 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -145,7 +145,7 @@ Changes in SCHEMA_VERSION = 87 - MSC4140: Add `delayed_events` table that keeps track of events that are to - be posted in response to a resettable timeout or an on-demand action. + be posted in response to a resettable timeout or an on-demand action. """ From 21311fbcb30a4b76876faa71465db05003f4d2ea Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 19 Aug 2024 13:07:06 -0400 Subject: [PATCH 30/79] Remove support for delayed event parents as they are not part of MSC4140 anymore (and if reinstated, will likely require their own MSC) --- synapse/handlers/delayed_events.py | 56 +-- synapse/rest/client/room.py | 8 +- .../storage/databases/main/delayed_events.py | 465 +++++++----------- .../main/delta/87/01_add_delayed_events.sql | 20 +- tests/rest/client/test_rooms.py | 4 +- 5 files changed, 200 insertions(+), 353 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index f626584a774..8f9ef171e05 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -175,8 +175,7 @@ async def add( state_key: Optional[str], origin_server_ts: Optional[int], content: JsonDict, - delay: Optional[int], - parent_id: Optional[str], + delay: int, ) -> str: """ Creates a new delayed event. @@ -190,29 +189,21 @@ async def add( If None, the timestamp will be the actual time when the event is sent. content: The content of the event to be sent. delay: How long (in milliseconds) to wait before automatically sending the event. - If None, the event won't be automatically sent (allowed only when parent_id is set). - parent_id: The ID of the delayed event this one is grouped with. - May only refer to a delayed event that has no parent itself. - When the parent event is sent or cancelled, this one is cancelled; - and when this event is sent, the parent is cancelled. Returns: The ID of the added delayed event. """ - if delay is not None: - max_delay = self._config.experimental.msc4140_max_delay - if delay > max_delay: - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "The requested delay exceeds the allowed maximum.", - Codes.UNKNOWN, - { - "org.matrix.msc4140.errcode": "M_MAX_DELAY_EXCEEDED", - "org.matrix.msc4140.max_delay": max_delay, - }, - ) - # Callers should ensure that at least one of these are set - assert delay or parent_id + max_delay = self._config.experimental.msc4140_max_delay + if delay > max_delay: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "The requested delay exceeds the allowed maximum.", + Codes.UNKNOWN, + { + "org.matrix.msc4140.errcode": "M_MAX_DELAY_EXCEEDED", + "org.matrix.msc4140.max_delay": max_delay, + }, + ) await self._request_ratelimiter.ratelimit(requester) await self._initialized_from_db @@ -240,11 +231,9 @@ async def add( origin_server_ts=origin_server_ts, content=content, delay=delay, - parent_id=parent_id, ) - if delay is not None: - self._schedule(delay_id, user_localpart, Delay(delay)) + self._schedule(delay_id, user_localpart, Delay(delay)) return delay_id @@ -277,10 +266,7 @@ async def update(self, requester: Requester, delay_id: str, action: str) -> None async with self._get_delay_context(delay_id, user_localpart): if enum_action == _UpdateDelayedEventAction.CANCEL: - for removed_timeout_delay_id in await self._store.remove( - delay_id, user_localpart - ): - self._unschedule(removed_timeout_delay_id, user_localpart) + self._unschedule(delay_id, user_localpart) elif enum_action == _UpdateDelayedEventAction.RESTART: delay = await self._store.restart( @@ -293,12 +279,9 @@ async def update(self, requester: Requester, delay_id: str, action: str) -> None self._schedule(delay_id, user_localpart, delay) elif enum_action == _UpdateDelayedEventAction.SEND: - args, removed_timeout_delay_ids = await self._store.pop_event( - delay_id, user_localpart - ) + args = await self._store.pop_event(delay_id, user_localpart) - for timeout_delay_id in removed_timeout_delay_ids: - self._unschedule(timeout_delay_id, user_localpart) + self._unschedule(delay_id, user_localpart) await self._send_event(user_localpart, *args) async def _send_on_timeout( @@ -308,9 +291,7 @@ async def _send_on_timeout( async with self._get_delay_context(delay_id, user_localpart): try: - args, removed_timeout_delay_ids = await self._store.pop_event( - delay_id, user_localpart - ) + args = await self._store.pop_event(delay_id, user_localpart) except NotFoundError: logger.debug( "delay_id %s for local user %s was removed from the DB before it timed out (or was always missing)", @@ -319,9 +300,6 @@ async def _send_on_timeout( ) return - removed_timeout_delay_ids.remove(delay_id) - for timeout_delay_id in removed_timeout_delay_ids: - self._unschedule(timeout_delay_id, user_localpart) await self._send_event(user_localpart, *args) def _schedule( diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 53208e0605a..43ca93b9446 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -292,8 +292,7 @@ async def on_PUT( origin_server_ts = parse_integer(request, "ts") delay = parse_integer(request, "org.matrix.msc4140.delay") - parent_id = parse_string(request, "org.matrix.msc4140.parent_delay_id") - if delay is not None or parent_id is not None: + if delay is not None: delay_id = await self.delayed_events_handler.add( requester, room_id=room_id, @@ -302,7 +301,6 @@ async def on_PUT( origin_server_ts=origin_server_ts, content=content, delay=delay, - parent_id=parent_id, ) set_tag("delay_id", delay_id) @@ -382,8 +380,7 @@ async def _do( origin_server_ts = parse_integer(request, "ts") delay = parse_integer(request, "org.matrix.msc4140.delay") - parent_id = parse_string(request, "org.matrix.msc4140.parent_delay_id") - if delay is not None or parent_id is not None: + if delay is not None: delay_id = await self.delayed_events_handler.add( requester, room_id=room_id, @@ -392,7 +389,6 @@ async def _do( origin_server_ts=origin_server_ts, content=content, delay=delay, - parent_id=parent_id, ) set_tag("delay_id", delay_id) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 6651b5fabdd..b602958a46c 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -18,18 +18,12 @@ import logging from http import HTTPStatus -from typing import Any, Dict, List, NewType, Optional, Set, Tuple - -from synapse.api.errors import ( - Codes, - InvalidAPICallError, - NotFoundError, - StoreError, - SynapseError, -) +from typing import List, NewType, Optional, Tuple + +from synapse.api.errors import NotFoundError, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import LoggingTransaction -from synapse.types import JsonDict, RoomID, StrCollection +from synapse.types import JsonDict, RoomID from synapse.util import json_encoder, stringutils as stringutils logger = logging.getLogger(__name__) @@ -74,8 +68,7 @@ async def add( state_key: Optional[str], origin_server_ts: Optional[int], content: JsonDict, - delay: Optional[int], - parent_id: Optional[str], + delay: int, ) -> DelayID: """ Inserts a new delayed event in the DB. @@ -91,11 +84,11 @@ def add_txn(txn: LoggingTransaction) -> DelayID: try: sql = """ INSERT INTO delayed_events ( - delay_id, user_localpart, running_since, + delay_id, user_localpart, delay, running_since, room_id, event_type, state_key, origin_server_ts, content ) VALUES ( - ?, ?, ?, + ?, ?, ?, ?, ?, ?, ?, ?, ? ) @@ -107,6 +100,7 @@ def add_txn(txn: LoggingTransaction) -> DelayID: ( delay_id, user_localpart, + delay, current_ts, room_id.to_string(), event_type, @@ -140,53 +134,6 @@ def add_txn(txn: LoggingTransaction) -> DelayID: ) row = txn.fetchone() assert row is not None - delay_rowid = row[0] - - if delay is not None: - self.db_pool.simple_insert_txn( - txn, - table="delayed_event_timeouts", - values={ - "delay_rowid": delay_rowid, - "delay": delay, - }, - ) - - if parent_id is None: - self.db_pool.simple_insert_txn( - txn, - table="delayed_event_parents", - values={"delay_rowid": delay_rowid}, - ) - else: - try: - txn.execute( - """ - INSERT INTO delayed_event_children (child_rowid, parent_rowid) - SELECT ?, delay_rowid - FROM delayed_events - WHERE delay_id = ? AND user_localpart = ? - """, - ( - delay_rowid, - parent_id, - user_localpart, - ), - ) - # TODO: Handle only the error for the relevant foreign key / check violation - except Exception as e: - logger.debug( - "Error inserting into delayed_event_children", - str(e), - ) - raise SynapseError( - HTTPStatus.BAD_REQUEST, - # TODO: Improve the wording for this - "Invalid parent delayed event", - Codes.INVALID_PARAM, - ) - if txn.rowcount != 1: - raise NotFoundError("Parent delayed event not found") return delay_id @@ -210,7 +157,7 @@ async def restart( current_ts: Timestamp, ) -> Delay: """ - Resets the matching delayed event, as long as it has a timeout. + Restarts the send time of matching delayed event. Args: delay_id: The ID of the delayed event to restart. @@ -221,40 +168,46 @@ async def restart( Raises: NotFoundError: if there is no matching delayed event. - SynapseError: if the matching delayed event has no timeout. """ def restart_txn(txn: LoggingTransaction) -> Delay: - keyvalues = { - "delay_id": delay_id, - "user_localpart": user_localpart, - } - row = self.db_pool.simple_select_one_txn( - txn, - table="delayed_events JOIN delayed_event_timeouts USING (delay_rowid)", - keyvalues=keyvalues, - retcols=("delay_rowid", "delay"), - allow_none=True, - ) - if row is None: - try: - self.db_pool.simple_select_one_onecol_txn( - txn, - table="delayed_events", - keyvalues=keyvalues, - retcol="1", - ) - except StoreError: + if self.database_engine.supports_returning: + txn.execute( + """ + UPDATE delayed_events + SET running_since = ? + WHERE delay_id = ? AND user_localpart = ? + RETURNING delay_rowid, delay + """, + ( + current_ts, + delay_id, + user_localpart, + ), + ) + row = txn.fetchone() + if row is None: raise NotFoundError("Delayed event not found") - else: - raise InvalidAPICallError("Delayed event has no timeout") - - self.db_pool.simple_update_txn( - txn, - table="delayed_events", - keyvalues={"delay_rowid": row[0]}, - updatevalues={"running_since": current_ts}, - ) + else: + row = self.db_pool.simple_select_one_txn( + txn, + table="delayed_events", + keyvalues={ + "delay_id": delay_id, + "user_localpart": user_localpart, + }, + retcols=("delay_rowid", "delay"), + allow_none=True, + ) + if row is None: + raise NotFoundError("Delayed event not found") + self.db_pool.simple_update_txn( + txn, + table="delayed_events", + keyvalues={"delay_rowid": row[0]}, + updatevalues={"running_since": current_ts}, + ) + assert txn.rowcount == 1 return Delay(row[1]) return await self.db_pool.runInteraction("restart", restart_txn) @@ -266,25 +219,19 @@ async def get_all_for_user( """Returns all pending delayed events owned by the given user.""" # TODO: Store and return "transaction_id" # TODO: Support Pagination stream API ("next_batch" field) - rows = await self.db_pool.execute( - "get_all_for_user", - """ - SELECT - delay_id, - room_id, event_type, state_key, - delay, parent_id, - running_since, - content - FROM delayed_events - LEFT JOIN delayed_event_timeouts USING (delay_rowid) - LEFT JOIN ( - SELECT delay_id AS parent_id, child_rowid - FROM delayed_event_children - JOIN delayed_events ON parent_rowid = delay_rowid - ) ON delay_rowid = child_rowid - WHERE user_localpart = ? - """, - user_localpart, + rows = await self.db_pool.simple_select_list( + table="delayed_events", + keyvalues={"user_localpart": user_localpart}, + retcols=( + "delay_id", + "room_id", + "event_type", + "state_key", + "delay", + "running_since", + "content", + ), + desc="get_all_for_user", ) return [ { @@ -293,9 +240,8 @@ async def get_all_for_user( "type": EventType(row[2]), **({"state_key": StateKey(row[3])} if row[3] is not None else {}), **({"delay": Delay(row[4])} if row[4] is not None else {}), - **({"parent_delay_id": DelayID(row[5])} if row[5] is not None else {}), - "running_since": Timestamp(row[6]), - "content": db_to_json(row[7]), + "running_since": Timestamp(row[5]), + "content": db_to_json(row[6]), } for row in rows ] @@ -308,39 +254,61 @@ async def process_all_delays(self, current_ts: Timestamp) -> Tuple[ Pops all delayed events that should have timed out prior to the provided time, and returns all remaining timeout delayed events along with how much later from the provided time they should time out at. + + Does not return any delayed events that got removed but not sent, as this is + meant to be called on startup before any delayed events have been scheduled. """ def process_all_delays_txn(txn: LoggingTransaction) -> Tuple[ List[DelayedPartialEventWithUser], List[Tuple[DelayID, UserLocalpart, Delay]], ]: - events: List[DelayedPartialEventWithUser] = [] - removed_timeout_delay_ids: Set[DelayID] = set() - - txn.execute( - """ - WITH delay_send_times AS ( - SELECT *, running_since + delay AS send_ts - FROM delayed_events - JOIN delayed_event_timeouts USING (delay_rowid) + sql_cols = ", ".join( + ( + "user_localpart", + "room_id", + "event_type", + "state_key", + "origin_server_ts", + "content", ) - SELECT delay_rowid, user_localpart - FROM delay_send_times - WHERE send_ts < ? - ORDER BY send_ts - """, - (current_ts,), ) - for row in txn.fetchall(): - try: - event, removed_timeout_delay_id = self._pop_event_txn( - txn, - keyvalues={"delay_rowid": row[0]}, - ) - except NotFoundError: - pass - events.append((UserLocalpart(row[1]), *event)) - removed_timeout_delay_ids |= removed_timeout_delay_id + sql_from = "FROM delayed_events WHERE running_since + delay < ?" + sql_order = "ORDER BY running_since + delay" + sql_args = (current_ts,) + if self.database_engine.supports_returning: + txn.execute( + f""" + WITH timed_out_events AS ( + DELETE {sql_from} RETURNING * + ) SELECT {sql_cols} FROM timed_out_events {sql_order} + """, + sql_args, + ) + rows = txn.fetchall() + else: + txn.execute( + f"SELECT {sql_cols}, delay_rowid {sql_from} {sql_order}", sql_args + ) + rows = txn.fetchall() + self.db_pool.simple_delete_many_txn( + txn, + table="delayed_events", + column="delay_rowid", + values=tuple(row[-1] for row in rows), + keyvalues={}, + ) + events = [ + ( + UserLocalpart(row[0]), + RoomID.from_string(row[1]), + EventType(row[2]), + StateKey(row[3]) if row[3] is not None else None, + Timestamp(row[4]) if row[4] is not None else None, + db_to_json(row[5]), + ) + for row in rows + ] txn.execute( """ @@ -349,9 +317,8 @@ def process_all_delays_txn(txn: LoggingTransaction) -> Tuple[ user_localpart, running_since + delay - ? AS relative_delay FROM delayed_events - JOIN delayed_event_timeouts USING (delay_rowid) """, - (current_ts,), + sql_args, ) remaining_timeout_delays = [ ( @@ -371,113 +338,72 @@ async def pop_event( self, delay_id: DelayID, user_localpart: UserLocalpart, - ) -> Tuple[ - DelayedPartialEvent, - Set[DelayID], - ]: + ) -> DelayedPartialEvent: """ - Get the partial event of the matching delayed event, - and remove it and all of its parent/child/sibling events from the DB. + Gets the partial event of the matching delayed event, and remove it from the DB. Returns: - A tuple of: - - The partial event to send for the matching delayed event. - - The IDs of all removed delayed events with a timeout that must be unscheduled. + The partial event to send for the matching delayed event. Raises: NotFoundError: if there is no matching delayed event. """ - return await self.db_pool.runInteraction( - "pop_event", - self._pop_event_txn, - keyvalues={ - "delay_id": delay_id, - "user_localpart": user_localpart, - }, - ) - - def _pop_event_txn( - self, - txn: LoggingTransaction, - keyvalues: Dict[str, Any], - ) -> Tuple[ - DelayedPartialEvent, - Set[DelayID], - ]: - row = self.db_pool.simple_select_one_txn( - txn, - table="delayed_events", - keyvalues=keyvalues, - retcols=( - "delay_rowid", - "room_id", - "event_type", - "state_key", - "origin_server_ts", - "content", - ), - allow_none=True, - ) - if row is None: - raise NotFoundError("Delayed event not found") - target_delay_rowid = row[0] - event_row = row[1:] - - parent_rowid = self.db_pool.simple_select_one_onecol_txn( - txn, - table="delayed_event_children JOIN delayed_events ON child_rowid = delay_rowid", - keyvalues={"delay_rowid": target_delay_rowid}, - retcol="parent_rowid", - allow_none=True, - ) - removed_timeout_delay_ids = self._remove_txn( - txn, - keyvalues={ - "delay_rowid": ( - parent_rowid if parent_rowid is not None else target_delay_rowid + def pop_event_txn(txn: LoggingTransaction) -> DelayedPartialEvent: + sql_cols = ", ".join( + ( + "room_id", + "event_type", + "state_key", + "origin_server_ts", + "content", + ) + ) + sql_from = "FROM delayed_events WHERE delay_id = ? AND user_localpart = ?" + txn.execute( + ( + f"DELETE {sql_from} RETURNING {sql_cols}" + if self.database_engine.supports_returning + else f"SELECT {sql_cols} {sql_from}" ), - }, - retcols=("delay_id",), - ) + (delay_id, user_localpart), + ) + row = txn.fetchone() + if row is None: + raise NotFoundError("Delayed event not found") + elif not self.database_engine.supports_returning: + txn.execute(f"DELETE {sql_from}") + assert txn.rowcount == 1 + + return ( + RoomID.from_string(row[0]), + EventType(row[1]), + StateKey(row[2]) if row[2] is not None else None, + Timestamp(row[3]) if row[3] is not None else None, + db_to_json(row[4]), + ) - contents: JsonDict = db_to_json(event_row[4]) - return ( - ( - RoomID.from_string(event_row[0]), - EventType(event_row[1]), - StateKey(event_row[2]) if event_row[2] is not None else None, - Timestamp(event_row[3]) if event_row[3] is not None else None, - contents, - ), - {DelayID(r[0]) for r in removed_timeout_delay_ids}, - ) + return await self.db_pool.runInteraction("pop_event", pop_event_txn) async def remove( self, delay_id: DelayID, user_localpart: UserLocalpart, - ) -> Set[DelayID]: + ) -> None: """ - Removes the matching delayed event, as well as all of its child events if it is a parent. - - Returns: - The IDs of all removed delayed events with a timeout that must be unscheduled. + Removes the matching delayed event. Raises: NotFoundError: if there is no matching delayed event. """ - - removed_timeout_delay_ids = await self.db_pool.runInteraction( - "remove", - self._remove_txn, + await self.db_pool.simple_delete( + table="delayed_events", keyvalues={ "delay_id": delay_id, "user_localpart": user_localpart, }, - retcols=("delay_id",), + desc="remove", ) - return {DelayID(r[0]) for r in removed_timeout_delay_ids} async def remove_state_events( self, @@ -486,77 +412,42 @@ async def remove_state_events( state_key: str, ) -> List[Tuple[DelayID, UserLocalpart]]: """ - Removes all matching delayed state events from the DB, as well as their children. + Removes all matching delayed state events from the DB. Returns: - The ID & owner of every removed delayed event with a timeout that must be unscheduled. - """ - return await self.db_pool.runInteraction( - "remove_state_events", - self._remove_txn, - keyvalues={ - "room_id": room_id, - "event_type": event_type, - "state_key": state_key, - }, - retcols=("delay_id", "user_localpart"), - allow_none=True, - ) - - def _remove_txn( - self, - txn: LoggingTransaction, - keyvalues: Dict[str, Any], - retcols: StrCollection, - allow_none: bool = False, - ) -> List[Tuple]: + The ID & owner of every removed delayed event. """ - Removes delayed events matching the keyvalues, and any children they may have. - - Returns: - The specified columns for each delayed event with a timeout that was removed. - Raises: - NotFoundError: if allow_none is False and no delayed events match the keyvalues. - """ - sql_with = f""" - WITH target_rowids AS ( - SELECT delay_rowid - FROM delayed_events - WHERE {" AND ".join("%s = ?" % k for k in keyvalues)} + def remove_state_events_txn(txn: LoggingTransaction) -> List[Tuple]: + sql_cols = ", ".join( + ( + "delay_id", + "user_localpart", + ) ) - """ - sql_where = """ - WHERE delay_rowid IN (SELECT * FROM target_rowids) - OR delay_rowid IN ( - SELECT child_rowid - FROM delayed_event_children - JOIN target_rowids ON parent_rowid = delay_rowid + sql_from = ( + "FROM delayed_events " + "WHERE room_id = ? AND event_type = ? AND state_key = ?" ) - """ - args = list(keyvalues.values()) - txn.execute( - f""" - {sql_with} - SELECT {", ".join(retcols)} - FROM delayed_events - JOIN delayed_event_timeouts USING (delay_rowid) - {sql_where} - """, - args, - ) - rows = txn.fetchall() - txn.execute( - f""" - {sql_with} - DELETE FROM delayed_events - {sql_where} - """, - args, + sql_args = (room_id, event_type, state_key) + if self.database_engine.supports_returning: + txn.execute(f"DELETE {sql_from} RETURNING {sql_cols}", sql_args) + rows = txn.fetchall() + else: + txn.execute(f"SELECT {sql_cols} {sql_from}", sql_args) + rows = txn.fetchall() + txn.execute(f"DELETE {sql_from}") + return [ + ( + DelayID(row[0]), + UserLocalpart(row[1]), + ) + for row in rows + ] + + return await self.db_pool.runInteraction( + "remove_state_events", remove_state_events_txn ) - if not allow_none and txn.rowcount == 0: - raise NotFoundError("No delayed event found") - return rows def _generate_delay_id() -> DelayID: diff --git a/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql index 7ae4d189ca7..f7ad172c08b 100644 --- a/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql +++ b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql @@ -20,6 +20,7 @@ CREATE TABLE delayed_events ( delay_id TEXT NOT NULL, user_localpart TEXT NOT NULL, + delay BIGINT NOT NULL, running_since BIGINT NOT NULL, room_id TEXT NOT NULL, event_type TEXT NOT NULL, @@ -31,22 +32,3 @@ CREATE TABLE delayed_events ( CREATE INDEX delayed_events_room_state_event_idx ON delayed_events (room_id, event_type, state_key) WHERE state_key IS NOT NULL; CREATE INDEX delayed_events_user_idx ON delayed_events (user_localpart); - -CREATE TABLE delayed_event_timeouts ( - delay_rowid INTEGER PRIMARY KEY - REFERENCES delayed_events (delay_rowid) ON DELETE CASCADE, - delay BIGINT NOT NULL -); - -CREATE TABLE delayed_event_parents ( - delay_rowid INTEGER PRIMARY KEY - REFERENCES delayed_event_timeouts (delay_rowid) ON DELETE CASCADE -); - -CREATE TABLE delayed_event_children ( - child_rowid INTEGER PRIMARY KEY - REFERENCES delayed_events (delay_rowid) ON DELETE CASCADE, - parent_rowid INTEGER NOT NULL - REFERENCES delayed_event_parents (delay_rowid) ON DELETE CASCADE, - CHECK (child_rowid <> parent_rowid) -); diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index d132cc7c540..bd299802182 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -742,7 +742,7 @@ def test_post_room_no_keys(self) -> None: self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(38, channel.resource_usage.db_txn_count) + self.assertEqual(33, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -755,7 +755,7 @@ def test_post_room_initial_state(self) -> None: self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(40, channel.resource_usage.db_txn_count) + self.assertEqual(35, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id From 347811891af4a5310884620e246d82b17c0d2cb0 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 20 Aug 2024 11:01:00 -0400 Subject: [PATCH 31/79] Fix comments --- synapse/storage/databases/main/delayed_events.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index b602958a46c..0c7475045c0 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -157,7 +157,7 @@ async def restart( current_ts: Timestamp, ) -> Delay: """ - Restarts the send time of matching delayed event. + Restarts the send time of the matching delayed event. Args: delay_id: The ID of the delayed event to restart. @@ -454,7 +454,7 @@ def _generate_delay_id() -> DelayID: """Generates an opaque string, for use as a delay ID""" # We use the following format for delay IDs: - # syf_ + # syd_ # They are scoped to user localparts, so it is possible for # the same ID to exist for multiple users. From 00217f3152cf49117b6e4976bca3037a375a8177 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 20 Aug 2024 09:23:16 -0400 Subject: [PATCH 32/79] Reraise the error for an invalid max delay config --- synapse/config/experimental.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f9b38a5c04e..4c00ad2b042 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -449,11 +449,11 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.msc4140_max_delay = int(experimental["msc4140_max_delay"]) if self.msc4140_max_delay <= 0: raise ValueError - except ValueError: + except ValueError as e: raise ConfigError( "msc4140_max_delay must be a positive integer", ("experimental", "msc4140_max_delay"), - ) + ) from e except KeyError: self.msc4140_max_delay = 10 * 365 * 24 * 60 * 60 * 1000 # 10 years From 90cc8b5a22fa14a07fe4cb265bb8fd4d5889ba01 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 20 Aug 2024 09:27:40 -0400 Subject: [PATCH 33/79] Don't bother handling DB key collisions because generated delay IDs should have enough entropy for collisions to be extremely rare --- .../storage/databases/main/delayed_events.py | 79 +++++++------------ 1 file changed, 28 insertions(+), 51 deletions(-) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 0c7475045c0..f6ca1f9631c 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -17,10 +17,9 @@ # import logging -from http import HTTPStatus from typing import List, NewType, Optional, Tuple -from synapse.api.errors import NotFoundError, SynapseError +from synapse.api.errors import NotFoundError from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict, RoomID @@ -81,44 +80,33 @@ async def add( def add_txn(txn: LoggingTransaction) -> DelayID: delay_id = _generate_delay_id() - try: - sql = """ - INSERT INTO delayed_events ( - delay_id, user_localpart, delay, running_since, - room_id, event_type, state_key, origin_server_ts, - content - ) VALUES ( - ?, ?, ?, ?, - ?, ?, ?, ?, - ? - ) - """ - if self.database_engine.supports_returning: - sql += "RETURNING delay_rowid" - txn.execute( - sql, - ( - delay_id, - user_localpart, - delay, - current_ts, - room_id.to_string(), - event_type, - state_key, - origin_server_ts, - json_encoder.encode(content), - ), - ) - # TODO: Handle only the error for DB key collisions - except Exception as e: - logger.debug( - "Error inserting into delayed_events", - str(e), - ) - raise SynapseError( - HTTPStatus.INTERNAL_SERVER_ERROR, - f"Couldn't generate a unique delay_id for user_localpart {user_localpart}", + sql = """ + INSERT INTO delayed_events ( + delay_id, user_localpart, delay, running_since, + room_id, event_type, state_key, origin_server_ts, + content + ) VALUES ( + ?, ?, ?, ?, + ?, ?, ?, ?, + ? ) + """ + if self.database_engine.supports_returning: + sql += "RETURNING delay_rowid" + txn.execute( + sql, + ( + delay_id, + user_localpart, + delay, + current_ts, + room_id.to_string(), + event_type, + state_key, + origin_server_ts, + json_encoder.encode(content), + ), + ) if not self.database_engine.supports_returning: txn.execute( @@ -137,18 +125,7 @@ def add_txn(txn: LoggingTransaction) -> DelayID: return delay_id - attempts_remaining = 10 - while True: - try: - return await self.db_pool.runInteraction("add", add_txn) - except SynapseError as e: - if ( - e.code == HTTPStatus.INTERNAL_SERVER_ERROR - and attempts_remaining > 0 - ): - attempts_remaining -= 1 - else: - raise e + return await self.db_pool.runInteraction("add", add_txn) async def restart( self, From 8a65c773362c3abf33ff2fd0755faeca3e53f87a Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 20 Aug 2024 11:00:34 -0400 Subject: [PATCH 34/79] Move update action value check to REST layer --- synapse/handlers/delayed_events.py | 78 +++++++++++++++------------ synapse/rest/client/delayed_events.py | 23 +++++++- 2 files changed, 66 insertions(+), 35 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 8f9ef171e05..3bbafcaa6e6 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -18,7 +18,6 @@ import logging from contextlib import asynccontextmanager -from enum import Enum from http import HTTPStatus from typing import ( TYPE_CHECKING, @@ -67,12 +66,6 @@ _STATE_LOCK_KEY = "STATE_LOCK_KEY" -class _UpdateDelayedEventAction(Enum): - CANCEL = "cancel" - RESTART = "restart" - SEND = "send" - - @attr.s(slots=True, frozen=True, auto_attribs=True) class _DelayedCallKey: delay_id: DelayID @@ -237,52 +230,69 @@ async def add( return delay_id - async def update(self, requester: Requester, delay_id: str, action: str) -> None: + async def cancel(self, requester: Requester, delay_id: str) -> None: """ - Executes the appropriate action for the matching delayed event. + Cancels the scheduled delivery of the matching delayed event. Args: + requester: The owner of the delayed event to act on. delay_id: The ID of the delayed event to act on. - action: What to do with the delayed event. Raises: - SynapseError: if the provided action is unknown, or is unsupported for the target delayed event. NotFoundError: if no matching delayed event could be found. """ - try: - enum_action = _UpdateDelayedEventAction(action) - except ValueError: - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "'action' is not one of " - + ", ".join(f"'{m.value}'" for m in _UpdateDelayedEventAction), - Codes.INVALID_PARAM, - ) + await self._request_ratelimiter.ratelimit(requester) delay_id = DelayID(delay_id) user_localpart = UserLocalpart(requester.user.localpart) + async with self._get_delay_context(delay_id, user_localpart): + self._unschedule(delay_id, user_localpart) + + async def restart(self, requester: Requester, delay_id: str) -> None: + """ + Restarts the scheduled delivery of the matching delayed event. + Args: + requester: The owner of the delayed event to act on. + delay_id: The ID of the delayed event to act on. + + Raises: + NotFoundError: if no matching delayed event could be found. + """ await self._request_ratelimiter.ratelimit(requester) + delay_id = DelayID(delay_id) + user_localpart = UserLocalpart(requester.user.localpart) async with self._get_delay_context(delay_id, user_localpart): - if enum_action == _UpdateDelayedEventAction.CANCEL: - self._unschedule(delay_id, user_localpart) + delay = await self._store.restart( + delay_id, + user_localpart, + self._get_current_ts(), + ) - elif enum_action == _UpdateDelayedEventAction.RESTART: - delay = await self._store.restart( - delay_id, - user_localpart, - self._get_current_ts(), - ) + self._unschedule(delay_id, user_localpart) + self._schedule(delay_id, user_localpart, delay) - self._unschedule(delay_id, user_localpart) - self._schedule(delay_id, user_localpart, delay) + async def send(self, requester: Requester, delay_id: str) -> None: + """ + Immediately sends the matching delayed event, instead of waiting for its scheduled delivery. - elif enum_action == _UpdateDelayedEventAction.SEND: - args = await self._store.pop_event(delay_id, user_localpart) + Args: + requester: The owner of the delayed event to act on. + delay_id: The ID of the delayed event to act on. - self._unschedule(delay_id, user_localpart) - await self._send_event(user_localpart, *args) + Raises: + NotFoundError: if no matching delayed event could be found. + """ + await self._request_ratelimiter.ratelimit(requester) + + delay_id = DelayID(delay_id) + user_localpart = UserLocalpart(requester.user.localpart) + async with self._get_delay_context(delay_id, user_localpart): + args = await self._store.pop_event(delay_id, user_localpart) + + self._unschedule(delay_id, user_localpart) + await self._send_event(user_localpart, *args) async def _send_on_timeout( self, delay_id: DelayID, user_localpart: UserLocalpart diff --git a/synapse/rest/client/delayed_events.py b/synapse/rest/client/delayed_events.py index c50cc96f6cf..8608ccebcc4 100644 --- a/synapse/rest/client/delayed_events.py +++ b/synapse/rest/client/delayed_events.py @@ -18,6 +18,7 @@ """ This module contains REST servlets to do with delayed events: /delayed_events/ """ import logging +from enum import Enum from http import HTTPStatus from typing import TYPE_CHECKING, List, Tuple @@ -34,6 +35,12 @@ logger = logging.getLogger(__name__) +class _UpdateDelayedEventAction(Enum): + CANCEL = "cancel" + RESTART = "restart" + SEND = "send" + + # TODO: Needs unit testing class UpdateDelayedEventServlet(RestServlet): PATTERNS = client_patterns( @@ -61,8 +68,22 @@ async def on_POST( "'action' is missing", Codes.MISSING_PARAM, ) + try: + enum_action = _UpdateDelayedEventAction(action) + except ValueError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "'action' is not one of " + + ", ".join(f"'{m.value}'" for m in _UpdateDelayedEventAction), + Codes.INVALID_PARAM, + ) - await self.delayed_events_handler.update(requester, delay_id, action) + if enum_action == _UpdateDelayedEventAction.CANCEL: + await self.delayed_events_handler.cancel(requester, delay_id) + elif enum_action == _UpdateDelayedEventAction.RESTART: + await self.delayed_events_handler.restart(requester, delay_id) + elif enum_action == _UpdateDelayedEventAction.SEND: + await self.delayed_events_handler.send(requester, delay_id) return 200, {} From 47b6e69bebe608ba591c3301e68964fb31b709ac Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 20 Aug 2024 11:15:49 -0400 Subject: [PATCH 35/79] Move delay value check to REST layer --- synapse/rest/client/room.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 43ca93b9446..694cf42214e 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -196,6 +196,7 @@ def __init__(self, hs: "HomeServer"): self.message_handler = hs.get_message_handler() self.delayed_events_handler = hs.get_delayed_events_handler() self.auth = hs.get_auth() + self._msc4140_max_delay = hs.config.experimental.msc4140_max_delay def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/state/$eventtype @@ -291,7 +292,7 @@ async def on_PUT( if requester.app_service: origin_server_ts = parse_integer(request, "ts") - delay = parse_integer(request, "org.matrix.msc4140.delay") + delay = _parse_request_delay(request, self._msc4140_max_delay) if delay is not None: delay_id = await self.delayed_events_handler.add( requester, @@ -359,6 +360,7 @@ def __init__(self, hs: "HomeServer"): self.event_creation_handler = hs.get_event_creation_handler() self.delayed_events_handler = hs.get_delayed_events_handler() self.auth = hs.get_auth() + self._msc4140_max_delay = hs.config.experimental.msc4140_max_delay def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/send/$event_type[/$txn_id] @@ -379,7 +381,7 @@ async def _do( if requester.app_service: origin_server_ts = parse_integer(request, "ts") - delay = parse_integer(request, "org.matrix.msc4140.delay") + delay = _parse_request_delay(request, self._msc4140_max_delay) if delay is not None: delay_id = await self.delayed_events_handler.add( requester, @@ -446,6 +448,23 @@ async def on_PUT( ) +def _parse_request_delay(request: SynapseRequest, max_delay: int) -> Optional[int]: + delay = parse_integer(request, "org.matrix.msc4140.delay") + if delay is None: + return None + if delay > max_delay: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "The requested delay exceeds the allowed maximum.", + Codes.UNKNOWN, + { + "org.matrix.msc4140.errcode": "M_MAX_DELAY_EXCEEDED", + "org.matrix.msc4140.max_delay": max_delay, + }, + ) + return delay + + # TODO: Needs unit testing for room ID + alias joins class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): CATEGORY = "Event sending requests" From e0e6802f1dfb430af7d1dc223bb221fed89d6937 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 20 Aug 2024 11:16:25 -0400 Subject: [PATCH 36/79] Replace assert with a single-row update --- synapse/storage/databases/main/delayed_events.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index f6ca1f9631c..80558b7b749 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -178,13 +178,12 @@ def restart_txn(txn: LoggingTransaction) -> Delay: ) if row is None: raise NotFoundError("Delayed event not found") - self.db_pool.simple_update_txn( + self.db_pool.simple_update_one_txn( txn, table="delayed_events", keyvalues={"delay_rowid": row[0]}, updatevalues={"running_since": current_ts}, ) - assert txn.rowcount == 1 return Delay(row[1]) return await self.db_pool.runInteraction("restart", restart_txn) From 5672d0d973870a7864a0ac4204be57eef68a40ba Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 20 Aug 2024 12:20:49 -0400 Subject: [PATCH 37/79] Replace rowid primary key with delay ID + user ID --- .../storage/databases/main/delayed_events.py | 113 +++++++----------- .../main/delta/87/01_add_delayed_events.sql | 9 +- .../87/01_add_delayed_events.sql.postgres | 16 --- 3 files changed, 41 insertions(+), 97 deletions(-) delete mode 100644 synapse/storage/schema/main/delta/87/01_add_delayed_events.sql.postgres diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 80558b7b749..d9c8da03c25 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -21,7 +21,7 @@ from synapse.api.errors import NotFoundError from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import LoggingTransaction +from synapse.storage.database import LoggingTransaction, make_tuple_in_list_sql_clause from synapse.types import JsonDict, RoomID from synapse.util import json_encoder, stringutils as stringutils @@ -73,59 +73,24 @@ async def add( Inserts a new delayed event in the DB. Returns: The generated ID assigned to the added delayed event. - - Raises: - SynapseError: if the delayed event failed to be added. """ - - def add_txn(txn: LoggingTransaction) -> DelayID: - delay_id = _generate_delay_id() - sql = """ - INSERT INTO delayed_events ( - delay_id, user_localpart, delay, running_since, - room_id, event_type, state_key, origin_server_ts, - content - ) VALUES ( - ?, ?, ?, ?, - ?, ?, ?, ?, - ? - ) - """ - if self.database_engine.supports_returning: - sql += "RETURNING delay_rowid" - txn.execute( - sql, - ( - delay_id, - user_localpart, - delay, - current_ts, - room_id.to_string(), - event_type, - state_key, - origin_server_ts, - json_encoder.encode(content), - ), - ) - - if not self.database_engine.supports_returning: - txn.execute( - """ - SELECT delay_rowid - FROM delayed_events - WHERE delay_id = ? AND user_localpart = ? - """, - ( - delay_id, - user_localpart, - ), - ) - row = txn.fetchone() - assert row is not None - - return delay_id - - return await self.db_pool.runInteraction("add", add_txn) + delay_id = _generate_delay_id() + await self.db_pool.simple_insert( + table="delayed_events", + values={ + "delay_id": delay_id, + "user_localpart": user_localpart, + "delay": delay, + "running_since": current_ts, + "room_id": room_id.to_string(), + "event_type": event_type, + "state_key": state_key, + "origin_server_ts": origin_server_ts, + "content": json_encoder.encode(content), + }, + desc="add", + ) + return delay_id async def restart( self, @@ -154,7 +119,7 @@ def restart_txn(txn: LoggingTransaction) -> Delay: UPDATE delayed_events SET running_since = ? WHERE delay_id = ? AND user_localpart = ? - RETURNING delay_rowid, delay + RETURNING delay """, ( current_ts, @@ -165,26 +130,28 @@ def restart_txn(txn: LoggingTransaction) -> Delay: row = txn.fetchone() if row is None: raise NotFoundError("Delayed event not found") + return Delay(row[0]) else: - row = self.db_pool.simple_select_one_txn( + keyvalues = { + "delay_id": delay_id, + "user_localpart": user_localpart, + } + delay = self.db_pool.simple_select_one_onecol_txn( txn, table="delayed_events", - keyvalues={ - "delay_id": delay_id, - "user_localpart": user_localpart, - }, - retcols=("delay_rowid", "delay"), + keyvalues=keyvalues, + retcol="delay", allow_none=True, ) - if row is None: + if delay is None: raise NotFoundError("Delayed event not found") self.db_pool.simple_update_one_txn( txn, table="delayed_events", - keyvalues={"delay_rowid": row[0]}, + keyvalues=keyvalues, updatevalues={"running_since": current_ts}, ) - return Delay(row[1]) + return Delay(delay) return await self.db_pool.runInteraction("restart", restart_txn) @@ -251,7 +218,6 @@ def process_all_delays_txn(txn: LoggingTransaction) -> Tuple[ ) sql_from = "FROM delayed_events WHERE running_since + delay < ?" sql_order = "ORDER BY running_since + delay" - sql_args = (current_ts,) if self.database_engine.supports_returning: txn.execute( f""" @@ -259,20 +225,21 @@ def process_all_delays_txn(txn: LoggingTransaction) -> Tuple[ DELETE {sql_from} RETURNING * ) SELECT {sql_cols} FROM timed_out_events {sql_order} """, - sql_args, + (current_ts,), ) rows = txn.fetchall() else: txn.execute( - f"SELECT {sql_cols}, delay_rowid {sql_from} {sql_order}", sql_args + f"SELECT {sql_cols}, delay_id {sql_from} {sql_order}", (current_ts,) ) rows = txn.fetchall() - self.db_pool.simple_delete_many_txn( - txn, - table="delayed_events", - column="delay_rowid", - values=tuple(row[-1] for row in rows), - keyvalues={}, + sql_key_clause, sql_key_args = make_tuple_in_list_sql_clause( + self.database_engine, + ("delay_id", "user_localpart"), + tuple((row[-1], row[0]) for row in rows), + ) + txn.execute( + f"DELETE from delayed_events WHERE {sql_key_clause}", sql_key_args ) events = [ ( @@ -294,7 +261,7 @@ def process_all_delays_txn(txn: LoggingTransaction) -> Tuple[ running_since + delay - ? AS relative_delay FROM delayed_events """, - sql_args, + (current_ts,), ) remaining_timeout_delays = [ ( diff --git a/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql index f7ad172c08b..40a416910fc 100644 --- a/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql +++ b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql @@ -12,12 +12,6 @@ -- . CREATE TABLE delayed_events ( - -- An alias of rowid in SQLite. - -- Newly-inserted rows that don't assign a (non-NULL) value for this column - -- will have it set to a table-unique value. - -- For Postgres to do this, the column must be set as an identity column. - delay_rowid INTEGER PRIMARY KEY, - delay_id TEXT NOT NULL, user_localpart TEXT NOT NULL, delay BIGINT NOT NULL, @@ -27,8 +21,7 @@ CREATE TABLE delayed_events ( state_key TEXT, origin_server_ts BIGINT, content bytea NOT NULL, - UNIQUE (delay_id, user_localpart) + PRIMARY KEY (delay_id, user_localpart) ); CREATE INDEX delayed_events_room_state_event_idx ON delayed_events (room_id, event_type, state_key) WHERE state_key IS NOT NULL; -CREATE INDEX delayed_events_user_idx ON delayed_events (user_localpart); diff --git a/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql.postgres b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql.postgres deleted file mode 100644 index 3771db43781..00000000000 --- a/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql.postgres +++ /dev/null @@ -1,16 +0,0 @@ --- --- This file is licensed under the Affero General Public License (AGPL) version 3. --- --- Copyright (C) 2024 New Vector, Ltd --- --- This program is free software: you can redistribute it and/or modify --- it under the terms of the GNU Affero General Public License as --- published by the Free Software Foundation, either version 3 of the --- License, or (at your option) any later version. --- --- See the GNU Affero General Public License for more details: --- . - --- Sets the column as an identity column, meaning that the column in new rows --- will automatically have values from an implicit sequence assigned to it. -ALTER TABLE delayed_events ALTER COLUMN delay_rowid ADD GENERATED ALWAYS AS IDENTITY; From 85cb72f9f6bae9ea4c7499224337372c4a40f0f2 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 20 Aug 2024 12:37:00 -0400 Subject: [PATCH 38/79] Rename delayed events store methods & txn descs --- synapse/handlers/delayed_events.py | 16 +++++------ .../storage/databases/main/delayed_events.py | 28 +++++++++---------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 3bbafcaa6e6..09bd263eaf0 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -92,8 +92,8 @@ def __init__(self, hs: "HomeServer"): async def _schedule_db_events() -> None: # TODO: Sync all state first, so that affected delayed state events will be cancelled - events, remaining_timeout_delays = await self._store.process_all_delays( - self._get_current_ts() + events, remaining_timeout_delays = ( + await self._store.process_all_delayed_events(self._get_current_ts()) ) sent_state: Set[Tuple[RoomID, EventType, StateKey]] = set() for ( @@ -149,7 +149,7 @@ async def on_new_event( for ( removed_timeout_delay_id, removed_timeout_delay_user_localpart, - ) in await self._store.remove_state_events( + ) in await self._store.remove_delayed_state_events( event.room_id, event.type, state_key, @@ -215,7 +215,7 @@ async def add( ) user_localpart = UserLocalpart(requester.user.localpart) - delay_id = await self._store.add( + delay_id = await self._store.add_delayed_event( user_localpart=user_localpart, current_ts=self._get_current_ts(), room_id=RoomID.from_string(room_id), @@ -264,7 +264,7 @@ async def restart(self, requester: Requester, delay_id: str) -> None: delay_id = DelayID(delay_id) user_localpart = UserLocalpart(requester.user.localpart) async with self._get_delay_context(delay_id, user_localpart): - delay = await self._store.restart( + delay = await self._store.restart_delayed_event( delay_id, user_localpart, self._get_current_ts(), @@ -289,7 +289,7 @@ async def send(self, requester: Requester, delay_id: str) -> None: delay_id = DelayID(delay_id) user_localpart = UserLocalpart(requester.user.localpart) async with self._get_delay_context(delay_id, user_localpart): - args = await self._store.pop_event(delay_id, user_localpart) + args = await self._store.pop_delayed_event(delay_id, user_localpart) self._unschedule(delay_id, user_localpart) await self._send_event(user_localpart, *args) @@ -301,7 +301,7 @@ async def _send_on_timeout( async with self._get_delay_context(delay_id, user_localpart): try: - args = await self._store.pop_event(delay_id, user_localpart) + args = await self._store.pop_delayed_event(delay_id, user_localpart) except NotFoundError: logger.debug( "delay_id %s for local user %s was removed from the DB before it timed out (or was always missing)", @@ -349,7 +349,7 @@ async def get_all_for_user(self, requester: Requester) -> List[JsonDict]: """Return all pending delayed events requested by the given user.""" await self._request_ratelimiter.ratelimit(requester) await self._initialized_from_db - return await self._store.get_all_for_user( + return await self._store.get_all_delayed_events_for_user( UserLocalpart(requester.user.localpart) ) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index d9c8da03c25..aacb6ae63bc 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -57,7 +57,7 @@ # TODO: Try to support workers class DelayedEventsStore(SQLBaseStore): - async def add( + async def add_delayed_event( self, *, user_localpart: UserLocalpart, @@ -88,11 +88,11 @@ async def add( "origin_server_ts": origin_server_ts, "content": json_encoder.encode(content), }, - desc="add", + desc="add_delayed_event", ) return delay_id - async def restart( + async def restart_delayed_event( self, delay_id: DelayID, user_localpart: UserLocalpart, @@ -153,9 +153,9 @@ def restart_txn(txn: LoggingTransaction) -> Delay: ) return Delay(delay) - return await self.db_pool.runInteraction("restart", restart_txn) + return await self.db_pool.runInteraction("restart_delayed_event", restart_txn) - async def get_all_for_user( + async def get_all_delayed_events_for_user( self, user_localpart: UserLocalpart, ) -> List[JsonDict]: @@ -174,7 +174,7 @@ async def get_all_for_user( "running_since", "content", ), - desc="get_all_for_user", + desc="get_all_delayed_events_for_user", ) return [ { @@ -189,7 +189,7 @@ async def get_all_for_user( for row in rows ] - async def process_all_delays(self, current_ts: Timestamp) -> Tuple[ + async def process_all_delayed_events(self, current_ts: Timestamp) -> Tuple[ List[DelayedPartialEventWithUser], List[Tuple[DelayID, UserLocalpart, Delay]], ]: @@ -274,10 +274,10 @@ def process_all_delays_txn(txn: LoggingTransaction) -> Tuple[ return events, remaining_timeout_delays return await self.db_pool.runInteraction( - "process_all_delays", process_all_delays_txn + "process_all_delayed_events", process_all_delays_txn ) - async def pop_event( + async def pop_delayed_event( self, delay_id: DelayID, user_localpart: UserLocalpart, @@ -326,9 +326,9 @@ def pop_event_txn(txn: LoggingTransaction) -> DelayedPartialEvent: db_to_json(row[4]), ) - return await self.db_pool.runInteraction("pop_event", pop_event_txn) + return await self.db_pool.runInteraction("pop_delayed_event", pop_event_txn) - async def remove( + async def remove_delayed_event( self, delay_id: DelayID, user_localpart: UserLocalpart, @@ -345,10 +345,10 @@ async def remove( "delay_id": delay_id, "user_localpart": user_localpart, }, - desc="remove", + desc="remove_delayed_event", ) - async def remove_state_events( + async def remove_delayed_state_events( self, room_id: str, event_type: str, @@ -389,7 +389,7 @@ def remove_state_events_txn(txn: LoggingTransaction) -> List[Tuple]: ] return await self.db_pool.runInteraction( - "remove_state_events", remove_state_events_txn + "remove_delayed_state_events", remove_state_events_txn ) From 974463f1a222b484599d2748695fadb3972b4529 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 20 Aug 2024 12:41:54 -0400 Subject: [PATCH 39/79] Remove delayed event from DB on cancel This was always supposed to be done, but was lost a few commits ago --- synapse/handlers/delayed_events.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 09bd263eaf0..c278894951b 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -246,6 +246,8 @@ async def cancel(self, requester: Requester, delay_id: str) -> None: delay_id = DelayID(delay_id) user_localpart = UserLocalpart(requester.user.localpart) async with self._get_delay_context(delay_id, user_localpart): + await self._store.remove_delayed_event(delay_id, user_localpart) + self._unschedule(delay_id, user_localpart) async def restart(self, requester: Requester, delay_id: str) -> None: From 05accda65b5161a1a572c9925e6580ec4083024d Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Wed, 21 Aug 2024 08:23:33 -0400 Subject: [PATCH 40/79] Make user_localpart first column of DB key to allow for faster queries on user_localpart alone --- synapse/storage/schema/main/delta/87/01_add_delayed_events.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql index 40a416910fc..6998de35162 100644 --- a/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql +++ b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql @@ -21,7 +21,7 @@ CREATE TABLE delayed_events ( state_key TEXT, origin_server_ts BIGINT, content bytea NOT NULL, - PRIMARY KEY (delay_id, user_localpart) + PRIMARY KEY (user_localpart, delay_id) ); CREATE INDEX delayed_events_room_state_event_idx ON delayed_events (room_id, event_type, state_key) WHERE state_key IS NOT NULL; From 2a9069c2faa3530d73217dd7b84cd7648851a309 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Wed, 21 Aug 2024 08:37:47 -0400 Subject: [PATCH 41/79] Don't handle missing delays in DB lookup because delay is now a mandatory field of all delayed events --- synapse/storage/databases/main/delayed_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index aacb6ae63bc..e363a41ed86 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -182,7 +182,7 @@ async def get_all_delayed_events_for_user( "room_id": str(RoomID.from_string(row[1])), "type": EventType(row[2]), **({"state_key": StateKey(row[3])} if row[3] is not None else {}), - **({"delay": Delay(row[4])} if row[4] is not None else {}), + "delay": Delay(row[4]), "running_since": Timestamp(row[5]), "content": db_to_json(row[6]), } From 57b72291fa1d13dedace00c111fc161b6c253702 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Wed, 21 Aug 2024 08:51:30 -0400 Subject: [PATCH 42/79] Replace running_since with send_ts and index it --- .../storage/databases/main/delayed_events.py | 22 +++++++++---------- .../main/delta/87/01_add_delayed_events.sql | 3 ++- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index e363a41ed86..253b5363364 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -81,7 +81,7 @@ async def add_delayed_event( "delay_id": delay_id, "user_localpart": user_localpart, "delay": delay, - "running_since": current_ts, + "send_ts": current_ts + delay, "room_id": room_id.to_string(), "event_type": event_type, "state_key": state_key, @@ -104,7 +104,7 @@ async def restart_delayed_event( Args: delay_id: The ID of the delayed event to restart. user_localpart: The localpart of the delayed event's owner. - current_ts: The current time, to which the delayed event's "running_since" will be set to. + current_ts: The current time, which will be used to calculate the new send time. Returns: The delay at which the delayed event will be sent (unless it is reset again). @@ -117,7 +117,7 @@ def restart_txn(txn: LoggingTransaction) -> Delay: txn.execute( """ UPDATE delayed_events - SET running_since = ? + SET send_ts = ? + delay WHERE delay_id = ? AND user_localpart = ? RETURNING delay """, @@ -149,7 +149,7 @@ def restart_txn(txn: LoggingTransaction) -> Delay: txn, table="delayed_events", keyvalues=keyvalues, - updatevalues={"running_since": current_ts}, + updatevalues={"send_ts": current_ts + delay}, ) return Delay(delay) @@ -171,7 +171,7 @@ async def get_all_delayed_events_for_user( "event_type", "state_key", "delay", - "running_since", + "send_ts", "content", ), desc="get_all_delayed_events_for_user", @@ -183,7 +183,7 @@ async def get_all_delayed_events_for_user( "type": EventType(row[2]), **({"state_key": StateKey(row[3])} if row[3] is not None else {}), "delay": Delay(row[4]), - "running_since": Timestamp(row[5]), + "running_since": Timestamp(row[5] - row[4]), "content": db_to_json(row[6]), } for row in rows @@ -216,14 +216,14 @@ def process_all_delays_txn(txn: LoggingTransaction) -> Tuple[ "content", ) ) - sql_from = "FROM delayed_events WHERE running_since + delay < ?" - sql_order = "ORDER BY running_since + delay" + sql_from = "FROM delayed_events WHERE send_ts <= ?" + sql_order = "ORDER BY send_ts" if self.database_engine.supports_returning: txn.execute( f""" - WITH timed_out_events AS ( + WITH events_to_send AS ( DELETE {sql_from} RETURNING * - ) SELECT {sql_cols} FROM timed_out_events {sql_order} + ) SELECT {sql_cols} FROM events_to_send {sql_order} """, (current_ts,), ) @@ -258,7 +258,7 @@ def process_all_delays_txn(txn: LoggingTransaction) -> Tuple[ SELECT delay_id, user_localpart, - running_since + delay - ? AS relative_delay + send_ts - ? AS relative_delay FROM delayed_events """, (current_ts,), diff --git a/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql index 6998de35162..24ba5b390b3 100644 --- a/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql +++ b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql @@ -15,7 +15,7 @@ CREATE TABLE delayed_events ( delay_id TEXT NOT NULL, user_localpart TEXT NOT NULL, delay BIGINT NOT NULL, - running_since BIGINT NOT NULL, + send_ts BIGINT NOT NULL, room_id TEXT NOT NULL, event_type TEXT NOT NULL, state_key TEXT, @@ -24,4 +24,5 @@ CREATE TABLE delayed_events ( PRIMARY KEY (user_localpart, delay_id) ); +CREATE INDEX delayed_events_send_ts ON delayed_events (send_ts); CREATE INDEX delayed_events_room_state_event_idx ON delayed_events (room_id, event_type, state_key) WHERE state_key IS NOT NULL; From 4dc41dc5cdf8406881b85a33fe31ad27a3d23bf7 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Wed, 21 Aug 2024 11:16:14 -0400 Subject: [PATCH 43/79] Remove redundant delay value check Should have been done in 47b6e69 --- synapse/handlers/delayed_events.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index c278894951b..1e57f7f31a1 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -18,7 +18,6 @@ import logging from contextlib import asynccontextmanager -from http import HTTPStatus from typing import ( TYPE_CHECKING, AsyncContextManager, @@ -35,7 +34,7 @@ from twisted.internet.interfaces import IDelayedCall from synapse.api.constants import EventTypes -from synapse.api.errors import Codes, NotFoundError, ShadowBanError, SynapseError +from synapse.api.errors import NotFoundError, ShadowBanError from synapse.events import EventBase from synapse.logging.opentracing import set_tag from synapse.metrics.background_process_metrics import run_as_background_process @@ -186,18 +185,6 @@ async def add( Returns: The ID of the added delayed event. """ - max_delay = self._config.experimental.msc4140_max_delay - if delay > max_delay: - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "The requested delay exceeds the allowed maximum.", - Codes.UNKNOWN, - { - "org.matrix.msc4140.errcode": "M_MAX_DELAY_EXCEEDED", - "org.matrix.msc4140.max_delay": max_delay, - }, - ) - await self._request_ratelimiter.ratelimit(requester) await self._initialized_from_db From 3ce7305deeba91512fc04c8fc73a43aa1f4cc6fb Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 27 Aug 2024 20:59:29 -0400 Subject: [PATCH 44/79] Refactor max delay config - Move from "experimental" config section to "server" - Allow to be set to None to disable delayed event sending - Parse as a duration string, not just an integer of milliseconds - Set default value to None, i.e. disable delayed events by default --- .../configuration/config_documentation.md | 12 +++++++++ synapse/config/experimental.py | 16 +---------- synapse/config/server.py | 13 ++++++++- synapse/rest/client/room.py | 22 +++++++++++---- tests/rest/client/test_rooms.py | 27 +++++++++++++++++++ 5 files changed, 69 insertions(+), 21 deletions(-) diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 567bbf88d28..cbb194d53d8 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -760,6 +760,18 @@ email: password_reset: "[%(server_name)s] Password reset" email_validation: "[%(server_name)s] Validate your email" ``` +--- +### `max_event_delay_duration` + +The maximum allowed duration by which sent events can be delayed, as per MSC4140. +Must be a positive value if set. + +Defaults to no duration, which disallows sending delayed events. + +Example configuration: +```yaml +max_event_delay_duration: 24h +``` ## Homeserver blocking Useful options for Synapse admins. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 4c00ad2b042..bae9cc80476 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -2,7 +2,7 @@ # This file is licensed under the Affero General Public License (AGPL) version 3. # # Copyright 2021 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023-2024 New Vector, Ltd +# Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -443,20 +443,6 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: "msc3823_account_suspension", False ) - # MSC4140: Delayed events - # The maximum allowed duration for delayed events. - try: - self.msc4140_max_delay = int(experimental["msc4140_max_delay"]) - if self.msc4140_max_delay <= 0: - raise ValueError - except ValueError as e: - raise ConfigError( - "msc4140_max_delay must be a positive integer", - ("experimental", "msc4140_max_delay"), - ) from e - except KeyError: - self.msc4140_max_delay = 10 * 365 * 24 * 60 * 60 * 1000 # 10 years - # MSC4151: Report room API (Client-Server API) self.msc4151_enabled: bool = experimental.get("msc4151_enabled", False) diff --git a/synapse/config/server.py b/synapse/config/server.py index fd52c0475cf..cc406aea2c7 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -2,7 +2,7 @@ # This file is licensed under the Affero General Public License (AGPL) version 3. # # Copyright 2014-2021 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023-2024 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -780,6 +780,17 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: else: self.delete_stale_devices_after = None + # The maximum allowed delay duration for delayed events (MSC4140). + max_event_delay_duration = config.get("max_event_delay_duration") + if max_event_delay_duration is not None: + self.max_event_delay_ms: Optional[int] = self.parse_duration( + max_event_delay_duration + ) + if self.max_event_delay_ms <= 0: + raise ConfigError("max_event_delay_duration must be a positive value") + else: + self.max_event_delay_ms = None + def has_tls_listener(self) -> bool: return any(listener.is_tls() for listener in self.listeners) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 694cf42214e..d92bae15c73 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -196,7 +196,7 @@ def __init__(self, hs: "HomeServer"): self.message_handler = hs.get_message_handler() self.delayed_events_handler = hs.get_delayed_events_handler() self.auth = hs.get_auth() - self._msc4140_max_delay = hs.config.experimental.msc4140_max_delay + self._max_event_delay_ms = hs.config.server.max_event_delay_ms def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/state/$eventtype @@ -292,7 +292,7 @@ async def on_PUT( if requester.app_service: origin_server_ts = parse_integer(request, "ts") - delay = _parse_request_delay(request, self._msc4140_max_delay) + delay = _parse_request_delay(request, self._max_event_delay_ms) if delay is not None: delay_id = await self.delayed_events_handler.add( requester, @@ -360,7 +360,7 @@ def __init__(self, hs: "HomeServer"): self.event_creation_handler = hs.get_event_creation_handler() self.delayed_events_handler = hs.get_delayed_events_handler() self.auth = hs.get_auth() - self._msc4140_max_delay = hs.config.experimental.msc4140_max_delay + self._max_event_delay_ms = hs.config.server.max_event_delay_ms def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/send/$event_type[/$txn_id] @@ -381,7 +381,7 @@ async def _do( if requester.app_service: origin_server_ts = parse_integer(request, "ts") - delay = _parse_request_delay(request, self._msc4140_max_delay) + delay = _parse_request_delay(request, self._max_event_delay_ms) if delay is not None: delay_id = await self.delayed_events_handler.add( requester, @@ -448,10 +448,22 @@ async def on_PUT( ) -def _parse_request_delay(request: SynapseRequest, max_delay: int) -> Optional[int]: +def _parse_request_delay( + request: SynapseRequest, + max_delay: Optional[int], +) -> Optional[int]: delay = parse_integer(request, "org.matrix.msc4140.delay") if delay is None: return None + if max_delay is None: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Delayed events are not supported on this server", + Codes.UNKNOWN, + { + "org.matrix.msc4140.errcode": "M_MAX_DELAY_UNSUPPORTED", + }, + ) if delay > max_delay: raise SynapseError( HTTPStatus.BAD_REQUEST, diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index bd299802182..7f3d74ad6cc 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2311,6 +2311,32 @@ def test_send_delayed_invalid_event(self) -> None: ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) + def test_delayed_event_unsupported_by_default(self) -> None: + """Test that sending a delayed event is unsupported with the default config.""" + channel = self.make_request( + "PUT", + ( + "rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=2000" + % self.room_id + ).encode("ascii"), + {"body": "test", "msgtype": "m.text"}, + ) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) + + @unittest.override_config({"max_event_delay_duration": "1000"}) + def test_delayed_event_exceeds_max_delay(self) -> None: + """Test that sending a delayed event fails if its delay is longer than allowed.""" + channel = self.make_request( + "PUT", + ( + "rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=2000" + % self.room_id + ).encode("ascii"), + {"body": "test", "msgtype": "m.text"}, + ) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) + + @unittest.override_config({"max_event_delay_duration": "24h"}) def test_send_delayed_message_event(self) -> None: """Test sending a delayed event with invalid content.""" channel = self.make_request( @@ -2323,6 +2349,7 @@ def test_send_delayed_message_event(self) -> None: ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + @unittest.override_config({"max_event_delay_duration": "24h"}) def test_send_delayed_state_event(self) -> None: """Test sending a delayed event with invalid content.""" channel = self.make_request( From be094e604586876099f0a98233a80acc4935cb6e Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Wed, 28 Aug 2024 04:55:28 -0400 Subject: [PATCH 45/79] Refactor delayed event processing - Instead of popping events to send before they persist, mark them as in-progress and remove them only once persisted - Check for timed out events in bulk - Eschew locks in favour of DB-level atomicity --- synapse/handlers/delayed_events.py | 321 +++++++------ .../storage/databases/main/delayed_events.py | 430 +++++++++++------- .../main/delta/87/01_add_delayed_events.sql | 2 + tests/rest/client/test_rooms.py | 4 +- 4 files changed, 438 insertions(+), 319 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 1e57f7f31a1..0c6328efab0 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -17,29 +17,18 @@ # import logging -from contextlib import asynccontextmanager -from typing import ( - TYPE_CHECKING, - AsyncContextManager, - AsyncIterator, - Dict, - List, - Optional, - Set, - Tuple, -) - -import attr +from typing import TYPE_CHECKING, List, Optional, Set, Tuple from twisted.internet.interfaces import IDelayedCall from synapse.api.constants import EventTypes -from synapse.api.errors import NotFoundError, ShadowBanError +from synapse.api.errors import ShadowBanError from synapse.events import EventBase from synapse.logging.opentracing import set_tag from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main.delayed_events import ( Delay, + DelayedEventDetails, DelayID, EventType, StateKey, @@ -54,7 +43,6 @@ UserID, create_requester, ) -from synapse.util.async_helpers import Linearizer, ReadWriteLock from synapse.util.events import generate_fake_event_id if TYPE_CHECKING: @@ -62,17 +50,6 @@ logger = logging.getLogger(__name__) -_STATE_LOCK_KEY = "STATE_LOCK_KEY" - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _DelayedCallKey: - delay_id: DelayID - user_localpart: UserLocalpart - - def __str__(self) -> str: - return f"{self.user_localpart}:{self.delay_id}" - class DelayedEventsHandler: def __init__(self, hs: "HomeServer"): @@ -83,52 +60,40 @@ def __init__(self, hs: "HomeServer"): self._event_creation_handler = hs.get_event_creation_handler() self._room_member_handler = hs.get_room_member_handler() - self._delayed_calls: Dict[_DelayedCallKey, IDelayedCall] = {} - # This is for making delayed event actions atomic - self._linearizer = Linearizer("delayed_events_handler") - # This is to prevent running actions on delayed events removed due to state changes - self._state_lock = ReadWriteLock() + self._next_delayed_event_call: Optional[IDelayedCall] = None + + # TODO: Looks like these callbacks are run in background. Find a foreground one + hs.get_module_api().register_third_party_rules_callbacks( + on_new_event=self.on_new_event, + ) async def _schedule_db_events() -> None: # TODO: Sync all state first, so that affected delayed state events will be cancelled - events, remaining_timeout_delays = ( - await self._store.process_all_delayed_events(self._get_current_ts()) + + # Delayed events that are already marked as processed on startup might not have been + # sent properly on the last run of the server, so unmark them to send them again. + # Caveats: + # - This will double-send delayed events that successfully persisted, but failed to be + # removed from the DB table of delayed events. + # - This will interfere with workers that are in the act of processing delayed events. + # TODO: To avoid double-sending, scan the timeline to find which of these events were + # already sent. To do so, must store delay_ids in sent events to retrieve them later. + # TODO: To avoid interfering with workers, think of a way to distinguish between + # events being processed by a worker vs ones that got lost after a server crash. + await self._store.unprocess_delayed_events() + + events, next_send_ts = await self._store.process_timeout_delayed_events( + self._get_current_ts() ) - sent_state: Set[Tuple[RoomID, EventType, StateKey]] = set() - for ( - user_localpart, - room_id, - event_type, - state_key, - timestamp, - content, - ) in events: - if state_key is not None: - state_info = (room_id, event_type, state_key) - if state_info in sent_state: - continue - else: - state_info = None - try: - await self._send_event( - user_localpart, - room_id, - event_type, - state_key, - timestamp, - content, - ) - if state_info is not None: - sent_state.add(state_info) - except Exception: - logger.exception("Failed to send delayed event on startup") - sent_state.clear() - - for delay_id, user_localpart, relative_delay in remaining_timeout_delays: - self._schedule(delay_id, user_localpart, relative_delay) - - hs.get_module_api().register_third_party_rules_callbacks( - on_new_event=self.on_new_event, + + if next_send_ts: + self._schedule_next_at(next_send_ts) + + # Can send the events in background after having awaited on marking them as processed + run_as_background_process( + "_send_events", + self._send_events, + events, ) self._initialized_from_db = run_as_background_process( @@ -144,19 +109,14 @@ async def on_new_event( """ state_key = event.get_state_key() if state_key is not None: - async with self._get_state_context(): - for ( - removed_timeout_delay_id, - removed_timeout_delay_user_localpart, - ) in await self._store.remove_delayed_state_events( - event.room_id, - event.type, - state_key, - ): - self._unschedule( - removed_timeout_delay_id, - removed_timeout_delay_user_localpart, - ) + changed, next_send_ts = await self._store.cancel_delayed_state_events( + room_id=event.room_id, + event_type=event.type, + state_key=state_key, + ) + + if changed: + self._schedule_next_at_or_none(next_send_ts) async def add( self, @@ -170,11 +130,11 @@ async def add( delay: int, ) -> str: """ - Creates a new delayed event. + Creates a new delayed event and schedules its delivery. Args: requester: The requester of the delayed event, who will be its owner. - room_id: The room where the event should be sent. + room_id: The ID of the room where the event should be sent to. event_type: The type of event to be sent. state_key: The state key of the event to be sent, or None if it is not a state event. origin_server_ts: The custom timestamp to send the event with. @@ -182,8 +142,10 @@ async def add( content: The content of the event to be sent. delay: How long (in milliseconds) to wait before automatically sending the event. - Returns: - The ID of the added delayed event. + Returns: The ID of the added delayed event. + + Raises: + SynapseError: if the delayed event fails validation checks. """ await self._request_ratelimiter.ratelimit(requester) await self._initialized_from_db @@ -201,11 +163,12 @@ async def add( ) ) - user_localpart = UserLocalpart(requester.user.localpart) - delay_id = await self._store.add_delayed_event( - user_localpart=user_localpart, - current_ts=self._get_current_ts(), - room_id=RoomID.from_string(room_id), + creation_ts = self._get_current_ts() + + delay_id, changed = await self._store.add_delayed_event( + user_localpart=requester.user.localpart, + creation_ts=creation_ts, + room_id=room_id, event_type=event_type, state_key=state_key, origin_server_ts=origin_server_ts, @@ -213,7 +176,8 @@ async def add( delay=delay, ) - self._schedule(delay_id, user_localpart, Delay(delay)) + if changed: + self._schedule_next_at(Timestamp(creation_ts + delay)) return delay_id @@ -229,13 +193,15 @@ async def cancel(self, requester: Requester, delay_id: str) -> None: NotFoundError: if no matching delayed event could be found. """ await self._request_ratelimiter.ratelimit(requester) + await self._initialized_from_db - delay_id = DelayID(delay_id) - user_localpart = UserLocalpart(requester.user.localpart) - async with self._get_delay_context(delay_id, user_localpart): - await self._store.remove_delayed_event(delay_id, user_localpart) + changed, next_send_ts = await self._store.cancel_delayed_event( + delay_id=delay_id, + user_localpart=requester.user.localpart, + ) - self._unschedule(delay_id, user_localpart) + if changed: + self._schedule_next_at_or_none(next_send_ts) async def restart(self, requester: Requester, delay_id: str) -> None: """ @@ -249,18 +215,16 @@ async def restart(self, requester: Requester, delay_id: str) -> None: NotFoundError: if no matching delayed event could be found. """ await self._request_ratelimiter.ratelimit(requester) + await self._initialized_from_db - delay_id = DelayID(delay_id) - user_localpart = UserLocalpart(requester.user.localpart) - async with self._get_delay_context(delay_id, user_localpart): - delay = await self._store.restart_delayed_event( - delay_id, - user_localpart, - self._get_current_ts(), - ) + changed, next_send_ts = await self._store.restart_delayed_event( + delay_id=delay_id, + user_localpart=requester.user.localpart, + current_ts=self._get_current_ts(), + ) - self._unschedule(delay_id, user_localpart) - self._schedule(delay_id, user_localpart, delay) + if changed: + self._schedule_next_at(next_send_ts) async def send(self, requester: Requester, delay_id: str) -> None: """ @@ -274,76 +238,102 @@ async def send(self, requester: Requester, delay_id: str) -> None: NotFoundError: if no matching delayed event could be found. """ await self._request_ratelimiter.ratelimit(requester) + await self._initialized_from_db - delay_id = DelayID(delay_id) - user_localpart = UserLocalpart(requester.user.localpart) - async with self._get_delay_context(delay_id, user_localpart): - args = await self._store.pop_delayed_event(delay_id, user_localpart) + event, changed, next_send_ts = await self._store.process_target_delayed_event( + delay_id=delay_id, + user_localpart=requester.user.localpart, + ) - self._unschedule(delay_id, user_localpart) - await self._send_event(user_localpart, *args) + if changed: + self._schedule_next_at_or_none(next_send_ts) - async def _send_on_timeout( - self, delay_id: DelayID, user_localpart: UserLocalpart - ) -> None: - del self._delayed_calls[_DelayedCallKey(delay_id, user_localpart)] + await self._send_event( + DelayID(delay_id), + UserLocalpart(requester.user.localpart), + *event, + ) - async with self._get_delay_context(delay_id, user_localpart): - try: - args = await self._store.pop_delayed_event(delay_id, user_localpart) - except NotFoundError: - logger.debug( - "delay_id %s for local user %s was removed from the DB before it timed out (or was always missing)", - delay_id, - user_localpart, - ) - return + async def _send_on_timeout(self) -> None: + self._next_delayed_event_call = None - await self._send_event(user_localpart, *args) + events, next_send_ts = await self._store.process_timeout_delayed_events( + self._get_current_ts() + ) - def _schedule( - self, - delay_id: DelayID, - user_localpart: UserLocalpart, - delay: Delay, - ) -> None: - assert delay > 0, "Clock.call_later doesn't support negative delays" - delay_sec = delay / 1000 + if next_send_ts: + self._schedule_next_at(next_send_ts) + + await self._send_events(events) - logger.info( - "Scheduling delayed event %s for local user %s to be sent in %.3fs", + async def _send_events(self, events: List[DelayedEventDetails]) -> None: + sent_state: Set[Tuple[RoomID, EventType, StateKey]] = set() + for ( delay_id, user_localpart, - delay_sec, - ) - - self._delayed_calls[_DelayedCallKey(delay_id, user_localpart)] = ( - self._clock.call_later( + room_id, + event_type, + state_key, + origin_server_ts, + content, + ) in events: + if state_key is not None: + state_info = (room_id, event_type, state_key) + if state_info in sent_state: + continue + else: + state_info = None + try: + # TODO: send in background if message event or non-conflicting state event + await self._send_event( + delay_id, + user_localpart, + room_id, + event_type, + state_key, + origin_server_ts, + content, + ) + if state_info is not None: + # Note that removal from the DB is done by self.on_new_event + sent_state.add(state_info) + except Exception: + logger.exception("Failed to send delayed event") + + def _schedule_next_at_or_none(self, next_send_ts: Optional[Timestamp]) -> None: + if next_send_ts is not None: + self._schedule_next_at(next_send_ts) + elif self._next_delayed_event_call is not None: + self._next_delayed_event_call.cancel() + self._next_delayed_event_call = None + + def _schedule_next_at(self, next_send_ts: Timestamp) -> None: + return self._schedule_next(self._get_delay_until(next_send_ts)) + + def _schedule_next(self, delay: Delay) -> None: + delay_sec = delay / 1000 if delay > 0 else 0 + + if self._next_delayed_event_call is None: + self._next_delayed_event_call = self._clock.call_later( delay_sec, run_as_background_process, "_send_on_timeout", self._send_on_timeout, - delay_id, - user_localpart, ) - ) - - def _unschedule(self, delay_id: DelayID, user_localpart: UserLocalpart) -> None: - delayed_call = self._delayed_calls.pop( - _DelayedCallKey(delay_id, user_localpart) - ) - self._clock.cancel_call_later(delayed_call) + else: + self._next_delayed_event_call.reset(delay_sec) async def get_all_for_user(self, requester: Requester) -> List[JsonDict]: """Return all pending delayed events requested by the given user.""" await self._request_ratelimiter.ratelimit(requester) await self._initialized_from_db return await self._store.get_all_delayed_events_for_user( - UserLocalpart(requester.user.localpart) + requester.user.localpart ) async def _send_event( self, + delay_id: DelayID, user_localpart: UserLocalpart, room_id: RoomID, event_type: EventType, @@ -396,22 +386,23 @@ async def _send_event( event_id = event.event_id except ShadowBanError: event_id = generate_fake_event_id() + finally: + # TODO: If this is a temporary error, retry. Otherwise, consider notifying clients of the failure + try: + await self._store.delete_processed_delayed_event( + delay_id, user_localpart + ) + except Exception: + logger.exception("Failed to delete processed delayed event") set_tag("event_id", event_id) def _get_current_ts(self) -> Timestamp: return Timestamp(self._clock.time_msec()) - @asynccontextmanager - async def _get_delay_context( - self, delay_id: DelayID, user_localpart: UserLocalpart - ) -> AsyncIterator[None]: - await self._initialized_from_db - # TODO: Use parenthesized context manager once the minimum supported Python version is 3.10 - async with self._state_lock.read(_STATE_LOCK_KEY), self._linearizer.queue( - _DelayedCallKey(delay_id, user_localpart) - ): - yield - - def _get_state_context(self) -> AsyncContextManager: - return self._state_lock.write(_STATE_LOCK_KEY) + def _get_delay_until(self, to_ts: Timestamp) -> Delay: + return _get_delay_between(self._get_current_ts(), to_ts) + + +def _get_delay_between(from_ts: Timestamp, to_ts: Timestamp) -> Delay: + return Delay(to_ts - from_ts) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 253b5363364..0c7da982876 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -21,7 +21,7 @@ from synapse.api.errors import NotFoundError from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import LoggingTransaction, make_tuple_in_list_sql_clause +from synapse.storage.database import LoggingTransaction, StoreError from synapse.types import JsonDict, RoomID from synapse.util import json_encoder, stringutils as stringutils @@ -36,16 +36,9 @@ Delay = NewType("Delay", int) Timestamp = NewType("Timestamp", int) -DelayedPartialEvent = Tuple[ - RoomID, - EventType, - Optional[StateKey], - Optional[Timestamp], - JsonDict, -] - -# TODO: If a Tuple type hint can be extended, extend the above one -DelayedPartialEventWithUser = Tuple[ +# TODO: Maybe use attr class +DelayedEventDetails = Tuple[ + DelayID, UserLocalpart, RoomID, EventType, @@ -55,71 +48,93 @@ ] -# TODO: Try to support workers class DelayedEventsStore(SQLBaseStore): async def add_delayed_event( self, *, - user_localpart: UserLocalpart, - current_ts: Timestamp, - room_id: RoomID, + user_localpart: str, + creation_ts: Timestamp, + room_id: str, event_type: str, state_key: Optional[str], origin_server_ts: Optional[int], content: JsonDict, delay: int, - ) -> DelayID: + ) -> Tuple[DelayID, bool]: """ Inserts a new delayed event in the DB. - Returns: The generated ID assigned to the added delayed event. + Returns: The generated ID assigned to the added delayed event, + and whether the added delayed event is the next to be sent. """ delay_id = _generate_delay_id() - await self.db_pool.simple_insert( - table="delayed_events", - values={ - "delay_id": delay_id, - "user_localpart": user_localpart, - "delay": delay, - "send_ts": current_ts + delay, - "room_id": room_id.to_string(), - "event_type": event_type, - "state_key": state_key, - "origin_server_ts": origin_server_ts, - "content": json_encoder.encode(content), - }, - desc="add_delayed_event", + send_ts = creation_ts + delay + + def _add_delayed_event_txn(txn: LoggingTransaction) -> bool: + old_next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) + + self.db_pool.simple_insert_txn( + txn, + table="delayed_events", + values={ + "delay_id": delay_id, + "user_localpart": user_localpart, + "delay": delay, + "send_ts": send_ts, + "room_id": room_id, + "event_type": event_type, + "state_key": state_key, + "origin_server_ts": origin_server_ts, + "content": json_encoder.encode(content), + }, + ) + + return old_next_send_ts is None or send_ts < old_next_send_ts + + changed = await self.db_pool.runInteraction( + "add_delayed_event", _add_delayed_event_txn ) - return delay_id + + return delay_id, changed async def restart_delayed_event( self, - delay_id: DelayID, - user_localpart: UserLocalpart, + *, + delay_id: str, + user_localpart: str, current_ts: Timestamp, - ) -> Delay: + ) -> Tuple[bool, Timestamp]: """ - Restarts the send time of the matching delayed event. + Restarts the send time of the matching delayed event, + as long as it hasn't already been marked for processing. Args: delay_id: The ID of the delayed event to restart. user_localpart: The localpart of the delayed event's owner. current_ts: The current time, which will be used to calculate the new send time. - Returns: The delay at which the delayed event will be sent (unless it is reset again). + Returns: Whether the matching delayed event would have been the next to be sent, + and if so, what the next soonest send time is, if any. Raises: NotFoundError: if there is no matching delayed event. """ - def restart_txn(txn: LoggingTransaction) -> Delay: + def restart_delayed_event_txn( + txn: LoggingTransaction, + ) -> Tuple[bool, Timestamp]: + old_next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) + if old_next_send_ts is None: + raise NotFoundError("Delayed event not found") + if self.database_engine.supports_returning: txn.execute( """ UPDATE delayed_events SET send_ts = ? + delay WHERE delay_id = ? AND user_localpart = ? - RETURNING delay + AND NOT is_processed + RETURNING send_ts """, ( current_ts, @@ -130,11 +145,13 @@ def restart_txn(txn: LoggingTransaction) -> Delay: row = txn.fetchone() if row is None: raise NotFoundError("Delayed event not found") - return Delay(row[0]) + + restarted_send_ts = row[0] else: keyvalues = { "delay_id": delay_id, "user_localpart": user_localpart, + "is_processed": False, } delay = self.db_pool.simple_select_one_onecol_txn( txn, @@ -145,26 +162,36 @@ def restart_txn(txn: LoggingTransaction) -> Delay: ) if delay is None: raise NotFoundError("Delayed event not found") + + restarted_send_ts = current_ts + delay self.db_pool.simple_update_one_txn( txn, table="delayed_events", keyvalues=keyvalues, - updatevalues={"send_ts": current_ts + delay}, + updatevalues={"send_ts": restarted_send_ts}, ) - return Delay(delay) - return await self.db_pool.runInteraction("restart_delayed_event", restart_txn) + new_next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) + assert new_next_send_ts is not None + return new_next_send_ts < old_next_send_ts, new_next_send_ts + + return await self.db_pool.runInteraction( + "restart_delayed_event", restart_delayed_event_txn + ) async def get_all_delayed_events_for_user( self, - user_localpart: UserLocalpart, + user_localpart: str, ) -> List[JsonDict]: """Returns all pending delayed events owned by the given user.""" # TODO: Store and return "transaction_id" # TODO: Support Pagination stream API ("next_batch" field) rows = await self.db_pool.simple_select_list( table="delayed_events", - keyvalues={"user_localpart": user_localpart}, + keyvalues={ + "user_localpart": user_localpart, + "is_processed": False, + }, retcols=( "delay_id", "room_id", @@ -189,25 +216,25 @@ async def get_all_delayed_events_for_user( for row in rows ] - async def process_all_delayed_events(self, current_ts: Timestamp) -> Tuple[ - List[DelayedPartialEventWithUser], - List[Tuple[DelayID, UserLocalpart, Delay]], + async def process_timeout_delayed_events(self, current_ts: Timestamp) -> Tuple[ + List[DelayedEventDetails], + Optional[Timestamp], ]: """ - Pops all delayed events that should have timed out prior to the provided time, - and returns all remaining timeout delayed events along with - how much later from the provided time they should time out at. + Marks for processing all delayed events that should have been sent prior to the provided time + that haven't already been marked as such. - Does not return any delayed events that got removed but not sent, as this is - meant to be called on startup before any delayed events have been scheduled. + Returns: The details of all newly-processed delayed events, + and the send time of the next delayed event to be sent, if any. """ - def process_all_delays_txn(txn: LoggingTransaction) -> Tuple[ - List[DelayedPartialEventWithUser], - List[Tuple[DelayID, UserLocalpart, Delay]], + def process_timeout_delayed_events_txn(txn: LoggingTransaction) -> Tuple[ + List[DelayedEventDetails], + Optional[Timestamp], ]: sql_cols = ", ".join( ( + "delay_id", "user_localpart", "room_id", "event_type", @@ -216,181 +243,280 @@ def process_all_delays_txn(txn: LoggingTransaction) -> Tuple[ "content", ) ) - sql_from = "FROM delayed_events WHERE send_ts <= ?" + sql_update = "UPDATE delayed_events SET is_processed = TRUE" + sql_where = "WHERE send_ts <= ? AND NOT is_processed" + sql_args = (current_ts,) sql_order = "ORDER BY send_ts" if self.database_engine.supports_returning: txn.execute( f""" WITH events_to_send AS ( - DELETE {sql_from} RETURNING * + {sql_update} {sql_where} RETURNING * ) SELECT {sql_cols} FROM events_to_send {sql_order} """, - (current_ts,), + sql_args, ) rows = txn.fetchall() else: txn.execute( - f"SELECT {sql_cols}, delay_id {sql_from} {sql_order}", (current_ts,) + f"SELECT {sql_cols} FROM delayed_events {sql_where} {sql_order}", + sql_args, ) rows = txn.fetchall() - sql_key_clause, sql_key_args = make_tuple_in_list_sql_clause( - self.database_engine, - ("delay_id", "user_localpart"), - tuple((row[-1], row[0]) for row in rows), - ) - txn.execute( - f"DELETE from delayed_events WHERE {sql_key_clause}", sql_key_args - ) - events = [ - ( - UserLocalpart(row[0]), - RoomID.from_string(row[1]), - EventType(row[2]), - StateKey(row[3]) if row[3] is not None else None, - Timestamp(row[4]) if row[4] is not None else None, - db_to_json(row[5]), - ) - for row in rows - ] + txn.execute(f"{sql_update} {sql_where}", sql_args) + assert txn.rowcount == len(rows) - txn.execute( - """ - SELECT - delay_id, - user_localpart, - send_ts - ? AS relative_delay - FROM delayed_events - """, - (current_ts,), - ) - remaining_timeout_delays = [ + events = [ ( DelayID(row[0]), UserLocalpart(row[1]), - Delay(row[2]), + RoomID.from_string(row[2]), + EventType(row[3]), + StateKey(row[4]) if row[4] is not None else None, + Timestamp(row[5]) if row[5] is not None else None, + db_to_json(row[6]), ) - for row in txn + for row in rows ] - return events, remaining_timeout_delays + next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) + return events, next_send_ts return await self.db_pool.runInteraction( - "process_all_delayed_events", process_all_delays_txn + "process_timeout_delayed_events", process_timeout_delayed_events_txn ) - async def pop_delayed_event( + async def process_target_delayed_event( self, - delay_id: DelayID, - user_localpart: UserLocalpart, - ) -> DelayedPartialEvent: + *, + delay_id: str, + user_localpart: str, + ) -> Tuple[ + Tuple[ + RoomID, + EventType, + Optional[StateKey], + Optional[Timestamp], + JsonDict, + ], + bool, + Optional[Timestamp], + ]: """ - Gets the partial event of the matching delayed event, and remove it from the DB. + Marks for processing the matching delayed event, regardless of its timeout time, + as long as it has not already been marked as such. + + Args: + delay_id: The ID of the delayed event to restart. + user_localpart: The localpart of the delayed event's owner. - Returns: - The partial event to send for the matching delayed event. + Returns: The details of the matching delayed event, + whether the matching delayed event would have been the next to be sent, + and the send time of the next delayed event to be sent, if any. Raises: NotFoundError: if there is no matching delayed event. """ - def pop_event_txn(txn: LoggingTransaction) -> DelayedPartialEvent: + def process_target_delayed_event_txn(txn: LoggingTransaction) -> Tuple[ + Tuple[ + RoomID, + EventType, + Optional[StateKey], + Optional[Timestamp], + JsonDict, + ], + bool, + Optional[Timestamp], + ]: sql_cols = ", ".join( ( "room_id", "event_type", "state_key", "origin_server_ts", + "send_ts", "content", ) ) - sql_from = "FROM delayed_events WHERE delay_id = ? AND user_localpart = ?" + sql_update = "UPDATE delayed_events SET is_processed = TRUE" + sql_where = "WHERE delay_id = ? AND user_localpart = ? AND NOT is_processed" + sql_args = (delay_id, user_localpart) txn.execute( ( - f"DELETE {sql_from} RETURNING {sql_cols}" + f"{sql_update} RETURNING {sql_cols}" if self.database_engine.supports_returning - else f"SELECT {sql_cols} {sql_from}" + else f"SELECT {sql_cols} FROM delayed_events {sql_where}" ), - (delay_id, user_localpart), + sql_args, ) row = txn.fetchone() if row is None: raise NotFoundError("Delayed event not found") elif not self.database_engine.supports_returning: - txn.execute(f"DELETE {sql_from}") + txn.execute(f"{sql_update} {sql_where}", sql_args) assert txn.rowcount == 1 - return ( + send_ts = Timestamp(row[4]) + event = ( RoomID.from_string(row[0]), EventType(row[1]), StateKey(row[2]) if row[2] is not None else None, Timestamp(row[3]) if row[3] is not None else None, - db_to_json(row[4]), + db_to_json(row[5]), ) - return await self.db_pool.runInteraction("pop_delayed_event", pop_event_txn) + next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) + return event, next_send_ts != send_ts, next_send_ts - async def remove_delayed_event( + return await self.db_pool.runInteraction( + "process_target_delayed_event", process_target_delayed_event_txn + ) + + async def cancel_delayed_event( self, - delay_id: DelayID, - user_localpart: UserLocalpart, - ) -> None: + *, + delay_id: str, + user_localpart: str, + ) -> Tuple[bool, Optional[Timestamp]]: """ - Removes the matching delayed event. + Cancels the matching delayed event, i.e. remove it as long as it hasn't been processed. + + Args: + delay_id: The ID of the delayed event to restart. + user_localpart: The localpart of the delayed event's owner. + + Returns: Whether the matching delayed event would have been the next to be sent, + and if so, what the next soonest send time is, if any. Raises: NotFoundError: if there is no matching delayed event. """ - await self.db_pool.simple_delete( - table="delayed_events", - keyvalues={ - "delay_id": delay_id, - "user_localpart": user_localpart, - }, - desc="remove_delayed_event", + + def cancel_delayed_event_txn( + txn: LoggingTransaction, + ) -> Tuple[bool, Optional[Timestamp]]: + old_next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) + if old_next_send_ts is None: + raise NotFoundError("Delayed event not found") + + try: + self.db_pool.simple_delete_one_txn( + txn, + table="delayed_events", + keyvalues={ + "delay_id": delay_id, + "user_localpart": user_localpart, + "is_processed": False, + }, + ) + except StoreError: + if txn.rowcount == 0: + raise NotFoundError("Delayed event not found") + else: + raise + + new_next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) + return new_next_send_ts != old_next_send_ts, new_next_send_ts + + return await self.db_pool.runInteraction( + "cancel_delayed_event", cancel_delayed_event_txn ) - async def remove_delayed_state_events( + async def cancel_delayed_state_events( self, + *, room_id: str, event_type: str, state_key: str, - ) -> List[Tuple[DelayID, UserLocalpart]]: + ) -> Tuple[bool, Optional[Timestamp]]: """ - Removes all matching delayed state events from the DB. + Cancels all matching delayed state events, i.e. remove them as long as they haven't been processed. - Returns: - The ID & owner of every removed delayed event. + Returns: Whether any of the matching delayed events would have been the next to be sent, + and if so, what the next soonest send time is, if any. """ - def remove_state_events_txn(txn: LoggingTransaction) -> List[Tuple]: - sql_cols = ", ".join( - ( - "delay_id", - "user_localpart", + def cancel_delayed_state_events_txn( + txn: LoggingTransaction, + ) -> Tuple[bool, Optional[Timestamp]]: + old_next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) + + if ( + self.db_pool.simple_delete_txn( + txn, + table="delayed_events", + keyvalues={ + "room_id": room_id, + "event_type": event_type, + "state_key": state_key, + "is_processed": False, + }, ) - ) - sql_from = ( - "FROM delayed_events " - "WHERE room_id = ? AND event_type = ? AND state_key = ?" - ) - sql_args = (room_id, event_type, state_key) - if self.database_engine.supports_returning: - txn.execute(f"DELETE {sql_from} RETURNING {sql_cols}", sql_args) - rows = txn.fetchall() + > 0 + ): + assert old_next_send_ts is not None + new_next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) + return new_next_send_ts != old_next_send_ts, new_next_send_ts else: - txn.execute(f"SELECT {sql_cols} {sql_from}", sql_args) - rows = txn.fetchall() - txn.execute(f"DELETE {sql_from}") - return [ - ( - DelayID(row[0]), - UserLocalpart(row[1]), - ) - for row in rows - ] + return False, None + + return await self.db_pool.runInteraction( + "cancel_delayed_state_events", cancel_delayed_state_events_txn + ) + + async def delete_processed_delayed_event( + self, + delay_id: DelayID, + user_localpart: UserLocalpart, + ) -> None: + """ + Delete the matching delayed event, as long as it has been marked as processed. + + Throws: + StoreError: if there is no matching delayed event, or if it has not yet been processed. + """ + return await self.db_pool.simple_delete_one( + table="delayed_events", + keyvalues={ + "delay_id": delay_id, + "user_localpart": user_localpart, + "is_processed": True, + }, + desc="delete_processed_delayed_event", + ) + + async def unprocess_delayed_events(self) -> None: + """ + Unmark all delayed events for processing. + """ + await self.db_pool.simple_update( + table="delayed_events", + keyvalues={"is_processed": True}, + updatevalues={"is_processed": False}, + desc="unprocess_delayed_events", + ) + async def get_next_delayed_event_send_ts(self) -> Optional[Timestamp]: + """ + Returns the send time of the next delayed event to be sent, if any. + """ return await self.db_pool.runInteraction( - "remove_delayed_state_events", remove_state_events_txn + "get_next_delayed_event_send_ts", + self._get_next_delayed_event_send_ts_txn, + db_autocommit=True, + ) + + def _get_next_delayed_event_send_ts_txn( + self, txn: LoggingTransaction + ) -> Optional[Timestamp]: + result = self.db_pool.simple_select_one_onecol_txn( + txn, + table="delayed_events", + keyvalues={"is_processed": False}, + retcol="MIN(send_ts)", + allow_none=True, ) + return Timestamp(result) if result is not None else None def _generate_delay_id() -> DelayID: diff --git a/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql index 24ba5b390b3..5996bcdd0c7 100644 --- a/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql +++ b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql @@ -21,8 +21,10 @@ CREATE TABLE delayed_events ( state_key TEXT, origin_server_ts BIGINT, content bytea NOT NULL, + is_processed BOOLEAN NOT NULL DEFAULT FALSE, PRIMARY KEY (user_localpart, delay_id) ); CREATE INDEX delayed_events_send_ts ON delayed_events (send_ts); +CREATE INDEX delayed_events_is_processed ON delayed_events (is_processed); CREATE INDEX delayed_events_room_state_event_idx ON delayed_events (room_id, event_type, state_key) WHERE state_key IS NOT NULL; diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 7f3d74ad6cc..9a1bae3d70b 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -742,7 +742,7 @@ def test_post_room_no_keys(self) -> None: self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(33, channel.resource_usage.db_txn_count) + self.assertEqual(38, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -755,7 +755,7 @@ def test_post_room_initial_state(self) -> None: self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(35, channel.resource_usage.db_txn_count) + self.assertEqual(40, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id From 2aed40b6c0fd7e1153e3ece4f7d4e91e353ca6ec Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Wed, 28 Aug 2024 04:58:25 -0400 Subject: [PATCH 46/79] Set delayed event origin time to its send time --- synapse/handlers/delayed_events.py | 6 ++---- synapse/storage/databases/main/delayed_events.py | 13 +++++++------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 0c6328efab0..54cdd03d7e2 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -338,7 +338,7 @@ async def _send_event( room_id: RoomID, event_type: EventType, state_key: Optional[StateKey], - origin_server_ts: Optional[Timestamp], + origin_server_ts: Timestamp, content: JsonDict, txn_id: Optional[str] = None, ) -> None: @@ -367,14 +367,12 @@ async def _send_event( "content": content, "room_id": room_id.to_string(), "sender": user_id_str, + "origin_server_ts": origin_server_ts, } if state_key is not None: event_dict["state_key"] = state_key - if origin_server_ts is not None: - event_dict["origin_server_ts"] = origin_server_ts - ( event, _, diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 0c7da982876..e6400a2a82d 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -43,7 +43,7 @@ RoomID, EventType, Optional[StateKey], - Optional[Timestamp], + Timestamp, JsonDict, ] @@ -240,6 +240,7 @@ def process_timeout_delayed_events_txn(txn: LoggingTransaction) -> Tuple[ "event_type", "state_key", "origin_server_ts", + "send_ts", "content", ) ) @@ -273,8 +274,8 @@ def process_timeout_delayed_events_txn(txn: LoggingTransaction) -> Tuple[ RoomID.from_string(row[2]), EventType(row[3]), StateKey(row[4]) if row[4] is not None else None, - Timestamp(row[5]) if row[5] is not None else None, - db_to_json(row[6]), + Timestamp(row[5] if row[5] is not None else row[6]), + db_to_json(row[7]), ) for row in rows ] @@ -295,7 +296,7 @@ async def process_target_delayed_event( RoomID, EventType, Optional[StateKey], - Optional[Timestamp], + Timestamp, JsonDict, ], bool, @@ -322,7 +323,7 @@ def process_target_delayed_event_txn(txn: LoggingTransaction) -> Tuple[ RoomID, EventType, Optional[StateKey], - Optional[Timestamp], + Timestamp, JsonDict, ], bool, @@ -361,7 +362,7 @@ def process_target_delayed_event_txn(txn: LoggingTransaction) -> Tuple[ RoomID.from_string(row[0]), EventType(row[1]), StateKey(row[2]) if row[2] is not None else None, - Timestamp(row[3]) if row[3] is not None else None, + Timestamp(row[3]) if row[3] is not None else send_ts, db_to_json(row[5]), ) From 235c4322e7aa954306b9fb1df902191c00d4b28c Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Wed, 28 Aug 2024 05:01:57 -0400 Subject: [PATCH 47/79] Save delayed event requester's device ID --- synapse/handlers/delayed_events.py | 8 ++++++++ synapse/storage/databases/main/delayed_events.py | 10 ++++++++++ .../schema/main/delta/87/01_add_delayed_events.sql | 1 + 3 files changed, 19 insertions(+) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 54cdd03d7e2..5e4ca6f4d51 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -30,6 +30,7 @@ Delay, DelayedEventDetails, DelayID, + DeviceID, EventType, StateKey, Timestamp, @@ -167,6 +168,7 @@ async def add( delay_id, changed = await self._store.add_delayed_event( user_localpart=requester.user.localpart, + device_id=requester.device_id, creation_ts=creation_ts, room_id=room_id, event_type=event_type, @@ -276,6 +278,7 @@ async def _send_events(self, events: List[DelayedEventDetails]) -> None: state_key, origin_server_ts, content, + device_id, ) in events: if state_key is not None: state_info = (room_id, event_type, state_key) @@ -293,6 +296,7 @@ async def _send_events(self, events: List[DelayedEventDetails]) -> None: state_key, origin_server_ts, content, + device_id, ) if state_info is not None: # Note that removal from the DB is done by self.on_new_event @@ -340,13 +344,17 @@ async def _send_event( state_key: Optional[StateKey], origin_server_ts: Timestamp, content: JsonDict, + device_id: Optional[DeviceID], txn_id: Optional[str] = None, ) -> None: user_id = UserID(user_localpart, self._config.server.server_name) user_id_str = user_id.to_string() + # Create a new requester from what data is currently available + # TODO: Consider storing the requester in the DB at add time and deserialize it here requester = create_requester( user_id, is_guest=await self._store.is_guest(user_id_str), + device_id=device_id, ) try: diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index e6400a2a82d..80d46e9ed17 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -30,6 +30,7 @@ DelayID = NewType("DelayID", str) UserLocalpart = NewType("UserLocalpart", str) +DeviceID = NewType("DeviceID", str) EventType = NewType("EventType", str) StateKey = NewType("StateKey", str) @@ -45,6 +46,7 @@ Optional[StateKey], Timestamp, JsonDict, + Optional[DeviceID], ] @@ -53,6 +55,7 @@ async def add_delayed_event( self, *, user_localpart: str, + device_id: Optional[str], creation_ts: Timestamp, room_id: str, event_type: str, @@ -79,6 +82,7 @@ def _add_delayed_event_txn(txn: LoggingTransaction) -> bool: values={ "delay_id": delay_id, "user_localpart": user_localpart, + "device_id": device_id, "delay": delay, "send_ts": send_ts, "room_id": room_id, @@ -242,6 +246,7 @@ def process_timeout_delayed_events_txn(txn: LoggingTransaction) -> Tuple[ "origin_server_ts", "send_ts", "content", + "device_id", ) ) sql_update = "UPDATE delayed_events SET is_processed = TRUE" @@ -276,6 +281,7 @@ def process_timeout_delayed_events_txn(txn: LoggingTransaction) -> Tuple[ StateKey(row[4]) if row[4] is not None else None, Timestamp(row[5] if row[5] is not None else row[6]), db_to_json(row[7]), + DeviceID(row[8]) if row[8] is not None else None, ) for row in rows ] @@ -298,6 +304,7 @@ async def process_target_delayed_event( Optional[StateKey], Timestamp, JsonDict, + Optional[DeviceID], ], bool, Optional[Timestamp], @@ -325,6 +332,7 @@ def process_target_delayed_event_txn(txn: LoggingTransaction) -> Tuple[ Optional[StateKey], Timestamp, JsonDict, + Optional[DeviceID], ], bool, Optional[Timestamp], @@ -337,6 +345,7 @@ def process_target_delayed_event_txn(txn: LoggingTransaction) -> Tuple[ "origin_server_ts", "send_ts", "content", + "device_id", ) ) sql_update = "UPDATE delayed_events SET is_processed = TRUE" @@ -364,6 +373,7 @@ def process_target_delayed_event_txn(txn: LoggingTransaction) -> Tuple[ StateKey(row[2]) if row[2] is not None else None, Timestamp(row[3]) if row[3] is not None else send_ts, db_to_json(row[5]), + DeviceID(row[6]) if row[6] is not None else None, ) next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) diff --git a/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql index 5996bcdd0c7..129e74b5bdf 100644 --- a/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql +++ b/synapse/storage/schema/main/delta/87/01_add_delayed_events.sql @@ -14,6 +14,7 @@ CREATE TABLE delayed_events ( delay_id TEXT NOT NULL, user_localpart TEXT NOT NULL, + device_id TEXT, delay BIGINT NOT NULL, send_ts BIGINT NOT NULL, room_id TEXT NOT NULL, From 798c79e7735ee52c458938d2aa10a001745c127c Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Wed, 4 Sep 2024 09:20:47 -0400 Subject: [PATCH 48/79] Remove license headers on new files --- synapse/handlers/delayed_events.py | 18 ------------------ synapse/rest/client/delayed_events.py | 18 ------------------ .../storage/databases/main/delayed_events.py | 18 ------------------ .../main/delta/88/01_add_delayed_events.sql | 13 ------------- 4 files changed, 67 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 5e4ca6f4d51..cd9c37d3026 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -1,21 +1,3 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# . -# -# Originally licensed under the Apache License, Version 2.0: -# . -# -# - import logging from typing import TYPE_CHECKING, List, Optional, Set, Tuple diff --git a/synapse/rest/client/delayed_events.py b/synapse/rest/client/delayed_events.py index 8608ccebcc4..a386bcf2d6d 100644 --- a/synapse/rest/client/delayed_events.py +++ b/synapse/rest/client/delayed_events.py @@ -1,21 +1,3 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# . -# -# Originally licensed under the Apache License, Version 2.0: -# . -# -# - """ This module contains REST servlets to do with delayed events: /delayed_events/ """ import logging from enum import Enum diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 80d46e9ed17..57bff50a78f 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -1,21 +1,3 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# . -# -# Originally licensed under the Apache License, Version 2.0: -# . -# -# - import logging from typing import List, NewType, Optional, Tuple diff --git a/synapse/storage/schema/main/delta/88/01_add_delayed_events.sql b/synapse/storage/schema/main/delta/88/01_add_delayed_events.sql index 129e74b5bdf..a807d93eb1a 100644 --- a/synapse/storage/schema/main/delta/88/01_add_delayed_events.sql +++ b/synapse/storage/schema/main/delta/88/01_add_delayed_events.sql @@ -1,16 +1,3 @@ --- --- This file is licensed under the Affero General Public License (AGPL) version 3. --- --- Copyright (C) 2024 New Vector, Ltd --- --- This program is free software: you can redistribute it and/or modify --- it under the terms of the GNU Affero General Public License as --- published by the Free Software Foundation, either version 3 of the --- License, or (at your option) any later version. --- --- See the GNU Affero General Public License for more details: --- . - CREATE TABLE delayed_events ( delay_id TEXT NOT NULL, user_localpart TEXT NOT NULL, From 186e55d8bb2a64edbe8669977db0e67dad2baa78 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Wed, 4 Sep 2024 09:27:10 -0400 Subject: [PATCH 49/79] Lint --- synapse/rest/client/delayed_events.py | 3 ++- synapse/storage/databases/main/delayed_events.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/synapse/rest/client/delayed_events.py b/synapse/rest/client/delayed_events.py index a386bcf2d6d..1461059b482 100644 --- a/synapse/rest/client/delayed_events.py +++ b/synapse/rest/client/delayed_events.py @@ -1,4 +1,5 @@ -""" This module contains REST servlets to do with delayed events: /delayed_events/ """ +"""This module contains REST servlets to do with delayed events: /delayed_events/""" + import logging from enum import Enum from http import HTTPStatus diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 57bff50a78f..84d3bf42cc2 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -202,7 +202,9 @@ async def get_all_delayed_events_for_user( for row in rows ] - async def process_timeout_delayed_events(self, current_ts: Timestamp) -> Tuple[ + async def process_timeout_delayed_events( + self, current_ts: Timestamp + ) -> Tuple[ List[DelayedEventDetails], Optional[Timestamp], ]: @@ -214,7 +216,9 @@ async def process_timeout_delayed_events(self, current_ts: Timestamp) -> Tuple[ and the send time of the next delayed event to be sent, if any. """ - def process_timeout_delayed_events_txn(txn: LoggingTransaction) -> Tuple[ + def process_timeout_delayed_events_txn( + txn: LoggingTransaction, + ) -> Tuple[ List[DelayedEventDetails], Optional[Timestamp], ]: @@ -307,7 +311,9 @@ async def process_target_delayed_event( NotFoundError: if there is no matching delayed event. """ - def process_target_delayed_event_txn(txn: LoggingTransaction) -> Tuple[ + def process_target_delayed_event_txn( + txn: LoggingTransaction, + ) -> Tuple[ Tuple[ RoomID, EventType, From a3fbdd348d03775a7c179a7ade03ac44c61cfade Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 9 Sep 2024 08:57:02 -0400 Subject: [PATCH 50/79] Update documentation --- docs/usage/configuration/config_documentation.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index a332037cdc2..23612235aca 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -764,10 +764,11 @@ email: --- ### `max_event_delay_duration` -The maximum allowed duration by which sent events can be delayed, as per MSC4140. +The maximum allowed duration by which sent events can be delayed, as per +[MSC4140](https://github.com/matrix-org/matrix-spec-proposals/pull/4140). Must be a positive value if set. -Defaults to no duration, which disallows sending delayed events. +Defaults to no duration (`null`), which disallows sending delayed events. Example configuration: ```yaml From ef7284ff3df7231d4e4ce7f1c82e617e2829e164 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 9 Sep 2024 08:59:38 -0400 Subject: [PATCH 51/79] Fix top-level comment --- synapse/rest/client/delayed_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/rest/client/delayed_events.py b/synapse/rest/client/delayed_events.py index 1461059b482..dd7e320b750 100644 --- a/synapse/rest/client/delayed_events.py +++ b/synapse/rest/client/delayed_events.py @@ -1,4 +1,4 @@ -"""This module contains REST servlets to do with delayed events: /delayed_events/""" +# This module contains REST servlets to do with delayed events: /delayed_events/ import logging from enum import Enum From 92d352c2aa27eb7d83ac440f67e46bbffdc8e3ee Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 9 Sep 2024 09:00:43 -0400 Subject: [PATCH 52/79] Add docstring to helper function --- synapse/rest/client/room.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 0afabf02ec7..778256c4163 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -453,6 +453,20 @@ def _parse_request_delay( request: SynapseRequest, max_delay: Optional[int], ) -> Optional[int]: + """Parses from the request string the delay parameter for + delayed event requests, and checks it for correctness. + + Args: + request: the twisted HTTP request. + max_delay: the maximum allowed value of the delay parameter, + or None if no delay parameter is allowed. + Returns: + The value of the requested delay, or None if it was absent. + + Raises: + SynapseError: if the delay parameter is present and forbidden, + or if it exceeds the maximum allowed value. + """ delay = parse_integer(request, "org.matrix.msc4140.delay") if delay is None: return None From 0ab82f50b58ff1319d8900b7b4eaba2088477f0f Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 9 Sep 2024 09:03:13 -0400 Subject: [PATCH 53/79] Fix unit tests --- tests/rest/client/test_rooms.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 9a1bae3d70b..cff7702be01 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2299,6 +2299,7 @@ class RoomDelayedEventTestCase(RoomBase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) + @unittest.override_config({"max_event_delay_duration": "24h"}) def test_send_delayed_invalid_event(self) -> None: """Test sending a delayed event with invalid content.""" channel = self.make_request( @@ -2310,6 +2311,7 @@ def test_send_delayed_invalid_event(self) -> None: {}, ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) + self.assertNotIn("org.matrix.msc4140.errcode", channel.json_body) def test_delayed_event_unsupported_by_default(self) -> None: """Test that sending a delayed event is unsupported with the default config.""" @@ -2322,6 +2324,11 @@ def test_delayed_event_unsupported_by_default(self) -> None: {"body": "test", "msgtype": "m.text"}, ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) + self.assertEqual( + "M_MAX_DELAY_UNSUPPORTED", + channel.json_body.get("org.matrix.msc4140.errcode"), + channel.json_body, + ) @unittest.override_config({"max_event_delay_duration": "1000"}) def test_delayed_event_exceeds_max_delay(self) -> None: @@ -2335,10 +2342,31 @@ def test_delayed_event_exceeds_max_delay(self) -> None: {"body": "test", "msgtype": "m.text"}, ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) + self.assertEqual( + "M_MAX_DELAY_EXCEEDED", + channel.json_body.get("org.matrix.msc4140.errcode"), + channel.json_body, + ) + + @unittest.override_config({"max_event_delay_duration": "24h"}) + def test_delayed_event_with_negative_delay(self) -> None: + """Test that sending a delayed event fails if its delay is negative.""" + channel = self.make_request( + "PUT", + ( + "rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=-2000" + % self.room_id + ).encode("ascii"), + {"body": "test", "msgtype": "m.text"}, + ) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) + self.assertEqual( + Codes.INVALID_PARAM, channel.json_body["errcode"], channel.json_body + ) @unittest.override_config({"max_event_delay_duration": "24h"}) def test_send_delayed_message_event(self) -> None: - """Test sending a delayed event with invalid content.""" + """Test sending a valid delayed message event.""" channel = self.make_request( "PUT", ( @@ -2351,7 +2379,7 @@ def test_send_delayed_message_event(self) -> None: @unittest.override_config({"max_event_delay_duration": "24h"}) def test_send_delayed_state_event(self) -> None: - """Test sending a delayed event with invalid content.""" + """Test sending a valid delayed state event.""" channel = self.make_request( "PUT", ( From 90259228ee2d6bf5b1235cbe32668edaefa7f548 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 9 Sep 2024 09:10:31 -0400 Subject: [PATCH 54/79] Comment early match return on no delayed events --- synapse/storage/databases/main/delayed_events.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 84d3bf42cc2..8227aecd09a 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -111,6 +111,7 @@ def restart_delayed_event_txn( ) -> Tuple[bool, Timestamp]: old_next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) if old_next_send_ts is None: + # Return early if there are no delayed events at all raise NotFoundError("Delayed event not found") if self.database_engine.supports_returning: @@ -396,6 +397,7 @@ def cancel_delayed_event_txn( ) -> Tuple[bool, Optional[Timestamp]]: old_next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) if old_next_send_ts is None: + # Return early if there are no delayed events at all raise NotFoundError("Delayed event not found") try: From c3ad95d565edffff2a1399c6e40a54492751e5c7 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 9 Sep 2024 09:11:59 -0400 Subject: [PATCH 55/79] Pick a nit --- synapse/storage/databases/main/delayed_events.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 8227aecd09a..892a8e2e10f 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -458,8 +458,8 @@ def cancel_delayed_state_events_txn( assert old_next_send_ts is not None new_next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) return new_next_send_ts != old_next_send_ts, new_next_send_ts - else: - return False, None + + return False, None return await self.db_pool.runInteraction( "cancel_delayed_state_events", cancel_delayed_state_events_txn From 1dbbb749bc9bbc22bac8153e414a0a48c18fcb2f Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 9 Sep 2024 09:13:02 -0400 Subject: [PATCH 56/79] Remove TODO for something that's no longer needed --- synapse/handlers/delayed_events.py | 1 - 1 file changed, 1 deletion(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index cd9c37d3026..9ec61587cb3 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -332,7 +332,6 @@ async def _send_event( user_id = UserID(user_localpart, self._config.server.server_name) user_id_str = user_id.to_string() # Create a new requester from what data is currently available - # TODO: Consider storing the requester in the DB at add time and deserialize it here requester = create_requester( user_id, is_guest=await self._store.is_guest(user_id_str), From 3e9f76f32defdadd246e6d8cdc3eb7bd1b204891 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 9 Sep 2024 09:24:59 -0400 Subject: [PATCH 57/79] Refactor DB logic for delayed event resetting --- synapse/handlers/delayed_events.py | 6 +- .../storage/databases/main/delayed_events.py | 124 +++++++++++------- 2 files changed, 78 insertions(+), 52 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 9ec61587cb3..d53dc1cd7bd 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -201,14 +201,14 @@ async def restart(self, requester: Requester, delay_id: str) -> None: await self._request_ratelimiter.ratelimit(requester) await self._initialized_from_db - changed, next_send_ts = await self._store.restart_delayed_event( + changed_next_send_ts = await self._store.restart_delayed_event( delay_id=delay_id, user_localpart=requester.user.localpart, current_ts=self._get_current_ts(), ) - if changed: - self._schedule_next_at(next_send_ts) + if changed_next_send_ts is not None: + self._schedule_next_at(changed_next_send_ts) async def send(self, requester: Requester, delay_id: str) -> None: """ diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 892a8e2e10f..838f982bcc6 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -89,7 +89,7 @@ async def restart_delayed_event( delay_id: str, user_localpart: str, current_ts: Timestamp, - ) -> Tuple[bool, Timestamp]: + ) -> Optional[Timestamp]: """ Restarts the send time of the matching delayed event, as long as it hasn't already been marked for processing. @@ -99,8 +99,9 @@ async def restart_delayed_event( user_localpart: The localpart of the delayed event's owner. current_ts: The current time, which will be used to calculate the new send time. - Returns: Whether the matching delayed event would have been the next to be sent, - and if so, what the next soonest send time is, if any. + Returns: None if the matching delayed event would not have been the next to be sent; + otherwise, the send time of the next delayed event to be sent (which may be the + matching delayed event, or another one sent before it). Raises: NotFoundError: if there is no matching delayed event. @@ -108,59 +109,84 @@ async def restart_delayed_event( def restart_delayed_event_txn( txn: LoggingTransaction, - ) -> Tuple[bool, Timestamp]: - old_next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) - if old_next_send_ts is None: + ) -> Optional[Timestamp]: + txn.execute( + """ + SELECT delay_id, user_localpart, send_ts + FROM delayed_events + WHERE NOT is_processed + ORDER BY send_ts + LIMIT 2 + """ + ) + if txn.rowcount == 0: # Return early if there are no delayed events at all raise NotFoundError("Delayed event not found") - if self.database_engine.supports_returning: - txn.execute( - """ - UPDATE delayed_events - SET send_ts = ? + delay - WHERE delay_id = ? AND user_localpart = ? - AND NOT is_processed - RETURNING send_ts - """, - ( - current_ts, - delay_id, - user_localpart, - ), - ) - row = txn.fetchone() - if row is None: - raise NotFoundError("Delayed event not found") - - restarted_send_ts = row[0] + if txn.rowcount == 1: + # The restarted event is the only event + changed = True + second_next_send_ts = None else: - keyvalues = { - "delay_id": delay_id, - "user_localpart": user_localpart, - "is_processed": False, - } - delay = self.db_pool.simple_select_one_onecol_txn( - txn, - table="delayed_events", - keyvalues=keyvalues, - retcol="delay", - allow_none=True, - ) - if delay is None: - raise NotFoundError("Delayed event not found") + next_event, second_next_event = txn.fetchmany(2) + if ( + # The restarted event is the next to be sent + delay_id == next_event[0] + and user_localpart == next_event[1] + # The next event to be sent is the only one to be sent at that time + and next_event[2] != second_next_event[2] + ): + changed = True + second_next_send_ts = Timestamp(second_next_event[2]) + else: + changed = False - restarted_send_ts = current_ts + delay - self.db_pool.simple_update_one_txn( - txn, - table="delayed_events", - keyvalues=keyvalues, - updatevalues={"send_ts": restarted_send_ts}, + sql_base = """ + UPDATE delayed_events + SET send_ts = ? + delay + WHERE delay_id = ? AND user_localpart = ? + AND NOT is_processed + """ + if changed and self.database_engine.supports_returning: + sql_base += "RETURNING send_ts" + + txn.execute( + sql_base, + ( + current_ts, + delay_id, + user_localpart, + ), + ) + if txn.rowcount == 0: + raise NotFoundError("Delayed event not found") + + if changed: + if self.database_engine.supports_returning: + row = txn.fetchone() + assert row is not None + restarted_send_ts = Timestamp(row[0]) + else: + restarted_send_ts = Timestamp( + self.db_pool.simple_select_one_onecol_txn( + txn, + table="delayed_events", + keyvalues={ + "delay_id": delay_id, + "user_localpart": user_localpart, + "is_processed": False, + }, + retcol="send_ts", + ) + ) + + return ( + restarted_send_ts + if second_next_send_ts is None + else min(restarted_send_ts, second_next_send_ts) ) - new_next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) - assert new_next_send_ts is not None - return new_next_send_ts < old_next_send_ts, new_next_send_ts + return None return await self.db_pool.runInteraction( "restart_delayed_event", restart_delayed_event_txn From d36c89f10e2d87c497fbd54533c0d7e421024cfa Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 9 Sep 2024 16:43:21 -0400 Subject: [PATCH 58/79] Use streams to watch for state deltas --- synapse/handlers/delayed_events.py | 113 +++++++++++++++--- .../third_party_event_rules_callbacks.py | 5 +- .../storage/databases/main/delayed_events.py | 47 ++++++++ .../main/delta/88/01_add_delayed_events.sql | 10 ++ tests/rest/client/test_rooms.py | 4 +- 5 files changed, 154 insertions(+), 25 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index d53dc1cd7bd..2908dad87af 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -5,8 +5,8 @@ from synapse.api.constants import EventTypes from synapse.api.errors import ShadowBanError -from synapse.events import EventBase from synapse.logging.opentracing import set_tag +from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main.delayed_events import ( Delay, @@ -18,15 +18,16 @@ Timestamp, UserLocalpart, ) +from synapse.storage.databases.main.state_deltas import StateDelta from synapse.types import ( JsonDict, Requester, RoomID, - StateMap, UserID, create_requester, ) from synapse.util.events import generate_fake_event_id +from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer @@ -37,6 +38,7 @@ class DelayedEventsHandler: def __init__(self, hs: "HomeServer"): self._store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self._config = hs.config self._clock = hs.get_clock() self._request_ratelimiter = hs.get_request_ratelimiter() @@ -45,14 +47,20 @@ def __init__(self, hs: "HomeServer"): self._next_delayed_event_call: Optional[IDelayedCall] = None - # TODO: Looks like these callbacks are run in background. Find a foreground one - hs.get_module_api().register_third_party_rules_callbacks( - on_new_event=self.on_new_event, - ) + # The current position in the current_state_delta stream + self._event_pos: Optional[int] = None - async def _schedule_db_events() -> None: - # TODO: Sync all state first, so that affected delayed state events will be cancelled + # Guard to ensure we only process event deltas one at a time + self._event_processing = False + + if hs.config.worker.run_background_tasks: + hs.get_notifier().add_replication_callback(self.notify_new_event) + # We kick this off to pick up outstanding work from before the last restart. + self._clock.call_later(0, self.notify_new_event) + + # TODO: Refactor or remove this + async def _schedule_db_events() -> None: # Delayed events that are already marked as processed on startup might not have been # sent properly on the last run of the server, so unmark them to send them again. # Caveats: @@ -83,19 +91,80 @@ async def _schedule_db_events() -> None: "_schedule_db_events", _schedule_db_events ) - async def on_new_event( - self, event: EventBase, _state_events: StateMap[EventBase] - ) -> None: + def notify_new_event(self) -> None: + """Called when there may be more event deltas to process""" + if self._event_processing: + return + + self._event_processing = True + + async def process() -> None: + try: + await self._unsafe_process_new_event() + finally: + self._event_processing = False + + run_as_background_process("delayed_events.notify_new_event", process) + + async def _unsafe_process_new_event(self) -> None: + # If self._event_pos is None then means we haven't fetched it from DB + if self._event_pos is None: + self._event_pos = await self._store.get_delayed_events_stream_pos() + room_max_stream_ordering = self._store.get_room_max_stream_ordering() + if self._event_pos > room_max_stream_ordering: + # apparently, we've processed more events than exist in the database! + # this can happen if events are removed with history purge or similar. + logger.warning( + "Event stream ordering appears to have gone backwards (%i -> %i): " + "rewinding delayed events processor", + self._event_pos, + room_max_stream_ordering, + ) + self._event_pos = room_max_stream_ordering + + # Loop round handling deltas until we're up to date + while True: + with Measure(self._clock, "delayed_events_delta"): + room_max_stream_ordering = self._store.get_room_max_stream_ordering() + if self._event_pos == room_max_stream_ordering: + return + + logger.debug( + "Processing delayed events %s->%s", + self._event_pos, + room_max_stream_ordering, + ) + ( + max_pos, + deltas, + ) = await self._storage_controllers.state.get_current_state_deltas( + self._event_pos, room_max_stream_ordering + ) + + logger.debug("Handling %d state deltas", len(deltas)) + await self._handle_state_deltas(deltas) + + self._event_pos = max_pos + + # Expose current event processing position to prometheus + event_processing_positions.labels("delayed_events").set(max_pos) + + await self._store.update_delayed_events_stream_pos(max_pos) + + async def _handle_state_deltas(self, deltas: List[StateDelta]) -> None: """ - Checks if a received event is a state event, and if so, - cancels any delayed events that target the same state. + Process current state deltas to cancel pending delayed events + that target the same state as any received state events. """ - state_key = event.get_state_key() - if state_key is not None: + for delta in deltas: + logger.debug( + "Handling: %r %r, %s", delta.event_type, delta.state_key, delta.event_id + ) + changed, next_send_ts = await self._store.cancel_delayed_state_events( - room_id=event.room_id, - event_type=event.type, - state_key=state_key, + room_id=delta.room_id, + event_type=delta.event_type, + state_key=delta.state_key, ) if changed: @@ -281,11 +350,17 @@ async def _send_events(self, events: List[DelayedEventDetails]) -> None: device_id, ) if state_info is not None: - # Note that removal from the DB is done by self.on_new_event sent_state.add(state_info) except Exception: logger.exception("Failed to send delayed event") + for room_id, event_type, state_key in sent_state: + await self._store.delete_processed_delayed_state_events( + room_id=str(room_id), + event_type=event_type, + state_key=state_key, + ) + def _schedule_next_at_or_none(self, next_send_ts: Optional[Timestamp]) -> None: if next_send_ts is not None: self._schedule_next_at(next_send_ts) diff --git a/synapse/module_api/callbacks/third_party_event_rules_callbacks.py b/synapse/module_api/callbacks/third_party_event_rules_callbacks.py index f4cf8596a8a..9f7a04372de 100644 --- a/synapse/module_api/callbacks/third_party_event_rules_callbacks.py +++ b/synapse/module_api/callbacks/third_party_event_rules_callbacks.py @@ -426,10 +426,7 @@ async def on_new_event(self, event_id: str) -> None: if len(self._on_new_event_callbacks) == 0: return - event = await self.store.get_event(event_id, allow_none=True) - if not event: - logger.warning("Could not find event %s" % (event_id,)) - return + event = await self.store.get_event(event_id) # We *don't* want to wait for the full state here, because waiting for full # state will persist event, which in turn will call this method. diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 838f982bcc6..0c5d52a6610 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -33,6 +33,32 @@ class DelayedEventsStore(SQLBaseStore): + async def get_delayed_events_stream_pos(self) -> int: + """ + Gets the stream position of the background process to watch for state events + that target the same piece of state as any pending delayed events. + """ + return await self.db_pool.simple_select_one_onecol( + table="delayed_events_stream_pos", + keyvalues={}, + retcol="stream_id", + desc="get_delayed_events_stream_pos", + ) + + async def update_delayed_events_stream_pos(self, stream_id: Optional[int]) -> None: + """ + Updates the stream position of the background process to watch for state events + that target the same piece of state as any pending delayed events. + + Must only be used by the worker running the background process. + """ + await self.db_pool.simple_update_one( + table="delayed_events_stream_pos", + keyvalues={}, + updatevalues={"stream_id": stream_id}, + desc="update_delayed_events_stream_pos", + ) + async def add_delayed_event( self, *, @@ -512,6 +538,27 @@ async def delete_processed_delayed_event( desc="delete_processed_delayed_event", ) + async def delete_processed_delayed_state_events( + self, + *, + room_id: str, + event_type: str, + state_key: str, + ) -> None: + """ + Delete the matching delayed state events that have been marked as processed. + """ + await self.db_pool.simple_delete( + table="delayed_events", + keyvalues={ + "room_id": room_id, + "event_type": event_type, + "state_key": state_key, + "is_processed": True, + }, + desc="delete_processed_delayed_state_events", + ) + async def unprocess_delayed_events(self) -> None: """ Unmark all delayed events for processing. diff --git a/synapse/storage/schema/main/delta/88/01_add_delayed_events.sql b/synapse/storage/schema/main/delta/88/01_add_delayed_events.sql index a807d93eb1a..91a34583299 100644 --- a/synapse/storage/schema/main/delta/88/01_add_delayed_events.sql +++ b/synapse/storage/schema/main/delta/88/01_add_delayed_events.sql @@ -16,3 +16,13 @@ CREATE TABLE delayed_events ( CREATE INDEX delayed_events_send_ts ON delayed_events (send_ts); CREATE INDEX delayed_events_is_processed ON delayed_events (is_processed); CREATE INDEX delayed_events_room_state_event_idx ON delayed_events (room_id, event_type, state_key) WHERE state_key IS NOT NULL; + +CREATE TABLE delayed_events_stream_pos ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT NOT NULL, + CHECK (Lock='X') +); + +INSERT INTO delayed_events_stream_pos ( + stream_id +) SELECT COALESCE(MAX(stream_ordering), 0) from events; diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index cff7702be01..00be0051c61 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -742,7 +742,7 @@ def test_post_room_no_keys(self) -> None: self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(38, channel.resource_usage.db_txn_count) + self.assertEqual(33, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -755,7 +755,7 @@ def test_post_room_initial_state(self) -> None: self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(40, channel.resource_usage.db_txn_count) + self.assertEqual(35, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id From e0225eba21e11e503aa2a2ef73d85ada6fabf718 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 10 Sep 2024 09:12:44 -0400 Subject: [PATCH 59/79] Nitpick: rename inner function Remove leading underscore for consistency with other inner functions --- synapse/storage/databases/main/delayed_events.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 0c5d52a6610..bc532cbf5b2 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -81,7 +81,7 @@ async def add_delayed_event( delay_id = _generate_delay_id() send_ts = creation_ts + delay - def _add_delayed_event_txn(txn: LoggingTransaction) -> bool: + def add_delayed_event_txn(txn: LoggingTransaction) -> bool: old_next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) self.db_pool.simple_insert_txn( @@ -104,7 +104,7 @@ def _add_delayed_event_txn(txn: LoggingTransaction) -> bool: return old_next_send_ts is None or send_ts < old_next_send_ts changed = await self.db_pool.runInteraction( - "add_delayed_event", _add_delayed_event_txn + "add_delayed_event", add_delayed_event_txn ) return delay_id, changed From e741c56bfeff866f0745133e2ea4ef78596e5b7c Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 10 Sep 2024 11:25:48 -0400 Subject: [PATCH 60/79] Add/improve comments & logs --- synapse/handlers/delayed_events.py | 11 +++++++---- synapse/storage/databases/main/delayed_events.py | 2 +- .../schema/main/delta/88/01_add_delayed_events.sql | 2 ++ 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 2908dad87af..eb106388e64 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -92,7 +92,10 @@ async def _schedule_db_events() -> None: ) def notify_new_event(self) -> None: - """Called when there may be more event deltas to process""" + """ + Called when there may be more state event deltas to process, + which should cancel pending delayed events for the same state. + """ if self._event_processing: return @@ -107,7 +110,7 @@ async def process() -> None: run_as_background_process("delayed_events.notify_new_event", process) async def _unsafe_process_new_event(self) -> None: - # If self._event_pos is None then means we haven't fetched it from DB + # If self._event_pos is None then means we haven't fetched it from the DB yet if self._event_pos is None: self._event_pos = await self._store.get_delayed_events_stream_pos() room_max_stream_ordering = self._store.get_room_max_stream_ordering() @@ -141,7 +144,7 @@ async def _unsafe_process_new_event(self) -> None: self._event_pos, room_max_stream_ordering ) - logger.debug("Handling %d state deltas", len(deltas)) + logger.debug("Handling %d state deltas for delayed events processing", len(deltas)) await self._handle_state_deltas(deltas) self._event_pos = max_pos @@ -154,7 +157,7 @@ async def _unsafe_process_new_event(self) -> None: async def _handle_state_deltas(self, deltas: List[StateDelta]) -> None: """ Process current state deltas to cancel pending delayed events - that target the same state as any received state events. + that target the same state. """ for delta in deltas: logger.debug( diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index bc532cbf5b2..a98361d4f07 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -159,7 +159,7 @@ def restart_delayed_event_txn( # The restarted event is the next to be sent delay_id == next_event[0] and user_localpart == next_event[1] - # The next event to be sent is the only one to be sent at that time + # The next two events to be sent aren't scheduled at the same time and next_event[2] != second_next_event[2] ): changed = True diff --git a/synapse/storage/schema/main/delta/88/01_add_delayed_events.sql b/synapse/storage/schema/main/delta/88/01_add_delayed_events.sql index 91a34583299..55bfbc8ae7c 100644 --- a/synapse/storage/schema/main/delta/88/01_add_delayed_events.sql +++ b/synapse/storage/schema/main/delta/88/01_add_delayed_events.sql @@ -23,6 +23,8 @@ CREATE TABLE delayed_events_stream_pos ( CHECK (Lock='X') ); +-- Start processing events from the point this migration was run, rather +-- than the beginning of time. INSERT INTO delayed_events_stream_pos ( stream_id ) SELECT COALESCE(MAX(stream_ordering), 0) from events; From 2d79506a47e9cd3122ef62ded65566a05365390f Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Tue, 10 Sep 2024 11:34:28 -0400 Subject: [PATCH 61/79] Lint --- synapse/handlers/delayed_events.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index eb106388e64..f9ff302af9c 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -144,7 +144,10 @@ async def _unsafe_process_new_event(self) -> None: self._event_pos, room_max_stream_ordering ) - logger.debug("Handling %d state deltas for delayed events processing", len(deltas)) + logger.debug( + "Handling %d state deltas for delayed events processing", + len(deltas), + ) await self._handle_state_deltas(deltas) self._event_pos = max_pos From e41b5a1d74fc14e1746ff7a5d5a47cd255193514 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Wed, 11 Sep 2024 00:42:33 -0400 Subject: [PATCH 62/79] Put retrieved delayed events in field for GET This is to match what is specified in the MSC. Also add a unit test for this. --- synapse/rest/client/delayed_events.py | 8 +++++--- tests/rest/client/test_delayed_events.py | 26 ++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) create mode 100644 tests/rest/client/test_delayed_events.py diff --git a/synapse/rest/client/delayed_events.py b/synapse/rest/client/delayed_events.py index dd7e320b750..5c2dcdc07d1 100644 --- a/synapse/rest/client/delayed_events.py +++ b/synapse/rest/client/delayed_events.py @@ -83,11 +83,13 @@ def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.delayed_events_handler = hs.get_delayed_events_handler() - async def on_GET(self, request: SynapseRequest) -> Tuple[int, List[JsonDict]]: + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) # TODO: Support Pagination stream API ("from" query parameter) - data = await self.delayed_events_handler.get_all_for_user(requester) - return 200, data + delayed_events = await self.delayed_events_handler.get_all_for_user(requester) + + ret = {"delayed_events": delayed_events} + return 200, ret def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: diff --git a/tests/rest/client/test_delayed_events.py b/tests/rest/client/test_delayed_events.py new file mode 100644 index 00000000000..68a0ddcc66b --- /dev/null +++ b/tests/rest/client/test_delayed_events.py @@ -0,0 +1,26 @@ +"""Tests REST events for /delayed_events paths.""" + +from http import HTTPStatus + +from tests.unittest import HomeserverTestCase + +from synapse.rest.client import delayed_events + +PATH_PREFIX = b"/_matrix/client/unstable/org.matrix.msc4140/delayed_events" + + +class DelayedEventsTestCase(HomeserverTestCase): + """Tests getting and managing delayed events.""" + + servlets = [delayed_events.register_servlets] + + user_id = "@sid1:red" + + def test_get_delayed_events(self) -> None: + channel = self.make_request("GET", PATH_PREFIX) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + self.assertEqual( + [], + channel.json_body.get("delayed_events"), + channel.json_body, + ) From 94048f72b939d0d1825049867ede8e791788022b Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Wed, 11 Sep 2024 09:49:00 -0400 Subject: [PATCH 63/79] Lint imports --- synapse/rest/client/delayed_events.py | 2 +- tests/rest/client/test_delayed_events.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/rest/client/delayed_events.py b/synapse/rest/client/delayed_events.py index 5c2dcdc07d1..7bb919c5cbf 100644 --- a/synapse/rest/client/delayed_events.py +++ b/synapse/rest/client/delayed_events.py @@ -3,7 +3,7 @@ import logging from enum import Enum from http import HTTPStatus -from typing import TYPE_CHECKING, List, Tuple +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer diff --git a/tests/rest/client/test_delayed_events.py b/tests/rest/client/test_delayed_events.py index 68a0ddcc66b..3ff6a20a1cf 100644 --- a/tests/rest/client/test_delayed_events.py +++ b/tests/rest/client/test_delayed_events.py @@ -2,10 +2,10 @@ from http import HTTPStatus -from tests.unittest import HomeserverTestCase - from synapse.rest.client import delayed_events +from tests.unittest import HomeserverTestCase + PATH_PREFIX = b"/_matrix/client/unstable/org.matrix.msc4140/delayed_events" From 8e3df619577600d4390e2b0b1c52c6824c6451a4 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Wed, 11 Sep 2024 09:49:26 -0400 Subject: [PATCH 64/79] Don't use data-modifying CTE in WITH for sqlite Also don't try to use ordered RETURNING rows for sqlite either --- synapse/storage/databases/main/delayed_events.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index a98361d4f07..1062ddb622b 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -4,6 +4,7 @@ from synapse.api.errors import NotFoundError from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import LoggingTransaction, StoreError +from synapse.storage.engines import PostgresEngine from synapse.types import JsonDict, RoomID from synapse.util import json_encoder, stringutils as stringutils @@ -292,7 +293,13 @@ def process_timeout_delayed_events_txn( sql_where = "WHERE send_ts <= ? AND NOT is_processed" sql_args = (current_ts,) sql_order = "ORDER BY send_ts" - if self.database_engine.supports_returning: + if isinstance(self.database_engine, PostgresEngine): + # Do this only in Postgres because: + # - SQLite's RETURNING emits rows in an arbitrary order + # - https://www.sqlite.org/lang_returning.html#limitations_and_caveats + # - SQLite does not support data-modifying statements in a WITH clause + # - https://www.sqlite.org/lang_with.html + # - https://www.postgresql.org/docs/current/queries-with.html#QUERIES-WITH-MODIFYING txn.execute( f""" WITH events_to_send AS ( From dd3c746eb5b7a6d378ab0ceac242cbe4a5692a17 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Wed, 11 Sep 2024 11:17:07 -0400 Subject: [PATCH 65/79] Remove TODO for returning transaction IDs GET /delayed_events is no longer specced to return them --- synapse/storage/databases/main/delayed_events.py | 1 - 1 file changed, 1 deletion(-) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 1062ddb622b..8468219281c 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -224,7 +224,6 @@ async def get_all_delayed_events_for_user( user_localpart: str, ) -> List[JsonDict]: """Returns all pending delayed events owned by the given user.""" - # TODO: Store and return "transaction_id" # TODO: Support Pagination stream API ("next_batch" field) rows = await self.db_pool.simple_select_list( table="delayed_events", From a60fa7f91b7f508755654442e3c54df42cf53c25 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Wed, 11 Sep 2024 14:31:44 -0400 Subject: [PATCH 66/79] Use attrs classes for delayed event properties --- synapse/handlers/delayed_events.py | 81 +++++++------------ .../storage/databases/main/delayed_events.py | 54 ++++++------- 2 files changed, 53 insertions(+), 82 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index f9ff302af9c..79be1951a6f 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -1,6 +1,8 @@ import logging from typing import TYPE_CHECKING, List, Optional, Set, Tuple +import attr + from twisted.internet.interfaces import IDelayedCall from synapse.api.constants import EventTypes @@ -12,7 +14,6 @@ Delay, DelayedEventDetails, DelayID, - DeviceID, EventType, StateKey, Timestamp, @@ -308,9 +309,12 @@ async def send(self, requester: Requester, delay_id: str) -> None: self._schedule_next_at_or_none(next_send_ts) await self._send_event( - DelayID(delay_id), - UserLocalpart(requester.user.localpart), - *event, + DelayedEventDetails( + # NOTE: mypy thinks that (*attr.astuple, ...) is too many args, so use kwargs instead + delay_id=DelayID(delay_id), + user_localpart=UserLocalpart(requester.user.localpart), + **attr.asdict(event), + ) ) async def _send_on_timeout(self) -> None: @@ -327,34 +331,16 @@ async def _send_on_timeout(self) -> None: async def _send_events(self, events: List[DelayedEventDetails]) -> None: sent_state: Set[Tuple[RoomID, EventType, StateKey]] = set() - for ( - delay_id, - user_localpart, - room_id, - event_type, - state_key, - origin_server_ts, - content, - device_id, - ) in events: - if state_key is not None: - state_info = (room_id, event_type, state_key) + for event in events: + if event.state_key is not None: + state_info = (event.room_id, event.type, event.state_key) if state_info in sent_state: continue else: state_info = None try: # TODO: send in background if message event or non-conflicting state event - await self._send_event( - delay_id, - user_localpart, - room_id, - event_type, - state_key, - origin_server_ts, - content, - device_id, - ) + await self._send_event(event) if state_info is not None: sent_state.add(state_info) except Exception: @@ -400,65 +386,58 @@ async def get_all_for_user(self, requester: Requester) -> List[JsonDict]: async def _send_event( self, - delay_id: DelayID, - user_localpart: UserLocalpart, - room_id: RoomID, - event_type: EventType, - state_key: Optional[StateKey], - origin_server_ts: Timestamp, - content: JsonDict, - device_id: Optional[DeviceID], + event: DelayedEventDetails, txn_id: Optional[str] = None, ) -> None: - user_id = UserID(user_localpart, self._config.server.server_name) + user_id = UserID(event.user_localpart, self._config.server.server_name) user_id_str = user_id.to_string() # Create a new requester from what data is currently available requester = create_requester( user_id, is_guest=await self._store.is_guest(user_id_str), - device_id=device_id, + device_id=event.device_id, ) try: - if state_key is not None and event_type == EventTypes.Member: - membership = content.get("membership") + if event.state_key is not None and event.type == EventTypes.Member: + membership = event.content.get("membership") assert membership is not None event_id, _ = await self._room_member_handler.update_membership( requester, - target=UserID.from_string(state_key), - room_id=room_id.to_string(), + target=UserID.from_string(event.state_key), + room_id=event.room_id.to_string(), action=membership, - content=content, - origin_server_ts=origin_server_ts, + content=event.content, + origin_server_ts=event.origin_server_ts, ) else: event_dict: JsonDict = { - "type": event_type, - "content": content, - "room_id": room_id.to_string(), + "type": event.type, + "content": event.content, + "room_id": event.room_id.to_string(), "sender": user_id_str, - "origin_server_ts": origin_server_ts, + "origin_server_ts": event.origin_server_ts, } - if state_key is not None: - event_dict["state_key"] = state_key + if event.state_key is not None: + event_dict["state_key"] = event.state_key ( - event, + sent_event, _, ) = await self._event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id, ) - event_id = event.event_id + event_id = sent_event.event_id except ShadowBanError: event_id = generate_fake_event_id() finally: # TODO: If this is a temporary error, retry. Otherwise, consider notifying clients of the failure try: await self._store.delete_processed_delayed_event( - delay_id, user_localpart + event.delay_id, event.user_localpart ) except Exception: logger.exception("Failed to delete processed delayed event") diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 8468219281c..b1bcc2ccb10 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -1,6 +1,8 @@ import logging from typing import List, NewType, Optional, Tuple +import attr + from synapse.api.errors import NotFoundError from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import LoggingTransaction, StoreError @@ -20,17 +22,21 @@ Delay = NewType("Delay", int) Timestamp = NewType("Timestamp", int) -# TODO: Maybe use attr class -DelayedEventDetails = Tuple[ - DelayID, - UserLocalpart, - RoomID, - EventType, - Optional[StateKey], - Timestamp, - JsonDict, - Optional[DeviceID], -] + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventDetails: + room_id: RoomID + type: EventType + state_key: Optional[StateKey] + origin_server_ts: Timestamp + content: JsonDict + device_id: Optional[DeviceID] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class DelayedEventDetails(EventDetails): + delay_id: DelayID + user_localpart: UserLocalpart class DelayedEventsStore(SQLBaseStore): @@ -318,15 +324,15 @@ def process_timeout_delayed_events_txn( assert txn.rowcount == len(rows) events = [ - ( - DelayID(row[0]), - UserLocalpart(row[1]), + DelayedEventDetails( RoomID.from_string(row[2]), EventType(row[3]), StateKey(row[4]) if row[4] is not None else None, Timestamp(row[5] if row[5] is not None else row[6]), db_to_json(row[7]), DeviceID(row[8]) if row[8] is not None else None, + DelayID(row[0]), + UserLocalpart(row[1]), ) for row in rows ] @@ -343,14 +349,7 @@ async def process_target_delayed_event( delay_id: str, user_localpart: str, ) -> Tuple[ - Tuple[ - RoomID, - EventType, - Optional[StateKey], - Timestamp, - JsonDict, - Optional[DeviceID], - ], + EventDetails, bool, Optional[Timestamp], ]: @@ -373,14 +372,7 @@ async def process_target_delayed_event( def process_target_delayed_event_txn( txn: LoggingTransaction, ) -> Tuple[ - Tuple[ - RoomID, - EventType, - Optional[StateKey], - Timestamp, - JsonDict, - Optional[DeviceID], - ], + EventDetails, bool, Optional[Timestamp], ]: @@ -414,7 +406,7 @@ def process_target_delayed_event_txn( assert txn.rowcount == 1 send_ts = Timestamp(row[4]) - event = ( + event = EventDetails( RoomID.from_string(row[0]), EventType(row[1]), StateKey(row[2]) if row[2] is not None else None, From 092793a287f935a659f40b19006fee8df00bc903 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Thu, 12 Sep 2024 22:30:18 -0400 Subject: [PATCH 67/79] Restrict delayed events to a single worker --- synapse/app/generic_worker.py | 2 + synapse/handlers/delayed_events.py | 83 ++++++++++++++-------- synapse/replication/http/__init__.py | 2 + synapse/replication/http/delayed_events.py | 48 +++++++++++++ synapse/rest/client/delayed_events.py | 4 +- 5 files changed, 107 insertions(+), 32 deletions(-) create mode 100644 synapse/replication/http/delayed_events.py diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 18d294f2b2a..6a944998f17 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -65,6 +65,7 @@ ) from synapse.storage.databases.main.censor_events import CensorEventsStore from synapse.storage.databases.main.client_ips import ClientIpWorkerStore +from synapse.storage.databases.main.delayed_events import DelayedEventsStore from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore from synapse.storage.databases.main.devices import DeviceWorkerStore from synapse.storage.databases.main.directory import DirectoryWorkerStore @@ -161,6 +162,7 @@ class GenericWorkerStore( TaskSchedulerWorkerStore, ExperimentalFeaturesStore, SlidingSyncStore, + DelayedEventsStore, ): # Properties that multiple storage classes define. Tell mypy what the # expected type is. diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 79be1951a6f..69da1764a3b 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -7,9 +7,13 @@ from synapse.api.constants import EventTypes from synapse.api.errors import ShadowBanError +from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME from synapse.logging.opentracing import set_tag from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.http.delayed_events import ( + ReplicationAddedDelayedEventRestServlet, +) from synapse.storage.databases.main.delayed_events import ( Delay, DelayedEventDetails, @@ -54,43 +58,46 @@ def __init__(self, hs: "HomeServer"): # Guard to ensure we only process event deltas one at a time self._event_processing = False - if hs.config.worker.run_background_tasks: + if hs.config.worker.worker_app is None: hs.get_notifier().add_replication_callback(self.notify_new_event) # We kick this off to pick up outstanding work from before the last restart. self._clock.call_later(0, self.notify_new_event) - # TODO: Refactor or remove this - async def _schedule_db_events() -> None: - # Delayed events that are already marked as processed on startup might not have been - # sent properly on the last run of the server, so unmark them to send them again. - # Caveats: - # - This will double-send delayed events that successfully persisted, but failed to be - # removed from the DB table of delayed events. - # - This will interfere with workers that are in the act of processing delayed events. - # TODO: To avoid double-sending, scan the timeline to find which of these events were - # already sent. To do so, must store delay_ids in sent events to retrieve them later. - # TODO: To avoid interfering with workers, think of a way to distinguish between - # events being processed by a worker vs ones that got lost after a server crash. - await self._store.unprocess_delayed_events() - - events, next_send_ts = await self._store.process_timeout_delayed_events( - self._get_current_ts() - ) + self._repl_client = None - if next_send_ts: - self._schedule_next_at(next_send_ts) + async def _schedule_db_events() -> None: + # Delayed events that are already marked as processed on startup might not have been + # sent properly on the last run of the server, so unmark them to send them again. + # Caveat: this will double-send delayed events that successfully persisted, but failed + # to be removed from the DB table of delayed events. + # TODO: To avoid double-sending, scan the timeline to find which of these events were + # already sent. To do so, must store delay_ids in sent events to retrieve them later. + await self._store.unprocess_delayed_events() + + events, next_send_ts = await self._store.process_timeout_delayed_events( + self._get_current_ts() + ) + + if next_send_ts: + self._schedule_next_at(next_send_ts) + + # Can send the events in background after having awaited on marking them as processed + run_as_background_process( + "_send_events", + self._send_events, + events, + ) - # Can send the events in background after having awaited on marking them as processed - run_as_background_process( - "_send_events", - self._send_events, - events, + self._initialized_from_db = run_as_background_process( + "_schedule_db_events", _schedule_db_events ) + else: + self._repl_client = ReplicationAddedDelayedEventRestServlet.make_client(hs) - self._initialized_from_db = run_as_background_process( - "_schedule_db_events", _schedule_db_events - ) + @property + def _is_master(self) -> bool: + return self._repl_client is None def notify_new_event(self) -> None: """ @@ -207,7 +214,6 @@ async def add( SynapseError: if the delayed event fails validation checks. """ await self._request_ratelimiter.ratelimit(requester) - await self._initialized_from_db self._event_creation_handler.validator.validate_builder( self._event_creation_handler.event_builder_factory.for_room_version( @@ -237,10 +243,23 @@ async def add( ) if changed: - self._schedule_next_at(Timestamp(creation_ts + delay)) + next_send_ts = creation_ts + delay + if self._repl_client is None: + self._schedule_next_at(Timestamp(next_send_ts)) + else: + # NOTE: If this throws, the delayed event will remain in the DB and + # will be picked up once the main worker gets another delayed event + # with a sooner send time. + await self._repl_client( + instance_name=MAIN_PROCESS_INSTANCE_NAME, + next_send_ts=next_send_ts, + ) return delay_id + def on_added(self, next_send_ts: int) -> None: + self._schedule_next_at(Timestamp(next_send_ts)) + async def cancel(self, requester: Requester, delay_id: str) -> None: """ Cancels the scheduled delivery of the matching delayed event. @@ -252,6 +271,7 @@ async def cancel(self, requester: Requester, delay_id: str) -> None: Raises: NotFoundError: if no matching delayed event could be found. """ + assert self._is_master await self._request_ratelimiter.ratelimit(requester) await self._initialized_from_db @@ -274,6 +294,7 @@ async def restart(self, requester: Requester, delay_id: str) -> None: Raises: NotFoundError: if no matching delayed event could be found. """ + assert self._is_master await self._request_ratelimiter.ratelimit(requester) await self._initialized_from_db @@ -297,6 +318,7 @@ async def send(self, requester: Requester, delay_id: str) -> None: Raises: NotFoundError: if no matching delayed event could be found. """ + assert self._is_master await self._request_ratelimiter.ratelimit(requester) await self._initialized_from_db @@ -379,7 +401,6 @@ def _schedule_next(self, delay: Delay) -> None: async def get_all_for_user(self, requester: Requester) -> List[JsonDict]: """Return all pending delayed events requested by the given user.""" await self._request_ratelimiter.ratelimit(requester) - await self._initialized_from_db return await self._store.get_all_delayed_events_for_user( requester.user.localpart ) diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index c9cf838255e..1673bd057e6 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -23,6 +23,7 @@ from synapse.http.server import JsonResource from synapse.replication.http import ( account_data, + delayed_events, devices, federation, login, @@ -64,3 +65,4 @@ def register_servlets(self, hs: "HomeServer") -> None: login.register_servlets(hs, self) register.register_servlets(hs, self) devices.register_servlets(hs, self) + delayed_events.register_servlets(hs, self) diff --git a/synapse/replication/http/delayed_events.py b/synapse/replication/http/delayed_events.py new file mode 100644 index 00000000000..77dabb08e63 --- /dev/null +++ b/synapse/replication/http/delayed_events.py @@ -0,0 +1,48 @@ +import logging +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +from twisted.web.server import Request + +from synapse.http.server import HttpServer +from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import JsonDict, JsonMapping + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class ReplicationAddedDelayedEventRestServlet(ReplicationEndpoint): + """Handle a delayed event being added by another worker. + + Request format: + + POST /_synapse/replication/delayed_event_added/ + + {} + """ + + NAME = "added_delayed_event" + PATH_ARGS = () + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.handler = hs.get_delayed_events_handler() + + @staticmethod + async def _serialize_payload(next_send_ts: int) -> JsonDict: # type: ignore[override] + return {"next_send_ts": next_send_ts} + + async def _handle_request( # type: ignore[override] + self, request: Request, content: JsonDict + ) -> Tuple[int, Dict[str, Optional[JsonMapping]]]: + self.handler.on_added(int(content["next_send_ts"])) + + return 200, {} + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + ReplicationAddedDelayedEventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/delayed_events.py b/synapse/rest/client/delayed_events.py index 7bb919c5cbf..63558db363f 100644 --- a/synapse/rest/client/delayed_events.py +++ b/synapse/rest/client/delayed_events.py @@ -93,5 +93,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - UpdateDelayedEventServlet(hs).register(http_server) + # The following can't currently be instantiated on workers. + if hs.config.worker.worker_app is None: + UpdateDelayedEventServlet(hs).register(http_server) DelayedEventsServlet(hs).register(http_server) From 1d750601f7f71ee8a5e6a78bceb1ad57bab78f74 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Thu, 12 Sep 2024 22:32:01 -0400 Subject: [PATCH 68/79] Reword docstring --- synapse/storage/databases/main/delayed_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index b1bcc2ccb10..d992aab0278 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -83,7 +83,7 @@ async def add_delayed_event( Inserts a new delayed event in the DB. Returns: The generated ID assigned to the added delayed event, - and whether the added delayed event is the next to be sent. + and whether the next delayed event is now this event instead. """ delay_id = _generate_delay_id() send_ts = creation_ts + delay From 34ed5828f3043516eca49482b97827678d59514e Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Thu, 12 Sep 2024 22:48:14 -0400 Subject: [PATCH 69/79] On startup, wait to catch up on state changes --- synapse/handlers/delayed_events.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 69da1764a3b..8efac14d184 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -59,14 +59,17 @@ def __init__(self, hs: "HomeServer"): self._event_processing = False if hs.config.worker.worker_app is None: - hs.get_notifier().add_replication_callback(self.notify_new_event) - - # We kick this off to pick up outstanding work from before the last restart. - self._clock.call_later(0, self.notify_new_event) - self._repl_client = None async def _schedule_db_events() -> None: + # We kick this off to pick up outstanding work from before the last restart. + # Block until we're up to date. + await self._unsafe_process_new_event() + hs.get_notifier().add_replication_callback(self.notify_new_event) + # Kick off again (without blocking) to catch any missed notifications + # that may have fired before the callback was added. + self._clock.call_later(0, self.notify_new_event) + # Delayed events that are already marked as processed on startup might not have been # sent properly on the last run of the server, so unmark them to send them again. # Caveat: this will double-send delayed events that successfully persisted, but failed From a6cf11c0680e3ffee656e8b9b44cab4e0cdea42a Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Fri, 13 Sep 2024 01:12:44 -0400 Subject: [PATCH 70/79] Don't expect to remember next_send_ts See DelayedEventsHandler._next_send_ts_changed for details --- synapse/handlers/delayed_events.py | 67 ++++---- .../storage/databases/main/delayed_events.py | 157 +++++------------- 2 files changed, 73 insertions(+), 151 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 8efac14d184..fe15efc232c 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -15,7 +15,6 @@ ReplicationAddedDelayedEventRestServlet, ) from synapse.storage.databases.main.delayed_events import ( - Delay, DelayedEventDetails, DelayID, EventType, @@ -178,13 +177,13 @@ async def _handle_state_deltas(self, deltas: List[StateDelta]) -> None: "Handling: %r %r, %s", delta.event_type, delta.state_key, delta.event_id ) - changed, next_send_ts = await self._store.cancel_delayed_state_events( + next_send_ts = await self._store.cancel_delayed_state_events( room_id=delta.room_id, event_type=delta.event_type, state_key=delta.state_key, ) - if changed: + if self._next_send_ts_changed(next_send_ts): self._schedule_next_at_or_none(next_send_ts) async def add( @@ -233,7 +232,7 @@ async def add( creation_ts = self._get_current_ts() - delay_id, changed = await self._store.add_delayed_event( + delay_id, next_send_ts = await self._store.add_delayed_event( user_localpart=requester.user.localpart, device_id=requester.device_id, creation_ts=creation_ts, @@ -245,23 +244,22 @@ async def add( delay=delay, ) - if changed: - next_send_ts = creation_ts + delay - if self._repl_client is None: - self._schedule_next_at(Timestamp(next_send_ts)) - else: - # NOTE: If this throws, the delayed event will remain in the DB and - # will be picked up once the main worker gets another delayed event - # with a sooner send time. - await self._repl_client( - instance_name=MAIN_PROCESS_INSTANCE_NAME, - next_send_ts=next_send_ts, - ) + if self._repl_client is not None: + # NOTE: If this throws, the delayed event will remain in the DB and + # will be picked up once the main worker gets another delayed event. + await self._repl_client( + instance_name=MAIN_PROCESS_INSTANCE_NAME, + next_send_ts=next_send_ts, + ) + elif self._next_send_ts_changed(next_send_ts): + self._schedule_next_at(next_send_ts) return delay_id def on_added(self, next_send_ts: int) -> None: - self._schedule_next_at(Timestamp(next_send_ts)) + next_send_ts = Timestamp(next_send_ts) + if self._next_send_ts_changed(next_send_ts): + self._schedule_next_at(next_send_ts) async def cancel(self, requester: Requester, delay_id: str) -> None: """ @@ -278,12 +276,12 @@ async def cancel(self, requester: Requester, delay_id: str) -> None: await self._request_ratelimiter.ratelimit(requester) await self._initialized_from_db - changed, next_send_ts = await self._store.cancel_delayed_event( + next_send_ts = await self._store.cancel_delayed_event( delay_id=delay_id, user_localpart=requester.user.localpart, ) - if changed: + if self._next_send_ts_changed(next_send_ts): self._schedule_next_at_or_none(next_send_ts) async def restart(self, requester: Requester, delay_id: str) -> None: @@ -301,14 +299,14 @@ async def restart(self, requester: Requester, delay_id: str) -> None: await self._request_ratelimiter.ratelimit(requester) await self._initialized_from_db - changed_next_send_ts = await self._store.restart_delayed_event( + next_send_ts = await self._store.restart_delayed_event( delay_id=delay_id, user_localpart=requester.user.localpart, current_ts=self._get_current_ts(), ) - if changed_next_send_ts is not None: - self._schedule_next_at(changed_next_send_ts) + if self._next_send_ts_changed(next_send_ts): + self._schedule_next_at(next_send_ts) async def send(self, requester: Requester, delay_id: str) -> None: """ @@ -325,12 +323,12 @@ async def send(self, requester: Requester, delay_id: str) -> None: await self._request_ratelimiter.ratelimit(requester) await self._initialized_from_db - event, changed, next_send_ts = await self._store.process_target_delayed_event( + event, next_send_ts = await self._store.process_target_delayed_event( delay_id=delay_id, user_localpart=requester.user.localpart, ) - if changed: + if self._next_send_ts_changed(next_send_ts): self._schedule_next_at_or_none(next_send_ts) await self._send_event( @@ -386,9 +384,7 @@ def _schedule_next_at_or_none(self, next_send_ts: Optional[Timestamp]) -> None: self._next_delayed_event_call = None def _schedule_next_at(self, next_send_ts: Timestamp) -> None: - return self._schedule_next(self._get_delay_until(next_send_ts)) - - def _schedule_next(self, delay: Delay) -> None: + delay = next_send_ts - self._get_current_ts() delay_sec = delay / 1000 if delay > 0 else 0 if self._next_delayed_event_call is None: @@ -471,9 +467,14 @@ async def _send_event( def _get_current_ts(self) -> Timestamp: return Timestamp(self._clock.time_msec()) - def _get_delay_until(self, to_ts: Timestamp) -> Delay: - return _get_delay_between(self._get_current_ts(), to_ts) - - -def _get_delay_between(from_ts: Timestamp, to_ts: Timestamp) -> Delay: - return Delay(to_ts - from_ts) + def _next_send_ts_changed(self, next_send_ts: Optional[Timestamp]) -> bool: + # The DB alone knows if the next send time changed after adding/modifying + # a delayed event, but if we were to ever miss updating our delayed call's + # firing time, we may miss other updates. So, keep track of changes to the + # the next send time here instead of in the DB. + cached_next_send_ts = ( + int(self._next_delayed_event_call.getTime() * 1000) + if self._next_delayed_event_call is not None + else None + ) + return next_send_ts != cached_next_send_ts diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index d992aab0278..145ef2fcbd8 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -78,19 +78,18 @@ async def add_delayed_event( origin_server_ts: Optional[int], content: JsonDict, delay: int, - ) -> Tuple[DelayID, bool]: + ) -> Tuple[DelayID, Timestamp]: """ Inserts a new delayed event in the DB. Returns: The generated ID assigned to the added delayed event, - and whether the next delayed event is now this event instead. + and the send time of the next delayed event to be sent, + which is either the event just added or one added earlier. """ delay_id = _generate_delay_id() - send_ts = creation_ts + delay - - def add_delayed_event_txn(txn: LoggingTransaction) -> bool: - old_next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) + send_ts = Timestamp(creation_ts + delay) + def add_delayed_event_txn(txn: LoggingTransaction) -> Timestamp: self.db_pool.simple_insert_txn( txn, table="delayed_events", @@ -108,13 +107,15 @@ def add_delayed_event_txn(txn: LoggingTransaction) -> bool: }, ) - return old_next_send_ts is None or send_ts < old_next_send_ts + next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) + assert next_send_ts is not None + return next_send_ts - changed = await self.db_pool.runInteraction( + next_send_ts = await self.db_pool.runInteraction( "add_delayed_event", add_delayed_event_txn ) - return delay_id, changed + return delay_id, next_send_ts async def restart_delayed_event( self, @@ -122,7 +123,7 @@ async def restart_delayed_event( delay_id: str, user_localpart: str, current_ts: Timestamp, - ) -> Optional[Timestamp]: + ) -> Timestamp: """ Restarts the send time of the matching delayed event, as long as it hasn't already been marked for processing. @@ -132,9 +133,9 @@ async def restart_delayed_event( user_localpart: The localpart of the delayed event's owner. current_ts: The current time, which will be used to calculate the new send time. - Returns: None if the matching delayed event would not have been the next to be sent; - otherwise, the send time of the next delayed event to be sent (which may be the - matching delayed event, or another one sent before it). + Returns: The send time of the next delayed event to be sent, + which is either the event just restarted, or another one + with an earlier send time than the restarted one's new send time. Raises: NotFoundError: if there is no matching delayed event. @@ -142,49 +143,14 @@ async def restart_delayed_event( def restart_delayed_event_txn( txn: LoggingTransaction, - ) -> Optional[Timestamp]: + ) -> Timestamp: txn.execute( """ - SELECT delay_id, user_localpart, send_ts - FROM delayed_events - WHERE NOT is_processed - ORDER BY send_ts - LIMIT 2 - """ - ) - if txn.rowcount == 0: - # Return early if there are no delayed events at all - raise NotFoundError("Delayed event not found") - - if txn.rowcount == 1: - # The restarted event is the only event - changed = True - second_next_send_ts = None - else: - next_event, second_next_event = txn.fetchmany(2) - if ( - # The restarted event is the next to be sent - delay_id == next_event[0] - and user_localpart == next_event[1] - # The next two events to be sent aren't scheduled at the same time - and next_event[2] != second_next_event[2] - ): - changed = True - second_next_send_ts = Timestamp(second_next_event[2]) - else: - changed = False - - sql_base = """ UPDATE delayed_events SET send_ts = ? + delay WHERE delay_id = ? AND user_localpart = ? AND NOT is_processed - """ - if changed and self.database_engine.supports_returning: - sql_base += "RETURNING send_ts" - - txn.execute( - sql_base, + """, ( current_ts, delay_id, @@ -194,32 +160,9 @@ def restart_delayed_event_txn( if txn.rowcount == 0: raise NotFoundError("Delayed event not found") - if changed: - if self.database_engine.supports_returning: - row = txn.fetchone() - assert row is not None - restarted_send_ts = Timestamp(row[0]) - else: - restarted_send_ts = Timestamp( - self.db_pool.simple_select_one_onecol_txn( - txn, - table="delayed_events", - keyvalues={ - "delay_id": delay_id, - "user_localpart": user_localpart, - "is_processed": False, - }, - retcol="send_ts", - ) - ) - - return ( - restarted_send_ts - if second_next_send_ts is None - else min(restarted_send_ts, second_next_send_ts) - ) - - return None + next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) + assert next_send_ts is not None + return next_send_ts return await self.db_pool.runInteraction( "restart_delayed_event", restart_delayed_event_txn @@ -350,7 +293,6 @@ async def process_target_delayed_event( user_localpart: str, ) -> Tuple[ EventDetails, - bool, Optional[Timestamp], ]: """ @@ -362,7 +304,6 @@ async def process_target_delayed_event( user_localpart: The localpart of the delayed event's owner. Returns: The details of the matching delayed event, - whether the matching delayed event would have been the next to be sent, and the send time of the next delayed event to be sent, if any. Raises: @@ -373,7 +314,6 @@ def process_target_delayed_event_txn( txn: LoggingTransaction, ) -> Tuple[ EventDetails, - bool, Optional[Timestamp], ]: sql_cols = ", ".join( @@ -405,18 +345,16 @@ def process_target_delayed_event_txn( txn.execute(f"{sql_update} {sql_where}", sql_args) assert txn.rowcount == 1 - send_ts = Timestamp(row[4]) event = EventDetails( RoomID.from_string(row[0]), EventType(row[1]), StateKey(row[2]) if row[2] is not None else None, - Timestamp(row[3]) if row[3] is not None else send_ts, + Timestamp(row[3]) if row[3] is not None else Timestamp(row[4]), db_to_json(row[5]), DeviceID(row[6]) if row[6] is not None else None, ) - next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) - return event, next_send_ts != send_ts, next_send_ts + return event, self._get_next_delayed_event_send_ts_txn(txn) return await self.db_pool.runInteraction( "process_target_delayed_event", process_target_delayed_event_txn @@ -427,7 +365,7 @@ async def cancel_delayed_event( *, delay_id: str, user_localpart: str, - ) -> Tuple[bool, Optional[Timestamp]]: + ) -> Optional[Timestamp]: """ Cancels the matching delayed event, i.e. remove it as long as it hasn't been processed. @@ -435,8 +373,7 @@ async def cancel_delayed_event( delay_id: The ID of the delayed event to restart. user_localpart: The localpart of the delayed event's owner. - Returns: Whether the matching delayed event would have been the next to be sent, - and if so, what the next soonest send time is, if any. + Returns: The send time of the next delayed event to be sent, if any. Raises: NotFoundError: if there is no matching delayed event. @@ -444,12 +381,7 @@ async def cancel_delayed_event( def cancel_delayed_event_txn( txn: LoggingTransaction, - ) -> Tuple[bool, Optional[Timestamp]]: - old_next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) - if old_next_send_ts is None: - # Return early if there are no delayed events at all - raise NotFoundError("Delayed event not found") - + ) -> Optional[Timestamp]: try: self.db_pool.simple_delete_one_txn( txn, @@ -466,8 +398,7 @@ def cancel_delayed_event_txn( else: raise - new_next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) - return new_next_send_ts != old_next_send_ts, new_next_send_ts + return self._get_next_delayed_event_send_ts_txn(txn) return await self.db_pool.runInteraction( "cancel_delayed_event", cancel_delayed_event_txn @@ -479,37 +410,27 @@ async def cancel_delayed_state_events( room_id: str, event_type: str, state_key: str, - ) -> Tuple[bool, Optional[Timestamp]]: + ) -> Optional[Timestamp]: """ Cancels all matching delayed state events, i.e. remove them as long as they haven't been processed. - Returns: Whether any of the matching delayed events would have been the next to be sent, - and if so, what the next soonest send time is, if any. + Returns: The send time of the next delayed event to be sent, if any. """ def cancel_delayed_state_events_txn( txn: LoggingTransaction, - ) -> Tuple[bool, Optional[Timestamp]]: - old_next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) - - if ( - self.db_pool.simple_delete_txn( - txn, - table="delayed_events", - keyvalues={ - "room_id": room_id, - "event_type": event_type, - "state_key": state_key, - "is_processed": False, - }, - ) - > 0 - ): - assert old_next_send_ts is not None - new_next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) - return new_next_send_ts != old_next_send_ts, new_next_send_ts - - return False, None + ) -> Optional[Timestamp]: + self.db_pool.simple_delete_txn( + txn, + table="delayed_events", + keyvalues={ + "room_id": room_id, + "event_type": event_type, + "state_key": state_key, + "is_processed": False, + }, + ) + return self._get_next_delayed_event_send_ts_txn(txn) return await self.db_pool.runInteraction( "cancel_delayed_state_events", cancel_delayed_state_events_txn From fb048335bbc7beb3e86b9d5e4b01a2cedf898c2b Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Fri, 13 Sep 2024 01:21:50 -0400 Subject: [PATCH 71/79] Add more unit tests --- tests/rest/client/test_delayed_events.py | 134 +++++++++++++++++++++-- 1 file changed, 124 insertions(+), 10 deletions(-) diff --git a/tests/rest/client/test_delayed_events.py b/tests/rest/client/test_delayed_events.py index 3ff6a20a1cf..bd4ae5c26e1 100644 --- a/tests/rest/client/test_delayed_events.py +++ b/tests/rest/client/test_delayed_events.py @@ -1,26 +1,140 @@ """Tests REST events for /delayed_events paths.""" from http import HTTPStatus +from typing import List -from synapse.rest.client import delayed_events +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest.client import delayed_events, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests.unittest import HomeserverTestCase -PATH_PREFIX = b"/_matrix/client/unstable/org.matrix.msc4140/delayed_events" +_HS_NAME = "red" +_EVENT_TYPE = "com.example.test" class DelayedEventsTestCase(HomeserverTestCase): """Tests getting and managing delayed events.""" - servlets = [delayed_events.register_servlets] + servlets = [delayed_events.register_servlets, room.register_servlets] + user_id = f"@sid1:{_HS_NAME}" + + def default_config(self) -> JsonDict: + config = super().default_config() + config["server_name"] = _HS_NAME + config["max_event_delay_duration"] = "24h" + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.room_id = self.helper.create_room_as( + self.user_id, + extra_content={ + "preset": "trusted_private_chat", + }, + ) + + def test_delayed_events_empty_on_startup(self) -> None: + self.assertListEqual([], self._get_delayed_events()) + + def test_delayed_state_events_are_sent_on_timeout(self) -> None: + state_key = "to_send_on_timeout" + + setter_key = "setter" + setter_expected = "on_timeout" + channel = self.make_request( + "PUT", + _get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 900), + { + setter_key: setter_expected, + }, + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + events = self._get_delayed_events() + self.assertEqual(1, len(events), events) + content = self._get_delayed_event_content(events[0]) + self.assertEqual(setter_expected, content.get(setter_key), content) + self.helper.get_state( + self.room_id, + _EVENT_TYPE, + "", + state_key=state_key, + expect_code=HTTPStatus.NOT_FOUND, + ) - user_id = "@sid1:red" + self.reactor.advance(1) + self.assertListEqual([], self._get_delayed_events()) + content = self.helper.get_state( + self.room_id, + _EVENT_TYPE, + "", + state_key=state_key, + ) + self.assertEqual(setter_expected, content.get(setter_key), content) - def test_get_delayed_events(self) -> None: - channel = self.make_request("GET", PATH_PREFIX) + def test_delayed_state_events_are_cancelled_by_more_recent_state(self) -> None: + state_key = "to_be_cancelled" + + setter_key = "setter" + channel = self.make_request( + "PUT", + _get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 900), + { + setter_key: "on_timeout", + }, + ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) - self.assertEqual( - [], - channel.json_body.get("delayed_events"), - channel.json_body, + events = self._get_delayed_events() + self.assertEqual(1, len(events), events) + + setter_expected = "manual" + self.helper.send_state( + self.room_id, + _EVENT_TYPE, + { + setter_key: setter_expected, + }, + None, + state_key=state_key, ) + self.assertListEqual([], self._get_delayed_events()) + + self.reactor.advance(1) + content = self.helper.get_state( + self.room_id, + _EVENT_TYPE, + "", + state_key=state_key, + ) + self.assertEqual(setter_expected, content.get(setter_key), content) + + def _get_delayed_events(self) -> List[JsonDict]: + channel = self.make_request( + "GET", b"/_matrix/client/unstable/org.matrix.msc4140/delayed_events" + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + + key = "delayed_events" + self.assertIn(key, channel.json_body) + + events = channel.json_body[key] + self.assertIsInstance(events, list) + + return events + + def _get_delayed_event_content(self, event: JsonDict) -> JsonDict: + key = "content" + self.assertIn(key, event) + + content = event[key] + self.assertIsInstance(content, dict) + + return content + + +def _get_path_for_delayed_state( + room_id: str, event_type: str, state_key: str, delay_ms: int +) -> str: + return f"rooms/{room_id}/state/{event_type}/{state_key}?org.matrix.msc4140.delay={delay_ms}" From 3860f75041cdf17e5dd46f15aaae4432bbe80577 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Fri, 13 Sep 2024 10:32:22 -0400 Subject: [PATCH 72/79] Mention that GET /delayed_events supports workers --- docs/workers.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/workers.md b/docs/workers.md index fbf539fa7e8..51b22fef9bf 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -290,6 +290,7 @@ information. Additionally, the following REST endpoints can be handled for GET requests: ^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/ + ^/_matrix/client/unstable/org.matrix.msc4140/delayed_events Pagination requests can also be handled, but all requests for a given room must be routed to the same instance. Additionally, care must be taken to From 8ee15580067424253428fcf9570a31d8395b7893 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 16 Sep 2024 00:49:21 -0400 Subject: [PATCH 73/79] Use default ts for delayed events sent on request When a delayed event is sent on-demand, let its timestamp be set to whatever time the event is sent at, like non-delayed events. Only timed-out delayed events should have their timestamps set to their timeout time, as that is the time they are meant to be sent. --- synapse/handlers/delayed_events.py | 4 +++- synapse/storage/databases/main/delayed_events.py | 10 +++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index fe15efc232c..8acf9610d84 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -436,9 +436,11 @@ async def _send_event( "content": event.content, "room_id": event.room_id.to_string(), "sender": user_id_str, - "origin_server_ts": event.origin_server_ts, } + if event.origin_server_ts is not None: + event_dict["origin_server_ts"] = event.origin_server_ts + if event.state_key is not None: event_dict["state_key"] = event.state_key diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 145ef2fcbd8..7b07a5beafb 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -28,7 +28,7 @@ class EventDetails: room_id: RoomID type: EventType state_key: Optional[StateKey] - origin_server_ts: Timestamp + origin_server_ts: Optional[Timestamp] content: JsonDict device_id: Optional[DeviceID] @@ -271,6 +271,7 @@ def process_timeout_delayed_events_txn( RoomID.from_string(row[2]), EventType(row[3]), StateKey(row[4]) if row[4] is not None else None, + # If no custom_origin_ts is set, use send_ts as the event's timestamp Timestamp(row[5] if row[5] is not None else row[6]), db_to_json(row[7]), DeviceID(row[8]) if row[8] is not None else None, @@ -322,7 +323,6 @@ def process_target_delayed_event_txn( "event_type", "state_key", "origin_server_ts", - "send_ts", "content", "device_id", ) @@ -349,9 +349,9 @@ def process_target_delayed_event_txn( RoomID.from_string(row[0]), EventType(row[1]), StateKey(row[2]) if row[2] is not None else None, - Timestamp(row[3]) if row[3] is not None else Timestamp(row[4]), - db_to_json(row[5]), - DeviceID(row[6]) if row[6] is not None else None, + Timestamp(row[3]) if row[3] is not None else None, + db_to_json(row[4]), + DeviceID(row[5]) if row[5] is not None else None, ) return event, self._get_next_delayed_event_send_ts_txn(txn) From cbedade990fbecaf4327e1cca36ed4bc2a46c89d Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 16 Sep 2024 00:55:19 -0400 Subject: [PATCH 74/79] Don't use attr.asdict as it converts fields (namely RoomIDs) to dicts too --- synapse/handlers/delayed_events.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 8acf9610d84..9d59a099486 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -1,8 +1,6 @@ import logging from typing import TYPE_CHECKING, List, Optional, Set, Tuple -import attr - from twisted.internet.interfaces import IDelayedCall from synapse.api.constants import EventTypes @@ -333,10 +331,14 @@ async def send(self, requester: Requester, delay_id: str) -> None: await self._send_event( DelayedEventDetails( - # NOTE: mypy thinks that (*attr.astuple, ...) is too many args, so use kwargs instead delay_id=DelayID(delay_id), user_localpart=UserLocalpart(requester.user.localpart), - **attr.asdict(event), + room_id=event.room_id, + type=event.type, + state_key=event.state_key, + origin_server_ts=event.origin_server_ts, + content=event.content, + device_id=event.device_id, ) ) From 7ee57d89cd4f4a50967b82c6189a822d783e2fb2 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 16 Sep 2024 00:56:59 -0400 Subject: [PATCH 75/79] Fix path regex for delayed_events updating Require non-empty delay ID in path --- synapse/rest/client/delayed_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/rest/client/delayed_events.py b/synapse/rest/client/delayed_events.py index 63558db363f..5765ff2b09c 100644 --- a/synapse/rest/client/delayed_events.py +++ b/synapse/rest/client/delayed_events.py @@ -27,7 +27,7 @@ class _UpdateDelayedEventAction(Enum): # TODO: Needs unit testing class UpdateDelayedEventServlet(RestServlet): PATTERNS = client_patterns( - r"/org\.matrix\.msc4140/delayed_events/(?P[^/]*)$", + r"/org\.matrix\.msc4140/delayed_events/(?P[^/]+)$", releases=(), ) CATEGORY = "Delayed event management requests" From 10b9dee8616d6285570f9f6d561b400d5ba76dc2 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 16 Sep 2024 00:58:10 -0400 Subject: [PATCH 76/79] Fix SQL query Add missing WHERE clause in branched case --- synapse/storage/databases/main/delayed_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 7b07a5beafb..6ff337e511b 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -332,7 +332,7 @@ def process_target_delayed_event_txn( sql_args = (delay_id, user_localpart) txn.execute( ( - f"{sql_update} RETURNING {sql_cols}" + f"{sql_update} {sql_where} RETURNING {sql_cols}" if self.database_engine.supports_returning else f"SELECT {sql_cols} FROM delayed_events {sql_where}" ), From a723f6bd7b1587936c3c0f65ce2ebee45d94cb66 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 16 Sep 2024 00:59:27 -0400 Subject: [PATCH 77/79] Add more unit tests --- synapse/rest/client/delayed_events.py | 2 - tests/rest/client/test_delayed_events.py | 208 ++++++++++++++++++++++- 2 files changed, 207 insertions(+), 3 deletions(-) diff --git a/synapse/rest/client/delayed_events.py b/synapse/rest/client/delayed_events.py index 5765ff2b09c..eae5c9d2269 100644 --- a/synapse/rest/client/delayed_events.py +++ b/synapse/rest/client/delayed_events.py @@ -24,7 +24,6 @@ class _UpdateDelayedEventAction(Enum): SEND = "send" -# TODO: Needs unit testing class UpdateDelayedEventServlet(RestServlet): PATTERNS = client_patterns( r"/org\.matrix\.msc4140/delayed_events/(?P[^/]+)$", @@ -70,7 +69,6 @@ async def on_POST( return 200, {} -# TODO: Needs unit testing class DelayedEventsServlet(RestServlet): PATTERNS = client_patterns( r"/org\.matrix\.msc4140/delayed_events$", diff --git a/tests/rest/client/test_delayed_events.py b/tests/rest/client/test_delayed_events.py index bd4ae5c26e1..34d9fe79587 100644 --- a/tests/rest/client/test_delayed_events.py +++ b/tests/rest/client/test_delayed_events.py @@ -3,8 +3,11 @@ from http import HTTPStatus from typing import List +from parameterized import parameterized + from twisted.test.proto_helpers import MemoryReactor +from synapse.api.errors import Codes from synapse.rest.client import delayed_events, room from synapse.server import HomeServer from synapse.types import JsonDict @@ -12,6 +15,8 @@ from tests.unittest import HomeserverTestCase +PATH_PREFIX = "/_matrix/client/unstable/org.matrix.msc4140/delayed_events" + _HS_NAME = "red" _EVENT_TYPE = "com.example.test" @@ -74,6 +79,206 @@ def test_delayed_state_events_are_sent_on_timeout(self) -> None: ) self.assertEqual(setter_expected, content.get(setter_key), content) + def test_update_delayed_event_without_id(self) -> None: + channel = self.make_request( + "POST", + f"{PATH_PREFIX}/", + ) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result) + + def test_update_delayed_event_without_body(self) -> None: + channel = self.make_request( + "POST", + f"{PATH_PREFIX}/abc", + ) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) + self.assertEqual( + Codes.NOT_JSON, + channel.json_body["errcode"], + ) + + def test_update_delayed_event_without_action(self) -> None: + channel = self.make_request( + "POST", + f"{PATH_PREFIX}/abc", + {}, + ) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) + self.assertEqual( + Codes.MISSING_PARAM, + channel.json_body["errcode"], + ) + + def test_update_delayed_event_with_invalid_action(self) -> None: + channel = self.make_request( + "POST", + f"{PATH_PREFIX}/abc", + {"action": "oops"}, + ) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) + self.assertEqual( + Codes.INVALID_PARAM, + channel.json_body["errcode"], + ) + + @parameterized.expand(["cancel", "restart", "send"]) + def test_update_delayed_event_without_match(self, action: str) -> None: + channel = self.make_request( + "POST", + f"{PATH_PREFIX}/abc", + {"action": action}, + ) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result) + + def test_cancel_delayed_state_event(self) -> None: + state_key = "to_never_send" + + setter_key = "setter" + setter_expected = "none" + channel = self.make_request( + "PUT", + _get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 1500), + { + setter_key: setter_expected, + }, + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + delay_id = channel.json_body.get("delay_id") + self.assertIsNotNone(delay_id) + + self.reactor.advance(1) + events = self._get_delayed_events() + self.assertEqual(1, len(events), events) + content = self._get_delayed_event_content(events[0]) + self.assertEqual(setter_expected, content.get(setter_key), content) + self.helper.get_state( + self.room_id, + _EVENT_TYPE, + "", + state_key=state_key, + expect_code=HTTPStatus.NOT_FOUND, + ) + + channel = self.make_request( + "POST", + f"{PATH_PREFIX}/{delay_id}", + {"action": "cancel"}, + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + self.assertListEqual([], self._get_delayed_events()) + + self.reactor.advance(1) + content = self.helper.get_state( + self.room_id, + _EVENT_TYPE, + "", + state_key=state_key, + expect_code=HTTPStatus.NOT_FOUND, + ) + + def test_send_delayed_state_event(self) -> None: + state_key = "to_send_on_request" + + setter_key = "setter" + setter_expected = "on_send" + channel = self.make_request( + "PUT", + _get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 100000), + { + setter_key: setter_expected, + }, + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + delay_id = channel.json_body.get("delay_id") + self.assertIsNotNone(delay_id) + + self.reactor.advance(1) + events = self._get_delayed_events() + self.assertEqual(1, len(events), events) + content = self._get_delayed_event_content(events[0]) + self.assertEqual(setter_expected, content.get(setter_key), content) + self.helper.get_state( + self.room_id, + _EVENT_TYPE, + "", + state_key=state_key, + expect_code=HTTPStatus.NOT_FOUND, + ) + + channel = self.make_request( + "POST", + f"{PATH_PREFIX}/{delay_id}", + {"action": "send"}, + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + self.assertListEqual([], self._get_delayed_events()) + content = self.helper.get_state( + self.room_id, + _EVENT_TYPE, + "", + state_key=state_key, + ) + self.assertEqual(setter_expected, content.get(setter_key), content) + + def test_restart_delayed_state_event(self) -> None: + state_key = "to_send_on_restarted_timeout" + + setter_key = "setter" + setter_expected = "on_timeout" + channel = self.make_request( + "PUT", + _get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 1500), + { + setter_key: setter_expected, + }, + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + delay_id = channel.json_body.get("delay_id") + self.assertIsNotNone(delay_id) + + self.reactor.advance(1) + events = self._get_delayed_events() + self.assertEqual(1, len(events), events) + content = self._get_delayed_event_content(events[0]) + self.assertEqual(setter_expected, content.get(setter_key), content) + self.helper.get_state( + self.room_id, + _EVENT_TYPE, + "", + state_key=state_key, + expect_code=HTTPStatus.NOT_FOUND, + ) + + channel = self.make_request( + "POST", + f"{PATH_PREFIX}/{delay_id}", + {"action": "restart"}, + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + + self.reactor.advance(1) + events = self._get_delayed_events() + self.assertEqual(1, len(events), events) + content = self._get_delayed_event_content(events[0]) + self.assertEqual(setter_expected, content.get(setter_key), content) + self.helper.get_state( + self.room_id, + _EVENT_TYPE, + "", + state_key=state_key, + expect_code=HTTPStatus.NOT_FOUND, + ) + + self.reactor.advance(1) + self.assertListEqual([], self._get_delayed_events()) + content = self.helper.get_state( + self.room_id, + _EVENT_TYPE, + "", + state_key=state_key, + ) + self.assertEqual(setter_expected, content.get(setter_key), content) + def test_delayed_state_events_are_cancelled_by_more_recent_state(self) -> None: state_key = "to_be_cancelled" @@ -112,7 +317,8 @@ def test_delayed_state_events_are_cancelled_by_more_recent_state(self) -> None: def _get_delayed_events(self) -> List[JsonDict]: channel = self.make_request( - "GET", b"/_matrix/client/unstable/org.matrix.msc4140/delayed_events" + "GET", + PATH_PREFIX, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) From f32cf9c82813fb1316fc83a2ce3e4ff853329fcb Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 16 Sep 2024 16:21:09 -0400 Subject: [PATCH 78/79] Run Complement tests Requires matrix-org/complement#734 --- docker/complement/conf/workers-shared-extra.yaml.j2 | 3 +++ scripts-dev/complement.sh | 1 + 2 files changed, 4 insertions(+) diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index 6588b3ce147..b9334cc53bb 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -111,6 +111,9 @@ server_notices: system_mxid_avatar_url: "" room_name: "Server Alert" +# Enable delayed events (msc4140) +max_event_delay_duration: 24h + # Disable sync cache so that initial `/sync` requests are up-to-date. caches: diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 4ad547bc7e5..8fef1ae022f 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -223,6 +223,7 @@ test_packages=( ./tests/msc3930 ./tests/msc3902 ./tests/msc3967 + ./tests/msc4140 ) # Enable dirty runs, so tests will reuse the same container where possible. From dfde3c22db2490b8aa6b75813368d006ba32bf0c Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Wed, 18 Sep 2024 14:52:51 -0400 Subject: [PATCH 79/79] Order returned delayed events by send_ts --- .../storage/databases/main/delayed_events.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 6ff337e511b..1a7781713f4 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -174,22 +174,22 @@ async def get_all_delayed_events_for_user( ) -> List[JsonDict]: """Returns all pending delayed events owned by the given user.""" # TODO: Support Pagination stream API ("next_batch" field) - rows = await self.db_pool.simple_select_list( - table="delayed_events", - keyvalues={ - "user_localpart": user_localpart, - "is_processed": False, - }, - retcols=( - "delay_id", - "room_id", - "event_type", - "state_key", - "delay", - "send_ts", - "content", - ), - desc="get_all_delayed_events_for_user", + rows = await self.db_pool.execute( + "get_all_delayed_events_for_user", + """ + SELECT + delay_id, + room_id, + event_type, + state_key, + delay, + send_ts, + content + FROM delayed_events + WHERE user_localpart = ? AND NOT is_processed + ORDER BY send_ts + """, + user_localpart, ) return [ {