Skip to content

Commit

Permalink
feat(pluginhost): allow cmd plugins to access host FS
Browse files Browse the repository at this point in the history
Also remove the hard dependency on (now-gone) xingque while refactoring
the plugin host API for addition of the feature flag.
  • Loading branch information
xen0n committed Nov 2, 2024
1 parent 008a716 commit da6e33c
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 39 deletions.
36 changes: 27 additions & 9 deletions ruyi/pluginhost/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,33 +57,42 @@ def __init__(
@abc.abstractmethod
def make_loader(
self,
plugin_root: pathlib.Path,
originating_file: pathlib.Path,
module_cache: MutableMapping[str, ModuleTy],
is_cmd: bool,
) -> "BasePluginLoader[ModuleTy]":
raise NotImplementedError

@abc.abstractmethod
def make_evaluator(self) -> EvalTy:
raise NotImplementedError

def load_plugin(self, plugin_id: str) -> None:
@property
def plugin_root(self) -> pathlib.Path:
return self._plugin_root

def load_plugin(self, plugin_id: str, is_cmd: bool) -> None:
plugin_dir = paths.get_plugin_dir(plugin_id, self._plugin_root)

loader = self.make_loader(
self._plugin_root,
plugin_dir / paths.PLUGIN_ENTRYPOINT_FILENAME,
self._module_cache,
is_cmd,
)
loaded_plugin = loader.load_this_plugin()
self._loaded_plugins[plugin_id] = loaded_plugin

def is_plugin_loaded(self, plugin_id: str) -> bool:
return plugin_id in self._loaded_plugins

def get_from_plugin(self, plugin_id: str, key: str) -> object | None:
def get_from_plugin(
self,
plugin_id: str,
key: str,
is_cmd_plugin: bool = False,
) -> object | None:
if not self.is_plugin_loaded(plugin_id):
self.load_plugin(plugin_id)
self.load_plugin(plugin_id, is_cmd_plugin)

if plugin_id not in self._value_cache:
self._value_cache[plugin_id] = {}
Expand Down Expand Up @@ -111,19 +120,26 @@ class BasePluginLoader(Generic[ModuleTy], metaclass=abc.ABCMeta):

def __init__(
self,
root: pathlib.Path,
phctx: PluginHostContext[ModuleTy, SupportsEvalFunction],
originating_file: pathlib.Path,
module_cache: MutableMapping[str, ModuleTy],
is_cmd: bool,
) -> None:
self.root = root
self._phctx = phctx
self.originating_file = originating_file
self.module_cache = module_cache
self.is_cmd = is_cmd

@property
def root(self) -> pathlib.Path:
return self._phctx.plugin_root

def make_sub_loader(self, originating_file: pathlib.Path) -> Self:
return self.__class__(
self.root,
self._phctx,
originating_file,
self.module_cache,
self.is_cmd,
)

def load_this_plugin(self) -> ModuleTy:
Expand All @@ -142,6 +158,7 @@ def _load(self, path: str, is_root: bool) -> ModuleTy:
self.root,
False,
self.originating_file,
self.is_cmd,
)
resolved_path_str = str(resolved_path)
if resolved_path_str in self.module_cache:
Expand All @@ -151,9 +168,10 @@ def _load(self, path: str, is_root: bool) -> ModuleTy:
plugin_dir = self.root / plugin_id

host_bridge = api.make_ruyi_plugin_api_for_module(
self.root,
self._phctx,
resolved_path,
plugin_dir,
self.is_cmd,
)

