From 4538ba8b1e0c0cc24e93c3e3c5a1b8d22bb6b810 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 7 Aug 2024 16:34:43 -0400 Subject: [PATCH 01/24] add Pydantic dependency --- lenskit/pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lenskit/pyproject.toml b/lenskit/pyproject.toml index bfa38dd43..7949810d1 100644 --- a/lenskit/pyproject.toml +++ b/lenskit/pyproject.toml @@ -25,7 +25,8 @@ dependencies = [ "scipy >= 1.9.0", "torch ~=2.1", # conda: pytorch>=2.1,<3 "threadpoolctl >=3.0", - "seedbank >= 0.2.0a2", # conda: @pip + "pydantic >=2.8,<3", + "seedbank >=0.2.0a2", # conda: @pip "progress-api >=0.1.0a9", # conda: @pip "manylog >=0.1.0a5", # conda: @pip ] From fc61a1d2ec137384ae7c7a8f1a1602a32cf58cc9 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 7 Aug 2024 17:08:50 -0400 Subject: [PATCH 02/24] add string type representations --- lenskit/lenskit/pipeline/types.py | 37 +++++++++++++++++- .../{test_pipeline_types.py => test_types.py} | 38 ++++++++++++++++++- 2 files changed, 73 insertions(+), 2 deletions(-) rename lenskit/tests/pipeline/{test_pipeline_types.py => test_types.py} (78%) diff --git a/lenskit/lenskit/pipeline/types.py b/lenskit/lenskit/pipeline/types.py index 00c91c4b1..ef3d61434 100644 --- a/lenskit/lenskit/pipeline/types.py +++ b/lenskit/lenskit/pipeline/types.py @@ -7,8 +7,10 @@ # pyright: basic from __future__ import annotations +import re import warnings -from types import GenericAlias +from importlib import import_module +from types import GenericAlias, NoneType from typing import Union, _GenericAlias, get_args, get_origin # type: ignore import numpy as np @@ -118,3 +120,36 @@ def is_compatible_data(obj: object, *targets: type) -> bool: return True return False + + +def type_string(typ: type) -> str: + """ + Compute a string representation of a type that is both resolvable and + human-readable. Type parameterizations are lost. + """ + if typ is None or typ is NoneType: + return "None" + elif typ.__module__ == "builtins": + return typ.__name__ + else: + return f"{typ.__module__}.{typ.__name__}" + + +def parse_type_string(tstr: str) -> type: + """ + Compute a string representation of a type that is both resolvable and + human-readable. Type parameterizations are lost. + """ + if tstr == "None": + return NoneType + elif re.match(r"^\w+$", tstr): + return __builtins__[tstr] + else: + # separate last element from module + parts = re.match(r"(.*)\.(\w+)$", tstr) + if not parts: + raise ValueError(f"unparsable type string {tstr}") + + mod_name, typ_name = parts.groups() + mod = import_module(mod_name) + return getattr(mod, typ_name) diff --git a/lenskit/tests/pipeline/test_pipeline_types.py b/lenskit/tests/pipeline/test_types.py similarity index 78% rename from lenskit/tests/pipeline/test_pipeline_types.py rename to lenskit/tests/pipeline/test_types.py index f2a9add0b..17b3cfccb 100644 --- a/lenskit/tests/pipeline/test_pipeline_types.py +++ b/lenskit/tests/pipeline/test_types.py @@ -10,6 +10,8 @@ import typing from collections.abc import Iterable, Sequence +from pathlib import Path +from types import NoneType import numpy as np import pandas as pd @@ -18,7 +20,13 @@ from pytest import warns from lenskit.data.dataset import Dataset, MatrixDataset -from lenskit.pipeline.types import TypecheckWarning, is_compatible_data, is_compatible_type +from lenskit.pipeline.types import ( + TypecheckWarning, + is_compatible_data, + is_compatible_type, + parse_type_string, + type_string, +) def test_type_compat_identical(): @@ -91,3 +99,31 @@ def test_numpy_typecheck(): def test_pandas_typecheck(): assert is_compatible_data(pd.Series(["a", "b"]), ArrayLike) + + +def test_type_string_none(): + assert type_string(None) == "None" + + +def test_type_string_str(): + assert type_string(str) == "str" + + +def test_type_string_generic(): + assert type_string(list[str]) == "list" + + +def test_type_string_class(): + assert type_string(Path) == "pathlib.Path" + + +def test_parse_string_None(): + assert parse_type_string("None") == NoneType + + +def test_parse_string_int(): + assert parse_type_string("int") is int + + +def test_parse_string_class(): + assert parse_type_string("pathlib.Path") is Path From 79432031ce5b92c51a50ed62f1576dd33a64c1db Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 7 Aug 2024 17:09:17 -0400 Subject: [PATCH 03/24] move and split pipeline clone / config --- .../tests/pipeline/test_component_config.py | 48 +++++++++++++++++++ ...eline_config.py => test_pipeline_clone.py} | 27 ----------- 2 files changed, 48 insertions(+), 27 deletions(-) create mode 100644 lenskit/tests/pipeline/test_component_config.py rename lenskit/tests/pipeline/{test_pipeline_config.py => test_pipeline_clone.py} (76%) diff --git a/lenskit/tests/pipeline/test_component_config.py b/lenskit/tests/pipeline/test_component_config.py new file mode 100644 index 000000000..231f50253 --- /dev/null +++ b/lenskit/tests/pipeline/test_component_config.py @@ -0,0 +1,48 @@ +# This file is part of LensKit. +# Copyright (C) 2018-2023 Boise State University +# Copyright (C) 2023-2024 Drexel University +# Licensed under the MIT license, see LICENSE.md for details. +# SPDX-License-Identifier: MIT + +import json + +from lenskit.pipeline import Pipeline +from lenskit.pipeline.components import AutoConfig +from lenskit.pipeline.nodes import ComponentNode + + +class Prefixer(AutoConfig): + prefix: str + + def __init__(self, prefix: str = "hello"): + self.prefix = prefix + + def __call__(self, msg: str) -> str: + return self.prefix + msg + + +def test_auto_config_roundtrip(): + comp = Prefixer("FOOBIE BLETCH") + + cfg = comp.get_config() + assert "prefix" in cfg + + c2 = Prefixer.from_config(cfg) + assert c2 is not comp + assert c2.prefix == comp.prefix + + +def test_pipeline_config(): + comp = Prefixer("scroll named ") + + pipe = Pipeline() + msg = pipe.create_input("msg", str) + pipe.add_component("prefix", comp, msg=msg) + + assert pipe.run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH" + + config = pipe.component_configs() + print(json.dumps(config, indent=2)) + + assert "prefix" in config + assert config["prefix"]["prefix"] == "scroll named " diff --git a/lenskit/tests/pipeline/test_pipeline_config.py b/lenskit/tests/pipeline/test_pipeline_clone.py similarity index 76% rename from lenskit/tests/pipeline/test_pipeline_config.py rename to lenskit/tests/pipeline/test_pipeline_clone.py index a73d89835..79d77cb50 100644 --- a/lenskit/tests/pipeline/test_pipeline_config.py +++ b/lenskit/tests/pipeline/test_pipeline_clone.py @@ -32,33 +32,6 @@ def exclaim(msg: str) -> str: return msg + "!" -def test_auto_config_roundtrip(): - comp = Prefixer("FOOBIE BLETCH") - - cfg = comp.get_config() - assert "prefix" in cfg - - c2 = Prefixer.from_config(cfg) - assert c2 is not comp - assert c2.prefix == comp.prefix - - -def test_pipeline_config(): - comp = Prefixer("scroll named ") - - pipe = Pipeline() - msg = pipe.create_input("msg", str) - pipe.add_component("prefix", comp, msg=msg) - - assert pipe.run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH" - - config = pipe.component_configs() - print(json.dumps(config, indent=2)) - - assert "prefix" in config - assert config["prefix"]["prefix"] == "scroll named " - - def test_pipeline_clone(): comp = Prefixer("scroll named ") From 7577d1e1d4568ab49acab66e45b715040aca0720 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 7 Aug 2024 17:09:33 -0400 Subject: [PATCH 04/24] add pydantic to docs --- docs/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/conf.py b/docs/conf.py index bad22e344..f606f4208 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -101,6 +101,7 @@ "manylog": ("https://manylog.readthedocs.io/en/latest/", None), "torch": ("https://pytorch.org/docs/stable/", None), "implicit": ("https://benfred.github.io/implicit/", None), + "pydantic": ("https://docs.pydantic.dev/latest/", None), } bibtex_bibfiles = ["lenskit.bib"] From 176df23d6b29f0ea0701f3e8207532b7b5e4e849 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 7 Aug 2024 17:11:08 -0400 Subject: [PATCH 05/24] add initial config model + input node + test --- lenskit/lenskit/pipeline/config.py | 39 ++++++++++++++++++++++ lenskit/tests/pipeline/test_config_node.py | 34 +++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 lenskit/lenskit/pipeline/config.py create mode 100644 lenskit/tests/pipeline/test_config_node.py diff --git a/lenskit/lenskit/pipeline/config.py b/lenskit/lenskit/pipeline/config.py new file mode 100644 index 000000000..1e38e8f58 --- /dev/null +++ b/lenskit/lenskit/pipeline/config.py @@ -0,0 +1,39 @@ +""" +Pydantic models for pipeline configuration and serialization support. +""" + +# pyright: strict +from __future__ import annotations + +from pydantic import BaseModel +from typing_extensions import Any, Optional, Self + +from lenskit.pipeline.types import type_string + +from .nodes import InputNode + + +class PipelineConfig(BaseModel): + """ + Root type for serialized pipeline configuration. A pipeline config contains + the full configuration, components, and wiring for the pipeline, but does + not contain the + """ + + inputs: list[PipelineInput] + + +class PipelineInput(BaseModel): + name: str + "The name for this input." + types: Optional[list[str]] + "The list of types for this input." + + @classmethod + def from_node(cls, node: InputNode[Any]) -> Self: + if node.types is not None: + types = [type_string(t) for t in node.types] + else: + types = None + + return cls(name=node.name, types=types) diff --git a/lenskit/tests/pipeline/test_config_node.py b/lenskit/tests/pipeline/test_config_node.py new file mode 100644 index 000000000..6e80a9b8b --- /dev/null +++ b/lenskit/tests/pipeline/test_config_node.py @@ -0,0 +1,34 @@ +from lenskit.pipeline.config import PipelineInput +from lenskit.pipeline.nodes import InputNode + + +def test_untyped_input(): + node = InputNode("scroll") + + cfg = PipelineInput.from_node(node) + assert cfg.name == "scroll" + assert cfg.types is None + + +def test_input_with_type(): + node = InputNode("scroll", types={str}) + + cfg = PipelineInput.from_node(node) + assert cfg.name == "scroll" + assert cfg.types == ["str"] + + +def test_input_with_none(): + node = InputNode("scroll", types={str, type(None)}) + + cfg = PipelineInput.from_node(node) + assert cfg.name == "scroll" + assert cfg.types == ["str", "None"] + + +def test_input_with_generic(): + node = InputNode("scroll", types={list[str]}) + + cfg = PipelineInput.from_node(node) + assert cfg.name == "scroll" + assert cfg.types == ["list"] From ff4d5981aba2e5df4eee7d45abaf28f097fbd23c Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 7 Aug 2024 17:17:27 -0400 Subject: [PATCH 06/24] store types in sets --- lenskit/lenskit/pipeline/config.py | 4 ++-- lenskit/lenskit/pipeline/types.py | 2 +- lenskit/tests/pipeline/test_config_node.py | 10 +++++++--- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/lenskit/lenskit/pipeline/config.py b/lenskit/lenskit/pipeline/config.py index 1e38e8f58..53e9fd3a1 100644 --- a/lenskit/lenskit/pipeline/config.py +++ b/lenskit/lenskit/pipeline/config.py @@ -26,13 +26,13 @@ class PipelineConfig(BaseModel): class PipelineInput(BaseModel): name: str "The name for this input." - types: Optional[list[str]] + types: Optional[set[str]] "The list of types for this input." @classmethod def from_node(cls, node: InputNode[Any]) -> Self: if node.types is not None: - types = [type_string(t) for t in node.types] + types = {type_string(t) for t in node.types} else: types = None diff --git a/lenskit/lenskit/pipeline/types.py b/lenskit/lenskit/pipeline/types.py index ef3d61434..19b1f5d8f 100644 --- a/lenskit/lenskit/pipeline/types.py +++ b/lenskit/lenskit/pipeline/types.py @@ -122,7 +122,7 @@ def is_compatible_data(obj: object, *targets: type) -> bool: return False -def type_string(typ: type) -> str: +def type_string(typ: type | None) -> str: """ Compute a string representation of a type that is both resolvable and human-readable. Type parameterizations are lost. diff --git a/lenskit/tests/pipeline/test_config_node.py b/lenskit/tests/pipeline/test_config_node.py index 6e80a9b8b..4f44b325c 100644 --- a/lenskit/tests/pipeline/test_config_node.py +++ b/lenskit/tests/pipeline/test_config_node.py @@ -6,6 +6,7 @@ def test_untyped_input(): node = InputNode("scroll") cfg = PipelineInput.from_node(node) + print(cfg) assert cfg.name == "scroll" assert cfg.types is None @@ -14,21 +15,24 @@ def test_input_with_type(): node = InputNode("scroll", types={str}) cfg = PipelineInput.from_node(node) + print(cfg) assert cfg.name == "scroll" - assert cfg.types == ["str"] + assert cfg.types == {"str"} def test_input_with_none(): node = InputNode("scroll", types={str, type(None)}) cfg = PipelineInput.from_node(node) + print(cfg) assert cfg.name == "scroll" - assert cfg.types == ["str", "None"] + assert cfg.types == {"None", "str"} def test_input_with_generic(): node = InputNode("scroll", types={list[str]}) cfg = PipelineInput.from_node(node) + print(cfg) assert cfg.name == "scroll" - assert cfg.types == ["list"] + assert cfg.types == {"list"} From 8ee677a0ea655474d457d2ffee8659dfc3cf7d10 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 7 Aug 2024 17:17:37 -0400 Subject: [PATCH 07/24] fix type error in pipeline construction --- lenskit/lenskit/pipeline/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lenskit/lenskit/pipeline/__init__.py b/lenskit/lenskit/pipeline/__init__.py index 96888d53b..ec75e0671 100644 --- a/lenskit/lenskit/pipeline/__init__.py +++ b/lenskit/lenskit/pipeline/__init__.py @@ -139,7 +139,7 @@ def create_input(self, name: str, *types: type[T] | None) -> Node[T]: """ self._check_available_name(name) - node = InputNode[Any](name, types=set((t if t is not None else type[None]) for t in types)) + node = InputNode[Any](name, types=set((t if t is not None else type(None)) for t in types)) self._nodes[name] = node self._clear_caches() return node From 1ba6aa13ff3b40093241c44af3e39d4660e9097b Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 7 Aug 2024 17:21:04 -0400 Subject: [PATCH 08/24] Serialize and round-trip component inputs --- lenskit/lenskit/pipeline/__init__.py | 34 ++++++++++++++++- lenskit/tests/pipeline/test_save_load.py | 47 ++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 lenskit/tests/pipeline/test_save_load.py diff --git a/lenskit/lenskit/pipeline/__init__.py b/lenskit/lenskit/pipeline/__init__.py index ec75e0671..46d04fe09 100644 --- a/lenskit/lenskit/pipeline/__init__.py +++ b/lenskit/lenskit/pipeline/__init__.py @@ -16,9 +16,10 @@ from typing import Literal, cast from uuid import uuid4 -from typing_extensions import Any, LiteralString, TypeVar, overload +from typing_extensions import Any, LiteralString, Self, TypeVar, overload from lenskit.data import Dataset +from lenskit.pipeline.types import parse_type_string from .components import ( AutoConfig, # noqa: F401 # type: ignore @@ -26,6 +27,7 @@ ConfigurableComponent, TrainableComponent, ) +from .config import PipelineConfig, PipelineInput from .nodes import ND, ComponentNode, FallbackNode, InputNode, LiteralNode, Node from .state import PipelineState @@ -36,6 +38,7 @@ "Component", "ConfigurableComponent", "TrainableComponent", + "PipelineConfig", ] _log = logging.getLogger(__name__) @@ -392,6 +395,35 @@ def clone(self, *, params: bool = False) -> Pipeline: return clone + def get_config(self) -> PipelineConfig: + """ + Get this pipeline's configuration for serialization. The configuration + consists of all inputs and components along with their configurations + and input connections. It can be serialized to disk (in JSON, YAML, or + a similar format) to save a pipeline. + + The configuration does **not** include any trained parameter values, + although the configuration may include things such as paths to + checkpoints to load such parameters, depending on the design of the + components in the pipeline. + """ + inputs = [ + PipelineInput.from_node(node) for node in self.nodes if isinstance(node, InputNode) + ] + return PipelineConfig(inputs=inputs) + + @classmethod + def from_config(cls, config: object) -> Self: + cfg = PipelineConfig.model_validate(config) + pipe = cls() + for inpt in cfg.inputs: + types: list[type[Any] | None] = [] + if inpt.types is not None: + types += [parse_type_string(t) for t in inpt.types] + pipe.create_input(inpt.name, *types) + + return pipe + def train(self, data: Dataset) -> None: """ Trains the pipeline's trainable components (those implementing the diff --git a/lenskit/tests/pipeline/test_save_load.py b/lenskit/tests/pipeline/test_save_load.py new file mode 100644 index 000000000..377a1bcf5 --- /dev/null +++ b/lenskit/tests/pipeline/test_save_load.py @@ -0,0 +1,47 @@ +from types import NoneType + +from typing_extensions import assert_type + +from lenskit.pipeline import InputNode, Node, Pipeline + + +def test_serialize_input(): + "serialize with one input node" + pipe = Pipeline() + pipe.create_input("user", int, str) + + cfg = pipe.get_config() + print(cfg) + assert len(cfg.inputs) == 1 + assert cfg.inputs[0].name == "user" + assert cfg.inputs[0].types == {"int", "str"} + + +def test_round_trip_input(): + "serialize with one input node" + pipe = Pipeline() + pipe.create_input("user", int, str) + + cfg = pipe.get_config() + print(cfg) + + p2 = Pipeline.from_config(cfg) + i2 = p2.node("user") + assert isinstance(i2, InputNode) + assert i2.name == "user" + assert i2.types == {int, str} + + +def test_round_trip_optional_input(): + "serialize with one input node" + pipe = Pipeline() + pipe.create_input("user", int, str, None) + + cfg = pipe.get_config() + assert cfg.inputs[0].types == {"int", "str", "None"} + + p2 = Pipeline.from_config(cfg) + i2 = p2.node("user") + assert isinstance(i2, InputNode) + assert i2.name == "user" + assert i2.types == {int, str, NoneType} From 0691cce7aab1d915bba2910df5891cfda1023635 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 7 Aug 2024 17:22:41 -0400 Subject: [PATCH 09/24] parse qualified (colon-separated) types --- lenskit/lenskit/pipeline/types.py | 13 ++++++++----- lenskit/tests/pipeline/test_types.py | 4 ++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/lenskit/lenskit/pipeline/types.py b/lenskit/lenskit/pipeline/types.py index 19b1f5d8f..0b840be25 100644 --- a/lenskit/lenskit/pipeline/types.py +++ b/lenskit/lenskit/pipeline/types.py @@ -145,11 +145,14 @@ def parse_type_string(tstr: str) -> type: elif re.match(r"^\w+$", tstr): return __builtins__[tstr] else: - # separate last element from module - parts = re.match(r"(.*)\.(\w+)$", tstr) - if not parts: - raise ValueError(f"unparsable type string {tstr}") + if ":" in tstr: + mod_name, typ_name = tstr.split(":", 1) + else: + # separate last element from module + parts = re.match(r"(.*)\.(\w+)$", tstr) + if not parts: + raise ValueError(f"unparsable type string {tstr}") + mod_name, typ_name = parts.groups() - mod_name, typ_name = parts.groups() mod = import_module(mod_name) return getattr(mod, typ_name) diff --git a/lenskit/tests/pipeline/test_types.py b/lenskit/tests/pipeline/test_types.py index 17b3cfccb..e4ee06d90 100644 --- a/lenskit/tests/pipeline/test_types.py +++ b/lenskit/tests/pipeline/test_types.py @@ -127,3 +127,7 @@ def test_parse_string_int(): def test_parse_string_class(): assert parse_type_string("pathlib.Path") is Path + + +def test_parse_string_mod_class(): + assert parse_type_string("pathlib:Path") is Path From b0f44d25fdbb892bc9e76d694fa3235000591eac Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 7 Aug 2024 17:23:50 -0400 Subject: [PATCH 10/24] serialize to qualified names --- lenskit/lenskit/pipeline/types.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lenskit/lenskit/pipeline/types.py b/lenskit/lenskit/pipeline/types.py index 0b840be25..35da340bb 100644 --- a/lenskit/lenskit/pipeline/types.py +++ b/lenskit/lenskit/pipeline/types.py @@ -131,8 +131,10 @@ def type_string(typ: type | None) -> str: return "None" elif typ.__module__ == "builtins": return typ.__name__ - else: + elif typ.__qualname__ == typ.__name__: return f"{typ.__module__}.{typ.__name__}" + else: + return f"{typ.__module__}:{typ.__qualname__}" def parse_type_string(tstr: str) -> type: From 8fbe084f6b1fc9449ab1d61f8f5a83a6c5513f23 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 7 Aug 2024 17:59:14 -0400 Subject: [PATCH 11/24] support round-trip configuration --- lenskit/lenskit/pipeline/__init__.py | 30 +++++++--- lenskit/lenskit/pipeline/components.py | 23 ++++++++ lenskit/lenskit/pipeline/config.py | 44 +++++++++++++-- lenskit/tests/pipeline/test_save_load.py | 70 ++++++++++++++++++++++++ 4 files changed, 156 insertions(+), 11 deletions(-) diff --git a/lenskit/lenskit/pipeline/__init__.py b/lenskit/lenskit/pipeline/__init__.py index 46d04fe09..07230f910 100644 --- a/lenskit/lenskit/pipeline/__init__.py +++ b/lenskit/lenskit/pipeline/__init__.py @@ -26,8 +26,9 @@ Component, ConfigurableComponent, TrainableComponent, + instantiate_component, ) -from .config import PipelineConfig, PipelineInput +from .config import PipelineComponent, PipelineConfig, PipelineInput from .nodes import ND, ComponentNode, FallbackNode, InputNode, LiteralNode, Node from .state import PipelineState @@ -384,9 +385,9 @@ def clone(self, *, params: bool = False) -> Pipeline: if isinstance(comp, FunctionType): comp = comp elif isinstance(comp, ConfigurableComponent): - comp = comp.__class__.from_config(comp.get_config()) + comp = comp.__class__.from_config(comp.get_config()) # type: ignore else: - comp = comp.__class__() + comp = comp.__class__() # type: ignore cn = clone.add_component(node.name, comp) # type: ignore for wn, wt in wiring.items(): clone.connect(cn, **{wn: clone.node(wt)}) @@ -407,10 +408,16 @@ def get_config(self) -> PipelineConfig: checkpoints to load such parameters, depending on the design of the components in the pipeline. """ - inputs = [ - PipelineInput.from_node(node) for node in self.nodes if isinstance(node, InputNode) - ] - return PipelineConfig(inputs=inputs) + return PipelineConfig( + inputs=[ + PipelineInput.from_node(node) for node in self.nodes if isinstance(node, InputNode) + ], + components={ + node.name: PipelineComponent.from_node(node) + for node in self.nodes + if isinstance(node, ComponentNode) + }, + ) @classmethod def from_config(cls, config: object) -> Self: @@ -422,6 +429,15 @@ def from_config(cls, config: object) -> Self: types += [parse_type_string(t) for t in inpt.types] pipe.create_input(inpt.name, *types) + for name, comp in cfg.components.items(): + obj = instantiate_component(comp.code, comp.config) + pipe.add_component(name, obj) + + # pass 2: wiring + for name, comp in cfg.components.items(): + inputs = {n: pipe.node(t) for (n, t) in comp.inputs.items()} + pipe.connect(name, **inputs) + return pipe def train(self, data: Dataset) -> None: diff --git a/lenskit/lenskit/pipeline/components.py b/lenskit/lenskit/pipeline/components.py index 3d6d59ec4..bae0f79c3 100644 --- a/lenskit/lenskit/pipeline/components.py +++ b/lenskit/lenskit/pipeline/components.py @@ -10,6 +10,8 @@ from __future__ import annotations import inspect +from importlib import import_module +from types import FunctionType from typing import Callable, ClassVar, TypeAlias from typing_extensions import Any, Generic, Protocol, Self, TypeVar, override, runtime_checkable @@ -184,3 +186,24 @@ def from_config(cls, cfg: dict[str, Any]) -> Self: are passed to the constructor as keywrod arguments. """ return cls(**cfg) + + +def instantiate_component( + comp: str | type | FunctionType, config: dict[str, Any] | None +) -> Callable[..., object]: + if isinstance(comp, str): + mname, oname = comp.split(":", 1) + mod = import_module(mname) + comp = getattr(mod, oname) + + # make the type checker happy + assert not isinstance(comp, str) + + if isinstance(comp, FunctionType): + return comp + elif issubclass(comp, ConfigurableComponent): + if config is None: + config = {} + return comp.from_config(config) # type: ignore + else: + return comp() # type: ignore diff --git a/lenskit/lenskit/pipeline/config.py b/lenskit/lenskit/pipeline/config.py index 53e9fd3a1..af29e0e4a 100644 --- a/lenskit/lenskit/pipeline/config.py +++ b/lenskit/lenskit/pipeline/config.py @@ -5,12 +5,14 @@ # pyright: strict from __future__ import annotations -from pydantic import BaseModel -from typing_extensions import Any, Optional, Self +from types import FunctionType -from lenskit.pipeline.types import type_string +from pydantic import BaseModel, Field +from typing_extensions import Any, Optional, Self -from .nodes import InputNode +from .components import ConfigurableComponent +from .nodes import ComponentNode, InputNode +from .types import type_string class PipelineConfig(BaseModel): @@ -21,6 +23,7 @@ class PipelineConfig(BaseModel): """ inputs: list[PipelineInput] + components: dict[str, PipelineComponent] = Field(default_factory=dict) class PipelineInput(BaseModel): @@ -37,3 +40,36 @@ def from_node(cls, node: InputNode[Any]) -> Self: types = None return cls(name=node.name, types=types) + + +class PipelineComponent(BaseModel): + code: str + """ + The path to the component's implementation, either a class or a function. + This is a Python qualified path of the form ``module:name``. + """ + + config: dict[str, object] | None = Field(default=None) + """ + The component configuration. If not provided, the component will be created + with its default constructor parameters. + """ + + inputs: dict[str, str] = Field(default_factory=dict) + """ + The component's input wirings, mapping input names to node names. + """ + + @classmethod + def from_node(cls, node: ComponentNode[Any]) -> Self: + comp = node.component + if isinstance(comp, FunctionType): + ctype = comp + else: + ctype = comp.__class__ + + code = f"{ctype.__module__}:{ctype.__qualname__}" + + config = comp.get_config() if isinstance(comp, ConfigurableComponent) else None + + return cls(code=code, config=config, inputs=node.connections) diff --git a/lenskit/tests/pipeline/test_save_load.py b/lenskit/tests/pipeline/test_save_load.py index 377a1bcf5..d864ef5c6 100644 --- a/lenskit/tests/pipeline/test_save_load.py +++ b/lenskit/tests/pipeline/test_save_load.py @@ -3,6 +3,18 @@ from typing_extensions import assert_type from lenskit.pipeline import InputNode, Node, Pipeline +from lenskit.pipeline.components import AutoConfig +from lenskit.pipeline.nodes import ComponentNode + + +class Prefixer(AutoConfig): + prefix: str + + def __init__(self, prefix: str = "hello"): + self.prefix = prefix + + def __call__(self, msg: str) -> str: + return self.prefix + msg def test_serialize_input(): @@ -45,3 +57,61 @@ def test_round_trip_optional_input(): assert isinstance(i2, InputNode) assert i2.name == "user" assert i2.types == {int, str, NoneType} + + +def msg_ident(msg: str) -> str: + return msg + + +def test_config_single_node(): + pipe = Pipeline() + msg = pipe.create_input("msg", str) + + pipe.add_component("return", msg_ident, msg=msg) + + cfg = pipe.get_config() + assert len(cfg.inputs) == 1 + assert len(cfg.components) == 1 + + assert cfg.components["return"].code == "lenskit.tests.pipeline.test_save_load:msg_ident" + assert cfg.components["return"].config is None + assert cfg.components["return"].inputs == {"msg": "msg"} + + +def test_round_trip_single_node(): + pipe = Pipeline() + msg = pipe.create_input("msg", str) + + pipe.add_component("return", msg_ident, msg=msg) + + cfg = pipe.get_config() + + p2 = Pipeline.from_config(cfg) + assert len(p2.nodes) == 2 + r2 = p2.node("return") + assert isinstance(r2, ComponentNode) + assert r2.component is msg_ident + assert r2.connections == {"msg": "msg"} + + assert p2.run("return", msg="foo") == "foo" + + +def test_configurable_component(): + pipe = Pipeline() + msg = pipe.create_input("msg", str) + + pfx = Prefixer("scroll named ") + pipe.add_component("prefix", pfx, msg=msg) + + cfg = pipe.get_config() + assert cfg.components["prefix"].config == {"prefix": "scroll named "} + + p2 = Pipeline.from_config(cfg) + assert len(p2.nodes) == 2 + r2 = p2.node("prefix") + assert isinstance(r2, ComponentNode) + assert isinstance(r2.component, Prefixer) + assert r2.component is not pfx + assert r2.connections == {"msg": "msg"} + + assert p2.run("prefix", msg="HACKEM MUCHE") == "scroll named HACKEM MUCHE" From 8c398df85c0ef44a5cf7395679166982ce269940 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 8 Aug 2024 14:30:37 -0400 Subject: [PATCH 12/24] refactor configuration with pattern-matching --- lenskit/lenskit/pipeline/__init__.py | 25 +++++++++++++++---------- lenskit/lenskit/pipeline/config.py | 2 +- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/lenskit/lenskit/pipeline/__init__.py b/lenskit/lenskit/pipeline/__init__.py index 07230f910..860164fd3 100644 --- a/lenskit/lenskit/pipeline/__init__.py +++ b/lenskit/lenskit/pipeline/__init__.py @@ -408,16 +408,21 @@ def get_config(self) -> PipelineConfig: checkpoints to load such parameters, depending on the design of the components in the pipeline. """ - return PipelineConfig( - inputs=[ - PipelineInput.from_node(node) for node in self.nodes if isinstance(node, InputNode) - ], - components={ - node.name: PipelineComponent.from_node(node) - for node in self.nodes - if isinstance(node, ComponentNode) - }, - ) + config = PipelineConfig() + for node in self.nodes: + match node: + case InputNode(): + config.inputs.append(PipelineInput.from_node(node)) + case LiteralNode(): + raise RuntimeError("literal nodes cannot be serialized to config") + case ComponentNode(name): + config.components[name] = PipelineComponent.from_node(node) + case FallbackNode(): + raise NotImplementedError() + case _: # pragma: nocover + raise RuntimeError(f"invalid node {node}") + + return config @classmethod def from_config(cls, config: object) -> Self: diff --git a/lenskit/lenskit/pipeline/config.py b/lenskit/lenskit/pipeline/config.py index af29e0e4a..cd0432944 100644 --- a/lenskit/lenskit/pipeline/config.py +++ b/lenskit/lenskit/pipeline/config.py @@ -22,7 +22,7 @@ class PipelineConfig(BaseModel): not contain the """ - inputs: list[PipelineInput] + inputs: list[PipelineInput] = Field(default_factory=list) components: dict[str, PipelineComponent] = Field(default_factory=dict) From 6153b6e1e53a467bd0afc9fed7ed6c024e004427 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 8 Aug 2024 14:56:28 -0400 Subject: [PATCH 13/24] implement fallback node serialization --- lenskit/lenskit/pipeline/__init__.py | 29 +++++++++++++++++--- lenskit/lenskit/pipeline/config.py | 8 +++--- lenskit/tests/pipeline/test_save_load.py | 34 ++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/lenskit/lenskit/pipeline/__init__.py b/lenskit/lenskit/pipeline/__init__.py index 860164fd3..fc3eda34f 100644 --- a/lenskit/lenskit/pipeline/__init__.py +++ b/lenskit/lenskit/pipeline/__init__.py @@ -417,8 +417,10 @@ def get_config(self) -> PipelineConfig: raise RuntimeError("literal nodes cannot be serialized to config") case ComponentNode(name): config.components[name] = PipelineComponent.from_node(node) - case FallbackNode(): - raise NotImplementedError() + case FallbackNode(name, alternatives): + config.components[name] = PipelineComponent( + code="@use-first-of", inputs=[n.name for n in alternatives] + ) case _: # pragma: nocover raise RuntimeError(f"invalid node {node}") @@ -434,14 +436,33 @@ def from_config(cls, config: object) -> Self: types += [parse_type_string(t) for t in inpt.types] pipe.create_input(inpt.name, *types) + # pass 1: add components + to_wire: list[PipelineComponent] = [] for name, comp in cfg.components.items(): + if comp.code.startswith("@"): + # ignore special nodes in first pass + continue + obj = instantiate_component(comp.code, comp.config) pipe.add_component(name, obj) + to_wire.append(comp) + + # pass 2: add meta nodes + for name, comp in cfg.components.items(): + if comp.code == "@use-first-of": + if not isinstance(comp.inputs, list): + raise RuntimeError("@use-first-of must have input list, not dict") + pipe.use_first_of(name, *[pipe.node(n) for n in comp.inputs]) + elif comp.code.startswith("@"): + raise RuntimeError(f"unsupported meta-component {comp.code}") # pass 2: wiring for name, comp in cfg.components.items(): - inputs = {n: pipe.node(t) for (n, t) in comp.inputs.items()} - pipe.connect(name, **inputs) + if isinstance(comp.inputs, dict): + inputs = {n: pipe.node(t) for (n, t) in comp.inputs.items()} + pipe.connect(name, **inputs) + elif not comp.code.startswith("@"): + raise RuntimeError(f"component {name} inputs must be dict, not list") return pipe diff --git a/lenskit/lenskit/pipeline/config.py b/lenskit/lenskit/pipeline/config.py index cd0432944..043fdffb8 100644 --- a/lenskit/lenskit/pipeline/config.py +++ b/lenskit/lenskit/pipeline/config.py @@ -5,6 +5,7 @@ # pyright: strict from __future__ import annotations +from collections import OrderedDict from types import FunctionType from pydantic import BaseModel, Field @@ -23,7 +24,7 @@ class PipelineConfig(BaseModel): """ inputs: list[PipelineInput] = Field(default_factory=list) - components: dict[str, PipelineComponent] = Field(default_factory=dict) + components: OrderedDict[str, PipelineComponent] = Field(default_factory=OrderedDict) class PipelineInput(BaseModel): @@ -55,9 +56,10 @@ class PipelineComponent(BaseModel): with its default constructor parameters. """ - inputs: dict[str, str] = Field(default_factory=dict) + inputs: dict[str, str] | list[str] = Field(default_factory=dict) """ - The component's input wirings, mapping input names to node names. + The component's input wirings, mapping input names to node names. For + certain meta-nodes, it is specified as a list instead of a dict. """ @classmethod diff --git a/lenskit/tests/pipeline/test_save_load.py b/lenskit/tests/pipeline/test_save_load.py index d864ef5c6..78b3dab99 100644 --- a/lenskit/tests/pipeline/test_save_load.py +++ b/lenskit/tests/pipeline/test_save_load.py @@ -4,6 +4,7 @@ from lenskit.pipeline import InputNode, Node, Pipeline from lenskit.pipeline.components import AutoConfig +from lenskit.pipeline.config import PipelineConfig from lenskit.pipeline.nodes import ComponentNode @@ -115,3 +116,36 @@ def test_configurable_component(): assert r2.connections == {"msg": "msg"} assert p2.run("prefix", msg="HACKEM MUCHE") == "scroll named HACKEM MUCHE" + + +def negative(x: int) -> int: + return -x + + +def double(x: int) -> int: + return x * 2 + + +def add(x: int, y: int) -> int: + return x + y + + +def test_save_with_fallback(): + pipe = Pipeline() + a = pipe.create_input("a", int) + b = pipe.create_input("b", int) + + nd = pipe.add_component("double", double, x=a) + nn = pipe.add_component("negate", negative, x=a) + fb = pipe.use_first_of("fill-operand", b, nn) + pipe.add_component("add", add, x=nd, y=fb) + + cfg = pipe.get_config() + json = cfg.model_dump_json(exclude_none=True) + print(json) + c2 = PipelineConfig.model_validate_json(json) + + p2 = Pipeline.from_config(c2) + + # 3 * 2 + -3 = 3 + assert p2.run("fill-operand", "add", a=3) == (-3, 3) From e43f51b9efbeccbe89a5bc520c9eee86db7bbab1 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 8 Aug 2024 14:59:05 -0400 Subject: [PATCH 14/24] document lack of support for serializing literal nodes --- lenskit/lenskit/pipeline/__init__.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/lenskit/lenskit/pipeline/__init__.py b/lenskit/lenskit/pipeline/__init__.py index fc3eda34f..c02d62652 100644 --- a/lenskit/lenskit/pipeline/__init__.py +++ b/lenskit/lenskit/pipeline/__init__.py @@ -149,6 +149,13 @@ def create_input(self, name: str, *types: type[T] | None) -> Node[T]: return node def literal(self, value: T) -> LiteralNode[T]: + """ + Create a literal node (a node with a fixed value). + + .. note:: + Literal nodes cannot be serialized witih :meth:`get_config` or + :meth:`save_config`. + """ name = str(uuid4()) node = LiteralNode(name, value, types=set([type(value)])) self._nodes[name] = node @@ -407,6 +414,11 @@ def get_config(self) -> PipelineConfig: although the configuration may include things such as paths to checkpoints to load such parameters, depending on the design of the components in the pipeline. + + .. note:: + Literal nodes (from :meth:`literal`, or literal values wired to + inputs) cannot be serialized, and this method will fail if they + are present in the pipeline. """ config = PipelineConfig() for node in self.nodes: From 4a037629621ae039cd39e84bd621dfff6934f864 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 8 Aug 2024 15:03:16 -0400 Subject: [PATCH 15/24] add pipeline metadata --- lenskit/lenskit/pipeline/__init__.py | 18 +++++++++++++++--- lenskit/lenskit/pipeline/config.py | 10 ++++++++++ lenskit/tests/pipeline/test_save_load.py | 3 ++- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/lenskit/lenskit/pipeline/__init__.py b/lenskit/lenskit/pipeline/__init__.py index c02d62652..8ebb2f23c 100644 --- a/lenskit/lenskit/pipeline/__init__.py +++ b/lenskit/lenskit/pipeline/__init__.py @@ -28,7 +28,7 @@ TrainableComponent, instantiate_component, ) -from .config import PipelineComponent, PipelineConfig, PipelineInput +from .config import PipelineComponent, PipelineConfig, PipelineInput, PipelineMeta from .nodes import ND, ComponentNode, FallbackNode, InputNode, LiteralNode, Node from .state import PipelineState @@ -62,14 +62,25 @@ class Pipeline: If you have a scoring model and just want to generate recommenations with a default setup and minimal configuration, see :func:`topn_pipeline`. + + Args: + name: + A name for the pipeline. + version: + A numeric version for the pipeline. """ + name: str | None = None + version: str | None = None + _nodes: dict[str, Node[Any]] _aliases: dict[str, Node[Any]] _defaults: dict[str, Node[Any] | Any] _components: dict[str, Component[Any]] - def __init__(self): + def __init__(self, name: str | None = None, version: str | None = None): + self.name = name + self.version = version self._nodes = {} self._aliases = {} self._defaults = {} @@ -420,7 +431,8 @@ def get_config(self) -> PipelineConfig: inputs) cannot be serialized, and this method will fail if they are present in the pipeline. """ - config = PipelineConfig() + meta = PipelineMeta(name=self.name, version=self.version) + config = PipelineConfig(meta=meta) for node in self.nodes: match node: case InputNode(): diff --git a/lenskit/lenskit/pipeline/config.py b/lenskit/lenskit/pipeline/config.py index 043fdffb8..a6dc4ba33 100644 --- a/lenskit/lenskit/pipeline/config.py +++ b/lenskit/lenskit/pipeline/config.py @@ -23,10 +23,20 @@ class PipelineConfig(BaseModel): not contain the """ + meta: PipelineMeta inputs: list[PipelineInput] = Field(default_factory=list) components: OrderedDict[str, PipelineComponent] = Field(default_factory=OrderedDict) +class PipelineMeta(BaseModel): + """ + Pipeline metadata. + """ + + name: str | None = Field(default=None) + version: str | None = Field(default=None) + + class PipelineInput(BaseModel): name: str "The name for this input." diff --git a/lenskit/tests/pipeline/test_save_load.py b/lenskit/tests/pipeline/test_save_load.py index 78b3dab99..f7d9a7286 100644 --- a/lenskit/tests/pipeline/test_save_load.py +++ b/lenskit/tests/pipeline/test_save_load.py @@ -20,11 +20,12 @@ def __call__(self, msg: str) -> str: def test_serialize_input(): "serialize with one input node" - pipe = Pipeline() + pipe = Pipeline("test") pipe.create_input("user", int, str) cfg = pipe.get_config() print(cfg) + assert cfg.meta.name == "test" assert len(cfg.inputs) == 1 assert cfg.inputs[0].name == "user" assert cfg.inputs[0].types == {"int", "str"} From 1b3c83892aabe75fccaaf8da5aa931cd11234500 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 8 Aug 2024 15:26:12 -0400 Subject: [PATCH 16/24] add configuration hashing --- lenskit/lenskit/pipeline/__init__.py | 32 +++++++++++++++++++++++- lenskit/lenskit/pipeline/config.py | 11 ++++++-- lenskit/tests/pipeline/test_save_load.py | 23 +++++++++++++++++ 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/lenskit/lenskit/pipeline/__init__.py b/lenskit/lenskit/pipeline/__init__.py index 8ebb2f23c..e02dd3749 100644 --- a/lenskit/lenskit/pipeline/__init__.py +++ b/lenskit/lenskit/pipeline/__init__.py @@ -12,6 +12,8 @@ from __future__ import annotations import logging +import warnings +from hashlib import sha256 from types import FunctionType from typing import Literal, cast from uuid import uuid4 @@ -414,7 +416,7 @@ def clone(self, *, params: bool = False) -> Pipeline: return clone - def get_config(self) -> PipelineConfig: + def get_config(self, *, include_hash=True) -> PipelineConfig: """ Get this pipeline's configuration for serialization. The configuration consists of all inputs and components along with their configurations @@ -432,6 +434,9 @@ def get_config(self) -> PipelineConfig: are present in the pipeline. """ meta = PipelineMeta(name=self.name, version=self.version) + if include_hash: + meta.hash = self.config_hash() + config = PipelineConfig(meta=meta) for node in self.nodes: match node: @@ -450,6 +455,25 @@ def get_config(self) -> PipelineConfig: return config + def config_hash(self) -> str: + """ + Get a hash of the pipeline's configuration to uniquely identify it for + logging, version control, or other purposes. + + The precise algorithm to compute the hash is not guaranteed, except that + the same configuration with the same version of LensKit and its + dependencies will produce the same hash. In LensKit 2024.1, the + configuration hash is computed by computing the JSON serialization of + the pipeline configuration *without* a hash returning the hex-encoded + SHA1 hash of that configuration. + """ + # get the config *without* a hash + cfg = self.get_config(include_hash=False) + json = cfg.model_dump_json(exclude_none=True) + h = sha256() + h.update(json.encode()) + return h.hexdigest() + @classmethod def from_config(cls, config: object) -> Self: cfg = PipelineConfig.model_validate(config) @@ -488,6 +512,12 @@ def from_config(cls, config: object) -> Self: elif not comp.code.startswith("@"): raise RuntimeError(f"component {name} inputs must be dict, not list") + if cfg.meta.hash is not None: + h2 = pipe.config_hash() + if h2 != cfg.meta.hash: + _log.warn("loaded pipeline does not match hash") + warnings.warn("loaded pipeline config does not match hash") + return pipe def train(self, data: Dataset) -> None: diff --git a/lenskit/lenskit/pipeline/config.py b/lenskit/lenskit/pipeline/config.py index a6dc4ba33..98b14d337 100644 --- a/lenskit/lenskit/pipeline/config.py +++ b/lenskit/lenskit/pipeline/config.py @@ -33,8 +33,15 @@ class PipelineMeta(BaseModel): Pipeline metadata. """ - name: str | None = Field(default=None) - version: str | None = Field(default=None) + name: str | None = None + "The pipeline name." + version: str | None = None + "The pipeline version." + hash: str | None = None + """ + The pipeline configuration hash. This is optional, particularly when + hand-crafting pipeline configuration files. + """ class PipelineInput(BaseModel): diff --git a/lenskit/tests/pipeline/test_save_load.py b/lenskit/tests/pipeline/test_save_load.py index f7d9a7286..f720e3ccc 100644 --- a/lenskit/tests/pipeline/test_save_load.py +++ b/lenskit/tests/pipeline/test_save_load.py @@ -2,6 +2,8 @@ from typing_extensions import assert_type +from pytest import fail, warns + from lenskit.pipeline import InputNode, Node, Pipeline from lenskit.pipeline.components import AutoConfig from lenskit.pipeline.config import PipelineConfig @@ -118,6 +120,10 @@ def test_configurable_component(): assert p2.run("prefix", msg="HACKEM MUCHE") == "scroll named HACKEM MUCHE" + print("hash:", pipe.config_hash()) + assert pipe.config_hash() is not None + assert p2.config_hash() == pipe.config_hash() + def negative(x: int) -> int: return -x @@ -150,3 +156,20 @@ def test_save_with_fallback(): # 3 * 2 + -3 = 3 assert p2.run("fill-operand", "add", a=3) == (-3, 3) + + +def test_hash_validate(): + pipe = Pipeline() + msg = pipe.create_input("msg", str) + + pfx = Prefixer("scroll named ") + pipe.add_component("prefix", pfx, msg=msg) + + cfg = pipe.get_config() + print("initial config:", cfg.model_dump_json(indent=2)) + assert cfg.meta.hash is not None + cfg.components["prefix"].config["prefix"] = "scroll called " # type: ignore + print("modified config:", cfg.model_dump_json(indent=2)) + + with warns(UserWarning): + Pipeline.from_config(cfg) From fa5ad118d93ba9e5302b976c6d7867055b5d6f9b Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 8 Aug 2024 15:30:10 -0400 Subject: [PATCH 17/24] clarify from_config comments --- lenskit/lenskit/pipeline/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lenskit/lenskit/pipeline/__init__.py b/lenskit/lenskit/pipeline/__init__.py index e02dd3749..987eea335 100644 --- a/lenskit/lenskit/pipeline/__init__.py +++ b/lenskit/lenskit/pipeline/__init__.py @@ -484,6 +484,10 @@ def from_config(cls, config: object) -> Self: types += [parse_type_string(t) for t in inpt.types] pipe.create_input(inpt.name, *types) + # we now add the components and other nodes in multiple passes to ensure + # that nodes are available before they are wired (since `connect` can + # introduce out-of-order dependencies). + # pass 1: add components to_wire: list[PipelineComponent] = [] for name, comp in cfg.components.items(): @@ -504,7 +508,7 @@ def from_config(cls, config: object) -> Self: elif comp.code.startswith("@"): raise RuntimeError(f"unsupported meta-component {comp.code}") - # pass 2: wiring + # pass 3: wiring for name, comp in cfg.components.items(): if isinstance(comp.inputs, dict): inputs = {n: pipe.node(t) for (n, t) in comp.inputs.items()} From ae4b9d8fc8d839f4d68552ee8ba8cd584da74566 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 8 Aug 2024 15:33:53 -0400 Subject: [PATCH 18/24] add pipeline error & warning classes --- lenskit/lenskit/pipeline/__init__.py | 30 ++++++++++++++++++++++-- lenskit/tests/pipeline/test_save_load.py | 4 ++-- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/lenskit/lenskit/pipeline/__init__.py b/lenskit/lenskit/pipeline/__init__.py index 987eea335..a6886a26d 100644 --- a/lenskit/lenskit/pipeline/__init__.py +++ b/lenskit/lenskit/pipeline/__init__.py @@ -55,6 +55,32 @@ T5 = TypeVar("T5") +class PipelineError(Exception): + """ + Pipeline configuration errors. + + .. note:: + + This exception is only to note problems with the pipeline configuration + and structure (e.g. circular dependencies). Errors *running* the + pipeline are raised as-is. + """ + + +class PipelineWarning(Warning): + """ + Pipeline configuration and setup warnings. We also emit warnings to the + logger in many cases, but this allows critical ones to be visible even if + the client code has not enabled logging. + + .. note:: + + This warning is only to note problems with the pipeline configuration + and structure (e.g. circular dependencies). Errors *running* the + pipeline are raised as-is. + """ + + class Pipeline: """ LensKit recommendation pipeline. This is the core abstraction for using @@ -416,7 +442,7 @@ def clone(self, *, params: bool = False) -> Pipeline: return clone - def get_config(self, *, include_hash=True) -> PipelineConfig: + def get_config(self, *, include_hash: bool = True) -> PipelineConfig: """ Get this pipeline's configuration for serialization. The configuration consists of all inputs and components along with their configurations @@ -520,7 +546,7 @@ def from_config(cls, config: object) -> Self: h2 = pipe.config_hash() if h2 != cfg.meta.hash: _log.warn("loaded pipeline does not match hash") - warnings.warn("loaded pipeline config does not match hash") + warnings.warn("loaded pipeline config does not match hash", PipelineWarning) return pipe diff --git a/lenskit/tests/pipeline/test_save_load.py b/lenskit/tests/pipeline/test_save_load.py index f720e3ccc..8272fa766 100644 --- a/lenskit/tests/pipeline/test_save_load.py +++ b/lenskit/tests/pipeline/test_save_load.py @@ -4,7 +4,7 @@ from pytest import fail, warns -from lenskit.pipeline import InputNode, Node, Pipeline +from lenskit.pipeline import InputNode, Node, Pipeline, PipelineWarning from lenskit.pipeline.components import AutoConfig from lenskit.pipeline.config import PipelineConfig from lenskit.pipeline.nodes import ComponentNode @@ -171,5 +171,5 @@ def test_hash_validate(): cfg.components["prefix"].config["prefix"] = "scroll called " # type: ignore print("modified config:", cfg.model_dump_json(indent=2)) - with warns(UserWarning): + with warns(PipelineWarning): Pipeline.from_config(cfg) From 2a598cbd864bded2f9253f042c13d349389fbfb1 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 8 Aug 2024 15:38:29 -0400 Subject: [PATCH 19/24] update to properly throw and test for pipeline errors --- lenskit/lenskit/pipeline/__init__.py | 10 ++++++---- lenskit/lenskit/pipeline/runner.py | 8 ++++---- lenskit/tests/pipeline/test_pipeline.py | 10 +++++----- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/lenskit/lenskit/pipeline/__init__.py b/lenskit/lenskit/pipeline/__init__.py index a6886a26d..f6f91b156 100644 --- a/lenskit/lenskit/pipeline/__init__.py +++ b/lenskit/lenskit/pipeline/__init__.py @@ -36,6 +36,8 @@ __all__ = [ "Pipeline", + "PipelineError", + "PipelineWarning", "Node", "topn_pipeline", "Component", @@ -529,10 +531,10 @@ def from_config(cls, config: object) -> Self: for name, comp in cfg.components.items(): if comp.code == "@use-first-of": if not isinstance(comp.inputs, list): - raise RuntimeError("@use-first-of must have input list, not dict") + raise PipelineError("@use-first-of must have input list, not dict") pipe.use_first_of(name, *[pipe.node(n) for n in comp.inputs]) elif comp.code.startswith("@"): - raise RuntimeError(f"unsupported meta-component {comp.code}") + raise PipelineError(f"unsupported meta-component {comp.code}") # pass 3: wiring for name, comp in cfg.components.items(): @@ -540,7 +542,7 @@ def from_config(cls, config: object) -> Self: inputs = {n: pipe.node(t) for (n, t) in comp.inputs.items()} pipe.connect(name, **inputs) elif not comp.code.startswith("@"): - raise RuntimeError(f"component {name} inputs must be dict, not list") + raise PipelineError(f"component {name} inputs must be dict, not list") if cfg.meta.hash is not None: h2 = pipe.config_hash() @@ -682,7 +684,7 @@ def _check_available_name(self, name: str) -> None: def _check_member_node(self, node: Node[Any]) -> None: nw = self._nodes.get(node.name) if nw is not node: - raise RuntimeError(f"node {node} not in pipeline") + raise PipelineError(f"node {node} not in pipeline") def _clear_caches(self): pass diff --git a/lenskit/lenskit/pipeline/runner.py b/lenskit/lenskit/pipeline/runner.py index d4d601693..668101c12 100644 --- a/lenskit/lenskit/pipeline/runner.py +++ b/lenskit/lenskit/pipeline/runner.py @@ -12,7 +12,7 @@ import logging from typing import Any, Literal, TypeAlias -from . import Pipeline +from . import Pipeline, PipelineError from .components import Component from .nodes import ComponentNode, FallbackNode, InputNode, LiteralNode, Node from .types import is_compatible_data @@ -48,7 +48,7 @@ def run(self, node: Node[Any], *, required: bool = True) -> Any: if status == "finished": return self.state[node.name] elif status == "in-progress": - raise RuntimeError(f"pipeline cycle encountered at {node}") + raise PipelineError(f"pipeline cycle encountered at {node}") elif status == "failed": # pragma: nocover raise RuntimeError(f"{node} previously failed") @@ -81,12 +81,12 @@ def _run_node(self, node: Node[Any], required: bool) -> None: case FallbackNode(name, alts): self._run_fallback(name, alts) case _: # pragma: nocover - raise RuntimeError(f"invalid node {node}") + raise PipelineError(f"invalid node {node}") def _inject_input(self, name: str, types: set[type] | None, required: bool) -> None: val = self.inputs.get(name, None) if val is None and required and types and not is_compatible_data(None, *types): - raise RuntimeError(f"input {name} not specified") + raise PipelineError(f"input {name} not specified") if val is not None and types and not is_compatible_data(val, *types): raise TypeError(f"invalid data for input {name} (expected {types}, got {type(val)})") diff --git a/lenskit/tests/pipeline/test_pipeline.py b/lenskit/tests/pipeline/test_pipeline.py index 03a27117a..ad1a7e1eb 100644 --- a/lenskit/tests/pipeline/test_pipeline.py +++ b/lenskit/tests/pipeline/test_pipeline.py @@ -13,7 +13,7 @@ from pytest import fail, raises from lenskit.data import Dataset, Vocabulary -from lenskit.pipeline import InputNode, Node, Pipeline +from lenskit.pipeline import InputNode, Node, Pipeline, PipelineError from lenskit.pipeline.components import TrainableComponent @@ -133,7 +133,7 @@ def incr(msg: str) -> str: node = pipe.add_component("return", incr, msg=msg) - with raises(RuntimeError, match="not specified"): + with raises(PipelineError, match="not specified"): pipe.run(node) @@ -245,7 +245,7 @@ def add(x: int, y: int) -> int: na = pipe.add_component("add", add, x=nd, y=b) pipe.connect(nd, x=na) - with raises(RuntimeError, match="cycle"): + with raises(PipelineError, match="cycle"): pipe.run(a=1, b=7) @@ -275,7 +275,7 @@ def add(x: int, y: int) -> int: assert pipe.run(nt, a=3, b=7) == 9 # old node should be missing! - with raises(RuntimeError, match="not in pipeline"): + with raises(PipelineError, match="not in pipeline"): pipe.run(nd, a=3, b=7) @@ -443,7 +443,7 @@ def add(x: int, y: int) -> int: nd = pipe.add_component("double", double, x=a) na = pipe.add_component("add", add, x=nd, y=b) - with raises(RuntimeError, match=r"input.*not specified"): + with raises(PipelineError, match=r"input.*not specified"): pipe.run(na, a=3) # missing inputs only matter if they are required From cf2ff62dfc41340a1025c87a4f0650049f859b91 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 8 Aug 2024 15:48:00 -0400 Subject: [PATCH 20/24] warn when parameter has no annotation --- lenskit/lenskit/pipeline/nodes.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/lenskit/lenskit/pipeline/nodes.py b/lenskit/lenskit/pipeline/nodes.py index 561ff8950..07cdab78e 100644 --- a/lenskit/lenskit/pipeline/nodes.py +++ b/lenskit/lenskit/pipeline/nodes.py @@ -95,7 +95,14 @@ def __init__(self, name: str, component: Component[ND]): else: self.types = set([sig.return_annotation]) - self.inputs = { - param.name: None if param.annotation == Signature.empty else param.annotation - for param in sig.parameters.values() - } + self.inputs = {} + for param in sig.parameters.values(): + if param.annotation == Signature.empty: + warnings.warn( + f"parameter {param.name} of component {component} has no type annotation", + TypecheckWarning, + 2, + ) + self.inputs[param.name] = None + else: + self.inputs[param.name] = param.annotation From d218f10a1e5d2d87bf4d019c4fb9aee7f0ebed2d Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 8 Aug 2024 15:48:22 -0400 Subject: [PATCH 21/24] fix deprecated warning method --- lenskit/lenskit/pipeline/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lenskit/lenskit/pipeline/__init__.py b/lenskit/lenskit/pipeline/__init__.py index f6f91b156..70de7ec6c 100644 --- a/lenskit/lenskit/pipeline/__init__.py +++ b/lenskit/lenskit/pipeline/__init__.py @@ -547,7 +547,7 @@ def from_config(cls, config: object) -> Self: if cfg.meta.hash is not None: h2 = pipe.config_hash() if h2 != cfg.meta.hash: - _log.warn("loaded pipeline does not match hash") + _log.warning("loaded pipeline does not match hash") warnings.warn("loaded pipeline config does not match hash", PipelineWarning) return pipe From 6c93561ce9f3c930975fd702091ad0263c9a5599 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 8 Aug 2024 15:57:19 -0400 Subject: [PATCH 22/24] document @-nodes --- lenskit/lenskit/pipeline/config.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lenskit/lenskit/pipeline/config.py b/lenskit/lenskit/pipeline/config.py index 98b14d337..9548557d4 100644 --- a/lenskit/lenskit/pipeline/config.py +++ b/lenskit/lenskit/pipeline/config.py @@ -65,6 +65,10 @@ class PipelineComponent(BaseModel): """ The path to the component's implementation, either a class or a function. This is a Python qualified path of the form ``module:name``. + + Special nodes, like :class:`lenskit.pipeline.Pipeline.use_first_of`, are + serialized as components whose code is a magic name beginning with ``@`` + (e.g. ``@use-first-of``). """ config: dict[str, object] | None = Field(default=None) From 48e4e27230e03d0d63983e8a71d2605ae384f49f Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 8 Aug 2024 15:58:51 -0400 Subject: [PATCH 23/24] fix SHA docs --- lenskit/lenskit/pipeline/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lenskit/lenskit/pipeline/__init__.py b/lenskit/lenskit/pipeline/__init__.py index 70de7ec6c..76f00ccb9 100644 --- a/lenskit/lenskit/pipeline/__init__.py +++ b/lenskit/lenskit/pipeline/__init__.py @@ -493,7 +493,7 @@ def config_hash(self) -> str: dependencies will produce the same hash. In LensKit 2024.1, the configuration hash is computed by computing the JSON serialization of the pipeline configuration *without* a hash returning the hex-encoded - SHA1 hash of that configuration. + SHA256 hash of that configuration. """ # get the config *without* a hash cfg = self.get_config(include_hash=False) From 886f35b5574e79f801ca721b4308c909bd582039 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 8 Aug 2024 16:02:15 -0400 Subject: [PATCH 24/24] only serialize once to insert hash into config --- lenskit/lenskit/pipeline/__init__.py | 14 +++++--------- lenskit/lenskit/pipeline/config.py | 11 +++++++++++ 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/lenskit/lenskit/pipeline/__init__.py b/lenskit/lenskit/pipeline/__init__.py index 76f00ccb9..aecd811ef 100644 --- a/lenskit/lenskit/pipeline/__init__.py +++ b/lenskit/lenskit/pipeline/__init__.py @@ -13,7 +13,6 @@ import logging import warnings -from hashlib import sha256 from types import FunctionType from typing import Literal, cast from uuid import uuid4 @@ -30,7 +29,7 @@ TrainableComponent, instantiate_component, ) -from .config import PipelineComponent, PipelineConfig, PipelineInput, PipelineMeta +from .config import PipelineComponent, PipelineConfig, PipelineInput, PipelineMeta, hash_config from .nodes import ND, ComponentNode, FallbackNode, InputNode, LiteralNode, Node from .state import PipelineState @@ -462,9 +461,6 @@ def get_config(self, *, include_hash: bool = True) -> PipelineConfig: are present in the pipeline. """ meta = PipelineMeta(name=self.name, version=self.version) - if include_hash: - meta.hash = self.config_hash() - config = PipelineConfig(meta=meta) for node in self.nodes: match node: @@ -481,6 +477,9 @@ def get_config(self, *, include_hash: bool = True) -> PipelineConfig: case _: # pragma: nocover raise RuntimeError(f"invalid node {node}") + if include_hash: + config.meta.hash = hash_config(config) + return config def config_hash(self) -> str: @@ -497,10 +496,7 @@ def config_hash(self) -> str: """ # get the config *without* a hash cfg = self.get_config(include_hash=False) - json = cfg.model_dump_json(exclude_none=True) - h = sha256() - h.update(json.encode()) - return h.hexdigest() + return hash_config(cfg) @classmethod def from_config(cls, config: object) -> Self: diff --git a/lenskit/lenskit/pipeline/config.py b/lenskit/lenskit/pipeline/config.py index 9548557d4..dd96bd26c 100644 --- a/lenskit/lenskit/pipeline/config.py +++ b/lenskit/lenskit/pipeline/config.py @@ -6,6 +6,7 @@ from __future__ import annotations from collections import OrderedDict +from hashlib import sha256 from types import FunctionType from pydantic import BaseModel, Field @@ -96,3 +97,13 @@ def from_node(cls, node: ComponentNode[Any]) -> Self: config = comp.get_config() if isinstance(comp, ConfigurableComponent) else None return cls(code=code, config=config, inputs=node.connections) + + +def hash_config(config: BaseModel) -> str: + """ + Compute the hash of a configuration model. + """ + json = config.model_dump_json(exclude_none=True) + h = sha256() + h.update(json.encode()) + return h.hexdigest()