Skip to content

Commit

Permalink
support round-trip configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Aug 7, 2024
1 parent b0f44d2 commit 8fbe084
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 11 deletions.
30 changes: 23 additions & 7 deletions lenskit/lenskit/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
Component,
ConfigurableComponent,
TrainableComponent,
instantiate_component,
)
from .config import PipelineConfig, PipelineInput
from .config import PipelineComponent, PipelineConfig, PipelineInput
from .nodes import ND, ComponentNode, FallbackNode, InputNode, LiteralNode, Node
from .state import PipelineState

Expand Down Expand Up @@ -384,9 +385,9 @@ def clone(self, *, params: bool = False) -> Pipeline:
if isinstance(comp, FunctionType):
comp = comp
elif isinstance(comp, ConfigurableComponent):
comp = comp.__class__.from_config(comp.get_config())
comp = comp.__class__.from_config(comp.get_config()) # type: ignore
else:
comp = comp.__class__()
comp = comp.__class__() # type: ignore
cn = clone.add_component(node.name, comp) # type: ignore
for wn, wt in wiring.items():
clone.connect(cn, **{wn: clone.node(wt)})
Expand All @@ -407,10 +408,16 @@ def get_config(self) -> PipelineConfig:
checkpoints to load such parameters, depending on the design of the
components in the pipeline.
"""
inputs = [
PipelineInput.from_node(node) for node in self.nodes if isinstance(node, InputNode)
]
return PipelineConfig(inputs=inputs)
return PipelineConfig(
inputs=[
PipelineInput.from_node(node) for node in self.nodes if isinstance(node, InputNode)
],
components={
node.name: PipelineComponent.from_node(node)
for node in self.nodes
if isinstance(node, ComponentNode)
},
)

@classmethod
def from_config(cls, config: object) -> Self:
Expand All @@ -422,6 +429,15 @@ def from_config(cls, config: object) -> Self:
types += [parse_type_string(t) for t in inpt.types]
pipe.create_input(inpt.name, *types)

for name, comp in cfg.components.items():
obj = instantiate_component(comp.code, comp.config)
pipe.add_component(name, obj)

# pass 2: wiring
for name, comp in cfg.components.items():
inputs = {n: pipe.node(t) for (n, t) in comp.inputs.items()}
pipe.connect(name, **inputs)

return pipe

def train(self, data: Dataset) -> None:
Expand Down
23 changes: 23 additions & 0 deletions lenskit/lenskit/pipeline/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from __future__ import annotations

import inspect
from importlib import import_module
from types import FunctionType
from typing import Callable, ClassVar, TypeAlias

from typing_extensions import Any, Generic, Protocol, Self, TypeVar, override, runtime_checkable
Expand Down Expand Up @@ -184,3 +186,24 @@ def from_config(cls, cfg: dict[str, Any]) -> Self:
are passed to the constructor as keywrod arguments.
"""
return cls(**cfg)


def instantiate_component(
comp: str | type | FunctionType, config: dict[str, Any] | None
) -> Callable[..., object]:
if isinstance(comp, str):
mname, oname = comp.split(":", 1)
mod = import_module(mname)
comp = getattr(mod, oname)

# make the type checker happy
assert not isinstance(comp, str)

if isinstance(comp, FunctionType):
return comp
elif issubclass(comp, ConfigurableComponent):
if config is None:
config = {}
return comp.from_config(config) # type: ignore
else:
return comp() # type: ignore
44 changes: 40 additions & 4 deletions lenskit/lenskit/pipeline/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
# pyright: strict
from __future__ import annotations

from pydantic import BaseModel
from typing_extensions import Any, Optional, Self
from types import FunctionType

from lenskit.pipeline.types import type_string
from pydantic import BaseModel, Field
from typing_extensions import Any, Optional, Self

from .nodes import InputNode
from .components import ConfigurableComponent
from .nodes import ComponentNode, InputNode
from .types import type_string


class PipelineConfig(BaseModel):
Expand All @@ -21,6 +23,7 @@ class PipelineConfig(BaseModel):
"""

inputs: list[PipelineInput]
components: dict[str, PipelineComponent] = Field(default_factory=dict)


class PipelineInput(BaseModel):
Expand All @@ -37,3 +40,36 @@ def from_node(cls, node: InputNode[Any]) -> Self:
types = None

return cls(name=node.name, types=types)


class PipelineComponent(BaseModel):
code: str
"""
The path to the component's implementation, either a class or a function.
This is a Python qualified path of the form ``module:name``.
"""

config: dict[str, object] | None = Field(default=None)
"""
The component configuration. If not provided, the component will be created
with its default constructor parameters.
"""

inputs: dict[str, str] = Field(default_factory=dict)
"""
The component's input wirings, mapping input names to node names.
"""

@classmethod
def from_node(cls, node: ComponentNode[Any]) -> Self:
comp = node.component
if isinstance(comp, FunctionType):
ctype = comp
else:
ctype = comp.__class__

code = f"{ctype.__module__}:{ctype.__qualname__}"

config = comp.get_config() if isinstance(comp, ConfigurableComponent) else None

return cls(code=code, config=config, inputs=node.connections)
70 changes: 70 additions & 0 deletions lenskit/tests/pipeline/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@
from typing_extensions import assert_type

from lenskit.pipeline import InputNode, Node, Pipeline
from lenskit.pipeline.components import AutoConfig
from lenskit.pipeline.nodes import ComponentNode


class Prefixer(AutoConfig):
prefix: str

def __init__(self, prefix: str = "hello"):
self.prefix = prefix

def __call__(self, msg: str) -> str:
return self.prefix + msg


def test_serialize_input():
Expand Down Expand Up @@ -45,3 +57,61 @@ def test_round_trip_optional_input():
assert isinstance(i2, InputNode)
assert i2.name == "user"
assert i2.types == {int, str, NoneType}


def msg_ident(msg: str) -> str:
return msg


def test_config_single_node():
pipe = Pipeline()
msg = pipe.create_input("msg", str)

pipe.add_component("return", msg_ident, msg=msg)

cfg = pipe.get_config()
assert len(cfg.inputs) == 1
assert len(cfg.components) == 1

assert cfg.components["return"].code == "lenskit.tests.pipeline.test_save_load:msg_ident"
assert cfg.components["return"].config is None
assert cfg.components["return"].inputs == {"msg": "msg"}


def test_round_trip_single_node():
pipe = Pipeline()
msg = pipe.create_input("msg", str)

pipe.add_component("return", msg_ident, msg=msg)

cfg = pipe.get_config()

p2 = Pipeline.from_config(cfg)
assert len(p2.nodes) == 2
r2 = p2.node("return")
assert isinstance(r2, ComponentNode)
assert r2.component is msg_ident
assert r2.connections == {"msg": "msg"}

assert p2.run("return", msg="foo") == "foo"


def test_configurable_component():
pipe = Pipeline()
msg = pipe.create_input("msg", str)

pfx = Prefixer("scroll named ")
pipe.add_component("prefix", pfx, msg=msg)

cfg = pipe.get_config()
assert cfg.components["prefix"].config == {"prefix": "scroll named "}

p2 = Pipeline.from_config(cfg)
assert len(p2.nodes) == 2
r2 = p2.node("prefix")
assert isinstance(r2, ComponentNode)
assert isinstance(r2.component, Prefixer)
assert r2.component is not pfx
assert r2.connections == {"msg": "msg"}

assert p2.run("prefix", msg="HACKEM MUCHE") == "scroll named HACKEM MUCHE"

0 comments on commit 8fbe084

Please sign in to comment.