Skip to content

Commit

Permalink
Replace Ypy with pycrdt
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Oct 13, 2023
1 parent 1167c29 commit 0881506
Show file tree
Hide file tree
Showing 20 changed files with 1,941 additions and 17 deletions.
22 changes: 11 additions & 11 deletions plugins/yjs/fps_yjs/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@
WebSocketDisconnect,
status,
)
from jupyter_ydoc import ydocs as YDOCS
from jupyter_ydoc.ybasedoc import YBaseDoc
from .ydocs import ydocs as YDOCS
from .ydocs.ybasedoc import YBaseDoc
from jupyverse_api.app import App
from jupyverse_api.auth import Auth, User
from jupyverse_api.contents import Contents
from jupyverse_api.yjs import Yjs
from jupyverse_api.yjs.models import CreateDocumentSession
from websockets.exceptions import ConnectionClosedOK
from ypy_websocket.websocket_server import WebsocketServer, YRoom
from ypy_websocket.ystore import SQLiteYStore, YDocNotFound
from ypy_websocket.yutils import YMessageType, YSyncMessageType
from .ywebsocket.websocket_server import WebsocketServer, YRoom
from .ywebsocket.ystore import SQLiteYStore, YDocNotFound
from .ywebsocket.yutils import YMessageType, YSyncMessageType

