-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MLflow log_model
option
#1544
base: main
Are you sure you want to change the base?
Conversation
…cleans up the code a little and prevents us from having forked logic in Composer to fetch by run_id
…cleans up the code a little and prevents us from having forked logic in Composer to fetch by run_id
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What testing have you done? We need to make sure everything e2e shows up properly
log_model
optionlog_model
option
log_model
optionlog_model
option
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -76,6 +76,11 @@ def _maybe_get_license_filename( | |||
|
|||
If the license file does not exist, returns None. | |||
""" | |||
# Early return if no local directory exists |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should never happen right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assuming that is correct, please remove
|
||
Used mainly to log from a child process. | ||
|
||
Inputs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Follow docstring format from other places (e.g. Args:
and the indentation and such
model_uri=model_uri, | ||
name=name, | ||
await_creation_for=await_creation_for, | ||
logging.getLogger('llmfoundry').setLevel(python_logging_level) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should still set composer log level too
mlflow_logger.log_model( | ||
transformers_model=transformers_model_path, | ||
flavor=flavor, | ||
artifact_path='model', # TODO: where should we define this parent dir name? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix?
mlflow_logger.log_model( | ||
transformers_model=transformers_model_path, | ||
flavor=flavor, | ||
artifact_path='model', # TODO: where should we define this parent dir name? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix?
@@ -171,7 +226,7 @@ class HuggingFaceCheckpointer(Callback): | |||
|
|||
def __init__( | |||
self, | |||
save_folder: str, | |||
save_folder: Optional[str], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is probably a change for part 2 not for this pr?
) | ||
with context_manager: | ||
new_model_instance.save_pretrained(temp_save_dir) | ||
original_tokenizer.save_pretrained(temp_save_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should move the next if statement out too (if new_model_instance....
)
@@ -702,14 +751,6 @@ def tensor_hook( | |||
True, | |||
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( | |||
) | |||
with context_manager: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if an equivalent to this is necessary or not, can you determine that?
'transformers_model_path': | ||
temp_save_dir, | ||
'flavor': | ||
'peft' if self.using_peft else 'transformers', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we haven't added peft support for log_model
in Composer yet. Will need to think about how to handle this.
'peft' if self.using_peft else 'transformers', | ||
'python_logging_level': | ||
logging.getLogger('llmfoundry').level, | ||
'task': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should just be ** the self.mlflow_logging_config
, rather than picking out pieces of it.
Context
In order to support customers with sensitive storage network configurations, we have to use the
log_model
API. This will cause duplicate artifact uploads, which is not efficient, so we will only reserve rolling out to customers who require this.This PR contains the first of 2 changes:
log_model
instead of uploading to MLflow artifacts.save_model
,register_model
, and uploading to UC directly via the remote uploader downloader object, this change simplifies the control logic with themlflow.log_model
function. This function is also critical to support secure training requirements, such as customer firewalls or private endpoints. Logging a model to MLflow will call the necessary steps to save and register a model for deployment.log_model
but not register the model. That way, a user can still manually register their intermediate checkpoints for evaluation.Testing
When incorporating this in MAPI, we should enable
final_register_only
to only upload using thelog_model
logic instead of uploading a duplicate copy to MLflow artifacts. All tests were done in AWS staging.Works for older models
[Databricks staging] Llama3 8b
Run:
llama3-log-model-xusOti
Llama3 8b was able to be successfully deployed here: https://e2-dogfood.staging.cloud.databricks.com/ml/endpoints/test-log-model?o=6051921418418893.
Works for newest models with extra security
[MCT] Llama3.2 1b
Run:
llama3-log-model-O50ClW
Experiment: https://dbc-559ffd80-2bfc.cloud.databricks.com/ml/experiments/2854093459220376?viewStateShareKey=55a332dc80d7200b6a6301d8f0163155ce9aac54d21436c9d292f0745e0bff05
Endpoint: https://dbc-559ffd80-2bfc.cloud.databricks.com/ml/endpoints/test-llama321b?o=7395834863327820