diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 6f97f5e8..bcf920da 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -22,6 +22,24 @@ "hi_lam_parallel": HiLAMParallel, } +class CustomMLFlowLogger(pl.loggers.MLFlowLogger): + + def log_image(self, key, images): + import mlflow + import io + from PIL import Image + # Need to save the image to a temporary file, then log that file + # mlflow.log_image, should do this automatically, but it doesn't work + temporary_image = f"{key}.png" + images[0].savefig(temporary_image) + + img = Image.open(temporary_image) + print(images) + print(images[0]) + mlflow.log_image(img, f"{key}.png") + + #mlflow.log_figure(images[0], key) + def _setup_training_logger(config, datastore, args, run_name): if config.training.logger == "wandb": @@ -36,17 +54,42 @@ def _setup_training_logger(config, datastore, args, run_name): raise ValueError( "MLFlow logger requires a URL to the MLFlow server" ) - logger = pl.loggers.MLFlowLogger( + # logger = pl.loggers.MLFlowLogger( + # experiment_name=args.wandb_project, + # tracking_uri=url, + # ) + logger = CustomMLFlowLogger( experiment_name=args.wandb_project, tracking_uri=url, ) + print(logger) logger.log_hyperparams( dict(training=vars(args), datastore=datastore._config) ) + print("Logged hyperparams") + print(run_name) + + print(logger.__str__) + # logger.log_image = log_image return logger +# def log_image(key, images): +# import mlflow + +# # Log the image +# # https://learn.microsoft.com/en-us/azure/machine-learning/how-to-log-view-metrics?view=azureml-api-2&tabs=interactive#log-images +# # For mlflow a matplotlib figure should use log_figure instead of log_image +# # Need to save the image to a temporary file, then log that file +# # mlflow.log_image, should do this automatically, but it doesn't work +# temporary_image = f"/tmp/key.png" +# images[0].savefig(temporary_image) + +# mlflow.log_figure(temporary_image, key) +# mlflow.log_figure(images[0], key) + + @logger.catch def main(input_args=None): """Main function for training and evaluating models.""" @@ -301,7 +344,10 @@ def main(input_args=None): trainer = pl.Trainer( max_epochs=args.epochs, deterministic=True, - strategy="ddp", + #strategy="ddp", + #devices=2, + devices=[1, 3], + strategy="auto", accelerator=device_name, logger=training_logger, log_every_n_steps=1, @@ -309,7 +355,7 @@ def main(input_args=None): check_val_every_n_epoch=args.val_interval, precision=args.precision, ) - + import ipdb # Only init once, on rank 0 only if trainer.global_rank == 0: utils.init_training_logger_metrics( @@ -318,7 +364,8 @@ def main(input_args=None): if args.eval: trainer.test(model=model, datamodule=data_module, ckpt_path=args.load) else: - trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load) + with ipdb.launch_ipdb_on_exception(): + trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index b723e322..be26adb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "dataclass-wizard>=0.22.3", "mllam-data-prep @ git+https://github.com/leifdenby/mllam-data-prep/@feat/extra-section-in-config", "mlflow>=2.16.2", + "boto3>=1.35.32", ] requires-python = ">=3.9"