Skip to content

Commit

Permalink
convert ShowImages summary to new metrics protocol
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 683621186
  • Loading branch information
Qwlouse authored and The kauldron Authors committed Oct 23, 2024
1 parent ebb9144 commit 2a06199
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 163 deletions.
1 change: 1 addition & 0 deletions kauldron/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=g-importing-member
from kauldron.metrics.auto_state import AutoState
from kauldron.metrics.auto_state import concat_field
from kauldron.metrics.auto_state import static_field
from kauldron.metrics.auto_state import sum_field
from kauldron.metrics.auto_state import truncate_field
from kauldron.metrics.base import Metric
Expand Down
31 changes: 29 additions & 2 deletions kauldron/metrics/auto_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Any, Literal, Self, TypeAlias, TypeVar

import jax
from kauldron import kontext
from kauldron.metrics import base_state
from kauldron.metrics.base_state import EMPTY # pylint: disable=g-importing-member
from kauldron.typing import Array # pylint: disable=g-multiple-import
Expand Down Expand Up @@ -137,6 +138,32 @@ def compute(self: _SelfT) -> _SelfT:
# TODO(klausg): overloaded type annotation similar to dataclasses.field?


def static_field(default: Any = dataclasses.MISSING, **kwargs):
"""Define an AutoState static field.
Static fields are not merged, and instead are checked for equality during
`merge`. They are also not pytree nodes, so they are not touched by jax
transforms (but can lead to recompilation if changed).
These can be useful to store some parameters of the metric, e.g. the number
of elements to keep. Note that static fields are rarely needed, since it is
usually better to define static params in the corresponding metric and access
them through the `parent` field.
Args:
default: The default value of the field.
**kwargs: Additional arguments to pass to the dataclasses.field.
Returns:
A dataclasses.Field instance with additional metadata that marks this field
as a static field.
"""
metadata = kwargs.pop("metadata", {})
metadata = metadata | {
"pytree_node": False,
}
return dataclasses.field(default=default, metadata=metadata, **kwargs)


