mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-23 18:13:28 +08:00
Merge f4c0e1d269 into 16cd8d8a8f
This commit is contained in:
commit
b096bb27ad
@ -8,15 +8,44 @@ import logging
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import glob
|
import glob
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import uuid
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from typing import Awaitable, Callable
|
||||||
|
import aiohttp
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
|
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
|
||||||
|
|
||||||
|
|
||||||
|
ALLOWED_MODEL_SOURCES = (
|
||||||
|
"https://civitai.com/",
|
||||||
|
"https://huggingface.co/",
|
||||||
|
"http://localhost:",
|
||||||
|
)
|
||||||
|
ALLOWED_MODEL_SUFFIXES = (".safetensors", ".sft")
|
||||||
|
WHITELISTED_MODEL_URLS = {
|
||||||
|
"https://huggingface.co/stabilityai/stable-zero123/resolve/main/stable_zero123.ckpt",
|
||||||
|
"https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth?download=true",
|
||||||
|
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||||
|
}
|
||||||
|
DOWNLOAD_CHUNK_SIZE = 1024 * 1024
|
||||||
|
MAX_BULK_MODEL_DOWNLOADS = 200
|
||||||
|
DownloadProgressCallback = Callable[[int], Awaitable[None]]
|
||||||
|
DownloadShouldCancel = Callable[[], bool]
|
||||||
|
DOWNLOAD_PROGRESS_MIN_INTERVAL = 0.25
|
||||||
|
DOWNLOAD_PROGRESS_MIN_BYTES = 4 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadCancelledError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ModelFileManager:
|
class ModelFileManager:
|
||||||
def __init__(self) -> None:
|
def __init__(self, prompt_server) -> None:
|
||||||
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
||||||
|
self.prompt_server = prompt_server
|
||||||
|
self._cancelled_missing_model_downloads: set[str] = set()
|
||||||
|
|
||||||
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
|
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
|
||||||
return self.cache.get(key, default)
|
return self.cache.get(key, default)
|
||||||
@ -75,6 +104,288 @@ class ModelFileManager:
|
|||||||
return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
|
return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
|
||||||
except:
|
except:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
|
|
||||||
|
@routes.post("/experiment/models/download_missing")
|
||||||
|
async def download_missing_models(request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
except Exception:
|
||||||
|
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
||||||
|
|
||||||
|
models = payload.get("models")
|
||||||
|
if not isinstance(models, list):
|
||||||
|
return web.json_response({"error": "Field 'models' must be a list"}, status=400)
|
||||||
|
if len(models) > MAX_BULK_MODEL_DOWNLOADS:
|
||||||
|
return web.json_response(
|
||||||
|
{"error": f"Maximum of {MAX_BULK_MODEL_DOWNLOADS} models allowed per request"},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
target_client_id = str(payload.get("client_id", "")).strip() or None
|
||||||
|
batch_id = str(payload.get("batch_id", "")).strip() or uuid.uuid4().hex
|
||||||
|
|
||||||
|
def emit_download_event(
|
||||||
|
*,
|
||||||
|
task_id: str,
|
||||||
|
model_name: str,
|
||||||
|
model_directory: str,
|
||||||
|
model_url: str,
|
||||||
|
status: str,
|
||||||
|
bytes_downloaded: int = 0,
|
||||||
|
error: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
message = {
|
||||||
|
"batch_id": batch_id,
|
||||||
|
"task_id": task_id,
|
||||||
|
"name": model_name,
|
||||||
|
"directory": model_directory,
|
||||||
|
"url": model_url,
|
||||||
|
"status": status,
|
||||||
|
"bytes_downloaded": bytes_downloaded
|
||||||
|
}
|
||||||
|
if error:
|
||||||
|
message["error"] = error
|
||||||
|
|
||||||
|
self.prompt_server.send_sync("missing_model_download", message, target_client_id)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
downloaded = 0
|
||||||
|
skipped = 0
|
||||||
|
canceled = 0
|
||||||
|
failed = 0
|
||||||
|
|
||||||
|
session = self.prompt_server.client_session
|
||||||
|
owns_session = False
|
||||||
|
if session is None or session.closed:
|
||||||
|
timeout = aiohttp.ClientTimeout(total=None)
|
||||||
|
session = aiohttp.ClientSession(timeout=timeout)
|
||||||
|
owns_session = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
for model_entry in models:
|
||||||
|
model_name, model_directory, model_url = _normalize_model_entry(model_entry)
|
||||||
|
task_id = uuid.uuid4().hex
|
||||||
|
self._cancelled_missing_model_downloads.discard(task_id)
|
||||||
|
|
||||||
|
if not model_name or not model_directory or not model_url:
|
||||||
|
failed += 1
|
||||||
|
error = "Each model must include non-empty name, directory, and url"
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"name": model_name or "",
|
||||||
|
"directory": model_directory or "",
|
||||||
|
"url": model_url or "",
|
||||||
|
"status": "failed",
|
||||||
|
"error": error,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
emit_download_event(
|
||||||
|
task_id=task_id,
|
||||||
|
model_name=model_name or "",
|
||||||
|
model_directory=model_directory or "",
|
||||||
|
model_url=model_url or "",
|
||||||
|
status="failed",
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not _is_http_url(model_url):
|
||||||
|
failed += 1
|
||||||
|
error = "URL must be an absolute HTTP/HTTPS URL"
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"name": model_name,
|
||||||
|
"directory": model_directory,
|
||||||
|
"url": model_url,
|
||||||
|
"status": "failed",
|
||||||
|
"error": error,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
emit_download_event(
|
||||||
|
task_id=task_id,
|
||||||
|
model_name=model_name,
|
||||||
|
model_directory=model_directory,
|
||||||
|
model_url=model_url,
|
||||||
|
status="failed",
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
allowed, reason = _is_model_download_allowed(model_name, model_url)
|
||||||
|
if not allowed:
|
||||||
|
failed += 1
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"name": model_name,
|
||||||
|
"directory": model_directory,
|
||||||
|
"url": model_url,
|
||||||
|
"status": "blocked",
|
||||||
|
"error": reason,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
emit_download_event(
|
||||||
|
task_id=task_id,
|
||||||
|
model_name=model_name,
|
||||||
|
model_directory=model_directory,
|
||||||
|
model_url=model_url,
|
||||||
|
status="blocked",
|
||||||
|
error=reason,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
destination = _resolve_download_destination(model_directory, model_name)
|
||||||
|
except Exception as exc:
|
||||||
|
failed += 1
|
||||||
|
error = str(exc)
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"name": model_name,
|
||||||
|
"directory": model_directory,
|
||||||
|
"url": model_url,
|
||||||
|
"status": "failed",
|
||||||
|
"error": error,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
emit_download_event(
|
||||||
|
task_id=task_id,
|
||||||
|
model_name=model_name,
|
||||||
|
model_directory=model_directory,
|
||||||
|
model_url=model_url,
|
||||||
|
status="failed",
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if os.path.exists(destination):
|
||||||
|
skipped += 1
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"name": model_name,
|
||||||
|
"directory": model_directory,
|
||||||
|
"url": model_url,
|
||||||
|
"status": "skipped_existing",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
emit_download_event(
|
||||||
|
task_id=task_id,
|
||||||
|
model_name=model_name,
|
||||||
|
model_directory=model_directory,
|
||||||
|
model_url=model_url,
|
||||||
|
status="skipped_existing",
|
||||||
|
bytes_downloaded=0
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
latest_downloaded = 0
|
||||||
|
async def on_progress(bytes_downloaded: int) -> None:
|
||||||
|
nonlocal latest_downloaded
|
||||||
|
latest_downloaded = bytes_downloaded
|
||||||
|
emit_download_event(
|
||||||
|
task_id=task_id,
|
||||||
|
model_name=model_name,
|
||||||
|
model_directory=model_directory,
|
||||||
|
model_url=model_url,
|
||||||
|
status="running",
|
||||||
|
bytes_downloaded=bytes_downloaded
|
||||||
|
)
|
||||||
|
|
||||||
|
await _download_file(
|
||||||
|
session,
|
||||||
|
model_url,
|
||||||
|
destination,
|
||||||
|
progress_callback=on_progress,
|
||||||
|
should_cancel=lambda: task_id in self._cancelled_missing_model_downloads
|
||||||
|
)
|
||||||
|
downloaded += 1
|
||||||
|
final_size = os.path.getsize(destination) if os.path.exists(destination) else latest_downloaded
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"name": model_name,
|
||||||
|
"directory": model_directory,
|
||||||
|
"url": model_url,
|
||||||
|
"status": "downloaded",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
emit_download_event(
|
||||||
|
task_id=task_id,
|
||||||
|
model_name=model_name,
|
||||||
|
model_directory=model_directory,
|
||||||
|
model_url=model_url,
|
||||||
|
status="completed",
|
||||||
|
bytes_downloaded=final_size
|
||||||
|
)
|
||||||
|
except DownloadCancelledError:
|
||||||
|
canceled += 1
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"name": model_name,
|
||||||
|
"directory": model_directory,
|
||||||
|
"url": model_url,
|
||||||
|
"status": "canceled",
|
||||||
|
"error": "Download canceled",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
emit_download_event(
|
||||||
|
task_id=task_id,
|
||||||
|
model_name=model_name,
|
||||||
|
model_directory=model_directory,
|
||||||
|
model_url=model_url,
|
||||||
|
status="canceled",
|
||||||
|
bytes_downloaded=latest_downloaded,
|
||||||
|
error="Download canceled",
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
failed += 1
|
||||||
|
error = str(exc)
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"name": model_name,
|
||||||
|
"directory": model_directory,
|
||||||
|
"url": model_url,
|
||||||
|
"status": "failed",
|
||||||
|
"error": error,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
emit_download_event(
|
||||||
|
task_id=task_id,
|
||||||
|
model_name=model_name,
|
||||||
|
model_directory=model_directory,
|
||||||
|
model_url=model_url,
|
||||||
|
status="failed",
|
||||||
|
error=error
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._cancelled_missing_model_downloads.discard(task_id)
|
||||||
|
finally:
|
||||||
|
if owns_session:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"downloaded": downloaded,
|
||||||
|
"skipped": skipped,
|
||||||
|
"canceled": canceled,
|
||||||
|
"failed": failed,
|
||||||
|
"results": results,
|
||||||
|
},
|
||||||
|
status=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
@routes.post("/experiment/models/download_missing/cancel")
|
||||||
|
async def cancel_download_missing_model(request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
except Exception:
|
||||||
|
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
||||||
|
|
||||||
|
task_id = str(payload.get("task_id", "")).strip()
|
||||||
|
if not task_id:
|
||||||
|
return web.json_response({"error": "Field 'task_id' is required"}, status=400)
|
||||||
|
|
||||||
|
self._cancelled_missing_model_downloads.add(task_id)
|
||||||
|
return web.json_response({"ok": True, "task_id": task_id}, status=200)
|
||||||
|
|
||||||
def get_model_file_list(self, folder_name: str):
|
def get_model_file_list(self, folder_name: str):
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
@ -193,3 +504,105 @@ class ModelFileManager:
|
|||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
|
|
||||||
|
def _is_model_download_allowed(model_name: str, model_url: str) -> tuple[bool, str | None]:
|
||||||
|
if model_url in WHITELISTED_MODEL_URLS:
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
if not any(model_url.startswith(source) for source in ALLOWED_MODEL_SOURCES):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"Download not allowed from source '{model_url}'.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not any(model_name.endswith(suffix) for suffix in ALLOWED_MODEL_SUFFIXES):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"Only allowed suffixes are: {', '.join(ALLOWED_MODEL_SUFFIXES)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_download_destination(directory: str, model_name: str) -> str:
|
||||||
|
if directory not in folder_paths.folder_names_and_paths:
|
||||||
|
raise ValueError(f"Unknown model directory '{directory}'")
|
||||||
|
|
||||||
|
model_paths = folder_paths.folder_names_and_paths[directory][0]
|
||||||
|
if not model_paths:
|
||||||
|
raise ValueError(f"No filesystem paths configured for '{directory}'")
|
||||||
|
|
||||||
|
base_path = os.path.abspath(model_paths[0])
|
||||||
|
normalized_name = os.path.normpath(model_name).lstrip("/\\")
|
||||||
|
if not normalized_name or normalized_name == ".":
|
||||||
|
raise ValueError("Model name cannot be empty")
|
||||||
|
|
||||||
|
destination = os.path.abspath(os.path.join(base_path, normalized_name))
|
||||||
|
if os.path.commonpath((base_path, destination)) != base_path:
|
||||||
|
raise ValueError("Model path escapes configured model directory")
|
||||||
|
|
||||||
|
destination_parent = os.path.dirname(destination)
|
||||||
|
if destination_parent:
|
||||||
|
os.makedirs(destination_parent, exist_ok=True)
|
||||||
|
|
||||||
|
return destination
|
||||||
|
|
||||||
|
|
||||||
|
async def _download_file(
|
||||||
|
session: aiohttp.ClientSession,
|
||||||
|
url: str,
|
||||||
|
destination: str,
|
||||||
|
progress_callback: DownloadProgressCallback | None = None,
|
||||||
|
should_cancel: DownloadShouldCancel | None = None,
|
||||||
|
progress_min_interval: float = DOWNLOAD_PROGRESS_MIN_INTERVAL,
|
||||||
|
progress_min_bytes: int = DOWNLOAD_PROGRESS_MIN_BYTES,
|
||||||
|
) -> None:
|
||||||
|
temp_file = f"{destination}.{uuid.uuid4().hex}.temp"
|
||||||
|
try:
|
||||||
|
if should_cancel is not None and should_cancel():
|
||||||
|
raise DownloadCancelledError("Download canceled")
|
||||||
|
async with session.get(url, allow_redirects=True) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
bytes_downloaded = 0
|
||||||
|
if progress_callback is not None:
|
||||||
|
await progress_callback(bytes_downloaded)
|
||||||
|
last_progress_emit_time = time.monotonic()
|
||||||
|
last_progress_emit_bytes = 0
|
||||||
|
with open(temp_file, "wb") as file_handle:
|
||||||
|
async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
|
||||||
|
if should_cancel is not None and should_cancel():
|
||||||
|
raise DownloadCancelledError("Download canceled")
|
||||||
|
if chunk:
|
||||||
|
file_handle.write(chunk)
|
||||||
|
bytes_downloaded += len(chunk)
|
||||||
|
if progress_callback is not None:
|
||||||
|
now = time.monotonic()
|
||||||
|
should_emit = (
|
||||||
|
bytes_downloaded - last_progress_emit_bytes >= progress_min_bytes
|
||||||
|
or now - last_progress_emit_time >= progress_min_interval
|
||||||
|
)
|
||||||
|
if should_emit:
|
||||||
|
await progress_callback(bytes_downloaded)
|
||||||
|
last_progress_emit_time = now
|
||||||
|
last_progress_emit_bytes = bytes_downloaded
|
||||||
|
if progress_callback is not None and bytes_downloaded != last_progress_emit_bytes:
|
||||||
|
await progress_callback(bytes_downloaded)
|
||||||
|
os.replace(temp_file, destination)
|
||||||
|
finally:
|
||||||
|
if os.path.exists(temp_file):
|
||||||
|
os.remove(temp_file)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_model_entry(model_entry: object) -> tuple[str | None, str | None, str | None]:
|
||||||
|
if not isinstance(model_entry, dict):
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
model_name = str(model_entry.get("name", "")).strip()
|
||||||
|
model_directory = str(model_entry.get("directory", "")).strip()
|
||||||
|
model_url = str(model_entry.get("url", "")).strip()
|
||||||
|
return model_name, model_directory, model_url
|
||||||
|
|
||||||
|
|
||||||
|
def _is_http_url(url: str) -> bool:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
return parsed.scheme in ("http", "https") and bool(parsed.netloc)
|
||||||
|
|||||||
@ -198,7 +198,7 @@ class PromptServer():
|
|||||||
PromptServer.instance = self
|
PromptServer.instance = self
|
||||||
|
|
||||||
self.user_manager = UserManager()
|
self.user_manager = UserManager()
|
||||||
self.model_file_manager = ModelFileManager()
|
self.model_file_manager = ModelFileManager(self)
|
||||||
self.custom_node_manager = CustomNodeManager()
|
self.custom_node_manager = CustomNodeManager()
|
||||||
self.subgraph_manager = SubgraphManager()
|
self.subgraph_manager = SubgraphManager()
|
||||||
self.node_replace_manager = NodeReplaceManager()
|
self.node_replace_manager = NodeReplaceManager()
|
||||||
|
|||||||
@ -12,9 +12,14 @@ pytestmark = (
|
|||||||
pytest.mark.asyncio
|
pytest.mark.asyncio
|
||||||
) # This applies the asyncio mark to all test functions in the module
|
) # This applies the asyncio mark to all test functions in the module
|
||||||
|
|
||||||
|
class DummyPromptServer:
|
||||||
|
def __init__(self):
|
||||||
|
self.client_session = None
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def model_manager():
|
def model_manager():
|
||||||
return ModelFileManager()
|
prompt_server = DummyPromptServer()
|
||||||
|
return ModelFileManager(prompt_server)
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def app(model_manager):
|
def app(model_manager):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user