diff --git a/server.py b/server.py index 268441bd1..a899bb0ea 100644 --- a/server.py +++ b/server.py @@ -47,6 +47,7 @@ from app.node_replace_manager import NodeReplaceManager from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes from protocol import BinaryEventTypes +from utils.http_headers import content_disposition_header # Import cache control middleware from middleware.cache_middleware import cache_control @@ -560,7 +561,7 @@ class PromptServer(): buffer.seek(0) return web.Response(body=buffer.read(), content_type=f'image/{image_format}', - headers={"Content-Disposition": f"filename=\"{filename}\""}) + headers={"Content-Disposition": content_disposition_header(filename, "inline")}) if 'channel' not in request.rel_url.query: channel = 'rgba' @@ -580,7 +581,7 @@ class PromptServer(): buffer.seek(0) return web.Response(body=buffer.read(), content_type='image/png', - headers={"Content-Disposition": f"filename=\"{filename}\""}) + headers={"Content-Disposition": content_disposition_header(filename, "inline")}) elif channel == 'a': with Image.open(file) as img: @@ -597,7 +598,7 @@ class PromptServer(): alpha_buffer.seek(0) return web.Response(body=alpha_buffer.read(), content_type='image/png', - headers={"Content-Disposition": f"filename=\"{filename}\""}) + headers={"Content-Disposition": content_disposition_header(filename, "inline")}) else: # Use the content type from asset resolution if available, # otherwise guess from the filename. @@ -614,7 +615,7 @@ class PromptServer(): return web.FileResponse( file, headers={ - "Content-Disposition": f"filename=\"{filename}\"", + "Content-Disposition": content_disposition_header(filename, "inline"), "Content-Type": content_type } ) diff --git a/tests-unit/prompt_server_test/view_image_header_test.py b/tests-unit/prompt_server_test/view_image_header_test.py new file mode 100644 index 000000000..33697a909 --- /dev/null +++ b/tests-unit/prompt_server_test/view_image_header_test.py @@ -0,0 +1,31 @@ +from email.message import Message + +from utils.http_headers import content_disposition_header + + +def test_view_content_disposition_is_valid_inline_filename(): + header = content_disposition_header("example image.png", "inline") + + message = Message() + message["Content-Disposition"] = header + + assert message.get_content_disposition() == "inline" + assert message.get_filename() == "example image.png" + assert "filename*=UTF-8''example%20image.png" in header + + +def test_view_content_disposition_sanitizes_control_characters(): + header = content_disposition_header('bad"\r\nname.png', "inline") + + message = Message() + message["Content-Disposition"] = header + + assert message.get_content_disposition() == "inline" + assert message.get_filename() == "bad__name.png" + + +def test_view_content_disposition_adds_utf8_filename_parameter(): + header = content_disposition_header("café.png", "inline") + + assert 'filename="caf_.png"' in header + assert "filename*=UTF-8''caf%C3%A9.png" in header diff --git a/utils/http_headers.py b/utils/http_headers.py new file mode 100644 index 000000000..68195aeec --- /dev/null +++ b/utils/http_headers.py @@ -0,0 +1,21 @@ +import re +import urllib.parse + + +_CONTROL_CHARS_RE = re.compile(r"[\r\n]+") +_QUOTED_FILENAME_UNSAFE_RE = re.compile(r'["\\]') +_NON_ASCII_RE = re.compile(r"[^\x20-\x7E]") + + +def content_disposition_header(filename: str, disposition: str) -> str: + if disposition not in {"inline", "attachment"}: + raise ValueError(f"Unsupported Content-Disposition type: {disposition}") + + safe_filename = _CONTROL_CHARS_RE.sub("_", filename or "") + safe_filename = _QUOTED_FILENAME_UNSAFE_RE.sub("_", safe_filename) + fallback_filename = _NON_ASCII_RE.sub("_", safe_filename) + encoded_filename = urllib.parse.quote(safe_filename) + return ( + f'{disposition}; filename="{fallback_filename}"; ' + f"filename*=UTF-8''{encoded_filename}" + )