Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MLflow log_model option #1544

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 137 additions & 33 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _register_model_with_run_id_multiprocess(
name: str,
await_creation_for: int,
):
"""Call MLFlowLogger.register_model_with_run_id.
"""Call MLFlowLogger.register_model.

Used mainly to register from a child process.
"""
Expand All @@ -135,6 +135,88 @@ def _register_model_with_run_id_multiprocess(
)


def _log_model_multiprocess(
mlflow_logger: MLFlowLogger,
composer_logging_level: int,
flavor: str,
input_example: dict[str, Any],
log_model_metadata: dict[str, str],
task: str,
await_creation_for: int,
registered_model_name: Optional[str] = None,
):
"""
Call MLFlowLogger.log_model.

Used mainly to log from a child process.

Inputs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow docstring format from other places (e.g. Args: and the indentation and such

- mlflow_logger: MLFlowLogger: MLflow logger object
- composer_logging_level: int: logging level for composer
- flavor: str: transformers or peft
- input_example: dict[str, Any]: model serving input example for model
- log_model_metadata: dict[str, str]: This metadata is currently needed for optimized serving
- task: str: LLM task for model deployment (i.e. chat or completions)
- await_creation_for: int: time to wait for model creation
- registered_model_name: Optional
"""
# Setup logging for child process. This ensures that any logs from composer are surfaced.
if composer_logging_level > 0:
# If logging_level is 0, then the composer logger was unset.
logging.basicConfig(
format=
f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s',
)
logging.getLogger('composer').setLevel(composer_logging_level)

# monkey patch to prevent duplicate tokenizer upload
import mlflow
original_save_model = mlflow.transformers.save_model
def save_model_patch(*args, **kwargs):
original_save_model(*args, **kwargs)
print(f"List of root path: {os.listdir(kwargs['path'])}")
components_path = os.path.join(kwargs['path'], 'components')
if os.path.exists(components_path):
print(f"List of components path: {components_path}: {os.listdir(components_path)}")
tokenizer_path = os.path.join(kwargs['path'], 'components', 'tokenizer')
if os.path.exists(tokenizer_path):
tokenizer_files = os.listdir(os.path.join(kwargs['path'], 'components', 'tokenizer'))
print(f"Tokenizer files: {tokenizer_files}")
try:
print(f"List of model/model/ files: {os.listdir(os.path.join(kwargs['path'], 'model'))}")
except Exception as e:
print(f"exception", e)
# TODO: what do we do here in code??
for tokenizer_file_name in tokenizer_files:
try:
dupe_file = os.path.isfile(os.path.join(kwargs['path'], 'model', tokenizer_file_name))
if dupe_file:
os.remove(os.path.join(kwargs['path'], 'model', tokenizer_file_name))
except Exception as e:
print(f"exception", e)
mlflow.transformers.save_model = save_model_patch

if registered_model_name is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need a big if/else here. presumably log_model is fine with you passing registered_model_name=None

mlflow_logger.log_model(
flavor=flavor,
artifact_path='model', # TODO: where should we define this parent dir name?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix?

input_example=input_example,
metadata=log_model_metadata,
task=task,
registered_model_name=registered_model_name,
await_creation_for=await_creation_for
)
else:
mlflow_logger.log_model(
flavor=flavor,
artifact_path='model', # TODO: where should we define this parent dir name?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix?

input_example=input_example,
metadata=log_model_metadata,
task=task,
await_creation_for=await_creation_for
)


class HuggingFaceCheckpointer(Callback):
"""Save a huggingface formatted checkpoint during training.

Expand Down Expand Up @@ -202,6 +284,7 @@ def __init__(
+
f'Defaulting to final_register_only=False and saving the HuggingFace checkpoint to {save_folder=}.',
)
self.use_mlflow_log_model = False

# mlflow config setup
if mlflow_logging_config is None:
Expand Down Expand Up @@ -232,6 +315,8 @@ def __init__(
'input_example',
default_input_example,
)
if mlflow_logging_config['use_mlflow_log_model']:
self.use_mlflow_log_model = True

self.mlflow_logging_config = mlflow_logging_config
if 'metadata' in self.mlflow_logging_config:
Expand Down Expand Up @@ -639,23 +724,29 @@ def tensor_hook(
self.flatten_imports,
)

if self.remote_ud is not None:
nancyhung marked this conversation as resolved.
Show resolved Hide resolved
for filename in os.listdir(temp_save_dir):
remote_file_name = os.path.join(save_dir, filename)
remote_file_uri = self.remote_ud.remote_backend.get_uri(
remote_file_name,
)
log.info(
f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}',
)
self.remote_ud.upload_file(
state=state,
remote_file_name=remote_file_name,
file_path=Path(
os.path.join(temp_save_dir, filename),
),
overwrite=self.overwrite,
# TODO: Log the model without registering
for i, mlflow_logger in enumerate(self.mlflow_loggers):
process = SpawnProcess(
target=_log_model_multiprocess,
kwargs={
'mlflow_logger':
mlflow_logger,
'flavor':
'peft' if self.using_peft else 'transformers',
'composer_logging_level':
logging.getLogger('composer').level,
'task':
self.mlflow_logging_config['metadata']['task'],
'log_model_metadata': self.mlflow_logging_config['metadata'],
'model_name':
self.pretrained_model_name,
'input_example':
self.mlflow_logging_config['input_example'],
'await_creation_for':
3600,
},
)
process.start()

dist.barrier()

Expand Down Expand Up @@ -729,22 +820,35 @@ def tensor_hook(
monitor_process = None

# Spawn a new process to register the model.
process = SpawnProcess(
target=_register_model_with_run_id_multiprocess,
kwargs={
'mlflow_logger':
mlflow_logger,
'composer_logging_level':
logging.getLogger('composer').level,
'model_uri':
local_save_path,
'name':
self.mlflow_registered_model_name,
'await_creation_for':
3600,
},
)
process.start()
# TODO: how do we fix intermediate checkpointing logic to use this too but not register
# the model with that param
# TODO: pass in model correctly
# Slower method to register the model via log_model.
if self.use_mlflow_log_model:
nancyhung marked this conversation as resolved.
Show resolved Hide resolved
print("----------------- REACHED MLFLOW LOG MODEL -----------------")
process = SpawnProcess(
target=_log_model_multiprocess,
kwargs={
'mlflow_logger':
mlflow_logger,
'flavor':
'peft' if self.using_peft else 'transformers',
'composer_logging_level':
logging.getLogger('composer').level,
'task':
self.mlflow_logging_config['metadata']['task'],
'log_model_metadata': self.mlflow_logging_config['metadata'],
'registered_model_name':
self.mlflow_registered_model_name,
'model_name':
self.pretrained_model_name,
'input_example':
self.mlflow_logging_config['input_example'],
'await_creation_for':
3600,
},
)
process.start()

# Restore the monitor process.
if monitor_process is not None:
Expand Down
Loading