Skip to content

Commit

Permalink
add support for optional fields (using None) to AutoState
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688964651
  • Loading branch information
Qwlouse authored and The kauldron Authors committed Oct 23, 2024
1 parent fe16083 commit b598bef
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 30 deletions.
81 changes: 51 additions & 30 deletions kauldron/metrics/auto_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@

_MetricT = TypeVar("_MetricT")
_SelfT = TypeVar("_SelfT")

ArrayOrEmpty: TypeAlias = Array | Literal[base_state._EMPTY_TYPE.EMPTY] # pylint: disable=protected-access
Empty: TypeAlias = Literal[base_state._EMPTY_TYPE.EMPTY] # pylint: disable=protected-access


class AutoState(base_state.State[_MetricT]):
Expand Down Expand Up @@ -287,13 +286,16 @@ class _FieldMerger(abc.ABC):

@abc.abstractmethod
def merge(
self, v1: ArrayOrEmpty, v2: ArrayOrEmpty, state: base_state.State
) -> ArrayOrEmpty:
self,
v1: Array | Empty | None,
v2: Array | Empty | None,
state: base_state.State,
) -> Array | Empty | None:
...

def finalize(self, v: ArrayOrEmpty) -> np.ndarray | None:
def finalize(self, v: Array | Empty | None) -> np.ndarray | None:
# by default convert to numpy array
if v is EMPTY:
if v is EMPTY or v is None:
return None
return np.asarray(v)

Expand All @@ -308,14 +310,19 @@ class _ReduceSum(_FieldMerger):

def merge(
self,
v1: ArrayOrEmpty,
v2: ArrayOrEmpty,
v1: Array | Empty | None,
v2: Array | Empty | None,
state: base_state.State,
) -> ArrayOrEmpty:
if v1 is EMPTY or v2 is EMPTY:
return v1 if v2 is EMPTY else v2
else:
return v1 + v2
) -> Array | Empty:
if v1 is EMPTY:
return v2
if v2 is EMPTY:
return v1
if v1 is None or v2 is None:
if not (v1 is None and v2 is None):
raise ValueError("Cannot sum None and non-None values.")
return None
return v1 + v2


@dataclasses.dataclass(kw_only=True, frozen=True)
Expand All @@ -326,23 +333,30 @@ class _Concatenate(_FieldMerger):

def merge(
self,
v1: ArrayOrEmpty | tuple[Array, ...],
v2: ArrayOrEmpty | tuple[Array, ...],
v1: Array | Empty | None | tuple[Array, ...],
v2: Array | Empty | None | tuple[Array, ...],
state: base_state.State,
) -> tuple[Array, ...]:
) -> tuple[Array, ...] | None:
if v1 is None or v2 is None:
if not (v1 is None and v2 is None):
raise ValueError("Cannot concatenate None and non-None values.")
return None

v1 = _normalize_to_tuple(v1)
v2 = _normalize_to_tuple(v2)
return v1 + v2 # concatenated tuples

def finalize(self, v: ArrayOrEmpty | tuple[Array, ...]) -> Array | None:
v = _normalize_to_tuple(v)
if not v:
def finalize(
self, v: Array | Empty | None | tuple[Array, ...]
) -> Array | None:
if v is EMPTY or v is None:
return None
v = _normalize_to_tuple(v)
return np.concatenate(v, axis=self.axis)


def _normalize_to_tuple(
v: ArrayOrEmpty | tuple[Array, ...],
v: Array | Empty | tuple[Array, ...],
) -> tuple[np.ndarray, ...]:
if v is EMPTY:
return ()
Expand All @@ -366,27 +380,34 @@ class _Truncate(_FieldMerger):

def merge(
self,
v1: ArrayOrEmpty,
v2: ArrayOrEmpty,
v1: Array | Empty | None,
v2: Array | Empty | None,
state: base_state.State,
) -> ArrayOrEmpty:
) -> Array | Empty | None:
num = kontext.get_by_path(state, self.num_field)
v1 = self._maybe_truncate(v1, num)
v2 = self._maybe_truncate(v2, num)
if v1 is EMPTY and v2 is EMPTY:
return EMPTY
if v1 is EMPTY or v2 is EMPTY:
return v1 if v2 is EMPTY else v2
if v1 is EMPTY:
return v2
if v2 is EMPTY:
return v1

if v1 is None or v2 is None:
if not (v1 is None and v2 is None):
raise ValueError(
"Cannot concatenate (& truncate) None and non-None values."
)
return None

