-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Add support for mlflow #77
base: main
Are you sure you want to change the base?
Conversation
…on/neural-lam into feature_dataset_yaml
tracking metrics is disabled currently because neural-lam previously used a Logger.define_metrics method which isn't available with the mlflow logger in pytorch-lightning as far as I'm aware
WIP, mlflow logger still not working, but got wandb working with the pytorch_lightning wandb logger. |
This comment was marked as resolved.
This comment was marked as resolved.
I now got model metrics, system metrics and artifacts logging (including model logging) supported for mlflow. See e.g: However I get this warning:
I am calling a training_logger.log_model(model) which is def log_model(self, model):
mlflow.pytorch.log_model(model, "model") But I need to set the signature.
It should be possible to use Any thoughts @joeloskarsson, @sadamov, @TomasLandelius ? |
@khintz Thanks for adding mlflow to the list of loggers, it's nice to give the user more choice. And clearly you already got most of the work done 🚀 . About this warning you are seeing: I don't think manually specifying the signatures is a good idea, as it is too error prone. How long would it take to use a single example as a signature to pass to mlflow with smth like this: Modify class CustomMLFlowLogger(pl.loggers.MLFlowLogger):
def __init__(self, experiment_name, tracking_uri, data_module):
super().__init__(experiment_name=experiment_name, tracking_uri=tracking_uri)
mlflow.start_run(run_id=self.run_id, log_system_metrics=True)
mlflow.log_param("run_id", self.run_id)
self.data_module = data_module
def log_image(self, key, images):
from PIL import Image
temporary_image = f"{key}.png"
images[0].savefig(temporary_image)
mlflow.log_image(Image.open(temporary_image), f"{key}.png")
def log_model(self, model):
input_example = self.create_input_example()
with torch.no_grad():
model_output = model(*input_example)
#TODO: Are we sure we can hardcode the input names?
signature = infer_signature(
{name: tensor.cpu().numpy() for name, tensor in zip(['init_states', 'target_states', 'forcing', 'target_times'], input_example)},
model_output.cpu().numpy()
)
mlflow.pytorch.log_model(
model,
"model",
input_example=input_example,
signature=signature
)
def create_input_example(self):
if self.data_module.val_dataset is None:
self.data_module.setup(stage="fit")
return self.data_module.val_dataset[0] |
From my understanding you don't need to feed the whole dataset to the model to infer this signature, only one example batch. Going by this, something like what @sadamov proposed should work. However:
I agree. Optimally we would even get rid of the hard-coded argument names in the zip from @sadamov 's code (but I don't have an immediate idea how to do that). Something else to consider here is that there are additional important inputs that are necessary to make a forecast with the model (that do not enter as arguments when calling the
I don't know if these (or rather their shape) should be considered for the third part of the model signature ("Parameters (params)"), or somehow also viewed as part of the input. But I also fear that including these might just make this complex enough that this signature is no longer particularly useful. I think we should be motivated by how useful we actually find this signature to be. If we just want to get rid of the warning maybe we don't have to worry about these. |
Describe your changes
Add support for mlflow logger by utilising pytorch_lightning.loggers
The native wandb module is replaced with pytorch_lightning wandb logger and introducing pytorch_lightning mlflow logger.
https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/loggers/logger.py
This will allow people to choose between wandb and mlflow.
Builds upon #66 although this is not strictly necessary for this change, but I am working with this feature to work with our dataset.
Issue Link
Closes #76
Type of change
Checklist before requesting a review
pull
with--rebase
option if possible).Checklist for reviewers
Each PR comes with its own improvements and flaws. The reviewer should check the following:
Author checklist after completed review
reflecting type of change (add section where missing):
Checklist for assignee