-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MLflow log_model
option
#1544
base: main
Are you sure you want to change the base?
Changes from 6 commits
06d77db
e40e5dd
454e18b
c8bd06f
0d3f9ce
6ea8de5
81306d8
b854bb2
bc73f65
1915042
bc29278
99589c7
04ddfaa
8e42217
be04e3d
5ab2cc7
79356d8
4327257
6c5fb05
bb0dd6a
c5ae4ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -114,7 +114,7 @@ def _register_model_with_run_id_multiprocess( | |
name: str, | ||
await_creation_for: int, | ||
): | ||
"""Call MLFlowLogger.register_model_with_run_id. | ||
"""Call MLFlowLogger.register_model. | ||
|
||
Used mainly to register from a child process. | ||
""" | ||
|
@@ -135,6 +135,88 @@ def _register_model_with_run_id_multiprocess( | |
) | ||
|
||
|
||
def _log_model_multiprocess( | ||
mlflow_logger: MLFlowLogger, | ||
composer_logging_level: int, | ||
flavor: str, | ||
input_example: dict[str, Any], | ||
log_model_metadata: dict[str, str], | ||
task: str, | ||
await_creation_for: int, | ||
registered_model_name: Optional[str] = None, | ||
): | ||
""" | ||
Call MLFlowLogger.log_model. | ||
|
||
Used mainly to log from a child process. | ||
|
||
Inputs: | ||
- mlflow_logger: MLFlowLogger: MLflow logger object | ||
- composer_logging_level: int: logging level for composer | ||
- flavor: str: transformers or peft | ||
- input_example: dict[str, Any]: model serving input example for model | ||
- log_model_metadata: dict[str, str]: This metadata is currently needed for optimized serving | ||
- task: str: LLM task for model deployment (i.e. chat or completions) | ||
- await_creation_for: int: time to wait for model creation | ||
- registered_model_name: Optional | ||
""" | ||
# Setup logging for child process. This ensures that any logs from composer are surfaced. | ||
if composer_logging_level > 0: | ||
# If logging_level is 0, then the composer logger was unset. | ||
logging.basicConfig( | ||
format= | ||
f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', | ||
) | ||
logging.getLogger('composer').setLevel(composer_logging_level) | ||
|
||
# monkey patch to prevent duplicate tokenizer upload | ||
import mlflow | ||
original_save_model = mlflow.transformers.save_model | ||
def save_model_patch(*args, **kwargs): | ||
original_save_model(*args, **kwargs) | ||
print(f"List of root path: {os.listdir(kwargs['path'])}") | ||
components_path = os.path.join(kwargs['path'], 'components') | ||
if os.path.exists(components_path): | ||
print(f"List of components path: {components_path}: {os.listdir(components_path)}") | ||
tokenizer_path = os.path.join(kwargs['path'], 'components', 'tokenizer') | ||
if os.path.exists(tokenizer_path): | ||
tokenizer_files = os.listdir(os.path.join(kwargs['path'], 'components', 'tokenizer')) | ||
print(f"Tokenizer files: {tokenizer_files}") | ||
try: | ||
print(f"List of model/model/ files: {os.listdir(os.path.join(kwargs['path'], 'model'))}") | ||
except Exception as e: | ||
print(f"exception", e) | ||
# TODO: what do we do here in code?? | ||
for tokenizer_file_name in tokenizer_files: | ||
try: | ||
dupe_file = os.path.isfile(os.path.join(kwargs['path'], 'model', tokenizer_file_name)) | ||
if dupe_file: | ||
os.remove(os.path.join(kwargs['path'], 'model', tokenizer_file_name)) | ||
except Exception as e: | ||
print(f"exception", e) | ||
mlflow.transformers.save_model = save_model_patch | ||
|
||
if registered_model_name is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you don't need a big if/else here. presumably log_model is fine with you passing |
||
mlflow_logger.log_model( | ||
flavor=flavor, | ||
artifact_path='model', # TODO: where should we define this parent dir name? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fix? |
||
input_example=input_example, | ||
metadata=log_model_metadata, | ||
task=task, | ||
registered_model_name=registered_model_name, | ||
await_creation_for=await_creation_for | ||
) | ||
else: | ||
mlflow_logger.log_model( | ||
flavor=flavor, | ||
artifact_path='model', # TODO: where should we define this parent dir name? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fix? |
||
input_example=input_example, | ||
metadata=log_model_metadata, | ||
task=task, | ||
await_creation_for=await_creation_for | ||
) | ||
|
||
|
||
class HuggingFaceCheckpointer(Callback): | ||
"""Save a huggingface formatted checkpoint during training. | ||
|
||
|
@@ -202,6 +284,7 @@ def __init__( | |
+ | ||
f'Defaulting to final_register_only=False and saving the HuggingFace checkpoint to {save_folder=}.', | ||
) | ||
self.use_mlflow_log_model = False | ||
|
||
# mlflow config setup | ||
if mlflow_logging_config is None: | ||
|
@@ -232,6 +315,8 @@ def __init__( | |
'input_example', | ||
default_input_example, | ||
) | ||
if mlflow_logging_config['use_mlflow_log_model']: | ||
self.use_mlflow_log_model = True | ||
|
||
self.mlflow_logging_config = mlflow_logging_config | ||
if 'metadata' in self.mlflow_logging_config: | ||
|
@@ -639,23 +724,29 @@ def tensor_hook( | |
self.flatten_imports, | ||
) | ||
|
||
if self.remote_ud is not None: | ||
nancyhung marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for filename in os.listdir(temp_save_dir): | ||
remote_file_name = os.path.join(save_dir, filename) | ||
remote_file_uri = self.remote_ud.remote_backend.get_uri( | ||
remote_file_name, | ||
) | ||
log.info( | ||
f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}', | ||
) | ||
self.remote_ud.upload_file( | ||
state=state, | ||
remote_file_name=remote_file_name, | ||
file_path=Path( | ||
os.path.join(temp_save_dir, filename), | ||
), | ||
overwrite=self.overwrite, | ||
# TODO: Log the model without registering | ||
for i, mlflow_logger in enumerate(self.mlflow_loggers): | ||
process = SpawnProcess( | ||
target=_log_model_multiprocess, | ||
kwargs={ | ||
'mlflow_logger': | ||
mlflow_logger, | ||
'flavor': | ||
'peft' if self.using_peft else 'transformers', | ||
'composer_logging_level': | ||
logging.getLogger('composer').level, | ||
'task': | ||
self.mlflow_logging_config['metadata']['task'], | ||
'log_model_metadata': self.mlflow_logging_config['metadata'], | ||
'model_name': | ||
self.pretrained_model_name, | ||
'input_example': | ||
self.mlflow_logging_config['input_example'], | ||
'await_creation_for': | ||
3600, | ||
}, | ||
) | ||
process.start() | ||
|
||
dist.barrier() | ||
|
||
|
@@ -729,22 +820,35 @@ def tensor_hook( | |
monitor_process = None | ||
|
||
# Spawn a new process to register the model. | ||
process = SpawnProcess( | ||
target=_register_model_with_run_id_multiprocess, | ||
kwargs={ | ||
'mlflow_logger': | ||
mlflow_logger, | ||
'composer_logging_level': | ||
logging.getLogger('composer').level, | ||
'model_uri': | ||
local_save_path, | ||
'name': | ||
self.mlflow_registered_model_name, | ||
'await_creation_for': | ||
3600, | ||
}, | ||
) | ||
process.start() | ||
# TODO: how do we fix intermediate checkpointing logic to use this too but not register | ||
# the model with that param | ||
# TODO: pass in model correctly | ||
# Slower method to register the model via log_model. | ||
if self.use_mlflow_log_model: | ||
nancyhung marked this conversation as resolved.
Show resolved
Hide resolved
|
||
print("----------------- REACHED MLFLOW LOG MODEL -----------------") | ||
process = SpawnProcess( | ||
target=_log_model_multiprocess, | ||
kwargs={ | ||
'mlflow_logger': | ||
mlflow_logger, | ||
'flavor': | ||
'peft' if self.using_peft else 'transformers', | ||
'composer_logging_level': | ||
logging.getLogger('composer').level, | ||
'task': | ||
self.mlflow_logging_config['metadata']['task'], | ||
'log_model_metadata': self.mlflow_logging_config['metadata'], | ||
'registered_model_name': | ||
self.mlflow_registered_model_name, | ||
'model_name': | ||
self.pretrained_model_name, | ||
'input_example': | ||
self.mlflow_logging_config['input_example'], | ||
'await_creation_for': | ||
3600, | ||
}, | ||
) | ||
process.start() | ||
|
||
# Restore the monitor process. | ||
if monitor_process is not None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Follow docstring format from other places (e.g.
Args:
and the indentation and such