ComfyUI/app/model_manager.py

423 lines
16 KiB
Python

from __future__ import annotations
import os
import base64
import json
import time
import logging
import folder_paths
import glob
import comfy.utils
import uuid
from urllib.parse import urlparse
import aiohttp
from aiohttp import web
from PIL import Image
from io import BytesIO
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
ALLOWED_MODEL_SOURCES = (
"https://civitai.com/",
"https://huggingface.co/",
"http://localhost:",
)
ALLOWED_MODEL_SUFFIXES = (".safetensors", ".sft")
WHITELISTED_MODEL_URLS = {
"https://huggingface.co/stabilityai/stable-zero123/resolve/main/stable_zero123.ckpt",
"https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth?download=true",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
}
DOWNLOAD_CHUNK_SIZE = 1024 * 1024
MAX_BULK_MODEL_DOWNLOADS = 200
class ModelFileManager:
def __init__(self, prompt_server) -> None:
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
self.prompt_server = prompt_server
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
return self.cache.get(key, default)
def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
self.cache[key] = value
def clear_cache(self):
self.cache.clear()
def add_routes(self, routes):
# NOTE: This is an experiment to replace `/models`
@routes.get("/experiment/models")
async def get_model_folders(request):
model_types = list(folder_paths.folder_names_and_paths.keys())
folder_black_list = ["configs", "custom_nodes"]
output_folders: list[dict] = []
for folder in model_types:
if folder in folder_black_list:
continue
output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
return web.json_response(output_folders)
# NOTE: This is an experiment to replace `/models/{folder}`
@routes.get("/experiment/models/{folder}")
async def get_all_models(request):
folder = request.match_info.get("folder", None)
if folder not in folder_paths.folder_names_and_paths:
return web.Response(status=404)
files = self.get_model_file_list(folder)
return web.json_response(files)
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
async def get_model_preview(request):
folder_name = request.match_info.get("folder", None)
path_index = int(request.match_info.get("path_index", None))
filename = request.match_info.get("filename", None)
if folder_name not in folder_paths.folder_names_and_paths:
return web.Response(status=404)
folders = folder_paths.folder_names_and_paths[folder_name]
folder = folders[0][path_index]
full_filename = os.path.join(folder, filename)
previews = self.get_model_previews(full_filename)
default_preview = previews[0] if len(previews) > 0 else None
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
return web.Response(status=404)
try:
with Image.open(default_preview) as img:
img_bytes = BytesIO()
img.save(img_bytes, format="WEBP")
img_bytes.seek(0)
return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
except:
return web.Response(status=404)
@routes.post("/experiment/models/download_missing")
async def download_missing_models(request: web.Request) -> web.Response:
print("download_missing_models")
try:
payload = await request.json()
except Exception:
return web.json_response({"error": "Invalid JSON body"}, status=400)
print("download_missing_models")
print(payload)
models = payload.get("models")
if not isinstance(models, list):
return web.json_response({"error": "Field 'models' must be a list"}, status=400)
if len(models) > MAX_BULK_MODEL_DOWNLOADS:
return web.json_response(
{"error": f"Maximum of {MAX_BULK_MODEL_DOWNLOADS} models allowed per request"},
status=400,
)
results = []
downloaded = 0
skipped = 0
failed = 0
session = self.prompt_server.client_session
owns_session = False
if session is None or session.closed:
timeout = aiohttp.ClientTimeout(total=None)
session = aiohttp.ClientSession(timeout=timeout)
owns_session = True
try:
for model_entry in models:
model_name, model_directory, model_url = _normalize_model_entry(model_entry)
if not model_name or not model_directory or not model_url:
failed += 1
results.append(
{
"name": model_name or "",
"directory": model_directory or "",
"url": model_url or "",
"status": "failed",
"error": "Each model must include non-empty name, directory, and url",
}
)
continue
if not _is_http_url(model_url):
failed += 1
results.append(
{
"name": model_name,
"directory": model_directory,
"url": model_url,
"status": "failed",
"error": "URL must be an absolute HTTP/HTTPS URL",
}
)
continue
allowed, reason = _is_model_download_allowed(model_name, model_url)
if not allowed:
failed += 1
results.append(
{
"name": model_name,
"directory": model_directory,
"url": model_url,
"status": "blocked",
"error": reason,
}
)
continue
try:
destination = _resolve_download_destination(model_directory, model_name)
except Exception as exc:
failed += 1
results.append(
{
"name": model_name,
"directory": model_directory,
"url": model_url,
"status": "failed",
"error": str(exc),
}
)
continue
if os.path.exists(destination):
skipped += 1
results.append(
{
"name": model_name,
"directory": model_directory,
"url": model_url,
"status": "skipped_existing",
}
)
continue
try:
await _download_file(session, model_url, destination)
downloaded += 1
results.append(
{
"name": model_name,
"directory": model_directory,
"url": model_url,
"status": "downloaded",
}
)
except Exception as exc:
failed += 1
results.append(
{
"name": model_name,
"directory": model_directory,
"url": model_url,
"status": "failed",
"error": str(exc),
}
)
finally:
if owns_session:
await session.close()
return web.json_response(
{
"downloaded": downloaded,
"skipped": skipped,
"failed": failed,
"results": results,
},
status=200,
)
def get_model_file_list(self, folder_name: str):
folder_name = map_legacy(folder_name)
folders = folder_paths.folder_names_and_paths[folder_name]
output_list: list[dict] = []
for index, folder in enumerate(folders[0]):
if not os.path.isdir(folder):
continue
out = self.cache_model_file_list_(folder)
if out is None:
out = self.recursive_search_models_(folder, index)
self.set_cache(folder, out)
output_list.extend(out[0])
return output_list
def cache_model_file_list_(self, folder: str):
model_file_list_cache = self.get_cache(folder)
if model_file_list_cache is None:
return None
if not os.path.isdir(folder):
return None
if os.path.getmtime(folder) != model_file_list_cache[1]:
return None
for x in model_file_list_cache[1]:
time_modified = model_file_list_cache[1][x]
folder = x
if os.path.getmtime(folder) != time_modified:
return None
return model_file_list_cache
def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
if not os.path.isdir(directory):
return [], {}, time.perf_counter()
excluded_dir_names = [".git"]
# TODO use settings
include_hidden_files = False
result: list[str] = []
dirs: dict[str, float] = {}
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
if not include_hidden_files:
subdirs[:] = [d for d in subdirs if not d.startswith(".")]
filenames = [f for f in filenames if not f.startswith(".")]
filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
for file_name in filenames:
try:
full_path = os.path.join(dirpath, file_name)
relative_path = os.path.relpath(full_path, directory)
# Get file metadata
file_info = {
"name": relative_path,
"pathIndex": pathIndex,
"modified": os.path.getmtime(full_path), # Add modification time
"created": os.path.getctime(full_path), # Add creation time
"size": os.path.getsize(full_path) # Add file size
}
result.append(file_info)
except Exception as e:
logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
continue
for d in subdirs:
path: str = os.path.join(dirpath, d)
try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue
return result, dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
return []
basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image")
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
safetensors_metadata = {}
result: list[str | BytesIO] = []
for filename in image_files:
_basename = os.path.splitext(filename)[0]
if _basename == basename:
result.append(filename)
if _basename == f"{basename}.preview":
result.append(filename)
if safetensors_file:
safetensors_filepath = os.path.join(dirname, safetensors_file)
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
if header:
safetensors_metadata = json.loads(header)
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
if safetensors_images:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result
def __exit__(self, exc_type, exc_value, traceback):
self.clear_cache()
def _is_model_download_allowed(model_name: str, model_url: str) -> tuple[bool, str | None]:
if model_url in WHITELISTED_MODEL_URLS:
return True, None
if not any(model_url.startswith(source) for source in ALLOWED_MODEL_SOURCES):
return (
False,
f"Download not allowed from source '{model_url}'.",
)
if not any(model_name.endswith(suffix) for suffix in ALLOWED_MODEL_SUFFIXES):
return (
False,
f"Only allowed suffixes are: {', '.join(ALLOWED_MODEL_SUFFIXES)}",
)
return True, None
def _resolve_download_destination(directory: str, model_name: str) -> str:
if directory not in folder_paths.folder_names_and_paths:
raise ValueError(f"Unknown model directory '{directory}'")
model_paths = folder_paths.folder_names_and_paths[directory][0]
if not model_paths:
raise ValueError(f"No filesystem paths configured for '{directory}'")
base_path = os.path.abspath(model_paths[0])
normalized_name = os.path.normpath(model_name).lstrip("/\\")
if not normalized_name or normalized_name == ".":
raise ValueError("Model name cannot be empty")
destination = os.path.abspath(os.path.join(base_path, normalized_name))
if os.path.commonpath((base_path, destination)) != base_path:
raise ValueError("Model path escapes configured model directory")
destination_parent = os.path.dirname(destination)
if destination_parent:
os.makedirs(destination_parent, exist_ok=True)
return destination
async def _download_file(session: aiohttp.ClientSession, url: str, destination: str) -> None:
temp_file = f"{destination}.part-{uuid.uuid4().hex}"
try:
async with session.get(url, allow_redirects=True) as response:
response.raise_for_status()
with open(temp_file, "wb") as file_handle:
async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
if chunk:
file_handle.write(chunk)
os.replace(temp_file, destination)
finally:
if os.path.exists(temp_file):
os.remove(temp_file)
def _normalize_model_entry(model_entry: object) -> tuple[str | None, str | None, str | None]:
if not isinstance(model_entry, dict):
return None, None, None
model_name = str(model_entry.get("name", "")).strip()
model_directory = str(model_entry.get("directory", "")).strip()
model_url = str(model_entry.get("url", "")).strip()
return model_name, model_directory, model_url
def _is_http_url(url: str) -> bool:
parsed = urlparse(url)
return parsed.scheme in ("http", "https") and bool(parsed.netloc)