Skip to content

Commit

Permalink
Addition of unit tests for tar extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
Jake Skipper committed Jul 31, 2023
1 parent 75b45c0 commit f955d9e
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 48 deletions.
98 changes: 50 additions & 48 deletions runway/cfngin/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,54 @@ def parse_cloudformation_template(template: str) -> Dict[str, Any]:
return yaml_parse(template)


def is_within_directory(directory: Path | str, target: str) -> bool:
"""Check if file is in directory.
Determines if the provided path is within a specific directory or its subdirectories.
Args:
directory (Union[Path, str]): Path of the directory we're checking.
target (str): Path of the file we're checking for containment.
Returns:
bool: True if the target is in the directory or subdirectories, False otherwise.
"""
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory


def safe_tar_extract(
tar: tarfile.TarFile,
path: Path | str = ".",
members: list[tarfile.TarInfo] | None = None,
*,
numeric_owner: bool = False,
):
"""Safely extract the contents of a tar file to a specified directory.
This code is modified from a PR provided to Runway repo by
Trellix to address CVE-2007-4559.
Args:
tar (TarFile): The tar file object that will be extracted.
path (Union[Path, str], optional): The directory to extract the tar into.
members (List[TarInfo] | None, optional): List of TarInfo objects to extract.
numeric_owner (bool, optional): Enable usage of owner and group IDs when extracting.
Raises:
Exception: If any tar file tries to go outside the specified area.
"""
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File")
tar.extractall(path, members, numeric_owner=numeric_owner)


class Extractor:
"""Base class for extractors."""

Expand Down Expand Up @@ -550,30 +598,7 @@ class TarExtractor(Extractor):
def extract(self, destination: Path) -> None:
"""Extract the archive."""
with tarfile.open(self.archive, "r:") as tar:
# tar.extractall(path=destination)
#
# Implementing fix suggested by Trellix for
# CVE-2007-4559.
def is_within_directory(directory: Path | str, target: str) -> bool:
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory

def safe_extract(
tar: tarfile.TarFile,
path: Path | str = ".",
members: list[tarfile.TarInfo] | None = None,
*,
numeric_owner: bool = False,
):
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File")
tar.extractall(path, members, numeric_owner=numeric_owner)

safe_extract(tar, path=destination)
safe_tar_extract(tar, path=destination)

Check warning on line 601 in runway/cfngin/utils.py

View check run for this annotation

Codecov / codecov/patch

runway/cfngin/utils.py#L601

Added line #L601 was not covered by tests


class TarGzipExtractor(Extractor):
Expand All @@ -584,30 +609,7 @@ class TarGzipExtractor(Extractor):
def extract(self, destination: Path) -> None:
"""Extract the archive."""
with tarfile.open(self.archive, "r:gz") as tar:
# tar.extractall(path=destination)
#
# Implementing fix suggested by Trellix for
# CVE-2007-4559.
def is_within_directory(directory: Path | str, target: str) -> bool:
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory

def safe_extract(
tar: tarfile.TarFile,
path: Path | str = ".",
members: list[tarfile.TarInfo] | None = None,
*,
numeric_owner: bool = False,
):
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File")
tar.extractall(path, members, numeric_owner=numeric_owner)

safe_extract(tar, path=destination)
safe_tar_extract(tar, path=destination)

Check warning on line 612 in runway/cfngin/utils.py

View check run for this annotation

Codecov / codecov/patch

runway/cfngin/utils.py#L612

Added line #L612 was not covered by tests


class ZipExtractor(Extractor):
Expand Down
52 changes: 52 additions & 0 deletions tests/unit/cfngin/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from __future__ import annotations

import logging
import os
import shutil
import tarfile
import tempfile
import unittest
from pathlib import Path
Expand All @@ -28,9 +30,11 @@
ensure_s3_bucket,
get_client_region,
get_s3_endpoint,
is_within_directory,
parse_cloudformation_template,
read_value_from_path,
s3_bucket_location_constraint,
safe_tar_extract,
yaml_to_ordered_dict,
)
from runway.config.models.cfngin import GitCfnginPackageSourceDefinitionModel
Expand Down Expand Up @@ -267,6 +271,19 @@ def setUp(self) -> None:
"""Set up test case."""
self.tmp_path = Path(tempfile.mkdtemp())

# Set up files for testing tar file extraction
self.tar_file = Path("test_tar_file.tar")
test_tar_file1 = self.tmp_path / "file1.txt"
test_tar_file1.write_text("Content of file1")
test_tar_subdir = self.tmp_path / "subdir"
test_tar_subdir.mkdir()
test_tar_file2 = test_tar_subdir / "file2.txt"
test_tar_file2.write_text("Content of file2")

# Create a tar file using the temporary directory
with tarfile.open(self.tmp_path / self.tar_file, "w") as tar:
tar.add(self.tmp_path, arcname=os.path.basename(self.tmp_path))

def tearDown(self) -> None:
"""Tear down test case."""
shutil.rmtree(self.tmp_path, ignore_errors=True)
Expand Down Expand Up @@ -358,6 +375,41 @@ def test_parse_cloudformation_template(self) -> None:
}
self.assertEqual(parse_cloudformation_template(template), parsed_template)

def test_is_within_directory(self):
"""Test is within directory."""
directory = Path("my_directory")

# Assert if the target is within the directory.
target = "my_directory/sub_directory/file.txt"
self.assertTrue(is_within_directory(directory, target))

# Assert if the target is NOT within the directory.
target = "other_directory/file.txt"
self.assertFalse(is_within_directory(directory, target))

# Assert if the target is the directory.
target = "my_directory"
self.assertTrue(is_within_directory(directory, target))

def test_safe_tar_extract_all_within(self):
"""Test when all tar file contents are within the specified directory."""
path = self.tmp_path / "my_directory"
with tarfile.open(self.tmp_path / self.tar_file, "r") as tar:
self.assertIsNone(safe_tar_extract(tar, path))

def test_safe_tar_extract_path_traversal(self):
"""Test when a tar file tries to go outside the specified area."""
with tarfile.open(self.tmp_path / self.tar_file, "r") as tar:
for member in tar.getmembers():
member.name = f"../{member.name}"

path = self.tmp_path / "my_directory"
with self.assertRaises(Exception) as context:
safe_tar_extract(tar, path)
self.assertEqual(
str(context.exception), "Attempted Path Traversal in Tar File"
)

def test_extractors(self):
"""Test extractors."""
self.assertEqual(Extractor(Path("test.zip")).archive, Path("test.zip"))
Expand Down

0 comments on commit f955d9e

Please sign in to comment.