Skip to content

Commit

Permalink
Implement server-side ypywidgets rendering
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Nov 6, 2023
1 parent db6bff4 commit c5dbc6f
Show file tree
Hide file tree
Showing 13 changed files with 491 additions and 55 deletions.
141 changes: 101 additions & 40 deletions plugins/kernels/fps_kernels/kernel_driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import uuid
from typing import Any, Dict, List, Optional, cast

from pycrdt import Array, Map

from jupyverse_api.yjs import Yjs

from .connect import cfg_t, connect_channel, launch_kernel, read_connection_file
from .connect import write_connection_file as _write_connection_file
from .kernelspec import find_kernelspec
Expand All @@ -23,10 +27,12 @@ def __init__(
connection_file: str = "",
write_connection_file: bool = True,
capture_kernel_output: bool = True,
yjs: Optional[Yjs] = None,
) -> None:
self.capture_kernel_output = capture_kernel_output
self.kernelspec_path = kernelspec_path or find_kernelspec(kernel_name)
self.kernel_cwd = kernel_cwd
self.yjs = yjs
if not self.kernelspec_path:
raise RuntimeError("Could not find a kernel, maybe you forgot to install one?")
if write_connection_file:
Expand All @@ -37,11 +43,12 @@ def __init__(
self.key = cast(str, self.connection_cfg["key"])
self.session_id = uuid.uuid4().hex
self.msg_cnt = 0
self.execute_requests: Dict[str, Dict[str, asyncio.Future]] = {}
self.channel_tasks: List[asyncio.Task] = []
self.execute_requests: Dict[str, Dict[str, asyncio.Queue]] = {}
self.comm_messages = asyncio.Queue()
self.tasks: List[asyncio.Task] = []

async def restart(self, startup_timeout: float = float("inf")) -> None:
for task in self.channel_tasks:
for task in self.tasks:
task.cancel()
msg = create_message("shutdown_request", content={"restart": True})
await send_message(msg, self.control_channel, self.key, change_date_to_str=True)
Expand All @@ -52,7 +59,7 @@ async def restart(self, startup_timeout: float = float("inf")) -> None:
if msg["msg_type"] == "shutdown_reply" and msg["content"]["restart"]:
break
await self._wait_for_ready(startup_timeout)
self.channel_tasks = []
self.tasks = []
self.listen_channels()

async def start(self, startup_timeout: float = float("inf"), connect: bool = True) -> None:
Expand All @@ -69,6 +76,7 @@ async def connect(self, startup_timeout: float = float("inf")) -> None:
self.connect_channels()
await self._wait_for_ready(startup_timeout)
self.listen_channels()
self.tasks.append(asyncio.create_task(self._handle_comms()))

def connect_channels(self, connection_cfg: Optional[cfg_t] = None):
connection_cfg = connection_cfg or self.connection_cfg
Expand All @@ -77,40 +85,42 @@ def connect_channels(self, connection_cfg: Optional[cfg_t] = None):
self.iopub_channel = connect_channel("iopub", connection_cfg)

def listen_channels(self):
self.channel_tasks.append(asyncio.create_task(self.listen_iopub()))
self.channel_tasks.append(asyncio.create_task(self.listen_shell()))
self.tasks.append(asyncio.create_task(self.listen_iopub()))
self.tasks.append(asyncio.create_task(self.listen_shell()))

async def stop(self) -> None:
self.kernel_process.kill()
await self.kernel_process.wait()
os.remove(self.connection_file_path)
for task in self.channel_tasks:
for task in self.tasks:
task.cancel()

async def listen_iopub(self):
while True:
msg = await receive_message(self.iopub_channel, change_str_to_date=True)
msg_id = msg["parent_header"].get("msg_id")
if msg_id in self.execute_requests.keys():
self.execute_requests[msg_id]["iopub_msg"].set_result(msg)
parent_id = msg["parent_header"].get("msg_id")
if msg["msg_type"] in ("comm_open", "comm_msg"):
self.comm_messages.put_nowait(msg)
elif parent_id in self.execute_requests.keys():
self.execute_requests[parent_id]["iopub_msg"].put_nowait(msg)

async def listen_shell(self):
while True:
msg = await receive_message(self.shell_channel, change_str_to_date=True)
msg_id = msg["parent_header"].get("msg_id")
if msg_id in self.execute_requests.keys():
self.execute_requests[msg_id]["shell_msg"].set_result(msg)
self.execute_requests[msg_id]["shell_msg"].put_nowait(msg)

async def execute(
self,
cell: Dict[str, Any],
ycell: Map,
timeout: float = float("inf"),
msg_id: str = "",
wait_for_executed: bool = True,
) -> None:
if cell["cell_type"] != "code":
if ycell["cell_type"] != "code":
return
content = {"code": cell["source"], "silent": False}
content = {"code": str(ycell["source"]), "silent": False}
msg = create_message(
"execute_request", content, session_id=self.session_id, msg_id=str(self.msg_cnt)
)
Expand All @@ -120,40 +130,61 @@ async def execute(
msg_id = msg["header"]["msg_id"]
self.msg_cnt += 1
await send_message(msg, self.shell_channel, self.key, change_date_to_str=True)
self.execute_requests[msg_id] = {
"iopub_msg": asyncio.Queue(),
"shell_msg": asyncio.Queue(),
}
if wait_for_executed:
deadline = time.time() + timeout
self.execute_requests[msg_id] = {
"iopub_msg": asyncio.Future(),
"shell_msg": asyncio.Future(),
}
while True:
try:
await asyncio.wait_for(
self.execute_requests[msg_id]["iopub_msg"],
msg = await asyncio.wait_for(
self.execute_requests[msg_id]["iopub_msg"].get(),
deadline_to_timeout(deadline),
)
except asyncio.TimeoutError:
error_message = f"Kernel didn't respond in {timeout} seconds"
raise RuntimeError(error_message)
msg = self.execute_requests[msg_id]["iopub_msg"].result()
self._handle_outputs(cell["outputs"], msg)
await self._handle_outputs(ycell["outputs"], msg)
if (
msg["header"]["msg_type"] == "status"
and msg["content"]["execution_state"] == "idle"
(msg["header"]["msg_type"] == "status"
and msg["content"]["execution_state"] == "idle")
):
break
self.execute_requests[msg_id]["iopub_msg"] = asyncio.Future()
try:
await asyncio.wait_for(
self.execute_requests[msg_id]["shell_msg"],
msg = await asyncio.wait_for(
self.execute_requests[msg_id]["shell_msg"].get(),
deadline_to_timeout(deadline),
)
except asyncio.TimeoutError:
error_message = f"Kernel didn't respond in {timeout} seconds"
raise RuntimeError(error_message)
msg = self.execute_requests[msg_id]["shell_msg"].result()
cell["execution_count"] = msg["content"]["execution_count"]
ycell["execution_count"] = msg["content"]["execution_count"]
del self.execute_requests[msg_id]
else:
self.tasks.append(asyncio.create_task(self._handle_iopub(msg_id, ycell)))

async def _handle_iopub(self, msg_id: str, ycell: Map) -> None:
while True:
msg = await self.execute_requests[msg_id]["iopub_msg"].get()
await self._handle_outputs(ycell["outputs"], msg)
if (
(msg["header"]["msg_type"] == "status"
and msg["content"]["execution_state"] == "idle")
):
msg = await self.execute_requests[msg_id]["shell_msg"].get()
ycell["execution_count"] = msg["content"]["execution_count"]

async def _handle_comms(self) -> None:
while True:
msg = await self.comm_messages.get()
msg_type = msg["header"]["msg_type"]
if msg_type == "comm_open":
comm_id = msg["content"]["comm_id"]
comm = Comm(comm_id, self.shell_channel, self.session_id, self.key)
self.yjs.widgets.comm_open(msg, comm)
elif msg_type == "comm_msg":
self.yjs.widgets.comm_msg(msg)

async def _wait_for_ready(self, timeout):
deadline = time.time() + timeout
Expand All @@ -178,22 +209,32 @@ async def _wait_for_ready(self, timeout):
break
new_timeout = deadline_to_timeout(deadline)

def _handle_outputs(self, outputs: List[Dict[str, Any]], msg: Dict[str, Any]):
async def _handle_outputs(self, outputs: Array, msg: Dict[str, Any]):
msg_type = msg["header"]["msg_type"]
content = msg["content"]
if msg_type == "stream":
if (not outputs) or (outputs[-1]["name"] != content["name"]):
outputs.append({"name": content["name"], "output_type": msg_type, "text": []})
outputs[-1]["text"].append(content["text"])
elif msg_type in ("display_data", "execute_result"):
outputs.append(
{
"data": {"text/plain": [content["data"].get("text/plain", "")]},
"execution_count": content["execution_count"],
"metadata": {},
"output_type": msg_type,
}
)
if "application/vnd.jupyter.ywidget-view+json" in content["data"]:
# this is a collaborative widget
model_id = content["data"]["application/vnd.jupyter.ywidget-view+json"]["model_id"]
if self.yjs is not None:
if model_id in self.yjs.widgets.widgets:
doc = self.yjs.widgets.widgets[model_id]["model"].ydoc
path = f"ywidget:{doc.guid}"
await self.yjs.room_manager.websocket_server.get_room(path, ydoc=doc)
outputs.append(doc)
else:
outputs.append(
{
"data": {"text/plain": [content["data"].get("text/plain", "")]},
"execution_count": content["execution_count"],
"metadata": {},
"output_type": msg_type,
}
)
elif msg_type == "error":
outputs.append(
{
Expand All @@ -203,5 +244,25 @@ def _handle_outputs(self, outputs: List[Dict[str, Any]], msg: Dict[str, Any]):
"traceback": content["traceback"],
}
)
else:
return


class Comm:
def __init__(self, comm_id: str, shell_channel, session_id: str, key: str):
self.comm_id = comm_id
self.shell_channel = shell_channel
self.session_id = session_id
self.key = key
self.msg_cnt = 0

def send(self, buffers):
msg = create_message(
"comm_msg",
content={"comm_id": self.comm_id},
session_id=self.session_id,
msg_id=self.msg_cnt,
buffers=buffers,
)
self.msg_cnt += 1
asyncio.create_task(
send_message(msg, self.shell_channel, self.key, change_date_to_str=True)
)
3 changes: 2 additions & 1 deletion plugins/kernels/fps_kernels/kernel_driver/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def create_message(
content: Dict = {},
session_id: str = "",
msg_id: str = "",
buffers: List = [],
) -> Dict[str, Any]:
header = create_message_header(msg_type, session_id, msg_id)
msg = {
Expand All @@ -65,7 +66,7 @@ def create_message(
"parent_header": {},
"content": content,
"metadata": {},
"buffers": [],
"buffers": buffers,
}
return msg

Expand Down
8 changes: 4 additions & 4 deletions plugins/kernels/fps_kernels/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,21 +259,21 @@ async def execute_cell(
execution = Execution(**r)
if kernel_id in kernels:
ynotebook = self.yjs.get_document(execution.document_id)
cell = ynotebook.get_cell(execution.cell_idx)
cell["outputs"] = []
ycell = ynotebook.ycells[execution.cell_idx]
del ycell["outputs"][:]

kernel = kernels[kernel_id]
if not kernel["driver"]:
kernel["driver"] = driver = KernelDriver(
kernelspec_path=Path(find_kernelspec(kernel["name"])).as_posix(),
write_connection_file=False,
connection_file=kernel["server"].connection_file_path,
yjs=self.yjs,
)
await driver.connect()
driver = kernel["driver"]

await driver.execute(cell)
ynotebook.set_cell(execution.cell_idx, cell)
await driver.execute(ycell, wait_for_executed=False)

async def get_kernel(
self,
Expand Down
4 changes: 2 additions & 2 deletions plugins/noauth/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ text = "BSD 3-Clause License"
Homepage = "https://jupyter.org"

[project.entry-points]
"asphalt.components" = {noauth = "fps_noauth.main:NoAuthComponent"}
"jupyverse.components" = {noauth = "fps_noauth.main:NoAuthComponent"}
"asphalt.components" = {auth = "fps_noauth.main:NoAuthComponent"}
"jupyverse.components" = {auth = "fps_noauth.main:NoAuthComponent"}

[tool.check-manifest]
ignore = [ ".*",]
Expand Down
9 changes: 6 additions & 3 deletions plugins/yjs/fps_yjs/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
WebSocketDisconnect,
status,
)
from pycrdt import Doc
from websockets.exceptions import ConnectionClosedOK

from jupyverse_api.app import App
Expand All @@ -27,6 +28,7 @@
from .ywebsocket.websocket_server import WebsocketServer, YRoom
from .ywebsocket.ystore import SQLiteYStore, YDocNotFound
from .ywebsocket.yutils import YMessageType, YSyncMessageType
from .ywidgets import Widgets

YFILE = YDOCS["file"]
AWARENESS = 1
Expand All @@ -48,6 +50,7 @@ def __init__(
super().__init__(app=app, auth=auth)
self.contents = contents
self.room_manager = RoomManager(contents)
self.widgets = Widgets()

async def collaboration_room_websocket(
self,
Expand Down Expand Up @@ -359,17 +362,17 @@ async def maybe_clean_room(self, room, ws_path: str) -> None:


class JupyterWebsocketServer(WebsocketServer):
async def get_room(self, ws_path: str) -> YRoom:
async def get_room(self, ws_path: str, ydoc: Doc | None = None) -> YRoom:
if ws_path not in self.rooms:
if ws_path.count(":") >= 2:
# it is a stored document (e.g. a notebook)
file_format, file_type, file_id = ws_path.split(":", 2)
updates_file_path = f".{file_type}:{file_id}.y"
ystore = JupyterSQLiteYStore(path=updates_file_path) # FIXME: pass in config
self.rooms[ws_path] = YRoom(ready=False, ystore=ystore)
self.rooms[ws_path] = YRoom(ydoc=ydoc, ready=False, ystore=ystore)
else:
# it is a transient document (e.g. awareness)
self.rooms[ws_path] = YRoom()
self.rooms[ws_path] = YRoom(ydoc=ydoc)
room = self.rooms[ws_path]
await self.start_room(room)
return room
5 changes: 3 additions & 2 deletions plugins/yjs/fps_yjs/ywebsocket/websocket_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from anyio import TASK_STATUS_IGNORED, Event, create_task_group
from anyio.abc import TaskGroup, TaskStatus
from pycrdt import Doc

from .websocket import Websocket
from .yroom import YRoom
Expand Down Expand Up @@ -57,7 +58,7 @@ def started(self) -> Event:
self._started = Event()
return self._started

async def get_room(self, name: str) -> YRoom:
async def get_room(self, name: str, ydoc: Doc | None = None) -> YRoom:
"""Get or create a room with the given name, and start it.
Arguments:
Expand All @@ -67,7 +68,7 @@ async def get_room(self, name: str) -> YRoom:
The room with the given name, or a new one if no room with that name was found.
"""
if name not in self.rooms.keys():
self.rooms[name] = YRoom(ready=self.rooms_ready, log=self.log)
self.rooms[name] = YRoom(ydoc=ydoc, ready=self.rooms_ready, log=self.log)
room = self.rooms[name]
await self.start_room(room)
return room
Expand Down
Loading

0 comments on commit c5dbc6f

Please sign in to comment.