Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialize full pipeline configurations #469

Merged
merged 24 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4538ba8
add Pydantic dependency
mdekstrand Aug 7, 2024
fc61a1d
add string type representations
mdekstrand Aug 7, 2024
7943203
move and split pipeline clone / config
mdekstrand Aug 7, 2024
7577d1e
add pydantic to docs
mdekstrand Aug 7, 2024
176df23
add initial config model + input node + test
mdekstrand Aug 7, 2024
ff4d598
store types in sets
mdekstrand Aug 7, 2024
8ee677a
fix type error in pipeline construction
mdekstrand Aug 7, 2024
1ba6aa1
Serialize and round-trip component inputs
mdekstrand Aug 7, 2024
0691cce
parse qualified (colon-separated) types
mdekstrand Aug 7, 2024
b0f44d2
serialize to qualified names
mdekstrand Aug 7, 2024
8fbe084
support round-trip configuration
mdekstrand Aug 7, 2024
8c398df
refactor configuration with pattern-matching
mdekstrand Aug 8, 2024
6153b6e
implement fallback node serialization
mdekstrand Aug 8, 2024
e43f51b
document lack of support for serializing literal nodes
mdekstrand Aug 8, 2024
4a03762
add pipeline metadata
mdekstrand Aug 8, 2024
1b3c838
add configuration hashing
mdekstrand Aug 8, 2024
fa5ad11
clarify from_config comments
mdekstrand Aug 8, 2024
ae4b9d8
add pipeline error & warning classes
mdekstrand Aug 8, 2024
2a598cb
update to properly throw and test for pipeline errors
mdekstrand Aug 8, 2024
cf2ff62
warn when parameter has no annotation
mdekstrand Aug 8, 2024
d218f10
fix deprecated warning method
mdekstrand Aug 8, 2024
6c93561
document @-nodes
mdekstrand Aug 8, 2024
48e4e27
fix SHA docs
mdekstrand Aug 8, 2024
886f35b
only serialize once to insert hash into config
mdekstrand Aug 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
"manylog": ("https://manylog.readthedocs.io/en/latest/", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"implicit": ("https://benfred.github.io/implicit/", None),
"pydantic": ("https://docs.pydantic.dev/latest/", None),
}

bibtex_bibfiles = ["lenskit.bib"]
Expand Down
56 changes: 52 additions & 4 deletions lenskit/lenskit/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@
from typing import Literal, cast
from uuid import uuid4

from typing_extensions import Any, LiteralString, TypeVar, overload
from typing_extensions import Any, LiteralString, Self, TypeVar, overload

from lenskit.data import Dataset
from lenskit.pipeline.types import parse_type_string

from .components import (
AutoConfig, # noqa: F401 # type: ignore
Component,
ConfigurableComponent,
TrainableComponent,
instantiate_component,
)
from .config import PipelineComponent, PipelineConfig, PipelineInput
from .nodes import ND, ComponentNode, FallbackNode, InputNode, LiteralNode, Node
from .state import PipelineState

Expand All @@ -36,6 +39,7 @@
"Component",
"ConfigurableComponent",
"TrainableComponent",
"PipelineConfig",
]

_log = logging.getLogger(__name__)
Expand Down Expand Up @@ -139,7 +143,7 @@ def create_input(self, name: str, *types: type[T] | None) -> Node[T]:
"""
self._check_available_name(name)

node = InputNode[Any](name, types=set((t if t is not None else type[None]) for t in types))
node = InputNode[Any](name, types=set((t if t is not None else type(None)) for t in types))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bugfix for existing code around None.

self._nodes[name] = node
self._clear_caches()
return node
Expand Down Expand Up @@ -381,9 +385,9 @@ def clone(self, *, params: bool = False) -> Pipeline:
if isinstance(comp, FunctionType):
comp = comp
elif isinstance(comp, ConfigurableComponent):
comp = comp.__class__.from_config(comp.get_config())
comp = comp.__class__.from_config(comp.get_config()) # type: ignore
else:
comp = comp.__class__()
comp = comp.__class__() # type: ignore
cn = clone.add_component(node.name, comp) # type: ignore
for wn, wt in wiring.items():
clone.connect(cn, **{wn: clone.node(wt)})
Expand All @@ -392,6 +396,50 @@ def clone(self, *, params: bool = False) -> Pipeline:

return clone

def get_config(self) -> PipelineConfig:
"""
Get this pipeline's configuration for serialization. The configuration
consists of all inputs and components along with their configurations
and input connections. It can be serialized to disk (in JSON, YAML, or
a similar format) to save a pipeline.

