diff --git a/server.py b/server.py index d606bb97c..f49d9e604 100644 --- a/server.py +++ b/server.py @@ -441,8 +441,9 @@ class PromptServer(): return web.Response(status=400) # validation for security: prevent accessing arbitrary path - # Check for path traversal patterns (../ or /..) but allow consecutive dots in filename - if filename[0] == '/' or '/..' in filename or filename.startswith('..'): + # Normalize backslashes to forward slashes to handle Windows-style path traversal (e.g., folder\..\secret) + normalized = filename.replace('\\', '/') + if normalized[0] == '/' or '/..' in normalized or normalized.startswith('..'): return web.Response(status=400) if output_dir is None: @@ -486,8 +487,9 @@ class PromptServer(): return web.Response(status=400) # validation for security: prevent accessing arbitrary path - # Check for path traversal patterns (../ or /..) but allow consecutive dots in filename - if filename[0] == '/' or '/..' in filename or filename.startswith('..'): + # Normalize backslashes to forward slashes to handle Windows-style path traversal (e.g., folder\..\secret) + normalized = filename.replace('\\', '/') + if normalized[0] == '/' or '/..' in normalized or normalized.startswith('..'): return web.Response(status=400) if output_dir is None: diff --git a/tests-unit/server_test/test_view_endpoint.py b/tests-unit/server_test/test_view_endpoint.py index 19d156233..1f7ca083c 100644 --- a/tests-unit/server_test/test_view_endpoint.py +++ b/tests-unit/server_test/test_view_endpoint.py @@ -35,8 +35,9 @@ class TestViewEndpointSecurity(AioHTTPTestCase): return web.Response(status=400) # validation for security: prevent accessing arbitrary path - # Check for path traversal patterns (../ or /..) but allow consecutive dots in filename - if filename[0] == '/' or '/..' in filename or filename.startswith('..'): + # Normalize backslashes to forward slashes to handle Windows-style path traversal + normalized = filename.replace('\\', '/') + if normalized[0] == '/' or '/..' in normalized or normalized.startswith('..'): return web.Response(status=400) # For testing, just check if file exists in test directory @@ -112,3 +113,27 @@ class TestViewEndpointSecurity(AioHTTPTestCase): resp = await self.client.request("GET", "/view?filename=my..file..name.png") assert resp.status == 200, "Should allow dots in middle of filename" + + @unittest_run_loop + async def test_blocks_backslash_path_traversal(self): + """Test that Windows-style backslash path traversal (\\..) is blocked""" + resp = await self.client.request("GET", "/view?filename=folder%5C..%5Csecret") + assert resp.status == 400, "Should block backslash path traversal" + + @unittest_run_loop + async def test_blocks_backslash_dotdot_at_start(self): + """Test that backslash path traversal starting with ..\\ is blocked""" + resp = await self.client.request("GET", "/view?filename=..%5Cetc%5Cpasswd") + assert resp.status == 400, "Should block ..\\ path traversal at start" + + @unittest_run_loop + async def test_blocks_mixed_slash_backslash_traversal(self): + """Test that mixed forward/backslash path traversal is blocked""" + resp = await self.client.request("GET", "/view?filename=folder/..%5Csecret") + assert resp.status == 400, "Should block mixed slash/backslash path traversal" + + @unittest_run_loop + async def test_blocks_backslash_absolute_path(self): + """Test that Windows absolute paths with backslash are blocked (e.g., C:\\)""" + resp = await self.client.request("GET", "/view?filename=%5Cetc%5Cpasswd") + assert resp.status == 400, "Should block backslash absolute paths"