assert isinstance(v1, Array) and isinstance(v2, Array)
if v1.shape[self.axis] < num:
v1 = np.concatenate([v1, v2], axis=self.axis)
return self._maybe_truncate(v1, num)

def _maybe_truncate(self, v: ArrayOrEmpty, num: int) -> ArrayOrEmpty:
def _maybe_truncate(self, v: Array | Empty, num: int) -> Array | Empty:
"""If v is not None, then truncate it to num elements along axis."""
if v is EMPTY:
return EMPTY
if v is EMPTY or v is None:
return v
assert isinstance(v, Array)
axis = np.lib.array_utils.normalize_axis_index(self.axis, v.ndim)
return np.asarray(v[(slice(None),) * axis + (slice(None, num),)])
Expand Down
34 changes: 34 additions & 0 deletions kauldron/metrics/auto_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from kauldron.metrics import base_state
from kauldron.typing import Float # pylint: disable=g-multiple-import,member-import
import numpy as np
import pytest


def test_empty():
Expand Down Expand Up @@ -57,26 +58,39 @@ def test_merge_sum():
class SumState(auto_state.AutoState):
a: int = 3
b: Float = auto_state.sum_field()
c: Float | None = auto_state.sum_field(default=None)

s1 = SumState(a=3, b=np.ones((3, 2)))
s2 = SumState(a=3, b=np.ones((3, 2)) * 5)
s3 = SumState(a=3, b=np.ones((3, 2)) * 10, c=np.ones((1, 1)))

s = s1.merge(s2)
assert s.a == 3
result = s.compute()
assert result.b.shape == (3, 2)
np.testing.assert_allclose(result.b, 6.0)

# no error:
_ = s3.merge(s3)

with pytest.raises(ValueError, match="Cannot sum None"):
s1.merge(s3)

with pytest.raises(ValueError, match="Cannot sum None"):
s3.merge(s1)


def test_merge_concat():
@functools.partial(flax.struct.dataclass, kw_only=True)
class ConcatState(auto_state.AutoState):
a: str = "irrelevant"
b: Float = auto_state.concat_field()
c: Float = auto_state.concat_field(axis=1)
d: Float | None = auto_state.concat_field(default=None)

s1 = ConcatState(b=np.ones((3, 2)), c=np.ones((3, 2)))
s2 = ConcatState(b=np.zeros((3, 2)), c=np.zeros((3, 2)))
s3 = ConcatState(b=np.ones((3, 2)), c=np.ones((3, 2)), d=np.ones((1, 1)))

s = s1.merge(s2)
assert s.a == "irrelevant"
Expand All @@ -90,6 +104,15 @@ class ConcatState(auto_state.AutoState):
np.testing.assert_allclose(result.c[:, :2], 1.0)
np.testing.assert_allclose(result.c[:, 2:], 0.0)

# no error:
_ = s3.merge(s3)

with pytest.raises(ValueError, match="Cannot concatenate None"):
s1.merge(s3)

with pytest.raises(ValueError, match="Cannot concatenate None"):
s3.merge(s1)


def test_merge_truncate():
@functools.partial(flax.struct.dataclass, kw_only=True)
Expand All @@ -98,9 +121,11 @@ class TruncateState(auto_state.AutoState):
num_c: int = 3
b: Float = auto_state.truncate_field(num_field="num_b")
c: Float = auto_state.truncate_field(num_field="num_c", axis=1)
d: Float | None = auto_state.truncate_field(num_field="num_b", default=None)

s1 = TruncateState(b=np.ones((3, 2)), c=np.ones((3, 2)))
s2 = TruncateState(b=np.ones((3, 2)), c=np.ones((3, 2)))
s3 = TruncateState(b=np.ones((3, 2)), c=np.ones((3, 2)), d=np.ones((1, 1)))

s = s1.merge(s2)
assert s.num_b == 4
Expand All @@ -111,3 +136,12 @@ class TruncateState(auto_state.AutoState):
assert result.c.shape == (3, 3)
np.testing.assert_allclose(result.b, 1.0)
np.testing.assert_allclose(result.c, 1.0)

# no error:
_ = s3.merge(s3)

with pytest.raises(ValueError, match=r"Cannot .*truncate.* None"):
s1.merge(s3)

with pytest.raises(ValueError, match=r"Cannot .*truncate.* None"):
s3.merge(s1)

0 comments on commit b598bef

Please sign in to comment.