The configuration does **not** include any trained parameter values,
although the configuration may include things such as paths to
checkpoints to load such parameters, depending on the design of the
components in the pipeline.
"""
return PipelineConfig(
inputs=[
PipelineInput.from_node(node) for node in self.nodes if isinstance(node, InputNode)
],
components={
node.name: PipelineComponent.from_node(node)
for node in self.nodes
if isinstance(node, ComponentNode)
},
)

@classmethod
def from_config(cls, config: object) -> Self:
cfg = PipelineConfig.model_validate(config)
pipe = cls()
for inpt in cfg.inputs:
types: list[type[Any] | None] = []
if inpt.types is not None:
types += [parse_type_string(t) for t in inpt.types]
pipe.create_input(inpt.name, *types)

for name, comp in cfg.components.items():
obj = instantiate_component(comp.code, comp.config)
pipe.add_component(name, obj)

# pass 2: wiring
for name, comp in cfg.components.items():
inputs = {n: pipe.node(t) for (n, t) in comp.inputs.items()}
pipe.connect(name, **inputs)

return pipe

def train(self, data: Dataset) -> None:
"""
Trains the pipeline's trainable components (those implementing the
Expand Down
23 changes: 23 additions & 0 deletions lenskit/lenskit/pipeline/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from __future__ import annotations

import inspect
from importlib import import_module
from types import FunctionType
from typing import Callable, ClassVar, TypeAlias

from typing_extensions import Any, Generic, Protocol, Self, TypeVar, override, runtime_checkable
Expand Down Expand Up @@ -184,3 +186,24 @@ def from_config(cls, cfg: dict[str, Any]) -> Self:
are passed to the constructor as keywrod arguments.
"""
return cls(**cfg)


def instantiate_component(
comp: str | type | FunctionType, config: dict[str, Any] | None
) -> Callable[..., object]:
if isinstance(comp, str):
mname, oname = comp.split(":", 1)
mod = import_module(mname)
comp = getattr(mod, oname)

# make the type checker happy
assert not isinstance(comp, str)

if isinstance(comp, FunctionType):
return comp
elif issubclass(comp, ConfigurableComponent):
if config is None:
config = {}
return comp.from_config(config) # type: ignore
else:
return comp() # type: ignore
75 changes: 75 additions & 0 deletions lenskit/lenskit/pipeline/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Pydantic models for pipeline configuration and serialization support.
"""

# pyright: strict
from __future__ import annotations

from types import FunctionType

from pydantic import BaseModel, Field
from typing_extensions import Any, Optional, Self

from .components import ConfigurableComponent
from .nodes import ComponentNode, InputNode
from .types import type_string


class PipelineConfig(BaseModel):
"""
Root type for serialized pipeline configuration. A pipeline config contains
the full configuration, components, and wiring for the pipeline, but does
not contain the
"""

inputs: list[PipelineInput]
components: dict[str, PipelineComponent] = Field(default_factory=dict)


class PipelineInput(BaseModel):
name: str
"The name for this input."
types: Optional[set[str]]
"The list of types for this input."

@classmethod
def from_node(cls, node: InputNode[Any]) -> Self:
if node.types is not None:
types = {type_string(t) for t in node.types}
else:
types = None

return cls(name=node.name, types=types)


class PipelineComponent(BaseModel):
code: str
"""
The path to the component's implementation, either a class or a function.
This is a Python qualified path of the form ``module:name``.
"""

config: dict[str, object] | None = Field(default=None)
"""
The component configuration. If not provided, the component will be created
with its default constructor parameters.
"""

inputs: dict[str, str] = Field(default_factory=dict)
"""
The component's input wirings, mapping input names to node names.
"""

@classmethod
def from_node(cls, node: ComponentNode[Any]) -> Self:
comp = node.component
if isinstance(comp, FunctionType):
ctype = comp
else:
ctype = comp.__class__

code = f"{ctype.__module__}:{ctype.__qualname__}"

config = comp.get_config() if isinstance(comp, ConfigurableComponent) else None

return cls(code=code, config=config, inputs=node.connections)
42 changes: 41 additions & 1 deletion lenskit/lenskit/pipeline/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
# pyright: basic
from __future__ import annotations

