Skip to content

Commit

Permalink
Parse json validation (#16923)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrew Morgan <[email protected]>
  • Loading branch information
TrevisGordan and anoadragon453 authored Apr 18, 2024
1 parent 09f0957 commit 1d47532
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 47 deletions.
1 change: 1 addition & 0 deletions changelog.d/16923.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Return `400 M_NOT_JSON` upon receiving invalid JSON in query parameters across various client and admin endpoints, rather than an internal server error.
82 changes: 82 additions & 0 deletions synapse/http/servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import enum
import logging
import urllib.parse as urlparse
from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -450,6 +451,87 @@ def parse_string(
)


def parse_json(
request: Request,
name: str,
default: Optional[dict] = None,
required: bool = False,
encoding: str = "ascii",
) -> Optional[JsonDict]:
"""
Parse a JSON parameter from the request query string.
Args:
request: the twisted HTTP request.
name: the name of the query parameter.
default: value to use if the parameter is absent,
defaults to None.
required: whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
encoding: The encoding to decode the string content with.
Returns:
A JSON value, or `default` if the named query parameter was not found
and `required` was False.
Raises:
SynapseError if the parameter is absent and required, or if the
parameter is present and not a JSON object.
"""
args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore
return parse_json_from_args(
args,
name,
default,
required=required,
encoding=encoding,
)


def parse_json_from_args(
args: Mapping[bytes, Sequence[bytes]],
name: str,
default: Optional[dict] = None,
required: bool = False,
encoding: str = "ascii",
) -> Optional[JsonDict]:
"""
Parse a JSON parameter from the request query string.
Args:
args: a mapping of request args as bytes to a list of bytes (e.g. request.args).
name: the name of the query parameter.
default: value to use if the parameter is absent,
defaults to None.
required: whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
encoding: the encoding to decode the string content with.
A JSON value, or `default` if the named query parameter was not found
and `required` was False.
Raises:
SynapseError if the parameter is absent and required, or if the
parameter is present and not a JSON object.
"""
name_bytes = name.encode("ascii")

if name_bytes not in args:
if not required:
return default

message = f"Missing required integer query parameter {name}"
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM)

json_str = parse_string_from_args(args, name, required=True, encoding=encoding)

try:
return json_decoder.decode(urlparse.unquote(json_str))
except Exception:
message = f"Query parameter {name} must be a valid JSON object"
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.NOT_JSON)


EnumT = TypeVar("EnumT", bound=enum.Enum)


Expand Down
36 changes: 12 additions & 24 deletions synapse/rest/admin/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from urllib import parse as urlparse

import attr

Expand All @@ -38,6 +37,7 @@
assert_params_in_dict,
parse_enum,
parse_integer,
parse_json,
parse_json_object_from_request,
parse_string,
)
Expand All @@ -51,7 +51,6 @@
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, RoomID, ScheduledTask, UserID, create_requester
from synapse.types.state import StateFilter
from synapse.util import json_decoder

