Skip to content

Commit

Permalink
Add serialization registry
Browse files Browse the repository at this point in the history
  • Loading branch information
sevdog committed Aug 21, 2024
1 parent 13cef45 commit b0173a2
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 61 deletions.
39 changes: 39 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ to 10, and all ``websocket.send!`` channels to 20:
If you want to enforce a matching order, use an ``OrderedDict`` as the
argument; channels will then be matched in the order the dict provides them.

.. _encryption
``symmetric_encryption_keys``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -237,6 +238,44 @@ And then in your channels consumer, you can implement the handler:
async def redis_disconnect(self, *args):
# Handle disconnect
``serializer_format``
~~~~~~~~~~~~~~~~~~~~~~
By default every message which reach redis is encoded using `msgpack <https://msgpack.org/>`_.
It is also possible to switch to `JSON <http://www.json.org/>`_:

.. code-block:: python
CHANNEL_LAYERS = {
"default": {
"BACKEND": "channels_redis.core.RedisChannelLayer",
"CONFIG": {
"hosts": ["redis://:[email protected]:6379/0"],
"serializer_format": "json",
},
},
}
A new serializer may be registered (or can be overriden) by using ``channels_redis.serializers.registry``,
providing a class which extends ``channels_redis.serializers.BaseMessageSerializer``, implementing ``dumps``
and ``loads`` methods, or which provides ``serialize``/``deserialize`` methods and calling the registration method on registry:

.. code-block:: python
from channels_redis.serializers import registry
class MyFormatSerializer:
def serialize(self, message):
...
def deserialize(self, message):
...
registry.register_serializer('myformat', MyFormatSerializer)
**NOTE**: Serializers also perform the encryption job see *symmetric_encryption_keys*.


Dependencies
------------

Expand Down
59 changes: 12 additions & 47 deletions channels_redis/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@
import hashlib
import itertools
import logging
import random
import time
import uuid

import msgpack
from redis import asyncio as aioredis

from channels.exceptions import ChannelFull
from channels.layers import BaseChannelLayer

