Skip to content

Commit

Permalink
add static_field support to AutoState
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688944515
  • Loading branch information
Qwlouse authored and The kauldron Authors committed Oct 23, 2024
1 parent 300c767 commit 42e6c07
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 3 deletions.
1 change: 1 addition & 0 deletions kauldron/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=g-importing-member
from kauldron.metrics.auto_state import AutoState
from kauldron.metrics.auto_state import concat_field
from kauldron.metrics.auto_state import static_field
from kauldron.metrics.auto_state import sum_field
from kauldron.metrics.auto_state import truncate_field
from kauldron.metrics.base import Metric
Expand Down
31 changes: 29 additions & 2 deletions kauldron/metrics/auto_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Any, Literal, Self, TypeAlias, TypeVar

import jax
from kauldron import kontext
from kauldron.metrics import base_state
from kauldron.metrics.base_state import EMPTY # pylint: disable=g-importing-member
from kauldron.typing import Array # pylint: disable=g-multiple-import
Expand Down Expand Up @@ -137,6 +138,32 @@ def compute(self: _SelfT) -> _SelfT:
# TODO(klausg): overloaded type annotation similar to dataclasses.field?


def static_field(default: Any = dataclasses.MISSING, **kwargs):
"""Define an AutoState static field.
Static fields are not merged, and instead are checked for equality during
`merge`. They are also not pytree nodes, so they are not touched by jax
transforms (but can lead to recompilation if changed).
These can be useful to store some parameters of the metric, e.g. the number
of elements to keep. Note that static fields are rarely needed, since it is
usually better to define static params in the corresponding metric and access
them through the `parent` field.
Args:
default: The default value of the field.
**kwargs: Additional arguments to pass to the dataclasses.field.
Returns:
A dataclasses.Field instance with additional metadata that marks this field
as a static field.
"""
metadata = kwargs.pop("metadata", {})
metadata = metadata | {
"pytree_node": False,
}
return dataclasses.field(default=default, metadata=metadata, **kwargs)


def sum_field(
*,
default: Any = dataclasses.MISSING,
Expand Down Expand Up @@ -331,7 +358,7 @@ class _Truncate(_FieldMerger):
Attributes:
axis: The axis along which to concatenate the two arrays. Defaults to 0.
num_field: The name of the field that contains the number of elements to
keep.
keep. Can be any valid kontext.Path (e.g. "parent.num_images").
"""

axis: int | None = 0
Expand All @@ -343,7 +370,7 @@ def merge(
v2: ArrayOrEmpty,
state: base_state.State,
) -> ArrayOrEmpty:
num = getattr(state, self.num_field)
num = kontext.get_by_path(state, self.num_field)
v1 = self._maybe_truncate(v1, num)
v2 = self._maybe_truncate(v2, num)
if v1 is EMPTY and v2 is EMPTY:
Expand Down
4 changes: 3 additions & 1 deletion kauldron/metrics/base_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ class State(abc.ABC, Generic[_MetricT]):
"""

_: dataclasses.KW_ONLY
parent: _MetricT = flax.struct.field(pytree_node=False, default=EMPTY)
parent: _MetricT = flax.struct.field(
pytree_node=False, default=EMPTY
) # pytype: disable=annotation-type-mismatch

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
Expand Down

0 comments on commit 42e6c07

Please sign in to comment.