def sum_field(
*,
default: Any = dataclasses.MISSING,
Expand Down Expand Up @@ -331,7 +358,7 @@ class _Truncate(_FieldMerger):
Attributes:
axis: The axis along which to concatenate the two arrays. Defaults to 0.
num_field: The name of the field that contains the number of elements to
keep.
keep. Can be any valid kontext.Path (e.g. "parent.num_images").
"""

axis: int | None = 0
Expand All @@ -343,7 +370,7 @@ def merge(
v2: ArrayOrEmpty,
state: base_state.State,
) -> ArrayOrEmpty:
num = getattr(state, self.num_field)
num = kontext.get_by_path(state, self.num_field)
v1 = self._maybe_truncate(v1, num)
v2 = self._maybe_truncate(v2, num)
if v1 is EMPTY and v2 is EMPTY:
Expand Down
4 changes: 3 additions & 1 deletion kauldron/metrics/base_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ class State(abc.ABC, Generic[_MetricT]):
"""

_: dataclasses.KW_ONLY
parent: _MetricT = flax.struct.field(pytree_node=False, default=EMPTY)
parent: _MetricT = flax.struct.field(
pytree_node=False, default=EMPTY
) # pytype: disable=annotation-type-mismatch

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
Expand Down
4 changes: 2 additions & 2 deletions kauldron/summaries/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from kauldron.summaries.base import ImageSummary
from kauldron.summaries.base import PerImageChannelPCA
from kauldron.summaries.base import ShowBoxes
from kauldron.summaries.base import ShowDifferenceImages
from kauldron.summaries.base import ShowImages
from kauldron.summaries.base import ShowSegmentations
from kauldron.summaries.base import Summary
from kauldron.summaries.histograms import Histogram
from kauldron.summaries.histograms import HistogramSummary
from kauldron.summaries.images import ShowDifferenceImages
from kauldron.summaries.images import ShowImages
from kauldron.summaries.pointclouds import PointCloud
from kauldron.summaries.pointclouds import ShowPointCloud
# pylint: enable=g-importing-member
156 changes: 0 additions & 156 deletions kauldron/summaries/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from etils import epy
import flax
import jax
import jax.numpy as jnp
from kauldron import kontext
from kauldron.typing import Array, Bool, Float, Integer, Shape, UInt8, typechecked # pylint: disable=g-multiple-import,g-importing-member
from kauldron.utils import plot_segmentation as segplot # pylint: disable=g-import-not-at-top
Expand Down Expand Up @@ -74,161 +73,6 @@ def __call__(self, *, context: Any = None, **kwargs) -> Images:
return self.get_images(**kwargs)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class ShowImages(ImageSummary):
"""Show a set of images with optional reshaping and resizing."""

images: kontext.Key
masks: Optional[kontext.Key] = None

num_images: int
rearrange: Optional[str] = None
rearrange_kwargs: Mapping[str, Any] = dataclasses.field(
default_factory=flax.core.FrozenDict[str, Any]
)
width: Optional[int] = None
height: Optional[int] = None
in_vrange: Optional[tuple[float, float]] = None
mask_color: float | tuple[float, float, float] = 0.5

def gather_kwargs(self, context: Any) -> dict[str, Images | Masks]:
# optimize gather_kwargs to only return num_images many images
kwargs = kontext.resolve_from_keyed_obj(context, self)
images = kwargs["images"]
masks = kwargs.get("masks", None)
if self.rearrange:
images = einops.rearrange(images, self.rearrange, **self.rearrange_kwargs)
images = images.astype(jnp.float32)
if not isinstance(images, Float["n h w #3"]):
raise ValueError(f"Bad shape or dtype: {images.shape} {images.dtype}")

images = images[: self.num_images]
if masks is not None:
if not isinstance(masks, Bool["n h w 1"]):
raise ValueError(
f"Bad mask shape or dtype: {masks.shape} {masks.dtype}" # pylint: disable=attribute-error
)
masks = masks[: self.num_images]

return {"images": images, "masks": masks}

@typechecked
def get_images(
self, images: Images, masks: Optional[Masks] = None
) -> Float["n _h _w _c"]:
# flatten batch dimensions
images = einops.rearrange(images, "... h w c -> (...) h w c")
images = np.array(images[: self.num_images])
# maybe rescale
if self.in_vrange is not None:
vmin, vmax = self.in_vrange
images = (images - vmin) / (vmax - vmin)
# convert to float
images = media.to_type(images.astype(jnp.float32), np.float32)

if masks is not None:
masks = einops.rearrange(masks, "... h w c -> (...) h w c")
masks = np.array(masks[: self.num_images, :, :, 0])
images[masks] = self.mask_color

# always clip to avoid display problems in TB and Datatables
images = np.clip(images, 0.0, 1.0)
# maybe resize
if (self.width, self.height) != (None, None):
shape = _get_height_width(self.width, self.height, Shape("h w"))
images = media.resize_video(images, shape)
return images


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class ShowDifferenceImages(ImageSummary):
"""Show a set of difference images with optional reshaping and resizing."""

images1: kontext.Key
images2: kontext.Key
masks: Optional[kontext.Key] = None

num_images: int
vrange: tuple[float, float]
rearrange: Optional[str] = None
rearrange_kwargs: Mapping[str, Any] = dataclasses.field(
default_factory=flax.core.FrozenDict[str, Any]
)
width: Optional[int] = None
height: Optional[int] = None
cmap: str | None = None
mask_color: float | tuple[float, float, float] = 0.5

def gather_kwargs(self, context: Any) -> dict[str, Images | Masks]:
# optimize gather_kwargs to only return num_images many images
kwargs = kontext.resolve_from_keyed_obj(context, self)
images1, images2 = kwargs["images1"], kwargs["images2"]
masks = kwargs.get("masks", None)
if self.rearrange:
images1 = einops.rearrange(
images1, self.rearrange, **self.rearrange_kwargs
)
images2 = einops.rearrange(
images2, self.rearrange, **self.rearrange_kwargs
)

images1 = images1.astype(jnp.float32)
images2 = images2.astype(jnp.float32)
if not isinstance(images1, Float["n h w #3"]):
raise ValueError(f"Bad shape or dtype: {images1.shape} {images1.dtype}")
if not isinstance(images2, Float["n h w #3"]):
raise ValueError(f"Bad shape or dtype: {images2.shape} {images2.dtype}")

num_images_per_device = math.ceil(
self.num_images / jax.local_device_count()
)
images1 = images1[:num_images_per_device]
images2 = images2[:num_images_per_device]
if masks is not None:
if not isinstance(masks, Bool["n h w 1"]):
raise ValueError(
f"Bad mask shape or dtype: {masks.shape} {masks.dtype}" # pylint: disable=attribute-error
)
masks = masks[:num_images_per_device]

return {"images1": images1, "images2": images2, "masks": masks}

@typechecked
def get_images(
self, images1: Images, images2: Images, masks: Optional[Masks] = None
) -> Float["n _h _w _c"]:
# flatten batch dimensions
images1 = einops.rearrange(images1, "... h w c -> (...) h w c")
images2 = einops.rearrange(images2, "... h w c -> (...) h w c")
images1 = np.array(images1[: self.num_images])
images2 = np.array(images2[: self.num_images])
# convert to float
images1 = media.to_type(images1, np.float32)
images2 = media.to_type(images2, np.float32)

# Compute absolute difference and mean across channels
vmin, vmax = self.vrange
images = np.abs(np.clip(images1, vmin, vmax) - np.clip(images2, vmin, vmax))
images = np.mean(images, axis=-1, keepdims=True)

# Normalize difference image to 0-1 and color.
cmap = self.cmap if self.cmap else "gray"
images = media.to_rgb(images[..., 0], cmap=cmap, vmin=0, vmax=vmax - vmin)

if masks is not None:
masks = einops.rearrange(masks, "... h w c -> (...) h w c")
masks = np.array(masks[: self.num_images, :, :, 0])
images[masks] = self.mask_color

# always clip to avoid display problems in TB and Datatables
images = np.clip(images, 0.0, 1.0)
# maybe resize
if (self.width, self.height) != (None, None):
shape = _get_height_width(self.width, self.height, Shape("h w"))
images = media.resize_video(images, shape)
return images


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class ShowSegmentations(ImageSummary):
"""Show a set of segmentations with optional reshaping and resizing."""
Expand Down
Loading

0 comments on commit 2a06199

Please sign in to comment.