Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialize full pipeline configurations #469

Merged
merged 24 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4538ba8
add Pydantic dependency
mdekstrand Aug 7, 2024
fc61a1d
add string type representations
mdekstrand Aug 7, 2024
7943203
move and split pipeline clone / config
mdekstrand Aug 7, 2024
7577d1e
add pydantic to docs
mdekstrand Aug 7, 2024
176df23
add initial config model + input node + test
mdekstrand Aug 7, 2024
ff4d598
store types in sets
mdekstrand Aug 7, 2024
8ee677a
fix type error in pipeline construction
mdekstrand Aug 7, 2024
1ba6aa1
Serialize and round-trip component inputs
mdekstrand Aug 7, 2024
0691cce
parse qualified (colon-separated) types
mdekstrand Aug 7, 2024
b0f44d2
serialize to qualified names
mdekstrand Aug 7, 2024
8fbe084
support round-trip configuration
mdekstrand Aug 7, 2024
8c398df
refactor configuration with pattern-matching
mdekstrand Aug 8, 2024
6153b6e
implement fallback node serialization
mdekstrand Aug 8, 2024
e43f51b
document lack of support for serializing literal nodes
mdekstrand Aug 8, 2024
4a03762
add pipeline metadata
mdekstrand Aug 8, 2024
1b3c838
add configuration hashing
mdekstrand Aug 8, 2024
fa5ad11
clarify from_config comments
mdekstrand Aug 8, 2024
ae4b9d8
add pipeline error & warning classes
mdekstrand Aug 8, 2024
2a598cb
update to properly throw and test for pipeline errors
mdekstrand Aug 8, 2024
cf2ff62
warn when parameter has no annotation
mdekstrand Aug 8, 2024
d218f10
fix deprecated warning method
mdekstrand Aug 8, 2024
6c93561
document @-nodes
mdekstrand Aug 8, 2024
48e4e27
fix SHA docs
mdekstrand Aug 8, 2024
886f35b
only serialize once to insert hash into config
mdekstrand Aug 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
"manylog": ("https://manylog.readthedocs.io/en/latest/", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"implicit": ("https://benfred.github.io/implicit/", None),
"pydantic": ("https://docs.pydantic.dev/latest/", None),
}

bibtex_bibfiles = ["lenskit.bib"]
Expand Down
168 changes: 162 additions & 6 deletions lenskit/lenskit/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,37 @@
from __future__ import annotations

import logging
import warnings
from types import FunctionType
from typing import Literal, cast
from uuid import uuid4

from typing_extensions import Any, LiteralString, TypeVar, overload
from typing_extensions import Any, LiteralString, Self, TypeVar, overload

from lenskit.data import Dataset
from lenskit.pipeline.types import parse_type_string

from .components import (
AutoConfig, # noqa: F401 # type: ignore
Component,
ConfigurableComponent,
TrainableComponent,
instantiate_component,
)
from .config import PipelineComponent, PipelineConfig, PipelineInput, PipelineMeta, hash_config
from .nodes import ND, ComponentNode, FallbackNode, InputNode, LiteralNode, Node
from .state import PipelineState

__all__ = [
"Pipeline",
"PipelineError",
"PipelineWarning",
"Node",
"topn_pipeline",
"Component",
"ConfigurableComponent",
"TrainableComponent",
"PipelineConfig",
]

_log = logging.getLogger(__name__)
Expand All @@ -49,6 +56,32 @@
T5 = TypeVar("T5")


class PipelineError(Exception):
"""
Pipeline configuration errors.

.. note::

This exception is only to note problems with the pipeline configuration
and structure (e.g. circular dependencies). Errors *running* the
pipeline are raised as-is.
"""


class PipelineWarning(Warning):
"""
Pipeline configuration and setup warnings. We also emit warnings to the
logger in many cases, but this allows critical ones to be visible even if
the client code has not enabled logging.

.. note::

This warning is only to note problems with the pipeline configuration
and structure (e.g. circular dependencies). Errors *running* the
pipeline are raised as-is.
"""


class Pipeline:
"""
LensKit recommendation pipeline. This is the core abstraction for using
Expand All @@ -58,14 +91,25 @@ class Pipeline:

If you have a scoring model and just want to generate recommenations with a
default setup and minimal configuration, see :func:`topn_pipeline`.

Args:
name:
A name for the pipeline.
version:
A numeric version for the pipeline.
"""

name: str | None = None
version: str | None = None

_nodes: dict[str, Node[Any]]
_aliases: dict[str, Node[Any]]
_defaults: dict[str, Node[Any] | Any]
_components: dict[str, Component[Any]]

def __init__(self):
def __init__(self, name: str | None = None, version: str | None = None):
self.name = name
self.version = version
self._nodes = {}
self._aliases = {}
self._defaults = {}
Expand Down Expand Up @@ -139,12 +183,19 @@ def create_input(self, name: str, *types: type[T] | None) -> Node[T]:
"""
self._check_available_name(name)

node = InputNode[Any](name, types=set((t if t is not None else type[None]) for t in types))
node = InputNode[Any](name, types=set((t if t is not None else type(None)) for t in types))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bugfix for existing code around None.

self._nodes[name] = node
self._clear_caches()
return node

def literal(self, value: T) -> LiteralNode[T]:
"""
Create a literal node (a node with a fixed value).

