diff --git a/src/fairseq2/assets/download_manager.py b/src/fairseq2/assets/download_manager.py index 2900f4d41..fc97e1d7e 100644 --- a/src/fairseq2/assets/download_manager.py +++ b/src/fairseq2/assets/download_manager.py @@ -39,6 +39,7 @@ class AssetDownloadManager(ABC): def download_checkpoint( self, uri: str, + checksum: str | None, model_name: str, *, shard_idx: int | None = None, @@ -66,6 +67,7 @@ def download_checkpoint( def download_tokenizer( self, uri: str, + checksum: str | None, model_name: str, *, tokenizer_name: str | None = None, @@ -93,6 +95,7 @@ def download_tokenizer( def download_dataset( self, uri: str, + checksum: str | None, dataset_name: str, *, force: bool = False, @@ -135,6 +138,7 @@ def __init__(self) -> None: def download_checkpoint( self, uri: str, + checksum: str | None, model_name: str, *, shard_idx: int | None = None, @@ -150,12 +154,16 @@ def download_checkpoint( self._cache_dir, uri, display_name, force, progress, shard_idx ) + path = op.run() + self._validate_asset_integrity(path, checksum) + return op.run() @override def download_tokenizer( self, uri: str, + checksum: str | None, model_name: str, *, tokenizer_name: str | None = None, @@ -169,12 +177,16 @@ def download_tokenizer( op = _AssetDownloadOp(self._cache_dir, uri, display_name, force, progress) - return op.run() + path = op.run() + self._validate_asset_integrity(path, checksum) + + return path @override def download_dataset( self, uri: str, + checksum: str | None, dataset_name: str, *, force: bool = False, @@ -184,8 +196,30 @@ def download_dataset( op = _AssetDownloadOp(self._cache_dir, uri, display_name, force, progress) + path = op.run() + self._validate_asset_integrity(path, checksum) + return op.run() + def _validate_asset_integrity(self, path: Path, checksum: str | None) -> None: + if checksum is None: + log.warning( + f"Asset at {path} has no recorded checksum, skipping integrity check." + ) + return + + BYTES_PER_CHUNK = 65536 + sha = sha1() + + with open(path, "rb") as file: + while data := file.read(BYTES_PER_CHUNK): + sha.update(data) + + if sha.hexdigest() != checksum: + raise AssetDownloadError( + f"Checksum for {path} does not match the expected checksum." + ) + class _AssetDownloadOp: _cache_dir: Path diff --git a/src/fairseq2/data/text/text_tokenizer.py b/src/fairseq2/data/text/text_tokenizer.py index 62dedfe58..bc80ba307 100644 --- a/src/fairseq2/data/text/text_tokenizer.py +++ b/src/fairseq2/data/text/text_tokenizer.py @@ -223,10 +223,15 @@ def __call__( return self(tokenizer_ref, force=force, progress=progress) tokenizer_uri = card.field("tokenizer").as_uri() + tokenizer_checksum = card.field("checksum").get_as_(str) try: path = self._download_manager.download_tokenizer( - tokenizer_uri, card.name, force=force, progress=progress + tokenizer_uri, + tokenizer_checksum, + card.name, + force=force, + progress=progress, ) except ValueError as ex: raise AssetCardError( diff --git a/src/fairseq2/datasets/loader.py b/src/fairseq2/datasets/loader.py index 973a1f6cd..3d92a0e45 100644 --- a/src/fairseq2/datasets/loader.py +++ b/src/fairseq2/datasets/loader.py @@ -85,10 +85,11 @@ def __call__( card = self._asset_store.retrieve_card(dataset_name_or_card) dataset_uri = card.field("data").as_uri() + dataset_checksum = card.field("checksum").get_as_(str) try: path = self._download_manager.download_dataset( - dataset_uri, card.name, force=force, progress=progress + dataset_uri, dataset_checksum, card.name, force=force, progress=progress ) except ValueError as ex: raise AssetCardError( diff --git a/src/fairseq2/models/loader.py b/src/fairseq2/models/loader.py index 83cdb125a..6c6685b0d 100644 --- a/src/fairseq2/models/loader.py +++ b/src/fairseq2/models/loader.py @@ -275,12 +275,14 @@ def __call__( # Load the checkpoint. checkpoint_uri = card.field("checkpoint").as_uri() + checkpoint_checksum = card.field("checksum").get_as_(str) shard_idx = gang.rank if gang is not None and gang.size != 1 else None try: path = self._download_manager.download_checkpoint( checkpoint_uri, + checkpoint_checksum, card.name, shard_idx=shard_idx, force=force, diff --git a/tests/integration/models/test_llama.py b/tests/integration/models/test_llama.py index 45a110546..5148ebbd7 100644 --- a/tests/integration/models/test_llama.py +++ b/tests/integration/models/test_llama.py @@ -25,8 +25,11 @@ def test_convert_to_reference_checkpoint() -> None: card = default_asset_store.retrieve_card("llama2_7b") + checkpoint_uri = card.field("checkpoint").as_uri() + checkpoint_checksum = card.field("checksum").get_as_(str) + path = default_download_manager.download_checkpoint( - card.field("checkpoint").as_uri(), model_name="llama2_7b", progress=False + checkpoint_uri, checkpoint_checksum, model_name="llama2_7b", progress=False ) tensor_loader = StandardTensorLoader()