Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MLflow log_model option #1544

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 94 additions & 45 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def _maybe_get_license_filename(

If the license file does not exist, returns None.
"""
# Early return if no local directory exists
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should never happen right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assuming that is correct, please remove

if not os.path.exists(local_dir):
return None

# Try to find the license file
try:
license_filename = next(
file for file in os.listdir(local_dir)
Expand Down Expand Up @@ -107,32 +112,82 @@ def _maybe_get_license_filename(
return None


def _register_model_with_run_id_multiprocess(
def _log_model_multiprocess(
mlflow_logger: MLFlowLogger,
composer_logging_level: int,
model_uri: str,
name: str,
python_logging_level: int,
transformers_model_path: str,
flavor: str,
input_example: dict[str, Any],
log_model_metadata: dict[str, str],
task: str,
await_creation_for: int,
registered_model_name: Optional[str] = None,
):
"""Call MLFlowLogger.register_model_with_run_id.

Used mainly to register from a child process.
"""
Call MLFlowLogger.log_model.

Used mainly to log from a child process.

Inputs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow docstring format from other places (e.g. Args: and the indentation and such

- mlflow_logger: MLFlowLogger: MLflow logger object
- python_logging_level: int: logging level
- flavor: str: transformers or peft
- input_example: dict[str, Any]: model serving input example for model
- log_model_metadata: dict[str, str]: This metadata is currently needed for optimized serving
- task: str: LLM task for model deployment (i.e. chat or completions)
- await_creation_for: int: time to wait for model creation
- registered_model_name: Optional
"""
# Setup logging for child process. This ensures that any logs from composer are surfaced.
if composer_logging_level > 0:
if python_logging_level > 0:
# If logging_level is 0, then the composer logger was unset.
logging.basicConfig(
format=
f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s',
)
logging.getLogger('composer').setLevel(composer_logging_level)

# Register model.
mlflow_logger.register_model_with_run_id(
model_uri=model_uri,
name=name,
await_creation_for=await_creation_for,
logging.getLogger('llmfoundry').setLevel(python_logging_level)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should still set composer log level too


# monkey patch to prevent duplicate tokenizer upload
import mlflow
mlflow.start_run(
run_id=mlflow_logger._run_id,
)
original_save_model = mlflow.transformers.save_model
def save_model_patch(*args: Any, **kwargs: Any):
original_save_model(*args, **kwargs)
tokenizer_files = []
# Check if there are duplicate tokenizer files in the model directory and remove them.
try:
for tokenizer_file_name in tokenizer_files:
dupe_file = os.path.isfile(os.path.join(kwargs['path'], 'model', tokenizer_file_name))
if dupe_file:
log.debug(f"Removing duplicate tokenizer file: {tokenizer_file_name}")
os.remove(os.path.join(kwargs['path'], 'model', tokenizer_file_name))
except Exception as e:
log.error(f"Exception when removing duplicate tokenizer files in the model directory", e)
mlflow.transformers.save_model = save_model_patch

if registered_model_name is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need a big if/else here. presumably log_model is fine with you passing registered_model_name=None

mlflow_logger.log_model(
transformers_model=transformers_model_path,
flavor=flavor,
artifact_path='model', # TODO: where should we define this parent dir name?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix?

input_example=input_example,
metadata=log_model_metadata,
task=task,
registered_model_name=registered_model_name, # not the full path? mlflow_logger.model_registry_prefix
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

full path or no?

await_creation_for=await_creation_for
)
else:
mlflow_logger.log_model(
transformers_model=transformers_model_path,
flavor=flavor,
artifact_path='model', # TODO: where should we define this parent dir name?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix?

input_example=input_example,
metadata=log_model_metadata,
task=task,
await_creation_for=await_creation_for
)


class HuggingFaceCheckpointer(Callback):
Expand Down Expand Up @@ -171,7 +226,7 @@ class HuggingFaceCheckpointer(Callback):

def __init__(
self,
save_folder: str,
save_folder: Optional[str],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is probably a change for part 2 not for this pr?

save_interval: Union[str, int, Time],
huggingface_folder_name: str = 'ba{batch}',
precision: str = 'float32',
Expand Down Expand Up @@ -618,22 +673,16 @@ def tensor_hook(

log.debug('Saving Hugging Face checkpoint to disk')

# This context manager casts the TE extra state in io.BytesIO format to tensor format
# Needed for proper hf ckpt saving.
context_manager = te.onnx_export(
True,
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
)
with context_manager:
new_model_instance.save_pretrained(temp_save_dir)
original_tokenizer.save_pretrained(temp_save_dir)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should move the next if statement out too (if new_model_instance....)

if upload_to_save_folder:
# This context manager casts the TE extra state in io.BytesIO format to tensor format
# Needed for proper hf ckpt saving.
context_manager = te.onnx_export(
True,
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
)
with context_manager:
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
assert isinstance(
original_tokenizer,
PreTrainedTokenizerBase,
)
original_tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
if new_model_instance.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
Expand Down Expand Up @@ -702,14 +751,6 @@ def tensor_hook(
True,
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
)
with context_manager:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if an equivalent to this is necessary or not, can you determine that?

# Add the pip requirements directly to avoid mlflow
# attempting to run inference on the model
model_saving_kwargs['pip_requirements'] = [
'transformers',
'torch',
]
mlflow_logger.save_model(**model_saving_kwargs)

# Upload the license file generated by mlflow during the model saving.
license_filename = _maybe_get_license_filename(
Expand All @@ -732,17 +773,25 @@ def tensor_hook(
monitor_process = None

# Spawn a new process to register the model.
# Slower method to register the model via log_model.
process = SpawnProcess(
target=_register_model_with_run_id_multiprocess,
target=_log_model_multiprocess,
kwargs={
'mlflow_logger':
mlflow_logger,
'composer_logging_level':
logging.getLogger('composer').level,
'model_uri':
local_save_path,
'name':
'transformers_model_path':
temp_save_dir,
'flavor':
'peft' if self.using_peft else 'transformers',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we haven't added peft support for log_model in Composer yet. Will need to think about how to handle this.

'python_logging_level':
logging.getLogger('llmfoundry').level,
'task':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should just be ** the self.mlflow_logging_config, rather than picking out pieces of it.

self.mlflow_logging_config['metadata']['task'],
'log_model_metadata': self.mlflow_logging_config['metadata'],
'registered_model_name':
self.mlflow_registered_model_name,
'input_example':
self.mlflow_logging_config['input_example'],
'await_creation_for':
3600,
},
Expand Down
Loading