Skip to content

Commit

Permalink
convert ShowImages summary to new metrics protocol
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 683621186
  • Loading branch information
Qwlouse authored and The kauldron Authors committed Oct 9, 2024
1 parent 8fb6db4 commit a9e4cf2
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 69 deletions.
1 change: 1 addition & 0 deletions kauldron/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=g-importing-member
from kauldron.metrics.auto_state import AutoState
from kauldron.metrics.auto_state import concat_field
from kauldron.metrics.auto_state import static_field
from kauldron.metrics.auto_state import sum_field
from kauldron.metrics.auto_state import truncate_field
from kauldron.metrics.base import Metric
Expand Down
26 changes: 26 additions & 0 deletions kauldron/metrics/auto_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,32 @@ def compute(self: _SelfT) -> _SelfT:
# TODO(klausg): overloaded type annotation similar to dataclasses.field?


def static_field(default: Any = dataclasses.MISSING, **kwargs):
"""Define an AutoState static field.
Static fields are not merged, and instead are checked for equality during
`merge`. They are also not pytree nodes, so they are not touched by jax
transforms (but can lead to recompilation if changed).
These can be useful to store some parameters of the metric, e.g. the number
of elements to keep. Note that static fields are rarely needed, since it is
usually better to define static params in the corresponding metric and access
them through the `parent` field.
Args:
default: The default value of the field.
**kwargs: Additional arguments to pass to the dataclasses.field.
Returns:
A dataclasses.Field instance with additional metadata that marks this field
as a static field.
"""
metadata = kwargs.pop("metadata", {})
metadata = metadata | {
"pytree_node": False,
}
return dataclasses.field(default=default, metadata=metadata, **kwargs)