.. note::
Literal nodes cannot be serialized witih :meth:`get_config` or
:meth:`save_config`.
"""
name = str(uuid4())
node = LiteralNode(name, value, types=set([type(value)]))
self._nodes[name] = node
Expand Down Expand Up @@ -381,9 +432,9 @@ def clone(self, *, params: bool = False) -> Pipeline:
if isinstance(comp, FunctionType):
comp = comp
elif isinstance(comp, ConfigurableComponent):
comp = comp.__class__.from_config(comp.get_config())
comp = comp.__class__.from_config(comp.get_config()) # type: ignore
else:
comp = comp.__class__()
comp = comp.__class__() # type: ignore
cn = clone.add_component(node.name, comp) # type: ignore
for wn, wt in wiring.items():
clone.connect(cn, **{wn: clone.node(wt)})
Expand All @@ -392,6 +443,111 @@ def clone(self, *, params: bool = False) -> Pipeline:

return clone

def get_config(self, *, include_hash: bool = True) -> PipelineConfig:
"""
Get this pipeline's configuration for serialization. The configuration
consists of all inputs and components along with their configurations
and input connections. It can be serialized to disk (in JSON, YAML, or
a similar format) to save a pipeline.

The configuration does **not** include any trained parameter values,
although the configuration may include things such as paths to
checkpoints to load such parameters, depending on the design of the
components in the pipeline.

.. note::
Literal nodes (from :meth:`literal`, or literal values wired to
inputs) cannot be serialized, and this method will fail if they
are present in the pipeline.
"""
meta = PipelineMeta(name=self.name, version=self.version)
config = PipelineConfig(meta=meta)
for node in self.nodes:
match node:
case InputNode():
config.inputs.append(PipelineInput.from_node(node))
case LiteralNode():
raise RuntimeError("literal nodes cannot be serialized to config")
case ComponentNode(name):
config.components[name] = PipelineComponent.from_node(node)
case FallbackNode(name, alternatives):
config.components[name] = PipelineComponent(
code="@use-first-of", inputs=[n.name for n in alternatives]
)
case _: # pragma: nocover
raise RuntimeError(f"invalid node {node}")

if include_hash:
config.meta.hash = hash_config(config)

return config

def config_hash(self) -> str:
"""
Get a hash of the pipeline's configuration to uniquely identify it for
logging, version control, or other purposes.

The precise algorithm to compute the hash is not guaranteed, except that
the same configuration with the same version of LensKit and its
dependencies will produce the same hash. In LensKit 2024.1, the
configuration hash is computed by computing the JSON serialization of
the pipeline configuration *without* a hash returning the hex-encoded
SHA256 hash of that configuration.
"""
# get the config *without* a hash
cfg = self.get_config(include_hash=False)
return hash_config(cfg)

@classmethod
def from_config(cls, config: object) -> Self:
cfg = PipelineConfig.model_validate(config)
pipe = cls()
for inpt in cfg.inputs:
types: list[type[Any] | None] = []
if inpt.types is not None:
types += [parse_type_string(t) for t in inpt.types]
pipe.create_input(inpt.name, *types)

# we now add the components and other nodes in multiple passes to ensure
# that nodes are available before they are wired (since `connect` can
# introduce out-of-order dependencies).

# pass 1: add components
to_wire: list[PipelineComponent] = []
for name, comp in cfg.components.items():
if comp.code.startswith("@"):
# ignore special nodes in first pass
continue

obj = instantiate_component(comp.code, comp.config)
pipe.add_component(name, obj)
to_wire.append(comp)

# pass 2: add meta nodes
for name, comp in cfg.components.items():
if comp.code == "@use-first-of":
if not isinstance(comp.inputs, list):
raise PipelineError("@use-first-of must have input list, not dict")
pipe.use_first_of(name, *[pipe.node(n) for n in comp.inputs])
elif comp.code.startswith("@"):
raise PipelineError(f"unsupported meta-component {comp.code}")

# pass 3: wiring
for name, comp in cfg.components.items():
if isinstance(comp.inputs, dict):
inputs = {n: pipe.node(t) for (n, t) in comp.inputs.items()}
pipe.connect(name, **inputs)
elif not comp.code.startswith("@"):
raise PipelineError(f"component {name} inputs must be dict, not list")

if cfg.meta.hash is not None:
h2 = pipe.config_hash()
if h2 != cfg.meta.hash:
_log.warning("loaded pipeline does not match hash")
warnings.warn("loaded pipeline config does not match hash", PipelineWarning)

return pipe

def train(self, data: Dataset) -> None:
"""
Trains the pipeline's trainable components (those implementing the
Expand Down Expand Up @@ -524,7 +680,7 @@ def _check_available_name(self, name: str) -> None:
def _check_member_node(self, node: Node[Any]) -> None:
nw = self._nodes.get(node.name)
if nw is not node:
raise RuntimeError(f"node {node} not in pipeline")
raise PipelineError(f"node {node} not in pipeline")

def _clear_caches(self):
pass
Expand Down
23 changes: 23 additions & 0 deletions lenskit/lenskit/pipeline/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from __future__ import annotations

import inspect
from importlib import import_module
from types import FunctionType
from typing import Callable, ClassVar, TypeAlias

from typing_extensions import Any, Generic, Protocol, Self, TypeVar, override, runtime_checkable
Expand Down Expand Up @@ -184,3 +186,24 @@ def from_config(cls, cfg: dict[str, Any]) -> Self:
are passed to the constructor as keywrod arguments.
"""
return cls(**cfg)


def instantiate_component(
comp: str | type | FunctionType, config: dict[str, Any] | None
) -> Callable[..., object]:
if isinstance(comp, str):
mname, oname = comp.split(":", 1)
mod = import_module(mname)
comp = getattr(mod, oname)

# make the type checker happy
assert not isinstance(comp, str)

if isinstance(comp, FunctionType):
return comp
elif issubclass(comp, ConfigurableComponent):
if config is None:
config = {}
return comp.from_config(config) # type: ignore
else:
return comp() # type: ignore
Loading
Loading