Skip to content

Commit

Permalink
upload of artifact to mlflow works, but instantiates a new experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
khintz committed Oct 4, 2024
1 parent 3fbe2d0 commit e0284a8
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
55 changes: 51 additions & 4 deletions neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,24 @@
"hi_lam_parallel": HiLAMParallel,
}

class CustomMLFlowLogger(pl.loggers.MLFlowLogger):

def log_image(self, key, images):
import mlflow
import io
from PIL import Image
# Need to save the image to a temporary file, then log that file
# mlflow.log_image, should do this automatically, but it doesn't work
temporary_image = f"{key}.png"
images[0].savefig(temporary_image)

img = Image.open(temporary_image)
print(images)
print(images[0])
mlflow.log_image(img, f"{key}.png")

#mlflow.log_figure(images[0], key)


def _setup_training_logger(config, datastore, args, run_name):
if config.training.logger == "wandb":
Expand All @@ -36,17 +54,42 @@ def _setup_training_logger(config, datastore, args, run_name):
raise ValueError(
"MLFlow logger requires a URL to the MLFlow server"
)
logger = pl.loggers.MLFlowLogger(
# logger = pl.loggers.MLFlowLogger(
# experiment_name=args.wandb_project,
# tracking_uri=url,
# )
logger = CustomMLFlowLogger(
experiment_name=args.wandb_project,
tracking_uri=url,
)
print(logger)
logger.log_hyperparams(
dict(training=vars(args), datastore=datastore._config)
)
print("Logged hyperparams")
print(run_name)

print(logger.__str__)
# logger.log_image = log_image

return logger


# def log_image(key, images):
# import mlflow

# # Log the image
# # https://learn.microsoft.com/en-us/azure/machine-learning/how-to-log-view-metrics?view=azureml-api-2&tabs=interactive#log-images
# # For mlflow a matplotlib figure should use log_figure instead of log_image
# # Need to save the image to a temporary file, then log that file
# # mlflow.log_image, should do this automatically, but it doesn't work
# temporary_image = f"/tmp/key.png"
# images[0].savefig(temporary_image)

# mlflow.log_figure(temporary_image, key)
# mlflow.log_figure(images[0], key)


@logger.catch
def main(input_args=None):
"""Main function for training and evaluating models."""
Expand Down Expand Up @@ -301,15 +344,18 @@ def main(input_args=None):
trainer = pl.Trainer(
max_epochs=args.epochs,
deterministic=True,
strategy="ddp",
#strategy="ddp",
#devices=2,
devices=[1, 3],
strategy="auto",
accelerator=device_name,
logger=training_logger,
log_every_n_steps=1,
callbacks=[checkpoint_callback],
check_val_every_n_epoch=args.val_interval,
precision=args.precision,
)

import ipdb
# Only init once, on rank 0 only
if trainer.global_rank == 0:
utils.init_training_logger_metrics(
Expand All @@ -318,7 +364,8 @@ def main(input_args=None):
if args.eval:
trainer.test(model=model, datamodule=data_module, ckpt_path=args.load)
else:
trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load)
with ipdb.launch_ipdb_on_exception():
trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"dataclass-wizard>=0.22.3",
"mllam-data-prep @ git+https://github.com/leifdenby/mllam-data-prep/@feat/extra-section-in-config",
"mlflow>=2.16.2",
"boto3>=1.35.32",
]
requires-python = ">=3.9"

Expand Down

0 comments on commit e0284a8

Please sign in to comment.