diff --git a/invokeai/app/services/image_files/image_files_disk.py b/invokeai/app/services/image_files/image_files_disk.py index 135df836103..e5bfd72781d 100644 --- a/invokeai/app/services/image_files/image_files_disk.py +++ b/invokeai/app/services/image_files/image_files_disk.py @@ -110,15 +110,26 @@ def delete(self, image_name: str) -> None: except Exception as e: raise ImageFileDeleteException from e - # TODO: make this a bit more flexible for e.g. cloud storage def get_path(self, image_name: str, thumbnail: bool = False) -> Path: - path = self.__output_folder / image_name + base_folder = self.__thumbnails_folder if thumbnail else self.__output_folder + filename = get_thumbnail_name(image_name) if thumbnail else image_name - if thumbnail: - thumbnail_name = get_thumbnail_name(image_name) - path = self.__thumbnails_folder / thumbnail_name + # Strip any path information from the filename + basename = Path(filename).name + + if basename != filename: + raise ValueError("Invalid image name, potential directory traversal detected") + + image_path = base_folder / basename + + # Ensure the image path is within the base folder to prevent directory traversal + resolved_base = base_folder.resolve() + resolved_image_path = image_path.resolve() + + if not resolved_image_path.is_relative_to(resolved_base): + raise ValueError("Image path outside outputs folder, potential directory traversal detected") - return path + return resolved_image_path def validate_path(self, path: Union[str, Path]) -> bool: """Validates the path given for an image or thumbnail.""" diff --git a/tests/app/services/image_files/test_image_files_disk.py b/tests/app/services/image_files/test_image_files_disk.py new file mode 100644 index 00000000000..4304b008c87 --- /dev/null +++ b/tests/app/services/image_files/test_image_files_disk.py @@ -0,0 +1,51 @@ +import platform +from pathlib import Path + +import pytest + +from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage + + +@pytest.fixture +def image_names() -> list[str]: + # Determine the platform and return a path that matches its format + if platform.system() == "Windows": + return [ + # Relative paths + "folder\\evil.txt", + "folder\\..\\evil.txt", + # Absolute paths + "\\folder\\evil.txt", + "C:\\folder\\..\\evil.txt", + ] + else: + return [ + # Relative paths + "folder/evil.txt", + "folder/../evil.txt", + # Absolute paths + "/folder/evil.txt", + "/folder/../evil.txt", + ] + + +def test_directory_traversal_protection(tmp_path: Path, image_names: list[str]): + """Test that the image file storage prevents directory traversal attacks. + + There are two safeguards in the `DiskImageFileStorage.get_path` method: + 1. Check if the image name contains any directory traversal characters + 2. Check if the resulting path is relative to the base folder + + This test checks the first safeguard. I'd like to check the second but I cannot figure out a test case that would + pass the first check but fail the second check. + """ + image_files_disk = DiskImageFileStorage(tmp_path) + for name in image_names: + with pytest.raises(ValueError, match="Invalid image name, potential directory traversal detected"): + image_files_disk.get_path(name) + + +def test_image_paths_relative_to_storage_dir(tmp_path: Path): + image_files_disk = DiskImageFileStorage(tmp_path) + path = image_files_disk.get_path("foo.png") + assert path.is_relative_to(tmp_path)