from .serializers import registry
from .utils import (
_close_redis,
_consistent_hash,
Expand Down Expand Up @@ -115,6 +114,7 @@ def __init__(
capacity=100,
channel_capacity=None,
symmetric_encryption_keys=None,
serializer_format="msgpack",
):
# Store basic information
self.expiry = expiry
Expand All @@ -126,15 +126,23 @@ def __init__(
# Configure the host objects
self.hosts = decode_hosts(hosts)
self.ring_size = len(self.hosts)
# serialization
self._serializer = registry.get_serializer(
serializer_format,
# As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes.
random_prefix_length=12,
expiry=self.expiry,
symmetric_encryption_keys=symmetric_encryption_keys,
)
self.serialize = self._serializer.serialize
self.deserialize = self._serializer.deserialize
# Cached redis connection pools and the event loop they are from
self._layers = {}
# Normal channels choose a host index by cycling through the available hosts
self._receive_index_generator = itertools.cycle(range(len(self.hosts)))
self._send_index_generator = itertools.cycle(range(len(self.hosts)))
# Decide on a unique client prefix to use in ! sections
self.client_prefix = uuid.uuid4().hex
# Set up any encryption objects
self._setup_encryption(symmetric_encryption_keys)
# Number of coroutines trying to receive right now
self.receive_count = 0
# The receive lock
Expand All @@ -154,24 +162,6 @@ def __init__(
def create_pool(self, index):
return create_pool(self.hosts[index])

def _setup_encryption(self, symmetric_encryption_keys):
# See if we can do encryption if they asked
if symmetric_encryption_keys:
if isinstance(symmetric_encryption_keys, (str, bytes)):
raise ValueError(
"symmetric_encryption_keys must be a list of possible keys"
)
try:
from cryptography.fernet import MultiFernet
except ImportError:
raise ValueError(
"Cannot run with encryption without 'cryptography' installed."
)
sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys]
self.crypter = MultiFernet(sub_fernets)
else:
self.crypter = None

### Channel layer API ###

extensions = ["groups", "flush"]
Expand Down Expand Up @@ -650,31 +640,6 @@ def _group_key(self, group):
"""
return f"{self.prefix}:group:{group}".encode("utf8")

### Serialization ###

def serialize(self, message):
"""
Serializes message to a byte string.
"""
value = msgpack.packb(message, use_bin_type=True)
if self.crypter:
value = self.crypter.encrypt(value)

# As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes.
random_prefix = random.getrandbits(8 * 12).to_bytes(12, "big")
return random_prefix + value

def deserialize(self, message):
"""
Deserializes from a byte string.
"""
# Removes the random prefix
message = message[12:]

if self.crypter:
message = self.crypter.decrypt(message, self.expiry + 10)
return msgpack.unpackb(message, raw=False)

### Internal functions ###

def consistent_hash(self, value):
Expand Down
29 changes: 15 additions & 14 deletions channels_redis/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import logging
import uuid

import msgpack
from redis import asyncio as aioredis

from .serializers import registry
from .utils import (
_close_redis,
_consistent_hash,
Expand All @@ -25,10 +25,23 @@ async def _async_proxy(obj, name, *args, **kwargs):


class RedisPubSubChannelLayer:
def __init__(self, *args, **kwargs) -> None:
def __init__(
self,
*args,
symmetric_encryption_keys=None,
serializer_format="msgpack",
**kwargs,
) -> None:
self._args = args
self._kwargs = kwargs
self._layers = {}
# serialization
self._serializer = registry.get_serializer(
serializer_format,
symmetric_encryption_keys=symmetric_encryption_keys,
)
self.serialize = self._serializer.serialize
self.deserialize = self._serializer.deserialize

def __getattr__(self, name):
if name in (
Expand All @@ -44,18 +57,6 @@ def __getattr__(self, name):
else:
return getattr(self._get_layer(), name)

def serialize(self, message):
"""
Serializes message to a byte string.
"""
return msgpack.packb(message)

def deserialize(self, message):
"""
Deserializes from a byte string.
"""
return msgpack.unpackb(message)

def _get_layer(self):
loop = asyncio.get_running_loop()

Expand Down
141 changes: 141 additions & 0 deletions channels_redis/serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import json
import random
import abc


class SerializerDoesNotExist(KeyError):
"""The requested serializer was not found."""


class BaseMessageSerializer(abc.ABC):

def __init__(
self,
symmetric_encryption_keys=None,
random_prefix_length=0,
expiry=None,
):
self.random_prefix_length = random_prefix_length
self.expiry = expiry
# Set up any encryption objects
self._setup_encryption(symmetric_encryption_keys)

def _setup_encryption(self, symmetric_encryption_keys):
# See if we can do encryption if they asked
if symmetric_encryption_keys:
if isinstance(symmetric_encryption_keys, (str, bytes)):
raise ValueError(
"symmetric_encryption_keys must be a list of possible keys"
)
try:
from cryptography.fernet import MultiFernet
except ImportError:
raise ValueError(
"Cannot run with encryption without 'cryptography' installed."
)
sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys]
self.crypter = MultiFernet(sub_fernets)
else:
self.crypter = None

@abc.abstractmethod
def dumps(self, message):
raise NotImplementedError

@abc.abstractmethod
def loads(self, message):
raise NotImplementedError

def serialize(self, message):
"""
Serializes message to a byte string.
"""
message = self.dumps(message)
# ensure message is bytes
if isinstance(message, str):
message = message.encode("utf-8")
if self.crypter:
message = self.crypter.encrypt(message)

if self.random_prefix_length > 0:
# provide random prefix
message = (
random.getrandbits(8 * self.random_prefix_length).to_bytes(
self.random_prefix_length, "big"
)
+ message
)
return message

def deserialize(self, message):
"""
Deserializes from a byte string.
"""
if self.random_prefix_length > 0:
# Removes the random prefix
message = message[self.random_prefix_length :] # noqa: E203

if self.crypter:
ttl = self.expiry if self.expiry is None else self.expiry + 10
message = self.crypter.decrypt(message, ttl)
return self.loads(message)


class MissingSerializer(BaseMessageSerializer):
exception = None

def __init__(self, *args, **kwargs):
raise self.exception


class JSONSerializer(BaseMessageSerializer):
dumps = staticmethod(json.dumps)
loads = staticmethod(json.loads)


# code ready for a future in which msgpack may become an optional dependency
try:
import msgpack
except ImportError as exc:

class MsgPackSerializer(MissingSerializer):
exception = exc

else:

class MsgPackSerializer(BaseMessageSerializer):
dumps = staticmethod(msgpack.packb)
loads = staticmethod(msgpack.unpackb)


class SerializersRegistry:
def __init__(self):
self._registry = {}

def register_serializer(self, format, serializer_class):
"""
Register a new serializer for given format
"""
assert isinstance(serializer_class, type) and (
issubclass(serializer_class, BaseMessageSerializer)
or hasattr(serializer_class, "serialize")
and hasattr(serializer_class, "deserialize")
), """
`serializer_class` should be a class which implements `serialize` and `deserialize` method
or a subclass of `channels_redis.serializers.BaseMessageSerializer`
"""

self._registry[format] = serializer_class

def get_serializer(self, format, *args, **kwargs):
try:
serializer_class = self._registry[format]
except KeyError:
raise SerializerDoesNotExist(format)

return serializer_class(*args, **kwargs)


registry = SerializersRegistry()
registry.register_serializer("json", JSONSerializer)
registry.register_serializer("msgpack", MsgPackSerializer)
Loading

0 comments on commit b0173a2

Please sign in to comment.