if TYPE_CHECKING:
from synapse.api.auth import Auth
Expand Down Expand Up @@ -776,14 +775,8 @@ async def on_GET(
limit = parse_integer(request, "limit", default=10)

# picking the API shape for symmetry with /messages
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
else:
event_filter = None
filter_json = parse_json(request, "filter", encoding="utf-8")
event_filter = Filter(self._hs, filter_json) if filter_json else None

event_context = await self.room_context_handler.get_event_context(
requester,
Expand Down Expand Up @@ -914,21 +907,16 @@ async def on_GET(
)
# Twisted will have processed the args by now.
assert request.args is not None

filter_json = parse_json(request, "filter", encoding="utf-8")
event_filter = Filter(self._hs, filter_json) if filter_json else None

as_client_event = b"raw" not in request.args
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
if (
event_filter
and event_filter.filter_json.get("event_format", "client")
== "federation"
):
as_client_event = False
else:
event_filter = None
if (
event_filter
and event_filter.filter_json.get("event_format", "client") == "federation"
):
as_client_event = False

msgs = await self._pagination_handler.get_messages(
room_id=room_id,
Expand Down
35 changes: 12 additions & 23 deletions synapse/rest/client/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
parse_boolean,
parse_enum,
parse_integer,
parse_json,
parse_json_object_from_request,
parse_string,
parse_strings_from_args,
Expand All @@ -65,7 +66,6 @@
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID
from synapse.types.state import StateFilter
from synapse.util import json_decoder
from synapse.util.cancellation import cancellable
from synapse.util.stringutils import parse_and_validate_server_name, random_string

Expand Down Expand Up @@ -703,21 +703,16 @@ async def on_GET(
)
# Twisted will have processed the args by now.
assert request.args is not None

filter_json = parse_json(request, "filter", encoding="utf-8")
event_filter = Filter(self._hs, filter_json) if filter_json else None

as_client_event = b"raw" not in request.args
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
if (
event_filter
and event_filter.filter_json.get("event_format", "client")
== "federation"
):
as_client_event = False
else:
event_filter = None
if (
event_filter
and event_filter.filter_json.get("event_format", "client") == "federation"
):
as_client_event = False

msgs = await self.pagination_handler.get_messages(
room_id=room_id,
Expand Down Expand Up @@ -898,14 +893,8 @@ async def on_GET(
limit = parse_integer(request, "limit", default=10)

# picking the API shape for symmetry with /messages
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
else:
event_filter = None
filter_json = parse_json(request, "filter", encoding="utf-8")
event_filter = Filter(self._hs, filter_json) if filter_json else None

event_context = await self.room_context_handler.get_event_context(
requester, room_id, event_id, limit, event_filter
Expand Down
61 changes: 61 additions & 0 deletions tests/rest/admin/test_room.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import json
import time
import urllib.parse
from http import HTTPStatus
from typing import List, Optional
from unittest.mock import AsyncMock, Mock

Expand Down Expand Up @@ -2190,6 +2191,33 @@ def test_room_messages_purge(self) -> None:
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])

def test_room_message_filter_query_validation(self) -> None:
# Test json validation in (filter) query parameter.
# Does not test the validity of the filter, only the json validation.

# Check Get with valid json filter parameter, expect 200.
valid_filter_str = '{"types": ["m.room.message"]}'
channel = self.make_request(
"GET",
f"/_synapse/admin/v1/rooms/{self.room_id}/messages?dir=b&filter={valid_filter_str}",
access_token=self.admin_user_tok,
)

self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)

# Check Get with invalid json filter parameter, expect 400 NOT_JSON.
invalid_filter_str = "}}}{}"
channel = self.make_request(
"GET",
f"/_synapse/admin/v1/rooms/{self.room_id}/messages?dir=b&filter={invalid_filter_str}",
access_token=self.admin_user_tok,
)

self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
self.assertEqual(
channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
)


class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
servlets = [
Expand Down Expand Up @@ -2522,6 +2550,39 @@ def test_context_as_admin(self) -> None:
else:
self.fail("Event %s from events_after not found" % j)

def test_room_event_context_filter_query_validation(self) -> None:
# Test json validation in (filter) query parameter.
# Does not test the validity of the filter, only the json validation.

# Create a user with room and event_id.
user_id = self.register_user("test", "test")
user_tok = self.login("test", "test")
room_id = self.helper.create_room_as(user_id, tok=user_tok)
event_id = self.helper.send(room_id, "message 1", tok=user_tok)["event_id"]

# Check Get with valid json filter parameter, expect 200.
valid_filter_str = '{"types": ["m.room.message"]}'
channel = self.make_request(
"GET",
f"/_synapse/admin/v1/rooms/{room_id}/context/{event_id}?filter={valid_filter_str}",
access_token=self.admin_user_tok,
)

self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)

# Check Get with invalid json filter parameter, expect 400 NOT_JSON.
invalid_filter_str = "}}}{}"
channel = self.make_request(
"GET",
f"/_synapse/admin/v1/rooms/{room_id}/context/{event_id}?filter={invalid_filter_str}",
access_token=self.admin_user_tok,
)

self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
self.assertEqual(
channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
)


class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
servlets = [
Expand Down
52 changes: 52 additions & 0 deletions tests/rest/client/test_rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2175,6 +2175,31 @@ def test_room_messages_purge(self) -> None:
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])

def test_room_message_filter_query_validation(self) -> None:
# Test json validation in (filter) query parameter.
# Does not test the validity of the filter, only the json validation.

# Check Get with valid json filter parameter, expect 200.
valid_filter_str = '{"types": ["m.room.message"]}'
channel = self.make_request(
"GET",
f"/rooms/{self.room_id}/messages?access_token=x&dir=b&filter={valid_filter_str}",
)

self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)

# Check Get with invalid json filter parameter, expect 400 NOT_JSON.
invalid_filter_str = "}}}{}"
channel = self.make_request(
"GET",
f"/rooms/{self.room_id}/messages?access_token=x&dir=b&filter={invalid_filter_str}",
)

self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
self.assertEqual(
channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
)


class RoomMessageFilterTestCase(RoomBase):
"""Tests /rooms/$room_id/messages REST events."""
Expand Down Expand Up @@ -3213,6 +3238,33 @@ def test_erased_sender(self) -> None:
self.assertDictEqual(events_after[0].get("content"), {}, events_after[0])
self.assertEqual(events_after[1].get("content"), {}, events_after[1])

def test_room_event_context_filter_query_validation(self) -> None:
# Test json validation in (filter) query parameter.
# Does not test the validity of the filter, only the json validation.
event_id = self.helper.send(self.room_id, "message 7", tok=self.tok)["event_id"]

# Check Get with valid json filter parameter, expect 200.
valid_filter_str = '{"types": ["m.room.message"]}'
channel = self.make_request(
"GET",
f"/rooms/{self.room_id}/context/{event_id}?filter={valid_filter_str}",
access_token=self.tok,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)

# Check Get with invalid json filter parameter, expect 400 NOT_JSON.
invalid_filter_str = "}}}{}"
channel = self.make_request(
"GET",
f"/rooms/{self.room_id}/context/{event_id}?filter={invalid_filter_str}",
access_token=self.tok,
)

self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
self.assertEqual(
channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
)


class RoomAliasListTestCase(unittest.HomeserverTestCase):
servlets = [
Expand Down

0 comments on commit 1d47532

Please sign in to comment.