mod = self.do_load_module(
Expand Down
53 changes: 35 additions & 18 deletions ruyi/pluginhost/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import subprocess
import time
import tomllib
from typing import Any, Callable, TypeVar, cast
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast

import xingque

from ruyi import log
from ruyi.cli import user_input
from ruyi.version import RUYI_SEMVER
from .. import log
from ..cli import user_input
from ..version import RUYI_SEMVER
from .paths import resolve_ruyi_load_path

if TYPE_CHECKING:
from . import PluginHostContext, SupportsEvalFunction, SupportsGetOption

T = TypeVar("T")
U = TypeVar("U")
Expand All @@ -20,14 +20,16 @@
class RuyiHostAPI:
def __init__(
self,
plugin_root: pathlib.Path,
phctx: "PluginHostContext[SupportsGetOption, SupportsEvalFunction]",
this_file: pathlib.Path,
this_plugin_dir: pathlib.Path,
allow_host_fs_access: bool,
) -> None:
self._plugin_root = plugin_root
self._phctx = phctx
self._this_file = this_file
self._this_plugin_dir = this_plugin_dir
self._ev = xingque.Evaluator()
self._ev = phctx.make_evaluator()
self._allow_host_fs_access = allow_host_fs_access

self._logger = RuyiPluginLogger()

Expand All @@ -42,9 +44,10 @@ def ruyi_plugin_api_rev(self) -> int:
def load_toml(self, path: str) -> object:
resolved_path = resolve_ruyi_load_path(
path,
self._plugin_root,
self._phctx.plugin_root,
True,
self._this_file,
self._allow_host_fs_access,
)
with open(resolved_path, "rb") as f:
return tomllib.load(f)
Expand Down Expand Up @@ -82,12 +85,10 @@ def sleep(self, seconds: float, /) -> None:
def with_(
self,
cm: AbstractContextManager[T],
fn: xingque.Value | Callable[[T], U],
fn: object | Callable[[T], U],
) -> U:
with cm as obj:
if isinstance(fn, xingque.Value):
return cast(U, self._ev.eval_function(fn, obj))
return fn(obj)
return cast(U, self._ev.eval_function(fn, obj))


class RuyiPluginLogger:
Expand Down Expand Up @@ -141,9 +142,10 @@ def F(


def _ruyi_plugin_rev(
plugin_root: pathlib.Path,
phctx: "PluginHostContext[SupportsGetOption, SupportsEvalFunction]",
this_file: pathlib.Path,
this_plugin_dir: pathlib.Path,
allow_host_fs_access: bool,
rev: object,
) -> RuyiHostAPI:
if not isinstance(rev, int):
Expand All @@ -152,12 +154,27 @@ def _ruyi_plugin_rev(
raise ValueError(
f"Ruyi plugin API revision {rev} is not supported by this Ruyi"
)
return RuyiHostAPI(plugin_root, this_file, this_plugin_dir)
return RuyiHostAPI(
phctx,
this_file,
this_plugin_dir,
allow_host_fs_access,
)


def make_ruyi_plugin_api_for_module(
plugin_root: pathlib.Path,
phctx: "PluginHostContext[SupportsGetOption, SupportsEvalFunction]",
this_file: pathlib.Path,
this_plugin_dir: pathlib.Path,
is_cmd: bool,
) -> Callable[[object], RuyiHostAPI]:
return lambda rev: _ruyi_plugin_rev(plugin_root, this_file, this_plugin_dir, rev)
# Only allow access to host FS when we're being loaded as a command plugin
allow_host_fs_access = is_cmd

return lambda rev: _ruyi_plugin_rev(
phctx,
this_file,
this_plugin_dir,
allow_host_fs_access,
rev,
)
17 changes: 17 additions & 0 deletions ruyi/pluginhost/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def resolve_ruyi_load_path(
plugin_root: pathlib.Path,
is_for_data: bool,
originating_file: pathlib.Path,
allow_host_fs_access: bool,
) -> pathlib.Path:
parsed = urlparse(path)
if parsed.params or parsed.query or parsed.fragment:
Expand Down Expand Up @@ -80,6 +81,22 @@ def resolve_ruyi_load_path(
plugin_id=parsed.netloc,
)

case "host":
if not allow_host_fs_access:
raise RuntimeError("the host protocol is not allowed in this context")

if not parsed.path:
raise RuntimeError(
"empty path segment is not allowed for host:// load paths"
)

if parsed.netloc:
raise RuntimeError(
"non-empty location is not allowed for host:// load paths"
)

return pathlib.Path(parsed.path)

case _:
raise RuntimeError(
f"unsupported Ruyi Starlark load path scheme {parsed.scheme}"
Expand Down
4 changes: 2 additions & 2 deletions ruyi/pluginhost/unsandboxed.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ class UnsandboxedPluginHostContext(
):
def make_loader(
self,
plugin_root: pathlib.Path,
originating_file: pathlib.Path,
module_cache: MutableMapping[str, UnsandboxedModuleDict],
is_cmd: bool,
) -> BasePluginLoader[UnsandboxedModuleDict]:
return UnsandboxedRuyiPluginLoader(plugin_root, originating_file, module_cache)
return UnsandboxedRuyiPluginLoader(self, originating_file, module_cache, is_cmd)

def make_evaluator(self) -> UnsandboxedTrivialEvaluator:
return UnsandboxedTrivialEvaluator()
Expand Down
11 changes: 6 additions & 5 deletions ruyi/ruyipkg/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from pygit2 import clone_repository
from pygit2.repository import Repository
import xingque
import yaml

from .. import log
Expand Down Expand Up @@ -507,14 +506,16 @@ def run_plugin_cmd(self, cmd_name: str, args: list[str]) -> int:
plugin_entrypoint = self._plugin_host_ctx.get_from_plugin(
plugin_id,
"plugin_cmd_main_v1",
is_cmd_plugin=True, # allow access to host FS for command plugins
)
if plugin_entrypoint is None:
raise RuntimeError(f"cmd entrypoint not found in plugin '{plugin_id}'")

ev = xingque.Evaluator()
ret = ev.eval_function(plugin_entrypoint, args)
ret = self.eval_plugin_fn(plugin_entrypoint, args)
if not isinstance(ret, int):
raise TypeError(
f"unexpected return type of cmd plugin '{plugin_id}': {type(ret)} is not int"
log.W(
f"unexpected return type of cmd plugin '{plugin_id}': {type(ret)} is not int."
)
log.I("forcing return code to 1; the plugin should be fixed")
ret = 1
return ret
9 changes: 4 additions & 5 deletions tests/pluginhost/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from types import TracebackType

import pytest
import xingque

from ruyi.pluginhost import PluginHostContext

Expand Down Expand Up @@ -31,8 +30,8 @@ def __exit__(
return None

with ruyi_file.plugin_suite("with_") as plugin_root:
phctx = PluginHostContext(plugin_root)
ev = xingque.Evaluator()
phctx = PluginHostContext.new(plugin_root)
ev = phctx.make_evaluator()

fn1 = phctx.get_from_plugin("foo", "fn1")
assert fn1 is not None
Expand All @@ -42,12 +41,12 @@ def __exit__(
assert cm1.exited == 1
assert ret1 == 466

# even when the Starlark side panics, the context manager semantics
# even when the plugin side panics, the context manager semantics
# shall remain enforced
fn2 = phctx.get_from_plugin("foo", "fn2")
assert fn2 is not None
cm2 = MockContextManager()
with pytest.raises(RuntimeError):
with pytest.raises((RuntimeError, AttributeError)):
ev.eval_function(fn2, cm2)
assert cm2.entered == 1
assert cm2.exited == 1
Expand Down

0 comments on commit da6e33c

Please sign in to comment.