Skip to content

Commit

Permalink
UPDATE: Optional distribution, tensorflow-gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
User3574 committed Aug 1, 2023
1 parent 905bdbe commit 605df4c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 23 deletions.
34 changes: 28 additions & 6 deletions ai/src/itwinai/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,47 @@

from ..components import Trainer


class TensorflowTrainer(Trainer):
def __init__(self, loss, epochs, batch_size, callbacks, optimizer, model_func, metrics_func, strategy=tf.distribute.MirroredStrategy()):
def __init__(
self,
loss,
epochs,
batch_size,
callbacks,
optimizer,
model_func,
metrics_func,
strategy
):
self.strategy = strategy
self.epochs = epochs
self.batch_size = batch_size
self.loss = loss
self.callbacks = callbacks
self.optimizer = optimizer

# TODO: Wrap strategy into a class inheritance -> Strategy, MultiGPU(Strategy) -> Can be instantiated via conf
# Create distributed TF vars
with self.strategy.scope():
if self.strategy:
with self.strategy.scope():
self.model = model_func()
self.model.compile(loss=self.loss, optimizer=self.optimizer, metrics=metrics_func())
# Run locally
else:
self.model = model_func()
self.model.compile(loss=self.loss, optimizer=self.optimizer, metrics=metrics_func())
print(f"Strategy is working with: {strategy.num_replicas_in_sync} devices")

num_devices = strategy.num_replicas_in_sync if strategy else 1
print(f"Strategy is working with: {num_devices} devices")

def train(self, data):
(train, n_train), (test, n_test) = data
#train = self.strategy.experimental_distribute_dataset(train)
#test = self.strategy.experimental_distribute_dataset(test)

# Distribute datasets
if self.strategy:
train = self.strategy.experimental_distribute_dataset(train)
test = self.strategy.experimental_distribute_dataset(test)

# compute the steps per epoch for train and valid
train_steps = n_train // self.batch_size
Expand All @@ -44,4 +66,4 @@ def execute(self, args):
raise "Not implemented!"

def setup(self, args):
raise "Not implemented!"
raise "Not implemented!"
4 changes: 2 additions & 2 deletions use-cases/cyclones/env-files/tensorflow-env.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
name: tensorflow-env
channels:
- tensorflow
- tensorflow-gpu
- conda-forge
dependencies:
- tensorflow
- tensorflow-gpu
- python=3.9.12
- pip:
- tfx
Expand Down
31 changes: 16 additions & 15 deletions use-cases/cyclones/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,25 @@
)
from itwinai.backend.tensorflow.trainer import TensorflowTrainer


class CyclonesTrainer(TensorflowTrainer):
def __init__(
self,
RUN_DIR,
epochs,
network,
activation,
regularization_strength,
learning_rate: float,
loss,
channels,
batch_size,
patch_size,
kernel_size: int = None,
self,
RUN_DIR,
epochs,
network,
activation,
regularization_strength,
learning_rate: float,
loss,
channels,
batch_size,
patch_size,
kernel_size: int = None,
):
# Configurable
regularization_strength, regularizer = \
[rg.value for rg in RegularizationStrength if rg.name.lower() == regularization_strength][0]
[rg.value for rg in RegularizationStrength if rg.name.lower() == regularization_strength][0]
loss_name, loss = [l.value for l in Losses if l.name.lower() == loss][0]
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)

Expand All @@ -40,7 +41,7 @@ def __init__(
BENCHMARK_HISTORY_CSV = join(RUN_DIR, "benchmark_history.csv")

super().__init__(
strategy=tf.distribute.MirroredStrategy(),
strategy=None, # tf.distribute.MirroredStrategy(),
loss=loss,
epochs=epochs,
batch_size=batch_size,
Expand Down Expand Up @@ -73,7 +74,7 @@ def __init__(
kernel_size=kernel_size,
channels=channels,
),
metrics_func=lambda: [keras.metrics.MeanAbsoluteError(name="mae")]
metrics_func=lambda: [keras.metrics.MeanAbsoluteError(name="mae")],
)

def train(self, data):
Expand Down

0 comments on commit 605df4c

Please sign in to comment.