diff --git a/app/model_manager.py b/app/model_manager.py index f124d1117..40e5b39a7 100644 --- a/app/model_manager.py +++ b/app/model_manager.py @@ -8,15 +8,44 @@ import logging import folder_paths import glob import comfy.utils +import uuid +from urllib.parse import urlparse +from typing import Awaitable, Callable +import aiohttp from aiohttp import web from PIL import Image from io import BytesIO from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types +ALLOWED_MODEL_SOURCES = ( + "https://civitai.com/", + "https://huggingface.co/", + "http://localhost:", +) +ALLOWED_MODEL_SUFFIXES = (".safetensors", ".sft") +WHITELISTED_MODEL_URLS = { + "https://huggingface.co/stabilityai/stable-zero123/resolve/main/stable_zero123.ckpt", + "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth?download=true", + "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", +} +DOWNLOAD_CHUNK_SIZE = 1024 * 1024 +MAX_BULK_MODEL_DOWNLOADS = 200 +DownloadProgressCallback = Callable[[int], Awaitable[None]] +DownloadShouldCancel = Callable[[], bool] +DOWNLOAD_PROGRESS_MIN_INTERVAL = 0.25 +DOWNLOAD_PROGRESS_MIN_BYTES = 4 * 1024 * 1024 + + +class DownloadCancelledError(Exception): + pass + + class ModelFileManager: - def __init__(self) -> None: + def __init__(self, prompt_server) -> None: self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {} + self.prompt_server = prompt_server + self._cancelled_missing_model_downloads: set[str] = set() def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None: return self.cache.get(key, default) @@ -75,6 +104,288 @@ class ModelFileManager: return web.Response(body=img_bytes.getvalue(), content_type="image/webp") except: return web.Response(status=404) + + @routes.post("/experiment/models/download_missing") + async def download_missing_models(request: web.Request) -> web.Response: + try: + payload = await request.json() + except Exception: + return web.json_response({"error": "Invalid JSON body"}, status=400) + + models = payload.get("models") + if not isinstance(models, list): + return web.json_response({"error": "Field 'models' must be a list"}, status=400) + if len(models) > MAX_BULK_MODEL_DOWNLOADS: + return web.json_response( + {"error": f"Maximum of {MAX_BULK_MODEL_DOWNLOADS} models allowed per request"}, + status=400, + ) + + target_client_id = str(payload.get("client_id", "")).strip() or None + batch_id = str(payload.get("batch_id", "")).strip() or uuid.uuid4().hex + + def emit_download_event( + *, + task_id: str, + model_name: str, + model_directory: str, + model_url: str, + status: str, + bytes_downloaded: int = 0, + error: str | None = None, + ) -> None: + message = { + "batch_id": batch_id, + "task_id": task_id, + "name": model_name, + "directory": model_directory, + "url": model_url, + "status": status, + "bytes_downloaded": bytes_downloaded + } + if error: + message["error"] = error + + self.prompt_server.send_sync("missing_model_download", message, target_client_id) + + results = [] + downloaded = 0 + skipped = 0 + canceled = 0 + failed = 0 + + session = self.prompt_server.client_session + owns_session = False + if session is None or session.closed: + timeout = aiohttp.ClientTimeout(total=None) + session = aiohttp.ClientSession(timeout=timeout) + owns_session = True + + try: + for model_entry in models: + model_name, model_directory, model_url = _normalize_model_entry(model_entry) + task_id = uuid.uuid4().hex + self._cancelled_missing_model_downloads.discard(task_id) + + if not model_name or not model_directory or not model_url: + failed += 1 + error = "Each model must include non-empty name, directory, and url" + results.append( + { + "name": model_name or "", + "directory": model_directory or "", + "url": model_url or "", + "status": "failed", + "error": error, + } + ) + emit_download_event( + task_id=task_id, + model_name=model_name or "", + model_directory=model_directory or "", + model_url=model_url or "", + status="failed", + error=error, + ) + continue + + if not _is_http_url(model_url): + failed += 1 + error = "URL must be an absolute HTTP/HTTPS URL" + results.append( + { + "name": model_name, + "directory": model_directory, + "url": model_url, + "status": "failed", + "error": error, + } + ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="failed", + error=error, + ) + continue + + allowed, reason = _is_model_download_allowed(model_name, model_url) + if not allowed: + failed += 1 + results.append( + { + "name": model_name, + "directory": model_directory, + "url": model_url, + "status": "blocked", + "error": reason, + } + ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="blocked", + error=reason, + ) + continue + + try: + destination = _resolve_download_destination(model_directory, model_name) + except Exception as exc: + failed += 1 + error = str(exc) + results.append( + { + "name": model_name, + "directory": model_directory, + "url": model_url, + "status": "failed", + "error": error, + } + ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="failed", + error=error, + ) + continue + + if os.path.exists(destination): + skipped += 1 + results.append( + { + "name": model_name, + "directory": model_directory, + "url": model_url, + "status": "skipped_existing", + } + ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="skipped_existing", + bytes_downloaded=0 + ) + continue + + try: + latest_downloaded = 0 + async def on_progress(bytes_downloaded: int) -> None: + nonlocal latest_downloaded + latest_downloaded = bytes_downloaded + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="running", + bytes_downloaded=bytes_downloaded + ) + + await _download_file( + session, + model_url, + destination, + progress_callback=on_progress, + should_cancel=lambda: task_id in self._cancelled_missing_model_downloads + ) + downloaded += 1 + final_size = os.path.getsize(destination) if os.path.exists(destination) else latest_downloaded + results.append( + { + "name": model_name, + "directory": model_directory, + "url": model_url, + "status": "downloaded", + } + ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="completed", + bytes_downloaded=final_size + ) + except DownloadCancelledError: + canceled += 1 + results.append( + { + "name": model_name, + "directory": model_directory, + "url": model_url, + "status": "canceled", + "error": "Download canceled", + } + ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="canceled", + bytes_downloaded=latest_downloaded, + error="Download canceled", + ) + except Exception as exc: + failed += 1 + error = str(exc) + results.append( + { + "name": model_name, + "directory": model_directory, + "url": model_url, + "status": "failed", + "error": error, + } + ) + emit_download_event( + task_id=task_id, + model_name=model_name, + model_directory=model_directory, + model_url=model_url, + status="failed", + error=error + ) + finally: + self._cancelled_missing_model_downloads.discard(task_id) + finally: + if owns_session: + await session.close() + + return web.json_response( + { + "downloaded": downloaded, + "skipped": skipped, + "canceled": canceled, + "failed": failed, + "results": results, + }, + status=200, + ) + + @routes.post("/experiment/models/download_missing/cancel") + async def cancel_download_missing_model(request: web.Request) -> web.Response: + try: + payload = await request.json() + except Exception: + return web.json_response({"error": "Invalid JSON body"}, status=400) + + task_id = str(payload.get("task_id", "")).strip() + if not task_id: + return web.json_response({"error": "Field 'task_id' is required"}, status=400) + + self._cancelled_missing_model_downloads.add(task_id) + return web.json_response({"ok": True, "task_id": task_id}, status=200) def get_model_file_list(self, folder_name: str): folder_name = map_legacy(folder_name) @@ -193,3 +504,105 @@ class ModelFileManager: def __exit__(self, exc_type, exc_value, traceback): self.clear_cache() + +def _is_model_download_allowed(model_name: str, model_url: str) -> tuple[bool, str | None]: + if model_url in WHITELISTED_MODEL_URLS: + return True, None + + if not any(model_url.startswith(source) for source in ALLOWED_MODEL_SOURCES): + return ( + False, + f"Download not allowed from source '{model_url}'.", + ) + + if not any(model_name.endswith(suffix) for suffix in ALLOWED_MODEL_SUFFIXES): + return ( + False, + f"Only allowed suffixes are: {', '.join(ALLOWED_MODEL_SUFFIXES)}", + ) + + return True, None + + +def _resolve_download_destination(directory: str, model_name: str) -> str: + if directory not in folder_paths.folder_names_and_paths: + raise ValueError(f"Unknown model directory '{directory}'") + + model_paths = folder_paths.folder_names_and_paths[directory][0] + if not model_paths: + raise ValueError(f"No filesystem paths configured for '{directory}'") + + base_path = os.path.abspath(model_paths[0]) + normalized_name = os.path.normpath(model_name).lstrip("/\\") + if not normalized_name or normalized_name == ".": + raise ValueError("Model name cannot be empty") + + destination = os.path.abspath(os.path.join(base_path, normalized_name)) + if os.path.commonpath((base_path, destination)) != base_path: + raise ValueError("Model path escapes configured model directory") + + destination_parent = os.path.dirname(destination) + if destination_parent: + os.makedirs(destination_parent, exist_ok=True) + + return destination + + +async def _download_file( + session: aiohttp.ClientSession, + url: str, + destination: str, + progress_callback: DownloadProgressCallback | None = None, + should_cancel: DownloadShouldCancel | None = None, + progress_min_interval: float = DOWNLOAD_PROGRESS_MIN_INTERVAL, + progress_min_bytes: int = DOWNLOAD_PROGRESS_MIN_BYTES, +) -> None: + temp_file = f"{destination}.{uuid.uuid4().hex}.temp" + try: + if should_cancel is not None and should_cancel(): + raise DownloadCancelledError("Download canceled") + async with session.get(url, allow_redirects=True) as response: + response.raise_for_status() + bytes_downloaded = 0 + if progress_callback is not None: + await progress_callback(bytes_downloaded) + last_progress_emit_time = time.monotonic() + last_progress_emit_bytes = 0 + with open(temp_file, "wb") as file_handle: + async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE): + if should_cancel is not None and should_cancel(): + raise DownloadCancelledError("Download canceled") + if chunk: + file_handle.write(chunk) + bytes_downloaded += len(chunk) + if progress_callback is not None: + now = time.monotonic() + should_emit = ( + bytes_downloaded - last_progress_emit_bytes >= progress_min_bytes + or now - last_progress_emit_time >= progress_min_interval + ) + if should_emit: + await progress_callback(bytes_downloaded) + last_progress_emit_time = now + last_progress_emit_bytes = bytes_downloaded + if progress_callback is not None and bytes_downloaded != last_progress_emit_bytes: + await progress_callback(bytes_downloaded) + os.replace(temp_file, destination) + finally: + if os.path.exists(temp_file): + os.remove(temp_file) + + +def _normalize_model_entry(model_entry: object) -> tuple[str | None, str | None, str | None]: + if not isinstance(model_entry, dict): + return None, None, None + + model_name = str(model_entry.get("name", "")).strip() + model_directory = str(model_entry.get("directory", "")).strip() + model_url = str(model_entry.get("url", "")).strip() + return model_name, model_directory, model_url + + +def _is_http_url(url: str) -> bool: + parsed = urlparse(url) + return parsed.scheme in ("http", "https") and bool(parsed.netloc) diff --git a/server.py b/server.py index 76904ebc9..593ef7413 100644 --- a/server.py +++ b/server.py @@ -198,7 +198,7 @@ class PromptServer(): PromptServer.instance = self self.user_manager = UserManager() - self.model_file_manager = ModelFileManager() + self.model_file_manager = ModelFileManager(self) self.custom_node_manager = CustomNodeManager() self.subgraph_manager = SubgraphManager() self.node_replace_manager = NodeReplaceManager() diff --git a/tests-unit/app_test/model_manager_test.py b/tests-unit/app_test/model_manager_test.py index ae59206f6..12a559fa7 100644 --- a/tests-unit/app_test/model_manager_test.py +++ b/tests-unit/app_test/model_manager_test.py @@ -12,9 +12,14 @@ pytestmark = ( pytest.mark.asyncio ) # This applies the asyncio mark to all test functions in the module +class DummyPromptServer: + def __init__(self): + self.client_session = None + @pytest.fixture def model_manager(): - return ModelFileManager() + prompt_server = DummyPromptServer() + return ModelFileManager(prompt_server) @pytest.fixture def app(model_manager):