From 2f7b77f341c985a9f3f57c041f3266aaf4dce6a9 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 24 Mar 2026 23:47:59 -0700 Subject: [PATCH] feat: add model download API gated behind --enable-download-api Add a new server-side download API that allows frontends and desktop apps to download models directly into ComfyUI's models directory, eliminating the need for DOM scraping of the frontend UI. New files: - app/download_manager.py: Async download manager with streaming downloads, pause/resume/cancel, manual redirect following with per-hop host validation, sidecar metadata for safe resume, and concurrency limiting. API endpoints (all under /download/, also mirrored at /api/download/): - POST /download/model - Start a download (url, directory, filename) - GET /download/status - List all downloads (filterable by client_id) - GET /download/status/{id} - Get single download status - POST /download/pause/{id} - Pause (cancels transfer, keeps temp) - POST /download/resume/{id} - Resume (new request with Range header) - POST /download/cancel/{id} - Cancel and clean up temp files Security: - Gated behind --enable-download-api CLI flag (403 if disabled) - HTTPS-only with exact host allowlist (huggingface.co, civitai.com + CDNs) - Manual redirect following with per-hop host validation (no SSRF) - Path traversal protection via realpath + commonpath - Extension allowlist (.safetensors, .sft) - Filename sanitization (no separators, .., control chars) - Destination re-checked before final rename - Progress events scoped to initiating client_id Closes Comfy-Org/ComfyUI-Desktop-2.0-Beta#293 Amp-Thread-ID: https://ampcode.com/threads/T-019d2344-139e-77a5-9f24-1cbb3b26a8ec Co-authored-by: Amp --- app/download_manager.py | 507 +++++++++++++++++++++++++++++++++++++ comfy/cli_args.py | 2 + comfy_api/feature_flags.py | 1 + server.py | 84 ++++++ 4 files changed, 594 insertions(+) create mode 100644 app/download_manager.py diff --git a/app/download_manager.py b/app/download_manager.py new file mode 100644 index 000000000..8bb52497f --- /dev/null +++ b/app/download_manager.py @@ -0,0 +1,507 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import os +import time +import uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional, TYPE_CHECKING +from urllib.parse import urlsplit + +import aiohttp +from yarl import URL + +import folder_paths + +if TYPE_CHECKING: + from server import PromptServer + +logger = logging.getLogger(__name__) + +ALLOWED_HTTPS_HOSTS = frozenset({ + "huggingface.co", + "cdn-lfs.huggingface.co", + "cdn-lfs-us-1.huggingface.co", + "cdn-lfs-eu-1.huggingface.co", + "civitai.com", + "api.civitai.com", +}) + +ALLOWED_EXTENSIONS = frozenset({".safetensors", ".sft"}) + +MAX_CONCURRENT_DOWNLOADS = 3 +MAX_TERMINAL_TASKS = 50 +MAX_REDIRECTS = 10 + +DOWNLOAD_TEMP_SUFFIX = ".download_tmp" +DOWNLOAD_META_SUFFIX = ".download_meta" + + +class DownloadStatus(str, Enum): + PENDING = "pending" + DOWNLOADING = "downloading" + PAUSED = "paused" + COMPLETED = "completed" + ERROR = "error" + CANCELLED = "cancelled" + + +ACTIVE_STATUSES = frozenset({ + DownloadStatus.PENDING, + DownloadStatus.DOWNLOADING, + DownloadStatus.PAUSED, +}) + +TERMINAL_STATUSES = frozenset({ + DownloadStatus.COMPLETED, + DownloadStatus.ERROR, + DownloadStatus.CANCELLED, +}) + + +@dataclass +class DownloadTask: + id: str + url: str + filename: str + directory: str + save_path: str + temp_path: str + meta_path: str + status: DownloadStatus = DownloadStatus.PENDING + progress: float = 0.0 + received_bytes: int = 0 + total_bytes: int = 0 + speed_bytes_per_sec: float = 0.0 + eta_seconds: float = 0.0 + error: Optional[str] = None + created_at: float = field(default_factory=time.time) + client_id: Optional[str] = None + _worker: Optional[asyncio.Task] = field(default=None, repr=False) + _stop_reason: Optional[str] = field(default=None, repr=False) + + def to_dict(self) -> dict: + return { + "id": self.id, + "url": self.url, + "filename": self.filename, + "directory": self.directory, + "status": self.status.value, + "progress": self.progress, + "received_bytes": self.received_bytes, + "total_bytes": self.total_bytes, + "speed_bytes_per_sec": self.speed_bytes_per_sec, + "eta_seconds": self.eta_seconds, + "error": self.error, + "created_at": self.created_at, + } + + +class DownloadManager: + def __init__(self, server: PromptServer): + self.server = server + self.tasks: dict[str, DownloadTask] = {} + self._session: Optional[aiohttp.ClientSession] = None + self._semaphore = asyncio.Semaphore(MAX_CONCURRENT_DOWNLOADS) + + async def _get_session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + timeout = aiohttp.ClientTimeout(total=None, connect=30, sock_read=60) + self._session = aiohttp.ClientSession(timeout=timeout) + return self._session + + async def close(self): + workers = [t._worker for t in self.tasks.values() if t._worker and not t._worker.done()] + for w in workers: + w.cancel() + if workers: + await asyncio.gather(*workers, return_exceptions=True) + if self._session and not self._session.closed: + await self._session.close() + + # -- Validation -- + + @staticmethod + def _validate_url(url: str) -> Optional[str]: + try: + parts = urlsplit(url) + except Exception: + return "Invalid URL" + + if parts.username or parts.password: + return "Credentials in URL are not allowed" + + host = (parts.hostname or "").lower() + scheme = parts.scheme.lower() + + if scheme != "https": + return "Only HTTPS URLs are allowed" + + if host not in ALLOWED_HTTPS_HOSTS: + return f"Host '{host}' is not in the allowed list" + + if parts.port not in (None, 443): + return "Custom ports are not allowed for remote downloads" + + return None + + @staticmethod + def _validate_filename(filename: str) -> Optional[str]: + if not filename: + return "Filename must not be empty" + ext = os.path.splitext(filename)[1].lower() + if ext not in ALLOWED_EXTENSIONS: + return f"File extension '{ext}' not allowed. Allowed: {', '.join(sorted(ALLOWED_EXTENSIONS))}" + if os.path.sep in filename or (os.path.altsep and os.path.altsep in filename): + return "Filename must not contain path separators" + if ".." in filename: + return "Filename must not contain '..'" + for ch in filename: + if ord(ch) < 32: + return "Filename must not contain control characters" + return None + + @staticmethod + def _validate_directory(directory: str) -> Optional[str]: + if directory not in folder_paths.folder_names_and_paths: + valid = ', '.join(sorted(folder_paths.folder_names_and_paths.keys())) + return f"Unknown model directory '{directory}'. Valid directories: {valid}" + return None + + @staticmethod + def _resolve_save_path(directory: str, filename: str) -> tuple[str, str, str]: + """Returns (save_path, temp_path, meta_path) for a download.""" + paths = folder_paths.folder_names_and_paths[directory][0] + base_dir = paths[0] + os.makedirs(base_dir, exist_ok=True) + + save_path = os.path.join(base_dir, filename) + temp_path = save_path + DOWNLOAD_TEMP_SUFFIX + meta_path = save_path + DOWNLOAD_META_SUFFIX + + real_save = os.path.realpath(save_path) + real_base = os.path.realpath(base_dir) + if os.path.commonpath([real_save, real_base]) != real_base: + raise ValueError("Resolved path escapes the model directory") + + return save_path, temp_path, meta_path + + # -- Sidecar metadata for resume validation -- + + @staticmethod + def _write_meta(meta_path: str, url: str, task_id: str): + try: + with open(meta_path, "w") as f: + json.dump({"url": url, "task_id": task_id}, f) + except OSError: + pass + + @staticmethod + def _read_meta(meta_path: str) -> Optional[dict]: + try: + with open(meta_path, "r") as f: + return json.load(f) + except (OSError, json.JSONDecodeError): + return None + + @staticmethod + def _cleanup_files(*paths: str): + for p in paths: + try: + if os.path.exists(p): + os.remove(p) + except OSError: + pass + + # -- Task management -- + + def _prune_terminal_tasks(self): + terminal = [ + (tid, t) for tid, t in self.tasks.items() + if t.status in TERMINAL_STATUSES + ] + if len(terminal) > MAX_TERMINAL_TASKS: + terminal.sort(key=lambda x: x[1].created_at) + to_remove = len(terminal) - MAX_TERMINAL_TASKS + for tid, _ in terminal[:to_remove]: + del self.tasks[tid] + + async def start_download( + self, url: str, directory: str, filename: str, client_id: Optional[str] = None + ) -> tuple[Optional[DownloadTask], Optional[str]]: + err = self._validate_url(url) + if err: + return None, err + + err = self._validate_filename(filename) + if err: + return None, err + + err = self._validate_directory(directory) + if err: + return None, err + + try: + save_path, temp_path, meta_path = self._resolve_save_path(directory, filename) + except ValueError as e: + return None, str(e) + + if os.path.exists(save_path): + return None, f"File already exists: {directory}/{filename}" + + # Reject duplicate active download by URL + for task in self.tasks.values(): + if task.url == url and task.status in ACTIVE_STATUSES: + return None, f"Download already in progress for this URL (id: {task.id})" + + # Reject duplicate active download by destination path (#4) + for task in self.tasks.values(): + if task.save_path == save_path and task.status in ACTIVE_STATUSES: + return None, f"Download already in progress for {directory}/{filename} (id: {task.id})" + + # Clean stale temp/meta if no active task owns them (#9) + existing_meta = self._read_meta(meta_path) + if existing_meta: + owning_task = self.tasks.get(existing_meta.get("task_id", "")) + if not owning_task or owning_task.status in TERMINAL_STATUSES: + if existing_meta.get("url") != url: + self._cleanup_files(temp_path, meta_path) + + task = DownloadTask( + id=uuid.uuid4().hex[:12], + url=url, + filename=filename, + directory=directory, + save_path=save_path, + temp_path=temp_path, + meta_path=meta_path, + client_id=client_id, + ) + self.tasks[task.id] = task + self._prune_terminal_tasks() + + task._worker = asyncio.create_task(self._run_download(task)) + return task, None + + # -- Redirect-safe fetch (#1, #2, #3) -- + + async def _fetch_with_validated_redirects( + self, session: aiohttp.ClientSession, url: str, headers: dict + ) -> aiohttp.ClientResponse: + """Follow redirects manually, validating each hop against the allowlist.""" + current_url = url + for _ in range(MAX_REDIRECTS + 1): + resp = await session.get(current_url, headers=headers, allow_redirects=False) + if resp.status not in (301, 302, 303, 307, 308): + return resp + + location = resp.headers.get("Location") + await resp.release() + if not location: + raise ValueError("Redirect without Location header") + + resolved = URL(current_url).join(URL(location)) + current_url = str(resolved) + + # Validate the redirect target host + parts = urlsplit(current_url) + host = (parts.hostname or "").lower() + scheme = parts.scheme.lower() + + if scheme != "https": + raise ValueError(f"Redirect to non-HTTPS URL: {current_url}") + if host not in ALLOWED_HTTPS_HOSTS: + # Allow CDN hosts that HuggingFace/CivitAI commonly redirect to + raise ValueError(f"Redirect to disallowed host: {host}") + + # 303 means GET with no Range + if resp.status == 303: + headers = {k: v for k, v in headers.items() if k.lower() != "range"} + + raise ValueError(f"Too many redirects (>{MAX_REDIRECTS})") + + # -- Download worker -- + + async def _run_download(self, task: DownloadTask): + try: + async with self._semaphore: + await self._run_download_inner(task) + except asyncio.CancelledError: + if task._stop_reason == "pause": + task.status = DownloadStatus.PAUSED + task.speed_bytes_per_sec = 0 + task.eta_seconds = 0 + await self._send_progress(task) + else: + task.status = DownloadStatus.CANCELLED + await self._send_progress(task) + self._cleanup_files(task.temp_path, task.meta_path) + except Exception as e: + task.status = DownloadStatus.ERROR + task.error = str(e) + await self._send_progress(task) + logger.exception("Download error for %s", task.url) + + async def _run_download_inner(self, task: DownloadTask): + session = await self._get_session() + headers = {} + + # Resume support with sidecar validation (#9) + if os.path.exists(task.temp_path): + meta = self._read_meta(task.meta_path) + if meta and meta.get("url") == task.url: + existing_size = os.path.getsize(task.temp_path) + if existing_size > 0: + headers["Range"] = f"bytes={existing_size}-" + task.received_bytes = existing_size + else: + self._cleanup_files(task.temp_path, task.meta_path) + + self._write_meta(task.meta_path, task.url, task.id) + task.status = DownloadStatus.DOWNLOADING + await self._send_progress(task) + + resp = await self._fetch_with_validated_redirects(session, task.url, headers) + try: + if resp.status == 416: + content_range = resp.headers.get("Content-Range", "") + if content_range: + total_str = content_range.split("/")[-1] + if total_str != "*": + total = int(total_str) + if task.received_bytes >= total: + if not os.path.exists(task.save_path): + os.rename(task.temp_path, task.save_path) + self._cleanup_files(task.meta_path) + task.status = DownloadStatus.COMPLETED + task.progress = 1.0 + task.total_bytes = total + await self._send_progress(task) + return + raise ValueError(f"HTTP 416 Range Not Satisfiable") + + if resp.status not in (200, 206): + task.status = DownloadStatus.ERROR + task.error = f"HTTP {resp.status}" + await self._send_progress(task) + return + + if resp.status == 200: + task.received_bytes = 0 + + content_length = resp.content_length + if resp.status == 206 and content_length: + task.total_bytes = task.received_bytes + content_length + elif resp.status == 200 and content_length: + task.total_bytes = content_length + + mode = "ab" if resp.status == 206 else "wb" + speed_window_start = time.monotonic() + speed_window_bytes = 0 + last_progress_time = 0.0 + + with open(task.temp_path, mode) as f: + async for chunk in resp.content.iter_chunked(1024 * 64): + f.write(chunk) + task.received_bytes += len(chunk) + speed_window_bytes += len(chunk) + + now = time.monotonic() + elapsed = now - speed_window_start + if elapsed > 0.5: + task.speed_bytes_per_sec = speed_window_bytes / elapsed + if task.total_bytes > 0 and task.speed_bytes_per_sec > 0: + remaining = task.total_bytes - task.received_bytes + task.eta_seconds = remaining / task.speed_bytes_per_sec + speed_window_start = now + speed_window_bytes = 0 + + if task.total_bytes > 0: + task.progress = task.received_bytes / task.total_bytes + + if now - last_progress_time >= 0.25: + await self._send_progress(task) + last_progress_time = now + finally: + resp.release() + + # Final cancel check before committing (#7) + if task._stop_reason is not None: + raise asyncio.CancelledError() + + # Re-check destination before finalizing (#10) + if os.path.exists(task.save_path): + task.status = DownloadStatus.ERROR + task.error = f"Destination file appeared during download: {task.directory}/{task.filename}" + await self._send_progress(task) + return + + os.replace(task.temp_path, task.save_path) + self._cleanup_files(task.meta_path) + task.status = DownloadStatus.COMPLETED + task.progress = 1.0 + task.speed_bytes_per_sec = 0 + task.eta_seconds = 0 + await self._send_progress(task) + logger.info("Download complete: %s/%s", task.directory, task.filename) + + # -- Progress (#8, #14) -- + + async def _send_progress(self, task: DownloadTask): + try: + self.server.send_sync("download_progress", task.to_dict(), task.client_id) + except Exception: + logger.exception("Failed to send download progress event") + + # -- Control operations (#5, #6, #13) -- + + def pause_download(self, task_id: str) -> Optional[str]: + task = self.tasks.get(task_id) + if not task: + return "Download not found" + if task.status not in (DownloadStatus.PENDING, DownloadStatus.DOWNLOADING): + return f"Cannot pause download in state '{task.status.value}'" + task._stop_reason = "pause" + if task._worker and not task._worker.done(): + task._worker.cancel() + return None + + def resume_download(self, task_id: str) -> Optional[str]: + task = self.tasks.get(task_id) + if not task: + return "Download not found" + if task.status != DownloadStatus.PAUSED: + return f"Cannot resume download in state '{task.status.value}'" + task._stop_reason = None + task.status = DownloadStatus.PENDING + task._worker = asyncio.create_task(self._run_download(task)) + return None + + def cancel_download(self, task_id: str) -> Optional[str]: + task = self.tasks.get(task_id) + if not task: + return "Download not found" + if task.status in TERMINAL_STATUSES: + return f"Cannot cancel download in state '{task.status.value}'" + task._stop_reason = "cancel" + if task._worker and not task._worker.done(): + task._worker.cancel() + else: + task.status = DownloadStatus.CANCELLED + self._cleanup_files(task.temp_path, task.meta_path) + return None + + # -- Query -- + + def get_all_tasks(self, client_id: Optional[str] = None) -> list[dict]: + tasks = self.tasks.values() + if client_id is not None: + tasks = [t for t in tasks if t.client_id == client_id] + return [t.to_dict() for t in tasks] + + def get_task(self, task_id: str) -> Optional[dict]: + task = self.tasks.get(task_id) + return task.to_dict() if task else None diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 13612175e..a70f0eaba 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -224,6 +224,8 @@ parser.add_argument("--user-directory", type=is_valid_directory, default=None, h parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.") +parser.add_argument("--enable-download-api", action="store_true", help="Enable the model download API. When set, ComfyUI exposes endpoints that allow downloading model files directly into the models directory. Only HTTPS downloads from allowed hosts (huggingface.co, civitai.com) are permitted.") + parser.add_argument( "--comfy-api-base", type=str, diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index 9f6918315..cca94ca4b 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -16,6 +16,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = { "extension": {"manager": {"supports_v4": True}}, "node_replacements": True, "assets": args.enable_assets, + "download_api": args.enable_download_api, } diff --git a/server.py b/server.py index 173a28376..c1201c688 100644 --- a/server.py +++ b/server.py @@ -43,6 +43,7 @@ from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager from app.subgraph_manager import SubgraphManager from app.node_replace_manager import NodeReplaceManager +from app.download_manager import DownloadManager from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes from protocol import BinaryEventTypes @@ -205,6 +206,7 @@ class PromptServer(): self.subgraph_manager = SubgraphManager() self.node_replace_manager = NodeReplaceManager() self.internal_routes = InternalRoutes(self) + self.download_manager = DownloadManager(self) if args.enable_download_api else None self.supports = ["custom_nodes_from_web"] self.prompt_queue = execution.PromptQueue(self) self.loop = loop @@ -1028,9 +1030,91 @@ class PromptServer(): return web.Response(status=200) + # -- Download API (gated behind --enable-download-api) -- + + def _require_download_api(handler): + async def wrapper(request): + if self.download_manager is None: + return web.json_response( + {"error": "Download API is not enabled. Start ComfyUI with --enable-download-api."}, + status=403, + ) + return await handler(request) + return wrapper + + @routes.post("/download/model") + @_require_download_api + async def post_download_model(request): + json_data = await request.json() + url = json_data.get("url") + directory = json_data.get("directory") + filename = json_data.get("filename") + client_id = json_data.get("client_id") + + if not url or not directory or not filename: + return web.json_response( + {"error": "Missing required fields: url, directory, filename"}, + status=400, + ) + + task, err = await self.download_manager.start_download(url, directory, filename, client_id=client_id) + if err: + status = 409 if "already" in err.lower() else 400 + return web.json_response({"error": err}, status=status) + + return web.json_response(task.to_dict(), status=201) + + @routes.get("/download/status") + @_require_download_api + async def get_download_status(request): + client_id = request.rel_url.query.get("client_id") + return web.json_response(self.download_manager.get_all_tasks(client_id=client_id)) + + @routes.get("/download/status/{task_id}") + @_require_download_api + async def get_download_task_status(request): + task_id = request.match_info["task_id"] + task_data = self.download_manager.get_task(task_id) + if task_data is None: + return web.json_response({"error": "Download not found"}, status=404) + return web.json_response(task_data) + + @routes.post("/download/pause/{task_id}") + @_require_download_api + async def post_download_pause(request): + task_id = request.match_info["task_id"] + err = self.download_manager.pause_download(task_id) + if err: + return web.json_response({"error": err}, status=400) + return web.json_response({"status": "paused"}) + + @routes.post("/download/resume/{task_id}") + @_require_download_api + async def post_download_resume(request): + task_id = request.match_info["task_id"] + err = self.download_manager.resume_download(task_id) + if err: + return web.json_response({"error": err}, status=400) + return web.json_response({"status": "resumed"}) + + @routes.post("/download/cancel/{task_id}") + @_require_download_api + async def post_download_cancel(request): + task_id = request.match_info["task_id"] + err = self.download_manager.cancel_download(task_id) + if err: + return web.json_response({"error": err}, status=400) + return web.json_response({"status": "cancelled"}) + async def setup(self): timeout = aiohttp.ClientTimeout(total=None) # no timeout self.client_session = aiohttp.ClientSession(timeout=timeout) + if self.download_manager is not None: + self.app.on_cleanup.append(self._cleanup_download_manager) + + async def _cleanup_download_manager(self, app): + if self.download_manager is not None: + await self.download_manager.close() def add_routes(self): self.user_manager.add_routes(self.routes)