From 6153b6e1e53a467bd0afc9fed7ed6c024e004427 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 8 Aug 2024 14:56:28 -0400 Subject: [PATCH] implement fallback node serialization --- lenskit/lenskit/pipeline/__init__.py | 29 +++++++++++++++++--- lenskit/lenskit/pipeline/config.py | 8 +++--- lenskit/tests/pipeline/test_save_load.py | 34 ++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/lenskit/lenskit/pipeline/__init__.py b/lenskit/lenskit/pipeline/__init__.py index 860164fd3..fc3eda34f 100644 --- a/lenskit/lenskit/pipeline/__init__.py +++ b/lenskit/lenskit/pipeline/__init__.py @@ -417,8 +417,10 @@ def get_config(self) -> PipelineConfig: raise RuntimeError("literal nodes cannot be serialized to config") case ComponentNode(name): config.components[name] = PipelineComponent.from_node(node) - case FallbackNode(): - raise NotImplementedError() + case FallbackNode(name, alternatives): + config.components[name] = PipelineComponent( + code="@use-first-of", inputs=[n.name for n in alternatives] + ) case _: # pragma: nocover raise RuntimeError(f"invalid node {node}") @@ -434,14 +436,33 @@ def from_config(cls, config: object) -> Self: types += [parse_type_string(t) for t in inpt.types] pipe.create_input(inpt.name, *types) + # pass 1: add components + to_wire: list[PipelineComponent] = [] for name, comp in cfg.components.items(): + if comp.code.startswith("@"): + # ignore special nodes in first pass + continue + obj = instantiate_component(comp.code, comp.config) pipe.add_component(name, obj) + to_wire.append(comp) + + # pass 2: add meta nodes + for name, comp in cfg.components.items(): + if comp.code == "@use-first-of": + if not isinstance(comp.inputs, list): + raise RuntimeError("@use-first-of must have input list, not dict") + pipe.use_first_of(name, *[pipe.node(n) for n in comp.inputs]) + elif comp.code.startswith("@"): + raise RuntimeError(f"unsupported meta-component {comp.code}") # pass 2: wiring for name, comp in cfg.components.items(): - inputs = {n: pipe.node(t) for (n, t) in comp.inputs.items()} - pipe.connect(name, **inputs) + if isinstance(comp.inputs, dict): + inputs = {n: pipe.node(t) for (n, t) in comp.inputs.items()} + pipe.connect(name, **inputs) + elif not comp.code.startswith("@"): + raise RuntimeError(f"component {name} inputs must be dict, not list") return pipe diff --git a/lenskit/lenskit/pipeline/config.py b/lenskit/lenskit/pipeline/config.py index cd0432944..043fdffb8 100644 --- a/lenskit/lenskit/pipeline/config.py +++ b/lenskit/lenskit/pipeline/config.py @@ -5,6 +5,7 @@ # pyright: strict from __future__ import annotations +from collections import OrderedDict from types import FunctionType from pydantic import BaseModel, Field @@ -23,7 +24,7 @@ class PipelineConfig(BaseModel): """ inputs: list[PipelineInput] = Field(default_factory=list) - components: dict[str, PipelineComponent] = Field(default_factory=dict) + components: OrderedDict[str, PipelineComponent] = Field(default_factory=OrderedDict) class PipelineInput(BaseModel): @@ -55,9 +56,10 @@ class PipelineComponent(BaseModel): with its default constructor parameters. """ - inputs: dict[str, str] = Field(default_factory=dict) + inputs: dict[str, str] | list[str] = Field(default_factory=dict) """ - The component's input wirings, mapping input names to node names. + The component's input wirings, mapping input names to node names. For + certain meta-nodes, it is specified as a list instead of a dict. """ @classmethod diff --git a/lenskit/tests/pipeline/test_save_load.py b/lenskit/tests/pipeline/test_save_load.py index d864ef5c6..78b3dab99 100644 --- a/lenskit/tests/pipeline/test_save_load.py +++ b/lenskit/tests/pipeline/test_save_load.py @@ -4,6 +4,7 @@ from lenskit.pipeline import InputNode, Node, Pipeline from lenskit.pipeline.components import AutoConfig +from lenskit.pipeline.config import PipelineConfig from lenskit.pipeline.nodes import ComponentNode @@ -115,3 +116,36 @@ def test_configurable_component(): assert r2.connections == {"msg": "msg"} assert p2.run("prefix", msg="HACKEM MUCHE") == "scroll named HACKEM MUCHE" + + +def negative(x: int) -> int: + return -x + + +def double(x: int) -> int: + return x * 2 + + +def add(x: int, y: int) -> int: + return x + y + + +def test_save_with_fallback(): + pipe = Pipeline() + a = pipe.create_input("a", int) + b = pipe.create_input("b", int) + + nd = pipe.add_component("double", double, x=a) + nn = pipe.add_component("negate", negative, x=a) + fb = pipe.use_first_of("fill-operand", b, nn) + pipe.add_component("add", add, x=nd, y=fb) + + cfg = pipe.get_config() + json = cfg.model_dump_json(exclude_none=True) + print(json) + c2 = PipelineConfig.model_validate_json(json) + + p2 = Pipeline.from_config(c2) + + # 3 * 2 + -3 = 3 + assert p2.run("fill-operand", "add", a=3) == (-3, 3)