from __future__ import annotations import os import base64 import json import time import logging import folder_paths import glob import comfy.utils import uuid from urllib.parse import urlparse import aiohttp from aiohttp import web from PIL import Image from io import BytesIO from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types ALLOWED_MODEL_SOURCES = ( "https://civitai.com/", "https://huggingface.co/", "http://localhost:", ) ALLOWED_MODEL_SUFFIXES = (".safetensors", ".sft") WHITELISTED_MODEL_URLS = { "https://huggingface.co/stabilityai/stable-zero123/resolve/main/stable_zero123.ckpt", "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth?download=true", "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", } DOWNLOAD_CHUNK_SIZE = 1024 * 1024 MAX_BULK_MODEL_DOWNLOADS = 200 class ModelFileManager: def __init__(self, prompt_server) -> None: self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {} self.prompt_server = prompt_server def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None: return self.cache.get(key, default) def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]): self.cache[key] = value def clear_cache(self): self.cache.clear() def add_routes(self, routes): # NOTE: This is an experiment to replace `/models` @routes.get("/experiment/models") async def get_model_folders(request): model_types = list(folder_paths.folder_names_and_paths.keys()) folder_black_list = ["configs", "custom_nodes"] output_folders: list[dict] = [] for folder in model_types: if folder in folder_black_list: continue output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)}) return web.json_response(output_folders) # NOTE: This is an experiment to replace `/models/{folder}` @routes.get("/experiment/models/{folder}") async def get_all_models(request): folder = request.match_info.get("folder", None) if folder not in folder_paths.folder_names_and_paths: return web.Response(status=404) files = self.get_model_file_list(folder) return web.json_response(files) @routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}") async def get_model_preview(request): folder_name = request.match_info.get("folder", None) path_index = int(request.match_info.get("path_index", None)) filename = request.match_info.get("filename", None) if folder_name not in folder_paths.folder_names_and_paths: return web.Response(status=404) folders = folder_paths.folder_names_and_paths[folder_name] folder = folders[0][path_index] full_filename = os.path.join(folder, filename) previews = self.get_model_previews(full_filename) default_preview = previews[0] if len(previews) > 0 else None if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)): return web.Response(status=404) try: with Image.open(default_preview) as img: img_bytes = BytesIO() img.save(img_bytes, format="WEBP") img_bytes.seek(0) return web.Response(body=img_bytes.getvalue(), content_type="image/webp") except: return web.Response(status=404) @routes.post("/experiment/models/download_missing") async def download_missing_models(request: web.Request) -> web.Response: print("download_missing_models") try: payload = await request.json() except Exception: return web.json_response({"error": "Invalid JSON body"}, status=400) print("download_missing_models") print(payload) models = payload.get("models") if not isinstance(models, list): return web.json_response({"error": "Field 'models' must be a list"}, status=400) if len(models) > MAX_BULK_MODEL_DOWNLOADS: return web.json_response( {"error": f"Maximum of {MAX_BULK_MODEL_DOWNLOADS} models allowed per request"}, status=400, ) results = [] downloaded = 0 skipped = 0 failed = 0 session = self.prompt_server.client_session owns_session = False if session is None or session.closed: timeout = aiohttp.ClientTimeout(total=None) session = aiohttp.ClientSession(timeout=timeout) owns_session = True try: for model_entry in models: model_name, model_directory, model_url = _normalize_model_entry(model_entry) if not model_name or not model_directory or not model_url: failed += 1 results.append( { "name": model_name or "", "directory": model_directory or "", "url": model_url or "", "status": "failed", "error": "Each model must include non-empty name, directory, and url", } ) continue if not _is_http_url(model_url): failed += 1 results.append( { "name": model_name, "directory": model_directory, "url": model_url, "status": "failed", "error": "URL must be an absolute HTTP/HTTPS URL", } ) continue allowed, reason = _is_model_download_allowed(model_name, model_url) if not allowed: failed += 1 results.append( { "name": model_name, "directory": model_directory, "url": model_url, "status": "blocked", "error": reason, } ) continue try: destination = _resolve_download_destination(model_directory, model_name) except Exception as exc: failed += 1 results.append( { "name": model_name, "directory": model_directory, "url": model_url, "status": "failed", "error": str(exc), } ) continue if os.path.exists(destination): skipped += 1 results.append( { "name": model_name, "directory": model_directory, "url": model_url, "status": "skipped_existing", } ) continue try: await _download_file(session, model_url, destination) downloaded += 1 results.append( { "name": model_name, "directory": model_directory, "url": model_url, "status": "downloaded", } ) except Exception as exc: failed += 1 results.append( { "name": model_name, "directory": model_directory, "url": model_url, "status": "failed", "error": str(exc), } ) finally: if owns_session: await session.close() return web.json_response( { "downloaded": downloaded, "skipped": skipped, "failed": failed, "results": results, }, status=200, ) def get_model_file_list(self, folder_name: str): folder_name = map_legacy(folder_name) folders = folder_paths.folder_names_and_paths[folder_name] output_list: list[dict] = [] for index, folder in enumerate(folders[0]): if not os.path.isdir(folder): continue out = self.cache_model_file_list_(folder) if out is None: out = self.recursive_search_models_(folder, index) self.set_cache(folder, out) output_list.extend(out[0]) return output_list def cache_model_file_list_(self, folder: str): model_file_list_cache = self.get_cache(folder) if model_file_list_cache is None: return None if not os.path.isdir(folder): return None if os.path.getmtime(folder) != model_file_list_cache[1]: return None for x in model_file_list_cache[1]: time_modified = model_file_list_cache[1][x] folder = x if os.path.getmtime(folder) != time_modified: return None return model_file_list_cache def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]: if not os.path.isdir(directory): return [], {}, time.perf_counter() excluded_dir_names = [".git"] # TODO use settings include_hidden_files = False result: list[str] = [] dirs: dict[str, float] = {} for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] if not include_hidden_files: subdirs[:] = [d for d in subdirs if not d.startswith(".")] filenames = [f for f in filenames if not f.startswith(".")] filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions) for file_name in filenames: try: full_path = os.path.join(dirpath, file_name) relative_path = os.path.relpath(full_path, directory) # Get file metadata file_info = { "name": relative_path, "pathIndex": pathIndex, "modified": os.path.getmtime(full_path), # Add modification time "created": os.path.getctime(full_path), # Add creation time "size": os.path.getsize(full_path) # Add file size } result.append(file_info) except Exception as e: logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.") continue for d in subdirs: path: str = os.path.join(dirpath, d) try: dirs[path] = os.path.getmtime(path) except FileNotFoundError: logging.warning(f"Warning: Unable to access {path}. Skipping this path.") continue return result, dirs, time.perf_counter() def get_model_previews(self, filepath: str) -> list[str | BytesIO]: dirname = os.path.dirname(filepath) if not os.path.exists(dirname): return [] basename = os.path.splitext(filepath)[0] match_files = glob.glob(f"{basename}.*", recursive=False) image_files = filter_files_content_types(match_files, "image") safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None) safetensors_metadata = {} result: list[str | BytesIO] = [] for filename in image_files: _basename = os.path.splitext(filename)[0] if _basename == basename: result.append(filename) if _basename == f"{basename}.preview": result.append(filename) if safetensors_file: safetensors_filepath = os.path.join(dirname, safetensors_file) header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024) if header: safetensors_metadata = json.loads(header) safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None) if safetensors_images: safetensors_images = json.loads(safetensors_images) for image in safetensors_images: result.append(BytesIO(base64.b64decode(image))) return result def __exit__(self, exc_type, exc_value, traceback): self.clear_cache() def _is_model_download_allowed(model_name: str, model_url: str) -> tuple[bool, str | None]: if model_url in WHITELISTED_MODEL_URLS: return True, None if not any(model_url.startswith(source) for source in ALLOWED_MODEL_SOURCES): return ( False, f"Download not allowed from source '{model_url}'.", ) if not any(model_name.endswith(suffix) for suffix in ALLOWED_MODEL_SUFFIXES): return ( False, f"Only allowed suffixes are: {', '.join(ALLOWED_MODEL_SUFFIXES)}", ) return True, None def _resolve_download_destination(directory: str, model_name: str) -> str: if directory not in folder_paths.folder_names_and_paths: raise ValueError(f"Unknown model directory '{directory}'") model_paths = folder_paths.folder_names_and_paths[directory][0] if not model_paths: raise ValueError(f"No filesystem paths configured for '{directory}'") base_path = os.path.abspath(model_paths[0]) normalized_name = os.path.normpath(model_name).lstrip("/\\") if not normalized_name or normalized_name == ".": raise ValueError("Model name cannot be empty") destination = os.path.abspath(os.path.join(base_path, normalized_name)) if os.path.commonpath((base_path, destination)) != base_path: raise ValueError("Model path escapes configured model directory") destination_parent = os.path.dirname(destination) if destination_parent: os.makedirs(destination_parent, exist_ok=True) return destination async def _download_file(session: aiohttp.ClientSession, url: str, destination: str) -> None: temp_file = f"{destination}.part-{uuid.uuid4().hex}" try: async with session.get(url, allow_redirects=True) as response: response.raise_for_status() with open(temp_file, "wb") as file_handle: async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE): if chunk: file_handle.write(chunk) os.replace(temp_file, destination) finally: if os.path.exists(temp_file): os.remove(temp_file) def _normalize_model_entry(model_entry: object) -> tuple[str | None, str | None, str | None]: if not isinstance(model_entry, dict): return None, None, None model_name = str(model_entry.get("name", "")).strip() model_directory = str(model_entry.get("directory", "")).strip() model_url = str(model_entry.get("url", "")).strip() return model_name, model_directory, model_url def _is_http_url(url: str) -> bool: parsed = urlparse(url) return parsed.scheme in ("http", "https") and bool(parsed.netloc)