mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-19 08:05:05 +08:00
missing models download: add progress tracking
This commit is contained in:
parent
f075b35fd1
commit
f4c0e1d269
@ -10,6 +10,7 @@ import glob
|
||||
import comfy.utils
|
||||
import uuid
|
||||
from urllib.parse import urlparse
|
||||
from typing import Awaitable, Callable
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
from PIL import Image
|
||||
@ -30,11 +31,21 @@ WHITELISTED_MODEL_URLS = {
|
||||
}
|
||||
DOWNLOAD_CHUNK_SIZE = 1024 * 1024
|
||||
MAX_BULK_MODEL_DOWNLOADS = 200
|
||||
DownloadProgressCallback = Callable[[int], Awaitable[None]]
|
||||
DownloadShouldCancel = Callable[[], bool]
|
||||
DOWNLOAD_PROGRESS_MIN_INTERVAL = 0.25
|
||||
DOWNLOAD_PROGRESS_MIN_BYTES = 4 * 1024 * 1024
|
||||
|
||||
|
||||
class DownloadCancelledError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ModelFileManager:
|
||||
def __init__(self, prompt_server) -> None:
|
||||
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
||||
self.prompt_server = prompt_server
|
||||
self._cancelled_missing_model_downloads: set[str] = set()
|
||||
|
||||
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
|
||||
return self.cache.get(key, default)
|
||||
@ -96,13 +107,10 @@ class ModelFileManager:
|
||||
|
||||
@routes.post("/experiment/models/download_missing")
|
||||
async def download_missing_models(request: web.Request) -> web.Response:
|
||||
print("download_missing_models")
|
||||
try:
|
||||
payload = await request.json()
|
||||
except Exception:
|
||||
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
||||
print("download_missing_models")
|
||||
print(payload)
|
||||
|
||||
models = payload.get("models")
|
||||
if not isinstance(models, list):
|
||||
@ -113,9 +121,37 @@ class ModelFileManager:
|
||||
status=400,
|
||||
)
|
||||
|
||||
target_client_id = str(payload.get("client_id", "")).strip() or None
|
||||
batch_id = str(payload.get("batch_id", "")).strip() or uuid.uuid4().hex
|
||||
|
||||
def emit_download_event(
|
||||
*,
|
||||
task_id: str,
|
||||
model_name: str,
|
||||
model_directory: str,
|
||||
model_url: str,
|
||||
status: str,
|
||||
bytes_downloaded: int = 0,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
message = {
|
||||
"batch_id": batch_id,
|
||||
"task_id": task_id,
|
||||
"name": model_name,
|
||||
"directory": model_directory,
|
||||
"url": model_url,
|
||||
"status": status,
|
||||
"bytes_downloaded": bytes_downloaded
|
||||
}
|
||||
if error:
|
||||
message["error"] = error
|
||||
|
||||
self.prompt_server.send_sync("missing_model_download", message, target_client_id)
|
||||
|
||||
results = []
|
||||
downloaded = 0
|
||||
skipped = 0
|
||||
canceled = 0
|
||||
failed = 0
|
||||
|
||||
session = self.prompt_server.client_session
|
||||
@ -128,31 +164,51 @@ class ModelFileManager:
|
||||
try:
|
||||
for model_entry in models:
|
||||
model_name, model_directory, model_url = _normalize_model_entry(model_entry)
|
||||
task_id = uuid.uuid4().hex
|
||||
self._cancelled_missing_model_downloads.discard(task_id)
|
||||
|
||||
if not model_name or not model_directory or not model_url:
|
||||
failed += 1
|
||||
error = "Each model must include non-empty name, directory, and url"
|
||||
results.append(
|
||||
{
|
||||
"name": model_name or "",
|
||||
"directory": model_directory or "",
|
||||
"url": model_url or "",
|
||||
"status": "failed",
|
||||
"error": "Each model must include non-empty name, directory, and url",
|
||||
"error": error,
|
||||
}
|
||||
)
|
||||
emit_download_event(
|
||||
task_id=task_id,
|
||||
model_name=model_name or "",
|
||||
model_directory=model_directory or "",
|
||||
model_url=model_url or "",
|
||||
status="failed",
|
||||
error=error,
|
||||
)
|
||||
continue
|
||||
|
||||
if not _is_http_url(model_url):
|
||||
failed += 1
|
||||
error = "URL must be an absolute HTTP/HTTPS URL"
|
||||
results.append(
|
||||
{
|
||||
"name": model_name,
|
||||
"directory": model_directory,
|
||||
"url": model_url,
|
||||
"status": "failed",
|
||||
"error": "URL must be an absolute HTTP/HTTPS URL",
|
||||
"error": error,
|
||||
}
|
||||
)
|
||||
emit_download_event(
|
||||
task_id=task_id,
|
||||
model_name=model_name,
|
||||
model_directory=model_directory,
|
||||
model_url=model_url,
|
||||
status="failed",
|
||||
error=error,
|
||||
)
|
||||
continue
|
||||
|
||||
allowed, reason = _is_model_download_allowed(model_name, model_url)
|
||||
@ -167,21 +223,38 @@ class ModelFileManager:
|
||||
"error": reason,
|
||||
}
|
||||
)
|
||||
emit_download_event(
|
||||
task_id=task_id,
|
||||
model_name=model_name,
|
||||
model_directory=model_directory,
|
||||
model_url=model_url,
|
||||
status="blocked",
|
||||
error=reason,
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
destination = _resolve_download_destination(model_directory, model_name)
|
||||
except Exception as exc:
|
||||
failed += 1
|
||||
error = str(exc)
|
||||
results.append(
|
||||
{
|
||||
"name": model_name,
|
||||
"directory": model_directory,
|
||||
"url": model_url,
|
||||
"status": "failed",
|
||||
"error": str(exc),
|
||||
"error": error,
|
||||
}
|
||||
)
|
||||
emit_download_event(
|
||||
task_id=task_id,
|
||||
model_name=model_name,
|
||||
model_directory=model_directory,
|
||||
model_url=model_url,
|
||||
status="failed",
|
||||
error=error,
|
||||
)
|
||||
continue
|
||||
|
||||
if os.path.exists(destination):
|
||||
@ -194,11 +267,39 @@ class ModelFileManager:
|
||||
"status": "skipped_existing",
|
||||
}
|
||||
)
|
||||
emit_download_event(
|
||||
task_id=task_id,
|
||||
model_name=model_name,
|
||||
model_directory=model_directory,
|
||||
model_url=model_url,
|
||||
status="skipped_existing",
|
||||
bytes_downloaded=0
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
await _download_file(session, model_url, destination)
|
||||
latest_downloaded = 0
|
||||
async def on_progress(bytes_downloaded: int) -> None:
|
||||
nonlocal latest_downloaded
|
||||
latest_downloaded = bytes_downloaded
|
||||
emit_download_event(
|
||||
task_id=task_id,
|
||||
model_name=model_name,
|
||||
model_directory=model_directory,
|
||||
model_url=model_url,
|
||||
status="running",
|
||||
bytes_downloaded=bytes_downloaded
|
||||
)
|
||||
|
||||
await _download_file(
|
||||
session,
|
||||
model_url,
|
||||
destination,
|
||||
progress_callback=on_progress,
|
||||
should_cancel=lambda: task_id in self._cancelled_missing_model_downloads
|
||||
)
|
||||
downloaded += 1
|
||||
final_size = os.path.getsize(destination) if os.path.exists(destination) else latest_downloaded
|
||||
results.append(
|
||||
{
|
||||
"name": model_name,
|
||||
@ -207,17 +308,56 @@ class ModelFileManager:
|
||||
"status": "downloaded",
|
||||
}
|
||||
)
|
||||
emit_download_event(
|
||||
task_id=task_id,
|
||||
model_name=model_name,
|
||||
model_directory=model_directory,
|
||||
model_url=model_url,
|
||||
status="completed",
|
||||
bytes_downloaded=final_size
|
||||
)
|
||||
except DownloadCancelledError:
|
||||
canceled += 1
|
||||
results.append(
|
||||
{
|
||||
"name": model_name,
|
||||
"directory": model_directory,
|
||||
"url": model_url,
|
||||
"status": "canceled",
|
||||
"error": "Download canceled",
|
||||
}
|
||||
)
|
||||
emit_download_event(
|
||||
task_id=task_id,
|
||||
model_name=model_name,
|
||||
model_directory=model_directory,
|
||||
model_url=model_url,
|
||||
status="canceled",
|
||||
bytes_downloaded=latest_downloaded,
|
||||
error="Download canceled",
|
||||
)
|
||||
except Exception as exc:
|
||||
failed += 1
|
||||
error = str(exc)
|
||||
results.append(
|
||||
{
|
||||
"name": model_name,
|
||||
"directory": model_directory,
|
||||
"url": model_url,
|
||||
"status": "failed",
|
||||
"error": str(exc),
|
||||
"error": error,
|
||||
}
|
||||
)
|
||||
emit_download_event(
|
||||
task_id=task_id,
|
||||
model_name=model_name,
|
||||
model_directory=model_directory,
|
||||
model_url=model_url,
|
||||
status="failed",
|
||||
error=error
|
||||
)
|
||||
finally:
|
||||
self._cancelled_missing_model_downloads.discard(task_id)
|
||||
finally:
|
||||
if owns_session:
|
||||
await session.close()
|
||||
@ -226,12 +366,27 @@ class ModelFileManager:
|
||||
{
|
||||
"downloaded": downloaded,
|
||||
"skipped": skipped,
|
||||
"canceled": canceled,
|
||||
"failed": failed,
|
||||
"results": results,
|
||||
},
|
||||
status=200,
|
||||
)
|
||||
|
||||
@routes.post("/experiment/models/download_missing/cancel")
|
||||
async def cancel_download_missing_model(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
except Exception:
|
||||
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
||||
|
||||
task_id = str(payload.get("task_id", "")).strip()
|
||||
if not task_id:
|
||||
return web.json_response({"error": "Field 'task_id' is required"}, status=400)
|
||||
|
||||
self._cancelled_missing_model_downloads.add(task_id)
|
||||
return web.json_response({"ok": True, "task_id": task_id}, status=200)
|
||||
|
||||
def get_model_file_list(self, folder_name: str):
|
||||
folder_name = map_legacy(folder_name)
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
@ -393,15 +548,45 @@ def _resolve_download_destination(directory: str, model_name: str) -> str:
|
||||
return destination
|
||||
|
||||
|
||||
async def _download_file(session: aiohttp.ClientSession, url: str, destination: str) -> None:
|
||||
temp_file = f"{destination}.part-{uuid.uuid4().hex}"
|
||||
async def _download_file(
|
||||
session: aiohttp.ClientSession,
|
||||
url: str,
|
||||
destination: str,
|
||||
progress_callback: DownloadProgressCallback | None = None,
|
||||
should_cancel: DownloadShouldCancel | None = None,
|
||||
progress_min_interval: float = DOWNLOAD_PROGRESS_MIN_INTERVAL,
|
||||
progress_min_bytes: int = DOWNLOAD_PROGRESS_MIN_BYTES,
|
||||
) -> None:
|
||||
temp_file = f"{destination}.{uuid.uuid4().hex}.temp"
|
||||
try:
|
||||
if should_cancel is not None and should_cancel():
|
||||
raise DownloadCancelledError("Download canceled")
|
||||
async with session.get(url, allow_redirects=True) as response:
|
||||
response.raise_for_status()
|
||||
bytes_downloaded = 0
|
||||
if progress_callback is not None:
|
||||
await progress_callback(bytes_downloaded)
|
||||
last_progress_emit_time = time.monotonic()
|
||||
last_progress_emit_bytes = 0
|
||||
with open(temp_file, "wb") as file_handle:
|
||||
async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
|
||||
if should_cancel is not None and should_cancel():
|
||||
raise DownloadCancelledError("Download canceled")
|
||||
if chunk:
|
||||
file_handle.write(chunk)
|
||||
bytes_downloaded += len(chunk)
|
||||
if progress_callback is not None:
|
||||
now = time.monotonic()
|
||||
should_emit = (
|
||||
bytes_downloaded - last_progress_emit_bytes >= progress_min_bytes
|
||||
or now - last_progress_emit_time >= progress_min_interval
|
||||
)
|
||||
if should_emit:
|
||||
await progress_callback(bytes_downloaded)
|
||||
last_progress_emit_time = now
|
||||
last_progress_emit_bytes = bytes_downloaded
|
||||
if progress_callback is not None and bytes_downloaded != last_progress_emit_bytes:
|
||||
await progress_callback(bytes_downloaded)
|
||||
os.replace(temp_file, destination)
|
||||
finally:
|
||||
if os.path.exists(temp_file):
|
||||
@ -420,4 +605,4 @@ def _normalize_model_entry(model_entry: object) -> tuple[str | None, str | None,
|
||||
|
||||
def _is_http_url(url: str) -> bool:
|
||||
parsed = urlparse(url)
|
||||
return parsed.scheme in ("http", "https") and bool(parsed.netloc)
|
||||
return parsed.scheme in ("http", "https") and bool(parsed.netloc)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user