def sum_field(
*,
default: Any = dataclasses.MISSING,
Expand Down
2 changes: 1 addition & 1 deletion kauldron/summaries/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from kauldron.summaries.base import PerImageChannelPCA
from kauldron.summaries.base import ShowBoxes
from kauldron.summaries.base import ShowDifferenceImages
from kauldron.summaries.base import ShowImages
from kauldron.summaries.base import ShowSegmentations
from kauldron.summaries.base import Summary
from kauldron.summaries.histograms import Histogram
from kauldron.summaries.histograms import HistogramSummary
from kauldron.summaries.images import ShowImages
from kauldron.summaries.pointclouds import PointCloud
from kauldron.summaries.pointclouds import ShowPointCloud
# pylint: enable=g-importing-member
66 changes: 0 additions & 66 deletions kauldron/summaries/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,72 +74,6 @@ def __call__(self, *, context: Any = None, **kwargs) -> Images:
return self.get_images(**kwargs)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class ShowImages(ImageSummary):
"""Show a set of images with optional reshaping and resizing."""

images: kontext.Key
masks: Optional[kontext.Key] = None

num_images: int
rearrange: Optional[str] = None
rearrange_kwargs: Mapping[str, Any] = dataclasses.field(
default_factory=flax.core.FrozenDict[str, Any]
)
width: Optional[int] = None
height: Optional[int] = None
in_vrange: Optional[tuple[float, float]] = None
mask_color: float | tuple[float, float, float] = 0.5

def gather_kwargs(self, context: Any) -> dict[str, Images | Masks]:
# optimize gather_kwargs to only return num_images many images
kwargs = kontext.resolve_from_keyed_obj(context, self)
images = kwargs["images"]
masks = kwargs.get("masks", None)
if self.rearrange:
images = einops.rearrange(images, self.rearrange, **self.rearrange_kwargs)
images = images.astype(jnp.float32)
if not isinstance(images, Float["n h w #3"]):
raise ValueError(f"Bad shape or dtype: {images.shape} {images.dtype}")

images = images[: self.num_images]
if masks is not None:
if not isinstance(masks, Bool["n h w 1"]):
raise ValueError(
f"Bad mask shape or dtype: {masks.shape} {masks.dtype}" # pylint: disable=attribute-error
)
masks = masks[: self.num_images]

return {"images": images, "masks": masks}

@typechecked
def get_images(
self, images: Images, masks: Optional[Masks] = None
) -> Float["n _h _w _c"]:
# flatten batch dimensions
images = einops.rearrange(images, "... h w c -> (...) h w c")
images = np.array(images[: self.num_images])
# maybe rescale
if self.in_vrange is not None:
vmin, vmax = self.in_vrange
images = (images - vmin) / (vmax - vmin)
# convert to float
images = media.to_type(images.astype(jnp.float32), np.float32)

if masks is not None:
masks = einops.rearrange(masks, "... h w c -> (...) h w c")
masks = np.array(masks[: self.num_images, :, :, 0])
images[masks] = self.mask_color

# always clip to avoid display problems in TB and Datatables
images = np.clip(images, 0.0, 1.0)
# maybe resize
if (self.width, self.height) != (None, None):
shape = _get_height_width(self.width, self.height, Shape("h w"))
images = media.resize_video(images, shape)
return images


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class ShowDifferenceImages(ImageSummary):
"""Show a set of difference images with optional reshaping and resizing."""
Expand Down
123 changes: 123 additions & 0 deletions kauldron/summaries/images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2024 The kauldron Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Histogram summaries."""

from __future__ import annotations

import dataclasses
from typing import Any, Mapping, Optional, TypeVar

import einops
from flax import struct
from kauldron import kontext
from kauldron import metrics
from kauldron.typing import Bool, Float, Shape, check_type, typechecked # pylint: disable=g-multiple-import,g-importing-member
import mediapy as media
import numpy as np

_MetricT = TypeVar("_MetricT")


@struct.dataclass
class CollectImages(metrics.AutoState[_MetricT]):
"""Collects the first num_images images and optionally resizes them."""

images: Float["n h w #3"] = metrics.truncate_field(num_field="num_images")

num_images: int = metrics.static_field()
width: int | None = metrics.static_field(default=None)
height: int | None = metrics.static_field(default=None)

@typechecked
def compute(self):
images = super().compute().images
check_type(images, Float["n h w #3"])
if self.width is not None and self.height is not None:
shape = _get_height_width(self.width, self.height, Shape("h w"))
images = media.resize_video(images, shape)

# always clip to avoid display problems in TB and Datatables
return np.clip(images, 0.0, 1.0)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class ShowImages(metrics.Metric):
"""Show image summaries with optional reshaping, masking, and resizing."""

images: kontext.Key
masks: Optional[kontext.Key] = None

num_images: int
rearrange: Optional[str] = None
rearrange_kwargs: Optional[Mapping[str, Any]] = None
width: Optional[int] = None
height: Optional[int] = None
in_vrange: Optional[tuple[float, float]] = None
mask_color: float | tuple[float, float, float] = 0.5

@struct.dataclass
class State(CollectImages["ShowImages"]):
pass

@typechecked
def get_state(
self,
images: Float["*b h w #3"],
masks: Optional[Bool["*b h w 1"]] = None,
) -> ShowImages.State:
images = self._rearrange(images)
masks = self._rearrange(masks)

if self.in_vrange is not None:
vmin, vmax = self.in_vrange
images = (images - vmin) / (vmax - vmin)

if masks is not None:
images = images.at[masks[..., 0]].set(self.mask_color)

return self.State(
num_images=self.num_images,
width=self.width,
height=self.height,
images=images,
)

def _rearrange(self, img_or_mask):
if self.rearrange and img_or_mask is not None:
rearrange_kwargs = self.rearrange_kwargs or {}
img_or_mask = einops.rearrange(
img_or_mask, self.rearrange, **rearrange_kwargs
)
# flatten batch dimensions
if img_or_mask is not None:
img_or_mask = einops.rearrange(img_or_mask, "... h w c -> (...) h w c")
# truncate to num_images. Just an optimization to avoid unnecessary
# computation.
img_or_mask = img_or_mask[: self.num_images]
return img_or_mask


def _get_height_width(
width: Optional[int], height: Optional[int], shape: tuple[int, int]
) -> tuple[int, int]:
"""Returns (width, height) given optional parameters and image shape."""
h, w = shape
if width and height:
return height, width
if width and not height:
return int(width * (h / w) + 0.5), width
if height and not width:
return height, int(height * (w / h) + 0.5)
return shape
12 changes: 10 additions & 2 deletions kauldron/train/metric_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,20 @@ def write_step_metrics(

if log_summaries:
with jax.transfer_guard("allow"):
# image summaries # TODO(klausg): unify with metrics
image_summaries = {
# TODO(klausg): remove once all summaries are migrated to new protocol
# image summaries
image_summaries_old = {
name: summary.get_images(**aux.summary_kwargs[name])
for name, summary in model_with_aux.summaries.items()
if isinstance(summary, summaries.ImageSummary)
}

image_summaries = image_summaries_old | {
name: value
for name, value in aux_result.summary_values.items()
if isinstance(value, Float["n h w #3"])
}

# Throw an error if empty arrays are given. TB throws very odd errors
# and kills Colab runtimes if we don't catch these ourselves.
for name, image in image_summaries.items():
Expand Down

0 comments on commit a9e4cf2

Please sign in to comment.