diff --git a/app/simple_downloader.py b/app/simple_downloader.py index fb7f86df3..2b18d15c3 100644 --- a/app/simple_downloader.py +++ b/app/simple_downloader.py @@ -22,6 +22,32 @@ class SimpleDownloader: """Create a new download task.""" task_id = str(uuid.uuid4()) + # SECURITY: Validate and sanitize inputs to prevent path traversal + # Sanitize model_type to prevent directory traversal + model_type = os.path.basename(model_type).replace('..', '').replace('/', '').replace('\\', '') + + # Sanitize filename to prevent path traversal + filename = os.path.basename(filename).replace('..', '') + + # Validate filename has allowed extension + allowed_extensions = ['.safetensors', '.ckpt', '.pt', '.pth', '.bin', '.sft'] + if not any(filename.lower().endswith(ext) for ext in allowed_extensions): + raise ValueError(f"Invalid file extension. Allowed extensions: {allowed_extensions}") + + # Whitelist of allowed model types + allowed_types = ['checkpoints', 'vae', 'loras', 'controlnet', 'clip', 'unet', + 'upscale_models', 'text_encoders', 'diffusion_models', 'embeddings'] + + # Map alternative names + type_mapping = { + 'text_encoders': 'clip', + 'diffusion_models': 'unet' + } + model_type = type_mapping.get(model_type, model_type) + + if model_type not in allowed_types: + raise ValueError(f"Invalid model type. Allowed types: {allowed_types}") + # Determine destination folder folder_map = { 'checkpoints': folder_paths.get_folder_paths('checkpoints')[0], @@ -31,24 +57,25 @@ class SimpleDownloader: 'clip': folder_paths.get_folder_paths('clip')[0], 'unet': folder_paths.get_folder_paths('diffusion_models')[0], 'upscale_models': folder_paths.get_folder_paths('upscale_models')[0], + 'embeddings': folder_paths.get_folder_paths('embeddings')[0] if folder_paths.get_folder_paths('embeddings') else os.path.join(folder_paths.models_dir, 'embeddings') } dest_folder = folder_map.get(model_type) if not dest_folder: - # Try to find the folder - try: - paths = folder_paths.get_folder_paths(model_type) - if paths: - dest_folder = paths[0] - else: - # Default to models folder - dest_folder = os.path.join(folder_paths.models_dir, model_type) - os.makedirs(dest_folder, exist_ok=True) - except: + # Only allow creating folders for whitelisted types + if model_type in allowed_types: dest_folder = os.path.join(folder_paths.models_dir, model_type) os.makedirs(dest_folder, exist_ok=True) + else: + raise ValueError(f"Cannot find or create folder for model type: {model_type}") - dest_path = os.path.join(dest_folder, filename) + # Use safe path joining and verify result + dest_path = os.path.abspath(os.path.join(dest_folder, filename)) + + # SECURITY: Ensure destination path is within the models directory + models_base = os.path.abspath(folder_paths.models_dir) + if not dest_path.startswith(models_base): + raise ValueError("Invalid destination path - outside models directory") with self.lock: self.downloads[task_id] = { @@ -86,6 +113,24 @@ class SimpleDownloader: dest_path = task['dest_path'] try: + # SECURITY: Validate URL before downloading + from urllib.parse import urlparse + parsed = urlparse(url) + + # Only allow HTTPS for security + if parsed.scheme != 'https': + raise ValueError("Only HTTPS URLs are allowed for security") + + # Prevent SSRF attacks - block local/private IPs + import socket + try: + ip = socket.gethostbyname(parsed.hostname) + # Block private/local IPs + if ip.startswith(('127.', '10.', '192.168.', '172.')): + raise ValueError("Downloads from local/private networks are not allowed") + except socket.gaierror: + pass # Domain name resolution failed, continue + # Create request with headers req = urllib.request.Request(url) req.add_header('User-Agent', 'ComfyUI/1.0') diff --git a/server.py b/server.py index 35b0defe5..01f4a6926 100644 --- a/server.py +++ b/server.py @@ -802,22 +802,45 @@ class PromptServer(): if not url: return web.json_response({"error": "URL is required"}, status=400) + # SECURITY: Validate URL format + from urllib.parse import urlparse + try: + parsed_url = urlparse(url) + if parsed_url.scheme not in ['https']: + return web.json_response({"error": "Only HTTPS URLs are allowed"}, status=400) + except Exception: + return web.json_response({"error": "Invalid URL format"}, status=400) + + # SECURITY: Sanitize model_type + import re + if not re.match(r'^[a-zA-Z0-9_-]+$', model_type): + return web.json_response({"error": "Invalid model type format"}, status=400) + if not filename: # Extract filename from URL filename = url.split('/')[-1].split('?')[0] if not filename: filename = "model.safetensors" - # Create download task + # SECURITY: Sanitize filename + import os + filename = os.path.basename(filename) + if not re.match(r'^[a-zA-Z0-9_.-]+$', filename): + return web.json_response({"error": "Invalid filename format"}, status=400) + + # Create download task (simple_downloader now has additional validation) task_id = simple_downloader.create_download(url, model_type, filename) # Return task ID and initial status status = simple_downloader.get_status(task_id) return web.json_response(status) + except ValueError as e: + # Return validation errors from simple_downloader + return web.json_response({"error": str(e)}, status=400) except Exception as e: logging.error(f"Error starting download: {e}") - return web.json_response({"error": str(e)}, status=500) + return web.json_response({"error": "Internal server error"}, status=500) @routes.get("/models/download/{task_id}") async def get_download_status(request):