import re
import warnings
from types import GenericAlias
from importlib import import_module
from types import GenericAlias, NoneType
from typing import Union, _GenericAlias, get_args, get_origin # type: ignore

import numpy as np
Expand Down Expand Up @@ -118,3 +120,41 @@ def is_compatible_data(obj: object, *targets: type) -> bool:
return True

return False


def type_string(typ: type | None) -> str:
"""
Compute a string representation of a type that is both resolvable and
human-readable. Type parameterizations are lost.
"""
if typ is None or typ is NoneType:
return "None"
elif typ.__module__ == "builtins":
return typ.__name__
elif typ.__qualname__ == typ.__name__:
return f"{typ.__module__}.{typ.__name__}"
else:
return f"{typ.__module__}:{typ.__qualname__}"


def parse_type_string(tstr: str) -> type:
"""
Compute a string representation of a type that is both resolvable and
human-readable. Type parameterizations are lost.
"""
if tstr == "None":
return NoneType
elif re.match(r"^\w+$", tstr):
return __builtins__[tstr]
else:
if ":" in tstr:
mod_name, typ_name = tstr.split(":", 1)
else:
# separate last element from module
parts = re.match(r"(.*)\.(\w+)$", tstr)
if not parts:
raise ValueError(f"unparsable type string {tstr}")
mod_name, typ_name = parts.groups()

mod = import_module(mod_name)
return getattr(mod, typ_name)
3 changes: 2 additions & 1 deletion lenskit/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ dependencies = [
"scipy >= 1.9.0",
"torch ~=2.1", # conda: pytorch>=2.1,<3
"threadpoolctl >=3.0",
"seedbank >= 0.2.0a2", # conda: @pip
"pydantic >=2.8,<3",
"seedbank >=0.2.0a2", # conda: @pip
"progress-api >=0.1.0a9", # conda: @pip
"manylog >=0.1.0a5", # conda: @pip
]
Expand Down
48 changes: 48 additions & 0 deletions lenskit/tests/pipeline/test_component_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# This file is part of LensKit.
# Copyright (C) 2018-2023 Boise State University
# Copyright (C) 2023-2024 Drexel University
# Licensed under the MIT license, see LICENSE.md for details.
# SPDX-License-Identifier: MIT

import json

from lenskit.pipeline import Pipeline
from lenskit.pipeline.components import AutoConfig
from lenskit.pipeline.nodes import ComponentNode


class Prefixer(AutoConfig):
prefix: str

def __init__(self, prefix: str = "hello"):
self.prefix = prefix

def __call__(self, msg: str) -> str:
return self.prefix + msg


def test_auto_config_roundtrip():
comp = Prefixer("FOOBIE BLETCH")

cfg = comp.get_config()
assert "prefix" in cfg

c2 = Prefixer.from_config(cfg)
assert c2 is not comp
assert c2.prefix == comp.prefix


def test_pipeline_config():
comp = Prefixer("scroll named ")

pipe = Pipeline()
msg = pipe.create_input("msg", str)
pipe.add_component("prefix", comp, msg=msg)

assert pipe.run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH"

config = pipe.component_configs()
print(json.dumps(config, indent=2))

assert "prefix" in config
assert config["prefix"]["prefix"] == "scroll named "
38 changes: 38 additions & 0 deletions lenskit/tests/pipeline/test_config_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from lenskit.pipeline.config import PipelineInput
from lenskit.pipeline.nodes import InputNode


def test_untyped_input():
node = InputNode("scroll")

cfg = PipelineInput.from_node(node)
print(cfg)
assert cfg.name == "scroll"
assert cfg.types is None


def test_input_with_type():
node = InputNode("scroll", types={str})

cfg = PipelineInput.from_node(node)
print(cfg)
assert cfg.name == "scroll"
assert cfg.types == {"str"}


def test_input_with_none():
node = InputNode("scroll", types={str, type(None)})

cfg = PipelineInput.from_node(node)
print(cfg)
assert cfg.name == "scroll"
assert cfg.types == {"None", "str"}


def test_input_with_generic():
node = InputNode("scroll", types={list[str]})

cfg = PipelineInput.from_node(node)
print(cfg)
assert cfg.name == "scroll"
assert cfg.types == {"list"}
Loading
Loading