YFILE = YDOCS["file"]
AWARENESS = 1
Expand Down Expand Up @@ -56,8 +56,8 @@ async def collaboration_room_websocket(
return
websocket, permissions = websocket_permissions
await websocket.accept()
ypy_websocket = YpyWebsocket(websocket, path)
await self.room_manager.serve(ypy_websocket, permissions)
ywebsocket = YWebsocket(websocket, path)
await self.room_manager.serve(ywebsocket, permissions)

async def create_roomid(
self,
Expand Down Expand Up @@ -95,8 +95,8 @@ def to_datetime(iso_date: str) -> datetime:
return datetime.fromisoformat(iso_date.rstrip("Z"))


class YpyWebsocket:
"""An wrapper to make a Starlette's WebSocket look like a ypy-websocket's WebSocket"""
class YWebsocket:
"""An wrapper to make a Starlette's WebSocket look like a ywebsocket's WebSocket"""

def __init__(self, websocket, path: str):
self._websocket = websocket
Expand Down Expand Up @@ -160,7 +160,7 @@ def stop(self):
cleaner.cancel()
self.websocket_server.stop()

async def serve(self, websocket: YpyWebsocket, permissions) -> None:
async def serve(self, websocket: YWebsocket, permissions) -> None:
room = await self.websocket_server.get_room(websocket.path)
can_write = permissions is None or "write" in permissions.get("yjs", [])
room.on_message = partial(self.filter_message, can_write)
Expand Down Expand Up @@ -309,7 +309,7 @@ async def maybe_save_document(
# if the room cannot be found, don't save
try:
file_path = await self.get_file_path(file_id, document)
except BaseException:
except Exception:
return
assert file_path is not None
async with self.lock:
Expand Down
9 changes: 9 additions & 0 deletions plugins/yjs/fps_yjs/ydocs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import sys


if sys.version_info < (3, 10):
from importlib_metadata import entry_points
else:
from importlib.metadata import entry_points

ydocs = {ep.name: ep.load() for ep in entry_points(group="jupyverse_ydoc")}
26 changes: 26 additions & 0 deletions plugins/yjs/fps_yjs/ydocs/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Dict, List, Type, Union

INT = Type[int]
FLOAT = Type[float]


def cast_all(
o: Union[List, Dict], from_type: Union[INT, FLOAT], to_type: Union[FLOAT, INT]
) -> Union[List, Dict]:
if isinstance(o, list):
for i, v in enumerate(o):
if type(v) is from_type:
v2 = to_type(v)
if v == v2:
o[i] = v2
elif isinstance(v, (list, dict)):
cast_all(v, from_type, to_type)
elif isinstance(o, dict):
for k, v in o.items():
if type(v) is from_type:
v2 = to_type(v)
if v == v2:
o[k] = v2
elif isinstance(v, (list, dict)):
cast_all(v, from_type, to_type)
return o
69 changes: 69 additions & 0 deletions plugins/yjs/fps_yjs/ydocs/ybasedoc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional

from pycrdt import Doc, Map


class YBaseDoc(ABC):
def __init__(self, ydoc: Optional[Doc] = None):
if ydoc is None:
self._ydoc = Doc()
else:
self._ydoc = ydoc
self._ystate = Map()
self._ydoc["state"] = self._ystate
self._subscriptions: Dict[Any, str] = {}

@property
@abstractmethod
def version(self) -> str:
...

@property
def ystate(self) -> Map:
return self._ystate

@property
def ydoc(self) -> Doc:
return self._ydoc

@property
def source(self) -> Any:
return self.get()

@source.setter
def source(self, value: Any):
return self.set(value)

@property
def dirty(self) -> Optional[bool]:
return self._ystate.get("dirty")

@dirty.setter
def dirty(self, value: bool) -> None:
self._ystate["dirty"] = value

@property
def path(self) -> Optional[str]:
return self._ystate.get("path")

@path.setter
def path(self, value: str) -> None:
self._ystate["path"] = value

@abstractmethod
def get(self) -> Any:
...

@abstractmethod
def set(self, value: Any) -> None:
...

@abstractmethod
def observe(self, callback: Callable[[str, Any], None]) -> None:
...

def unobserve(self) -> None:
for k, v in self._subscriptions.items():
k.unobserve(v)
self._subscriptions = {}
39 changes: 39 additions & 0 deletions plugins/yjs/fps_yjs/ydocs/yblob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import base64
from functools import partial
from typing import Any, Callable, Optional, Union

from pycrdt import Doc, Map

from .ybasedoc import YBaseDoc


class YBlob(YBaseDoc):
"""
Extends :class:`YBaseDoc`, and represents a blob document.
It is currently encoded as base64 because of:
https://github.com/y-crdt/ypy/issues/108#issuecomment-1377055465
The Y document can be set from bytes or from str, in which case it is assumed to be encoded as
base64.
"""

def __init__(self, ydoc: Optional[Doc] = None):
super().__init__(ydoc)
self._ysource = Map()
self._ydoc["source"] = self._ysource

@property
def version(self) -> str:
return "1.0.0"

def get(self) -> bytes:
return base64.b64decode(self._ysource["base64"].encode())

def set(self, value: Union[bytes, str]) -> None:
if isinstance(value, bytes):
value = base64.b64encode(value).decode()
self._ysource["base64"] = value

def observe(self, callback: Callable[[str, Any], None]) -> None:
self.unobserve()
self._subscriptions[self._ystate] = self._ystate.observe(partial(callback, "state"))
self._subscriptions[self._ysource] = self._ysource.observe(partial(callback, "source"))
5 changes: 5 additions & 0 deletions plugins/yjs/fps_yjs/ydocs/yfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .yunicode import YUnicode


class YFile(YUnicode): # for backwards-compatibility
pass
144 changes: 144 additions & 0 deletions plugins/yjs/fps_yjs/ydocs/ynotebook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import copy
import json
from functools import partial
from typing import Any, Callable, Dict, Optional
from uuid import uuid4

from pycrdt import Array, Doc, Map, Text

from .utils import cast_all
from .ybasedoc import YBaseDoc

# The default major version of the notebook format.
NBFORMAT_MAJOR_VERSION = 4
# The default minor version of the notebook format.
NBFORMAT_MINOR_VERSION = 5


class YNotebook(YBaseDoc):
def __init__(self, ydoc: Optional[Doc] = None):
super().__init__(ydoc)
self._ymeta = Map()
self._ycells = Array()
self._ydoc["meta"] = self._ymeta
self._ydoc["cells"] = self._ycells

@property
def version(self) -> str:
return "1.0.0"

@property
def ycells(self):
return self._ycells

@property
def cell_number(self) -> int:
return len(self._ycells)

def get_cell(self, index: int) -> Dict[str, Any]:
meta = json.loads(str(self._ymeta))
cell = json.loads(str(self._ycells[index]))
cast_all(cell, float, int) # cells coming from Yjs have e.g. execution_count as float
if "id" in cell and meta["nbformat"] == 4 and meta["nbformat_minor"] <= 4:
# strip cell IDs if we have notebook format 4.0-4.4
del cell["id"]
if (
"attachments" in cell
and cell["cell_type"] in ("raw", "markdown")
and not cell["attachments"]
):
del cell["attachments"]
return cell

def append_cell(self, value: Dict[str, Any]) -> None:
ycell = self.create_ycell(value)
self._ycells.append(ycell)

def set_cell(self, index: int, value: Dict[str, Any]) -> None:
ycell = self.create_ycell(value)
self.set_ycell(index, ycell)

def create_ycell(self, value: Dict[str, Any]) -> Map:
cell = copy.deepcopy(value)
if "id" not in cell:
cell["id"] = str(uuid4())
cell_type = cell["cell_type"]
cell_source = cell["source"]
cell_source = "".join(cell_source) if isinstance(cell_source, list) else cell_source
cell["source"] = Text(cell_source)
cell["metadata"] = Map(cell.get("metadata", {}))

if cell_type in ("raw", "markdown"):
if "attachments" in cell and not cell["attachments"]:
del cell["attachments"]
elif cell_type == "code":
cell["outputs"] = Array(cell.get("outputs", []))

return Map(cell)

def set_ycell(self, index: int, ycell: Map) -> None:
self._ycells[index] = ycell

def get(self) -> Dict:
meta = json.loads(str(self._ymeta))
cast_all(meta, float, int) # notebook coming from Yjs has e.g. nbformat as float
cells = []
for i in range(len(self._ycells)):
cell = self.get_cell(i)
if "id" in cell and meta["nbformat"] == 4 and meta["nbformat_minor"] <= 4:
# strip cell IDs if we have notebook format 4.0-4.4
del cell["id"]
if (
"attachments" in cell
and cell["cell_type"] in ["raw", "markdown"]
and not cell["attachments"]
):
del cell["attachments"]
cells.append(cell)

return dict(
cells=cells,
metadata=meta.get("metadata", {}),
nbformat=int(meta.get("nbformat", 0)),
nbformat_minor=int(meta.get("nbformat_minor", 0)),
)

def set(self, value: Dict) -> None:
nb_without_cells = {key: value[key] for key in value.keys() if key != "cells"}
nb = copy.deepcopy(nb_without_cells)
cast_all(nb, int, float) # Yjs expects numbers to be floating numbers
cells = value["cells"] or [
{
"cell_type": "code",
"execution_count": None,
# auto-created empty code cell without outputs ought be trusted
"metadata": {"trusted": True},
"outputs": [],
"source": "",
"id": str(uuid4()),
}
]

with self._ydoc.transaction():
# clear document
self._ymeta.clear()
self._ycells.clear()
for key in [k for k in self._ystate.keys() if k not in ("dirty", "path")]:
del self._ystate[key]

# initialize document
self._ycells.extend([self.create_ycell(cell) for cell in cells])
self._ymeta["nbformat"] = nb.get("nbformat", NBFORMAT_MAJOR_VERSION)
self._ymeta["nbformat_minor"] = nb.get("nbformat_minor", NBFORMAT_MINOR_VERSION)

metadata = nb.get("metadata", {})
metadata.setdefault("language_info", {"name": ""})
metadata.setdefault("kernelspec", {"name": "", "display_name": ""})

self._ymeta["metadata"] = Map(metadata)

def observe(self, callback: Callable[[str, Any], None]) -> None:
self.unobserve()
self._subscriptions[self._ystate] = self._ystate.observe(partial(callback, "state"))
self._subscriptions[self._ymeta] = self._ymeta.observe_deep(partial(callback, "meta"))
self._subscriptions[self._ycells] = self._ycells.observe_deep(partial(callback, "cells"))
33 changes: 33 additions & 0 deletions plugins/yjs/fps_yjs/ydocs/yunicode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from functools import partial
from typing import Any, Callable, Optional

from pycrdt import Doc, Text

from .ybasedoc import YBaseDoc


class YUnicode(YBaseDoc):
def __init__(self, ydoc: Optional[Doc] = None):
super().__init__(ydoc)
self._ysource = Text()
self._ydoc["source"] = self._ysource

@property
def version(self) -> str:
return "1.0.0"

def get(self) -> str:
return str(self._ysource)

def set(self, value: str) -> None:
with self._ydoc.transaction():
# clear document
del self._ysource[:]
# initialize document
if value:
self._ysource += value

def observe(self, callback: Callable[[str, Any], None]) -> None:
self.unobserve()
self._subscriptions[self._ystate] = self._ystate.observe(partial(callback, "state"))
self._subscriptions[self._ysource] = self._ysource.observe(partial(callback, "source"))
4 changes: 4 additions & 0 deletions plugins/yjs/fps_yjs/ywebsocket/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .asgi_server import ASGIServer as ASGIServer
from .websocket_provider import WebsocketProvider as WebsocketProvider
from .websocket_server import WebsocketServer as WebsocketServer, YRoom as YRoom
from .yutils import YMessageType as YMessageType
Loading

0 comments on commit 0881506

Please sign in to comment.