diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index 4893c8d03c..55abc8c0c3 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -1,4 +1,5 @@ import asyncio +import fnmatch import glob import json import mimetypes @@ -9,7 +10,7 @@ import webbrowser from contextlib import asynccontextmanager from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import socketio from chainlit.auth import create_jwt, get_configuration, get_current_user @@ -870,6 +871,8 @@ async def upload_file( assert file.filename, "No filename for uploaded file" assert file.content_type, "No content type for uploaded file" + validate_file_upload(file) + file_response = await session.persist_file( name=file.filename, content=content, mime=file.content_type ) @@ -877,6 +880,58 @@ async def upload_file( return JSONResponse(content=file_response) +def validate_file_upload(file: UploadFile): + if config.features.spontaneous_file_upload is None: + return # TODO: if it is not configured what should happen? + + if config.features.spontaneous_file_upload.enabled is False: + raise HTTPException( + status_code=400, + detail="File upload is not enabled", + ) + + validate_file_mime_type(file) + validate_file_size(file) + + +def validate_file_mime_type(file: UploadFile): + if config.features.spontaneous_file_upload.accept is None: + return + + accept = config.features.spontaneous_file_upload.accept + if isinstance(accept, List): + for pattern in accept: + if fnmatch.fnmatch(file.content_type, pattern): + return + else: + for pattern, extensions in accept.items(): + if fnmatch.fnmatch(file.content_type, pattern): + if len(extensions) == 0: + return + for extension in extensions: + if file.filename is not None and file.filename.endswith(extension): + return + raise HTTPException( + status_code=400, + detail="File type not allowed", + ) + + +def validate_file_size(file: UploadFile): + if config.features.spontaneous_file_upload.max_size_mb is None: + return + + if ( + file.size is not None + and file.size + > config.features.spontaneous_file_upload.max_size_mb * 1024 * 1024 + ): + raise HTTPException( + status_code=400, + detail="File size too large", + ) + + @router.get("/project/file/{file_id}") async def get_file( file_id: str, diff --git a/backend/tests/test_server.py b/backend/tests/test_server.py index 36c65124d6..ce49892428 100644 --- a/backend/tests/test_server.py +++ b/backend/tests/test_server.py @@ -1,19 +1,24 @@ +import datetime # Added import for datetime import os -from pathlib import Path import pathlib +import tempfile +from pathlib import Path from typing import Callable from unittest.mock import AsyncMock, Mock, create_autospec, mock_open -import datetime # Added import for datetime import pytest -import tempfile -from chainlit.session import WebsocketSession from chainlit.auth import get_current_user -from chainlit.config import APP_ROOT, ChainlitConfig, load_config +from chainlit.config import ( + APP_ROOT, + ChainlitConfig, + SpontaneousFileUploadFeature, + load_config, +) from chainlit.server import app -from fastapi.testclient import TestClient +from chainlit.session import WebsocketSession from chainlit.types import FileReference from chainlit.user import PersistedUser # Added import for PersistedUser +from fastapi.testclient import TestClient @pytest.fixture @@ -509,6 +514,175 @@ def test_upload_file_unauthorized( assert response.status_code == 422 +def test_upload_file_disabled( + test_client: TestClient, + test_config: ChainlitConfig, + mock_session_get_by_id_patched: Mock, + monkeypatch: pytest.MonkeyPatch, +): + """Test file upload being disabled by config.""" + + # Set accept in config + monkeypatch.setattr( + test_config.features, + "spontaneous_file_upload", + SpontaneousFileUploadFeature(enabled=False), + ) + + # Prepare the files to upload + file_content = b"Sample file content" + files = { + "file": ("test_upload.txt", file_content, "text/plain"), + } + + # Make the POST request to upload the file + response = test_client.post( + "/project/file", + files=files, + params={"session_id": mock_session_get_by_id_patched.id}, + ) + + # Verify the response + assert response.status_code == 400 + + +@pytest.mark.parametrize( + "accept_pattern, mime_type, expected_status", + [ + ({"image/*": [".png", ".gif", ".jpeg", ".jpg"]}, "image/jpeg", 400), + (["image/*"], "text/plain", 400), + (["image/*", "application/*"], "text/plain", 400), + (["image/png", "application/pdf"], "image/jpeg", 400), + (["text/*"], "text/plain", 200), + (["application/*"], "application/pdf", 200), + (["image/*"], "image/jpeg", 200), + (["image/*", "text/*"], "text/plain", 200), + (["*/*"], "text/plain", 200), + (["*/*"], "image/jpeg", 200), + (["*/*"], "application/pdf", 200), + (["image/*", "application/*"], "application/pdf", 200), + (["image/*", "application/*"], "image/jpeg", 200), + (["image/png", "application/pdf"], "image/png", 200), + (["image/png", "application/pdf"], "application/pdf", 200), + ({"image/*": []}, "image/jpeg", 200), + ( + {"image/*": [".png", ".gif", ".jpeg", ".jpg"]}, + "text/plain", + 400, + ), # mime type not allowed + ( + {"*/*": [".txt", ".gif", ".jpeg", ".jpg"]}, + "text/plain", + 200, + ), # extension allowed + ( + {"*/*": [".gif", ".jpeg", ".jpg"]}, + "text/plain", + 400, + ), # extension not allowed + ], +) +def test_upload_file_mime_type_check( + test_client: TestClient, + test_config: ChainlitConfig, + mock_session_get_by_id_patched: Mock, + monkeypatch: pytest.MonkeyPatch, + accept_pattern: list[str], + mime_type: str, + expected_status: int, +): + """Test check of mime_type.""" + + # Set accept in config + monkeypatch.setattr( + test_config.features, + "spontaneous_file_upload", + SpontaneousFileUploadFeature(enabled=True, accept=accept_pattern), + ) + + # Prepare the files to upload + file_content = b"Sample file content" + files = { + "file": ("test_upload.txt", file_content, mime_type), + } + + # Mock the persist_file method to return a known value + expected_file_id = "mocked_file_id" + mock_session_get_by_id_patched.persist_file = AsyncMock( + return_value={ + "id": expected_file_id, + "name": "test_upload.txt", + "type": "text/plain", + "size": len(file_content), + } + ) + + # Make the POST request to upload the file + response = test_client.post( + "/project/file", + files=files, + params={"session_id": mock_session_get_by_id_patched.id}, + ) + + # Verify the response + assert response.status_code == expected_status + + +@pytest.mark.parametrize( + "file_content, content_multiplier, max_size_mb, expected_status", + [ + (b"1", 1, 1, 200), + (b"11", 1024 * 1024, 1, 400), + ], +) +def test_upload_file_mime_type_check( + test_client: TestClient, + test_config: ChainlitConfig, + mock_session_get_by_id_patched: Mock, + monkeypatch: pytest.MonkeyPatch, + file_content: bytes, + content_multiplier: int, + max_size_mb: int, + expected_status: int, +): + """Test check of max_size_mb.""" + + file_content = file_content * content_multiplier + + # Set accept in config + monkeypatch.setattr( + test_config.features, + "spontaneous_file_upload", + SpontaneousFileUploadFeature(max_size_mb=max_size_mb), + ) + + # Prepare the files to upload + files = { + "file": ("test_upload.txt", file_content, "text/plain"), + } + + # Mock the persist_file method to return a known value + expected_file_id = "mocked_file_id" + mock_session_get_by_id_patched.persist_file = AsyncMock( + return_value={ + "id": expected_file_id, + "name": "test_upload.txt", + "type": "text/plain", + "size": len(file_content), + } + ) + + # Make the POST request to upload the file + response = test_client.post( + "/project/file", + files=files, + params={"session_id": mock_session_get_by_id_patched.id}, + ) + + # Verify the response + assert response.status_code == expected_status + + def test_project_translations_file_path_traversal( test_client: TestClient, monkeypatch: pytest.MonkeyPatch ):