From c85d425fb5bc4cabaca6dd7d25005bc85c8f3ae4 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 8 Aug 2024 16:56:41 -0400 Subject: [PATCH] include aliases in serialized configuration --- lenskit/lenskit/pipeline/__init__.py | 6 +++++ lenskit/lenskit/pipeline/config.py | 1 + lenskit/tests/pipeline/test_save_load.py | 29 ++++++++++++++++++++++++ 3 files changed, 36 insertions(+) diff --git a/lenskit/lenskit/pipeline/__init__.py b/lenskit/lenskit/pipeline/__init__.py index aecd811ef..965b0e5f4 100644 --- a/lenskit/lenskit/pipeline/__init__.py +++ b/lenskit/lenskit/pipeline/__init__.py @@ -477,6 +477,8 @@ def get_config(self, *, include_hash: bool = True) -> PipelineConfig: case _: # pragma: nocover raise RuntimeError(f"invalid node {node}") + config.aliases = {a: t.name for (a, t) in self._aliases.items()} + if include_hash: config.meta.hash = hash_config(config) @@ -540,6 +542,10 @@ def from_config(cls, config: object) -> Self: elif not comp.code.startswith("@"): raise PipelineError(f"component {name} inputs must be dict, not list") + # pass 4: aliases + for n, t in cfg.aliases.items(): + pipe.alias(n, t) + if cfg.meta.hash is not None: h2 = pipe.config_hash() if h2 != cfg.meta.hash: diff --git a/lenskit/lenskit/pipeline/config.py b/lenskit/lenskit/pipeline/config.py index dd96bd26c..3efdf4d98 100644 --- a/lenskit/lenskit/pipeline/config.py +++ b/lenskit/lenskit/pipeline/config.py @@ -27,6 +27,7 @@ class PipelineConfig(BaseModel): meta: PipelineMeta inputs: list[PipelineInput] = Field(default_factory=list) components: OrderedDict[str, PipelineComponent] = Field(default_factory=OrderedDict) + aliases: dict[str, str] = Field(default_factory=dict) class PipelineMeta(BaseModel): diff --git a/lenskit/tests/pipeline/test_save_load.py b/lenskit/tests/pipeline/test_save_load.py index 8272fa766..660338451 100644 --- a/lenskit/tests/pipeline/test_save_load.py +++ b/lenskit/tests/pipeline/test_save_load.py @@ -173,3 +173,32 @@ def test_hash_validate(): with warns(PipelineWarning): Pipeline.from_config(cfg) + + +def test_alias_input(): + "alias an input node" + pipe = Pipeline() + user = pipe.create_input("user", int, str) + + pipe.alias("person", user) + + cfg = pipe.get_config() + + p2 = Pipeline.from_config(cfg) + assert p2.run("person", user=32) == 32 + + +def test_alias_node(): + pipe = Pipeline() + a = pipe.create_input("a", int) + b = pipe.create_input("b", int) + + nd = pipe.add_component("double", double, x=a) + na = pipe.add_component("add", add, x=nd, y=b) + pipe.alias("result", na) + + assert pipe.run("result", a=5, b=7) == 17 + + cfg = pipe.get_config() + p2 = Pipeline.from_config(cfg) + assert p2.run("result", a=5, b=7) == 17