From f4c0e1d2694989a96bc872e9635cda07aeaded5e Mon Sep 17 00:00:00 2001 From: teddav Date: Tue, 17 Feb 2026 12:06:30 +0100 Subject: [PATCH] missing models download: add progress tracking --- app/model_manager.py | 207 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 196 insertions(+), 11 deletions(-) diff --git a/app/model_manager.py b/app/model_manager.py index ff145c1a6..40e5b39a7 100644 --- a/app/model_manager.py +++ b/app/model_manager.py @@ -10,6 +10,7 @@ import glob import comfy.utils import uuid from urllib.parse import urlparse +from typing import Awaitable, Callable import aiohttp from aiohttp import web from PIL import Image @@ -30,11 +31,21 @@ WHITELISTED_MODEL_URLS = { } DOWNLOAD_CHUNK_SIZE = 1024 * 1024 MAX_BULK_MODEL_DOWNLOADS = 200 +DownloadProgressCallback = Callable[[int], Awaitable[None]] +DownloadShouldCancel = Callable[[], bool] +DOWNLOAD_PROGRESS_MIN_INTERVAL = 0.25 +DOWNLOAD_PROGRESS_MIN_BYTES = 4 * 1024 * 1024 + + +class DownloadCancelledError(Exception): + pass + class ModelFileManager: def __init__(self, prompt_server) -> None: self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {} self.prompt_server = prompt_server + self._cancelled_missing_model_downloads: set[str] = set() def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None: return self.cache.get(key, default) @@ -96,13 +107,10 @@ class ModelFileManager: @routes.post("/experiment/models/download_missing") async def download_missing_models(request: web.Request) -> web.Response: - print("download_missing_models") try: payload = await request.json() except Exception: return web.json_response({"error": "Invalid JSON body"}, status=400) - print("download_missing_models") - print(payload) models = payload.get("models") if not isinstance(models, list): @@ -113,9 +121,37 @@ class ModelFileManager: status=400, ) + target_client_id = str(payload.get("client_id", "")).strip() or None + batch_id = str(payload.get("batch_id", "")).strip() or uuid.uuid4().hex + + def emit_download_event( + *, + task_id: str, + model_name: str, + model_directory: str, + model_url: str, + status: str, + bytes_downloaded: int = 0, + error: str | None = None, + ) -> None: + message = { + "batch_id": batch_id, + "task_id": task_id, + "name": model_name, + "directory": model_directory, + "url": model_url, + "status": status, + "bytes_downloaded": bytes_downloaded + } + if error: + message["error"] = error + + self.prompt_server.send_sync("missing_model_download", message, target_client_id) + results = [] downloaded = 0 skipped = 0 + canceled = 0 failed = 0 session = self.prompt_server.client_session @@ -128,31 +164,51 @@ class ModelFileManager: try: for model_entry in models: model_name, model_directory, model_url = _normalize_model_entry(model_entry) + task_id = uuid.uuid4().hex + self._cancelled_missing_model_downloads.discard(task_id) if not model_name or not model_directory or not model_url: failed += 1 + error = "Each model must include non-empty name, directory, and url" results.append( { "name": model_name or "", "directory": model_directory or "", "url": model_url or "", "status": "failed", - "error": "Each model must include non-empty name, directory, and url", + "error": error, } ) + emit_download_event( + task_id=task_id, + model_name=model_name or "", + model_directory=model_directory or "", + model_url=model_url or "", + status="failed", + error=error, + ) continue if not _is_http_url(model_url): failed += 1 + error = "URL must be an absolute HTTP/HTTPS URL" results.append( { "name": model_name, "directory": model_directory, "url": model_url, "status": "failed", - "error": "URL must be an absolute HTTP/HTTPS URL", + "error": error, } ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="failed", + error=error, + ) continue allowed, reason = _is_model_download_allowed(model_name, model_url) @@ -167,21 +223,38 @@ class ModelFileManager: "error": reason, } ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="blocked", + error=reason, + ) continue try: destination = _resolve_download_destination(model_directory, model_name) except Exception as exc: failed += 1 + error = str(exc) results.append( { "name": model_name, "directory": model_directory, "url": model_url, "status": "failed", - "error": str(exc), + "error": error, } ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="failed", + error=error, + ) continue if os.path.exists(destination): @@ -194,11 +267,39 @@ class ModelFileManager: "status": "skipped_existing", } ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="skipped_existing", + bytes_downloaded=0 + ) continue try: - await _download_file(session, model_url, destination) + latest_downloaded = 0 + async def on_progress(bytes_downloaded: int) -> None: + nonlocal latest_downloaded + latest_downloaded = bytes_downloaded + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="running", + bytes_downloaded=bytes_downloaded + ) + + await _download_file( + session, + model_url, + destination, + progress_callback=on_progress, + should_cancel=lambda: task_id in self._cancelled_missing_model_downloads + ) downloaded += 1 + final_size = os.path.getsize(destination) if os.path.exists(destination) else latest_downloaded results.append( { "name": model_name, @@ -207,17 +308,56 @@ class ModelFileManager: "status": "downloaded", } ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="completed", + bytes_downloaded=final_size + ) + except DownloadCancelledError: + canceled += 1 + results.append( + { + "name": model_name, + "directory": model_directory, + "url": model_url, + "status": "canceled", + "error": "Download canceled", + } + ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="canceled", + bytes_downloaded=latest_downloaded, + error="Download canceled", + ) except Exception as exc: failed += 1 + error = str(exc) results.append( { "name": model_name, "directory": model_directory, "url": model_url, "status": "failed", - "error": str(exc), + "error": error, } ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="failed", + error=error + ) + finally: + self._cancelled_missing_model_downloads.discard(task_id) finally: if owns_session: await session.close() @@ -226,12 +366,27 @@ class ModelFileManager: { "downloaded": downloaded, "skipped": skipped, + "canceled": canceled, "failed": failed, "results": results, }, status=200, ) + @routes.post("/experiment/models/download_missing/cancel") + async def cancel_download_missing_model(request: web.Request) -> web.Response: + try: + payload = await request.json() + except Exception: + return web.json_response({"error": "Invalid JSON body"}, status=400) + + task_id = str(payload.get("task_id", "")).strip() + if not task_id: + return web.json_response({"error": "Field 'task_id' is required"}, status=400) + + self._cancelled_missing_model_downloads.add(task_id) + return web.json_response({"ok": True, "task_id": task_id}, status=200) + def get_model_file_list(self, folder_name: str): folder_name = map_legacy(folder_name) folders = folder_paths.folder_names_and_paths[folder_name] @@ -393,15 +548,45 @@ def _resolve_download_destination(directory: str, model_name: str) -> str: return destination -async def _download_file(session: aiohttp.ClientSession, url: str, destination: str) -> None: - temp_file = f"{destination}.part-{uuid.uuid4().hex}" +async def _download_file( + session: aiohttp.ClientSession, + url: str, + destination: str, + progress_callback: DownloadProgressCallback | None = None, + should_cancel: DownloadShouldCancel | None = None, + progress_min_interval: float = DOWNLOAD_PROGRESS_MIN_INTERVAL, + progress_min_bytes: int = DOWNLOAD_PROGRESS_MIN_BYTES, +) -> None: + temp_file = f"{destination}.{uuid.uuid4().hex}.temp" try: + if should_cancel is not None and should_cancel(): + raise DownloadCancelledError("Download canceled") async with session.get(url, allow_redirects=True) as response: response.raise_for_status() + bytes_downloaded = 0 + if progress_callback is not None: + await progress_callback(bytes_downloaded) + last_progress_emit_time = time.monotonic() + last_progress_emit_bytes = 0 with open(temp_file, "wb") as file_handle: async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE): + if should_cancel is not None and should_cancel(): + raise DownloadCancelledError("Download canceled") if chunk: file_handle.write(chunk) + bytes_downloaded += len(chunk) + if progress_callback is not None: + now = time.monotonic() + should_emit = ( + bytes_downloaded - last_progress_emit_bytes >= progress_min_bytes + or now - last_progress_emit_time >= progress_min_interval + ) + if should_emit: + await progress_callback(bytes_downloaded) + last_progress_emit_time = now + last_progress_emit_bytes = bytes_downloaded + if progress_callback is not None and bytes_downloaded != last_progress_emit_bytes: + await progress_callback(bytes_downloaded) os.replace(temp_file, destination) finally: if os.path.exists(temp_file): @@ -420,4 +605,4 @@ def _normalize_model_entry(model_entry: object) -> tuple[str | None, str | None, def _is_http_url(url: str) -> bool: parsed = urlparse(url) - return parsed.scheme in ("http", "https") and bool(parsed.netloc) \ No newline at end of file + return parsed.scheme in ("http", "https") and bool(parsed.netloc)