mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-24 00:39:30 +08:00
333 lines
14 KiB
Python
333 lines
14 KiB
Python
import os
|
|
import base64
|
|
import json
|
|
import time
|
|
import logging
|
|
import asyncio
|
|
import requests
|
|
from threading import Lock
|
|
from tqdm.auto import tqdm
|
|
from urllib.parse import unquote, urlparse
|
|
from typing import Callable
|
|
import folder_paths
|
|
import glob
|
|
import comfy.utils
|
|
from aiohttp import web
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
|
|
|
|
|
|
class ModelFileManager:
|
|
def __init__(self, is_download_model_enabled: Callable[[], bool] = lambda: True) -> None:
|
|
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
|
self.is_download_model_enabled = is_download_model_enabled
|
|
self._download_progress: dict[str, dict] = {}
|
|
self._download_lock = Lock()
|
|
|
|
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
|
|
return self.cache.get(key, default)
|
|
|
|
def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
|
|
self.cache[key] = value
|
|
|
|
def clear_cache(self):
|
|
self.cache.clear()
|
|
|
|
|
|
def _download_progress_key(self, save_dir: str, filename: str) -> str:
|
|
return f"{save_dir}/{filename}"
|
|
|
|
def _set_download_progress(self, key: str, **fields) -> None:
|
|
with self._download_lock:
|
|
entry = dict(self._download_progress.get(key, {}))
|
|
entry.update(fields)
|
|
self._download_progress[key] = entry
|
|
|
|
def _download_file_sync(self, url: str, headers: dict, tmp_path: str, save_path: str, key: str) -> None:
|
|
with requests.get(url, headers=headers, stream=True, timeout=(30, 3600)) as r:
|
|
r.raise_for_status()
|
|
total_size = int(r.headers.get("content-length", 0) or 0)
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
self._set_download_progress(
|
|
key,
|
|
status="running",
|
|
bytes_downloaded=0,
|
|
bytes_total=total_size,
|
|
filename=os.path.basename(save_path),
|
|
)
|
|
with open(tmp_path, "wb") as f:
|
|
downloaded = 0
|
|
for chunk in r.iter_content(chunk_size=1024 * 1024):
|
|
if not chunk:
|
|
continue
|
|
f.write(chunk)
|
|
downloaded += len(chunk)
|
|
self._set_download_progress(key, bytes_downloaded=downloaded, bytes_total=total_size or downloaded)
|
|
os.replace(tmp_path, save_path)
|
|
|
|
|
|
def add_routes(self, routes):
|
|
# NOTE: This is an experiment to replace `/models`
|
|
@routes.get("/experiment/models")
|
|
async def get_model_folders(request):
|
|
model_types = list(folder_paths.folder_names_and_paths.keys())
|
|
folder_black_list = ["configs", "custom_nodes"]
|
|
output_folders: list[dict] = []
|
|
for folder in model_types:
|
|
if folder in folder_black_list:
|
|
continue
|
|
output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
|
|
return web.json_response(output_folders)
|
|
|
|
# NOTE: This is an experiment to replace `/models/{folder}`
|
|
@routes.get("/experiment/models/{folder}")
|
|
async def get_all_models(request):
|
|
folder = request.match_info.get("folder", None)
|
|
if folder not in folder_paths.folder_names_and_paths:
|
|
return web.Response(status=404)
|
|
files = self.get_model_file_list(folder)
|
|
return web.json_response(files)
|
|
|
|
|
|
@routes.post("/download_model")
|
|
async def post_download_model(request):
|
|
if not self.is_download_model_enabled():
|
|
logging.error("Download Model endpoint is disabled")
|
|
return web.Response(status=403)
|
|
|
|
json_data = await request.json()
|
|
url = json_data.get("url")
|
|
if not url:
|
|
return web.json_response({"error": "url required"}, status=400)
|
|
|
|
save_dir = json_data.get("save_dir")
|
|
if save_dir not in folder_paths.folder_names_and_paths:
|
|
return web.json_response({"error": "invalid save_dir"}, status=400)
|
|
|
|
default_filename = unquote(urlparse(url).path.split("/")[-1].split("?")[0])
|
|
filename = json_data.get("filename") or default_filename
|
|
if not filename or filename in (".", "..") or "/" in filename or "\\" in filename:
|
|
return web.json_response({"error": "invalid filename"}, status=400)
|
|
|
|
allowed_sources = (
|
|
"https://civitai.com/",
|
|
"https://civitai.red/",
|
|
"https://huggingface.co/",
|
|
"https://github.com/",
|
|
"http://localhost:",
|
|
)
|
|
if not any(url.startswith(src) for src in allowed_sources):
|
|
return web.json_response({"error": "url not allowed"}, status=400)
|
|
|
|
save_root = folder_paths.folder_names_and_paths[save_dir][0][0]
|
|
save_path = os.path.join(save_root, filename)
|
|
save_real = os.path.realpath(save_path)
|
|
root_real = os.path.realpath(save_root)
|
|
if not save_real.startswith(root_real + os.sep) and save_real != root_real:
|
|
return web.json_response({"error": "invalid path"}, status=400)
|
|
|
|
tmp_path = save_path + ".tmp"
|
|
token = json_data.get("token")
|
|
headers = {"Authorization": f"Bearer {token}"} if token else {}
|
|
|
|
key = self._download_progress_key(save_dir, filename)
|
|
loop = asyncio.get_running_loop()
|
|
try:
|
|
await loop.run_in_executor(
|
|
None,
|
|
self._download_file_sync,
|
|
url,
|
|
headers,
|
|
tmp_path,
|
|
save_path,
|
|
key,
|
|
)
|
|
self._set_download_progress(key, status="completed")
|
|
logging.info("Downloaded model to %s", save_path)
|
|
return web.json_response({"ok": True, "path": save_path, "save_dir": save_dir, "filename": filename})
|
|
except Exception as e:
|
|
logging.error("Failed to download model: %s", e)
|
|
self._set_download_progress(key, status="failed", error=str(e))
|
|
if os.path.exists(tmp_path):
|
|
os.remove(tmp_path)
|
|
return web.json_response({"error": str(e)}, status=500)
|
|
|
|
|
|
|
|
@routes.get("/download_model/progress")
|
|
async def get_download_model_progress(request):
|
|
save_dir = request.rel_url.query.get("save_dir")
|
|
filename = request.rel_url.query.get("filename")
|
|
if not save_dir or not filename:
|
|
return web.json_response({"error": "save_dir and filename required"}, status=400)
|
|
key = self._download_progress_key(save_dir, filename)
|
|
with self._download_lock:
|
|
entry = dict(self._download_progress.get(key, {}))
|
|
if not entry and save_dir in folder_paths.folder_names_and_paths:
|
|
root = folder_paths.folder_names_and_paths[save_dir][0][0]
|
|
tmp_path = os.path.join(root, filename + ".tmp")
|
|
final_path = os.path.join(root, filename)
|
|
if os.path.isfile(final_path):
|
|
size = os.path.getsize(final_path)
|
|
entry = {"status": "completed", "bytes_downloaded": size, "bytes_total": size}
|
|
elif os.path.isfile(tmp_path):
|
|
size = os.path.getsize(tmp_path)
|
|
entry = {"status": "running", "bytes_downloaded": size, "bytes_total": 0}
|
|
bytes_total = int(entry.get("bytes_total") or 0)
|
|
bytes_downloaded = int(entry.get("bytes_downloaded") or 0)
|
|
progress = (bytes_downloaded / bytes_total) if bytes_total > 0 else 0
|
|
return web.json_response({
|
|
"status": entry.get("status", "unknown"),
|
|
"bytes_downloaded": bytes_downloaded,
|
|
"bytes_total": bytes_total,
|
|
"progress": progress,
|
|
"error": entry.get("error"),
|
|
})
|
|
|
|
|
|
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
|
|
async def get_model_preview(request):
|
|
folder_name = request.match_info.get("folder", None)
|
|
path_index = int(request.match_info.get("path_index", None))
|
|
filename = request.match_info.get("filename", None)
|
|
|
|
if folder_name not in folder_paths.folder_names_and_paths:
|
|
return web.Response(status=404)
|
|
|
|
folders = folder_paths.folder_names_and_paths[folder_name]
|
|
folder = folders[0][path_index]
|
|
full_filename = os.path.join(folder, filename)
|
|
|
|
previews = self.get_model_previews(full_filename)
|
|
default_preview = previews[0] if len(previews) > 0 else None
|
|
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
|
|
return web.Response(status=404)
|
|
|
|
try:
|
|
with Image.open(default_preview) as img:
|
|
img_bytes = BytesIO()
|
|
img.save(img_bytes, format="WEBP")
|
|
img_bytes.seek(0)
|
|
return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
|
|
except:
|
|
return web.Response(status=404)
|
|
|
|
def get_model_file_list(self, folder_name: str):
|
|
folder_name = map_legacy(folder_name)
|
|
folders = folder_paths.folder_names_and_paths[folder_name]
|
|
output_list: list[dict] = []
|
|
|
|
for index, folder in enumerate(folders[0]):
|
|
if not os.path.isdir(folder):
|
|
continue
|
|
out = self.cache_model_file_list_(folder)
|
|
if out is None:
|
|
out = self.recursive_search_models_(folder, index)
|
|
self.set_cache(folder, out)
|
|
output_list.extend(out[0])
|
|
|
|
return output_list
|
|
|
|
def cache_model_file_list_(self, folder: str):
|
|
model_file_list_cache = self.get_cache(folder)
|
|
|
|
if model_file_list_cache is None:
|
|
return None
|
|
if not os.path.isdir(folder):
|
|
return None
|
|
if os.path.getmtime(folder) != model_file_list_cache[1]:
|
|
return None
|
|
for x in model_file_list_cache[1]:
|
|
time_modified = model_file_list_cache[1][x]
|
|
folder = x
|
|
if os.path.getmtime(folder) != time_modified:
|
|
return None
|
|
|
|
return model_file_list_cache
|
|
|
|
def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
|
|
if not os.path.isdir(directory):
|
|
return [], {}, time.perf_counter()
|
|
|
|
excluded_dir_names = [".git"]
|
|
# TODO use settings
|
|
include_hidden_files = False
|
|
|
|
result: list[str] = []
|
|
dirs: dict[str, float] = {}
|
|
|
|
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
|
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
|
if not include_hidden_files:
|
|
subdirs[:] = [d for d in subdirs if not d.startswith(".")]
|
|
filenames = [f for f in filenames if not f.startswith(".")]
|
|
|
|
filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
|
|
|
|
for file_name in filenames:
|
|
try:
|
|
full_path = os.path.join(dirpath, file_name)
|
|
relative_path = os.path.relpath(full_path, directory)
|
|
|
|
# Get file metadata
|
|
file_info = {
|
|
"name": relative_path,
|
|
"pathIndex": pathIndex,
|
|
"modified": os.path.getmtime(full_path), # Add modification time
|
|
"created": os.path.getctime(full_path), # Add creation time
|
|
"size": os.path.getsize(full_path) # Add file size
|
|
}
|
|
result.append(file_info)
|
|
|
|
except Exception as e:
|
|
logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
|
|
continue
|
|
|
|
for d in subdirs:
|
|
path: str = os.path.join(dirpath, d)
|
|
try:
|
|
dirs[path] = os.path.getmtime(path)
|
|
except FileNotFoundError:
|
|
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
|
continue
|
|
|
|
return result, dirs, time.perf_counter()
|
|
|
|
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
|
|
dirname = os.path.dirname(filepath)
|
|
|
|
if not os.path.exists(dirname):
|
|
return []
|
|
|
|
basename = os.path.splitext(filepath)[0]
|
|
match_files = glob.glob(f"{basename}.*", recursive=False)
|
|
image_files = filter_files_content_types(match_files, "image")
|
|
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
|
|
safetensors_metadata = {}
|
|
|
|
result: list[str | BytesIO] = []
|
|
|
|
for filename in image_files:
|
|
_basename = os.path.splitext(filename)[0]
|
|
if _basename == basename:
|
|
result.append(filename)
|
|
if _basename == f"{basename}.preview":
|
|
result.append(filename)
|
|
|
|
if safetensors_file:
|
|
safetensors_filepath = os.path.join(dirname, safetensors_file)
|
|
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
|
|
if header:
|
|
safetensors_metadata = json.loads(header)
|
|
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
|
|
if safetensors_images:
|
|
safetensors_images = json.loads(safetensors_images)
|
|
for image in safetensors_images:
|
|
result.append(BytesIO(base64.b64decode(image)))
|
|
|
|
return result
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self.clear_cache()
|