ComfyUI/app/model_manager.py

333 lines
14 KiB
Python

import os
import base64
import json
import time
import logging
import asyncio
import requests
from threading import Lock
from tqdm.auto import tqdm
from urllib.parse import unquote, urlparse
from typing import Callable
import folder_paths
import glob
import comfy.utils
from aiohttp import web
from PIL import Image
from io import BytesIO
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
class ModelFileManager:
def __init__(self, is_download_model_enabled: Callable[[], bool] = lambda: True) -> None:
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
self.is_download_model_enabled = is_download_model_enabled
self._download_progress: dict[str, dict] = {}
self._download_lock = Lock()
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
return self.cache.get(key, default)
def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
self.cache[key] = value
def clear_cache(self):
self.cache.clear()
def _download_progress_key(self, save_dir: str, filename: str) -> str:
return f"{save_dir}/{filename}"
def _set_download_progress(self, key: str, **fields) -> None:
with self._download_lock:
entry = dict(self._download_progress.get(key, {}))
entry.update(fields)
self._download_progress[key] = entry
def _download_file_sync(self, url: str, headers: dict, tmp_path: str, save_path: str, key: str) -> None:
with requests.get(url, headers=headers, stream=True, timeout=(30, 3600)) as r:
r.raise_for_status()
total_size = int(r.headers.get("content-length", 0) or 0)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
self._set_download_progress(
key,
status="running",
bytes_downloaded=0,
bytes_total=total_size,
filename=os.path.basename(save_path),
)
with open(tmp_path, "wb") as f:
downloaded = 0
for chunk in r.iter_content(chunk_size=1024 * 1024):
if not chunk:
continue
f.write(chunk)
downloaded += len(chunk)
self._set_download_progress(key, bytes_downloaded=downloaded, bytes_total=total_size or downloaded)
os.replace(tmp_path, save_path)
def add_routes(self, routes):
# NOTE: This is an experiment to replace `/models`
@routes.get("/experiment/models")
async def get_model_folders(request):
model_types = list(folder_paths.folder_names_and_paths.keys())
folder_black_list = ["configs", "custom_nodes"]
output_folders: list[dict] = []
for folder in model_types:
if folder in folder_black_list:
continue
output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
return web.json_response(output_folders)
# NOTE: This is an experiment to replace `/models/{folder}`
@routes.get("/experiment/models/{folder}")
async def get_all_models(request):
folder = request.match_info.get("folder", None)
if folder not in folder_paths.folder_names_and_paths:
return web.Response(status=404)
files = self.get_model_file_list(folder)
return web.json_response(files)
@routes.post("/download_model")
async def post_download_model(request):
if not self.is_download_model_enabled():
logging.error("Download Model endpoint is disabled")
return web.Response(status=403)
json_data = await request.json()
url = json_data.get("url")
if not url:
return web.json_response({"error": "url required"}, status=400)
save_dir = json_data.get("save_dir")
if save_dir not in folder_paths.folder_names_and_paths:
return web.json_response({"error": "invalid save_dir"}, status=400)
default_filename = unquote(urlparse(url).path.split("/")[-1].split("?")[0])
filename = json_data.get("filename") or default_filename
if not filename or filename in (".", "..") or "/" in filename or "\\" in filename:
return web.json_response({"error": "invalid filename"}, status=400)
allowed_sources = (
"https://civitai.com/",
"https://civitai.red/",
"https://huggingface.co/",
"https://github.com/",
"http://localhost:",
)
if not any(url.startswith(src) for src in allowed_sources):
return web.json_response({"error": "url not allowed"}, status=400)
save_root = folder_paths.folder_names_and_paths[save_dir][0][0]
save_path = os.path.join(save_root, filename)
save_real = os.path.realpath(save_path)
root_real = os.path.realpath(save_root)
if not save_real.startswith(root_real + os.sep) and save_real != root_real:
return web.json_response({"error": "invalid path"}, status=400)
tmp_path = save_path + ".tmp"
token = json_data.get("token")
headers = {"Authorization": f"Bearer {token}"} if token else {}
key = self._download_progress_key(save_dir, filename)
loop = asyncio.get_running_loop()
try:
await loop.run_in_executor(
None,
self._download_file_sync,
url,
headers,
tmp_path,
save_path,
key,
)
self._set_download_progress(key, status="completed")
logging.info("Downloaded model to %s", save_path)
return web.json_response({"ok": True, "path": save_path, "save_dir": save_dir, "filename": filename})
except Exception as e:
logging.error("Failed to download model: %s", e)
self._set_download_progress(key, status="failed", error=str(e))
if os.path.exists(tmp_path):
os.remove(tmp_path)
return web.json_response({"error": str(e)}, status=500)
@routes.get("/download_model/progress")
async def get_download_model_progress(request):
save_dir = request.rel_url.query.get("save_dir")
filename = request.rel_url.query.get("filename")
if not save_dir or not filename:
return web.json_response({"error": "save_dir and filename required"}, status=400)
key = self._download_progress_key(save_dir, filename)
with self._download_lock:
entry = dict(self._download_progress.get(key, {}))
if not entry and save_dir in folder_paths.folder_names_and_paths:
root = folder_paths.folder_names_and_paths[save_dir][0][0]
tmp_path = os.path.join(root, filename + ".tmp")
final_path = os.path.join(root, filename)
if os.path.isfile(final_path):
size = os.path.getsize(final_path)
entry = {"status": "completed", "bytes_downloaded": size, "bytes_total": size}
elif os.path.isfile(tmp_path):
size = os.path.getsize(tmp_path)
entry = {"status": "running", "bytes_downloaded": size, "bytes_total": 0}
bytes_total = int(entry.get("bytes_total") or 0)
bytes_downloaded = int(entry.get("bytes_downloaded") or 0)
progress = (bytes_downloaded / bytes_total) if bytes_total > 0 else 0
return web.json_response({
"status": entry.get("status", "unknown"),
"bytes_downloaded": bytes_downloaded,
"bytes_total": bytes_total,
"progress": progress,
"error": entry.get("error"),
})
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
async def get_model_preview(request):
folder_name = request.match_info.get("folder", None)
path_index = int(request.match_info.get("path_index", None))
filename = request.match_info.get("filename", None)
if folder_name not in folder_paths.folder_names_and_paths:
return web.Response(status=404)
folders = folder_paths.folder_names_and_paths[folder_name]
folder = folders[0][path_index]
full_filename = os.path.join(folder, filename)
previews = self.get_model_previews(full_filename)
default_preview = previews[0] if len(previews) > 0 else None
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
return web.Response(status=404)
try:
with Image.open(default_preview) as img:
img_bytes = BytesIO()
img.save(img_bytes, format="WEBP")
img_bytes.seek(0)
return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
except:
return web.Response(status=404)
def get_model_file_list(self, folder_name: str):
folder_name = map_legacy(folder_name)
folders = folder_paths.folder_names_and_paths[folder_name]
output_list: list[dict] = []
for index, folder in enumerate(folders[0]):
if not os.path.isdir(folder):
continue
out = self.cache_model_file_list_(folder)
if out is None:
out = self.recursive_search_models_(folder, index)
self.set_cache(folder, out)
output_list.extend(out[0])
return output_list
def cache_model_file_list_(self, folder: str):
model_file_list_cache = self.get_cache(folder)
if model_file_list_cache is None:
return None
if not os.path.isdir(folder):
return None
if os.path.getmtime(folder) != model_file_list_cache[1]:
return None
for x in model_file_list_cache[1]:
time_modified = model_file_list_cache[1][x]
folder = x
if os.path.getmtime(folder) != time_modified:
return None
return model_file_list_cache
def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
if not os.path.isdir(directory):
return [], {}, time.perf_counter()
excluded_dir_names = [".git"]
# TODO use settings
include_hidden_files = False
result: list[str] = []
dirs: dict[str, float] = {}
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
if not include_hidden_files:
subdirs[:] = [d for d in subdirs if not d.startswith(".")]
filenames = [f for f in filenames if not f.startswith(".")]
filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
for file_name in filenames:
try:
full_path = os.path.join(dirpath, file_name)
relative_path = os.path.relpath(full_path, directory)
# Get file metadata
file_info = {
"name": relative_path,
"pathIndex": pathIndex,
"modified": os.path.getmtime(full_path), # Add modification time
"created": os.path.getctime(full_path), # Add creation time
"size": os.path.getsize(full_path) # Add file size
}
result.append(file_info)
except Exception as e:
logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
continue
for d in subdirs:
path: str = os.path.join(dirpath, d)
try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue
return result, dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
return []
basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image")
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
safetensors_metadata = {}
result: list[str | BytesIO] = []
for filename in image_files:
_basename = os.path.splitext(filename)[0]
if _basename == basename:
result.append(filename)
if _basename == f"{basename}.preview":
result.append(filename)
if safetensors_file:
safetensors_filepath = os.path.join(dirname, safetensors_file)
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
if header:
safetensors_metadata = json.loads(header)
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
if safetensors_images:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result
def __exit__(self, exc_type, exc_value, traceback):
self.clear_cache()