Skip to content

Commit

Permalink
print the registered model name
Browse files Browse the repository at this point in the history
  • Loading branch information
nancyhung committed Oct 26, 2024
1 parent be04e3d commit 5ab2cc7
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _maybe_get_license_filename(

def _log_model_multiprocess(
mlflow_logger: MLFlowLogger,
composer_logging_level: int,
python_logging_level: int,
transformers_model_path: str,
flavor: str,
input_example: dict[str, Any],
Expand All @@ -130,7 +130,7 @@ def _log_model_multiprocess(
Inputs:
- mlflow_logger: MLFlowLogger: MLflow logger object
- composer_logging_level: int: logging level for composer
- python_logging_level: int: logging level
- flavor: str: transformers or peft
- input_example: dict[str, Any]: model serving input example for model
- log_model_metadata: dict[str, str]: This metadata is currently needed for optimized serving
Expand All @@ -139,13 +139,13 @@ def _log_model_multiprocess(
- registered_model_name: Optional
"""
# Setup logging for child process. This ensures that any logs from composer are surfaced.
if composer_logging_level > 0:
if python_logging_level > 0:
# If logging_level is 0, then the composer logger was unset.
logging.basicConfig(
format=
f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s',
)
logging.getLogger('composer').setLevel(composer_logging_level)
logging.getLogger('llmfoundry').setLevel(python_logging_level)

log.info("----------------- REACHED MLFLOW LOG MODEL -----------------")
# monkey patch to prevent duplicate tokenizer upload
Expand Down Expand Up @@ -678,6 +678,7 @@ def tensor_hook(

log.debug('Saving Hugging Face checkpoint to disk')

log.debug(f"UPLOAD_TO_SAVE_FOLDER: {upload_to_save_folder}")
if upload_to_save_folder:
# This context manager casts the TE extra state in io.BytesIO format to tensor format
# Needed for proper hf ckpt saving.
Expand Down Expand Up @@ -785,7 +786,7 @@ def tensor_hook(

# Spawn a new process to register the model.
# Slower method to register the model via log_model.
log.info('USING MY BRANCH!!!!!!!!!!!!!!')
log.info(f'USING MY BRANCH!!!!!!!!!!!!!! REGISTERED MODEL NAME: {self.mlflow_registered_model_name}')
process = SpawnProcess(
target=_log_model_multiprocess,
kwargs={
Expand All @@ -795,8 +796,8 @@ def tensor_hook(
temp_save_dir,
'flavor':
'peft' if self.using_peft else 'transformers',
'composer_logging_level':
logging.getLogger('composer').level,
'python_logging_level':
logging.getLogger('llmfoundry').level,
'task':
self.mlflow_logging_config['metadata']['task'],
'log_model_metadata': self.mlflow_logging_config['metadata'],
Expand Down

0 comments on commit 5ab2cc7

Please sign in to comment.