mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-06 16:56:44 +08:00
Add a new server-side download API that allows frontends and desktop apps
to download models directly into ComfyUI's models directory, eliminating
the need for DOM scraping of the frontend UI.
New files:
- app/download_manager.py: Async download manager with streaming downloads,
pause/resume/cancel, manual redirect following with per-hop host validation,
sidecar metadata for safe resume, and concurrency limiting.
API endpoints (all under /download/, also mirrored at /api/download/):
- POST /download/model - Start a download (url, directory, filename)
- GET /download/status - List all downloads (filterable by client_id)
- GET /download/status/{id} - Get single download status
- POST /download/pause/{id} - Pause (cancels transfer, keeps temp)
- POST /download/resume/{id} - Resume (new request with Range header)
- POST /download/cancel/{id} - Cancel and clean up temp files
Security:
- Gated behind --enable-download-api CLI flag (403 if disabled)
- HTTPS-only with exact host allowlist (huggingface.co, civitai.com + CDNs)
- Manual redirect following with per-hop host validation (no SSRF)
- Path traversal protection via realpath + commonpath
- Extension allowlist (.safetensors, .sft)
- Filename sanitization (no separators, .., control chars)
- Destination re-checked before final rename
- Progress events scoped to initiating client_id
Closes Comfy-Org/ComfyUI-Desktop-2.0-Beta#293
Amp-Thread-ID: https://ampcode.com/threads/T-019d2344-139e-77a5-9f24-1cbb3b26a8ec
Co-authored-by: Amp <amp@ampcode.com>
508 lines
18 KiB
Python
508 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import Optional, TYPE_CHECKING
|
|
from urllib.parse import urlsplit
|
|
|
|
import aiohttp
|
|
from yarl import URL
|
|
|
|
import folder_paths
|
|
|
|
if TYPE_CHECKING:
|
|
from server import PromptServer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
ALLOWED_HTTPS_HOSTS = frozenset({
|
|
"huggingface.co",
|
|
"cdn-lfs.huggingface.co",
|
|
"cdn-lfs-us-1.huggingface.co",
|
|
"cdn-lfs-eu-1.huggingface.co",
|
|
"civitai.com",
|
|
"api.civitai.com",
|
|
})
|
|
|
|
ALLOWED_EXTENSIONS = frozenset({".safetensors", ".sft"})
|
|
|
|
MAX_CONCURRENT_DOWNLOADS = 3
|
|
MAX_TERMINAL_TASKS = 50
|
|
MAX_REDIRECTS = 10
|
|
|
|
DOWNLOAD_TEMP_SUFFIX = ".download_tmp"
|
|
DOWNLOAD_META_SUFFIX = ".download_meta"
|
|
|
|
|
|
class DownloadStatus(str, Enum):
|
|
PENDING = "pending"
|
|
DOWNLOADING = "downloading"
|
|
PAUSED = "paused"
|
|
COMPLETED = "completed"
|
|
ERROR = "error"
|
|
CANCELLED = "cancelled"
|
|
|
|
|
|
ACTIVE_STATUSES = frozenset({
|
|
DownloadStatus.PENDING,
|
|
DownloadStatus.DOWNLOADING,
|
|
DownloadStatus.PAUSED,
|
|
})
|
|
|
|
TERMINAL_STATUSES = frozenset({
|
|
DownloadStatus.COMPLETED,
|
|
DownloadStatus.ERROR,
|
|
DownloadStatus.CANCELLED,
|
|
})
|
|
|
|
|
|
@dataclass
|
|
class DownloadTask:
|
|
id: str
|
|
url: str
|
|
filename: str
|
|
directory: str
|
|
save_path: str
|
|
temp_path: str
|
|
meta_path: str
|
|
status: DownloadStatus = DownloadStatus.PENDING
|
|
progress: float = 0.0
|
|
received_bytes: int = 0
|
|
total_bytes: int = 0
|
|
speed_bytes_per_sec: float = 0.0
|
|
eta_seconds: float = 0.0
|
|
error: Optional[str] = None
|
|
created_at: float = field(default_factory=time.time)
|
|
client_id: Optional[str] = None
|
|
_worker: Optional[asyncio.Task] = field(default=None, repr=False)
|
|
_stop_reason: Optional[str] = field(default=None, repr=False)
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
"id": self.id,
|
|
"url": self.url,
|
|
"filename": self.filename,
|
|
"directory": self.directory,
|
|
"status": self.status.value,
|
|
"progress": self.progress,
|
|
"received_bytes": self.received_bytes,
|
|
"total_bytes": self.total_bytes,
|
|
"speed_bytes_per_sec": self.speed_bytes_per_sec,
|
|
"eta_seconds": self.eta_seconds,
|
|
"error": self.error,
|
|
"created_at": self.created_at,
|
|
}
|
|
|
|
|
|
class DownloadManager:
|
|
def __init__(self, server: PromptServer):
|
|
self.server = server
|
|
self.tasks: dict[str, DownloadTask] = {}
|
|
self._session: Optional[aiohttp.ClientSession] = None
|
|
self._semaphore = asyncio.Semaphore(MAX_CONCURRENT_DOWNLOADS)
|
|
|
|
async def _get_session(self) -> aiohttp.ClientSession:
|
|
if self._session is None or self._session.closed:
|
|
timeout = aiohttp.ClientTimeout(total=None, connect=30, sock_read=60)
|
|
self._session = aiohttp.ClientSession(timeout=timeout)
|
|
return self._session
|
|
|
|
async def close(self):
|
|
workers = [t._worker for t in self.tasks.values() if t._worker and not t._worker.done()]
|
|
for w in workers:
|
|
w.cancel()
|
|
if workers:
|
|
await asyncio.gather(*workers, return_exceptions=True)
|
|
if self._session and not self._session.closed:
|
|
await self._session.close()
|
|
|
|
# -- Validation --
|
|
|
|
@staticmethod
|
|
def _validate_url(url: str) -> Optional[str]:
|
|
try:
|
|
parts = urlsplit(url)
|
|
except Exception:
|
|
return "Invalid URL"
|
|
|
|
if parts.username or parts.password:
|
|
return "Credentials in URL are not allowed"
|
|
|
|
host = (parts.hostname or "").lower()
|
|
scheme = parts.scheme.lower()
|
|
|
|
if scheme != "https":
|
|
return "Only HTTPS URLs are allowed"
|
|
|
|
if host not in ALLOWED_HTTPS_HOSTS:
|
|
return f"Host '{host}' is not in the allowed list"
|
|
|
|
if parts.port not in (None, 443):
|
|
return "Custom ports are not allowed for remote downloads"
|
|
|
|
return None
|
|
|
|
@staticmethod
|
|
def _validate_filename(filename: str) -> Optional[str]:
|
|
if not filename:
|
|
return "Filename must not be empty"
|
|
ext = os.path.splitext(filename)[1].lower()
|
|
if ext not in ALLOWED_EXTENSIONS:
|
|
return f"File extension '{ext}' not allowed. Allowed: {', '.join(sorted(ALLOWED_EXTENSIONS))}"
|
|
if os.path.sep in filename or (os.path.altsep and os.path.altsep in filename):
|
|
return "Filename must not contain path separators"
|
|
if ".." in filename:
|
|
return "Filename must not contain '..'"
|
|
for ch in filename:
|
|
if ord(ch) < 32:
|
|
return "Filename must not contain control characters"
|
|
return None
|
|
|
|
@staticmethod
|
|
def _validate_directory(directory: str) -> Optional[str]:
|
|
if directory not in folder_paths.folder_names_and_paths:
|
|
valid = ', '.join(sorted(folder_paths.folder_names_and_paths.keys()))
|
|
return f"Unknown model directory '{directory}'. Valid directories: {valid}"
|
|
return None
|
|
|
|
@staticmethod
|
|
def _resolve_save_path(directory: str, filename: str) -> tuple[str, str, str]:
|
|
"""Returns (save_path, temp_path, meta_path) for a download."""
|
|
paths = folder_paths.folder_names_and_paths[directory][0]
|
|
base_dir = paths[0]
|
|
os.makedirs(base_dir, exist_ok=True)
|
|
|
|
save_path = os.path.join(base_dir, filename)
|
|
temp_path = save_path + DOWNLOAD_TEMP_SUFFIX
|
|
meta_path = save_path + DOWNLOAD_META_SUFFIX
|
|
|
|
real_save = os.path.realpath(save_path)
|
|
real_base = os.path.realpath(base_dir)
|
|
if os.path.commonpath([real_save, real_base]) != real_base:
|
|
raise ValueError("Resolved path escapes the model directory")
|
|
|
|
return save_path, temp_path, meta_path
|
|
|
|
# -- Sidecar metadata for resume validation --
|
|
|
|
@staticmethod
|
|
def _write_meta(meta_path: str, url: str, task_id: str):
|
|
try:
|
|
with open(meta_path, "w") as f:
|
|
json.dump({"url": url, "task_id": task_id}, f)
|
|
except OSError:
|
|
pass
|
|
|
|
@staticmethod
|
|
def _read_meta(meta_path: str) -> Optional[dict]:
|
|
try:
|
|
with open(meta_path, "r") as f:
|
|
return json.load(f)
|
|
except (OSError, json.JSONDecodeError):
|
|
return None
|
|
|
|
@staticmethod
|
|
def _cleanup_files(*paths: str):
|
|
for p in paths:
|
|
try:
|
|
if os.path.exists(p):
|
|
os.remove(p)
|
|
except OSError:
|
|
pass
|
|
|
|
# -- Task management --
|
|
|
|
def _prune_terminal_tasks(self):
|
|
terminal = [
|
|
(tid, t) for tid, t in self.tasks.items()
|
|
if t.status in TERMINAL_STATUSES
|
|
]
|
|
if len(terminal) > MAX_TERMINAL_TASKS:
|
|
terminal.sort(key=lambda x: x[1].created_at)
|
|
to_remove = len(terminal) - MAX_TERMINAL_TASKS
|
|
for tid, _ in terminal[:to_remove]:
|
|
del self.tasks[tid]
|
|
|
|
async def start_download(
|
|
self, url: str, directory: str, filename: str, client_id: Optional[str] = None
|
|
) -> tuple[Optional[DownloadTask], Optional[str]]:
|
|
err = self._validate_url(url)
|
|
if err:
|
|
return None, err
|
|
|
|
err = self._validate_filename(filename)
|
|
if err:
|
|
return None, err
|
|
|
|
err = self._validate_directory(directory)
|
|
if err:
|
|
return None, err
|
|
|
|
try:
|
|
save_path, temp_path, meta_path = self._resolve_save_path(directory, filename)
|
|
except ValueError as e:
|
|
return None, str(e)
|
|
|
|
if os.path.exists(save_path):
|
|
return None, f"File already exists: {directory}/{filename}"
|
|
|
|
# Reject duplicate active download by URL
|
|
for task in self.tasks.values():
|
|
if task.url == url and task.status in ACTIVE_STATUSES:
|
|
return None, f"Download already in progress for this URL (id: {task.id})"
|
|
|
|
# Reject duplicate active download by destination path (#4)
|
|
for task in self.tasks.values():
|
|
if task.save_path == save_path and task.status in ACTIVE_STATUSES:
|
|
return None, f"Download already in progress for {directory}/{filename} (id: {task.id})"
|
|
|
|
# Clean stale temp/meta if no active task owns them (#9)
|
|
existing_meta = self._read_meta(meta_path)
|
|
if existing_meta:
|
|
owning_task = self.tasks.get(existing_meta.get("task_id", ""))
|
|
if not owning_task or owning_task.status in TERMINAL_STATUSES:
|
|
if existing_meta.get("url") != url:
|
|
self._cleanup_files(temp_path, meta_path)
|
|
|
|
task = DownloadTask(
|
|
id=uuid.uuid4().hex[:12],
|
|
url=url,
|
|
filename=filename,
|
|
directory=directory,
|
|
save_path=save_path,
|
|
temp_path=temp_path,
|
|
meta_path=meta_path,
|
|
client_id=client_id,
|
|
)
|
|
self.tasks[task.id] = task
|
|
self._prune_terminal_tasks()
|
|
|
|
task._worker = asyncio.create_task(self._run_download(task))
|
|
return task, None
|
|
|
|
# -- Redirect-safe fetch (#1, #2, #3) --
|
|
|
|
async def _fetch_with_validated_redirects(
|
|
self, session: aiohttp.ClientSession, url: str, headers: dict
|
|
) -> aiohttp.ClientResponse:
|
|
"""Follow redirects manually, validating each hop against the allowlist."""
|
|
current_url = url
|
|
for _ in range(MAX_REDIRECTS + 1):
|
|
resp = await session.get(current_url, headers=headers, allow_redirects=False)
|
|
if resp.status not in (301, 302, 303, 307, 308):
|
|
return resp
|
|
|
|
location = resp.headers.get("Location")
|
|
await resp.release()
|
|
if not location:
|
|
raise ValueError("Redirect without Location header")
|
|
|
|
resolved = URL(current_url).join(URL(location))
|
|
current_url = str(resolved)
|
|
|
|
# Validate the redirect target host
|
|
parts = urlsplit(current_url)
|
|
host = (parts.hostname or "").lower()
|
|
scheme = parts.scheme.lower()
|
|
|
|
if scheme != "https":
|
|
raise ValueError(f"Redirect to non-HTTPS URL: {current_url}")
|
|
if host not in ALLOWED_HTTPS_HOSTS:
|
|
# Allow CDN hosts that HuggingFace/CivitAI commonly redirect to
|
|
raise ValueError(f"Redirect to disallowed host: {host}")
|
|
|
|
# 303 means GET with no Range
|
|
if resp.status == 303:
|
|
headers = {k: v for k, v in headers.items() if k.lower() != "range"}
|
|
|
|
raise ValueError(f"Too many redirects (>{MAX_REDIRECTS})")
|
|
|
|
# -- Download worker --
|
|
|
|
async def _run_download(self, task: DownloadTask):
|
|
try:
|
|
async with self._semaphore:
|
|
await self._run_download_inner(task)
|
|
except asyncio.CancelledError:
|
|
if task._stop_reason == "pause":
|
|
task.status = DownloadStatus.PAUSED
|
|
task.speed_bytes_per_sec = 0
|
|
task.eta_seconds = 0
|
|
await self._send_progress(task)
|
|
else:
|
|
task.status = DownloadStatus.CANCELLED
|
|
await self._send_progress(task)
|
|
self._cleanup_files(task.temp_path, task.meta_path)
|
|
except Exception as e:
|
|
task.status = DownloadStatus.ERROR
|
|
task.error = str(e)
|
|
await self._send_progress(task)
|
|
logger.exception("Download error for %s", task.url)
|
|
|
|
async def _run_download_inner(self, task: DownloadTask):
|
|
session = await self._get_session()
|
|
headers = {}
|
|
|
|
# Resume support with sidecar validation (#9)
|
|
if os.path.exists(task.temp_path):
|
|
meta = self._read_meta(task.meta_path)
|
|
if meta and meta.get("url") == task.url:
|
|
existing_size = os.path.getsize(task.temp_path)
|
|
if existing_size > 0:
|
|
headers["Range"] = f"bytes={existing_size}-"
|
|
task.received_bytes = existing_size
|
|
else:
|
|
self._cleanup_files(task.temp_path, task.meta_path)
|
|
|
|
self._write_meta(task.meta_path, task.url, task.id)
|
|
task.status = DownloadStatus.DOWNLOADING
|
|
await self._send_progress(task)
|
|
|
|
resp = await self._fetch_with_validated_redirects(session, task.url, headers)
|
|
try:
|
|
if resp.status == 416:
|
|
content_range = resp.headers.get("Content-Range", "")
|
|
if content_range:
|
|
total_str = content_range.split("/")[-1]
|
|
if total_str != "*":
|
|
total = int(total_str)
|
|
if task.received_bytes >= total:
|
|
if not os.path.exists(task.save_path):
|
|
os.rename(task.temp_path, task.save_path)
|
|
self._cleanup_files(task.meta_path)
|
|
task.status = DownloadStatus.COMPLETED
|
|
task.progress = 1.0
|
|
task.total_bytes = total
|
|
await self._send_progress(task)
|
|
return
|
|
raise ValueError(f"HTTP 416 Range Not Satisfiable")
|
|
|
|
if resp.status not in (200, 206):
|
|
task.status = DownloadStatus.ERROR
|
|
task.error = f"HTTP {resp.status}"
|
|
await self._send_progress(task)
|
|
return
|
|
|
|
if resp.status == 200:
|
|
task.received_bytes = 0
|
|
|
|
content_length = resp.content_length
|
|
if resp.status == 206 and content_length:
|
|
task.total_bytes = task.received_bytes + content_length
|
|
elif resp.status == 200 and content_length:
|
|
task.total_bytes = content_length
|
|
|
|
mode = "ab" if resp.status == 206 else "wb"
|
|
speed_window_start = time.monotonic()
|
|
speed_window_bytes = 0
|
|
last_progress_time = 0.0
|
|
|
|
with open(task.temp_path, mode) as f:
|
|
async for chunk in resp.content.iter_chunked(1024 * 64):
|
|
f.write(chunk)
|
|
task.received_bytes += len(chunk)
|
|
speed_window_bytes += len(chunk)
|
|
|
|
now = time.monotonic()
|
|
elapsed = now - speed_window_start
|
|
if elapsed > 0.5:
|
|
task.speed_bytes_per_sec = speed_window_bytes / elapsed
|
|
if task.total_bytes > 0 and task.speed_bytes_per_sec > 0:
|
|
remaining = task.total_bytes - task.received_bytes
|
|
task.eta_seconds = remaining / task.speed_bytes_per_sec
|
|
speed_window_start = now
|
|
speed_window_bytes = 0
|
|
|
|
if task.total_bytes > 0:
|
|
task.progress = task.received_bytes / task.total_bytes
|
|
|
|
if now - last_progress_time >= 0.25:
|
|
await self._send_progress(task)
|
|
last_progress_time = now
|
|
finally:
|
|
resp.release()
|
|
|
|
# Final cancel check before committing (#7)
|
|
if task._stop_reason is not None:
|
|
raise asyncio.CancelledError()
|
|
|
|
# Re-check destination before finalizing (#10)
|
|
if os.path.exists(task.save_path):
|
|
task.status = DownloadStatus.ERROR
|
|
task.error = f"Destination file appeared during download: {task.directory}/{task.filename}"
|
|
await self._send_progress(task)
|
|
return
|
|
|
|
os.replace(task.temp_path, task.save_path)
|
|
self._cleanup_files(task.meta_path)
|
|
task.status = DownloadStatus.COMPLETED
|
|
task.progress = 1.0
|
|
task.speed_bytes_per_sec = 0
|
|
task.eta_seconds = 0
|
|
await self._send_progress(task)
|
|
logger.info("Download complete: %s/%s", task.directory, task.filename)
|
|
|
|
# -- Progress (#8, #14) --
|
|
|
|
async def _send_progress(self, task: DownloadTask):
|
|
try:
|
|
self.server.send_sync("download_progress", task.to_dict(), task.client_id)
|
|
except Exception:
|
|
logger.exception("Failed to send download progress event")
|
|
|
|
# -- Control operations (#5, #6, #13) --
|
|
|
|
def pause_download(self, task_id: str) -> Optional[str]:
|
|
task = self.tasks.get(task_id)
|
|
if not task:
|
|
return "Download not found"
|
|
if task.status not in (DownloadStatus.PENDING, DownloadStatus.DOWNLOADING):
|
|
return f"Cannot pause download in state '{task.status.value}'"
|
|
task._stop_reason = "pause"
|
|
if task._worker and not task._worker.done():
|
|
task._worker.cancel()
|
|
return None
|
|
|
|
def resume_download(self, task_id: str) -> Optional[str]:
|
|
task = self.tasks.get(task_id)
|
|
if not task:
|
|
return "Download not found"
|
|
if task.status != DownloadStatus.PAUSED:
|
|
return f"Cannot resume download in state '{task.status.value}'"
|
|
task._stop_reason = None
|
|
task.status = DownloadStatus.PENDING
|
|
task._worker = asyncio.create_task(self._run_download(task))
|
|
return None
|
|
|
|
def cancel_download(self, task_id: str) -> Optional[str]:
|
|
task = self.tasks.get(task_id)
|
|
if not task:
|
|
return "Download not found"
|
|
if task.status in TERMINAL_STATUSES:
|
|
return f"Cannot cancel download in state '{task.status.value}'"
|
|
task._stop_reason = "cancel"
|
|
if task._worker and not task._worker.done():
|
|
task._worker.cancel()
|
|
else:
|
|
task.status = DownloadStatus.CANCELLED
|
|
self._cleanup_files(task.temp_path, task.meta_path)
|
|
return None
|
|
|
|
# -- Query --
|
|
|
|
def get_all_tasks(self, client_id: Optional[str] = None) -> list[dict]:
|
|
tasks = self.tasks.values()
|
|
if client_id is not None:
|
|
tasks = [t for t in tasks if t.client_id == client_id]
|
|
return [t.to_dict() for t in tasks]
|
|
|
|
def get_task(self, task_id: str) -> Optional[dict]:
|
|
task = self.tasks.get(task_id)
|
|
return task.to_dict() if task else None
|