refactor: use PurePosixPath for path traversal validation per review feedback

This commit is contained in:
Varun Chawla 2026-02-10 02:50:45 -08:00 committed by Test User
parent ae575a8670
commit ab484ef4e6
No known key found for this signature in database
2 changed files with 22 additions and 10 deletions

View File

@ -3,6 +3,7 @@ import sys
import asyncio import asyncio
import traceback import traceback
import time import time
from pathlib import PurePosixPath
import nodes import nodes
import folder_paths import folder_paths
@ -441,9 +442,10 @@ class PromptServer():
return web.Response(status=400) return web.Response(status=400)
# validation for security: prevent accessing arbitrary path # validation for security: prevent accessing arbitrary path
# Normalize backslashes to forward slashes to handle Windows-style path traversal (e.g., folder\..\secret) # Normalize backslashes and use standard library to parse path components
normalized = filename.replace('\\', '/') normalized = filename.replace('\\', '/')
if normalized[0] == '/' or '/..' in normalized or normalized.startswith('..'): path = PurePosixPath(normalized)
if path.is_absolute() or '..' in path.parts:
return web.Response(status=400) return web.Response(status=400)
if output_dir is None: if output_dir is None:
@ -487,9 +489,10 @@ class PromptServer():
return web.Response(status=400) return web.Response(status=400)
# validation for security: prevent accessing arbitrary path # validation for security: prevent accessing arbitrary path
# Normalize backslashes to forward slashes to handle Windows-style path traversal (e.g., folder\..\secret) # Normalize backslashes and use standard library to parse path components
normalized = filename.replace('\\', '/') normalized = filename.replace('\\', '/')
if normalized[0] == '/' or '/..' in normalized or normalized.startswith('..'): path = PurePosixPath(normalized)
if path.is_absolute() or '..' in path.parts:
return web.Response(status=400) return web.Response(status=400)
if output_dir is None: if output_dir is None:

View File

@ -5,7 +5,7 @@ from aiohttp import web
from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
import os import os
import tempfile import tempfile
from pathlib import Path from pathlib import Path, PurePosixPath
class TestViewEndpointSecurity(AioHTTPTestCase): class TestViewEndpointSecurity(AioHTTPTestCase):
@ -19,10 +19,12 @@ class TestViewEndpointSecurity(AioHTTPTestCase):
self.test_dir = tempfile.mkdtemp() self.test_dir = tempfile.mkdtemp()
self.test_file_valid = os.path.join(self.test_dir, "test..png") self.test_file_valid = os.path.join(self.test_dir, "test..png")
self.test_file_normal = os.path.join(self.test_dir, "normal.png") self.test_file_normal = os.path.join(self.test_dir, "normal.png")
self.test_file_dotsecret = os.path.join(self.test_dir, "..secret")
# Create test files # Create test files
Path(self.test_file_valid).touch() Path(self.test_file_valid).touch()
Path(self.test_file_normal).touch() Path(self.test_file_normal).touch()
Path(self.test_file_dotsecret).touch()
async def view_handler(request): async def view_handler(request):
"""Simplified /view endpoint handler for testing""" """Simplified /view endpoint handler for testing"""
@ -35,9 +37,10 @@ class TestViewEndpointSecurity(AioHTTPTestCase):
return web.Response(status=400) return web.Response(status=400)
# validation for security: prevent accessing arbitrary path # validation for security: prevent accessing arbitrary path
# Normalize backslashes to forward slashes to handle Windows-style path traversal # Normalize backslashes and use standard library to parse path components
normalized = filename.replace('\\', '/') normalized = filename.replace('\\', '/')
if normalized[0] == '/' or '/..' in normalized or normalized.startswith('..'): path = PurePosixPath(normalized)
if path.is_absolute() or '..' in path.parts:
return web.Response(status=400) return web.Response(status=400)
# For testing, just check if file exists in test directory # For testing, just check if file exists in test directory
@ -89,10 +92,16 @@ class TestViewEndpointSecurity(AioHTTPTestCase):
assert resp.status == 400, "Should block absolute paths" assert resp.status == 400, "Should block absolute paths"
@unittest_run_loop @unittest_run_loop
async def test_blocks_dotdot_at_start(self): async def test_allows_dotdot_prefix_in_filename(self):
"""Test that filenames starting with .. are blocked""" """Test that filenames starting with .. but not as a path component are allowed"""
resp = await self.client.request("GET", "/view?filename=..secret") resp = await self.client.request("GET", "/view?filename=..secret")
assert resp.status == 400, "Should block filenames starting with .." assert resp.status == 200, "Should allow filenames starting with .. that aren't path traversal"
@unittest_run_loop
async def test_blocks_bare_dotdot(self):
"""Test that bare .. as a path component is blocked"""
resp = await self.client.request("GET", "/view?filename=..")
assert resp.status == 400, "Should block bare .. path component"
@unittest_run_loop @unittest_run_loop
async def test_multiple_consecutive_dots(self): async def test_multiple_consecutive_dots(self):