Skip to content

Commit

Permalink
Simplify config manager
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Oct 3, 2024
1 parent 942ecf5 commit 83550be
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 116 deletions.
2 changes: 1 addition & 1 deletion src/fairseq2/data/parquet/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def inner_iterator(wrap_table: _TableWrapper) -> DataPipeline:
.map(
table_func_wrap(
partial(
apply_filter, filters=config.filters, drop_null=config.drop_null
apply_filter, filters=config.filters, drop_null=config.drop_null # type: ignore[arg-type]
)
)
)
Expand Down
51 changes: 27 additions & 24 deletions src/fairseq2/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def resolve(self, kls: type[T], key: str | None = None) -> T:
"""
Returns the object of type ``T``.
:param kls: The :class:`type` of ``T``.
:param kls: The value of ``T``.
:param key: If not ``None``, the object with the specified key will be
returned.
Expand All @@ -51,10 +51,10 @@ def resolve(self, kls: type[T], key: str | None = None) -> T:
def resolve_optional(self, kls: type[T], key: str | None = None) -> T | None:
"""
Returns the object of type ``T`` similar to :meth:`resolve`, but returns
``None`` instead of raising a :class:`LookupError` if the object is not
found.
``None`` instead of raising a :class:`DependencyNotFoundError` if the
object is not found.
:param kls: The :class:`type` of ``T``.
:param kls: The value of ``T``.
:param key: If not ``None``, the object with the specified key will be
returned.
Expand All @@ -70,7 +70,7 @@ def resolve_all(self, kls: type[T]) -> Iterable[T]:
only the last registered one. In contrast, ``resolve_all`` returns
them all in the order they were registered.
:param kls: The :class:`type` of ``T``.
:param kls: The value of ``T``.
:returns: An iterable of resolved objects. If no object is found, an
empty iterable.
Expand All @@ -84,7 +84,7 @@ def resolve_all_keyed(self, kls: type[T]) -> Iterable[tuple[str, T]]:
This method behaves similar to :meth:`resolve_all`, but returns an
iterable of key-object pairs instead.
:param kls: The :class:`type` of ``T``.
:param kls: The value of ``T``.
:returns: An iterable of resolved key-object pairs. If no object is
found, an empty iterable.
Expand Down Expand Up @@ -127,7 +127,7 @@ def register(
registered, :meth:`~DependencyResolver.resolve` will return only the
last registered one.
:param kls: The :class:`type` of ``T``.
:param kls: The value of ``T``.
:param sub_kls: The real type of the object. If not ``None``, must be a
subclass of ``kls``.
:param key: If not ``None``, registers the object with the specified key.
Expand All @@ -150,7 +150,7 @@ def register_factory(
registered, :meth:`~DependencyResolver.resolve` will return only the
last registered one.
:param kls: The :class:`type` of ``T``.
:param kls: The value of ``T``.
:param factory: A callable to create an object of type ``T``. If the
callable returns ``None``, the object is considered to not exist.
:param key: If not ``None``, registers the object with the specified key.
Expand All @@ -169,7 +169,7 @@ def register_instance(self, kls: type[T], obj: T, key: str | None = None) -> Non
Other than registering an existing object instead of a factory, the
method behaves the same as :meth:`register`.
:param kls: The :class:`type` of ``T``.
:param kls: The value of ``T``.
:param obj: The object to register.
:param key: If not ``None``, registers the object with the specified key.
:meth:`~DependencyResolver.resolve` will return the object only if
Expand Down Expand Up @@ -379,18 +379,18 @@ def factory(resolver: DependencyResolver) -> T:
continue

try:
type_hint = type_hints[param_name]
param_type = type_hints[param_name]
except KeyError:
raise DependencyError(
f"The `{param_name}` parameter of `{init_method}` has no type annotation."
)

param_type = get_origin(type_hint) or type_hint
param_origin_type = get_origin(param_type)

arg: Any

if param_type is Iterable:
param_type_args = get_args(type_hint)
if param_origin_type is Iterable:
param_type_args = get_args(param_type)
if len(param_type_args) != 1:
raise DependencyError(
f"The iterable `{param_name}` parameter of `{init_method}` has no element type expression."
Expand All @@ -405,16 +405,17 @@ def factory(resolver: DependencyResolver) -> T:
f"The element type of the iterable `{param_name}` parameter of `{init_method}` is not a `type`."
)

if isinstance(element_type, tuple):
# TODO: implement!
raise RuntimeError("not supported yet!")
else:
arg = list(resolver.resolve_all(element_type))
if get_origin(element_type) is tuple:
element_type_args = get_args(element_type)

if len(element_type_args) == 2 and element_type_args[0] is str:
kwargs[param_name] = resolver.resolve_all_keyed(
element_type_args[1]
)

if len(arg) == 0 and param.default != Parameter.empty:
continue

kwargs[param_name] = arg
kwargs[param_name] = resolver.resolve_all(element_type)
else:
if not isinstance(param_type, type):
if param.default != Parameter.empty:
Expand Down Expand Up @@ -456,11 +457,13 @@ def __init__(
"Neither `obj` nor `factory` is specified. Please file a bug report."
)

if factory is None:
factory = lambda _: obj

self.obj = obj
self.factory = factory

if factory is not None:
self.factory = factory
else:
self.factory = lambda _: obj

self.transient = transient


Expand Down
65 changes: 31 additions & 34 deletions src/fairseq2/recipes/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from fairseq2.dependency import DependencyContainer, DependencyResolver
from fairseq2.gang import Gang
from fairseq2.recipes.config_manager import ConfigManager, register_config
from fairseq2.recipes.config_manager import ConfigManager, ConfigNotFoundError
from fairseq2.utils.file import TensorDumper, TensorLoader
from fairseq2.utils.structured import ValueConverter

Expand All @@ -27,24 +27,7 @@ class FileCheckpointManagerConfig:
path: Path = field(default_factory=lambda: Path("checkpoints"))


@dataclass
class ScoreConfig:
metric: str = "loss"
lower_better: bool = True


def register_checkpoint_manager(container: DependencyContainer) -> None:
register_config(
container,
path="checkpoint_manager",
kls=FileCheckpointManagerConfig,
default_factory=FileCheckpointManagerConfig,
)

register_config(container, path="output_dir", kls=Path)

register_config(container, path="score", kls=ScoreConfig)

container.register_factory(CheckpointManager, _create_checkpoint_manager)

container.register_factory(
Expand All @@ -55,17 +38,25 @@ def register_checkpoint_manager(container: DependencyContainer) -> None:
def _create_checkpoint_manager(resolver: DependencyResolver) -> CheckpointManager:
config_manager = resolver.resolve(ConfigManager)

type_ = config_manager.get_config(
"checkpoint_manager_type", str, default_factory=lambda: "file"
)
try:
type_ = config_manager.get_config("checkpoint_manager_type", str)
except ConfigNotFoundError:
type_ = "file"

return resolver.resolve(CheckpointManager, key=type_)


def _create_file_checkpoint_manager(resolver: DependencyResolver) -> CheckpointManager:
output_dir = resolver.resolve(Path, key="output_dir")
config_manager = resolver.resolve(ConfigManager)

config = resolver.resolve(FileCheckpointManagerConfig, key="checkpoint_manager")
output_dir = config_manager.get_config("output_dir", Path)

try:
config = config_manager.get_config(
"checkpoint_manager", FileCheckpointManagerConfig
)
except ConfigNotFoundError:
config = FileCheckpointManagerConfig()

checkpoint_dir = output_dir.joinpath(config.path)

Expand All @@ -79,9 +70,10 @@ def _create_file_checkpoint_manager(resolver: DependencyResolver) -> CheckpointM

value_converter = resolver.resolve(ValueConverter)

score_config = resolver.resolve_optional(ScoreConfig, key="score")

lower_score_better = False if score_config is None else score_config.lower_better
try:
lower_score_better = config_manager.get_config("lower_score_better", bool)
except ConfigNotFoundError:
lower_score_better = False

return FileCheckpointManager(
checkpoint_dir,
Expand All @@ -96,10 +88,6 @@ def _create_file_checkpoint_manager(resolver: DependencyResolver) -> CheckpointM


def register_checkpoint_metadata_provider(container: DependencyContainer) -> None:
register_config(
container, path="checkpoint_search_dir", kls=Path, type_expr=Path | None
)

container.register_factory(
AssetMetadataProvider, _create_file_checkpoint_metadata_provider
)
Expand All @@ -108,13 +96,22 @@ def register_checkpoint_metadata_provider(container: DependencyContainer) -> Non
def _create_file_checkpoint_metadata_provider(
resolver: DependencyResolver,
) -> AssetMetadataProvider | None:
checkpoint_search_dir = resolver.resolve_optional(Path, key="checkpoint_search_dir")
config_manager = resolver.resolve(ConfigManager)

try:
checkpoint_search_dir = config_manager.get_config(
"checkpoint_search_dir", Path | None
)
except ConfigNotFoundError:
checkpoint_search_dir = None

if checkpoint_search_dir is None:
return None

score_config = resolver.resolve_optional(ScoreConfig, key="score")

lower_score_better = False if score_config is None else score_config.lower_better
try:
lower_score_better = config_manager.get_config("lower_score_better", bool)
except ConfigNotFoundError:
lower_score_better = False

return FileCheckpointMetadataProvider(
checkpoint_search_dir, lower_score_better=lower_score_better
Expand Down
43 changes: 3 additions & 40 deletions src/fairseq2/recipes/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,17 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any, final

from typing_extensions import override

from fairseq2.dependency import DependencyContainer, DependencyResolver
from fairseq2.dependency import DependencyContainer
from fairseq2.utils.structured import StructuredError, ValueConverter


class ConfigManager(ABC):
@abstractmethod
def get_config(
self,
path: str,
type_expr: Any,
*,
default_factory: Callable[[], Any] | None = None,
) -> Any:
def get_config(self, path: str, type_expr: Any) -> Any:
...


Expand All @@ -42,19 +35,10 @@ def update_config_dict(self, config_dict: dict[str, object]) -> None:
self._config_dict.update(config_dict)

@override
def get_config(
self,
path: str,
type_expr: Any,
*,
default_factory: Callable[[], Any] | None = None,
) -> Any:
def get_config(self, path: str, type_expr: Any) -> Any:
try:
config = self._config_dict[path]
except KeyError:
if default_factory is not None:
return default_factory()

raise ConfigNotFoundError(
f"The '{path}' configuration is not found."
) from None
Expand All @@ -81,24 +65,3 @@ def register_config_manager(container: DependencyContainer) -> None:
container.register_factory(
ConfigManager, lambda r: r.resolve(StandardConfigManager)
)


def register_config(
container: DependencyContainer,
path: str,
kls: type,
*,
type_expr: Any | None = None,
default_factory: Callable[[], Any] | None = None,
) -> None:
def create(resolver: DependencyResolver) -> Any:
config_manager = resolver.resolve(ConfigManager)

try:
return config_manager.get_config(
path, type_expr or kls, default_factory=default_factory
)
except ConfigNotFoundError:
return None

container.register_factory(kls, create, key=path)
Loading

0 comments on commit 83550be

Please sign in to comment.