diff --git a/ai/src/itwinai/backend/tensorflow/trainer.py b/ai/src/itwinai/backend/tensorflow/trainer.py index e8d3731b..0c6b78c0 100644 --- a/ai/src/itwinai/backend/tensorflow/trainer.py +++ b/ai/src/itwinai/backend/tensorflow/trainer.py @@ -1,15 +1,15 @@ import logging +import tensorflow as tf from ..components import Trainer class TensorflowTrainer(Trainer): - def __init__(self, strategy, loss, epochs, batch_size, callbacks, optimizer, model_func, metrics_func): + def __init__(self, loss, epochs, batch_size, callbacks, optimizer, model_func, metrics_func, strategy=tf.distribute.MirroredStrategy()): self.strategy = strategy self.epochs = epochs self.batch_size = batch_size self.loss = loss self.callbacks = callbacks - self.global_batch_size = self.batch_size * self.strategy.num_replicas_in_sync self.optimizer = optimizer # Create distributed TF vars diff --git a/use-cases/cyclones/env-files/tensorflow-env.yml b/use-cases/cyclones/env-files/tensorflow-env.yml index 920a1750..6d38c6ad 100644 --- a/use-cases/cyclones/env-files/tensorflow-env.yml +++ b/use-cases/cyclones/env-files/tensorflow-env.yml @@ -10,4 +10,4 @@ dependencies: - scipy==1.9.3 - scikit-learn - gdown - - git+https://github.com/User3574/C_T6.git#egg=itwinai&subdirectory=ai \ No newline at end of file + - git+https://github.com/interTwin-eu/T6.5-AI-and-ML.git@backend_full#egg=itwinai&subdirectory=ai \ No newline at end of file