From ebb91445be538fa50af40a732cbef2f53c914f7f Mon Sep 17 00:00:00 2001 From: Klaus Greff Date: Tue, 22 Oct 2024 08:09:25 -0700 Subject: [PATCH] copybara strip some externally unsupported features of the metric writer PiperOrigin-RevId: 688554082 --- examples/mnist_autoencoder.py | 12 ++-- examples/test/minimal.py | 71 ------------------- examples/test/mnist_autoencoder.py | 107 ----------------------------- examples/test/mock_data.py | 27 -------- kauldron/train/metric_writer.py | 39 ++--------- pyproject.toml | 4 +- 6 files changed, 13 insertions(+), 247 deletions(-) delete mode 100644 examples/test/minimal.py delete mode 100644 examples/test/mnist_autoencoder.py delete mode 100644 examples/test/mock_data.py diff --git a/examples/mnist_autoencoder.py b/examples/mnist_autoencoder.py index 8c35ceba..9c1508b5 100644 --- a/examples/mnist_autoencoder.py +++ b/examples/mnist_autoencoder.py @@ -14,13 +14,10 @@ r"""Minimal example training a simple Autoencoder on MNIST. -`--xp.use_interpreter` to launch with `ml_python` (no BUILD rules). - +Run: ```sh -xmanager launch third_party/py/kauldron/xm/launch.py -- \ - --cfg=third_party/py/kauldron/examples/mnist_autoencoder.py \ - --xp.use_interpreter \ - --xp.platform=jf=2x2 +python main.py --cfg=examples/mnist_autoencoder.py \ + --cfg.workdir=/tmp/kauldron_oss/workdir ``` """ @@ -90,7 +87,8 @@ def get_config(): def _make_ds(training: bool): - return kd.data.Tfds( + Tfds = kd.data.py.Tfds + return Tfds( name="mnist", split="train" if training else "test", shuffle=True if training else False, diff --git a/examples/test/minimal.py b/examples/test/minimal.py deleted file mode 100644 index 026ad865..00000000 --- a/examples/test/minimal.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2024 The kauldron Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Temporary test file to debug the open source release. - -Run: - python main.py --cfg=examples/test/minimal.py \ - --cfg.workdir=/tmp/kauldron_oss/workdir - -TODO(klausg): remove this file once the other examples are working. -""" - -from kauldron import konfig - -# pylint: disable=g-import-not-at-top -with konfig.imports(): - from kauldron import kd - import numpy as np - import functools - import optax -# pylint: enable=g-import-not-at-top - - -def get_config(): - """Get the default hyperparameter configuration.""" - cfg = kd.train.Trainer() - cfg.seed = 42 - - # Dataset - cfg.train_ds = kd.data.InMemoryPipeline( - loader=functools.partial(np.ones, (1, 1)), batch_size=1 - ) - - # Model - cfg.model = kd.nn.DummyModel() - - # Training - cfg.num_train_steps = 0 - - # Losses - cfg.train_losses = {} - - cfg.train_metrics = {} - - cfg.train_summaries = {} - - cfg.writer = kd.train.metric_writer.NoopWriter() - - # Optimizer - cfg.schedules = {} - - cfg.optimizer = optax.sgd(0.1) - - cfg.evals = {} - - cfg.setup = kd.train.setup_utils.Setup( # pytype: disable=wrong-arg-types - add_flatboard=False, flatboard_build_context=None - ) - - return cfg diff --git a/examples/test/mnist_autoencoder.py b/examples/test/mnist_autoencoder.py deleted file mode 100644 index 00318b7b..00000000 --- a/examples/test/mnist_autoencoder.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2024 The kauldron Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Temporary adaptation of the mnist_autoencoder example to debug OSS issues. - -Run: - python main.py --cfg=examples/test/mnist_autoencoder.py \ - --cfg.workdir=/tmp/kauldron_oss/workdir - -TODO(klausg): remove this file once the original example is working. -""" - -from kauldron import konfig - -# pylint: disable=g-import-not-at-top -with konfig.imports(): - from flax import linen as nn - from kauldron import kd - import optax -# pylint: enable=g-import-not-at-top - - -def get_config(): - """Get the default hyperparameter configuration.""" - cfg = kd.train.Trainer() - cfg.seed = 42 - - # Dataset - cfg.train_ds = _make_ds(training=True) - - # Model - cfg.model = kd.nn.FlatAutoencoder( - inputs="batch.image", - encoder=nn.Dense(features=128), - decoder=nn.Dense(features=28 * 28), - ) - - # Training - cfg.num_train_steps = 1000 - - # Losses - cfg.train_losses = { - "recon": kd.losses.L2(preds="preds.image", targets="batch.image"), - } - - cfg.train_metrics = { - "latent_norm": kd.metrics.Norm(tensor="interms.encoder.__call__[0]"), - "param_norm": kd.metrics.TreeMap( - metric=kd.metrics.Norm(tensor="params", axis=None) - ), - "grad_norm": kd.metrics.TreeReduce( - metric=kd.metrics.Norm(tensor="grads", axis=None) - ), - } - - cfg.train_summaries = { - "gt": kd.summaries.ShowImages(images="batch.image", num_images=5), - "recon": kd.summaries.ShowImages(images="preds.image", num_images=5), - } - - cfg.writer = kd.train.metric_writer.NoopWriter() - - # Optimizer - cfg.schedules = {} - - cfg.optimizer = optax.adam(learning_rate=0.003) - - # Checkpointer - cfg.checkpointer = kd.ckpts.Checkpointer( - save_interval_steps=500, - ) - - cfg.evals = { - "eval": kd.evals.Evaluator( - run=kd.evals.EveryNSteps(100), - num_batches=None, - ds=_make_ds(training=False), - metrics={}, - ) - } - - return cfg - - -def _make_ds(training: bool): - return kd.data.py.Tfds( - name="mnist", - split="train" if training else "test", - shuffle=True if training else False, - num_epochs=None if training else 1, - transforms=[ - kd.data.Elements(keep=["image"]), - kd.data.ValueRange(key="image", vrange=(0, 1)), - ], - batch_size=256, - ) diff --git a/examples/test/mock_data.py b/examples/test/mock_data.py deleted file mode 100644 index 5ca964a2..00000000 --- a/examples/test/mock_data.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2024 The kauldron Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Temporary helper functions to generate mock data for testing. - -TODO(klausg): remove once kd.data is working correctly. -""" - -import numpy as np - - -def get_mnist_mockup(): - return { - "image": np.ones((32, 28, 28, 1), dtype=np.float32), - "label": np.ones((32, 1), dtype=np.int32), - } diff --git a/kauldron/train/metric_writer.py b/kauldron/train/metric_writer.py index 5b575a26..bb6c90a1 100644 --- a/kauldron/train/metric_writer.py +++ b/kauldron/train/metric_writer.py @@ -358,15 +358,9 @@ class KDMetricWriter(MetadataWriter): add_artifacts: bool = True - flatboard_build_context: kdash.BuildContext = ( - config_util.ROOT_CFG_REF.setup.flatboard_build_context - ) - @functools.cached_property def _collection_path_prefix(self) -> str: - if self.flatboard_build_context.collection_path_prefix is None: - raise ValueError("collection_path_prefix must be set.") - return self.flatboard_build_context.collection_path_prefix + return "" @functools.cached_property def _scalar_datatable_name(self) -> str: @@ -416,27 +410,7 @@ def _array_writer(self) -> metric_writers.MetricWriter: def _create_datatable_writer( self, name: str, description: str ) -> metric_writers.MetricWriter: - if epy.is_test(): - return self._noop # Do not write to datatable inside tests - if status.on_xmanager: - if not status.is_lead_host: - return self._noop - if status.wid == 1 and self.add_artifacts: - status.xp.create_artifact( - artifact_type=xmanager_api.ArtifactType.ARTIFACT_TYPE_STORAGE2_BIGTABLE, - artifact=name, - description=description, - ) - keys = [("wid", status.wid)] - else: - keys = [] - - return metric_writers.AsyncWriter( - metric_writers.DatatableWriter( - datatable_name=name, - keys=keys, - ), - ) + return self._noop def write_summaries( self, @@ -504,12 +478,9 @@ def write_pointcloud( point_colors: Mapping[str, Array["n 3"]] | None = None, configs: Mapping[str, str | float | bool | None] | None = None, ) -> None: - self._tf_summary_writer.write_pointcloud( - step=step, - point_clouds=point_clouds, - point_colors=point_colors, - configs=configs, - ) + if not point_clouds: + return + logging.info("Pointcloud summary not supported.") def write_hparams(self, hparams: Mapping[str, Any]) -> None: self._log_writer.write_hparams(hparams) diff --git a/pyproject.toml b/pyproject.toml index 8e072746..3a4f6c4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,8 +42,10 @@ dependencies = [ "regex", "scikit-image", "scikit-learn", + "tabulate", # used by pandas.DataFrame.to_markdown (for logging context) "tensorflow", - "tensorflow_datasets", + "tfds-nightly", # TODO(klausg): switch back to tensorflow_datasets>=4.9.7 + # once released: https://github.com/tensorflow/datasets/commit/d4bfd59863c6cb5b64d043b7cb6ab566e7d92440 "tqdm", # closest match to the internal typeguard "typeguard@git+https://github.com/agronholm/typeguard@0dd7f7510b7c694e66a0d17d1d58d185125bad5d",