Skip to content

Commit

Permalink
ADD: Optional strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
User3574 committed Aug 1, 2023
1 parent 114a4c0 commit 0a501f9
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
13 changes: 9 additions & 4 deletions ai/src/itwinai/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(
optimizer,
model_func,
metrics_func,
strategy=tf.distribute.MirroredStrategy()
strategy
):
self.strategy = strategy
self.epochs = epochs
Expand All @@ -24,7 +24,11 @@ def __init__(
self.optimizer = optimizer

# Create distributed TF vars
with self.strategy.scope():
if self.strategy:
with self.strategy.scope():
self.model = model_func()
self.model.compile(loss=self.loss, optimizer=self.optimizer, metrics=metrics_func())
else:
self.model = model_func()
self.model.compile(loss=self.loss, optimizer=self.optimizer, metrics=metrics_func())
print(f"Strategy is working with: {strategy.num_replicas_in_sync} devices")
Expand All @@ -41,8 +45,9 @@ def train(self, data):
n_test = test.cardinality().numpy()

# Distribute dataset
train = self.strategy.experimental_distribute_dataset(train)
test = self.strategy.experimental_distribute_dataset(test)
if self.strategy:
train = self.strategy.experimental_distribute_dataset(train)
test = self.strategy.experimental_distribute_dataset(test)

# compute the steps per epoch for train and valid
train_steps = n_train // self.batch_size
Expand Down
2 changes: 1 addition & 1 deletion use-cases/mnist/tensorflow/env-files/tensorflow-env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ channels:
- tensorflow
- conda-forge
dependencies:
- tensorflow
- tensorflow-gpu
- python=3.9.12
- pip:
- tfx
Expand Down
2 changes: 1 addition & 1 deletion use-cases/mnist/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
optimizer=keras.optimizers.get(optimizer),
model_func=lambda: model,
metrics_func=lambda: [],
strategy=tf.distribute.MirroredStrategy()
strategy=None
)

def train(self, data):
Expand Down

0 comments on commit 0a501f9

Please sign in to comment.