From e7fc9f2b99e218b1c2eceaf08e7666a6784eca9b Mon Sep 17 00:00:00 2001 From: Jeff Shepherd Date: Thu, 24 Oct 2024 16:05:53 -0700 Subject: [PATCH 1/2] Upgraded distributed tensorflow sample to 2.16 --- .../tensorflow/mnist-distributed/src/main.py | 17 +++++++++++++++-- .../tensorflow-mnist-distributed.ipynb | 2 +- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/sdk/python/jobs/single-step/tensorflow/mnist-distributed/src/main.py b/sdk/python/jobs/single-step/tensorflow/mnist-distributed/src/main.py index 82e1441baf..6ea3b21c65 100644 --- a/sdk/python/jobs/single-step/tensorflow/mnist-distributed/src/main.py +++ b/sdk/python/jobs/single-step/tensorflow/mnist-distributed/src/main.py @@ -79,6 +79,17 @@ def write_filepath(filepath, task_type, task_id): return os.path.join(dirpath, base) +def fix_tf_config(): + # This is necessary for TensorFlow 2.13 and later + tf_config = json.loads(os.environ["TF_CONFIG"]) + if "cluster" in tf_config: + cluster = tf_config["cluster"] + if "ps" in cluster and len(cluster["ps"]) == 0: + cluster.pop("ps") + os.environ["TF_CONFIG"] = json.dumps(tf_config) + return tf_config + + def main(): parser = argparse.ArgumentParser() parser.add_argument("--epochs", type=int, default=3) @@ -93,14 +104,16 @@ def main(): args = parser.parse_args() - tf_config = json.loads(os.environ["TF_CONFIG"]) + tf_config = fix_tf_config() + num_workers = len(tf_config["cluster"]["worker"]) - strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() + strategy = tf.distribute.MultiWorkerMirroredStrategy() # Here the batch size scales up by number of workers since # `tf.data.Dataset.batch` expects the global batch size. global_batch_size = args.per_worker_batch_size * num_workers + multi_worker_dataset = mnist_dataset(global_batch_size) with strategy.scope(): diff --git a/sdk/python/jobs/single-step/tensorflow/mnist-distributed/tensorflow-mnist-distributed.ipynb b/sdk/python/jobs/single-step/tensorflow/mnist-distributed/tensorflow-mnist-distributed.ipynb index f681ee2087..a8c3b04e00 100644 --- a/sdk/python/jobs/single-step/tensorflow/mnist-distributed/tensorflow-mnist-distributed.ipynb +++ b/sdk/python/jobs/single-step/tensorflow/mnist-distributed/tensorflow-mnist-distributed.ipynb @@ -130,7 +130,7 @@ " code=\"./src\", # local path where the code is stored\n", " command=\"python main.py --epochs ${{inputs.epochs}} --model-dir ${{inputs.model_dir}}\",\n", " inputs={\"epochs\": 1, \"model_dir\": \"outputs/keras-model\"},\n", - " environment=\"AzureML-tensorflow-2.12-cuda11@latest\",\n", + " environment=\"AzureML-tensorflow-2.16-cuda11@latest\",\n", " compute=\"cpu-cluster\",\n", " instance_count=2,\n", " # distribution = {\"type\": \"mpi\", \"process_count_per_instance\": 1},\n", From 69eed0a9118f076f6ed1f2d6e37c1e460423dafe Mon Sep 17 00:00:00 2001 From: Jeff Shepherd Date: Thu, 24 Oct 2024 16:09:02 -0700 Subject: [PATCH 2/2] Removed unintended blank line --- .../jobs/single-step/tensorflow/mnist-distributed/src/main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdk/python/jobs/single-step/tensorflow/mnist-distributed/src/main.py b/sdk/python/jobs/single-step/tensorflow/mnist-distributed/src/main.py index 6ea3b21c65..032a6ebffc 100644 --- a/sdk/python/jobs/single-step/tensorflow/mnist-distributed/src/main.py +++ b/sdk/python/jobs/single-step/tensorflow/mnist-distributed/src/main.py @@ -113,7 +113,6 @@ def main(): # Here the batch size scales up by number of workers since # `tf.data.Dataset.batch` expects the global batch size. global_batch_size = args.per_worker_batch_size * num_workers - multi_worker_dataset = mnist_dataset(global_batch_size) with strategy.scope():