diff --git a/.gitignore b/.gitignore index 70e4192c..d4751320 100644 --- a/.gitignore +++ b/.gitignore @@ -344,6 +344,3 @@ $RECYCLE.BIN/ .jupyter_ystore.db .jupyter_ystore.db-journal fps_cli_args.toml - -# pixi environments -.pixi diff --git a/jupyverse_api/jupyverse_api/__init__.py b/jupyverse_api/jupyverse_api/__init__.py index fe1d14ec..2f3b2a9b 100644 --- a/jupyverse_api/jupyverse_api/__init__.py +++ b/jupyverse_api/jupyverse_api/__init__.py @@ -1,5 +1,6 @@ -from typing import Dict +from typing import Any, Dict +from anyio import Event from pydantic import BaseModel from .app import App @@ -41,3 +42,37 @@ def mount(self, path: str, *args, **kwargs) -> None: def add_middleware(self, middleware, *args, **kwargs) -> None: self._app.add_middleware(middleware, *args, **kwargs) + + +class ResourceLock: + """ResourceLock ensures that accesses cannot be done concurrently on the same resource. + """ + _locks: Dict[Any, Event] + + def __init__(self): + self._locks = {} + + def __call__(self, idx: Any): + return _ResourceLock(idx, self._locks) + + +class _ResourceLock: + _idx: Any + _locks: Dict[Any, Event] + _lock: Event + + def __init__(self, idx: Any, locks: Dict[Any, Event]): + self._idx = idx + self._locks = locks + + async def __aenter__(self): + while True: + if self._idx in self._locks: + await self._locks[self._idx].wait() + else: + break + self._locks[self._idx] = self._lock = Event() + + async def __aexit__(self, exc_type, exc_value, exc_tb): + self._lock.set() + del self._locks[self._idx] diff --git a/jupyverse_api/jupyverse_api/cli.py b/jupyverse_api/jupyverse_api/cli.py index 20dd8a2f..bdb0ff03 100644 --- a/jupyverse_api/jupyverse_api/cli.py +++ b/jupyverse_api/jupyverse_api/cli.py @@ -2,7 +2,7 @@ from typing import List, Tuple import rich_click as click -from asphalt.core.cli import run +from asphalt.core._cli import run if sys.version_info < (3, 10): from importlib_metadata import entry_points @@ -66,8 +66,6 @@ def main( set_list.append(f"component.allow_origin={allow_origin}") config = get_config(disable) run.callback( - unsafe=False, - loop=None, set_=set_list, service=None, configfile=[config], diff --git a/jupyverse_api/jupyverse_api/contents/__init__.py b/jupyverse_api/jupyverse_api/contents/__init__.py index 4296e29b..4a9ccd90 100644 --- a/jupyverse_api/jupyverse_api/contents/__init__.py +++ b/jupyverse_api/jupyverse_api/contents/__init__.py @@ -1,11 +1,12 @@ -import asyncio +from __future__ import annotations + from abc import ABC, abstractmethod from pathlib import Path from typing import Dict, List, Optional, Union from fastapi import APIRouter, Depends, Request, Response -from jupyverse_api import Router +from jupyverse_api import ResourceLock, Router from ..app import App from ..auth import Auth, User @@ -13,8 +14,13 @@ class FileIdManager(ABC): - stop_watching_files: asyncio.Event - stopped_watching_files: asyncio.Event + @abstractmethod + async def start(self) -> None: + ... + + @abstractmethod + async def stop(self) -> None: + ... @abstractmethod async def get_path(self, file_id: str) -> str: @@ -32,9 +38,13 @@ def unwatch(self, path: str, watcher): class Contents(Router, ABC): + file_lock: ResourceLock + def __init__(self, app: App, auth: Auth): super().__init__(app=app) + self.file_lock = ResourceLock() + router = APIRouter() @router.post( diff --git a/jupyverse_api/jupyverse_api/main/__init__.py b/jupyverse_api/jupyverse_api/main/__init__.py index cfbaa3a3..7656c2b6 100644 --- a/jupyverse_api/jupyverse_api/main/__init__.py +++ b/jupyverse_api/jupyverse_api/main/__init__.py @@ -3,8 +3,9 @@ import webbrowser from typing import Any, Callable, Dict, Sequence, Tuple +from anyio import Event from asgiref.typing import ASGI3Application -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource, start_service_task from asphalt.web.fastapi import FastAPIComponent from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -22,14 +23,11 @@ def __init__( super().__init__() self.mount_path = mount_path - async def start( - self, - ctx: Context, - ) -> None: - app = await ctx.request_resource(FastAPI) + async def start(self) -> None: + app = await get_resource(FastAPI, wait=True) _app = App(app, mount_path=self.mount_path) - ctx.add_resource(_app) + add_resource(_app) class JupyverseComponent(FastAPIComponent): @@ -67,22 +65,27 @@ def __init__( self.port = port self.open_browser = open_browser self.query_params = query_params + self.lifespan = Lifespan() - async def start( - self, - ctx: Context, - ) -> None: + async def start(self) -> None: query_params = QueryParams(d={}) host = self.host if not host.startswith("http"): host = f"http://{host}" host_url = Host(url=f"{host}:{self.port}/") - ctx.add_resource(query_params) - ctx.add_resource(host_url) + add_resource(query_params) + add_resource(host_url) + add_resource(self.lifespan) - await super().start(ctx) + await super().start() # at this point, the server has started + await start_service_task( + self.lifespan.shutdown_request.wait, + "Server lifespan notifier", + teardown_action=self.lifespan.shutdown_request.set, + ) + if self.open_browser: qp = query_params.d if self.query_params: @@ -97,3 +100,8 @@ class QueryParams(BaseModel): class Host(BaseModel): url: str + + +class Lifespan: + def __init__(self): + self.shutdown_request = Event() diff --git a/jupyverse_api/pyproject.toml b/jupyverse_api/pyproject.toml index a8168883..ee1ea543 100644 --- a/jupyverse_api/pyproject.toml +++ b/jupyverse_api/pyproject.toml @@ -20,6 +20,7 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] @@ -28,8 +29,9 @@ dependencies = [ "pydantic >=2,<3", "fastapi >=0.95.0,<1", "rich-click >=1.6.1,<2", - "asphalt >=4.11.0,<5", - "asphalt-web[fastapi] >=1.1.0,<2", + "importlib_metadata >=3.6; python_version<'3.10'", + #"asphalt >=4.11.0,<5", + #"asphalt-web[fastapi] >=1.1.0,<2", ] dynamic = ["version"] diff --git a/jupyverse_api/tests/test_resource_lock.py b/jupyverse_api/tests/test_resource_lock.py new file mode 100644 index 00000000..ff0da450 --- /dev/null +++ b/jupyverse_api/tests/test_resource_lock.py @@ -0,0 +1,57 @@ +import pytest +from anyio import create_task_group, sleep + +from jupyverse_api import ResourceLock + +pytestmark = pytest.mark.anyio + + +async def do_op(operation, resource_lock, operations): + op, path = operation + async with resource_lock(path): + operations.append(operation + ["start"]) + await sleep(0.1) + operations.append(operation + ["done"]) + + +async def test_resource_lock(): + resource_lock = ResourceLock() + + # test concurrent accesses to the same resource + idx = "idx" + operations = [] + async with create_task_group() as tg: + tg.start_soon(do_op, [0, idx], resource_lock, operations) + await sleep(0.01) + tg.start_soon(do_op, [1, idx], resource_lock, operations) + + assert operations == [ + [0, idx, "start"], + [0, idx, "done"], + [1, idx, "start"], + [1, idx, "done"], + ] + + # test concurrent accesses to different files + idx0 = "idx0" + idx1 = "idx1" + operations = [] + async with create_task_group() as tg: + tg.start_soon(do_op, [0, idx0], resource_lock, operations) + await sleep(0.01) + tg.start_soon(do_op, [1, idx1], resource_lock, operations) + await sleep(0.01) + tg.start_soon(do_op, [2, idx0], resource_lock, operations) + await sleep(0.01) + tg.start_soon(do_op, [3, idx1], resource_lock, operations) + + assert operations == [ + [0, idx0, "start"], + [1, idx1, "start"], + [0, idx0, "done"], + [2, idx0, "start"], + [1, idx1, "done"], + [3, idx1, "start"], + [2, idx0, "done"], + [3, idx1, "done"], + ] diff --git a/plugins/auth/fps_auth/main.py b/plugins/auth/fps_auth/main.py index 8670a084..df086639 100644 --- a/plugins/auth/fps_auth/main.py +++ b/plugins/auth/fps_auth/main.py @@ -1,6 +1,6 @@ import logging -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource from fastapi_users.exceptions import UserAlreadyExists from jupyverse_api.app import App @@ -18,17 +18,14 @@ class AuthComponent(Component): def __init__(self, **kwargs): self.auth_config = _AuthConfig(**kwargs) - async def start( - self, - ctx: Context, - ) -> None: - ctx.add_resource(self.auth_config, types=AuthConfig) + async def start(self) -> None: + add_resource(self.auth_config, types=AuthConfig) - app = await ctx.request_resource(App) - frontend_config = await ctx.request_resource(FrontendConfig) + app = await get_resource(App, wait=True) + frontend_config = await get_resource(FrontendConfig, wait=True) auth = auth_factory(app, self.auth_config, frontend_config) - ctx.add_resource(auth, types=Auth) + add_resource(auth, types=Auth) await auth.db.create_db_and_tables() @@ -59,8 +56,8 @@ async def start( ) if self.auth_config.mode == "token": - query_params = await ctx.request_resource(QueryParams) - host = await ctx.request_resource(Host) + query_params = await get_resource(QueryParams, wait=True) + host = await get_resource(Host, wait=True) query_params.d["token"] = self.auth_config.token logger.info("") diff --git a/plugins/auth_fief/fps_auth_fief/main.py b/plugins/auth_fief/fps_auth_fief/main.py index ddc8224f..04632df9 100644 --- a/plugins/auth_fief/fps_auth_fief/main.py +++ b/plugins/auth_fief/fps_auth_fief/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource from jupyverse_api.app import App from jupyverse_api.auth import Auth, AuthConfig @@ -11,13 +11,10 @@ class AuthFiefComponent(Component): def __init__(self, **kwargs): self.auth_fief_config = _AuthFiefConfig(**kwargs) - async def start( - self, - ctx: Context, - ) -> None: - ctx.add_resource(self.auth_fief_config, types=AuthConfig) + async def start(self) -> None: + add_resource(self.auth_fief_config, types=AuthConfig) - app = await ctx.request_resource(App) + app = await get_resource(App, wait=True) auth_fief = auth_factory(app, self.auth_fief_config) - ctx.add_resource(auth_fief, types=Auth) + add_resource(auth_fief, types=Auth) diff --git a/plugins/auth_jupyterhub/fps_auth_jupyterhub/main.py b/plugins/auth_jupyterhub/fps_auth_jupyterhub/main.py index 8b9a8b71..27f9e120 100644 --- a/plugins/auth_jupyterhub/fps_auth_jupyterhub/main.py +++ b/plugins/auth_jupyterhub/fps_auth_jupyterhub/main.py @@ -1,5 +1,10 @@ -import httpx -from asphalt.core import Component, ContainerComponent, Context, context_teardown +from asphalt.core import ( + Component, + ContainerComponent, + add_resource, + get_resource, + start_service_task, +) from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession from jupyverse_api.app import App @@ -11,40 +16,29 @@ class _AuthJupyterHubComponent(Component): - @context_teardown - async def start( - self, - ctx: Context, - ) -> None: - app = await ctx.request_resource(App) - db_session = await ctx.request_resource(AsyncSession) - db_engine = await ctx.request_resource(AsyncEngine) - - http_client = httpx.AsyncClient() - auth_jupyterhub = auth_factory(app, db_session, http_client) - ctx.add_resource(auth_jupyterhub, types=Auth) + async def start(self) -> None: + app = await get_resource(App, wait=True) + db_session = await get_resource(AsyncSession, wait=True) + db_engine = await get_resource(AsyncEngine, wait=True) + + auth_jupyterhub = auth_factory(app, db_session) + await start_service_task(auth_jupyterhub.start, "JupyterHub Auth", auth_jupyterhub.stop) + add_resource(auth_jupyterhub, types=Auth) async with db_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) - yield - - await http_client.aclose() - class AuthJupyterHubComponent(ContainerComponent): def __init__(self, **kwargs): self.auth_jupyterhub_config = AuthJupyterHubConfig(**kwargs) super().__init__() - async def start( - self, - ctx: Context, - ) -> None: - ctx.add_resource(self.auth_jupyterhub_config, types=AuthConfig) + async def start(self) -> None: + add_resource(self.auth_jupyterhub_config, types=AuthConfig) self.add_component( "sqlalchemy", url=self.auth_jupyterhub_config.db_url, ) self.add_component("auth_jupyterhub", type=_AuthJupyterHubComponent) - await super().start(ctx) + await super().start() diff --git a/plugins/auth_jupyterhub/fps_auth_jupyterhub/routes.py b/plugins/auth_jupyterhub/fps_auth_jupyterhub/routes.py index e0cb72e6..b132b2c2 100644 --- a/plugins/auth_jupyterhub/fps_auth_jupyterhub/routes.py +++ b/plugins/auth_jupyterhub/fps_auth_jupyterhub/routes.py @@ -1,14 +1,16 @@ from __future__ import annotations -import asyncio import json import os from datetime import datetime +from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import httpx +from anyio import TASK_STATUS_IGNORED, Lock, create_task_group +from anyio.abc import TaskStatus from fastapi import APIRouter, Cookie, Depends, HTTPException, Request, WebSocket, status from fastapi.responses import RedirectResponse +from httpx import AsyncClient from jupyterhub.services.auth import HubOAuth from jupyterhub.utils import isoformat from sqlalchemy.ext.asyncio import AsyncSession @@ -26,16 +28,15 @@ def auth_factory( app: App, db_session: AsyncSession, - http_client: httpx.AsyncClient, ): class AuthJupyterHub(Auth, Router): def __init__(self) -> None: super().__init__(app) self.hub_auth = HubOAuth() - self.db_lock = asyncio.Lock() + self.db_lock = Lock() self.activity_url = os.environ.get("JUPYTERHUB_ACTIVITY_URL") self.server_name = os.environ.get("JUPYTERHUB_SERVER_NAME") - self.background_tasks = set() + self.http_client = AsyncClient() router = APIRouter() @@ -123,8 +124,9 @@ async def _( "Content-Type": "application/json", } last_activity = isoformat(datetime.utcnow()) - task = asyncio.create_task( - http_client.post( + self.task_group.start_soon( + partial( + self.http_client.post, self.activity_url, headers=headers, json={ @@ -132,8 +134,6 @@ async def _( }, ) ) - self.background_tasks.add(task) - task.add_done_callback(self.background_tasks.discard) return user if permissions: @@ -193,4 +193,13 @@ async def _( return _ + async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None: + async with create_task_group() as tg: + self.task_group = tg + task_status.started() + + async def stop(self) -> None: + await self.http_client.aclose() + self.task_group.cancel_scope().cancel() + return AuthJupyterHub() diff --git a/plugins/auth_jupyterhub/pyproject.toml b/plugins/auth_jupyterhub/pyproject.toml index 62b5037b..b4121812 100644 --- a/plugins/auth_jupyterhub/pyproject.toml +++ b/plugins/auth_jupyterhub/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "httpx >=0.24.1,<1", "jupyterhub >=4.0.1,<5", "jupyverse-api >=0.1.2,<1", + "anyio", ] [[project.authors]] diff --git a/plugins/contents/fps_contents/fileid.py b/plugins/contents/fps_contents/fileid.py index f489c59d..6d6a78fc 100644 --- a/plugins/contents/fps_contents/fileid.py +++ b/plugins/contents/fps_contents/fileid.py @@ -1,28 +1,28 @@ -import asyncio +from __future__ import annotations + import logging +import sqlite3 from typing import Dict, List, Optional, Set from uuid import uuid4 -import aiosqlite -from anyio import Path +from anyio import Event, Lock, Path +from sqlite_anyio import connect from watchfiles import Change, awatch -from jupyverse_api import Singleton - logger = logging.getLogger("contents") class Watcher: def __init__(self, path: str) -> None: self.path = path - self._event = asyncio.Event() + self._event = Event() def __aiter__(self): return self async def __anext__(self): await self._event.wait() - self._event.clear() + self._event = Event() return self._change def notify(self, change): @@ -30,128 +30,133 @@ def notify(self, change): self._event.set() -class FileIdManager(metaclass=Singleton): +class FileIdManager: db_path: str - initialized: asyncio.Event + initialized: Event watchers: Dict[str, List[Watcher]] - lock: asyncio.Lock + lock: Lock def __init__(self, db_path: str = ".fileid.db"): self.db_path = db_path - self.initialized = asyncio.Event() + self.initialized = Event() self.watchers = {} - self.watch_files_task = asyncio.create_task(self.watch_files()) - self.stop_watching_files = asyncio.Event() - self.stopped_watching_files = asyncio.Event() - self.lock = asyncio.Lock() + self.stop_event = Event() + self.lock = Lock() + + async def start(self) -> None: + self._db = await connect(self.db_path) + try: + await self.watch_files() + except sqlite3.ProgrammingError: + pass + + async def stop(self) -> None: + await self._db.close() + self.stop_event.set() async def get_id(self, path: str) -> Optional[str]: await self.initialized.wait() async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - async with db.execute("SELECT id FROM fileids WHERE path = ?", (path,)) as cursor: - async for (idx,) in cursor: - return idx - return None + cursor = await self._db.cursor() + await cursor.execute("SELECT id FROM fileids WHERE path = ?", (path,)) + for (idx,) in await cursor.fetchall(): + return idx + return None async def get_path(self, idx: str) -> Optional[str]: await self.initialized.wait() async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - async with db.execute("SELECT path FROM fileids WHERE id = ?", (idx,)) as cursor: - async for (path,) in cursor: - return path - return None + cursor = await self._db.cursor() + await cursor.execute("SELECT path FROM fileids WHERE id = ?", (idx,)) + for (path,) in await cursor.fetchall(): + return path + return None async def index(self, path: str) -> Optional[str]: await self.initialized.wait() async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - apath = Path(path) - if not await apath.exists(): - return None + apath = Path(path) + if not await apath.exists(): + return None - idx = uuid4().hex - mtime = (await apath.stat()).st_mtime - await db.execute("INSERT INTO fileids VALUES (?, ?, ?)", (idx, path, mtime)) - await db.commit() - return idx + idx = uuid4().hex + mtime = (await apath.stat()).st_mtime + cursor = await self._db.cursor() + await cursor.execute("INSERT INTO fileids VALUES (?, ?, ?)", (idx, path, mtime)) + await self._db.commit() + return idx async def watch_files(self): async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - await db.execute("DROP TABLE IF EXISTS fileids") - await db.execute( - "CREATE TABLE fileids " - "(id TEXT PRIMARY KEY, path TEXT NOT NULL UNIQUE, mtime REAL NOT NULL)" - ) - await db.commit() + cursor = await self._db.cursor() + await cursor.execute("DROP TABLE IF EXISTS fileids") + await cursor.execute( + "CREATE TABLE fileids " + "(id TEXT PRIMARY KEY, path TEXT NOT NULL UNIQUE, mtime REAL NOT NULL)" + ) + await self._db.commit() # index files async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - async for path in Path().rglob("*"): - idx = uuid4().hex - mtime = (await path.stat()).st_mtime - await db.execute( - "INSERT INTO fileids VALUES (?, ?, ?)", (idx, str(path), mtime) - ) - await db.commit() - self.initialized.set() - - async for changes in awatch(".", stop_event=self.stop_watching_files): + cursor = await self._db.cursor() + async for path in Path().rglob("*"): + idx = uuid4().hex + mtime = (await path.stat()).st_mtime + await cursor.execute( + "INSERT INTO fileids VALUES (?, ?, ?)", (idx, str(path), mtime) + ) + await self._db.commit() + self.initialized.set() + + async for changes in awatch(".", stop_event=self.stop_event): async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - deleted_paths = set() - added_paths = set() - for change, changed_path in changes: - # get relative path - changed_path = Path(changed_path).relative_to(await Path().absolute()) - changed_path_str = str(changed_path) - - if change == Change.deleted: - logger.debug("File %s was deleted", changed_path_str) - async with db.execute( - "SELECT COUNT(*) FROM fileids WHERE path = ?", (changed_path_str,) - ) as cursor: - if not (await cursor.fetchone())[0]: - # path is not indexed, ignore - logger.debug( - "File %s is not indexed, ignoring", changed_path_str - ) - continue - # path is indexed - await maybe_rename( - db, changed_path_str, deleted_paths, added_paths, False - ) - elif change == Change.added: - logger.debug("File %s was added", changed_path_str) - await maybe_rename( - db, changed_path_str, added_paths, deleted_paths, True - ) - elif change == Change.modified: - logger.debug("File %s was modified", changed_path_str) - if changed_path_str == self.db_path: - continue - async with db.execute( - "SELECT COUNT(*) FROM fileids WHERE path = ?", (changed_path_str,) - ) as cursor: - if not (await cursor.fetchone())[0]: - # path is not indexed, ignore - logger.debug( - "File %s is not indexed, ignoring", changed_path_str - ) - continue - mtime = (await changed_path.stat()).st_mtime - await db.execute( - "UPDATE fileids SET mtime = ? WHERE path = ?", - (mtime, changed_path_str), - ) - - for path in deleted_paths - added_paths: - logger.debug("Unindexing file %s ", path) - await db.execute("DELETE FROM fileids WHERE path = ?", (path,)) - await db.commit() + deleted_paths = set() + added_paths = set() + cursor = await self._db.cursor() + for change, changed_path in changes: + # get relative path + changed_path = Path(changed_path).relative_to(await Path().absolute()) + changed_path_str = str(changed_path) + + if change == Change.deleted: + logger.debug("File %s was deleted", changed_path_str) + await cursor.execute( + "SELECT COUNT(*) FROM fileids WHERE path = ?", (changed_path_str,) + ) + if not (await cursor.fetchone())[0]: + # path is not indexed, ignore + logger.debug("File %s is not indexed, ignoring", changed_path_str) + continue + # path is indexed + await maybe_rename( + self._db, changed_path_str, deleted_paths, added_paths, False + ) + elif change == Change.added: + logger.debug("File %s was added", changed_path_str) + await maybe_rename( + self._db, changed_path_str, added_paths, deleted_paths, True + ) + elif change == Change.modified: + logger.debug("File %s was modified", changed_path_str) + if changed_path_str == self.db_path: + continue + await cursor.execute( + "SELECT COUNT(*) FROM fileids WHERE path = ?", (changed_path_str,) + ) + if not (await cursor.fetchone())[0]: + # path is not indexed, ignore + logger.debug("File %s is not indexed, ignoring", changed_path_str) + continue + mtime = (await changed_path.stat()).st_mtime + await cursor.execute( + "UPDATE fileids SET mtime = ? WHERE path = ?", + (mtime, changed_path_str), + ) + + for path in deleted_paths - added_paths: + logger.debug("Unindexing file %s ", path) + await cursor.execute("DELETE FROM fileids WHERE path = ?", (path,)) + await self._db.commit() for change in changes: changed_path = change[1] @@ -161,8 +166,6 @@ async def watch_files(self): for watcher in self.watchers.get(relative_changed_path, []): watcher.notify(relative_change) - self.stopped_watching_files.set() - def watch(self, path: str) -> Watcher: watcher = Watcher(path) self.watchers.setdefault(path, []).append(watcher) @@ -174,11 +177,12 @@ def unwatch(self, path: str, watcher: Watcher): async def get_mtime(path, db) -> Optional[float]: if db: - async with db.execute("SELECT mtime FROM fileids WHERE path = ?", (path,)) as cursor: - async for (mtime,) in cursor: - return mtime - # deleted file is not in database, shouldn't happen - return None + cursor = await db.cursor() + await cursor.execute("SELECT mtime FROM fileids WHERE path = ?", (path,)) + for (mtime,) in await cursor.fetchall(): + return mtime + # deleted file is not in database, shouldn't happen + return None try: mtime = (await Path(path).stat()).st_mtime except FileNotFoundError: @@ -204,7 +208,8 @@ async def maybe_rename( if is_added_path: path1, path2 = path2, path1 logger.debug("File %s was renamed to %s", path1, path2) - await db.execute("UPDATE fileids SET path = ? WHERE path = ?", (path2, path1)) + cursor = await db.cursor() + await cursor.execute("UPDATE fileids SET path = ? WHERE path = ?", (path2, path1)) other_paths.remove(other_path) return changed_paths.add(changed_path) diff --git a/plugins/contents/fps_contents/main.py b/plugins/contents/fps_contents/main.py index 81f6fd18..8582ce89 100644 --- a/plugins/contents/fps_contents/main.py +++ b/plugins/contents/fps_contents/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource from jupyverse_api.app import App from jupyverse_api.auth import Auth @@ -8,12 +8,9 @@ class ContentsComponent(Component): - async def start( - self, - ctx: Context, - ) -> None: - app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) # type: ignore + async def start(self) -> None: + app = await get_resource(App, wait=True) + auth = await get_resource(Auth, wait=True) # type: ignore[type-abstract] contents = _Contents(app, auth) - ctx.add_resource(contents, types=Contents) + add_resource(contents, types=Contents) diff --git a/plugins/contents/fps_contents/routes.py b/plugins/contents/fps_contents/routes.py index 7f759b9e..c39f1e04 100644 --- a/plugins/contents/fps_contents/routes.py +++ b/plugins/contents/fps_contents/routes.py @@ -1,13 +1,15 @@ +from __future__ import annotations + import base64 import json import os import shutil -from datetime import datetime +from datetime import datetime, timezone from http import HTTPStatus from pathlib import Path from typing import Dict, List, Optional, Union, cast -from anyio import open_file +from anyio import CancelScope, open_file from fastapi import HTTPException, Response from starlette.requests import Request @@ -25,6 +27,8 @@ class _Contents(Contents): + _file_id_manager: FileIdManager | None = None + async def create_checkpoint( self, path, @@ -143,109 +147,115 @@ async def read_content( ) -> Content: if isinstance(path, str): path = Path(path) - content: Optional[Union[str, Dict, List[Dict]]] = None - if get_content: + async with self.file_lock(path): + content: Optional[Union[str, Dict, List[Dict]]] = None + if get_content: + if path.is_dir(): + content = [ + (await self.read_content(subpath, get_content=False)).model_dump() + for subpath in path.iterdir() + if not subpath.name.startswith(".") + ] + elif path.is_file() or path.is_symlink(): + try: + async with await open_file(path, mode="rb") as f: + content_bytes = await f.read() + if file_format == "base64": + content = base64.b64encode(content_bytes).decode("ascii") + elif file_format == "json": + content = json.loads(content_bytes) + else: + content = content_bytes.decode() + except Exception: + raise HTTPException(status_code=404, detail="Item not found") + format: Optional[str] = None if path.is_dir(): - content = [ - (await self.read_content(subpath, get_content=False)).model_dump() - for subpath in path.iterdir() - if not subpath.name.startswith(".") - ] - elif path.is_file() or path.is_symlink(): - try: - async with await open_file(path, mode="rb") as f: - content_bytes = await f.read() - if file_format == "base64": - content = base64.b64encode(content_bytes).decode("ascii") - elif file_format == "json": - content = json.loads(content_bytes) - else: - content = content_bytes.decode() - except Exception: - raise HTTPException(status_code=404, detail="Item not found") - format: Optional[str] = None - if path.is_dir(): - size = None - type = "directory" - format = "json" - mimetype = None - elif path.is_file() or path.is_symlink(): - size = get_file_size(path) - if path.suffix == ".ipynb": - type = "notebook" - format = None + size = None + type = "directory" + format = "json" mimetype = None - if content is not None: - nb: dict - if file_format == "json": - content = cast(Dict, content) - nb = content - else: - content = cast(str, content) - nb = json.loads(content) - for cell in nb["cells"]: - if "metadata" not in cell: - cell["metadata"] = {} - cell["metadata"].update({"trusted": False}) - if cell["cell_type"] == "code": - cell_source = cell["source"] - if not isinstance(cell_source, str): - cell["source"] = "".join(cell_source) - if file_format != "json": - content = json.dumps(nb) - elif path.suffix == ".json": - type = "json" - format = "text" - mimetype = "application/json" + elif path.is_file() or path.is_symlink(): + size = get_file_size(path) + if path.suffix == ".ipynb": + type = "notebook" + format = None + mimetype = None + if content is not None: + nb: dict + if file_format == "json": + content = cast(Dict, content) + nb = content + else: + content = cast(str, content) + nb = json.loads(content) + for cell in nb["cells"]: + if "metadata" not in cell: + cell["metadata"] = {} + cell["metadata"].update({"trusted": False}) + if cell["cell_type"] == "code": + cell_source = cell["source"] + if not isinstance(cell_source, str): + cell["source"] = "".join(cell_source) + if file_format != "json": + content = json.dumps(nb) + elif path.suffix == ".json": + type = "json" + format = "text" + mimetype = "application/json" + else: + type = "file" + format = None + mimetype = "text/plain" else: - type = "file" - format = None - mimetype = "text/plain" - else: - raise HTTPException(status_code=404, detail="Item not found") + raise HTTPException(status_code=404, detail="Item not found") - return Content( - **{ - "name": path.name, - "path": path.as_posix(), - "last_modified": get_file_modification_time(path), - "created": get_file_creation_time(path), - "content": content, - "format": format, - "mimetype": mimetype, - "size": size, - "writable": is_file_writable(path), - "type": type, - } - ) + return Content( + **{ + "name": path.name, + "path": path.as_posix(), + "last_modified": get_file_modification_time(path), + "created": get_file_creation_time(path), + "content": content, + "format": format, + "mimetype": mimetype, + "size": size, + "writable": is_file_writable(path), + "type": type, + } + ) async def write_content(self, content: Union[SaveContent, Dict]) -> None: - if not isinstance(content, SaveContent): - content = SaveContent(**content) - if content.format == "base64": - async with await open_file(content.path, "wb") as f: - content.content = cast(str, content.content) - content_bytes = content.content.encode("ascii") - await f.write(content_bytes) - else: - async with await open_file(content.path, "wt") as f: - if content.format == "json": - dict_content = cast(Dict, content.content) - if content.type == "notebook": - # see https://github.com/jupyterlab/jupyterlab/issues/11005 - if ( - "metadata" in dict_content - and "orig_nbformat" in dict_content["metadata"] - ): - del dict_content["metadata"]["orig_nbformat"] - await f.write(json.dumps(dict_content, indent=2)) + # writing can never be cancelled, otherwise it would corrupt the file + with CancelScope(shield=True): + if not isinstance(content, SaveContent): + content = SaveContent(**content) + async with self.file_lock(Path(content.path)): + if content.format == "base64": + async with await open_file(content.path, "wb") as f: + content.content = cast(str, content.content) + content_bytes = content.content.encode("ascii") + await f.write(content_bytes) else: - content.content = cast(str, content.content) - await f.write(content.content) + async with await open_file(content.path, "wt") as f: + if content.format == "json": + dict_content = cast(Dict, content.content) + if content.type == "notebook": + # see https://github.com/jupyterlab/jupyterlab/issues/11005 + if ( + "metadata" in dict_content + and "orig_nbformat" in dict_content["metadata"] + ): + del dict_content["metadata"]["orig_nbformat"] + await f.write(json.dumps(dict_content, indent=2)) + else: + content.content = cast(str, content.content) + await f.write(content.content) @property def file_id_manager(self): - return FileIdManager() + if self._file_id_manager is None: + self._file_id_manager = FileIdManager() + return self._file_id_manager def get_available_path(path: Path, sep: str = "") -> Path: @@ -268,12 +278,16 @@ def get_available_path(path: Path, sep: str = "") -> Path: def get_file_modification_time(path: Path): if path.exists(): - return datetime.utcfromtimestamp(path.stat().st_mtime).isoformat() + "Z" + return datetime.fromtimestamp(path.stat().st_mtime, tz=timezone.utc).isoformat().replace( + "+00:00", "Z" + ) def get_file_creation_time(path: Path): if path.exists(): - return datetime.utcfromtimestamp(path.stat().st_ctime).isoformat() + "Z" + return datetime.fromtimestamp(path.stat().st_ctime, tz=timezone.utc).isoformat().replace( + "+00:00", "Z" + ) def get_file_size(path: Path) -> Optional[int]: diff --git a/plugins/contents/pyproject.toml b/plugins/contents/pyproject.toml index 22ac019b..111ed3fa 100644 --- a/plugins/contents/pyproject.toml +++ b/plugins/contents/pyproject.toml @@ -9,7 +9,7 @@ keywords = ["jupyter", "server", "fastapi", "plugins"] requires-python = ">=3.8" dependencies = [ "watchfiles >=0.18.1,<1", - "aiosqlite >=0.17.0,<1", + "sqlite-anyio >=0.2.0,<0.3.0", "anyio>=3.6.2,<5", "jupyverse-api >=0.1.2,<1", ] diff --git a/plugins/frontend/fps_frontend/main.py b/plugins/frontend/fps_frontend/main.py index b34b488b..1b502de3 100644 --- a/plugins/frontend/fps_frontend/main.py +++ b/plugins/frontend/fps_frontend/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource from jupyverse_api.frontend import FrontendConfig @@ -7,8 +7,5 @@ class FrontendComponent(Component): def __init__(self, **kwargs): self.frontend_config = FrontendConfig(**kwargs) - async def start( - self, - ctx: Context, - ) -> None: - ctx.add_resource(self.frontend_config, types=FrontendConfig) + async def start(self) -> None: + add_resource(self.frontend_config, types=FrontendConfig) diff --git a/plugins/jupyterlab/fps_jupyterlab/main.py b/plugins/jupyterlab/fps_jupyterlab/main.py index 2cd31dd9..31865b8b 100644 --- a/plugins/jupyterlab/fps_jupyterlab/main.py +++ b/plugins/jupyterlab/fps_jupyterlab/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource from jupyverse_api.app import App from jupyverse_api.auth import Auth @@ -13,16 +13,13 @@ class JupyterLabComponent(Component): def __init__(self, **kwargs): self.jupyterlab_config = JupyterLabConfig(**kwargs) - async def start( - self, - ctx: Context, - ) -> None: - ctx.add_resource(self.jupyterlab_config, types=JupyterLabConfig) + async def start(self) -> None: + add_resource(self.jupyterlab_config, types=JupyterLabConfig) - app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) # type: ignore - frontend_config = await ctx.request_resource(FrontendConfig) - lab = await ctx.request_resource(Lab) # type: ignore + app = await get_resource(App, wait=True) + auth = await get_resource(Auth, wait=True) # type: ignore[type-abstract] + frontend_config = await get_resource(FrontendConfig, wait=True) + lab = await get_resource(Lab, wait=True) # type: ignore[type-abstract] jupyterlab = _JupyterLab(app, self.jupyterlab_config, auth, frontend_config, lab) - ctx.add_resource(jupyterlab, types=JupyterLab) + add_resource(jupyterlab, types=JupyterLab) diff --git a/plugins/kernels/fps_kernels/kernel_driver/connect.py b/plugins/kernels/fps_kernels/kernel_driver/connect.py index 8c177a89..0eb4aaaa 100644 --- a/plugins/kernels/fps_kernels/kernel_driver/connect.py +++ b/plugins/kernels/fps_kernels/kernel_driver/connect.py @@ -1,14 +1,17 @@ -import asyncio +from __future__ import annotations + import json import os import socket +import subprocess import sys import tempfile import uuid from typing import Dict, Optional, Tuple, Union import zmq -import zmq.asyncio +from anyio import open_process +from anyio.abc import Process from zmq.asyncio import Socket channel_socket_types = { @@ -71,8 +74,8 @@ def read_connection_file(fname: str) -> cfg_t: async def launch_kernel( - kernelspec_path: str, connection_file_path: str, kernel_cwd: str, capture_output: bool -) -> asyncio.subprocess.Process: + kernelspec_path: str, connection_file_path: str, kernel_cwd: str | None, capture_output: bool +) -> Process: with open(kernelspec_path) as f: kernelspec = json.load(f) cmd = [s.format(connection_file=connection_file_path) for s in kernelspec["argv"]] @@ -82,18 +85,16 @@ async def launch_kernel( "python%i.%i" % sys.version_info[:2], }: cmd[0] = sys.executable - if kernel_cwd: - prev_dir = os.getcwd() - os.chdir(kernel_cwd) if capture_output: - p = await asyncio.create_subprocess_exec( - *cmd, stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.STDOUT - ) + stdout = subprocess.DEVNULL + stderr = subprocess.STDOUT else: - p = await asyncio.create_subprocess_exec(*cmd) - if kernel_cwd: - os.chdir(prev_dir) - return p + stdout = None + stderr = None + if not kernel_cwd: + kernel_cwd = None + process = await open_process(cmd, stdout=stdout, stderr=stderr, cwd=kernel_cwd) + return process def create_socket(channel: str, cfg: cfg_t, identity: Optional[bytes] = None) -> Socket: diff --git a/plugins/kernels/fps_kernels/kernel_driver/driver.py b/plugins/kernels/fps_kernels/kernel_driver/driver.py index d1761b77..b1242339 100644 --- a/plugins/kernels/fps_kernels/kernel_driver/driver.py +++ b/plugins/kernels/fps_kernels/kernel_driver/driver.py @@ -1,9 +1,18 @@ -import asyncio import os import time import uuid -from typing import Any, Dict, List, Optional, cast +from typing import Any, Dict, Optional, cast +import anyio +from anyio import ( + TASK_STATUS_IGNORED, + Event, + create_memory_object_stream, + create_task_group, + fail_after, +) +from anyio.abc import TaskGroup, TaskStatus +from anyio.streams.stapled import StapledObjectStream from pycrdt import Array, Map from jupyverse_api.yjs import Yjs @@ -19,6 +28,8 @@ def deadline_to_timeout(deadline: float) -> float: class KernelDriver: + task_group: TaskGroup + def __init__( self, kernel_name: str = "", @@ -29,6 +40,7 @@ def __init__( capture_kernel_output: bool = True, yjs: Optional[Yjs] = None, ) -> None: + self.write_connection_file = write_connection_file self.capture_kernel_output = capture_kernel_output self.kernelspec_path = kernelspec_path or find_kernelspec(kernel_name) self.kernel_cwd = kernel_cwd @@ -43,40 +55,56 @@ def __init__( self.key = cast(str, self.connection_cfg["key"]) self.session_id = uuid.uuid4().hex self.msg_cnt = 0 - self.execute_requests: Dict[str, Dict[str, asyncio.Queue]] = {} - self.comm_messages: asyncio.Queue = asyncio.Queue() - self.tasks: List[asyncio.Task] = [] + self.execute_requests: Dict[str, Dict[str, StapledObjectStream]] = {} + self.comm_messages: StapledObjectStream = StapledObjectStream( + *create_memory_object_stream[dict](max_buffer_size=1024) + ) + self.stopped_event = Event() async def restart(self, startup_timeout: float = float("inf")) -> None: - for task in self.tasks: - task.cancel() - msg = create_message("shutdown_request", content={"restart": True}) - await send_message(msg, self.control_channel, self.key, change_date_to_str=True) - while True: - msg = cast( - Dict[str, Any], await receive_message(self.control_channel, change_str_to_date=True) - ) - if msg["msg_type"] == "shutdown_reply" and msg["content"]["restart"]: - break - await self._wait_for_ready(startup_timeout) - self.tasks = [] - self.listen_channels() + self.task_group.cancel_scope.cancel() + await self.stopped_event.wait() + self.stopped_event = Event() + async with create_task_group() as tg: + self.task_group = tg + msg = create_message("shutdown_request", content={"restart": True}) + await send_message(msg, self.control_channel, self.key, change_date_to_str=True) + while True: + msg = cast( + Dict[str, Any], + await receive_message(self.control_channel, change_str_to_date=True), + ) + if msg["msg_type"] == "shutdown_reply" and msg["content"]["restart"]: + break + await self._wait_for_ready(startup_timeout) + self.listen_channels() + tg.start_soon(self._handle_comms) - async def start(self, startup_timeout: float = float("inf"), connect: bool = True) -> None: - self.kernel_process = await launch_kernel( - self.kernelspec_path, - self.connection_file_path, - self.kernel_cwd, - self.capture_kernel_output, - ) - if connect: - await self.connect(startup_timeout) + async def start( + self, + startup_timeout: float = float("inf"), + connect: bool = True, + *, + task_status: TaskStatus[None] = TASK_STATUS_IGNORED, + ) -> None: + async with create_task_group() as tg: + self.task_group = tg + self.kernel_process = await launch_kernel( + self.kernelspec_path, + self.connection_file_path, + self.kernel_cwd, + self.capture_kernel_output, + ) + if connect: + await self.connect() + task_status.started() + self.stopped_event.set() async def connect(self, startup_timeout: float = float("inf")) -> None: self.connect_channels() await self._wait_for_ready(startup_timeout) self.listen_channels() - self.tasks.append(asyncio.create_task(self._handle_comms())) + self.task_group.start_soon(self._handle_comms) def connect_channels(self, connection_cfg: Optional[cfg_t] = None): connection_cfg = connection_cfg or self.connection_cfg @@ -85,31 +113,38 @@ def connect_channels(self, connection_cfg: Optional[cfg_t] = None): self.iopub_channel = connect_channel("iopub", connection_cfg) def listen_channels(self): - self.tasks.append(asyncio.create_task(self.listen_iopub())) - self.tasks.append(asyncio.create_task(self.listen_shell())) + (self.task_group.start_soon(self.listen_iopub),) + (self.task_group.start_soon(self.listen_shell),) async def stop(self) -> None: - self.kernel_process.kill() + try: + self.kernel_process.terminate() + except ProcessLookupError: + pass await self.kernel_process.wait() - os.remove(self.connection_file_path) - for task in self.tasks: - task.cancel() + self.task_group.cancel_scope.cancel() + if self.write_connection_file: + path = anyio.Path(self.connection_file_path) + try: + await path.unlink() + except Exception: + pass async def listen_iopub(self): while True: msg = await receive_message(self.iopub_channel, change_str_to_date=True) parent_id = msg["parent_header"].get("msg_id") if msg["msg_type"] in ("comm_open", "comm_msg"): - self.comm_messages.put_nowait(msg) + await self.comm_messages.send(msg) elif parent_id in self.execute_requests.keys(): - self.execute_requests[parent_id]["iopub_msg"].put_nowait(msg) + await self.execute_requests[parent_id]["iopub_msg"].send(msg) async def listen_shell(self): while True: msg = await receive_message(self.shell_channel, change_str_to_date=True) msg_id = msg["parent_header"].get("msg_id") if msg_id in self.execute_requests.keys(): - self.execute_requests[msg_id]["shell_msg"].put_nowait(msg) + await self.execute_requests[msg_id]["shell_msg"].send(msg) async def execute( self, @@ -132,32 +167,32 @@ async def execute( self.msg_cnt += 1 await send_message(msg, self.shell_channel, self.key, change_date_to_str=True) self.execute_requests[msg_id] = { - "iopub_msg": asyncio.Queue(), - "shell_msg": asyncio.Queue(), + "iopub_msg": StapledObjectStream( + *create_memory_object_stream[dict](max_buffer_size=1024) + ), + "shell_msg": StapledObjectStream( + *create_memory_object_stream[dict](max_buffer_size=1024) + ), } if wait_for_executed: deadline = time.time() + timeout while True: try: - msg = await asyncio.wait_for( - self.execute_requests[msg_id]["iopub_msg"].get(), - deadline_to_timeout(deadline), - ) - except asyncio.TimeoutError: + with fail_after(deadline_to_timeout(deadline)): + msg = await self.execute_requests[msg_id]["iopub_msg"].receive() + except TimeoutError: error_message = f"Kernel didn't respond in {timeout} seconds" raise RuntimeError(error_message) await self._handle_outputs(ycell["outputs"], msg) if ( - (msg["header"]["msg_type"] == "status" - and msg["content"]["execution_state"] == "idle") + msg["header"]["msg_type"] == "status" + and msg["content"]["execution_state"] == "idle" ): break try: - msg = await asyncio.wait_for( - self.execute_requests[msg_id]["shell_msg"].get(), - deadline_to_timeout(deadline), - ) - except asyncio.TimeoutError: + with fail_after(deadline_to_timeout(deadline)): + msg = await self.execute_requests[msg_id]["shell_msg"].receive() + except TimeoutError: error_message = f"Kernel didn't respond in {timeout} seconds" raise RuntimeError(error_message) with ycell.doc.transaction(): @@ -165,31 +200,32 @@ async def execute( ycell["execution_state"] = "idle" del self.execute_requests[msg_id] else: - self.tasks.append(asyncio.create_task(self._handle_iopub(msg_id, ycell))) + self.task_group.start_soon(lambda: self._handle_iopub(msg_id, ycell)) async def _handle_iopub(self, msg_id: str, ycell: Map) -> None: while True: - msg = await self.execute_requests[msg_id]["iopub_msg"].get() + msg = await self.execute_requests[msg_id]["iopub_msg"].receive() await self._handle_outputs(ycell["outputs"], msg) if ( - (msg["header"]["msg_type"] == "status" - and msg["content"]["execution_state"] == "idle") + msg["header"]["msg_type"] == "status" + and msg["content"]["execution_state"] == "idle" ): - msg = await self.execute_requests[msg_id]["shell_msg"].get() + msg = await self.execute_requests[msg_id]["shell_msg"].receive() with ycell.doc.transaction(): ycell["execution_count"] = msg["content"]["execution_count"] ycell["execution_state"] = "idle" + break async def _handle_comms(self) -> None: if self.yjs is None or self.yjs.widgets is None: # type: ignore return while True: - msg = await self.comm_messages.get() + msg = await self.comm_messages.receive() msg_type = msg["header"]["msg_type"] if msg_type == "comm_open": comm_id = msg["content"]["comm_id"] - comm = Comm(comm_id, self.shell_channel, self.session_id, self.key) + comm = Comm(comm_id, self.shell_channel, self.session_id, self.key, self.task_group) self.yjs.widgets.comm_open(msg, comm) # type: ignore elif msg_type == "comm_msg": self.yjs.widgets.comm_msg(msg) # type: ignore @@ -228,13 +264,13 @@ async def _handle_outputs(self, outputs: Array, msg: Dict[str, Any]): text = text[:-1] if (not outputs) or (outputs[-1]["name"] != content["name"]): # type: ignore outputs.append( - #Map( + # Map( # { # "name": content["name"], # "output_type": msg_type, # "text": Array([content["text"]]), # } - #) + # ) { "name": content["name"], "output_type": msg_type, @@ -242,7 +278,7 @@ async def _handle_outputs(self, outputs: Array, msg: Dict[str, Any]): } ) else: - #outputs[-1]["text"].append(content["text"]) # type: ignore + # outputs[-1]["text"].append(content["text"]) # type: ignore last_output = outputs[-1] last_output["text"].append(text) # type: ignore outputs[-1] = last_output @@ -277,11 +313,14 @@ async def _handle_outputs(self, outputs: Array, msg: Dict[str, Any]): class Comm: - def __init__(self, comm_id: str, shell_channel, session_id: str, key: str): + def __init__( + self, comm_id: str, shell_channel, session_id: str, key: str, task_group: TaskGroup + ): self.comm_id = comm_id self.shell_channel = shell_channel self.session_id = session_id self.key = key + self.task_group = task_group self.msg_cnt = 0 def send(self, buffers): @@ -293,6 +332,6 @@ def send(self, buffers): buffers=buffers, ) self.msg_cnt += 1 - asyncio.create_task( - send_message(msg, self.shell_channel, self.key, change_date_to_str=True) + self.task_group.start_soon( + lambda: send_message(msg, self.shell_channel, self.key, change_date_to_str=True) ) diff --git a/plugins/kernels/fps_kernels/kernel_driver/message.py b/plugins/kernels/fps_kernels/kernel_driver/message.py index 6946c73c..5905a46c 100644 --- a/plugins/kernels/fps_kernels/kernel_driver/message.py +++ b/plugins/kernels/fps_kernels/kernel_driver/message.py @@ -32,7 +32,7 @@ def date_to_str(obj: Dict[str, Any]): def utcnow() -> datetime: - return datetime.utcnow().replace(tzinfo=timezone.utc) + return datetime.now(tz=timezone.utc) def create_message_header(msg_type: str, session_id: str, msg_id: str) -> Dict[str, Any]: diff --git a/plugins/kernels/fps_kernels/kernel_server/server.py b/plugins/kernels/fps_kernels/kernel_server/server.py index e10a2b84..0336136c 100644 --- a/plugins/kernels/fps_kernels/kernel_server/server.py +++ b/plugins/kernels/fps_kernels/kernel_server/server.py @@ -1,11 +1,12 @@ -import asyncio import json -import os import signal import uuid -from datetime import datetime +from datetime import datetime, timezone from typing import Dict, Iterable, List, Optional, cast +import anyio +from anyio import TASK_STATUS_IGNORED, Event, create_task_group +from anyio.abc import TaskGroup, TaskStatus from fastapi import WebSocket, WebSocketDisconnect from starlette.websockets import WebSocketState @@ -62,7 +63,6 @@ def __init__( self.connection_cfg = connection_cfg self.connection_file = connection_file self.write_connection_file = write_connection_file - self.channel_tasks: List[asyncio.Task] = [] self.sessions: Dict[str, AcceptedWebSocket] = {} # blocked messages and allowed messages are mutually exclusive self.blocked_messages: List[str] = [] @@ -104,77 +104,88 @@ def allow_messages(self, message_types: Optional[Iterable[str]] = None): def connections(self) -> int: return len(self.sessions) - async def start(self, launch_kernel: bool = True) -> None: - self.last_activity = { - "date": datetime.utcnow().isoformat() + "Z", - "execution_state": "starting", - } - if launch_kernel: - if not self.kernelspec_path: - raise RuntimeError("Could not find a kernel, maybe you forgot to install one?") - self.kernel_process = await _launch_kernel( - self.kernelspec_path, - self.connection_file_path, - self.kernel_cwd, - self.capture_kernel_output, + async def start( + self, launch_kernel: bool = True, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED + ) -> None: + async with create_task_group() as tg: + self.task_group = tg + self.last_activity = { + "date": datetime.now(tz=timezone.utc).isoformat().replace("+00:00", "Z"), + "execution_state": "starting", + } + if launch_kernel: + if not self.kernelspec_path: + raise RuntimeError("Could not find a kernel, maybe you forgot to install one?") + self.kernel_process = await _launch_kernel( + self.kernelspec_path, + self.connection_file_path, + self.kernel_cwd, + self.capture_kernel_output, + ) + assert self.connection_cfg is not None + identity = uuid.uuid4().hex.encode("ascii") + self.shell_channel = connect_channel("shell", self.connection_cfg, identity=identity) + self.stdin_channel = connect_channel("stdin", self.connection_cfg, identity=identity) + self.control_channel = connect_channel( + "control", self.connection_cfg, identity=identity ) - assert self.connection_cfg is not None - identity = uuid.uuid4().hex.encode("ascii") - self.shell_channel = connect_channel("shell", self.connection_cfg, identity=identity) - self.stdin_channel = connect_channel("stdin", self.connection_cfg, identity=identity) - self.control_channel = connect_channel("control", self.connection_cfg, identity=identity) - self.iopub_channel = connect_channel("iopub", self.connection_cfg) - await self._wait_for_ready() - self.channel_tasks += [ - asyncio.create_task(self.listen("shell")), - asyncio.create_task(self.listen("stdin")), - asyncio.create_task(self.listen("control")), - asyncio.create_task(self.listen("iopub")), - ] + self.iopub_channel = connect_channel("iopub", self.connection_cfg) + await self._wait_for_ready() + tg.start_soon(lambda: self.listen("shell")) + tg.start_soon(lambda: self.listen("stdin")) + tg.start_soon(lambda: self.listen("control")) + tg.start_soon(lambda: self.listen("iopub")) + task_status.started() async def stop(self) -> None: + try: + self.kernel_process.terminate() + except ProcessLookupError: + pass + await self.kernel_process.wait() + self.task_group.cancel_scope.cancel() if self.write_connection_file: - # FIXME: stop kernel in a better way + path = anyio.Path(self.connection_file_path) try: - self.kernel_process.send_signal(signal.SIGINT) - self.kernel_process.kill() - await self.kernel_process.wait() - except BaseException: + await path.unlink() + except Exception: pass - try: - os.remove(self.connection_file_path) - except BaseException: - pass - for task in self.channel_tasks: - task.cancel() - self.channel_tasks = [] def interrupt(self) -> None: self.kernel_process.send_signal(signal.SIGINT) - async def restart(self) -> None: + async def restart(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None: await self.stop() self.setup_connection_file() - await self.start() + await self.start(task_status=task_status) async def serve( self, websocket: AcceptedWebSocket, session_id: str, permissions: Optional[Dict[str, List[str]]], + stop_event: Event, ): self.sessions[session_id] = websocket self.can_execute = permissions is None or "execute" in permissions.get("kernels", []) - await self.listen_web(websocket) + async with create_task_group() as tg: + tg.start_soon(self.listen_web, websocket, tg) + tg.start_soon(self._watch_stop, tg, stop_event) + # the session could have been removed through the REST API, so check if it still exists if session_id in self.sessions: del self.sessions[session_id] - async def listen_web(self, websocket: AcceptedWebSocket): + async def _watch_stop(self, tg: TaskGroup, stop_event: Event): + await stop_event.wait() + tg.cancel_scope.cancel() + + async def listen_web(self, websocket: AcceptedWebSocket, tg: TaskGroup): try: await self.send_to_zmq(websocket) except WebSocketDisconnect: pass + tg.cancel_scope.cancel() async def listen(self, channel_name: str): if channel_name == "shell": @@ -264,7 +275,7 @@ async def send_to_ws(self, websocket, parts, parent_header, channel_name): "execution_state": msg["content"]["execution_state"], } elif websocket.accepted_subprotocol == "v1.kernel.websocket.jupyter.org": - bin_msg = serialize_msg_to_ws_v1(parts, channel_name) + bin_msg = b"".join(serialize_msg_to_ws_v1(parts, channel_name)) try: await websocket.websocket.send_bytes(bin_msg) except BaseException: diff --git a/plugins/kernels/fps_kernels/main.py b/plugins/kernels/fps_kernels/main.py index 3cb4abc2..fb5a3a41 100644 --- a/plugins/kernels/fps_kernels/main.py +++ b/plugins/kernels/fps_kernels/main.py @@ -1,19 +1,14 @@ from __future__ import annotations -import asyncio -from collections.abc import AsyncGenerator -from pathlib import Path -from typing import Optional - -from asphalt.core import Component, Context, context_teardown +from asphalt.core import Component, add_resource, get_resource, start_service_task from jupyverse_api.app import App from jupyverse_api.auth import Auth from jupyverse_api.frontend import FrontendConfig from jupyverse_api.kernels import Kernels, KernelsConfig +from jupyverse_api.main import Lifespan from jupyverse_api.yjs import Yjs -from .kernel_driver.paths import jupyter_runtime_dir from .routes import _Kernels @@ -21,36 +16,18 @@ class KernelsComponent(Component): def __init__(self, **kwargs): self.kernels_config = KernelsConfig(**kwargs) - @context_teardown - async def start( - self, - ctx: Context, - ) -> AsyncGenerator[None, Optional[BaseException]]: - ctx.add_resource(self.kernels_config, types=KernelsConfig) - - app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) # type: ignore - frontend_config = await ctx.request_resource(FrontendConfig) - yjs = ( - await ctx.request_resource(Yjs) # type: ignore - if self.kernels_config.require_yjs - else None - ) - - kernels = _Kernels(app, self.kernels_config, auth, frontend_config, yjs) - ctx.add_resource(kernels, types=Kernels) - - if self.kernels_config.allow_external_kernels: - external_connection_dir = self.kernels_config.external_connection_dir - if external_connection_dir is None: - path = Path(jupyter_runtime_dir()) / "external_kernels" - else: - path = Path(external_connection_dir) - task = asyncio.create_task(kernels.watch_connection_files(path)) - - yield - - if self.kernels_config.allow_external_kernels: - task.cancel() - for kernel in kernels.kernels.values(): - await kernel["server"].stop() + async def start(self) -> None: + add_resource(self.kernels_config, types=KernelsConfig) + + app = await get_resource(App, wait=True) + auth = await get_resource(Auth, wait=True) # type: ignore[type-abstract] + frontend_config = await get_resource(FrontendConfig, wait=True) + lifespan = await get_resource(Lifespan, wait=True) + if self.kernels_config.require_yjs: + yjs = await get_resource(Yjs, wait=True) # type: ignore[type-abstract] + else: + yjs = None + + kernels = _Kernels(app, self.kernels_config, auth, frontend_config, yjs, lifespan) # type: ignore[type-abstract] + await start_service_task(kernels.start, "Kernels", teardown_action=kernels.stop) + add_resource(kernels, types=Kernels) diff --git a/plugins/kernels/fps_kernels/routes.py b/plugins/kernels/fps_kernels/routes.py index e15eab80..77e475d5 100644 --- a/plugins/kernels/fps_kernels/routes.py +++ b/plugins/kernels/fps_kernels/routes.py @@ -1,10 +1,13 @@ import json import logging import uuid +from functools import partial from http import HTTPStatus from pathlib import Path from typing import Dict, List, Optional, Set, Tuple +from anyio import TASK_STATUS_IGNORED, Event, create_task_group +from anyio.abc import TaskStatus from fastapi import HTTPException, Response from fastapi.responses import FileResponse from starlette.requests import Request @@ -15,10 +18,12 @@ from jupyverse_api.frontend import FrontendConfig from jupyverse_api.kernels import Kernels, KernelsConfig from jupyverse_api.kernels.models import CreateSession, Execution, Kernel, Notebook, Session +from jupyverse_api.main import Lifespan from jupyverse_api.yjs import Yjs from .kernel_driver.driver import KernelDriver from .kernel_driver.kernelspec import find_kernelspec, kernelspec_dirs +from .kernel_driver.paths import jupyter_runtime_dir from .kernel_server.server import ( AcceptedWebSocket, KernelServer, @@ -36,17 +41,50 @@ def __init__( auth: Auth, frontend_config: FrontendConfig, yjs: Optional[Yjs], + lifespan: Lifespan, ) -> None: super().__init__(app=app, auth=auth) self.kernels_config = kernels_config self.frontend_config = frontend_config self.yjs = yjs + self.lifespan = lifespan self.kernelspecs: dict = {} self.kernel_id_to_connection_file: Dict[str, str] = {} self.sessions: Dict[str, Session] = {} self.kernels = kernels self._app = app + self.stop_event = Event() + + async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None: + async with create_task_group() as tg: + self.task_group = tg + if self.kernels_config.allow_external_kernels: + external_connection_dir = self.kernels_config.external_connection_dir + if external_connection_dir is None: + path = Path(jupyter_runtime_dir()) / "external_kernels" + else: + path = Path(external_connection_dir) + tg.start_soon(lambda: self.watch_connection_files(path)) + tg.start_soon(self.on_shutdown) + task_status.started() + await self.stop_event.wait() + + async def stop(self) -> None: + if self.stop_event.is_set(): + return + + async with create_task_group(): + for kernel in self.kernels.values(): + self.task_group.start_soon(kernel["server"].stop) + if kernel["driver"] is not None: + self.task_group.start_soon(kernel["driver"].stop) + self.stop_event.set() + self.task_group.cancel_scope.cancel() + + async def on_shutdown(self): + await self.lifespan.shutdown_request.wait() + await self.stop() async def get_status( self, @@ -179,7 +217,7 @@ async def create_session( ) kernel_id = str(uuid.uuid4()) kernels[kernel_id] = {"name": kernel_name, "server": kernel_server, "driver": None} - await kernel_server.start() + await self.task_group.start(kernel_server.start) elif kernel_id is not None: # external kernel kernel_name = kernels[kernel_id]["name"] @@ -188,7 +226,7 @@ async def create_session( write_connection_file=False, ) kernels[kernel_id]["server"] = kernel_server - await kernel_server.start(launch_kernel=False) + await self.task_group.start(partial(kernel_server.start, launch_kernel=False)) else: return session_id = str(uuid.uuid4()) @@ -236,7 +274,7 @@ async def restart_kernel( ): if kernel_id in kernels: kernel = kernels[kernel_id] - await kernel["server"].restart() + await self.task_group.start(kernel["server"].restart) result = { "id": kernel_id, "name": kernel["name"], @@ -274,7 +312,7 @@ async def execute_cell( connection_file=kernel["server"].connection_file_path, yjs=self.yjs, ) - await driver.connect() + await self.task_group.start(driver.start) driver = kernel["driver"] await driver.execute(ycell, wait_for_executed=False) @@ -332,16 +370,18 @@ async def kernel_channels( connection_file=self.kernel_id_to_connection_file[kernel_id], write_connection_file=False, ) - await kernel_server.start(launch_kernel=False) + await self.task_group.start(partial(kernel_server.start, launch_kernel=False)) kernels[kernel_id]["server"] = kernel_server - await kernel_server.serve(accepted_websocket, session_id, permissions) + await kernel_server.serve( + accepted_websocket, session_id, permissions, self.lifespan.shutdown_request + ) async def watch_connection_files(self, path: Path) -> None: # first time scan, treat everything as added files initial_changes = {(Change.added, str(p)) for p in path.iterdir()} await self.process_connection_files(initial_changes) # then, on every change - async for changes in awatch(path): + async for changes in awatch(path, stop_event=self.stop_event): await self.process_connection_files(changes) async def process_connection_files(self, changes: Set[Tuple[Change, str]]): diff --git a/plugins/kernels/pyproject.toml b/plugins/kernels/pyproject.toml index 9a11f90e..4340804b 100644 --- a/plugins/kernels/pyproject.toml +++ b/plugins/kernels/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "types-python-dateutil", "watchfiles >=0.16.1,<1", "jupyverse-api >=0.1.2,<1", + "anyio", ] dynamic = [ "version",] [[project.authors]] diff --git a/plugins/lab/fps_lab/main.py b/plugins/lab/fps_lab/main.py index 50292912..2628bab2 100644 --- a/plugins/lab/fps_lab/main.py +++ b/plugins/lab/fps_lab/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource from jupyverse_api.app import App from jupyverse_api.auth import Auth @@ -10,14 +10,11 @@ class LabComponent(Component): - async def start( - self, - ctx: Context, - ) -> None: - app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) # type: ignore - frontend_config = await ctx.request_resource(FrontendConfig) - jupyterlab_config = ctx.get_resource(JupyterLabConfig) + async def start(self) -> None: + app = await get_resource(App, wait=True) + auth = await get_resource(Auth, wait=True) # type: ignore[type-abstract] + frontend_config = await get_resource(FrontendConfig, wait=True) + jupyterlab_config = await get_resource(JupyterLabConfig, optional=True) lab = _Lab(app, auth, frontend_config, jupyterlab_config) - ctx.add_resource(lab, types=Lab) + add_resource(lab, types=Lab) diff --git a/plugins/lab/fps_lab/routes.py b/plugins/lab/fps_lab/routes.py index 3f68f844..11667a35 100644 --- a/plugins/lab/fps_lab/routes.py +++ b/plugins/lab/fps_lab/routes.py @@ -7,6 +7,11 @@ from pathlib import Path from typing import List, Optional, Tuple +if sys.version_info < (3, 10): + from importlib_metadata import entry_points +else: + from importlib.metadata import entry_points + import json5 # type: ignore from babel import Locale from fastapi import Response, status diff --git a/plugins/login/fps_login/main.py b/plugins/login/fps_login/main.py index f513ca57..95550e46 100644 --- a/plugins/login/fps_login/main.py +++ b/plugins/login/fps_login/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource from jupyverse_api.app import App from jupyverse_api.auth import AuthConfig @@ -8,12 +8,9 @@ class LoginComponent(Component): - async def start( - self, - ctx: Context, - ) -> None: - app = await ctx.request_resource(App) - auth_config = await ctx.request_resource(AuthConfig) + async def start(self) -> None: + app = await get_resource(App, wait=True) + auth_config = await get_resource(AuthConfig, wait=True) login = _Login(app, auth_config) - ctx.add_resource(login, types=Login) + add_resource(login, types=Login) diff --git a/plugins/nbconvert/fps_nbconvert/main.py b/plugins/nbconvert/fps_nbconvert/main.py index c865adf8..5b48d53c 100644 --- a/plugins/nbconvert/fps_nbconvert/main.py +++ b/plugins/nbconvert/fps_nbconvert/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource from jupyverse_api.app import App from jupyverse_api.auth import Auth @@ -8,12 +8,9 @@ class NbconvertComponent(Component): - async def start( - self, - ctx: Context, - ) -> None: - app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) # type: ignore + async def start(self) -> None: + app = await get_resource(App, wait=True) + auth = await get_resource(Auth, wait=True) # type: ignore[type-abstract] nbconvert = _Nbconvert(app, auth) - ctx.add_resource(nbconvert, types=Nbconvert) + add_resource(nbconvert, types=Nbconvert) diff --git a/plugins/noauth/fps_noauth/main.py b/plugins/noauth/fps_noauth/main.py index 911a79f8..92bf0846 100644 --- a/plugins/noauth/fps_noauth/main.py +++ b/plugins/noauth/fps_noauth/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource from jupyverse_api.auth import Auth @@ -6,9 +6,6 @@ class NoAuthComponent(Component): - async def start( - self, - ctx: Context, - ) -> None: + async def start(self) -> None: no_auth = _NoAuth() - ctx.add_resource(no_auth, types=Auth) + add_resource(no_auth, types=Auth) diff --git a/plugins/notebook/fps_notebook/main.py b/plugins/notebook/fps_notebook/main.py index 31521a06..a4b9e544 100644 --- a/plugins/notebook/fps_notebook/main.py +++ b/plugins/notebook/fps_notebook/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource from jupyverse_api.app import App from jupyverse_api.auth import Auth @@ -10,14 +10,11 @@ class NotebookComponent(Component): - async def start( - self, - ctx: Context, - ) -> None: - app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) # type: ignore - frontend_config = await ctx.request_resource(FrontendConfig) - lab = await ctx.request_resource(Lab) # type: ignore + async def start(self) -> None: + app = await get_resource(App, wait=True) + auth = await get_resource(Auth, wait=True) # type: ignore[type-abstract] + frontend_config = await get_resource(FrontendConfig, wait=True) + lab = await get_resource(Lab, wait=True) # type: ignore[type-abstract] notebook = _Notebook(app, auth, frontend_config, lab) - ctx.add_resource(notebook, types=Notebook) + add_resource(notebook, types=Notebook) diff --git a/plugins/resource_usage/fps_resource_usage/main.py b/plugins/resource_usage/fps_resource_usage/main.py index 4cc3c2b7..14b64669 100644 --- a/plugins/resource_usage/fps_resource_usage/main.py +++ b/plugins/resource_usage/fps_resource_usage/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource from jupyverse_api.app import App from jupyverse_api.auth import Auth @@ -11,12 +11,9 @@ class ResourceUsageComponent(Component): def __init__(self, **kwargs): self.resource_usage_config = ResourceUsageConfig(**kwargs) - async def start( - self, - ctx: Context, - ) -> None: - app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) # type: ignore + async def start(self) -> None: + app = await get_resource(App, wait=True) + auth = await get_resource(Auth, wait=True) # type: ignore[type-abstract] resource_usage = _ResourceUsage(app, auth, self.resource_usage_config) - ctx.add_resource(resource_usage, types=ResourceUsage) + add_resource(resource_usage, types=ResourceUsage) diff --git a/plugins/terminals/fps_terminals/main.py b/plugins/terminals/fps_terminals/main.py index 93a00719..e56373ff 100644 --- a/plugins/terminals/fps_terminals/main.py +++ b/plugins/terminals/fps_terminals/main.py @@ -1,7 +1,7 @@ -import os +import sys from typing import Type -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource, start_service_task from jupyverse_api.app import App from jupyverse_api.auth import Auth @@ -10,19 +10,17 @@ from .routes import _Terminals _TerminalServer: Type[TerminalServer] -if os.name == "nt": +if sys.platform == "win32": from .win_server import _TerminalServer else: from .server import _TerminalServer class TerminalsComponent(Component): - async def start( - self, - ctx: Context, - ) -> None: - app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) # type: ignore + async def start(self) -> None: + app = await get_resource(App, wait=True) + auth = await get_resource(Auth, wait=True) # type: ignore[type-abstract] terminals = _Terminals(app, auth, _TerminalServer) - ctx.add_resource(terminals, types=Terminals) + await start_service_task(terminals.start, name="Terminals", teardown_action=terminals.stop) + add_resource(terminals, types=Terminals) diff --git a/plugins/terminals/fps_terminals/routes.py b/plugins/terminals/fps_terminals/routes.py index ca7d2f8d..337c07fb 100644 --- a/plugins/terminals/fps_terminals/routes.py +++ b/plugins/terminals/fps_terminals/routes.py @@ -1,7 +1,8 @@ -from datetime import datetime +from datetime import datetime, timezone from http import HTTPStatus from typing import Any, Dict, Type +from anyio import Event, create_task_group from fastapi import Response from jupyverse_api.app import App @@ -15,6 +16,16 @@ class _Terminals(Terminals): def __init__(self, app: App, auth: Auth, _TerminalServer: Type[TerminalServer]) -> None: super().__init__(app=app, auth=auth) self.TerminalServer = _TerminalServer + self.stop_event = Event() + + async def start(self): + await self.stop_event.wait() + + async def stop(self): + async with create_task_group() as tg: + for terminal in TERMINALS.values(): + tg.start_soon(terminal["server"].stop) + self.stop_event.set() async def get_terminals( self, @@ -29,7 +40,7 @@ async def create_terminal( name = str(len(TERMINALS) + 1) terminal = Terminal( name=name, - last_activity=datetime.utcnow().isoformat() + "Z", + last_activity=datetime.now(tz=timezone.utc).isoformat().replace("+00:00", "Z"), ) server = self.TerminalServer() TERMINALS[name] = {"info": terminal, "server": server} diff --git a/plugins/terminals/fps_terminals/server.py b/plugins/terminals/fps_terminals/server.py index c086f8c6..bbdb55b2 100644 --- a/plugins/terminals/fps_terminals/server.py +++ b/plugins/terminals/fps_terminals/server.py @@ -1,75 +1,140 @@ -import asyncio import fcntl import os import pty +import selectors import shlex import struct import termios +from functools import partial +from anyio import create_memory_object_stream, create_task_group, from_thread, to_thread +from anyio.abc import ByteReceiveStream, ByteSendStream from fastapi import WebSocketDisconnect from jupyverse_api.terminals import TerminalServer -def open_terminal(command="bash", columns=80, lines=24): - pid, fd = pty.fork() - if pid == 0: - argv = shlex.split(command) - env = os.environ.copy() - env.update(TERM="linux", COLUMNS=str(columns), LINES=str(lines)) - os.execvpe(argv[0], argv, env) - return fd - - class _TerminalServer(TerminalServer): def __init__(self): - self.fd = open_terminal() + # FIXME: pass in config + command = "bash" + columns = 80 + lines = 24 + + pid, fd = pty.fork() + if pid == 0: + argv = shlex.split(command) + env = os.environ.copy() + env.update(TERM="linux", COLUMNS=str(columns), LINES=str(lines)) + os.execvpe(argv[0], argv, env) + self.fd = fd self.p_out = os.fdopen(self.fd, "w+b", 0) self.websockets = [] - async def serve(self, websocket, permissions): + async def serve(self, websocket, permissions) -> None: self.websocket = websocket + self.permissions = permissions self.websockets.append(websocket) - self.event = asyncio.Event() - self.loop = asyncio.get_event_loop() - - task = asyncio.create_task(self.send_data()) - - def on_output(): - try: - self.data_or_disconnect = self.p_out.read(65536).decode() - self.event.set() - except Exception: - self.loop.remove_reader(self.p_out) - self.data_or_disconnect = None - self.event.set() - - self.loop.add_reader(self.p_out, on_output) - await websocket.send_json(["setup", {}]) - can_execute = permissions is None or "execute" in permissions.get("terminals", []) + + async with create_task_group() as self.task_group: + self.recv_stream = ReceiveStream(self.p_out, self.task_group) + self.send_stream = SendStream(self.p_out) + self.task_group.start_soon(self.backend_to_frontend) + self.task_group.start_soon(self.frontend_to_backend) + + async def stop(self) -> None: + os.write(self.recv_stream.pipeout, b"0") + self.p_out.close() + try: + self.recv_stream.sel.unregister(self.p_out) + except Exception: + pass + self.task_group.cancel_scope.cancel() + + async def backend_to_frontend(self): + while True: + data = (await self.recv_stream.receive(65536)).decode() + for websocket in self.websockets: + await websocket.send_json(["stdout", data]) + + async def frontend_to_backend(self): + await self.websocket.send_json(["setup", {}]) + can_execute = self.permissions is None or "execute" in self.permissions.get("terminals", []) try: while True: - msg = await websocket.receive_json() + msg = await self.websocket.receive_json() if can_execute: if msg[0] == "stdin": - self.p_out.write(msg[1].encode()) + await self.send_stream.send(msg[1].encode()) elif msg[0] == "set_size": winsize = struct.pack("HH", msg[1], msg[2]) fcntl.ioctl(self.fd, termios.TIOCSWINSZ, winsize) except WebSocketDisconnect: - task.cancel() - - async def send_data(self): - while True: - await self.event.wait() - self.event.clear() - if self.data_or_disconnect is None: - await self.websocket.send_json(["disconnect", 1]) - else: - for websocket in self.websockets: - await websocket.send_json(["stdout", self.data_or_disconnect]) + self.quit(self.websocket) + self.task_group.cancel_scope.cancel() def quit(self, websocket): - self.websockets.remove(websocket) - if not self.websockets: - os.close(self.fd) + try: + os.write(self.recv_stream.pipeout, b"0") + self.p_out.close() + self.recv_stream.sel.unregister(self.p_out) + self.websockets.remove(websocket) + if not self.websockets: + os.close(self.fd) + except Exception: + pass + + +class ReceiveStream(ByteReceiveStream): + def __init__(self, p_out, task_group): + self.p_out = p_out + self.sel = selectors.DefaultSelector() + self.sel.register(self.p_out, selectors.EVENT_READ, self._read) + self.pipein, self.pipeout = os.pipe() + f = os.fdopen(self.pipein, "r+b", 0) + + def cb(): + return True + + self.sel.register(f, selectors.EVENT_READ, cb) + self.send_stream, self.recv_stream = create_memory_object_stream[bytes]( + max_buffer_size=65536 + ) + + def reader(): + while True: + events = self.sel.select() + for key, mask in events: + callback = key.data + if callback(): + return + + task_group.start_soon(partial(to_thread.run_sync, reader, abandon_on_cancel=True)) + + def _read(self) -> bool: + try: + data = self.p_out.read(65536) + except OSError: + self.sel.unregister(self.p_out) + return True + else: + from_thread.run_sync(self.send_stream.send_nowait, data) + return False + + async def receive(self, max_bytes: int = 65536) -> bytes: + data = await self.recv_stream.receive() + return data + + async def aclose(self) -> None: + pass + + +class SendStream(ByteSendStream): + def __init__(self, p_out): + self.p_out = p_out + + async def send(self, item: bytes) -> None: + self.p_out.write(item) + + async def aclose(self) -> None: + pass diff --git a/plugins/terminals/fps_terminals/win_server.py b/plugins/terminals/fps_terminals/win_server.py index f1865391..a71c3474 100644 --- a/plugins/terminals/fps_terminals/win_server.py +++ b/plugins/terminals/fps_terminals/win_server.py @@ -1,8 +1,8 @@ -import asyncio import os from functools import partial -from anyio import to_thread +from anyio import create_task_group, to_thread +from fastapi import WebSocketDisconnect from winpty import PTY # type: ignore from jupyverse_api.terminals import TerminalServer @@ -20,16 +20,20 @@ def __init__(self): self.process = open_terminal() self.websockets = [] - async def serve(self, websocket): + async def serve(self, websocket, permissions) -> None: self.websocket = websocket + self.permissions = permissions self.websockets.append(websocket) await websocket.send_json(["setup", {}]) - self.send_task = asyncio.create_task(self.send_data()) - self.recv_task = asyncio.create_task(self.recv_data()) + async with create_task_group() as tg: + self.task_group = tg + tg.start_soon(self.send_data) + tg.start_soon(self.recv_data) - await asyncio.gather(self.send_task, self.recv_task) + async def stop(self) -> None: + self.task_group.cancel_scope.cancel() async def send_data(self): while True: @@ -43,19 +47,21 @@ async def send_data(self): await websocket.send_json(["stdout", data]) async def recv_data(self): - while True: - try: + can_execute = self.permissions is None or "execute" in self.permissions.get("terminals", []) + try: + while True: msg = await self.websocket.receive_json() - except Exception: - return - if msg[0] == "stdin": - self.process.write(msg[1]) - elif msg[0] == "set_size": - self.process.set_size(msg[2], msg[1]) + if can_execute: + if msg[0] == "stdin": + self.process.write(msg[1]) + elif msg[0] == "set_size": + self.process.set_size(msg[2], msg[1]) + except WebSocketDisconnect: + self.quit(self.websocket) + self.task_group.cancel_scope.cancel() def quit(self, websocket): self.websockets.remove(websocket) if not self.websockets: - self.send_task.cancel() - self.recv_task.cancel() + self.task_group.cancel_scope.cancel() del self.process diff --git a/plugins/webdav/fps_webdav/main.py b/plugins/webdav/fps_webdav/main.py index 2e6c4662..7da847bd 100644 --- a/plugins/webdav/fps_webdav/main.py +++ b/plugins/webdav/fps_webdav/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource from jupyverse_api.app import App @@ -10,11 +10,8 @@ class WebDAVComponent(Component): def __init__(self, **kwargs): self.webdav_config = WebDAVConfig(**kwargs) - async def start( - self, - ctx: Context, - ) -> None: - app = await ctx.request_resource(App) + async def start(self) -> None: + app = await get_resource(App, wait=True) webdav = WebDAV(app, self.webdav_config) - ctx.add_resource(webdav) + add_resource(webdav) diff --git a/plugins/webdav/fps_webdav/routes.py b/plugins/webdav/fps_webdav/routes.py index b31b1fc9..bf67903a 100644 --- a/plugins/webdav/fps_webdav/routes.py +++ b/plugins/webdav/fps_webdav/routes.py @@ -46,7 +46,7 @@ def __init__(self, app: App, webdav_config: WebDAVConfig): for account in webdav_config.account_mapping: logger.info(f"WebDAV user {account.username} has password {account.password}") - webdav_conf = webdav_config.dict() + webdav_conf = webdav_config.model_dump() init_config_from_obj(webdav_conf) webdav_aep = AppEntryParameters() webdav_app = get_asgi_app(aep=webdav_aep, config_obj=webdav_conf) diff --git a/plugins/webdav/pyproject.toml b/plugins/webdav/pyproject.toml index 718adeee..62f23725 100644 --- a/plugins/webdav/pyproject.toml +++ b/plugins/webdav/pyproject.toml @@ -30,7 +30,6 @@ Homepage = "https://jupyter.org" test = [ "easywebdav", "pytest", - "pytest-asyncio", ] [tool.check-manifest] diff --git a/plugins/webdav/tests/conftest.py b/plugins/webdav/tests/conftest.py new file mode 100644 index 00000000..af7e4799 --- /dev/null +++ b/plugins/webdav/tests/conftest.py @@ -0,0 +1,6 @@ +import pytest + + +@pytest.fixture +def anyio_backend(): + return "asyncio" diff --git a/plugins/webdav/tests/test_webdav.py b/plugins/webdav/tests/test_webdav.py index dc6ed856..a4e53df4 100644 --- a/plugins/webdav/tests/test_webdav.py +++ b/plugins/webdav/tests/test_webdav.py @@ -24,17 +24,17 @@ def configure(components, config): return _components -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python >=3.10") async def test_webdav(unused_tcp_port): components = configure( COMPONENTS, {"webdav": {"account_mapping": [{"username": "foo", "password": "bar"}]}} ) - async with Context() as ctx: + async with Context(): await JupyverseComponent( components=components, port=unused_tcp_port, - ).start(ctx) + ).start() webdav = easywebdav.connect( "127.0.0.1", port=unused_tcp_port, path="webdav", username="foo", password="bar" diff --git a/plugins/yjs/fps_yjs/main.py b/plugins/yjs/fps_yjs/main.py index eacd1b91..18c9daa8 100644 --- a/plugins/yjs/fps_yjs/main.py +++ b/plugins/yjs/fps_yjs/main.py @@ -1,36 +1,36 @@ from __future__ import annotations -from collections.abc import AsyncGenerator -from typing import Optional - -from asphalt.core import Component, Context, context_teardown +from asphalt.core import ( + Component, + add_resource, + get_resource, + start_background_task_factory, + start_service_task, +) from jupyverse_api.app import App from jupyverse_api.auth import Auth from jupyverse_api.contents import Contents +from jupyverse_api.main import Lifespan from jupyverse_api.yjs import Yjs from .routes import _Yjs class YjsComponent(Component): - @context_teardown - async def start( - self, - ctx: Context, - ) -> AsyncGenerator[None, Optional[BaseException]]: - app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) # type: ignore - contents = await ctx.request_resource(Contents) # type: ignore - - yjs = _Yjs(app, auth, contents) - ctx.add_resource(yjs, types=Yjs) - - # start indexing in the background - contents.file_id_manager - - yield - - yjs.room_manager.stop() - contents.file_id_manager.stop_watching_files.set() - await contents.file_id_manager.stopped_watching_files.wait() + async def start(self) -> None: + app = await get_resource(App, wait=True) + auth = await get_resource(Auth, wait=True) # type: ignore[type-abstract] + contents = await get_resource(Contents, wait=True) # type: ignore[type-abstract] + lifespan = await get_resource(Lifespan, wait=True) + + task_factory = await start_background_task_factory("yjs_tasks") + yjs = _Yjs(app, auth, contents, lifespan, task_factory) + add_resource(yjs, types=Yjs) + + await start_service_task(yjs.start, "Room manager", teardown_action=yjs.stop) + await start_service_task( + contents.file_id_manager.start, + "File ID manager", + teardown_action=contents.file_id_manager.stop, + ) diff --git a/plugins/yjs/fps_yjs/routes.py b/plugins/yjs/fps_yjs/routes.py index 1a023d95..4835953c 100644 --- a/plugins/yjs/fps_yjs/routes.py +++ b/plugins/yjs/fps_yjs/routes.py @@ -1,12 +1,14 @@ from __future__ import annotations -import asyncio import logging from datetime import datetime from functools import partial from typing import Dict from uuid import uuid4 +from anyio import TASK_STATUS_IGNORED, sleep +from anyio.abc import TaskStatus +from asphalt.core import TaskFactory, TaskHandle from fastapi import ( HTTPException, Request, @@ -17,9 +19,11 @@ from pycrdt import Doc from websockets.exceptions import ConnectionClosedOK +from jupyverse_api import ResourceLock from jupyverse_api.app import App from jupyverse_api.auth import Auth, User from jupyverse_api.contents import Contents +from jupyverse_api.main import Lifespan from jupyverse_api.yjs import Yjs from jupyverse_api.yjs.models import CreateDocumentSession @@ -46,15 +50,35 @@ def __init__( app: App, auth: Auth, contents: Contents, + lifespan: Lifespan, + task_factory: TaskFactory, ) -> None: super().__init__(app=app, auth=auth) self.contents = contents - self.room_manager = RoomManager(contents) + self.lifespan = lifespan + self.task_factory = task_factory if Widgets is None: self.widgets = None else: self.widgets = Widgets() # type: ignore + async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None: + self.room_manager = RoomManager(self.contents, self.lifespan, self.task_factory) + await self.task_factory.start_task( + self.room_manager.websocket_server.start, + "WebSocket server", + ) + self.task_factory.start_task_soon(self.room_manager.on_shutdown) + task_status.started() + + async def stop(self) -> None: + for task in ( + list(self.room_manager.watchers.values()) + + list(self.room_manager.savers.values()) + + list(self.room_manager.cleaners.values()) + ): + task.cancel() + async def collaboration_room_websocket( self, path, @@ -140,92 +164,102 @@ async def recv(self): class RoomManager: contents: Contents + lifespan: Lifespan documents: Dict[str, YBaseDoc] - watchers: Dict[str, asyncio.Task] - savers: Dict[str, asyncio.Task] - cleaners: Dict[YRoom, asyncio.Task] + watchers: Dict[str, TaskHandle] + savers: Dict[str, TaskHandle] + cleaners: Dict[YRoom, TaskHandle] last_modified: Dict[str, datetime] websocket_server: JupyterWebsocketServer - lock: asyncio.Lock + room_lock: ResourceLock + task_factory: TaskFactory - def __init__(self, contents: Contents): + def __init__(self, contents: Contents, lifespan: Lifespan, task_factory: TaskFactory): self.contents = contents + self.lifespan = lifespan + self.task_factory = task_factory self.documents = {} # a dictionary of room_name:document self.watchers = {} # a dictionary of file_id:task self.savers = {} # a dictionary of file_id:task self.cleaners = {} # a dictionary of room:task self.last_modified = {} # a dictionary of file_id:last_modification_date self.websocket_server = JupyterWebsocketServer(rooms_ready=False, auto_clean_rooms=False) - self.websocket_server_task = asyncio.create_task(self.websocket_server.start()) - self.lock = asyncio.Lock() - - def stop(self): - for watcher in self.watchers.values(): - watcher.cancel() - for saver in self.savers.values(): - saver.cancel() - for cleaner in self.cleaners.values(): - cleaner.cancel() - self.websocket_server.stop() + self.room_lock = ResourceLock() + + async def on_shutdown(self): + await self.lifespan.shutdown_request.wait() + await self.websocket_server.stop() async def serve(self, websocket: YWebsocket, permissions) -> None: - room = await self.websocket_server.get_room(websocket.path) - can_write = permissions is None or "write" in permissions.get("yjs", []) - room.on_message = partial(self.filter_message, can_write) - is_stored_document = websocket.path.count(":") >= 2 - if is_stored_document: - assert room.ystore is not None - file_format, file_type, file_id = websocket.path.split(":", 2) - if room in self.cleaners: - # cleaning the room was scheduled because there was no client left - # cancel that since there is a new client - self.cleaners[room].cancel() - if not room.ready: - file_path = await self.contents.file_id_manager.get_path(file_id) - logger.info(f"Opening collaboration room: {websocket.path} ({file_path})") - document = YDOCS.get(file_type, YFILE)(room.ydoc) - document.file_id = file_id - self.documents[websocket.path] = document - async with self.lock: - model = await self.contents.read_content(file_path, True, file_format) - assert model.last_modified is not None - self.last_modified[file_id] = to_datetime(model.last_modified) + async with self.room_lock(websocket.path): + room = await self.websocket_server.get_room(websocket.path) + can_write = permissions is None or "write" in permissions.get("yjs", []) + room.on_message = partial(self.filter_message, can_write) + is_stored_document = websocket.path.count(":") >= 2 + if is_stored_document: + assert room.ystore is not None + file_format, file_type, file_id = websocket.path.split(":", 2) + if room in self.cleaners: + # cleaning the room was scheduled because there was no client left + # cancel that since there is a new client + self.cleaners[room].cancel() + await self.cleaners[room].wait_finished() + if room in self.cleaners: + del self.cleaners[room] if not room.ready: - # try to apply Y updates from the YStore for this document - try: - await room.ystore.apply_updates(room.ydoc) - read_from_source = False - except YDocNotFound: - # YDoc not found in the YStore, create the document from - # the source file (no change history) - read_from_source = True - if not read_from_source: - # if YStore updates and source file are out-of-sync, resync updates - # with source - if document.source != model.content: + file_path = await self.contents.file_id_manager.get_path(file_id) + logger.info(f"Opening collaboration room: {websocket.path} ({file_path})") + document = YDOCS.get(file_type, YFILE)(room.ydoc) + document.file_id = file_id + self.documents[websocket.path] = document + model = await self.contents.read_content(file_path, True, file_format) + assert model.last_modified is not None + self.last_modified[file_id] = to_datetime(model.last_modified) + if not room.ready: + # try to apply Y updates from the YStore for this document + try: + await room.ystore.apply_updates(room.ydoc) + read_from_source = False + except YDocNotFound: + # YDoc not found in the YStore, create the document from + # the source file (no change history) read_from_source = True - if read_from_source: - document.source = model.content - await room.ystore.encode_state_as_update(room.ydoc) - - document.dirty = False - room.ready = True - # save the document to file when changed - document.observe( - partial(self.on_document_change, file_id, file_type, file_format, document) - ) - # update the document when file changes - if file_id not in self.watchers: - self.watchers[file_id] = asyncio.create_task( - self.watch_file(file_format, file_id, document) + if not read_from_source: + # if YStore updates and source file are out-of-sync, resync updates + # with source + if document.source != model.content: + read_from_source = True + if read_from_source: + document.source = model.content + await room.ystore.encode_state_as_update(room.ydoc) + + document.dirty = False + room.ready = True + # save the document to file when changed + document.observe( + partial( + self.on_document_change, + file_id, + file_type, + file_format, + document, + ) ) + # update the document when file changes + if file_id not in self.watchers: + self.watchers[file_id] = self.task_factory.start_task_soon( + lambda: self.watch_file(file_format, file_id, document), + f"Watch file {file_id}" + ) - await self.websocket_server.started.wait() - await self.websocket_server.serve(websocket) + await self.websocket_server.serve(websocket, self.lifespan.shutdown_request) if is_stored_document and not room.clients: # no client in this room after we disconnect - self.cleaners[room] = asyncio.create_task(self.maybe_clean_room(room, websocket.path)) + self.cleaners[room] = self.task_factory.start_task_soon( + lambda: self.maybe_clean_room(room, websocket.path), + f"Clean room {websocket.path}" + ) async def filter_message(self, can_write: bool, message: bytes) -> bool: """ @@ -262,28 +296,28 @@ async def watch_file(self, file_format: str, file_id: str, document: YBaseDoc) - file_path = await self.get_file_path(file_id, document) assert file_path is not None logger.debug(f"Watching file: {file_path}") - while True: - watcher = self.contents.file_id_manager.watch(file_path) - async for changes in watcher: - new_file_path = await self.get_file_path(file_id, document) - if new_file_path is None: - continue - if new_file_path != file_path: - # file was renamed - self.contents.file_id_manager.unwatch(file_path, watcher) - file_path = new_file_path - # break - await self.maybe_load_file(file_format, file_path, file_id) + # FIXME: handle file rename/move? + watcher = self.contents.file_id_manager.watch(file_path) + async for changes in watcher: + new_file_path = await self.get_file_path(file_id, document) + if new_file_path is None: + continue + if new_file_path != file_path: + # file was renamed + self.contents.file_id_manager.unwatch(file_path, watcher) + file_path = new_file_path + # break + await self.maybe_load_file(file_format, file_path, file_id) + if file_id in self.watchers: + del self.watchers[file_id] async def maybe_load_file(self, file_format: str, file_path: str, file_id: str) -> None: - async with self.lock: - model = await self.contents.read_content(file_path, False) + model = await self.contents.read_content(file_path, False) # do nothing if the file was saved by us assert model.last_modified is not None if self.last_modified[file_id] < to_datetime(model.last_modified): # the file was not saved by us, update the shared document(s) - async with self.lock: - model = await self.contents.read_content(file_path, True, file_format) + model = await self.contents.read_content(file_path, True, file_format) assert model.last_modified is not None documents = [v for k, v in self.documents.items() if k.split(":", 2)[2] == file_id] for document in documents: @@ -306,23 +340,23 @@ def on_document_change( ) if file_id in self.savers: self.savers[file_id].cancel() - self.savers[file_id] = asyncio.create_task( - self.maybe_save_document(file_id, file_type, file_format, document) + self.savers[file_id] = self.task_factory.start_task_soon( + lambda: self.maybe_save_document(file_id, file_type, file_format, document), + f"Save file {file_id}" ) async def maybe_save_document( self, file_id: str, file_type: str, file_format: str, document: YBaseDoc ) -> None: # save after 1 second of inactivity to prevent too frequent saving - await asyncio.sleep(1) # FIXME: pass in config + await sleep(1) # FIXME: pass in config # if the room cannot be found, don't save try: file_path = await self.get_file_path(file_id, document) except Exception: return assert file_path is not None - async with self.lock: - model = await self.contents.read_content(file_path, True, file_format) + model = await self.contents.read_content(file_path, True, file_format) assert model.last_modified is not None if self.last_modified[file_id] < to_datetime(model.last_modified): # file changed on disk, let's revert @@ -339,30 +373,34 @@ async def maybe_save_document( "path": file_path, "type": file_type, } - async with self.lock: - await self.contents.write_content(content) - model = await self.contents.read_content(file_path, False) + await self.contents.write_content(content) + model = await self.contents.read_content(file_path, False) assert model.last_modified is not None self.last_modified[file_id] = to_datetime(model.last_modified) document.dirty = False # we're done saving, remove the saver - del self.savers[file_id] + if file_id in self.savers: + del self.savers[file_id] async def maybe_clean_room(self, room, ws_path: str) -> None: file_id = ws_path.split(":", 2)[2] # keep the document for a while in case someone reconnects - await asyncio.sleep(60) # FIXME: pass in config + await sleep(60) # FIXME: pass in config document = self.documents[ws_path] document.unobserve() del self.documents[ws_path] documents = [v for k, v in self.documents.items() if k.split(":", 2)[2] == file_id] if not documents: self.watchers[file_id].cancel() - del self.watchers[file_id] + await self.watchers[file_id].wait_finished() + if file_id in self.watchers: + del self.watchers[file_id] room_name = self.websocket_server.get_room_name(room) self.websocket_server.delete_room(room=room) file_path = await self.get_file_path(file_id, document) logger.info(f"Closing collaboration room: {room_name} ({file_path})") + if room in self.cleaners: + del self.cleaners[room] class JupyterWebsocketServer(WebsocketServer): diff --git a/plugins/yjs/fps_yjs/ydocs/ybasedoc.py b/plugins/yjs/fps_yjs/ydocs/ybasedoc.py index 7ea34ed2..ef7c1087 100644 --- a/plugins/yjs/fps_yjs/ydocs/ybasedoc.py +++ b/plugins/yjs/fps_yjs/ydocs/ybasedoc.py @@ -16,8 +16,7 @@ def __init__(self, ydoc: Optional[Doc] = None): @property @abstractmethod - def version(self) -> str: - ... + def version(self) -> str: ... @property def ystate(self) -> Map: @@ -60,16 +59,13 @@ def file_id(self, value: str) -> None: self._ystate["file_id"] = value @abstractmethod - def get(self) -> Any: - ... + def get(self) -> Any: ... @abstractmethod - def set(self, value: Any) -> None: - ... + def set(self, value: Any) -> None: ... @abstractmethod - def observe(self, callback: Callable[[str, Any], None]) -> None: - ... + def observe(self, callback: Callable[[str, Any], None]) -> None: ... def unobserve(self) -> None: for k, v in self._subscriptions.items(): diff --git a/plugins/yjs/fps_yjs/ywebsocket/websocket_provider.py b/plugins/yjs/fps_yjs/ywebsocket/websocket_provider.py index 1e7fb5a2..755e09e0 100644 --- a/plugins/yjs/fps_yjs/ywebsocket/websocket_provider.py +++ b/plugins/yjs/fps_yjs/ywebsocket/websocket_provider.py @@ -25,8 +25,6 @@ class WebsocketProvider: - """WebSocket provider.""" - _ydoc: Doc _update_send_stream: MemoryObjectSendStream _update_receive_stream: MemoryObjectReceiveStream @@ -35,26 +33,6 @@ class WebsocketProvider: _task_group: TaskGroup | None def __init__(self, ydoc: Doc, websocket: Websocket, log: Logger | None = None) -> None: - """Initialize the object. - - The WebsocketProvider instance should preferably be used as an async context manager: - ```py - async with websocket_provider: - ... - ``` - However, a lower-level API can also be used: - ```py - task = asyncio.create_task(websocket_provider.start()) - await websocket_provider.started.wait() - ... - websocket_provider.stop() - ``` - - Arguments: - ydoc: The YDoc to connect through the WebSocket. - websocket: The WebSocket through which to connect the YDoc. - log: An optional logger. - """ self._ydoc = ydoc self._websocket = websocket self.log = log or getLogger(__name__) @@ -68,7 +46,6 @@ def __init__(self, ydoc: Doc, websocket: Websocket, log: Logger | None = None) - @property def started(self) -> Event: - """An async event that is set when the WebSocket provider has started.""" if self._started is None: self._started = Event() return self._started @@ -111,11 +88,6 @@ async def _send(self): pass async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): - """Start the WebSocket provider. - - Arguments: - task_status: The status to set when the task has started. - """ if self._starting: return else: @@ -131,7 +103,6 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): task_status.started() def stop(self): - """Stop the WebSocket provider.""" if self._task_group is None: raise RuntimeError("WebsocketProvider not running") diff --git a/plugins/yjs/fps_yjs/ywebsocket/websocket_server.py b/plugins/yjs/fps_yjs/ywebsocket/websocket_server.py index 40100211..27b69fb5 100644 --- a/plugins/yjs/fps_yjs/ywebsocket/websocket_server.py +++ b/plugins/yjs/fps_yjs/ywebsocket/websocket_server.py @@ -1,6 +1,5 @@ from __future__ import annotations -from contextlib import AsyncExitStack from logging import Logger, getLogger from anyio import TASK_STATUS_IGNORED, Event, create_task_group @@ -12,61 +11,19 @@ class WebsocketServer: - """WebSocket server.""" - auto_clean_rooms: bool rooms: dict[str, YRoom] - _started: Event | None - _starting: bool - _task_group: TaskGroup | None + _task_group: TaskGroup def __init__( self, rooms_ready: bool = True, auto_clean_rooms: bool = True, log: Logger | None = None ) -> None: - """Initialize the object. - - The WebsocketServer instance should preferably be used as an async context manager: - ```py - async with websocket_server: - ... - ``` - However, a lower-level API can also be used: - ```py - task = asyncio.create_task(websocket_server.start()) - await websocket_server.started.wait() - ... - websocket_server.stop() - ``` - - Arguments: - rooms_ready: Whether rooms are ready to be synchronized when opened. - auto_clean_rooms: Whether rooms should be deleted when no client is there anymore. - log: An optional logger. - """ self.rooms_ready = rooms_ready self.auto_clean_rooms = auto_clean_rooms self.log = log or getLogger(__name__) self.rooms = {} - self._started = None - self._starting = False - self._task_group = None - - @property - def started(self) -> Event: - """An async event that is set when the WebSocket server has started.""" - if self._started is None: - self._started = Event() - return self._started async def get_room(self, name: str, ydoc: Doc | None = None) -> YRoom: - """Get or create a room with the given name, and start it. - - Arguments: - name: The room name. - - Returns: - The room with the given name, or a new one if no room with that name was found. - """ if name not in self.rooms.keys(): self.rooms[name] = YRoom(ydoc=ydoc, ready=self.rooms_ready, log=self.log) room = self.rooms[name] @@ -74,41 +31,15 @@ async def get_room(self, name: str, ydoc: Doc | None = None) -> YRoom: return room async def start_room(self, room: YRoom) -> None: - """Start a room, if not already started. - - Arguments: - room: The room to start. - """ - if self._task_group is None: - raise RuntimeError( - "The WebsocketServer is not running: use `async with websocket_server:` " - "or `await websocket_server.start()`" - ) - if not room.started.is_set(): await self._task_group.start(room.start) def get_room_name(self, room: YRoom) -> str: - """Get the name of a room. - - Arguments: - room: The room to get the name from. - - Returns: - The room name. - """ return list(self.rooms.keys())[list(self.rooms.values()).index(room)] def rename_room( self, to_name: str, *, from_name: str | None = None, from_room: YRoom | None = None ) -> None: - """Rename a room. - - Arguments: - to_name: The new name of the room. - from_name: The previous name of the room (if `from_room` is not passed). - from_room: The room to be renamed (if `from_name` is not passed). - """ if from_name is not None and from_room is not None: raise RuntimeError("Cannot pass from_name and from_room") if from_name is None: @@ -117,12 +48,6 @@ def rename_room( self.rooms[to_name] = self.rooms.pop(from_name) def delete_room(self, *, name: str | None = None, room: YRoom | None = None) -> None: - """Delete a room. - - Arguments: - name: The name of the room to delete (if `room` is not passed). - room: The room to delete ( if `name` is not passed). - """ if name is not None and room is not None: raise RuntimeError("Cannot pass name and room") if name is None: @@ -131,20 +56,15 @@ def delete_room(self, *, name: str | None = None, room: YRoom | None = None) -> room = self.rooms.pop(name) room.stop() - async def serve(self, websocket: Websocket) -> None: - """Serve a client through a WebSocket. - - Arguments: - websocket: The WebSocket through which to serve the client. - """ - if self._task_group is None: - raise RuntimeError( - "The WebsocketServer is not running: use `async with websocket_server:` " - "or `await websocket_server.start()`" - ) - + async def serve(self, websocket: Websocket, stop_event: Event | None = None) -> None: async with create_task_group() as tg: tg.start_soon(self._serve, websocket, tg) + if stop_event is not None: + tg.start_soon(self._watch_stop, tg, stop_event) + + async def _watch_stop(self, tg: TaskGroup, stop_event: Event): + await stop_event.wait() + tg.cancel_scope.cancel() async def _serve(self, websocket: Websocket, tg: TaskGroup): room = await self.get_room(websocket.path) @@ -155,51 +75,12 @@ async def _serve(self, websocket: Websocket, tg: TaskGroup): self.delete_room(room=room) tg.cancel_scope.cancel() - async def __aenter__(self) -> WebsocketServer: - if self._task_group is not None: - raise RuntimeError("WebsocketServer already running") - - async with AsyncExitStack() as exit_stack: - tg = create_task_group() - self._task_group = await exit_stack.enter_async_context(tg) - self._exit_stack = exit_stack.pop_all() - self.started.set() - - return self - - async def __aexit__(self, exc_type, exc_value, exc_tb): - if self._task_group is None: - raise RuntimeError("WebsocketServer not running") - - self._task_group.cancel_scope.cancel() - self._task_group = None - return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) - async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): - """Start the WebSocket server. - - Arguments: - task_status: The status to set when the task has started. - """ - if self._starting: - return - else: - self._starting = True - - if self._task_group is not None: - raise RuntimeError("WebsocketServer already running") - # create the task group and wait forever - async with create_task_group() as self._task_group: - self._task_group.start_soon(Event().wait) - self.started.set() - self._starting = False + async with create_task_group() as tg: + self._task_group = tg + tg.start_soon(Event().wait) task_status.started() - def stop(self) -> None: - """Stop the WebSocket server.""" - if self._task_group is None: - raise RuntimeError("WebsocketServer not running") - + async def stop(self) -> None: self._task_group.cancel_scope.cancel() - self._task_group = None diff --git a/plugins/yjs/fps_yjs/ywebsocket/yroom.py b/plugins/yjs/fps_yjs/ywebsocket/yroom.py index 15fd41de..6a2bb01b 100644 --- a/plugins/yjs/fps_yjs/ywebsocket/yroom.py +++ b/plugins/yjs/fps_yjs/ywebsocket/yroom.py @@ -191,11 +191,6 @@ def stop(self): self._task_group = None async def serve(self, websocket: Websocket): - """Serve a client. - - Arguments: - websocket: The WebSocket through which to serve the client. - """ async with create_task_group() as tg: self.clients.append(websocket) await sync(self.ydoc, websocket, self.log) diff --git a/plugins/yjs/fps_yjs/ywebsocket/ystore.py b/plugins/yjs/fps_yjs/ywebsocket/ystore.py index 127a542e..1615b91b 100644 --- a/plugins/yjs/fps_yjs/ywebsocket/ystore.py +++ b/plugins/yjs/fps_yjs/ywebsocket/ystore.py @@ -10,11 +10,11 @@ from pathlib import Path from typing import AsyncIterator, Awaitable, Callable, cast -import aiosqlite import anyio from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group from anyio.abc import TaskGroup, TaskStatus from pycrdt import Doc +from sqlite_anyio import connect from .yutils import Decoder, get_new_path, write_var_uint @@ -33,16 +33,13 @@ class BaseYStore(ABC): @abstractmethod def __init__( self, path: str, metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None - ): - ... + ): ... @abstractmethod - async def write(self, data: bytes) -> None: - ... + async def write(self, data: bytes) -> None: ... @abstractmethod - async def read(self) -> AsyncIterator[tuple[bytes, bytes]]: - ... + async def read(self) -> AsyncIterator[tuple[bytes, bytes]]: ... @property def started(self) -> Event: @@ -58,16 +55,12 @@ async def __aenter__(self) -> BaseYStore: tg = create_task_group() self._task_group = await exit_stack.enter_async_context(tg) self._exit_stack = exit_stack.pop_all() - tg.start_soon(self.start) + await tg.start(self.start) return self async def __aexit__(self, exc_type, exc_value, exc_tb): - if self._task_group is None: - raise RuntimeError("YStore not running") - - self._task_group.cancel_scope.cancel() - self._task_group = None + await self.stop() return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): @@ -78,8 +71,8 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): """ if self._starting: return - else: - self._starting = True + + self._starting = True if self._task_group is not None: raise RuntimeError("YStore already running") @@ -88,7 +81,7 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): self._starting = False task_status.started() - def stop(self) -> None: + async def stop(self) -> None: """Stop the store.""" if self._task_group is None: raise RuntimeError("YStore not running") @@ -327,19 +320,14 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): Arguments: task_status: The status to set when the task has started. """ - if self._starting: - return - else: - self._starting = True + self._db = await connect(self.db_path) + await self._init_db() + await super().start(task_status=task_status) - if self._task_group is not None: - raise RuntimeError("YStore already running") - - async with create_task_group() as self._task_group: - self._task_group.start_soon(self._init_db) - self.started.set() - self._starting = False - task_status.started() + async def stop(self) -> None: + """Stop the store.""" + await self._db.close() + await super().stop() async def _init_db(self): create_db = False @@ -348,36 +336,36 @@ async def _init_db(self): create_db = True else: async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - cursor = await db.execute( - "SELECT count(name) FROM sqlite_master " - "WHERE type='table' and name='yupdates'" - ) - table_exists = (await cursor.fetchone())[0] - if table_exists: - cursor = await db.execute("pragma user_version") - version = (await cursor.fetchone())[0] - if version != self.version: - move_db = True - create_db = True - else: + cursor = await self._db.cursor() + await cursor.execute( + "SELECT count(name) FROM sqlite_master " + "WHERE type='table' and name='yupdates'" + ) + table_exists = (await cursor.fetchone())[0] + if table_exists: + await cursor.execute("pragma user_version") + version = (await cursor.fetchone())[0] + if version != self.version: + move_db = True create_db = True + else: + create_db = True if move_db: new_path = await get_new_path(self.db_path) self.log.warning(f"YStore version mismatch, moving {self.db_path} to {new_path}") await anyio.Path(self.db_path).rename(new_path) if create_db: async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - await db.execute( - "CREATE TABLE yupdates " - "(path TEXT NOT NULL, yupdate BLOB, metadata BLOB, timestamp REAL NOT NULL)" - ) - await db.execute( - "CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)" - ) - await db.execute(f"PRAGMA user_version = {self.version}") - await db.commit() + cursor = await self._db.cursor() + await cursor.execute( + "CREATE TABLE yupdates " + "(path TEXT NOT NULL, yupdate BLOB, metadata BLOB, timestamp REAL NOT NULL)" + ) + await cursor.execute( + "CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)" + ) + await cursor.execute(f"PRAGMA user_version = {self.version}") + await self._db.commit() self.db_initialized.set() async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]: # type: ignore @@ -389,17 +377,17 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]: # type: igno await self.db_initialized.wait() try: async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - async with db.execute( - "SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?", - (self.path,), - ) as cursor: - found = False - async for update, metadata, timestamp in cursor: - found = True - yield update, metadata, timestamp - if not found: - raise YDocNotFound + cursor = await self._db.cursor() + await cursor.execute( + "SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?", + (self.path,), + ) + found = False + for update, metadata, timestamp in await cursor.fetchall(): + found = True + yield update, metadata, timestamp + if not found: + raise YDocNotFound except Exception: raise YDocNotFound @@ -411,37 +399,35 @@ async def write(self, data: bytes) -> None: """ await self.db_initialized.wait() async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - # first, determine time elapsed since last update - cursor = await db.execute( - "SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1", - (self.path,), - ) - row = await cursor.fetchone() - diff = (time.time() - row[0]) if row else 0 - - if self.document_ttl is not None and diff > self.document_ttl: - # squash updates - ydoc = Doc() - async with db.execute( - "SELECT yupdate FROM yupdates WHERE path = ?", (self.path,) - ) as cursor: - async for update, in cursor: - ydoc.apply_update(update) - # delete history - await db.execute("DELETE FROM yupdates WHERE path = ?", (self.path,)) - # insert squashed updates - squashed_update = ydoc.get_update() - metadata = await self.get_metadata() - await db.execute( - "INSERT INTO yupdates VALUES (?, ?, ?, ?)", - (self.path, squashed_update, metadata, time.time()), - ) - - # finally, write this update to the DB + # first, determine time elapsed since last update + cursor = await self._db.cursor() + await cursor.execute( + "SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1", + (self.path,), + ) + row = await cursor.fetchone() + diff = (time.time() - row[0]) if row else 0 + + if self.document_ttl is not None and diff > self.document_ttl: + # squash updates + ydoc = Doc() + await cursor.execute("SELECT yupdate FROM yupdates WHERE path = ?", (self.path,)) + for (update,) in await cursor.fetchall(): + ydoc.apply_update(update) + # delete history + await cursor.execute("DELETE FROM yupdates WHERE path = ?", (self.path,)) + # insert squashed updates + squashed_update = ydoc.get_update() metadata = await self.get_metadata() - await db.execute( + await cursor.execute( "INSERT INTO yupdates VALUES (?, ?, ?, ?)", - (self.path, data, metadata, time.time()), + (self.path, squashed_update, metadata, time.time()), ) - await db.commit() + + # finally, write this update to the DB + metadata = await self.get_metadata() + await cursor.execute( + "INSERT INTO yupdates VALUES (?, ?, ?, ?)", + (self.path, data, metadata, time.time()), + ) + await self._db.commit() diff --git a/plugins/yjs/fps_yjs/ywebsocket/yutils.py b/plugins/yjs/fps_yjs/ywebsocket/yutils.py index fe731116..5ccec736 100644 --- a/plugins/yjs/fps_yjs/ywebsocket/yutils.py +++ b/plugins/yjs/fps_yjs/ywebsocket/yutils.py @@ -4,6 +4,7 @@ from pathlib import Path import anyio +from anyio import BrokenResourceError from anyio.streams.memory import MemoryObjectSendStream from pycrdt import Doc, TransactionEvent @@ -99,7 +100,10 @@ def read_var_string(self): def put_updates(update_send_stream: MemoryObjectSendStream, event: TransactionEvent) -> None: update = event.update # type: ignore - update_send_stream.send_nowait(update) + try: + update_send_stream.send_nowait(update) + except BrokenResourceError: + pass async def process_sync_message(message: bytes, ydoc: Doc, websocket, log) -> None: diff --git a/plugins/yjs/fps_yjs/ywidgets/widgets.py b/plugins/yjs/fps_yjs/ywidgets/widgets.py index 52eeae03..206e9560 100644 --- a/plugins/yjs/fps_yjs/ywidgets/widgets.py +++ b/plugins/yjs/fps_yjs/ywidgets/widgets.py @@ -11,6 +11,7 @@ process_sync_message, sync, ) + ypywidgets_installed = True except ImportError: ypywidgets_installed = False @@ -24,11 +25,10 @@ Widgets: Any if ypywidgets_installed: + class Widgets: # type: ignore def __init__(self): - self.ydocs = { - ep.name: ep.load() for ep in entry_points(group="ypywidgets") - } + self.ydocs = {ep.name: ep.load() for ep in entry_points(group="ypywidgets")} self.widgets = {} def comm_open(self, msg, comm) -> None: diff --git a/plugins/yjs/pyproject.toml b/plugins/yjs/pyproject.toml index 94ac7f81..b45cc29d 100644 --- a/plugins/yjs/pyproject.toml +++ b/plugins/yjs/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "importlib_metadata >=3.6; python_version<'3.10'", "pycrdt >=0.8.16,<0.9.0", "jupyverse-api >=0.1.2,<1", + "sqlite-anyio >=0.2.0,<0.3.0", ] dynamic = [ "version",] [[project.authors]] diff --git a/pyproject.toml b/pyproject.toml index 8486f9f2..e7e702a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,12 @@ docs = [ "mkdocs", "mkdocs-material" ] # pre-install commands here and post-install commands in the matrix can be moved # to the dependencies section pre-install-commands = [ + "pip install git+https://github.com/asphalt-framework/asphalt.git@5.0", + "pip install git+https://github.com/asphalt-framework/asphalt-web.git@asphalt5", + "pip install asgiref", + "pip install fastapi", + "pip install hypercorn", + "pip install -e ./jupyverse_api", "pip install -e ./plugins/contents", "pip install -e ./plugins/frontend", @@ -87,6 +93,7 @@ matrix.frontend.scripts = [ { key = "typecheck1", value = "typecheck0 ./plugins/jupyterlab", if = ["jupyterlab"] }, { key = "typecheck1", value = "typecheck0 ./plugins/notebook", if = ["notebook"] }, ] + matrix.auth.post-install-commands = [ { value = "pip install -e ./plugins/noauth", if = ["noauth"] }, { value = "pip install -e ./plugins/auth -e ./plugins/login", if = ["auth"] }, @@ -181,15 +188,3 @@ python_packages = [ [tool.hatch.version] path = "jupyverse/__init__.py" - -[tool.pytest.ini_options] -asyncio_mode = "strict" - -[tool.pixi.project] -name = "" -channels = ["conda-forge"] -platforms = ["linux-64"] - -[tool.pixi.dependencies] -pip = ">=24.0,<25" -python = "<3.12" diff --git a/tests/conftest.py b/tests/conftest.py index 748a3e70..983395db 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import os +import signal import subprocess import time from pathlib import Path @@ -7,6 +8,12 @@ import requests +@pytest.fixture +def anyio_backend(): + # at least, SQLAlchemy doesn't support anything else than asyncio + return "asyncio" + + @pytest.fixture() def cwd(): return Path(__file__).parents[1] @@ -38,5 +45,5 @@ def start_jupyverse(auth_mode, clear_users, cwd, unused_tcp_port): else: break yield url - p.kill() + os.kill(p.pid, signal.SIGINT) p.wait() diff --git a/tests/data/notebook1.ipynb b/tests/data/notebook1.ipynb index e1c94429..6ea750cb 100644 --- a/tests/data/notebook1.ipynb +++ b/tests/data/notebook1.ipynb @@ -1,53 +1,56 @@ { "cells": [ { - "execution_count": null, + "execution_count": 1, "outputs": [], "id": "a7243792-6f06-4462-a6b5-7e9ec604348e", "source": "from ypywidgets_textual.switch import Switch", - "cell_type": "code", + "execution_state": "idle", "metadata": { "trusted": false - } + }, + "cell_type": "code" }, { + "execution_count": 2, + "cell_type": "code", + "outputs": [], + "execution_state": "busy", "id": "a7243792-6f06-4462-a6b5-7e9ec604348f", - "source": "switch = Switch()\nswitch", - "execution_count": null, "metadata": { "trusted": false }, - "outputs": [], - "cell_type": "code" + "source": "switch = Switch()\nswitch" }, { + "execution_state": "idle", "outputs": [], "id": "a7243792-6f06-4462-a6b5-7e9ec604349f", "source": "switch.toggle()", "cell_type": "code", + "execution_count": 3, "metadata": { "trusted": false - }, - "execution_count": null + } } ], "metadata": { "kernelspec": { - "language": "python", + "display_name": "Python 3 (ipykernel)", "name": "python3", - "display_name": "Python 3 (ipykernel)" + "language": "python" }, "language_info": { "version": "3.7.12", + "pygments_lexer": "ipython3", + "name": "python", + "nbconvert_exporter": "python", + "mimetype": "text/x-python", + "file_extension": ".py", "codemirror_mode": { "version": 3, "name": "ipython" - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "pygments_lexer": "ipython3", - "nbconvert_exporter": "python" + } } }, "nbformat": 4, diff --git a/tests/test_app.py b/tests/test_app.py index dfd97365..ff62fc0a 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,5 +1,5 @@ import pytest -from asphalt.core import Context +from asphalt.core import Context, get_resource from fastapi import APIRouter from httpx import AsyncClient from utils import configure @@ -9,7 +9,7 @@ from jupyverse_api.main import JupyverseComponent -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize( "mount_path", ( @@ -20,13 +20,13 @@ async def test_mount_path(mount_path, unused_tcp_port): components = configure({"app": {"type": "app"}}, {"app": {"mount_path": mount_path}}) - async with Context() as ctx, AsyncClient() as http: + async with Context(), AsyncClient() as http: await JupyverseComponent( components=components, port=unused_tcp_port, - ).start(ctx) + ).start() - app = await ctx.request_resource(App) + app = await get_resource(App, wait=True) router = APIRouter() @router.get("/") diff --git a/tests/test_auth.py b/tests/test_auth.py index e8a3b5ed..9e35fa25 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,5 +1,5 @@ import pytest -from asphalt.core import Context +from asphalt.core import Context, get_resource from httpx import AsyncClient from httpx_ws import WebSocketUpgradeError, aconnect_ws from utils import authenticate_client, configure @@ -19,13 +19,13 @@ } -@pytest.mark.asyncio +@pytest.mark.anyio async def test_kernel_channels_unauthenticated(unused_tcp_port): - async with Context() as ctx: + async with Context(): await JupyverseComponent( components=COMPONENTS, port=unused_tcp_port, - ).start(ctx) + ).start() with pytest.raises(WebSocketUpgradeError): async with aconnect_ws( @@ -34,13 +34,13 @@ async def test_kernel_channels_unauthenticated(unused_tcp_port): pass -@pytest.mark.asyncio +@pytest.mark.anyio async def test_kernel_channels_authenticated(unused_tcp_port): - async with Context() as ctx, AsyncClient() as http: + async with Context(), AsyncClient() as http: await JupyverseComponent( components=COMPONENTS, port=unused_tcp_port, - ).start(ctx) + ).start() await authenticate_client(http, unused_tcp_port) async with aconnect_ws( @@ -50,15 +50,15 @@ async def test_kernel_channels_authenticated(unused_tcp_port): pass -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize("auth_mode", ("noauth", "token", "user")) async def test_root_auth(auth_mode, unused_tcp_port): components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) - async with Context() as ctx, AsyncClient() as http: + async with Context(), AsyncClient() as http: await JupyverseComponent( components=components, port=unused_tcp_port, - ).start(ctx) + ).start() response = await http.get(f"http://127.0.0.1:{unused_tcp_port}/") if auth_mode == "noauth": @@ -70,31 +70,31 @@ async def test_root_auth(auth_mode, unused_tcp_port): assert response.headers["content-type"] == "application/json" -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize("auth_mode", ("noauth",)) async def test_no_auth(auth_mode, unused_tcp_port): components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) - async with Context() as ctx, AsyncClient() as http: + async with Context(), AsyncClient() as http: await JupyverseComponent( components=components, port=unused_tcp_port, - ).start(ctx) + ).start() response = await http.get(f"http://127.0.0.1:{unused_tcp_port}/lab") assert response.status_code == 200 -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize("auth_mode", ("token",)) async def test_token_auth(auth_mode, unused_tcp_port): components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) - async with Context() as ctx, AsyncClient() as http: + async with Context(), AsyncClient() as http: await JupyverseComponent( components=components, port=unused_tcp_port, - ).start(ctx) + ).start() - auth_config = await ctx.request_resource(AuthConfig) + auth_config = await get_resource(AuthConfig, wait=True) # no token provided, should not work response = await http.get(f"http://127.0.0.1:{unused_tcp_port}/") @@ -104,7 +104,7 @@ async def test_token_auth(auth_mode, unused_tcp_port): assert response.status_code == 302 -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize("auth_mode", ("user",)) @pytest.mark.parametrize( "permissions", @@ -115,11 +115,11 @@ async def test_token_auth(auth_mode, unused_tcp_port): ) async def test_permissions(auth_mode, permissions, unused_tcp_port): components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) - async with Context() as ctx, AsyncClient() as http: + async with Context(), AsyncClient() as http: await JupyverseComponent( components=components, port=unused_tcp_port, - ).start(ctx) + ).start() await authenticate_client(http, unused_tcp_port, permissions=permissions) response = await http.get(f"http://127.0.0.1:{unused_tcp_port}/auth/user/me") diff --git a/tests/test_contents.py b/tests/test_contents.py index b44a4aac..1262bd6a 100644 --- a/tests/test_contents.py +++ b/tests/test_contents.py @@ -16,7 +16,7 @@ } -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize("auth_mode", ("noauth",)) async def test_tree(auth_mode, tmp_path, unused_tcp_port): prev_dir = os.getcwd() @@ -65,11 +65,11 @@ async def test_tree(auth_mode, tmp_path, unused_tcp_port): ) components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) - async with Context() as ctx, AsyncClient() as http: + async with Context(), AsyncClient() as http: await JupyverseComponent( components=components, port=unused_tcp_port, - ).start(ctx) + ).start() response = await http.get( f"http://127.0.0.1:{unused_tcp_port}/api/contents", params={"content": 1} diff --git a/tests/test_execute.py b/tests/test_execute.py index d423f1a1..543156d3 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -1,9 +1,9 @@ -import asyncio import os from functools import partial from pathlib import Path import pytest +from anyio import create_memory_object_stream, create_task_group, sleep from asphalt.core import Context from fps_yjs.ydocs import ydocs from fps_yjs.ywebsocket import WebsocketProvider @@ -55,7 +55,7 @@ async def recv(self) -> bytes: return bytes(b) -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize("auth_mode", ("noauth",)) async def test_execute(auth_mode, unused_tcp_port): url = f"http://127.0.0.1:{unused_tcp_port}" @@ -63,11 +63,11 @@ async def test_execute(auth_mode, unused_tcp_port): "auth": {"mode": auth_mode}, "kernels": {"require_yjs": True}, }) - async with Context() as ctx, AsyncClient() as http: + async with Context(), AsyncClient() as http: await JupyverseComponent( components=components, port=unused_tcp_port, - ).start(ctx) + ).start() ws_url = url.replace("http", "ws", 1) name = "notebook1.ipynb" @@ -90,23 +90,24 @@ async def test_execute(auth_mode, unused_tcp_port): json={ "format": "json", "type": "notebook", - } + }, + timeout=20, ) file_id = response.json()["fileId"] document_id = f"json:notebook:{file_id}" ynb = ydocs["notebook"]() - def callback(aevent, events, event): + def callback(event_stream_send, events, event): events.append(event) - aevent.set() - aevent = asyncio.Event() + event_stream_send.send_nowait(None) + event_stream_send, event_stream_recv = create_memory_object_stream[None](1) events = [] - ynb.ydoc.observe_subdocs(partial(callback, aevent, events)) + ynb.ydoc.observe_subdocs(partial(callback, event_stream_send, events)) async with aconnect_ws( f"{ws_url}/api/collaboration/room/{document_id}" ) as websocket, WebsocketProvider(ynb.ydoc, Websocket(websocket, document_id)): # connect to the shared notebook document # wait for file to be loaded and Y model to be created in server and client - await asyncio.sleep(0.5) + await sleep(0.5) # execute notebook for cell_idx in range(2): response = await http.post( @@ -117,23 +118,22 @@ def callback(aevent, events, event): } ) while True: - await aevent.wait() - aevent.clear() + await event_stream_recv.receive() guid = None for event in events: if event.added: guid = event.added[0] if guid is not None: break - task = asyncio.create_task(connect_ywidget(ws_url, guid)) - response = await http.post( - f"{url}/api/kernels/{kernel_id}/execute", - json={ - "document_id": document_id, - "cell_id": ynb.ycells[2]["id"], - } - ) - await task + async with create_task_group() as tg: + tg.start_soon(connect_ywidget, ws_url, guid) + response = await http.post( + f"{url}/api/kernels/{kernel_id}/execute", + json={ + "document_id": document_id, + "cell_id": ynb.ycells[2]["id"], + } + ) async def connect_ywidget(ws_url, guid): @@ -141,10 +141,8 @@ async def connect_ywidget(ws_url, guid): async with aconnect_ws( f"{ws_url}/api/collaboration/room/ywidget:{guid}" ) as websocket, WebsocketProvider(ywidget_doc, Websocket(websocket, guid)): - await asyncio.sleep(0.5) - attrs = Map() - model_name = Text() - ywidget_doc["_attrs"] = attrs - ywidget_doc["_model_name"] = model_name + await sleep(0.5) + ywidget_doc["_attrs"] = attrs = Map() + ywidget_doc["_model_name"] = model_name = Text() assert str(model_name) == "Switch" assert str(attrs) == '{"value":true}' diff --git a/tests/test_kernels.py b/tests/test_kernels.py index dba726b9..77f4cd8c 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -4,6 +4,7 @@ from time import sleep import pytest +from anyio import create_task_group from asphalt.core import Context from fps_kernels.kernel_server.server import KernelServer, kernels from httpx import AsyncClient @@ -26,9 +27,9 @@ } -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize("auth_mode", ("noauth",)) -async def test_kernel_messages(auth_mode, capfd, unused_tcp_port): +async def test_kernel_messages(auth_mode, unused_tcp_port, capfd): kernel_id = "kernel_id_0" kernel_name = "python3" kernelspec_path = ( @@ -36,67 +37,70 @@ async def test_kernel_messages(auth_mode, capfd, unused_tcp_port): ) assert kernelspec_path.exists() kernel_server = KernelServer(kernelspec_path=kernelspec_path, capture_kernel_output=False) - await kernel_server.start() - kernels[kernel_id] = {"server": kernel_server} - msg_id = "0" - msg = { - "channel": "shell", - "parent_header": None, - "content": None, - "metadata": None, - "header": { - "msg_type": "msg_type_0", - "msg_id": msg_id, - }, - } + async with create_task_group() as tg: + await tg.start(kernel_server.start) + kernels[kernel_id] = {"server": kernel_server, "driver": None} + msg_id = "0" + msg = { + "channel": "shell", + "parent_header": None, + "content": None, + "metadata": None, + "header": { + "msg_type": "msg_type_0", + "msg_id": msg_id, + }, + } - components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) - async with Context() as ctx, AsyncClient(): - await JupyverseComponent( - components=components, - port=unused_tcp_port, - ).start(ctx) + components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) + async with Context(), AsyncClient(): + await JupyverseComponent( + components=components, + port=unused_tcp_port, + ).start() - # block msg_type_0 - msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1) - kernel_server.block_messages("msg_type_0") - async with aconnect_ws( - f"http://127.0.0.1:{unused_tcp_port}/api/kernels/{kernel_id}/channels?session_id=session_id_0", - ) as websocket: - await websocket.send_json(msg) - sleep(0.5) - out, err = capfd.readouterr() - assert not err + # block msg_type_0 + msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1) + kernel_server.block_messages("msg_type_0") + async with aconnect_ws( + f"http://127.0.0.1:{unused_tcp_port}/api/kernels/{kernel_id}/channels?session_id=session_id_0", + ) as websocket: + await websocket.send_json(msg) + sleep(0.5) + out, err = capfd.readouterr() + assert "IPKernelApp" not in err - # allow only msg_type_0 - msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1) - kernel_server.allow_messages("msg_type_0") - async with aconnect_ws( - f"http://127.0.0.1:{unused_tcp_port}/api/kernels/{kernel_id}/channels?session_id=session_id_0", - ) as websocket: - await websocket.send_json(msg) - sleep(0.5) - out, err = capfd.readouterr() - assert err.count("[IPKernelApp] WARNING | Unknown message type: 'msg_type_0'") == 1 + # allow only msg_type_0 + msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1) + kernel_server.allow_messages("msg_type_0") + async with aconnect_ws( + f"http://127.0.0.1:{unused_tcp_port}/api/kernels/{kernel_id}/channels?session_id=session_id_0", + ) as websocket: + await websocket.send_json(msg) + sleep(0.5) + out, err = capfd.readouterr() + assert err.count("[IPKernelApp] WARNING | Unknown message type: 'msg_type_0'") == 1 - # block all messages - msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1) - kernel_server.allow_messages([]) - async with aconnect_ws( - f"http://127.0.0.1:{unused_tcp_port}/api/kernels/{kernel_id}/channels?session_id=session_id_0", - ) as websocket: - await websocket.send_json(msg) - sleep(0.5) - out, err = capfd.readouterr() - assert not err + # block all messages + msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1) + kernel_server.allow_messages([]) + async with aconnect_ws( + f"http://127.0.0.1:{unused_tcp_port}/api/kernels/{kernel_id}/channels?session_id=session_id_0", + ) as websocket: + await websocket.send_json(msg) + sleep(0.5) + out, err = capfd.readouterr() + assert "IPKernelApp" not in err - # allow all messages - msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1) - kernel_server.allow_messages() - async with aconnect_ws( - f"http://127.0.0.1:{unused_tcp_port}/api/kernels/{kernel_id}/channels?session_id=session_id_0", - ) as websocket: - await websocket.send_json(msg) - sleep(0.5) - out, err = capfd.readouterr() - assert err.count("[IPKernelApp] WARNING | Unknown message type: 'msg_type_0'") >= 1 + # allow all messages + msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1) + kernel_server.allow_messages() + async with aconnect_ws( + f"http://127.0.0.1:{unused_tcp_port}/api/kernels/{kernel_id}/channels?session_id=session_id_0", + ) as websocket: + await websocket.send_json(msg) + sleep(0.5) + out, err = capfd.readouterr() + assert err.count("[IPKernelApp] WARNING | Unknown message type: 'msg_type_0'") >= 1 + + tg.start_soon(kernel_server.stop) diff --git a/tests/test_server.py b/tests/test_server.py index 0ba0cc03..e03f763f 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,10 +1,10 @@ -import asyncio import json from functools import partial from pathlib import Path import pytest import requests +from anyio import create_memory_object_stream, create_task_group, sleep from fps_yjs.ydocs import ydocs from fps_yjs.ywebsocket import WebsocketProvider from pycrdt import Array, Doc, Map, Text @@ -47,7 +47,7 @@ def test_settings_persistence_get(start_jupyverse): assert response.status_code == 204 -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize("auth_mode", ("noauth",)) @pytest.mark.parametrize("clear_users", (False,)) async def test_rest_api(start_jupyverse): @@ -87,7 +87,7 @@ async def test_rest_api(start_jupyverse): ) as websocket, WebsocketProvider(ydoc, websocket): # connect to the shared notebook document # wait for file to be loaded and Y model to be created in server and client - await asyncio.sleep(0.5) + await sleep(0.5) ydoc["cells"] = ycells = Array() # execute notebook for cell_idx in range(3): @@ -101,7 +101,7 @@ async def test_rest_api(start_jupyverse): ), ) # wait for Y model to be updated - await asyncio.sleep(0.5) + await sleep(0.5) # retrieve cells cells = json.loads(str(ycells)) assert cells[0]["outputs"] == [ @@ -125,7 +125,7 @@ async def test_rest_api(start_jupyverse): ] -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize("auth_mode", ("noauth",)) @pytest.mark.parametrize("clear_users", (False,)) async def test_ywidgets(start_jupyverse): @@ -139,7 +139,6 @@ async def test_ywidgets(start_jupyverse): data=json.dumps( { "kernel": {"name": "python3"}, - #"kernel": {"name": "akernel"}, "name": name, "path": path, "type": "notebook", @@ -161,18 +160,18 @@ async def test_ywidgets(start_jupyverse): file_id = response.json()["fileId"] document_id = f"json:notebook:{file_id}" ynb = ydocs["notebook"]() - def callback(aevent, events, event): + def callback(event_stream_send, events, event): events.append(event) - aevent.set() - aevent = asyncio.Event() + event_stream_send.send_nowait(None) + event_stream_send, event_stream_recv = create_memory_object_stream[None](1) events = [] - ynb.ydoc.observe_subdocs(partial(callback, aevent, events)) + ynb.ydoc.observe_subdocs(partial(callback, event_stream_send, events)) async with connect( f"{ws_url}/api/collaboration/room/{document_id}" ) as websocket, WebsocketProvider(ynb.ydoc, websocket): # connect to the shared notebook document # wait for file to be loaded and Y model to be created in server and client - await asyncio.sleep(0.5) + await sleep(0.5) # execute notebook for cell_idx in range(2): response = requests.post( @@ -185,25 +184,24 @@ def callback(aevent, events, event): ), ) while True: - await aevent.wait() - aevent.clear() + await event_stream_recv.receive() guid = None for event in events: if event.added: guid = event.added[0] if guid is not None: break - task = asyncio.create_task(connect_ywidget(ws_url, guid)) - response = requests.post( - f"{url}/api/kernels/{kernel_id}/execute", - data=json.dumps( - { - "document_id": document_id, - "cell_id": ynb.ycells[2]["id"], - } - ), - ) - await task + async with create_task_group() as tg: + tg.start_soon(connect_ywidget, ws_url, guid) + response = requests.post( + f"{url}/api/kernels/{kernel_id}/execute", + data=json.dumps( + { + "document_id": document_id, + "cell_id": ynb.ycells[2]["id"], + } + ), + ) async def connect_ywidget(ws_url, guid): @@ -211,7 +209,7 @@ async def connect_ywidget(ws_url, guid): async with connect( f"{ws_url}/api/collaboration/room/ywidget:{guid}" ) as websocket, WebsocketProvider(ywidget_doc, websocket): - await asyncio.sleep(0.5) + await sleep(0.5) attrs = Map() model_name = Text() ywidget_doc["_attrs"] = attrs diff --git a/tests/test_settings.py b/tests/test_settings.py index 03cb6a60..1ee953f2 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -21,15 +21,15 @@ } -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize("auth_mode", ("noauth",)) async def test_settings(auth_mode, unused_tcp_port): components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) - async with Context() as ctx, AsyncClient() as http: + async with Context(), AsyncClient() as http: await JupyverseComponent( components=components, port=unused_tcp_port, - ).start(ctx) + ).start() # get previous theme response = await http.get( @@ -40,7 +40,7 @@ async def test_settings(auth_mode, unused_tcp_port): # put new theme response = await http.put( f"http://127.0.0.1:{unused_tcp_port}/lab/api/settings/@jupyterlab/apputils-extension:themes", - data=json.dumps(test_theme), + content=json.dumps(test_theme), ) assert response.status_code == 204 # get new theme @@ -52,6 +52,6 @@ async def test_settings(auth_mode, unused_tcp_port): # put previous theme back response = await http.put( f"http://127.0.0.1:{unused_tcp_port}/lab/api/settings/@jupyterlab/apputils-extension:themes", - data=json.dumps(theme), + content=json.dumps(theme), ) assert response.status_code == 204