mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-20 16:43:45 +08:00
refactor: use PurePosixPath for path traversal validation per review feedback
This commit is contained in:
parent
ae575a8670
commit
ab484ef4e6
11
server.py
11
server.py
@ -3,6 +3,7 @@ import sys
|
|||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
import time
|
import time
|
||||||
|
from pathlib import PurePosixPath
|
||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
import folder_paths
|
import folder_paths
|
||||||
@ -441,9 +442,10 @@ class PromptServer():
|
|||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
|
|
||||||
# validation for security: prevent accessing arbitrary path
|
# validation for security: prevent accessing arbitrary path
|
||||||
# Normalize backslashes to forward slashes to handle Windows-style path traversal (e.g., folder\..\secret)
|
# Normalize backslashes and use standard library to parse path components
|
||||||
normalized = filename.replace('\\', '/')
|
normalized = filename.replace('\\', '/')
|
||||||
if normalized[0] == '/' or '/..' in normalized or normalized.startswith('..'):
|
path = PurePosixPath(normalized)
|
||||||
|
if path.is_absolute() or '..' in path.parts:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
|
|
||||||
if output_dir is None:
|
if output_dir is None:
|
||||||
@ -487,9 +489,10 @@ class PromptServer():
|
|||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
|
|
||||||
# validation for security: prevent accessing arbitrary path
|
# validation for security: prevent accessing arbitrary path
|
||||||
# Normalize backslashes to forward slashes to handle Windows-style path traversal (e.g., folder\..\secret)
|
# Normalize backslashes and use standard library to parse path components
|
||||||
normalized = filename.replace('\\', '/')
|
normalized = filename.replace('\\', '/')
|
||||||
if normalized[0] == '/' or '/..' in normalized or normalized.startswith('..'):
|
path = PurePosixPath(normalized)
|
||||||
|
if path.is_absolute() or '..' in path.parts:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
|
|
||||||
if output_dir is None:
|
if output_dir is None:
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from aiohttp import web
|
|||||||
from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
|
from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path, PurePosixPath
|
||||||
|
|
||||||
|
|
||||||
class TestViewEndpointSecurity(AioHTTPTestCase):
|
class TestViewEndpointSecurity(AioHTTPTestCase):
|
||||||
@ -19,10 +19,12 @@ class TestViewEndpointSecurity(AioHTTPTestCase):
|
|||||||
self.test_dir = tempfile.mkdtemp()
|
self.test_dir = tempfile.mkdtemp()
|
||||||
self.test_file_valid = os.path.join(self.test_dir, "test..png")
|
self.test_file_valid = os.path.join(self.test_dir, "test..png")
|
||||||
self.test_file_normal = os.path.join(self.test_dir, "normal.png")
|
self.test_file_normal = os.path.join(self.test_dir, "normal.png")
|
||||||
|
self.test_file_dotsecret = os.path.join(self.test_dir, "..secret")
|
||||||
|
|
||||||
# Create test files
|
# Create test files
|
||||||
Path(self.test_file_valid).touch()
|
Path(self.test_file_valid).touch()
|
||||||
Path(self.test_file_normal).touch()
|
Path(self.test_file_normal).touch()
|
||||||
|
Path(self.test_file_dotsecret).touch()
|
||||||
|
|
||||||
async def view_handler(request):
|
async def view_handler(request):
|
||||||
"""Simplified /view endpoint handler for testing"""
|
"""Simplified /view endpoint handler for testing"""
|
||||||
@ -35,9 +37,10 @@ class TestViewEndpointSecurity(AioHTTPTestCase):
|
|||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
|
|
||||||
# validation for security: prevent accessing arbitrary path
|
# validation for security: prevent accessing arbitrary path
|
||||||
# Normalize backslashes to forward slashes to handle Windows-style path traversal
|
# Normalize backslashes and use standard library to parse path components
|
||||||
normalized = filename.replace('\\', '/')
|
normalized = filename.replace('\\', '/')
|
||||||
if normalized[0] == '/' or '/..' in normalized or normalized.startswith('..'):
|
path = PurePosixPath(normalized)
|
||||||
|
if path.is_absolute() or '..' in path.parts:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
|
|
||||||
# For testing, just check if file exists in test directory
|
# For testing, just check if file exists in test directory
|
||||||
@ -89,10 +92,16 @@ class TestViewEndpointSecurity(AioHTTPTestCase):
|
|||||||
assert resp.status == 400, "Should block absolute paths"
|
assert resp.status == 400, "Should block absolute paths"
|
||||||
|
|
||||||
@unittest_run_loop
|
@unittest_run_loop
|
||||||
async def test_blocks_dotdot_at_start(self):
|
async def test_allows_dotdot_prefix_in_filename(self):
|
||||||
"""Test that filenames starting with .. are blocked"""
|
"""Test that filenames starting with .. but not as a path component are allowed"""
|
||||||
resp = await self.client.request("GET", "/view?filename=..secret")
|
resp = await self.client.request("GET", "/view?filename=..secret")
|
||||||
assert resp.status == 400, "Should block filenames starting with .."
|
assert resp.status == 200, "Should allow filenames starting with .. that aren't path traversal"
|
||||||
|
|
||||||
|
@unittest_run_loop
|
||||||
|
async def test_blocks_bare_dotdot(self):
|
||||||
|
"""Test that bare .. as a path component is blocked"""
|
||||||
|
resp = await self.client.request("GET", "/view?filename=..")
|
||||||
|
assert resp.status == 400, "Should block bare .. path component"
|
||||||
|
|
||||||
@unittest_run_loop
|
@unittest_run_loop
|
||||||
async def test_multiple_consecutive_dots(self):
|
async def test_multiple_consecutive_dots(self):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user