From f075b35fd1a138373d4371966317b55061ee1c9c Mon Sep 17 00:00:00 2001 From: teddav Date: Mon, 16 Feb 2026 12:14:56 +0100 Subject: [PATCH 1/2] add route to automatically download missing models --- app/model_manager.py | 230 +++++++++++++++++++++- server.py | 2 +- tests-unit/app_test/model_manager_test.py | 7 +- 3 files changed, 236 insertions(+), 3 deletions(-) diff --git a/app/model_manager.py b/app/model_manager.py index f124d1117..ff145c1a6 100644 --- a/app/model_manager.py +++ b/app/model_manager.py @@ -8,15 +8,33 @@ import logging import folder_paths import glob import comfy.utils +import uuid +from urllib.parse import urlparse +import aiohttp from aiohttp import web from PIL import Image from io import BytesIO from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types +ALLOWED_MODEL_SOURCES = ( + "https://civitai.com/", + "https://huggingface.co/", + "http://localhost:", +) +ALLOWED_MODEL_SUFFIXES = (".safetensors", ".sft") +WHITELISTED_MODEL_URLS = { + "https://huggingface.co/stabilityai/stable-zero123/resolve/main/stable_zero123.ckpt", + "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth?download=true", + "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", +} +DOWNLOAD_CHUNK_SIZE = 1024 * 1024 +MAX_BULK_MODEL_DOWNLOADS = 200 + class ModelFileManager: - def __init__(self) -> None: + def __init__(self, prompt_server) -> None: self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {} + self.prompt_server = prompt_server def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None: return self.cache.get(key, default) @@ -75,6 +93,144 @@ class ModelFileManager: return web.Response(body=img_bytes.getvalue(), content_type="image/webp") except: return web.Response(status=404) + + @routes.post("/experiment/models/download_missing") + async def download_missing_models(request: web.Request) -> web.Response: + print("download_missing_models") + try: + payload = await request.json() + except Exception: + return web.json_response({"error": "Invalid JSON body"}, status=400) + print("download_missing_models") + print(payload) + + models = payload.get("models") + if not isinstance(models, list): + return web.json_response({"error": "Field 'models' must be a list"}, status=400) + if len(models) > MAX_BULK_MODEL_DOWNLOADS: + return web.json_response( + {"error": f"Maximum of {MAX_BULK_MODEL_DOWNLOADS} models allowed per request"}, + status=400, + ) + + results = [] + downloaded = 0 + skipped = 0 + failed = 0 + + session = self.prompt_server.client_session + owns_session = False + if session is None or session.closed: + timeout = aiohttp.ClientTimeout(total=None) + session = aiohttp.ClientSession(timeout=timeout) + owns_session = True + + try: + for model_entry in models: + model_name, model_directory, model_url = _normalize_model_entry(model_entry) + + if not model_name or not model_directory or not model_url: + failed += 1 + results.append( + { + "name": model_name or "", + "directory": model_directory or "", + "url": model_url or "", + "status": "failed", + "error": "Each model must include non-empty name, directory, and url", + } + ) + continue + + if not _is_http_url(model_url): + failed += 1 + results.append( + { + "name": model_name, + "directory": model_directory, + "url": model_url, + "status": "failed", + "error": "URL must be an absolute HTTP/HTTPS URL", + } + ) + continue + + allowed, reason = _is_model_download_allowed(model_name, model_url) + if not allowed: + failed += 1 + results.append( + { + "name": model_name, + "directory": model_directory, + "url": model_url, + "status": "blocked", + "error": reason, + } + ) + continue + + try: + destination = _resolve_download_destination(model_directory, model_name) + except Exception as exc: + failed += 1 + results.append( + { + "name": model_name, + "directory": model_directory, + "url": model_url, + "status": "failed", + "error": str(exc), + } + ) + continue + + if os.path.exists(destination): + skipped += 1 + results.append( + { + "name": model_name, + "directory": model_directory, + "url": model_url, + "status": "skipped_existing", + } + ) + continue + + try: + await _download_file(session, model_url, destination) + downloaded += 1 + results.append( + { + "name": model_name, + "directory": model_directory, + "url": model_url, + "status": "downloaded", + } + ) + except Exception as exc: + failed += 1 + results.append( + { + "name": model_name, + "directory": model_directory, + "url": model_url, + "status": "failed", + "error": str(exc), + } + ) + finally: + if owns_session: + await session.close() + + return web.json_response( + { + "downloaded": downloaded, + "skipped": skipped, + "failed": failed, + "results": results, + }, + status=200, + ) def get_model_file_list(self, folder_name: str): folder_name = map_legacy(folder_name) @@ -193,3 +349,75 @@ class ModelFileManager: def __exit__(self, exc_type, exc_value, traceback): self.clear_cache() + +def _is_model_download_allowed(model_name: str, model_url: str) -> tuple[bool, str | None]: + if model_url in WHITELISTED_MODEL_URLS: + return True, None + + if not any(model_url.startswith(source) for source in ALLOWED_MODEL_SOURCES): + return ( + False, + f"Download not allowed from source '{model_url}'.", + ) + + if not any(model_name.endswith(suffix) for suffix in ALLOWED_MODEL_SUFFIXES): + return ( + False, + f"Only allowed suffixes are: {', '.join(ALLOWED_MODEL_SUFFIXES)}", + ) + + return True, None + + +def _resolve_download_destination(directory: str, model_name: str) -> str: + if directory not in folder_paths.folder_names_and_paths: + raise ValueError(f"Unknown model directory '{directory}'") + + model_paths = folder_paths.folder_names_and_paths[directory][0] + if not model_paths: + raise ValueError(f"No filesystem paths configured for '{directory}'") + + base_path = os.path.abspath(model_paths[0]) + normalized_name = os.path.normpath(model_name).lstrip("/\\") + if not normalized_name or normalized_name == ".": + raise ValueError("Model name cannot be empty") + + destination = os.path.abspath(os.path.join(base_path, normalized_name)) + if os.path.commonpath((base_path, destination)) != base_path: + raise ValueError("Model path escapes configured model directory") + + destination_parent = os.path.dirname(destination) + if destination_parent: + os.makedirs(destination_parent, exist_ok=True) + + return destination + + +async def _download_file(session: aiohttp.ClientSession, url: str, destination: str) -> None: + temp_file = f"{destination}.part-{uuid.uuid4().hex}" + try: + async with session.get(url, allow_redirects=True) as response: + response.raise_for_status() + with open(temp_file, "wb") as file_handle: + async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE): + if chunk: + file_handle.write(chunk) + os.replace(temp_file, destination) + finally: + if os.path.exists(temp_file): + os.remove(temp_file) + + +def _normalize_model_entry(model_entry: object) -> tuple[str | None, str | None, str | None]: + if not isinstance(model_entry, dict): + return None, None, None + + model_name = str(model_entry.get("name", "")).strip() + model_directory = str(model_entry.get("directory", "")).strip() + model_url = str(model_entry.get("url", "")).strip() + return model_name, model_directory, model_url + + +def _is_http_url(url: str) -> bool: + parsed = urlparse(url) + return parsed.scheme in ("http", "https") and bool(parsed.netloc) \ No newline at end of file diff --git a/server.py b/server.py index 8882e43c4..c8f579548 100644 --- a/server.py +++ b/server.py @@ -202,7 +202,7 @@ class PromptServer(): mimetypes.add_type('image/webp', '.webp') self.user_manager = UserManager() - self.model_file_manager = ModelFileManager() + self.model_file_manager = ModelFileManager(self) self.custom_node_manager = CustomNodeManager() self.subgraph_manager = SubgraphManager() self.node_replace_manager = NodeReplaceManager() diff --git a/tests-unit/app_test/model_manager_test.py b/tests-unit/app_test/model_manager_test.py index ae59206f6..12a559fa7 100644 --- a/tests-unit/app_test/model_manager_test.py +++ b/tests-unit/app_test/model_manager_test.py @@ -12,9 +12,14 @@ pytestmark = ( pytest.mark.asyncio ) # This applies the asyncio mark to all test functions in the module +class DummyPromptServer: + def __init__(self): + self.client_session = None + @pytest.fixture def model_manager(): - return ModelFileManager() + prompt_server = DummyPromptServer() + return ModelFileManager(prompt_server) @pytest.fixture def app(model_manager): From f4c0e1d2694989a96bc872e9635cda07aeaded5e Mon Sep 17 00:00:00 2001 From: teddav Date: Tue, 17 Feb 2026 12:06:30 +0100 Subject: [PATCH 2/2] missing models download: add progress tracking --- app/model_manager.py | 207 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 196 insertions(+), 11 deletions(-) diff --git a/app/model_manager.py b/app/model_manager.py index ff145c1a6..40e5b39a7 100644 --- a/app/model_manager.py +++ b/app/model_manager.py @@ -10,6 +10,7 @@ import glob import comfy.utils import uuid from urllib.parse import urlparse +from typing import Awaitable, Callable import aiohttp from aiohttp import web from PIL import Image @@ -30,11 +31,21 @@ WHITELISTED_MODEL_URLS = { } DOWNLOAD_CHUNK_SIZE = 1024 * 1024 MAX_BULK_MODEL_DOWNLOADS = 200 +DownloadProgressCallback = Callable[[int], Awaitable[None]] +DownloadShouldCancel = Callable[[], bool] +DOWNLOAD_PROGRESS_MIN_INTERVAL = 0.25 +DOWNLOAD_PROGRESS_MIN_BYTES = 4 * 1024 * 1024 + + +class DownloadCancelledError(Exception): + pass + class ModelFileManager: def __init__(self, prompt_server) -> None: self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {} self.prompt_server = prompt_server + self._cancelled_missing_model_downloads: set[str] = set() def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None: return self.cache.get(key, default) @@ -96,13 +107,10 @@ class ModelFileManager: @routes.post("/experiment/models/download_missing") async def download_missing_models(request: web.Request) -> web.Response: - print("download_missing_models") try: payload = await request.json() except Exception: return web.json_response({"error": "Invalid JSON body"}, status=400) - print("download_missing_models") - print(payload) models = payload.get("models") if not isinstance(models, list): @@ -113,9 +121,37 @@ class ModelFileManager: status=400, ) + target_client_id = str(payload.get("client_id", "")).strip() or None + batch_id = str(payload.get("batch_id", "")).strip() or uuid.uuid4().hex + + def emit_download_event( + *, + task_id: str, + model_name: str, + model_directory: str, + model_url: str, + status: str, + bytes_downloaded: int = 0, + error: str | None = None, + ) -> None: + message = { + "batch_id": batch_id, + "task_id": task_id, + "name": model_name, + "directory": model_directory, + "url": model_url, + "status": status, + "bytes_downloaded": bytes_downloaded + } + if error: + message["error"] = error + + self.prompt_server.send_sync("missing_model_download", message, target_client_id) + results = [] downloaded = 0 skipped = 0 + canceled = 0 failed = 0 session = self.prompt_server.client_session @@ -128,31 +164,51 @@ class ModelFileManager: try: for model_entry in models: model_name, model_directory, model_url = _normalize_model_entry(model_entry) + task_id = uuid.uuid4().hex + self._cancelled_missing_model_downloads.discard(task_id) if not model_name or not model_directory or not model_url: failed += 1 + error = "Each model must include non-empty name, directory, and url" results.append( { "name": model_name or "", "directory": model_directory or "", "url": model_url or "", "status": "failed", - "error": "Each model must include non-empty name, directory, and url", + "error": error, } ) + emit_download_event( + task_id=task_id, + model_name=model_name or "", + model_directory=model_directory or "", + model_url=model_url or "", + status="failed", + error=error, + ) continue if not _is_http_url(model_url): failed += 1 + error = "URL must be an absolute HTTP/HTTPS URL" results.append( { "name": model_name, "directory": model_directory, "url": model_url, "status": "failed", - "error": "URL must be an absolute HTTP/HTTPS URL", + "error": error, } ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="failed", + error=error, + ) continue allowed, reason = _is_model_download_allowed(model_name, model_url) @@ -167,21 +223,38 @@ class ModelFileManager: "error": reason, } ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="blocked", + error=reason, + ) continue try: destination = _resolve_download_destination(model_directory, model_name) except Exception as exc: failed += 1 + error = str(exc) results.append( { "name": model_name, "directory": model_directory, "url": model_url, "status": "failed", - "error": str(exc), + "error": error, } ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="failed", + error=error, + ) continue if os.path.exists(destination): @@ -194,11 +267,39 @@ class ModelFileManager: "status": "skipped_existing", } ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="skipped_existing", + bytes_downloaded=0 + ) continue try: - await _download_file(session, model_url, destination) + latest_downloaded = 0 + async def on_progress(bytes_downloaded: int) -> None: + nonlocal latest_downloaded + latest_downloaded = bytes_downloaded + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="running", + bytes_downloaded=bytes_downloaded + ) + + await _download_file( + session, + model_url, + destination, + progress_callback=on_progress, + should_cancel=lambda: task_id in self._cancelled_missing_model_downloads + ) downloaded += 1 + final_size = os.path.getsize(destination) if os.path.exists(destination) else latest_downloaded results.append( { "name": model_name, @@ -207,17 +308,56 @@ class ModelFileManager: "status": "downloaded", } ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="completed", + bytes_downloaded=final_size + ) + except DownloadCancelledError: + canceled += 1 + results.append( + { + "name": model_name, + "directory": model_directory, + "url": model_url, + "status": "canceled", + "error": "Download canceled", + } + ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="canceled", + bytes_downloaded=latest_downloaded, + error="Download canceled", + ) except Exception as exc: failed += 1 + error = str(exc) results.append( { "name": model_name, "directory": model_directory, "url": model_url, "status": "failed", - "error": str(exc), + "error": error, } ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="failed", + error=error + ) + finally: + self._cancelled_missing_model_downloads.discard(task_id) finally: if owns_session: await session.close() @@ -226,12 +366,27 @@ class ModelFileManager: { "downloaded": downloaded, "skipped": skipped, + "canceled": canceled, "failed": failed, "results": results, }, status=200, ) + @routes.post("/experiment/models/download_missing/cancel") + async def cancel_download_missing_model(request: web.Request) -> web.Response: + try: + payload = await request.json() + except Exception: + return web.json_response({"error": "Invalid JSON body"}, status=400) + + task_id = str(payload.get("task_id", "")).strip() + if not task_id: + return web.json_response({"error": "Field 'task_id' is required"}, status=400) + + self._cancelled_missing_model_downloads.add(task_id) + return web.json_response({"ok": True, "task_id": task_id}, status=200) + def get_model_file_list(self, folder_name: str): folder_name = map_legacy(folder_name) folders = folder_paths.folder_names_and_paths[folder_name] @@ -393,15 +548,45 @@ def _resolve_download_destination(directory: str, model_name: str) -> str: return destination -async def _download_file(session: aiohttp.ClientSession, url: str, destination: str) -> None: - temp_file = f"{destination}.part-{uuid.uuid4().hex}" +async def _download_file( + session: aiohttp.ClientSession, + url: str, + destination: str, + progress_callback: DownloadProgressCallback | None = None, + should_cancel: DownloadShouldCancel | None = None, + progress_min_interval: float = DOWNLOAD_PROGRESS_MIN_INTERVAL, + progress_min_bytes: int = DOWNLOAD_PROGRESS_MIN_BYTES, +) -> None: + temp_file = f"{destination}.{uuid.uuid4().hex}.temp" try: + if should_cancel is not None and should_cancel(): + raise DownloadCancelledError("Download canceled") async with session.get(url, allow_redirects=True) as response: response.raise_for_status() + bytes_downloaded = 0 + if progress_callback is not None: + await progress_callback(bytes_downloaded) + last_progress_emit_time = time.monotonic() + last_progress_emit_bytes = 0 with open(temp_file, "wb") as file_handle: async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE): + if should_cancel is not None and should_cancel(): + raise DownloadCancelledError("Download canceled") if chunk: file_handle.write(chunk) + bytes_downloaded += len(chunk) + if progress_callback is not None: + now = time.monotonic() + should_emit = ( + bytes_downloaded - last_progress_emit_bytes >= progress_min_bytes + or now - last_progress_emit_time >= progress_min_interval + ) + if should_emit: + await progress_callback(bytes_downloaded) + last_progress_emit_time = now + last_progress_emit_bytes = bytes_downloaded + if progress_callback is not None and bytes_downloaded != last_progress_emit_bytes: + await progress_callback(bytes_downloaded) os.replace(temp_file, destination) finally: if os.path.exists(temp_file): @@ -420,4 +605,4 @@ def _normalize_model_entry(model_entry: object) -> tuple[str | None, str | None, def _is_http_url(url: str) -> bool: parsed = urlparse(url) - return parsed.scheme in ("http", "https") and bool(parsed.netloc) \ No newline at end of file + return parsed.scheme in ("http", "https") and bool(parsed.netloc)