-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MLflow log_model
option
#1544
base: main
Are you sure you want to change the base?
Changes from all commits
06d77db
e40e5dd
454e18b
c8bd06f
0d3f9ce
6ea8de5
81306d8
b854bb2
bc73f65
1915042
bc29278
99589c7
04ddfaa
8e42217
be04e3d
5ab2cc7
79356d8
4327257
6c5fb05
bb0dd6a
c5ae4ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -76,6 +76,11 @@ def _maybe_get_license_filename( | |
|
||
If the license file does not exist, returns None. | ||
""" | ||
# Early return if no local directory exists | ||
if not os.path.exists(local_dir): | ||
return None | ||
|
||
# Try to find the license file | ||
try: | ||
license_filename = next( | ||
file for file in os.listdir(local_dir) | ||
|
@@ -107,32 +112,82 @@ def _maybe_get_license_filename( | |
return None | ||
|
||
|
||
def _register_model_with_run_id_multiprocess( | ||
def _log_model_multiprocess( | ||
mlflow_logger: MLFlowLogger, | ||
composer_logging_level: int, | ||
model_uri: str, | ||
name: str, | ||
python_logging_level: int, | ||
transformers_model_path: str, | ||
flavor: str, | ||
input_example: dict[str, Any], | ||
log_model_metadata: dict[str, str], | ||
task: str, | ||
await_creation_for: int, | ||
registered_model_name: Optional[str] = None, | ||
): | ||
"""Call MLFlowLogger.register_model_with_run_id. | ||
|
||
Used mainly to register from a child process. | ||
""" | ||
Call MLFlowLogger.log_model. | ||
|
||
Used mainly to log from a child process. | ||
|
||
Inputs: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Follow docstring format from other places (e.g. |
||
- mlflow_logger: MLFlowLogger: MLflow logger object | ||
- python_logging_level: int: logging level | ||
- flavor: str: transformers or peft | ||
- input_example: dict[str, Any]: model serving input example for model | ||
- log_model_metadata: dict[str, str]: This metadata is currently needed for optimized serving | ||
- task: str: LLM task for model deployment (i.e. chat or completions) | ||
- await_creation_for: int: time to wait for model creation | ||
- registered_model_name: Optional | ||
""" | ||
# Setup logging for child process. This ensures that any logs from composer are surfaced. | ||
if composer_logging_level > 0: | ||
if python_logging_level > 0: | ||
# If logging_level is 0, then the composer logger was unset. | ||
logging.basicConfig( | ||
format= | ||
f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', | ||
) | ||
logging.getLogger('composer').setLevel(composer_logging_level) | ||
|
||
# Register model. | ||
mlflow_logger.register_model_with_run_id( | ||
model_uri=model_uri, | ||
name=name, | ||
await_creation_for=await_creation_for, | ||
logging.getLogger('llmfoundry').setLevel(python_logging_level) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it should still set composer log level too |
||
|
||
# monkey patch to prevent duplicate tokenizer upload | ||
import mlflow | ||
mlflow.start_run( | ||
run_id=mlflow_logger._run_id, | ||
) | ||
original_save_model = mlflow.transformers.save_model | ||
def save_model_patch(*args: Any, **kwargs: Any): | ||
original_save_model(*args, **kwargs) | ||
tokenizer_files = [] | ||
# Check if there are duplicate tokenizer files in the model directory and remove them. | ||
try: | ||
for tokenizer_file_name in tokenizer_files: | ||
dupe_file = os.path.isfile(os.path.join(kwargs['path'], 'model', tokenizer_file_name)) | ||
if dupe_file: | ||
log.debug(f"Removing duplicate tokenizer file: {tokenizer_file_name}") | ||
os.remove(os.path.join(kwargs['path'], 'model', tokenizer_file_name)) | ||
except Exception as e: | ||
log.error(f"Exception when removing duplicate tokenizer files in the model directory", e) | ||
mlflow.transformers.save_model = save_model_patch | ||
|
||
if registered_model_name is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you don't need a big if/else here. presumably log_model is fine with you passing |
||
mlflow_logger.log_model( | ||
transformers_model=transformers_model_path, | ||
flavor=flavor, | ||
artifact_path='model', # TODO: where should we define this parent dir name? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fix? |
||
input_example=input_example, | ||
metadata=log_model_metadata, | ||
task=task, | ||
registered_model_name=registered_model_name, # not the full path? mlflow_logger.model_registry_prefix | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. full path or no? |
||
await_creation_for=await_creation_for | ||
) | ||
else: | ||
mlflow_logger.log_model( | ||
transformers_model=transformers_model_path, | ||
flavor=flavor, | ||
artifact_path='model', # TODO: where should we define this parent dir name? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fix? |
||
input_example=input_example, | ||
metadata=log_model_metadata, | ||
task=task, | ||
await_creation_for=await_creation_for | ||
) | ||
|
||
|
||
class HuggingFaceCheckpointer(Callback): | ||
|
@@ -171,7 +226,7 @@ class HuggingFaceCheckpointer(Callback): | |
|
||
def __init__( | ||
self, | ||
save_folder: str, | ||
save_folder: Optional[str], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is probably a change for part 2 not for this pr? |
||
save_interval: Union[str, int, Time], | ||
huggingface_folder_name: str = 'ba{batch}', | ||
precision: str = 'float32', | ||
|
@@ -618,22 +673,16 @@ def tensor_hook( | |
|
||
log.debug('Saving Hugging Face checkpoint to disk') | ||
|
||
# This context manager casts the TE extra state in io.BytesIO format to tensor format | ||
# Needed for proper hf ckpt saving. | ||
context_manager = te.onnx_export( | ||
True, | ||
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( | ||
) | ||
with context_manager: | ||
new_model_instance.save_pretrained(temp_save_dir) | ||
original_tokenizer.save_pretrained(temp_save_dir) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should move the next if statement out too ( |
||
if upload_to_save_folder: | ||
# This context manager casts the TE extra state in io.BytesIO format to tensor format | ||
# Needed for proper hf ckpt saving. | ||
context_manager = te.onnx_export( | ||
True, | ||
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( | ||
) | ||
with context_manager: | ||
new_model_instance.save_pretrained(temp_save_dir) | ||
if original_tokenizer is not None: | ||
assert isinstance( | ||
original_tokenizer, | ||
PreTrainedTokenizerBase, | ||
) | ||
original_tokenizer.save_pretrained(temp_save_dir) | ||
|
||
# Only need to edit files for MPT because it has custom code | ||
if new_model_instance.config.model_type == 'mpt': | ||
log.debug('Editing MPT files for HuggingFace compatibility') | ||
|
@@ -702,14 +751,6 @@ def tensor_hook( | |
True, | ||
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( | ||
) | ||
with context_manager: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure if an equivalent to this is necessary or not, can you determine that? |
||
# Add the pip requirements directly to avoid mlflow | ||
# attempting to run inference on the model | ||
model_saving_kwargs['pip_requirements'] = [ | ||
'transformers', | ||
'torch', | ||
] | ||
mlflow_logger.save_model(**model_saving_kwargs) | ||
|
||
# Upload the license file generated by mlflow during the model saving. | ||
license_filename = _maybe_get_license_filename( | ||
|
@@ -732,17 +773,25 @@ def tensor_hook( | |
monitor_process = None | ||
|
||
# Spawn a new process to register the model. | ||
# Slower method to register the model via log_model. | ||
process = SpawnProcess( | ||
target=_register_model_with_run_id_multiprocess, | ||
target=_log_model_multiprocess, | ||
kwargs={ | ||
'mlflow_logger': | ||
mlflow_logger, | ||
'composer_logging_level': | ||
logging.getLogger('composer').level, | ||
'model_uri': | ||
local_save_path, | ||
'name': | ||
'transformers_model_path': | ||
temp_save_dir, | ||
'flavor': | ||
'peft' if self.using_peft else 'transformers', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we haven't added peft support for |
||
'python_logging_level': | ||
logging.getLogger('llmfoundry').level, | ||
'task': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should just be ** the |
||
self.mlflow_logging_config['metadata']['task'], | ||
'log_model_metadata': self.mlflow_logging_config['metadata'], | ||
'registered_model_name': | ||
self.mlflow_registered_model_name, | ||
'input_example': | ||
self.mlflow_logging_config['input_example'], | ||
'await_creation_for': | ||
3600, | ||
}, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should never happen right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assuming that is correct, please remove