Skip to content

Commit

Permalink
De-duplicate get_ts_context usages and move to ts_utils.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686540442
  • Loading branch information
cpgaffney1 authored and t5-copybara committed Oct 21, 2024
1 parent b642f30 commit 6390fa1
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions t5x/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,12 @@ def get_restore_parameters(
restore_args = jax.tree.map(lambda x: ocp.RestoreArgs(), structure)
flat_param_infos = {}
is_ocdbt_checkpoint = ocp.type_handlers.is_ocdbt_checkpoint(directory)
ts_context = ocp.type_handlers.get_ts_context()

if hasattr(ocp.serialization, 'ts_utils'):
ts_context = ocp.serialization.ts_utils.get_ts_context(
use_ocdbt=is_ocdbt_checkpoint
)
else:
ts_context = ocp.type_handlers.get_ts_context()
def _get_param_info(
name: str,
meta_or_value: Union[Any, ocp.metadata.tree.ValueMetadataEntry],
Expand Down

0 comments on commit 6390fa1

Please sign in to comment.