diff --git a/ai/src/itwinai/backend/tensorflow/trainer.py b/ai/src/itwinai/backend/tensorflow/trainer.py index 14c1cdf0..a5bd8858 100644 --- a/ai/src/itwinai/backend/tensorflow/trainer.py +++ b/ai/src/itwinai/backend/tensorflow/trainer.py @@ -14,7 +14,7 @@ def __init__( optimizer, model_func, metrics_func, - strategy=tf.distribute.MirroredStrategy() + strategy ): self.strategy = strategy self.epochs = epochs @@ -24,7 +24,11 @@ def __init__( self.optimizer = optimizer # Create distributed TF vars - with self.strategy.scope(): + if self.strategy: + with self.strategy.scope(): + self.model = model_func() + self.model.compile(loss=self.loss, optimizer=self.optimizer, metrics=metrics_func()) + else: self.model = model_func() self.model.compile(loss=self.loss, optimizer=self.optimizer, metrics=metrics_func()) print(f"Strategy is working with: {strategy.num_replicas_in_sync} devices") @@ -41,8 +45,9 @@ def train(self, data): n_test = test.cardinality().numpy() # Distribute dataset - train = self.strategy.experimental_distribute_dataset(train) - test = self.strategy.experimental_distribute_dataset(test) + if self.strategy: + train = self.strategy.experimental_distribute_dataset(train) + test = self.strategy.experimental_distribute_dataset(test) # compute the steps per epoch for train and valid train_steps = n_train // self.batch_size diff --git a/use-cases/mnist/tensorflow/env-files/tensorflow-env.yml b/use-cases/mnist/tensorflow/env-files/tensorflow-env.yml index d326eb84..23963af5 100644 --- a/use-cases/mnist/tensorflow/env-files/tensorflow-env.yml +++ b/use-cases/mnist/tensorflow/env-files/tensorflow-env.yml @@ -3,7 +3,7 @@ channels: - tensorflow - conda-forge dependencies: - - tensorflow + - tensorflow-gpu - python=3.9.12 - pip: - tfx diff --git a/use-cases/mnist/tensorflow/trainer.py b/use-cases/mnist/tensorflow/trainer.py index b0a1b5ed..ff942c1a 100644 --- a/use-cases/mnist/tensorflow/trainer.py +++ b/use-cases/mnist/tensorflow/trainer.py @@ -27,7 +27,7 @@ def __init__( optimizer=keras.optimizers.get(optimizer), model_func=lambda: model, metrics_func=lambda: [], - strategy=tf.distribute.MirroredStrategy() + strategy=None ) def train(self, data):