Fix critical security vulnerabilities in model downloader

- Add path traversal protection in simple_downloader.py
- Sanitize model_type and filename inputs to prevent directory escapes
- Validate file extensions against allowed list
- Restrict model types to whitelisted folders only
- Add URL validation to require HTTPS
- Block SSRF attacks by preventing local/private network downloads
- Add input validation in server.py endpoint
- Ensure all file paths remain within models directory

These changes prevent attackers from:
- Writing files outside the models directory
- Accessing sensitive files via path traversal
- Making requests to internal services (SSRF)
- Executing arbitrary code via malicious filenames

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
fragmede 2025-09-27 06:08:32 -07:00
parent be88ae7a64
commit b8a0408c65
No known key found for this signature in database
2 changed files with 81 additions and 13 deletions

View File

@ -22,6 +22,32 @@ class SimpleDownloader:
"""Create a new download task.""" """Create a new download task."""
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
# SECURITY: Validate and sanitize inputs to prevent path traversal
# Sanitize model_type to prevent directory traversal
model_type = os.path.basename(model_type).replace('..', '').replace('/', '').replace('\\', '')
# Sanitize filename to prevent path traversal
filename = os.path.basename(filename).replace('..', '')
# Validate filename has allowed extension
allowed_extensions = ['.safetensors', '.ckpt', '.pt', '.pth', '.bin', '.sft']
if not any(filename.lower().endswith(ext) for ext in allowed_extensions):
raise ValueError(f"Invalid file extension. Allowed extensions: {allowed_extensions}")
# Whitelist of allowed model types
allowed_types = ['checkpoints', 'vae', 'loras', 'controlnet', 'clip', 'unet',
'upscale_models', 'text_encoders', 'diffusion_models', 'embeddings']
# Map alternative names
type_mapping = {
'text_encoders': 'clip',
'diffusion_models': 'unet'
}
model_type = type_mapping.get(model_type, model_type)
if model_type not in allowed_types:
raise ValueError(f"Invalid model type. Allowed types: {allowed_types}")
# Determine destination folder # Determine destination folder
folder_map = { folder_map = {
'checkpoints': folder_paths.get_folder_paths('checkpoints')[0], 'checkpoints': folder_paths.get_folder_paths('checkpoints')[0],
@ -31,24 +57,25 @@ class SimpleDownloader:
'clip': folder_paths.get_folder_paths('clip')[0], 'clip': folder_paths.get_folder_paths('clip')[0],
'unet': folder_paths.get_folder_paths('diffusion_models')[0], 'unet': folder_paths.get_folder_paths('diffusion_models')[0],
'upscale_models': folder_paths.get_folder_paths('upscale_models')[0], 'upscale_models': folder_paths.get_folder_paths('upscale_models')[0],
'embeddings': folder_paths.get_folder_paths('embeddings')[0] if folder_paths.get_folder_paths('embeddings') else os.path.join(folder_paths.models_dir, 'embeddings')
} }
dest_folder = folder_map.get(model_type) dest_folder = folder_map.get(model_type)
if not dest_folder: if not dest_folder:
# Try to find the folder # Only allow creating folders for whitelisted types
try: if model_type in allowed_types:
paths = folder_paths.get_folder_paths(model_type)
if paths:
dest_folder = paths[0]
else:
# Default to models folder
dest_folder = os.path.join(folder_paths.models_dir, model_type)
os.makedirs(dest_folder, exist_ok=True)
except:
dest_folder = os.path.join(folder_paths.models_dir, model_type) dest_folder = os.path.join(folder_paths.models_dir, model_type)
os.makedirs(dest_folder, exist_ok=True) os.makedirs(dest_folder, exist_ok=True)
else:
raise ValueError(f"Cannot find or create folder for model type: {model_type}")
dest_path = os.path.join(dest_folder, filename) # Use safe path joining and verify result
dest_path = os.path.abspath(os.path.join(dest_folder, filename))
# SECURITY: Ensure destination path is within the models directory
models_base = os.path.abspath(folder_paths.models_dir)
if not dest_path.startswith(models_base):
raise ValueError("Invalid destination path - outside models directory")
with self.lock: with self.lock:
self.downloads[task_id] = { self.downloads[task_id] = {
@ -86,6 +113,24 @@ class SimpleDownloader:
dest_path = task['dest_path'] dest_path = task['dest_path']
try: try:
# SECURITY: Validate URL before downloading
from urllib.parse import urlparse
parsed = urlparse(url)
# Only allow HTTPS for security
if parsed.scheme != 'https':
raise ValueError("Only HTTPS URLs are allowed for security")
# Prevent SSRF attacks - block local/private IPs
import socket
try:
ip = socket.gethostbyname(parsed.hostname)
# Block private/local IPs
if ip.startswith(('127.', '10.', '192.168.', '172.')):
raise ValueError("Downloads from local/private networks are not allowed")
except socket.gaierror:
pass # Domain name resolution failed, continue
# Create request with headers # Create request with headers
req = urllib.request.Request(url) req = urllib.request.Request(url)
req.add_header('User-Agent', 'ComfyUI/1.0') req.add_header('User-Agent', 'ComfyUI/1.0')

View File

@ -802,22 +802,45 @@ class PromptServer():
if not url: if not url:
return web.json_response({"error": "URL is required"}, status=400) return web.json_response({"error": "URL is required"}, status=400)
# SECURITY: Validate URL format
from urllib.parse import urlparse
try:
parsed_url = urlparse(url)
if parsed_url.scheme not in ['https']:
return web.json_response({"error": "Only HTTPS URLs are allowed"}, status=400)
except Exception:
return web.json_response({"error": "Invalid URL format"}, status=400)
# SECURITY: Sanitize model_type
import re
if not re.match(r'^[a-zA-Z0-9_-]+$', model_type):
return web.json_response({"error": "Invalid model type format"}, status=400)
if not filename: if not filename:
# Extract filename from URL # Extract filename from URL
filename = url.split('/')[-1].split('?')[0] filename = url.split('/')[-1].split('?')[0]
if not filename: if not filename:
filename = "model.safetensors" filename = "model.safetensors"
# Create download task # SECURITY: Sanitize filename
import os
filename = os.path.basename(filename)
if not re.match(r'^[a-zA-Z0-9_.-]+$', filename):
return web.json_response({"error": "Invalid filename format"}, status=400)
# Create download task (simple_downloader now has additional validation)
task_id = simple_downloader.create_download(url, model_type, filename) task_id = simple_downloader.create_download(url, model_type, filename)
# Return task ID and initial status # Return task ID and initial status
status = simple_downloader.get_status(task_id) status = simple_downloader.get_status(task_id)
return web.json_response(status) return web.json_response(status)
except ValueError as e:
# Return validation errors from simple_downloader
return web.json_response({"error": str(e)}, status=400)
except Exception as e: except Exception as e:
logging.error(f"Error starting download: {e}") logging.error(f"Error starting download: {e}")
return web.json_response({"error": str(e)}, status=500) return web.json_response({"error": "Internal server error"}, status=500)
@routes.get("/models/download/{task_id}") @routes.get("/models/download/{task_id}")
async def get_download_status(request): async def get_download_status(request):