mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-14 04:22:31 +08:00
feat: add model download API gated behind --enable-download-api
Add a new server-side download API that allows frontends and desktop apps
to download models directly into ComfyUI's models directory, eliminating
the need for DOM scraping of the frontend UI.
New files:
- app/download_manager.py: Async download manager with streaming downloads,
pause/resume/cancel, manual redirect following with per-hop host validation,
sidecar metadata for safe resume, and concurrency limiting.
API endpoints (all under /download/, also mirrored at /api/download/):
- POST /download/model - Start a download (url, directory, filename)
- GET /download/status - List all downloads (filterable by client_id)
- GET /download/status/{id} - Get single download status
- POST /download/pause/{id} - Pause (cancels transfer, keeps temp)
- POST /download/resume/{id} - Resume (new request with Range header)
- POST /download/cancel/{id} - Cancel and clean up temp files
Security:
- Gated behind --enable-download-api CLI flag (403 if disabled)
- HTTPS-only with exact host allowlist (huggingface.co, civitai.com + CDNs)
- Manual redirect following with per-hop host validation (no SSRF)
- Path traversal protection via realpath + commonpath
- Extension allowlist (.safetensors, .sft)
- Filename sanitization (no separators, .., control chars)
- Destination re-checked before final rename
- Progress events scoped to initiating client_id
Closes Comfy-Org/ComfyUI-Desktop-2.0-Beta#293
Amp-Thread-ID: https://ampcode.com/threads/T-019d2344-139e-77a5-9f24-1cbb3b26a8ec
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
b53b10ea61
commit
2f7b77f341
507
app/download_manager.py
Normal file
507
app/download_manager.py
Normal file
@ -0,0 +1,507 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, TYPE_CHECKING
|
||||||
|
from urllib.parse import urlsplit
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from server import PromptServer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ALLOWED_HTTPS_HOSTS = frozenset({
|
||||||
|
"huggingface.co",
|
||||||
|
"cdn-lfs.huggingface.co",
|
||||||
|
"cdn-lfs-us-1.huggingface.co",
|
||||||
|
"cdn-lfs-eu-1.huggingface.co",
|
||||||
|
"civitai.com",
|
||||||
|
"api.civitai.com",
|
||||||
|
})
|
||||||
|
|
||||||
|
ALLOWED_EXTENSIONS = frozenset({".safetensors", ".sft"})
|
||||||
|
|
||||||
|
MAX_CONCURRENT_DOWNLOADS = 3
|
||||||
|
MAX_TERMINAL_TASKS = 50
|
||||||
|
MAX_REDIRECTS = 10
|
||||||
|
|
||||||
|
DOWNLOAD_TEMP_SUFFIX = ".download_tmp"
|
||||||
|
DOWNLOAD_META_SUFFIX = ".download_meta"
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadStatus(str, Enum):
|
||||||
|
PENDING = "pending"
|
||||||
|
DOWNLOADING = "downloading"
|
||||||
|
PAUSED = "paused"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
ERROR = "error"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
|
ACTIVE_STATUSES = frozenset({
|
||||||
|
DownloadStatus.PENDING,
|
||||||
|
DownloadStatus.DOWNLOADING,
|
||||||
|
DownloadStatus.PAUSED,
|
||||||
|
})
|
||||||
|
|
||||||
|
TERMINAL_STATUSES = frozenset({
|
||||||
|
DownloadStatus.COMPLETED,
|
||||||
|
DownloadStatus.ERROR,
|
||||||
|
DownloadStatus.CANCELLED,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DownloadTask:
|
||||||
|
id: str
|
||||||
|
url: str
|
||||||
|
filename: str
|
||||||
|
directory: str
|
||||||
|
save_path: str
|
||||||
|
temp_path: str
|
||||||
|
meta_path: str
|
||||||
|
status: DownloadStatus = DownloadStatus.PENDING
|
||||||
|
progress: float = 0.0
|
||||||
|
received_bytes: int = 0
|
||||||
|
total_bytes: int = 0
|
||||||
|
speed_bytes_per_sec: float = 0.0
|
||||||
|
eta_seconds: float = 0.0
|
||||||
|
error: Optional[str] = None
|
||||||
|
created_at: float = field(default_factory=time.time)
|
||||||
|
client_id: Optional[str] = None
|
||||||
|
_worker: Optional[asyncio.Task] = field(default=None, repr=False)
|
||||||
|
_stop_reason: Optional[str] = field(default=None, repr=False)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"url": self.url,
|
||||||
|
"filename": self.filename,
|
||||||
|
"directory": self.directory,
|
||||||
|
"status": self.status.value,
|
||||||
|
"progress": self.progress,
|
||||||
|
"received_bytes": self.received_bytes,
|
||||||
|
"total_bytes": self.total_bytes,
|
||||||
|
"speed_bytes_per_sec": self.speed_bytes_per_sec,
|
||||||
|
"eta_seconds": self.eta_seconds,
|
||||||
|
"error": self.error,
|
||||||
|
"created_at": self.created_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadManager:
|
||||||
|
def __init__(self, server: PromptServer):
|
||||||
|
self.server = server
|
||||||
|
self.tasks: dict[str, DownloadTask] = {}
|
||||||
|
self._session: Optional[aiohttp.ClientSession] = None
|
||||||
|
self._semaphore = asyncio.Semaphore(MAX_CONCURRENT_DOWNLOADS)
|
||||||
|
|
||||||
|
async def _get_session(self) -> aiohttp.ClientSession:
|
||||||
|
if self._session is None or self._session.closed:
|
||||||
|
timeout = aiohttp.ClientTimeout(total=None, connect=30, sock_read=60)
|
||||||
|
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
workers = [t._worker for t in self.tasks.values() if t._worker and not t._worker.done()]
|
||||||
|
for w in workers:
|
||||||
|
w.cancel()
|
||||||
|
if workers:
|
||||||
|
await asyncio.gather(*workers, return_exceptions=True)
|
||||||
|
if self._session and not self._session.closed:
|
||||||
|
await self._session.close()
|
||||||
|
|
||||||
|
# -- Validation --
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_url(url: str) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
parts = urlsplit(url)
|
||||||
|
except Exception:
|
||||||
|
return "Invalid URL"
|
||||||
|
|
||||||
|
if parts.username or parts.password:
|
||||||
|
return "Credentials in URL are not allowed"
|
||||||
|
|
||||||
|
host = (parts.hostname or "").lower()
|
||||||
|
scheme = parts.scheme.lower()
|
||||||
|
|
||||||
|
if scheme != "https":
|
||||||
|
return "Only HTTPS URLs are allowed"
|
||||||
|
|
||||||
|
if host not in ALLOWED_HTTPS_HOSTS:
|
||||||
|
return f"Host '{host}' is not in the allowed list"
|
||||||
|
|
||||||
|
if parts.port not in (None, 443):
|
||||||
|
return "Custom ports are not allowed for remote downloads"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_filename(filename: str) -> Optional[str]:
|
||||||
|
if not filename:
|
||||||
|
return "Filename must not be empty"
|
||||||
|
ext = os.path.splitext(filename)[1].lower()
|
||||||
|
if ext not in ALLOWED_EXTENSIONS:
|
||||||
|
return f"File extension '{ext}' not allowed. Allowed: {', '.join(sorted(ALLOWED_EXTENSIONS))}"
|
||||||
|
if os.path.sep in filename or (os.path.altsep and os.path.altsep in filename):
|
||||||
|
return "Filename must not contain path separators"
|
||||||
|
if ".." in filename:
|
||||||
|
return "Filename must not contain '..'"
|
||||||
|
for ch in filename:
|
||||||
|
if ord(ch) < 32:
|
||||||
|
return "Filename must not contain control characters"
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_directory(directory: str) -> Optional[str]:
|
||||||
|
if directory not in folder_paths.folder_names_and_paths:
|
||||||
|
valid = ', '.join(sorted(folder_paths.folder_names_and_paths.keys()))
|
||||||
|
return f"Unknown model directory '{directory}'. Valid directories: {valid}"
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_save_path(directory: str, filename: str) -> tuple[str, str, str]:
|
||||||
|
"""Returns (save_path, temp_path, meta_path) for a download."""
|
||||||
|
paths = folder_paths.folder_names_and_paths[directory][0]
|
||||||
|
base_dir = paths[0]
|
||||||
|
os.makedirs(base_dir, exist_ok=True)
|
||||||
|
|
||||||
|
save_path = os.path.join(base_dir, filename)
|
||||||
|
temp_path = save_path + DOWNLOAD_TEMP_SUFFIX
|
||||||
|
meta_path = save_path + DOWNLOAD_META_SUFFIX
|
||||||
|
|
||||||
|
real_save = os.path.realpath(save_path)
|
||||||
|
real_base = os.path.realpath(base_dir)
|
||||||
|
if os.path.commonpath([real_save, real_base]) != real_base:
|
||||||
|
raise ValueError("Resolved path escapes the model directory")
|
||||||
|
|
||||||
|
return save_path, temp_path, meta_path
|
||||||
|
|
||||||
|
# -- Sidecar metadata for resume validation --
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _write_meta(meta_path: str, url: str, task_id: str):
|
||||||
|
try:
|
||||||
|
with open(meta_path, "w") as f:
|
||||||
|
json.dump({"url": url, "task_id": task_id}, f)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _read_meta(meta_path: str) -> Optional[dict]:
|
||||||
|
try:
|
||||||
|
with open(meta_path, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except (OSError, json.JSONDecodeError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _cleanup_files(*paths: str):
|
||||||
|
for p in paths:
|
||||||
|
try:
|
||||||
|
if os.path.exists(p):
|
||||||
|
os.remove(p)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# -- Task management --
|
||||||
|
|
||||||
|
def _prune_terminal_tasks(self):
|
||||||
|
terminal = [
|
||||||
|
(tid, t) for tid, t in self.tasks.items()
|
||||||
|
if t.status in TERMINAL_STATUSES
|
||||||
|
]
|
||||||
|
if len(terminal) > MAX_TERMINAL_TASKS:
|
||||||
|
terminal.sort(key=lambda x: x[1].created_at)
|
||||||
|
to_remove = len(terminal) - MAX_TERMINAL_TASKS
|
||||||
|
for tid, _ in terminal[:to_remove]:
|
||||||
|
del self.tasks[tid]
|
||||||
|
|
||||||
|
async def start_download(
|
||||||
|
self, url: str, directory: str, filename: str, client_id: Optional[str] = None
|
||||||
|
) -> tuple[Optional[DownloadTask], Optional[str]]:
|
||||||
|
err = self._validate_url(url)
|
||||||
|
if err:
|
||||||
|
return None, err
|
||||||
|
|
||||||
|
err = self._validate_filename(filename)
|
||||||
|
if err:
|
||||||
|
return None, err
|
||||||
|
|
||||||
|
err = self._validate_directory(directory)
|
||||||
|
if err:
|
||||||
|
return None, err
|
||||||
|
|
||||||
|
try:
|
||||||
|
save_path, temp_path, meta_path = self._resolve_save_path(directory, filename)
|
||||||
|
except ValueError as e:
|
||||||
|
return None, str(e)
|
||||||
|
|
||||||
|
if os.path.exists(save_path):
|
||||||
|
return None, f"File already exists: {directory}/{filename}"
|
||||||
|
|
||||||
|
# Reject duplicate active download by URL
|
||||||
|
for task in self.tasks.values():
|
||||||
|
if task.url == url and task.status in ACTIVE_STATUSES:
|
||||||
|
return None, f"Download already in progress for this URL (id: {task.id})"
|
||||||
|
|
||||||
|
# Reject duplicate active download by destination path (#4)
|
||||||
|
for task in self.tasks.values():
|
||||||
|
if task.save_path == save_path and task.status in ACTIVE_STATUSES:
|
||||||
|
return None, f"Download already in progress for {directory}/{filename} (id: {task.id})"
|
||||||
|
|
||||||
|
# Clean stale temp/meta if no active task owns them (#9)
|
||||||
|
existing_meta = self._read_meta(meta_path)
|
||||||
|
if existing_meta:
|
||||||
|
owning_task = self.tasks.get(existing_meta.get("task_id", ""))
|
||||||
|
if not owning_task or owning_task.status in TERMINAL_STATUSES:
|
||||||
|
if existing_meta.get("url") != url:
|
||||||
|
self._cleanup_files(temp_path, meta_path)
|
||||||
|
|
||||||
|
task = DownloadTask(
|
||||||
|
id=uuid.uuid4().hex[:12],
|
||||||
|
url=url,
|
||||||
|
filename=filename,
|
||||||
|
directory=directory,
|
||||||
|
save_path=save_path,
|
||||||
|
temp_path=temp_path,
|
||||||
|
meta_path=meta_path,
|
||||||
|
client_id=client_id,
|
||||||
|
)
|
||||||
|
self.tasks[task.id] = task
|
||||||
|
self._prune_terminal_tasks()
|
||||||
|
|
||||||
|
task._worker = asyncio.create_task(self._run_download(task))
|
||||||
|
return task, None
|
||||||
|
|
||||||
|
# -- Redirect-safe fetch (#1, #2, #3) --
|
||||||
|
|
||||||
|
async def _fetch_with_validated_redirects(
|
||||||
|
self, session: aiohttp.ClientSession, url: str, headers: dict
|
||||||
|
) -> aiohttp.ClientResponse:
|
||||||
|
"""Follow redirects manually, validating each hop against the allowlist."""
|
||||||
|
current_url = url
|
||||||
|
for _ in range(MAX_REDIRECTS + 1):
|
||||||
|
resp = await session.get(current_url, headers=headers, allow_redirects=False)
|
||||||
|
if resp.status not in (301, 302, 303, 307, 308):
|
||||||
|
return resp
|
||||||
|
|
||||||
|
location = resp.headers.get("Location")
|
||||||
|
await resp.release()
|
||||||
|
if not location:
|
||||||
|
raise ValueError("Redirect without Location header")
|
||||||
|
|
||||||
|
resolved = URL(current_url).join(URL(location))
|
||||||
|
current_url = str(resolved)
|
||||||
|
|
||||||
|
# Validate the redirect target host
|
||||||
|
parts = urlsplit(current_url)
|
||||||
|
host = (parts.hostname or "").lower()
|
||||||
|
scheme = parts.scheme.lower()
|
||||||
|
|
||||||
|
if scheme != "https":
|
||||||
|
raise ValueError(f"Redirect to non-HTTPS URL: {current_url}")
|
||||||
|
if host not in ALLOWED_HTTPS_HOSTS:
|
||||||
|
# Allow CDN hosts that HuggingFace/CivitAI commonly redirect to
|
||||||
|
raise ValueError(f"Redirect to disallowed host: {host}")
|
||||||
|
|
||||||
|
# 303 means GET with no Range
|
||||||
|
if resp.status == 303:
|
||||||
|
headers = {k: v for k, v in headers.items() if k.lower() != "range"}
|
||||||
|
|
||||||
|
raise ValueError(f"Too many redirects (>{MAX_REDIRECTS})")
|
||||||
|
|
||||||
|
# -- Download worker --
|
||||||
|
|
||||||
|
async def _run_download(self, task: DownloadTask):
|
||||||
|
try:
|
||||||
|
async with self._semaphore:
|
||||||
|
await self._run_download_inner(task)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
if task._stop_reason == "pause":
|
||||||
|
task.status = DownloadStatus.PAUSED
|
||||||
|
task.speed_bytes_per_sec = 0
|
||||||
|
task.eta_seconds = 0
|
||||||
|
await self._send_progress(task)
|
||||||
|
else:
|
||||||
|
task.status = DownloadStatus.CANCELLED
|
||||||
|
await self._send_progress(task)
|
||||||
|
self._cleanup_files(task.temp_path, task.meta_path)
|
||||||
|
except Exception as e:
|
||||||
|
task.status = DownloadStatus.ERROR
|
||||||
|
task.error = str(e)
|
||||||
|
await self._send_progress(task)
|
||||||
|
logger.exception("Download error for %s", task.url)
|
||||||
|
|
||||||
|
async def _run_download_inner(self, task: DownloadTask):
|
||||||
|
session = await self._get_session()
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
# Resume support with sidecar validation (#9)
|
||||||
|
if os.path.exists(task.temp_path):
|
||||||
|
meta = self._read_meta(task.meta_path)
|
||||||
|
if meta and meta.get("url") == task.url:
|
||||||
|
existing_size = os.path.getsize(task.temp_path)
|
||||||
|
if existing_size > 0:
|
||||||
|
headers["Range"] = f"bytes={existing_size}-"
|
||||||
|
task.received_bytes = existing_size
|
||||||
|
else:
|
||||||
|
self._cleanup_files(task.temp_path, task.meta_path)
|
||||||
|
|
||||||
|
self._write_meta(task.meta_path, task.url, task.id)
|
||||||
|
task.status = DownloadStatus.DOWNLOADING
|
||||||
|
await self._send_progress(task)
|
||||||
|
|
||||||
|
resp = await self._fetch_with_validated_redirects(session, task.url, headers)
|
||||||
|
try:
|
||||||
|
if resp.status == 416:
|
||||||
|
content_range = resp.headers.get("Content-Range", "")
|
||||||
|
if content_range:
|
||||||
|
total_str = content_range.split("/")[-1]
|
||||||
|
if total_str != "*":
|
||||||
|
total = int(total_str)
|
||||||
|
if task.received_bytes >= total:
|
||||||
|
if not os.path.exists(task.save_path):
|
||||||
|
os.rename(task.temp_path, task.save_path)
|
||||||
|
self._cleanup_files(task.meta_path)
|
||||||
|
task.status = DownloadStatus.COMPLETED
|
||||||
|
task.progress = 1.0
|
||||||
|
task.total_bytes = total
|
||||||
|
await self._send_progress(task)
|
||||||
|
return
|
||||||
|
raise ValueError(f"HTTP 416 Range Not Satisfiable")
|
||||||
|
|
||||||
|
if resp.status not in (200, 206):
|
||||||
|
task.status = DownloadStatus.ERROR
|
||||||
|
task.error = f"HTTP {resp.status}"
|
||||||
|
await self._send_progress(task)
|
||||||
|
return
|
||||||
|
|
||||||
|
if resp.status == 200:
|
||||||
|
task.received_bytes = 0
|
||||||
|
|
||||||
|
content_length = resp.content_length
|
||||||
|
if resp.status == 206 and content_length:
|
||||||
|
task.total_bytes = task.received_bytes + content_length
|
||||||
|
elif resp.status == 200 and content_length:
|
||||||
|
task.total_bytes = content_length
|
||||||
|
|
||||||
|
mode = "ab" if resp.status == 206 else "wb"
|
||||||
|
speed_window_start = time.monotonic()
|
||||||
|
speed_window_bytes = 0
|
||||||
|
last_progress_time = 0.0
|
||||||
|
|
||||||
|
with open(task.temp_path, mode) as f:
|
||||||
|
async for chunk in resp.content.iter_chunked(1024 * 64):
|
||||||
|
f.write(chunk)
|
||||||
|
task.received_bytes += len(chunk)
|
||||||
|
speed_window_bytes += len(chunk)
|
||||||
|
|
||||||
|
now = time.monotonic()
|
||||||
|
elapsed = now - speed_window_start
|
||||||
|
if elapsed > 0.5:
|
||||||
|
task.speed_bytes_per_sec = speed_window_bytes / elapsed
|
||||||
|
if task.total_bytes > 0 and task.speed_bytes_per_sec > 0:
|
||||||
|
remaining = task.total_bytes - task.received_bytes
|
||||||
|
task.eta_seconds = remaining / task.speed_bytes_per_sec
|
||||||
|
speed_window_start = now
|
||||||
|
speed_window_bytes = 0
|
||||||
|
|
||||||
|
if task.total_bytes > 0:
|
||||||
|
task.progress = task.received_bytes / task.total_bytes
|
||||||
|
|
||||||
|
if now - last_progress_time >= 0.25:
|
||||||
|
await self._send_progress(task)
|
||||||
|
last_progress_time = now
|
||||||
|
finally:
|
||||||
|
resp.release()
|
||||||
|
|
||||||
|
# Final cancel check before committing (#7)
|
||||||
|
if task._stop_reason is not None:
|
||||||
|
raise asyncio.CancelledError()
|
||||||
|
|
||||||
|
# Re-check destination before finalizing (#10)
|
||||||
|
if os.path.exists(task.save_path):
|
||||||
|
task.status = DownloadStatus.ERROR
|
||||||
|
task.error = f"Destination file appeared during download: {task.directory}/{task.filename}"
|
||||||
|
await self._send_progress(task)
|
||||||
|
return
|
||||||
|
|
||||||
|
os.replace(task.temp_path, task.save_path)
|
||||||
|
self._cleanup_files(task.meta_path)
|
||||||
|
task.status = DownloadStatus.COMPLETED
|
||||||
|
task.progress = 1.0
|
||||||
|
task.speed_bytes_per_sec = 0
|
||||||
|
task.eta_seconds = 0
|
||||||
|
await self._send_progress(task)
|
||||||
|
logger.info("Download complete: %s/%s", task.directory, task.filename)
|
||||||
|
|
||||||
|
# -- Progress (#8, #14) --
|
||||||
|
|
||||||
|
async def _send_progress(self, task: DownloadTask):
|
||||||
|
try:
|
||||||
|
self.server.send_sync("download_progress", task.to_dict(), task.client_id)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to send download progress event")
|
||||||
|
|
||||||
|
# -- Control operations (#5, #6, #13) --
|
||||||
|
|
||||||
|
def pause_download(self, task_id: str) -> Optional[str]:
|
||||||
|
task = self.tasks.get(task_id)
|
||||||
|
if not task:
|
||||||
|
return "Download not found"
|
||||||
|
if task.status not in (DownloadStatus.PENDING, DownloadStatus.DOWNLOADING):
|
||||||
|
return f"Cannot pause download in state '{task.status.value}'"
|
||||||
|
task._stop_reason = "pause"
|
||||||
|
if task._worker and not task._worker.done():
|
||||||
|
task._worker.cancel()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def resume_download(self, task_id: str) -> Optional[str]:
|
||||||
|
task = self.tasks.get(task_id)
|
||||||
|
if not task:
|
||||||
|
return "Download not found"
|
||||||
|
if task.status != DownloadStatus.PAUSED:
|
||||||
|
return f"Cannot resume download in state '{task.status.value}'"
|
||||||
|
task._stop_reason = None
|
||||||
|
task.status = DownloadStatus.PENDING
|
||||||
|
task._worker = asyncio.create_task(self._run_download(task))
|
||||||
|
return None
|
||||||
|
|
||||||
|
def cancel_download(self, task_id: str) -> Optional[str]:
|
||||||
|
task = self.tasks.get(task_id)
|
||||||
|
if not task:
|
||||||
|
return "Download not found"
|
||||||
|
if task.status in TERMINAL_STATUSES:
|
||||||
|
return f"Cannot cancel download in state '{task.status.value}'"
|
||||||
|
task._stop_reason = "cancel"
|
||||||
|
if task._worker and not task._worker.done():
|
||||||
|
task._worker.cancel()
|
||||||
|
else:
|
||||||
|
task.status = DownloadStatus.CANCELLED
|
||||||
|
self._cleanup_files(task.temp_path, task.meta_path)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# -- Query --
|
||||||
|
|
||||||
|
def get_all_tasks(self, client_id: Optional[str] = None) -> list[dict]:
|
||||||
|
tasks = self.tasks.values()
|
||||||
|
if client_id is not None:
|
||||||
|
tasks = [t for t in tasks if t.client_id == client_id]
|
||||||
|
return [t.to_dict() for t in tasks]
|
||||||
|
|
||||||
|
def get_task(self, task_id: str) -> Optional[dict]:
|
||||||
|
task = self.tasks.get(task_id)
|
||||||
|
return task.to_dict() if task else None
|
||||||
@ -224,6 +224,8 @@ parser.add_argument("--user-directory", type=is_valid_directory, default=None, h
|
|||||||
|
|
||||||
parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
|
parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
|
||||||
|
|
||||||
|
parser.add_argument("--enable-download-api", action="store_true", help="Enable the model download API. When set, ComfyUI exposes endpoints that allow downloading model files directly into the models directory. Only HTTPS downloads from allowed hosts (huggingface.co, civitai.com) are permitted.")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--comfy-api-base",
|
"--comfy-api-base",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@ -16,6 +16,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
|
|||||||
"extension": {"manager": {"supports_v4": True}},
|
"extension": {"manager": {"supports_v4": True}},
|
||||||
"node_replacements": True,
|
"node_replacements": True,
|
||||||
"assets": args.enable_assets,
|
"assets": args.enable_assets,
|
||||||
|
"download_api": args.enable_download_api,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
84
server.py
84
server.py
@ -43,6 +43,7 @@ from app.model_manager import ModelFileManager
|
|||||||
from app.custom_node_manager import CustomNodeManager
|
from app.custom_node_manager import CustomNodeManager
|
||||||
from app.subgraph_manager import SubgraphManager
|
from app.subgraph_manager import SubgraphManager
|
||||||
from app.node_replace_manager import NodeReplaceManager
|
from app.node_replace_manager import NodeReplaceManager
|
||||||
|
from app.download_manager import DownloadManager
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||||
from protocol import BinaryEventTypes
|
from protocol import BinaryEventTypes
|
||||||
@ -205,6 +206,7 @@ class PromptServer():
|
|||||||
self.subgraph_manager = SubgraphManager()
|
self.subgraph_manager = SubgraphManager()
|
||||||
self.node_replace_manager = NodeReplaceManager()
|
self.node_replace_manager = NodeReplaceManager()
|
||||||
self.internal_routes = InternalRoutes(self)
|
self.internal_routes = InternalRoutes(self)
|
||||||
|
self.download_manager = DownloadManager(self) if args.enable_download_api else None
|
||||||
self.supports = ["custom_nodes_from_web"]
|
self.supports = ["custom_nodes_from_web"]
|
||||||
self.prompt_queue = execution.PromptQueue(self)
|
self.prompt_queue = execution.PromptQueue(self)
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
@ -1028,9 +1030,91 @@ class PromptServer():
|
|||||||
|
|
||||||
return web.Response(status=200)
|
return web.Response(status=200)
|
||||||
|
|
||||||
|
# -- Download API (gated behind --enable-download-api) --
|
||||||
|
|
||||||
|
def _require_download_api(handler):
|
||||||
|
async def wrapper(request):
|
||||||
|
if self.download_manager is None:
|
||||||
|
return web.json_response(
|
||||||
|
{"error": "Download API is not enabled. Start ComfyUI with --enable-download-api."},
|
||||||
|
status=403,
|
||||||
|
)
|
||||||
|
return await handler(request)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
@routes.post("/download/model")
|
||||||
|
@_require_download_api
|
||||||
|
async def post_download_model(request):
|
||||||
|
json_data = await request.json()
|
||||||
|
url = json_data.get("url")
|
||||||
|
directory = json_data.get("directory")
|
||||||
|
filename = json_data.get("filename")
|
||||||
|
client_id = json_data.get("client_id")
|
||||||
|
|
||||||
|
if not url or not directory or not filename:
|
||||||
|
return web.json_response(
|
||||||
|
{"error": "Missing required fields: url, directory, filename"},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
task, err = await self.download_manager.start_download(url, directory, filename, client_id=client_id)
|
||||||
|
if err:
|
||||||
|
status = 409 if "already" in err.lower() else 400
|
||||||
|
return web.json_response({"error": err}, status=status)
|
||||||
|
|
||||||
|
return web.json_response(task.to_dict(), status=201)
|
||||||
|
|
||||||
|
@routes.get("/download/status")
|
||||||
|
@_require_download_api
|
||||||
|
async def get_download_status(request):
|
||||||
|
client_id = request.rel_url.query.get("client_id")
|
||||||
|
return web.json_response(self.download_manager.get_all_tasks(client_id=client_id))
|
||||||
|
|
||||||
|
@routes.get("/download/status/{task_id}")
|
||||||
|
@_require_download_api
|
||||||
|
async def get_download_task_status(request):
|
||||||
|
task_id = request.match_info["task_id"]
|
||||||
|
task_data = self.download_manager.get_task(task_id)
|
||||||
|
if task_data is None:
|
||||||
|
return web.json_response({"error": "Download not found"}, status=404)
|
||||||
|
return web.json_response(task_data)
|
||||||
|
|
||||||
|
@routes.post("/download/pause/{task_id}")
|
||||||
|
@_require_download_api
|
||||||
|
async def post_download_pause(request):
|
||||||
|
task_id = request.match_info["task_id"]
|
||||||
|
err = self.download_manager.pause_download(task_id)
|
||||||
|
if err:
|
||||||
|
return web.json_response({"error": err}, status=400)
|
||||||
|
return web.json_response({"status": "paused"})
|
||||||
|
|
||||||
|
@routes.post("/download/resume/{task_id}")
|
||||||
|
@_require_download_api
|
||||||
|
async def post_download_resume(request):
|
||||||
|
task_id = request.match_info["task_id"]
|
||||||
|
err = self.download_manager.resume_download(task_id)
|
||||||
|
if err:
|
||||||
|
return web.json_response({"error": err}, status=400)
|
||||||
|
return web.json_response({"status": "resumed"})
|
||||||
|
|
||||||
|
@routes.post("/download/cancel/{task_id}")
|
||||||
|
@_require_download_api
|
||||||
|
async def post_download_cancel(request):
|
||||||
|
task_id = request.match_info["task_id"]
|
||||||
|
err = self.download_manager.cancel_download(task_id)
|
||||||
|
if err:
|
||||||
|
return web.json_response({"error": err}, status=400)
|
||||||
|
return web.json_response({"status": "cancelled"})
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||||
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||||
|
if self.download_manager is not None:
|
||||||
|
self.app.on_cleanup.append(self._cleanup_download_manager)
|
||||||
|
|
||||||
|
async def _cleanup_download_manager(self, app):
|
||||||
|
if self.download_manager is not None:
|
||||||
|
await self.download_manager.close()
|
||||||
|
|
||||||
def add_routes(self):
|
def add_routes(self):
|
||||||
self.user_manager.add_routes(self.routes)
|
self.user_manager.add_routes(self.routes)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user