Skip to content

Commit

Permalink
-
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Apr 10, 2024
1 parent 8b07d9b commit 8b79890
Show file tree
Hide file tree
Showing 37 changed files with 273 additions and 215 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,3 @@ $RECYCLE.BIN/
.jupyter_ystore.db
.jupyter_ystore.db-journal
fps_cli_args.toml

# pixi environments
.pixi
20 changes: 15 additions & 5 deletions jupyverse_api/jupyverse_api/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import webbrowser
from typing import Any, Callable, Dict, Sequence, Tuple

from anyio import Event
from asgiref.typing import ASGI3Application
from asphalt.core import Component, add_resource, request_resource
from asphalt.core import Component, add_resource, get_resource, start_service_task
from asphalt.web.fastapi import FastAPIComponent
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -23,10 +24,10 @@ def __init__(
self.mount_path = mount_path

async def start(self) -> None:
app = await request_resource(FastAPI)
app = await get_resource(FastAPI, wait=True)

_app = App(app, mount_path=self.mount_path)
await add_resource(_app)
add_resource(_app)


class JupyverseComponent(FastAPIComponent):
Expand Down Expand Up @@ -64,19 +65,23 @@ def __init__(
self.port = port
self.open_browser = open_browser
self.query_params = query_params
self.lifespan = Lifespan()

async def start(self) -> None:
query_params = QueryParams(d={})
host = self.host
if not host.startswith("http"):
host = f"http://{host}"
host_url = Host(url=f"{host}:{self.port}/")
await add_resource(query_params)
await add_resource(host_url)
add_resource(query_params)
add_resource(host_url)
add_resource(self.lifespan)

await super().start()

# at this point, the server has started
await start_service_task(self.lifespan.shutdown_request.wait, "Server lifespan notifier", teardown_action=self.lifespan.shutdown_request.set)

if self.open_browser:
qp = query_params.d
if self.query_params:
Expand All @@ -91,3 +96,8 @@ class QueryParams(BaseModel):

class Host(BaseModel):
url: str


class Lifespan:
def __init__(self):
self.shutdown_request = Event()
2 changes: 2 additions & 0 deletions jupyverse_api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ classifiers = [
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
Expand All @@ -28,6 +29,7 @@ dependencies = [
"pydantic >=2,<3",
"fastapi >=0.95.0,<1",
"rich-click >=1.6.1,<2",
"importlib_metadata >=3.6; python_version<'3.10'",
#"asphalt >=4.11.0,<5",
#"asphalt-web[fastapi] >=1.1.0,<2",
]
Expand Down
14 changes: 7 additions & 7 deletions plugins/auth/fps_auth/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from asphalt.core import Component, add_resource, request_resource
from asphalt.core import Component, add_resource, get_resource
from fastapi_users.exceptions import UserAlreadyExists

from jupyverse_api.app import App
Expand All @@ -19,13 +19,13 @@ def __init__(self, **kwargs):
self.auth_config = _AuthConfig(**kwargs)

async def start(self) -> None:
await add_resource(self.auth_config, types=AuthConfig)
add_resource(self.auth_config, types=AuthConfig)

app = await request_resource(App)
frontend_config = await request_resource(FrontendConfig)
app = await get_resource(App, wait=True)
frontend_config = await get_resource(FrontendConfig, wait=True)

auth = auth_factory(app, self.auth_config, frontend_config)
await add_resource(auth, types=Auth)
add_resource(auth, types=Auth)

await auth.db.create_db_and_tables()

Expand Down Expand Up @@ -56,8 +56,8 @@ async def start(self) -> None:
)

if self.auth_config.mode == "token":
query_params = await request_resource(QueryParams)
host = await request_resource(Host)
query_params = await get_resource(QueryParams, wait=True)
host = await get_resource(Host, wait=True)
query_params.d["token"] = self.auth_config.token

logger.info("")
Expand Down
13 changes: 5 additions & 8 deletions plugins/auth_fief/fps_auth_fief/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from asphalt.core import Component, Context
from asphalt.core import Component, add_resource, get_resource

from jupyverse_api.app import App
from jupyverse_api.auth import Auth, AuthConfig
Expand All @@ -11,13 +11,10 @@ class AuthFiefComponent(Component):
def __init__(self, **kwargs):
self.auth_fief_config = _AuthFiefConfig(**kwargs)

async def start(
self,
ctx: Context,
) -> None:
await ctx.add_resource(self.auth_fief_config, types=AuthConfig)
async def start(self) -> None:
add_resource(self.auth_fief_config, types=AuthConfig)

app = await ctx.request_resource(App)
app = await get_resource(App, wait=True)

auth_fief = auth_factory(app, self.auth_fief_config)
await ctx.add_resource(auth_fief, types=Auth)
add_resource(auth_fief, types=Auth)
32 changes: 16 additions & 16 deletions plugins/auth_jupyterhub/fps_auth_jupyterhub/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from asphalt.core import Component, ContainerComponent, Context
from asphalt.core import (
Component,
ContainerComponent,
add_resource,
get_resource,
start_background_task,
)
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession

from jupyverse_api.app import App
Expand All @@ -11,18 +17,15 @@

class _AuthJupyterHubComponent(Component):
#@context_teardown
async def start(
self,
ctx: Context,
) -> None:
app = await ctx.request_resource(App)
db_session = await ctx.request_resource(AsyncSession)
db_engine = await ctx.request_resource(AsyncEngine)
async def start(self) -> None:
app = await get_resource(App, wait=True)
db_session = await get_resource(AsyncSession, wait=True)
db_engine = await get_resource(AsyncEngine, wait=True)

http_client = httpx.AsyncClient()
auth_jupyterhub = auth_factory(app, db_session)
await ctx.start_background_task(auth_jupyterhub.start, "JupyterHub Auth", auth_jupyterhub.stop)
await ctx.add_resource(auth_jupyterhub, types=Auth)
await start_background_task(auth_jupyterhub.start, "JupyterHub Auth", auth_jupyterhub.stop)
add_resource(auth_jupyterhub, types=Auth)

async with db_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
Expand All @@ -37,14 +40,11 @@ def __init__(self, **kwargs):
self.auth_jupyterhub_config = AuthJupyterHubConfig(**kwargs)
super().__init__()

async def start(
self,
ctx: Context,
) -> None:
await ctx.add_resource(self.auth_jupyterhub_config, types=AuthConfig)
async def start(self) -> None:
add_resource(self.auth_jupyterhub_config, types=AuthConfig)
self.add_component(
"sqlalchemy",
url=self.auth_jupyterhub_config.db_url,
)
self.add_component("auth_jupyterhub", type=_AuthJupyterHubComponent)
await super().start(ctx)
await super().start()
20 changes: 16 additions & 4 deletions plugins/contents/fps_contents/fileid.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
from contextlib import AsyncExitStack
from typing import Dict, List, Optional, Set
Expand Down Expand Up @@ -41,6 +43,7 @@ def __init__(self, db_path: str = ".fileid.db"):
self.initialized = Event()
self.watchers = {}
self.stop_watching_files = Event()
self.started_watching_files = Event()
self.stopped_watching_files = Event()
self.lock = Lock()
self._task_group = None
Expand All @@ -65,19 +68,27 @@ async def __aexit__(self, exc_type, exc_value, exc_tb):
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

async def start(self) -> None:
self._db = await connect(self.db_path)
await self.watch_files()
async def _start():
self._db = await connect(self.db_path)
await self.watch_files()

async with create_task_group() as self._task_group:
self._task_group.start_soon(_start)

async def stop(self) -> None:
print("fileid stopping")
self._task_group.cancel_scope.cancel()
await self._db.close()
self.stop_watching_files.set()
await self.stopped_watching_files.wait()
if self.started_watching_files.is_set():
await self.stopped_watching_files.wait()
print("fileid stopped")

async def get_id(self, path: str) -> Optional[str]:
await self.initialized.wait()
async with self.lock:
cursor = await self._db.cursor()
await cur.execute("SELECT id FROM fileids WHERE path = ?", (path,))
await cursor.execute("SELECT id FROM fileids WHERE path = ?", (path,))
for (idx,) in await cursor.fetchall():
return idx
return None
Expand Down Expand Up @@ -127,6 +138,7 @@ async def watch_files(self):
await self._db.commit()
self.initialized.set()

self.started_watching_files.set()
async for changes in awatch(".", stop_event=self.stop_watching_files):
async with self.lock:
deleted_paths = set()
Expand Down
8 changes: 4 additions & 4 deletions plugins/contents/fps_contents/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from asphalt.core import Component, add_resource, request_resource
from asphalt.core import Component, add_resource, get_resource

from jupyverse_api.app import App
from jupyverse_api.auth import Auth
Expand All @@ -9,8 +9,8 @@

class ContentsComponent(Component):
async def start(self) -> None:
app = await request_resource(App)
auth = await request_resource(Auth) # type: ignore
app = await get_resource(App, wait=True)
auth = await get_resource(Auth, wait=True)

contents = _Contents(app, auth)
await add_resource(contents, types=Contents)
add_resource(contents, types=Contents)
2 changes: 1 addition & 1 deletion plugins/frontend/fps_frontend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ def __init__(self, **kwargs):
self.frontend_config = FrontendConfig(**kwargs)

async def start(self) -> None:
await add_resource(self.frontend_config, types=FrontendConfig)
add_resource(self.frontend_config, types=FrontendConfig)
14 changes: 7 additions & 7 deletions plugins/jupyterlab/fps_jupyterlab/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from asphalt.core import Component, add_resource, request_resource
from asphalt.core import Component, add_resource, get_resource

from jupyverse_api.app import App
from jupyverse_api.auth import Auth
Expand All @@ -14,12 +14,12 @@ def __init__(self, **kwargs):
self.jupyterlab_config = JupyterLabConfig(**kwargs)

async def start(self) -> None:
await add_resource(self.jupyterlab_config, types=JupyterLabConfig)
add_resource(self.jupyterlab_config, types=JupyterLabConfig)

app = await request_resource(App)
auth = await request_resource(Auth) # type: ignore
frontend_config = await request_resource(FrontendConfig)
lab = await request_resource(Lab) # type: ignore
app = await get_resource(App, wait=True)
auth = await get_resource(Auth, wait=True)
frontend_config = await get_resource(FrontendConfig, wait=True)
lab = await get_resource(Lab, wait=True)

jupyterlab = _JupyterLab(app, self.jupyterlab_config, auth, frontend_config, lab)
await add_resource(jupyterlab, types=JupyterLab)
add_resource(jupyterlab, types=JupyterLab)
2 changes: 2 additions & 0 deletions plugins/kernels/fps_kernels/kernel_driver/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ async def launch_kernel(
else:
stdout = None
stderr = None
if not kernel_cwd:
kernel_cwd = None
process = await open_process(cmd, stdout=stdout, stderr=stderr, cwd=kernel_cwd)
return process

Expand Down
26 changes: 20 additions & 6 deletions plugins/kernels/fps_kernels/kernel_driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
import uuid
from typing import Any, Dict, Optional, cast

from anyio import create_memory_object_stream, create_task_group, fail_after
from anyio import (
TASK_STATUS_IGNORED,
Event,
create_memory_object_stream,
create_task_group,
fail_after,
)
from anyio.abc import TaskStatus
from anyio.streams.stapled import StapledObjectStream
from pycrdt import Array, Map

Expand Down Expand Up @@ -45,7 +52,9 @@ def __init__(
self.session_id = uuid.uuid4().hex
self.msg_cnt = 0
self.execute_requests: Dict[str, Dict[str, StapledObjectStream]] = {}
self.comm_messages: StapledObjectStream = StapledObjectStream(create_memory_object_stream[dict](max_buffer_size=1024))
self.comm_messages: StapledObjectStream = StapledObjectStream(*create_memory_object_stream[dict](max_buffer_size=1024))
self.stop_event = Event()
self.stopped_event = Event()

async def restart(self, startup_timeout: float = float("inf")) -> None:
self.task_group.cancel_scope.cancel()
Expand All @@ -61,7 +70,7 @@ async def restart(self, startup_timeout: float = float("inf")) -> None:
async with create_task_group() as self.task_group:
self.listen_channels()

async def start(self, startup_timeout: float = float("inf"), connect: bool = True) -> None:
async def start(self, startup_timeout: float = float("inf"), connect: bool = True, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None:
async with create_task_group() as self.task_group:
self.kernel_process = await launch_kernel(
self.kernelspec_path,
Expand All @@ -70,7 +79,10 @@ async def start(self, startup_timeout: float = float("inf"), connect: bool = Tru
self.capture_kernel_output,
)
if connect:
await self.connect(startup_timeout)
await self.connect()
task_status.started()
await self.stop_event.wait()
self.stopped_event.set()

async def connect(self, startup_timeout: float = float("inf")) -> None:
self.connect_channels()
Expand All @@ -93,6 +105,8 @@ async def stop(self) -> None:
await self.kernel_process.wait()
await self.kernel_process.aclose()
os.remove(self.connection_file_path)
self.stop_event.set()
await self.stopped_event.wait()
self.task_group.cancel_scope.cancel()

async def listen_iopub(self):
Expand Down Expand Up @@ -132,8 +146,8 @@ async def execute(
self.msg_cnt += 1
await send_message(msg, self.shell_channel, self.key, change_date_to_str=True)
self.execute_requests[msg_id] = {
"iopub_msg": StapledObjectStream(create_memory_object_stream[dict](max_buffer_size=1024)),
"shell_msg": StapledObjectStream(create_memory_object_stream[dict](max_buffer_size=1024)),
"iopub_msg": StapledObjectStream(*create_memory_object_stream[dict](max_buffer_size=1024)),
"shell_msg": StapledObjectStream(*create_memory_object_stream[dict](max_buffer_size=1024)),
}
if wait_for_executed:
deadline = time.time() + timeout
Expand Down
Loading

0 comments on commit 8b79890

Please sign in to comment.