Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add security check to see if a URL is pointing to internal IP #421

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
6 changes: 2 additions & 4 deletions src/agentscope/manager/_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,19 @@ def _get_text_embedding_record_hash(
if isinstance(embedding_model, dict):
# Format the dict to avoid duplicate keys
embedding_model = json.dumps(embedding_model, sort_keys=True)
elif isinstance(embedding_model, str):
embedding_model_hash = _hash_string(embedding_model, hash_method)
else:
elif not isinstance(embedding_model, str):
raise RuntimeError(
f"The embedding model must be a string or a dict, got "
f"{type(embedding_model)}.",
)
embedding_model_hash = _hash_string(embedding_model, hash_method)

# Calculate the embedding id by hashing the hash codes of the
# original data and the embedding model
record_hash = _hash_string(
original_data_hash + embedding_model_hash,
hash_method,
)

return record_hash


Expand Down
3 changes: 2 additions & 1 deletion src/agentscope/rag/llama_index_knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,9 @@ def __init__(
)

if persist_root is None:
persist_root = FileManager.get_instance().run_dir or "./"
persist_root = FileManager.get_instance().cache_dir or "./"
self.persist_dir = os.path.join(persist_root, knowledge_id)
logger.info(f"** persist_dir: {self.persist_dir}")
self.emb_model = emb_model
self.overwrite_index = overwrite_index
self.showprogress = showprogress
Expand Down
45 changes: 45 additions & 0 deletions src/agentscope/service/web/web_digest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import json
from urllib.parse import urlparse
from typing import Optional, Callable, Sequence, Any
import socket
import ipaddress
import requests
from loguru import logger

Expand Down Expand Up @@ -37,12 +39,46 @@ def is_valid_url(url: str) -> bool:
return False # A ValueError indicates that the URL is not valid.


def is_internal_ip_address(url: str) -> bool:
"""
Check if a URL is to interal IP addresses
Args:
url (str): url to be checked

Returns:
bool: True if url is not to interal IP addresses,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove "not" ?And typo "interal" here and above

False otherwise
"""
parsed_url = urlparse(url)
hostname = parsed_url.hostname
if hostname is None:
# illegal hostname is ignore in this function
return False

# Resolve the hostname to an IP address
ip = socket.gethostbyname(hostname)
# Check if it's localhost or within the loopback range
if (
ip.startswith("127.")
or ip == "::1"
or ipaddress.ip_address(ip).is_private
):
logger.warning(
f"Access to this URL {url} is "
f"restricted because it is private",
)
return True

return False


def load_web(
url: str,
keep_raw: bool = True,
html_selected_tags: Optional[Sequence[str]] = None,
self_parse_func: Optional[Callable[[requests.Response], Any]] = None,
timeout: int = 5,
exclude_internal_ips: bool = True,
) -> ServiceResponse:
"""Function for parsing and digesting the web page.

Expand All @@ -62,6 +98,8 @@ def load_web(
The result is stored with `self_define_func`
key
timeout (int): timeout parameter for requests.
exclude_internal_ips (bool):
whether prevent the function access internal_ips

Returns:
`ServiceResponse`: If successful, `ServiceResponse` object is returned
Expand All @@ -87,6 +125,13 @@ def load_web(
"selected_tags_text": xxxxx
}
"""
if exclude_internal_ips and is_internal_ip_address(url):
return ServiceResponse(
ServiceExecStatus.ERROR,
content=f"Access to this URL {url} is restricted "
f"because it is private",
)

header = {
"Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6",
"Cache-Control": "max-age=0",
Expand Down
8 changes: 7 additions & 1 deletion tests/web_digest_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_web_load(self, mock_get: MagicMock) -> None:
mock_get.return_value = mock_response

# set parameters
fake_url = "fake-url"
fake_url = "http://fake-url.com"

results = load_web(
url=fake_url,
Expand Down Expand Up @@ -100,6 +100,12 @@ def format(
expected_result,
)

def test_block_internal_ips(self) -> None:
"""test whether can prevent internal_url successfully"""
internal_url = "http://localhost:8080/some/path"
response = load_web(internal_url)
self.assertEqual(ServiceExecStatus.ERROR, response.status)


# This allows the tests to be run from the command line
if __name__ == "__main__":
Expand Down