Skip to content

Commit

Permalink
include aliases in serialized configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Aug 8, 2024
1 parent 1411196 commit c85d425
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 0 deletions.
6 changes: 6 additions & 0 deletions lenskit/lenskit/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,8 @@ def get_config(self, *, include_hash: bool = True) -> PipelineConfig:
case _: # pragma: nocover
raise RuntimeError(f"invalid node {node}")

config.aliases = {a: t.name for (a, t) in self._aliases.items()}

if include_hash:
config.meta.hash = hash_config(config)

Expand Down Expand Up @@ -540,6 +542,10 @@ def from_config(cls, config: object) -> Self:
elif not comp.code.startswith("@"):
raise PipelineError(f"component {name} inputs must be dict, not list")

# pass 4: aliases
for n, t in cfg.aliases.items():
pipe.alias(n, t)

if cfg.meta.hash is not None:
h2 = pipe.config_hash()
if h2 != cfg.meta.hash:
Expand Down
1 change: 1 addition & 0 deletions lenskit/lenskit/pipeline/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class PipelineConfig(BaseModel):
meta: PipelineMeta
inputs: list[PipelineInput] = Field(default_factory=list)
components: OrderedDict[str, PipelineComponent] = Field(default_factory=OrderedDict)
aliases: dict[str, str] = Field(default_factory=dict)


class PipelineMeta(BaseModel):
Expand Down
29 changes: 29 additions & 0 deletions lenskit/tests/pipeline/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,32 @@ def test_hash_validate():

with warns(PipelineWarning):
Pipeline.from_config(cfg)


def test_alias_input():
"alias an input node"
pipe = Pipeline()
user = pipe.create_input("user", int, str)

pipe.alias("person", user)

cfg = pipe.get_config()

p2 = Pipeline.from_config(cfg)
assert p2.run("person", user=32) == 32


def test_alias_node():
pipe = Pipeline()
a = pipe.create_input("a", int)
b = pipe.create_input("b", int)

nd = pipe.add_component("double", double, x=a)
na = pipe.add_component("add", add, x=nd, y=b)
pipe.alias("result", na)

assert pipe.run("result", a=5, b=7) == 17

cfg = pipe.get_config()
p2 = Pipeline.from_config(cfg)
assert p2.run("result", a=5, b=7) == 17

0 comments on commit c85d425

Please sign in to comment.