mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 00:43:48 +08:00
Fix critical security vulnerabilities in model downloader
- Add path traversal protection in simple_downloader.py - Sanitize model_type and filename inputs to prevent directory escapes - Validate file extensions against allowed list - Restrict model types to whitelisted folders only - Add URL validation to require HTTPS - Block SSRF attacks by preventing local/private network downloads - Add input validation in server.py endpoint - Ensure all file paths remain within models directory These changes prevent attackers from: - Writing files outside the models directory - Accessing sensitive files via path traversal - Making requests to internal services (SSRF) - Executing arbitrary code via malicious filenames 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
be88ae7a64
commit
b8a0408c65
@ -22,6 +22,32 @@ class SimpleDownloader:
|
|||||||
"""Create a new download task."""
|
"""Create a new download task."""
|
||||||
task_id = str(uuid.uuid4())
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# SECURITY: Validate and sanitize inputs to prevent path traversal
|
||||||
|
# Sanitize model_type to prevent directory traversal
|
||||||
|
model_type = os.path.basename(model_type).replace('..', '').replace('/', '').replace('\\', '')
|
||||||
|
|
||||||
|
# Sanitize filename to prevent path traversal
|
||||||
|
filename = os.path.basename(filename).replace('..', '')
|
||||||
|
|
||||||
|
# Validate filename has allowed extension
|
||||||
|
allowed_extensions = ['.safetensors', '.ckpt', '.pt', '.pth', '.bin', '.sft']
|
||||||
|
if not any(filename.lower().endswith(ext) for ext in allowed_extensions):
|
||||||
|
raise ValueError(f"Invalid file extension. Allowed extensions: {allowed_extensions}")
|
||||||
|
|
||||||
|
# Whitelist of allowed model types
|
||||||
|
allowed_types = ['checkpoints', 'vae', 'loras', 'controlnet', 'clip', 'unet',
|
||||||
|
'upscale_models', 'text_encoders', 'diffusion_models', 'embeddings']
|
||||||
|
|
||||||
|
# Map alternative names
|
||||||
|
type_mapping = {
|
||||||
|
'text_encoders': 'clip',
|
||||||
|
'diffusion_models': 'unet'
|
||||||
|
}
|
||||||
|
model_type = type_mapping.get(model_type, model_type)
|
||||||
|
|
||||||
|
if model_type not in allowed_types:
|
||||||
|
raise ValueError(f"Invalid model type. Allowed types: {allowed_types}")
|
||||||
|
|
||||||
# Determine destination folder
|
# Determine destination folder
|
||||||
folder_map = {
|
folder_map = {
|
||||||
'checkpoints': folder_paths.get_folder_paths('checkpoints')[0],
|
'checkpoints': folder_paths.get_folder_paths('checkpoints')[0],
|
||||||
@ -31,24 +57,25 @@ class SimpleDownloader:
|
|||||||
'clip': folder_paths.get_folder_paths('clip')[0],
|
'clip': folder_paths.get_folder_paths('clip')[0],
|
||||||
'unet': folder_paths.get_folder_paths('diffusion_models')[0],
|
'unet': folder_paths.get_folder_paths('diffusion_models')[0],
|
||||||
'upscale_models': folder_paths.get_folder_paths('upscale_models')[0],
|
'upscale_models': folder_paths.get_folder_paths('upscale_models')[0],
|
||||||
|
'embeddings': folder_paths.get_folder_paths('embeddings')[0] if folder_paths.get_folder_paths('embeddings') else os.path.join(folder_paths.models_dir, 'embeddings')
|
||||||
}
|
}
|
||||||
|
|
||||||
dest_folder = folder_map.get(model_type)
|
dest_folder = folder_map.get(model_type)
|
||||||
if not dest_folder:
|
if not dest_folder:
|
||||||
# Try to find the folder
|
# Only allow creating folders for whitelisted types
|
||||||
try:
|
if model_type in allowed_types:
|
||||||
paths = folder_paths.get_folder_paths(model_type)
|
|
||||||
if paths:
|
|
||||||
dest_folder = paths[0]
|
|
||||||
else:
|
|
||||||
# Default to models folder
|
|
||||||
dest_folder = os.path.join(folder_paths.models_dir, model_type)
|
|
||||||
os.makedirs(dest_folder, exist_ok=True)
|
|
||||||
except:
|
|
||||||
dest_folder = os.path.join(folder_paths.models_dir, model_type)
|
dest_folder = os.path.join(folder_paths.models_dir, model_type)
|
||||||
os.makedirs(dest_folder, exist_ok=True)
|
os.makedirs(dest_folder, exist_ok=True)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Cannot find or create folder for model type: {model_type}")
|
||||||
|
|
||||||
dest_path = os.path.join(dest_folder, filename)
|
# Use safe path joining and verify result
|
||||||
|
dest_path = os.path.abspath(os.path.join(dest_folder, filename))
|
||||||
|
|
||||||
|
# SECURITY: Ensure destination path is within the models directory
|
||||||
|
models_base = os.path.abspath(folder_paths.models_dir)
|
||||||
|
if not dest_path.startswith(models_base):
|
||||||
|
raise ValueError("Invalid destination path - outside models directory")
|
||||||
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.downloads[task_id] = {
|
self.downloads[task_id] = {
|
||||||
@ -86,6 +113,24 @@ class SimpleDownloader:
|
|||||||
dest_path = task['dest_path']
|
dest_path = task['dest_path']
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# SECURITY: Validate URL before downloading
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
parsed = urlparse(url)
|
||||||
|
|
||||||
|
# Only allow HTTPS for security
|
||||||
|
if parsed.scheme != 'https':
|
||||||
|
raise ValueError("Only HTTPS URLs are allowed for security")
|
||||||
|
|
||||||
|
# Prevent SSRF attacks - block local/private IPs
|
||||||
|
import socket
|
||||||
|
try:
|
||||||
|
ip = socket.gethostbyname(parsed.hostname)
|
||||||
|
# Block private/local IPs
|
||||||
|
if ip.startswith(('127.', '10.', '192.168.', '172.')):
|
||||||
|
raise ValueError("Downloads from local/private networks are not allowed")
|
||||||
|
except socket.gaierror:
|
||||||
|
pass # Domain name resolution failed, continue
|
||||||
|
|
||||||
# Create request with headers
|
# Create request with headers
|
||||||
req = urllib.request.Request(url)
|
req = urllib.request.Request(url)
|
||||||
req.add_header('User-Agent', 'ComfyUI/1.0')
|
req.add_header('User-Agent', 'ComfyUI/1.0')
|
||||||
|
|||||||
27
server.py
27
server.py
@ -802,22 +802,45 @@ class PromptServer():
|
|||||||
if not url:
|
if not url:
|
||||||
return web.json_response({"error": "URL is required"}, status=400)
|
return web.json_response({"error": "URL is required"}, status=400)
|
||||||
|
|
||||||
|
# SECURITY: Validate URL format
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
try:
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
if parsed_url.scheme not in ['https']:
|
||||||
|
return web.json_response({"error": "Only HTTPS URLs are allowed"}, status=400)
|
||||||
|
except Exception:
|
||||||
|
return web.json_response({"error": "Invalid URL format"}, status=400)
|
||||||
|
|
||||||
|
# SECURITY: Sanitize model_type
|
||||||
|
import re
|
||||||
|
if not re.match(r'^[a-zA-Z0-9_-]+$', model_type):
|
||||||
|
return web.json_response({"error": "Invalid model type format"}, status=400)
|
||||||
|
|
||||||
if not filename:
|
if not filename:
|
||||||
# Extract filename from URL
|
# Extract filename from URL
|
||||||
filename = url.split('/')[-1].split('?')[0]
|
filename = url.split('/')[-1].split('?')[0]
|
||||||
if not filename:
|
if not filename:
|
||||||
filename = "model.safetensors"
|
filename = "model.safetensors"
|
||||||
|
|
||||||
# Create download task
|
# SECURITY: Sanitize filename
|
||||||
|
import os
|
||||||
|
filename = os.path.basename(filename)
|
||||||
|
if not re.match(r'^[a-zA-Z0-9_.-]+$', filename):
|
||||||
|
return web.json_response({"error": "Invalid filename format"}, status=400)
|
||||||
|
|
||||||
|
# Create download task (simple_downloader now has additional validation)
|
||||||
task_id = simple_downloader.create_download(url, model_type, filename)
|
task_id = simple_downloader.create_download(url, model_type, filename)
|
||||||
|
|
||||||
# Return task ID and initial status
|
# Return task ID and initial status
|
||||||
status = simple_downloader.get_status(task_id)
|
status = simple_downloader.get_status(task_id)
|
||||||
return web.json_response(status)
|
return web.json_response(status)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
# Return validation errors from simple_downloader
|
||||||
|
return web.json_response({"error": str(e)}, status=400)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error starting download: {e}")
|
logging.error(f"Error starting download: {e}")
|
||||||
return web.json_response({"error": str(e)}, status=500)
|
return web.json_response({"error": "Internal server error"}, status=500)
|
||||||
|
|
||||||
@routes.get("/models/download/{task_id}")
|
@routes.get("/models/download/{task_id}")
|
||||||
async def get_download_status(request):
|
async def get_download_status(request):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user