Skip to content

Commit

Permalink
FIX: Dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
User3574 committed Jul 26, 2023
1 parent 35dd2c9 commit 75f6fd3
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions ai/src/itwinai/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import logging
import tensorflow as tf

from ..components import Trainer

class TensorflowTrainer(Trainer):
def __init__(self, strategy, loss, epochs, batch_size, callbacks, optimizer, model_func, metrics_func):
def __init__(self, loss, epochs, batch_size, callbacks, optimizer, model_func, metrics_func, strategy=tf.distribute.MirroredStrategy()):
self.strategy = strategy
self.epochs = epochs
self.batch_size = batch_size
self.loss = loss
self.callbacks = callbacks
self.global_batch_size = self.batch_size * self.strategy.num_replicas_in_sync
self.optimizer = optimizer

# Create distributed TF vars
Expand Down
2 changes: 1 addition & 1 deletion use-cases/cyclones/env-files/tensorflow-env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ dependencies:
- scipy==1.9.3
- scikit-learn
- gdown
- git+https://github.com/User3574/C_T6.git#egg=itwinai&subdirectory=ai
- git+https://github.com/interTwin-eu/T6.5-AI-and-ML.git@backend_full#egg=itwinai&subdirectory=ai

0 comments on commit 75f6fd3

Please sign in to comment.