missing models download: add progress tracking

This commit is contained in:
teddav 2026-02-17 12:06:30 +01:00
parent f075b35fd1
commit f4c0e1d269

View File

@ -10,6 +10,7 @@ import glob
import comfy.utils import comfy.utils
import uuid import uuid
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import Awaitable, Callable
import aiohttp import aiohttp
from aiohttp import web from aiohttp import web
from PIL import Image from PIL import Image
@ -30,11 +31,21 @@ WHITELISTED_MODEL_URLS = {
} }
DOWNLOAD_CHUNK_SIZE = 1024 * 1024 DOWNLOAD_CHUNK_SIZE = 1024 * 1024
MAX_BULK_MODEL_DOWNLOADS = 200 MAX_BULK_MODEL_DOWNLOADS = 200
DownloadProgressCallback = Callable[[int], Awaitable[None]]
DownloadShouldCancel = Callable[[], bool]
DOWNLOAD_PROGRESS_MIN_INTERVAL = 0.25
DOWNLOAD_PROGRESS_MIN_BYTES = 4 * 1024 * 1024
class DownloadCancelledError(Exception):
pass
class ModelFileManager: class ModelFileManager:
def __init__(self, prompt_server) -> None: def __init__(self, prompt_server) -> None:
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {} self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
self.prompt_server = prompt_server self.prompt_server = prompt_server
self._cancelled_missing_model_downloads: set[str] = set()
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None: def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
return self.cache.get(key, default) return self.cache.get(key, default)
@ -96,13 +107,10 @@ class ModelFileManager:
@routes.post("/experiment/models/download_missing") @routes.post("/experiment/models/download_missing")
async def download_missing_models(request: web.Request) -> web.Response: async def download_missing_models(request: web.Request) -> web.Response:
print("download_missing_models")
try: try:
payload = await request.json() payload = await request.json()
except Exception: except Exception:
return web.json_response({"error": "Invalid JSON body"}, status=400) return web.json_response({"error": "Invalid JSON body"}, status=400)
print("download_missing_models")
print(payload)
models = payload.get("models") models = payload.get("models")
if not isinstance(models, list): if not isinstance(models, list):
@ -113,9 +121,37 @@ class ModelFileManager:
status=400, status=400,
) )
target_client_id = str(payload.get("client_id", "")).strip() or None
batch_id = str(payload.get("batch_id", "")).strip() or uuid.uuid4().hex
def emit_download_event(
*,
task_id: str,
model_name: str,
model_directory: str,
model_url: str,
status: str,
bytes_downloaded: int = 0,
error: str | None = None,
) -> None:
message = {
"batch_id": batch_id,
"task_id": task_id,
"name": model_name,
"directory": model_directory,
"url": model_url,
"status": status,
"bytes_downloaded": bytes_downloaded
}
if error:
message["error"] = error
self.prompt_server.send_sync("missing_model_download", message, target_client_id)
results = [] results = []
downloaded = 0 downloaded = 0
skipped = 0 skipped = 0
canceled = 0
failed = 0 failed = 0
session = self.prompt_server.client_session session = self.prompt_server.client_session
@ -128,31 +164,51 @@ class ModelFileManager:
try: try:
for model_entry in models: for model_entry in models:
model_name, model_directory, model_url = _normalize_model_entry(model_entry) model_name, model_directory, model_url = _normalize_model_entry(model_entry)
task_id = uuid.uuid4().hex
self._cancelled_missing_model_downloads.discard(task_id)
if not model_name or not model_directory or not model_url: if not model_name or not model_directory or not model_url:
failed += 1 failed += 1
error = "Each model must include non-empty name, directory, and url"
results.append( results.append(
{ {
"name": model_name or "", "name": model_name or "",
"directory": model_directory or "", "directory": model_directory or "",
"url": model_url or "", "url": model_url or "",
"status": "failed", "status": "failed",
"error": "Each model must include non-empty name, directory, and url", "error": error,
} }
) )
emit_download_event(
task_id=task_id,
model_name=model_name or "",
model_directory=model_directory or "",
model_url=model_url or "",
status="failed",
error=error,
)
continue continue
if not _is_http_url(model_url): if not _is_http_url(model_url):
failed += 1 failed += 1
error = "URL must be an absolute HTTP/HTTPS URL"
results.append( results.append(
{ {
"name": model_name, "name": model_name,
"directory": model_directory, "directory": model_directory,
"url": model_url, "url": model_url,
"status": "failed", "status": "failed",
"error": "URL must be an absolute HTTP/HTTPS URL", "error": error,
} }
) )
emit_download_event(
task_id=task_id,
model_name=model_name,
model_directory=model_directory,
model_url=model_url,
status="failed",
error=error,
)
continue continue
allowed, reason = _is_model_download_allowed(model_name, model_url) allowed, reason = _is_model_download_allowed(model_name, model_url)
@ -167,21 +223,38 @@ class ModelFileManager:
"error": reason, "error": reason,
} }
) )
emit_download_event(
task_id=task_id,
model_name=model_name,
model_directory=model_directory,
model_url=model_url,
status="blocked",
error=reason,
)
continue continue
try: try:
destination = _resolve_download_destination(model_directory, model_name) destination = _resolve_download_destination(model_directory, model_name)
except Exception as exc: except Exception as exc:
failed += 1 failed += 1
error = str(exc)
results.append( results.append(
{ {
"name": model_name, "name": model_name,
"directory": model_directory, "directory": model_directory,
"url": model_url, "url": model_url,
"status": "failed", "status": "failed",
"error": str(exc), "error": error,
} }
) )
emit_download_event(
task_id=task_id,
model_name=model_name,
model_directory=model_directory,
model_url=model_url,
status="failed",
error=error,
)
continue continue
if os.path.exists(destination): if os.path.exists(destination):
@ -194,11 +267,39 @@ class ModelFileManager:
"status": "skipped_existing", "status": "skipped_existing",
} }
) )
emit_download_event(
task_id=task_id,
model_name=model_name,
model_directory=model_directory,
model_url=model_url,
status="skipped_existing",
bytes_downloaded=0
)
continue continue
try: try:
await _download_file(session, model_url, destination) latest_downloaded = 0
async def on_progress(bytes_downloaded: int) -> None:
nonlocal latest_downloaded
latest_downloaded = bytes_downloaded
emit_download_event(
task_id=task_id,
model_name=model_name,
model_directory=model_directory,
model_url=model_url,
status="running",
bytes_downloaded=bytes_downloaded
)
await _download_file(
session,
model_url,
destination,
progress_callback=on_progress,
should_cancel=lambda: task_id in self._cancelled_missing_model_downloads
)
downloaded += 1 downloaded += 1
final_size = os.path.getsize(destination) if os.path.exists(destination) else latest_downloaded
results.append( results.append(
{ {
"name": model_name, "name": model_name,
@ -207,17 +308,56 @@ class ModelFileManager:
"status": "downloaded", "status": "downloaded",
} }
) )
emit_download_event(
task_id=task_id,
model_name=model_name,
model_directory=model_directory,
model_url=model_url,
status="completed",
bytes_downloaded=final_size
)
except DownloadCancelledError:
canceled += 1
results.append(
{
"name": model_name,
"directory": model_directory,
"url": model_url,
"status": "canceled",
"error": "Download canceled",
}
)
emit_download_event(
task_id=task_id,
model_name=model_name,
model_directory=model_directory,
model_url=model_url,
status="canceled",
bytes_downloaded=latest_downloaded,
error="Download canceled",
)
except Exception as exc: except Exception as exc:
failed += 1 failed += 1
error = str(exc)
results.append( results.append(
{ {
"name": model_name, "name": model_name,
"directory": model_directory, "directory": model_directory,
"url": model_url, "url": model_url,
"status": "failed", "status": "failed",
"error": str(exc), "error": error,
} }
) )
emit_download_event(
task_id=task_id,
model_name=model_name,
model_directory=model_directory,
model_url=model_url,
status="failed",
error=error
)
finally:
self._cancelled_missing_model_downloads.discard(task_id)
finally: finally:
if owns_session: if owns_session:
await session.close() await session.close()
@ -226,12 +366,27 @@ class ModelFileManager:
{ {
"downloaded": downloaded, "downloaded": downloaded,
"skipped": skipped, "skipped": skipped,
"canceled": canceled,
"failed": failed, "failed": failed,
"results": results, "results": results,
}, },
status=200, status=200,
) )
@routes.post("/experiment/models/download_missing/cancel")
async def cancel_download_missing_model(request: web.Request) -> web.Response:
try:
payload = await request.json()
except Exception:
return web.json_response({"error": "Invalid JSON body"}, status=400)
task_id = str(payload.get("task_id", "")).strip()
if not task_id:
return web.json_response({"error": "Field 'task_id' is required"}, status=400)
self._cancelled_missing_model_downloads.add(task_id)
return web.json_response({"ok": True, "task_id": task_id}, status=200)
def get_model_file_list(self, folder_name: str): def get_model_file_list(self, folder_name: str):
folder_name = map_legacy(folder_name) folder_name = map_legacy(folder_name)
folders = folder_paths.folder_names_and_paths[folder_name] folders = folder_paths.folder_names_and_paths[folder_name]
@ -393,15 +548,45 @@ def _resolve_download_destination(directory: str, model_name: str) -> str:
return destination return destination
async def _download_file(session: aiohttp.ClientSession, url: str, destination: str) -> None: async def _download_file(
temp_file = f"{destination}.part-{uuid.uuid4().hex}" session: aiohttp.ClientSession,
url: str,
destination: str,
progress_callback: DownloadProgressCallback | None = None,
should_cancel: DownloadShouldCancel | None = None,
progress_min_interval: float = DOWNLOAD_PROGRESS_MIN_INTERVAL,
progress_min_bytes: int = DOWNLOAD_PROGRESS_MIN_BYTES,
) -> None:
temp_file = f"{destination}.{uuid.uuid4().hex}.temp"
try: try:
if should_cancel is not None and should_cancel():
raise DownloadCancelledError("Download canceled")
async with session.get(url, allow_redirects=True) as response: async with session.get(url, allow_redirects=True) as response:
response.raise_for_status() response.raise_for_status()
bytes_downloaded = 0
if progress_callback is not None:
await progress_callback(bytes_downloaded)
last_progress_emit_time = time.monotonic()
last_progress_emit_bytes = 0
with open(temp_file, "wb") as file_handle: with open(temp_file, "wb") as file_handle:
async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE): async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
if should_cancel is not None and should_cancel():
raise DownloadCancelledError("Download canceled")
if chunk: if chunk:
file_handle.write(chunk) file_handle.write(chunk)
bytes_downloaded += len(chunk)
if progress_callback is not None:
now = time.monotonic()
should_emit = (
bytes_downloaded - last_progress_emit_bytes >= progress_min_bytes
or now - last_progress_emit_time >= progress_min_interval
)
if should_emit:
await progress_callback(bytes_downloaded)
last_progress_emit_time = now
last_progress_emit_bytes = bytes_downloaded
if progress_callback is not None and bytes_downloaded != last_progress_emit_bytes:
await progress_callback(bytes_downloaded)
os.replace(temp_file, destination) os.replace(temp_file, destination)
finally: finally:
if os.path.exists(temp_file): if os.path.exists(temp_file):
@ -420,4 +605,4 @@ def _normalize_model_entry(model_entry: object) -> tuple[str | None, str | None,
def _is_http_url(url: str) -> bool: def _is_http_url(url: str) -> bool:
parsed = urlparse(url) parsed = urlparse(url)
return parsed.scheme in ("http", "https") and bool(parsed.netloc) return parsed.scheme in ("http", "https") and bool(parsed.netloc)