Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(file_upload): validate mimetype as configured #1459

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion backend/chainlit/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import fnmatch
import glob
import json
import mimetypes
Expand All @@ -9,7 +10,7 @@
import webbrowser
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Optional, Union
from typing import Any, List, Optional, Union

import socketio
from chainlit.auth import create_jwt, get_configuration, get_current_user
Expand Down Expand Up @@ -870,13 +871,67 @@ async def upload_file(
assert file.filename, "No filename for uploaded file"
assert file.content_type, "No content type for uploaded file"

validate_file_upload(file)

file_response = await session.persist_file(
name=file.filename, content=content, mime=file.content_type
)

return JSONResponse(content=file_response)


def validate_file_upload(file: UploadFile):
if config.features.spontaneous_file_upload is None:
return # TODO: if it is not configured what should happen?

if config.features.spontaneous_file_upload.enabled is False:
raise HTTPException(
status_code=400,
detail="File upload is not enabled",
)

validate_file_mime_type(file)
validate_file_size(file)


def validate_file_mime_type(file: UploadFile):
if config.features.spontaneous_file_upload.accept is None:
return

accept = config.features.spontaneous_file_upload.accept
if isinstance(accept, List):
for pattern in accept:
if fnmatch.fnmatch(file.content_type, pattern):
return
else:
for pattern, extensions in accept.items():
if fnmatch.fnmatch(file.content_type, pattern):
if len(extensions) == 0:
return
for extension in extensions:
if file.filename is not None and file.filename.endswith(extension):
return
raise HTTPException(
status_code=400,
detail="File type not allowed",
)


def validate_file_size(file: UploadFile):
if config.features.spontaneous_file_upload.max_size_mb is None:
return

if (
file.size is not None
and file.size
> config.features.spontaneous_file_upload.max_size_mb * 1024 * 1024
):
raise HTTPException(
status_code=400,
detail="File size too large",
)


@router.get("/project/file/{file_id}")
async def get_file(
file_id: str,
Expand Down
186 changes: 180 additions & 6 deletions backend/tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
import datetime # Added import for datetime
import os
from pathlib import Path
import pathlib
import tempfile
from pathlib import Path
from typing import Callable
from unittest.mock import AsyncMock, Mock, create_autospec, mock_open
import datetime # Added import for datetime

import pytest
import tempfile
from chainlit.session import WebsocketSession
from chainlit.auth import get_current_user
from chainlit.config import APP_ROOT, ChainlitConfig, load_config
from chainlit.config import (
APP_ROOT,
ChainlitConfig,
SpontaneousFileUploadFeature,
load_config,
)
from chainlit.server import app
from fastapi.testclient import TestClient
from chainlit.session import WebsocketSession
from chainlit.types import FileReference
from chainlit.user import PersistedUser # Added import for PersistedUser
from fastapi.testclient import TestClient


@pytest.fixture
Expand Down Expand Up @@ -509,6 +514,175 @@ def test_upload_file_unauthorized(
assert response.status_code == 422


def test_upload_file_disabled(
test_client: TestClient,
test_config: ChainlitConfig,
mock_session_get_by_id_patched: Mock,
monkeypatch: pytest.MonkeyPatch,
):
"""Test file upload being disabled by config."""

# Set accept in config
monkeypatch.setattr(
test_config.features,
"spontaneous_file_upload",
SpontaneousFileUploadFeature(enabled=False),
)

# Prepare the files to upload
file_content = b"Sample file content"
files = {
"file": ("test_upload.txt", file_content, "text/plain"),
}

# Make the POST request to upload the file
response = test_client.post(
"/project/file",
files=files,
params={"session_id": mock_session_get_by_id_patched.id},
)

# Verify the response
assert response.status_code == 400


@pytest.mark.parametrize(
"accept_pattern, mime_type, expected_status",
[
({"image/*": [".png", ".gif", ".jpeg", ".jpg"]}, "image/jpeg", 400),
(["image/*"], "text/plain", 400),
(["image/*", "application/*"], "text/plain", 400),
(["image/png", "application/pdf"], "image/jpeg", 400),
(["text/*"], "text/plain", 200),
(["application/*"], "application/pdf", 200),
(["image/*"], "image/jpeg", 200),
(["image/*", "text/*"], "text/plain", 200),
(["*/*"], "text/plain", 200),
(["*/*"], "image/jpeg", 200),
(["*/*"], "application/pdf", 200),
(["image/*", "application/*"], "application/pdf", 200),
(["image/*", "application/*"], "image/jpeg", 200),
(["image/png", "application/pdf"], "image/png", 200),
(["image/png", "application/pdf"], "application/pdf", 200),
({"image/*": []}, "image/jpeg", 200),
(
{"image/*": [".png", ".gif", ".jpeg", ".jpg"]},
"text/plain",
400,
), # mime type not allowed
(
{"*/*": [".txt", ".gif", ".jpeg", ".jpg"]},
"text/plain",
200,
), # extension allowed
(
{"*/*": [".gif", ".jpeg", ".jpg"]},
"text/plain",
400,
), # extension not allowed
],
)
def test_upload_file_mime_type_check(
test_client: TestClient,
test_config: ChainlitConfig,
mock_session_get_by_id_patched: Mock,
monkeypatch: pytest.MonkeyPatch,
accept_pattern: list[str],
mime_type: str,
expected_status: int,
):
"""Test check of mime_type."""

# Set accept in config
monkeypatch.setattr(
test_config.features,
"spontaneous_file_upload",
SpontaneousFileUploadFeature(enabled=True, accept=accept_pattern),
)

# Prepare the files to upload
file_content = b"Sample file content"
files = {
"file": ("test_upload.txt", file_content, mime_type),
}

# Mock the persist_file method to return a known value
expected_file_id = "mocked_file_id"
mock_session_get_by_id_patched.persist_file = AsyncMock(
return_value={
"id": expected_file_id,
"name": "test_upload.txt",
"type": "text/plain",
"size": len(file_content),
}
)

# Make the POST request to upload the file
response = test_client.post(
"/project/file",
files=files,
params={"session_id": mock_session_get_by_id_patched.id},
)

# Verify the response
assert response.status_code == expected_status


@pytest.mark.parametrize(
"file_content, content_multiplier, max_size_mb, expected_status",
[
(b"1", 1, 1, 200),
(b"11", 1024 * 1024, 1, 400),
],
)
def test_upload_file_mime_type_check(
test_client: TestClient,
test_config: ChainlitConfig,
mock_session_get_by_id_patched: Mock,
monkeypatch: pytest.MonkeyPatch,
file_content: bytes,
content_multiplier: int,
max_size_mb: int,
expected_status: int,
):
"""Test check of max_size_mb."""

file_content = file_content * content_multiplier

# Set accept in config
monkeypatch.setattr(
test_config.features,
"spontaneous_file_upload",
SpontaneousFileUploadFeature(max_size_mb=max_size_mb),
)

# Prepare the files to upload
files = {
"file": ("test_upload.txt", file_content, "text/plain"),
}

# Mock the persist_file method to return a known value
expected_file_id = "mocked_file_id"
mock_session_get_by_id_patched.persist_file = AsyncMock(
return_value={
"id": expected_file_id,
"name": "test_upload.txt",
"type": "text/plain",
"size": len(file_content),
}
)

# Make the POST request to upload the file
response = test_client.post(
"/project/file",
files=files,
params={"session_id": mock_session_get_by_id_patched.id},
)

# Verify the response
assert response.status_code == expected_status


def test_project_translations_file_path_traversal(
test_client: TestClient, monkeypatch: pytest.MonkeyPatch
):
Expand Down