diff --git a/server.py b/server.py index 80d4e7102..4e28c7c25 100644 --- a/server.py +++ b/server.py @@ -1,7 +1,7 @@ import os import sys import asyncio -from pathlib import PurePosixPath +from pathlib import PurePosixPath, PureWindowsPath import traceback import time @@ -464,7 +464,8 @@ class PromptServer(): # Normalize backslashes and use standard library to parse path components normalized = filename.replace('\\', '/') path = PurePosixPath(normalized) - if path.is_absolute() or '..' in path.parts: + win_path = PureWindowsPath(normalized) + if path.is_absolute() or win_path.is_absolute() or win_path.drive or '..' in path.parts: return web.Response(status=400) if output_dir is None: diff --git a/tests-unit/server_test/test_view_endpoint.py b/tests-unit/server_test/test_view_endpoint.py index 622d96003..79b2ff98c 100644 --- a/tests-unit/server_test/test_view_endpoint.py +++ b/tests-unit/server_test/test_view_endpoint.py @@ -5,7 +5,7 @@ from aiohttp import web from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop import os import tempfile -from pathlib import Path, PurePosixPath +from pathlib import Path, PurePosixPath, PureWindowsPath class TestViewEndpointSecurity(AioHTTPTestCase): @@ -40,7 +40,8 @@ class TestViewEndpointSecurity(AioHTTPTestCase): # Normalize backslashes and use standard library to parse path components normalized = filename.replace('\\', '/') path = PurePosixPath(normalized) - if path.is_absolute() or '..' in path.parts: + win_path = PureWindowsPath(normalized) + if path.is_absolute() or win_path.is_absolute() or win_path.drive or '..' in path.parts: return web.Response(status=400) # For testing, just check if file exists in test directory @@ -146,3 +147,21 @@ class TestViewEndpointSecurity(AioHTTPTestCase): """Test that Windows absolute paths with backslash are blocked (e.g., C:\\)""" resp = await self.client.request("GET", "/view?filename=%5Cetc%5Cpasswd") assert resp.status == 400, "Should block backslash absolute paths" + + @unittest_run_loop + async def test_blocks_windows_drive_absolute_path(self): + """Test that Windows drive-qualified absolute paths (C:/...) are blocked""" + resp = await self.client.request("GET", "/view?filename=C:/Windows/secret.png") + assert resp.status == 400, "Should block Windows drive-qualified absolute paths" + + @unittest_run_loop + async def test_blocks_windows_drive_relative_path(self): + """Test that Windows drive-qualified relative paths (C:secret.png) are blocked""" + resp = await self.client.request("GET", "/view?filename=C:secret.png") + assert resp.status == 400, "Should block Windows drive-qualified relative paths" + + @unittest_run_loop + async def test_blocks_windows_drive_backslash_path(self): + """Test that Windows drive paths with backslashes are blocked""" + resp = await self.client.request("GET", "/view?filename=C:%5CWindows%5Csecret.png") + assert resp.status == 400, "Should block Windows drive backslash paths"