From ab484ef4e6fa66c9e84275d732376cd04e47b8a8 Mon Sep 17 00:00:00 2001 From: Varun Chawla Date: Tue, 10 Feb 2026 02:50:45 -0800 Subject: [PATCH] refactor: use PurePosixPath for path traversal validation per review feedback --- server.py | 11 ++++++---- tests-unit/server_test/test_view_endpoint.py | 21 ++++++++++++++------ 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/server.py b/server.py index f49d9e604..ec442bf8b 100644 --- a/server.py +++ b/server.py @@ -3,6 +3,7 @@ import sys import asyncio import traceback import time +from pathlib import PurePosixPath import nodes import folder_paths @@ -441,9 +442,10 @@ class PromptServer(): return web.Response(status=400) # validation for security: prevent accessing arbitrary path - # Normalize backslashes to forward slashes to handle Windows-style path traversal (e.g., folder\..\secret) + # Normalize backslashes and use standard library to parse path components normalized = filename.replace('\\', '/') - if normalized[0] == '/' or '/..' in normalized or normalized.startswith('..'): + path = PurePosixPath(normalized) + if path.is_absolute() or '..' in path.parts: return web.Response(status=400) if output_dir is None: @@ -487,9 +489,10 @@ class PromptServer(): return web.Response(status=400) # validation for security: prevent accessing arbitrary path - # Normalize backslashes to forward slashes to handle Windows-style path traversal (e.g., folder\..\secret) + # Normalize backslashes and use standard library to parse path components normalized = filename.replace('\\', '/') - if normalized[0] == '/' or '/..' in normalized or normalized.startswith('..'): + path = PurePosixPath(normalized) + if path.is_absolute() or '..' in path.parts: return web.Response(status=400) if output_dir is None: diff --git a/tests-unit/server_test/test_view_endpoint.py b/tests-unit/server_test/test_view_endpoint.py index 1f7ca083c..622d96003 100644 --- a/tests-unit/server_test/test_view_endpoint.py +++ b/tests-unit/server_test/test_view_endpoint.py @@ -5,7 +5,7 @@ from aiohttp import web from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop import os import tempfile -from pathlib import Path +from pathlib import Path, PurePosixPath class TestViewEndpointSecurity(AioHTTPTestCase): @@ -19,10 +19,12 @@ class TestViewEndpointSecurity(AioHTTPTestCase): self.test_dir = tempfile.mkdtemp() self.test_file_valid = os.path.join(self.test_dir, "test..png") self.test_file_normal = os.path.join(self.test_dir, "normal.png") + self.test_file_dotsecret = os.path.join(self.test_dir, "..secret") # Create test files Path(self.test_file_valid).touch() Path(self.test_file_normal).touch() + Path(self.test_file_dotsecret).touch() async def view_handler(request): """Simplified /view endpoint handler for testing""" @@ -35,9 +37,10 @@ class TestViewEndpointSecurity(AioHTTPTestCase): return web.Response(status=400) # validation for security: prevent accessing arbitrary path - # Normalize backslashes to forward slashes to handle Windows-style path traversal + # Normalize backslashes and use standard library to parse path components normalized = filename.replace('\\', '/') - if normalized[0] == '/' or '/..' in normalized or normalized.startswith('..'): + path = PurePosixPath(normalized) + if path.is_absolute() or '..' in path.parts: return web.Response(status=400) # For testing, just check if file exists in test directory @@ -89,10 +92,16 @@ class TestViewEndpointSecurity(AioHTTPTestCase): assert resp.status == 400, "Should block absolute paths" @unittest_run_loop - async def test_blocks_dotdot_at_start(self): - """Test that filenames starting with .. are blocked""" + async def test_allows_dotdot_prefix_in_filename(self): + """Test that filenames starting with .. but not as a path component are allowed""" resp = await self.client.request("GET", "/view?filename=..secret") - assert resp.status == 400, "Should block filenames starting with .." + assert resp.status == 200, "Should allow filenames starting with .. that aren't path traversal" + + @unittest_run_loop + async def test_blocks_bare_dotdot(self): + """Test that bare .. as a path component is blocked""" + resp = await self.client.request("GET", "/view?filename=..") + assert resp.status == 400, "Should block bare .. path component" @unittest_run_loop async def test_multiple_consecutive_dots(self):