Skip to content

Commit

Permalink
implement fallback node serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Aug 8, 2024
1 parent 8c398df commit 6153b6e
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 7 deletions.
29 changes: 25 additions & 4 deletions lenskit/lenskit/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,10 @@ def get_config(self) -> PipelineConfig:
raise RuntimeError("literal nodes cannot be serialized to config")
case ComponentNode(name):
config.components[name] = PipelineComponent.from_node(node)
case FallbackNode():
raise NotImplementedError()
case FallbackNode(name, alternatives):
config.components[name] = PipelineComponent(
code="@use-first-of", inputs=[n.name for n in alternatives]
)
case _: # pragma: nocover
raise RuntimeError(f"invalid node {node}")

Expand All @@ -434,14 +436,33 @@ def from_config(cls, config: object) -> Self:
types += [parse_type_string(t) for t in inpt.types]
pipe.create_input(inpt.name, *types)

# pass 1: add components
to_wire: list[PipelineComponent] = []
for name, comp in cfg.components.items():
if comp.code.startswith("@"):
# ignore special nodes in first pass
continue

obj = instantiate_component(comp.code, comp.config)
pipe.add_component(name, obj)
to_wire.append(comp)

# pass 2: add meta nodes
for name, comp in cfg.components.items():
if comp.code == "@use-first-of":
if not isinstance(comp.inputs, list):
raise RuntimeError("@use-first-of must have input list, not dict")
pipe.use_first_of(name, *[pipe.node(n) for n in comp.inputs])
elif comp.code.startswith("@"):
raise RuntimeError(f"unsupported meta-component {comp.code}")

# pass 2: wiring
for name, comp in cfg.components.items():
inputs = {n: pipe.node(t) for (n, t) in comp.inputs.items()}
pipe.connect(name, **inputs)
if isinstance(comp.inputs, dict):
inputs = {n: pipe.node(t) for (n, t) in comp.inputs.items()}
pipe.connect(name, **inputs)
elif not comp.code.startswith("@"):
raise RuntimeError(f"component {name} inputs must be dict, not list")

return pipe

Expand Down
8 changes: 5 additions & 3 deletions lenskit/lenskit/pipeline/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# pyright: strict
from __future__ import annotations

from collections import OrderedDict
from types import FunctionType

from pydantic import BaseModel, Field
Expand All @@ -23,7 +24,7 @@ class PipelineConfig(BaseModel):
"""

inputs: list[PipelineInput] = Field(default_factory=list)
components: dict[str, PipelineComponent] = Field(default_factory=dict)
components: OrderedDict[str, PipelineComponent] = Field(default_factory=OrderedDict)


class PipelineInput(BaseModel):
Expand Down Expand Up @@ -55,9 +56,10 @@ class PipelineComponent(BaseModel):
with its default constructor parameters.
"""

inputs: dict[str, str] = Field(default_factory=dict)
inputs: dict[str, str] | list[str] = Field(default_factory=dict)
"""
The component's input wirings, mapping input names to node names.
The component's input wirings, mapping input names to node names. For
certain meta-nodes, it is specified as a list instead of a dict.
"""

@classmethod
Expand Down
34 changes: 34 additions & 0 deletions lenskit/tests/pipeline/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from lenskit.pipeline import InputNode, Node, Pipeline
from lenskit.pipeline.components import AutoConfig
from lenskit.pipeline.config import PipelineConfig
from lenskit.pipeline.nodes import ComponentNode


Expand Down Expand Up @@ -115,3 +116,36 @@ def test_configurable_component():
assert r2.connections == {"msg": "msg"}

assert p2.run("prefix", msg="HACKEM MUCHE") == "scroll named HACKEM MUCHE"


def negative(x: int) -> int:
return -x


def double(x: int) -> int:
return x * 2


def add(x: int, y: int) -> int:
return x + y


def test_save_with_fallback():
pipe = Pipeline()
a = pipe.create_input("a", int)
b = pipe.create_input("b", int)

nd = pipe.add_component("double", double, x=a)
nn = pipe.add_component("negate", negative, x=a)
fb = pipe.use_first_of("fill-operand", b, nn)
pipe.add_component("add", add, x=nd, y=fb)

cfg = pipe.get_config()
json = cfg.model_dump_json(exclude_none=True)
print(json)
c2 = PipelineConfig.model_validate_json(json)

p2 = Pipeline.from_config(c2)

# 3 * 2 + -3 = 3
assert p2.run("fill-operand", "add", a=3) == (-3, 3)

0 comments on commit 6153b6e

Please sign in to comment.