add route to automatically download missing models

This commit is contained in:
teddav 2026-02-16 12:14:56 +01:00
parent 88e6370527
commit f075b35fd1
3 changed files with 236 additions and 3 deletions

View File

@ -8,15 +8,33 @@ import logging
import folder_paths
import glob
import comfy.utils
import uuid
from urllib.parse import urlparse
import aiohttp
from aiohttp import web
from PIL import Image
from io import BytesIO
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
ALLOWED_MODEL_SOURCES = (
"https://civitai.com/",
"https://huggingface.co/",
"http://localhost:",
)
ALLOWED_MODEL_SUFFIXES = (".safetensors", ".sft")
WHITELISTED_MODEL_URLS = {
"https://huggingface.co/stabilityai/stable-zero123/resolve/main/stable_zero123.ckpt",
"https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth?download=true",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
}
DOWNLOAD_CHUNK_SIZE = 1024 * 1024
MAX_BULK_MODEL_DOWNLOADS = 200
class ModelFileManager:
def __init__(self) -> None:
def __init__(self, prompt_server) -> None:
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
self.prompt_server = prompt_server
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
return self.cache.get(key, default)
@ -75,6 +93,144 @@ class ModelFileManager:
return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
except:
return web.Response(status=404)
@routes.post("/experiment/models/download_missing")
async def download_missing_models(request: web.Request) -> web.Response:
print("download_missing_models")
try:
payload = await request.json()
except Exception:
return web.json_response({"error": "Invalid JSON body"}, status=400)
print("download_missing_models")
print(payload)
models = payload.get("models")
if not isinstance(models, list):
return web.json_response({"error": "Field 'models' must be a list"}, status=400)
if len(models) > MAX_BULK_MODEL_DOWNLOADS:
return web.json_response(
{"error": f"Maximum of {MAX_BULK_MODEL_DOWNLOADS} models allowed per request"},
status=400,
)
results = []
downloaded = 0
skipped = 0
failed = 0
session = self.prompt_server.client_session
owns_session = False
if session is None or session.closed:
timeout = aiohttp.ClientTimeout(total=None)
session = aiohttp.ClientSession(timeout=timeout)
owns_session = True
try:
for model_entry in models:
model_name, model_directory, model_url = _normalize_model_entry(model_entry)
if not model_name or not model_directory or not model_url:
failed += 1
results.append(
{
"name": model_name or "",
"directory": model_directory or "",
"url": model_url or "",
"status": "failed",
"error": "Each model must include non-empty name, directory, and url",
}
)
continue
if not _is_http_url(model_url):
failed += 1
results.append(
{
"name": model_name,
"directory": model_directory,
"url": model_url,
"status": "failed",
"error": "URL must be an absolute HTTP/HTTPS URL",
}
)
continue
allowed, reason = _is_model_download_allowed(model_name, model_url)
if not allowed:
failed += 1
results.append(
{
"name": model_name,
"directory": model_directory,
"url": model_url,
"status": "blocked",
"error": reason,
}
)
continue
try:
destination = _resolve_download_destination(model_directory, model_name)
except Exception as exc:
failed += 1
results.append(
{
"name": model_name,
"directory": model_directory,
"url": model_url,
"status": "failed",
"error": str(exc),
}
)
continue
if os.path.exists(destination):
skipped += 1
results.append(
{
"name": model_name,
"directory": model_directory,
"url": model_url,
"status": "skipped_existing",
}
)
continue
try:
await _download_file(session, model_url, destination)
downloaded += 1
results.append(
{
"name": model_name,
"directory": model_directory,
"url": model_url,
"status": "downloaded",
}
)
except Exception as exc:
failed += 1
results.append(
{
"name": model_name,
"directory": model_directory,
"url": model_url,
"status": "failed",
"error": str(exc),
}
)
finally:
if owns_session:
await session.close()
return web.json_response(
{
"downloaded": downloaded,
"skipped": skipped,
"failed": failed,
"results": results,
},
status=200,
)
def get_model_file_list(self, folder_name: str):
folder_name = map_legacy(folder_name)
@ -193,3 +349,75 @@ class ModelFileManager:
def __exit__(self, exc_type, exc_value, traceback):
self.clear_cache()
def _is_model_download_allowed(model_name: str, model_url: str) -> tuple[bool, str | None]:
if model_url in WHITELISTED_MODEL_URLS:
return True, None
if not any(model_url.startswith(source) for source in ALLOWED_MODEL_SOURCES):
return (
False,
f"Download not allowed from source '{model_url}'.",
)
if not any(model_name.endswith(suffix) for suffix in ALLOWED_MODEL_SUFFIXES):
return (
False,
f"Only allowed suffixes are: {', '.join(ALLOWED_MODEL_SUFFIXES)}",
)
return True, None
def _resolve_download_destination(directory: str, model_name: str) -> str:
if directory not in folder_paths.folder_names_and_paths:
raise ValueError(f"Unknown model directory '{directory}'")
model_paths = folder_paths.folder_names_and_paths[directory][0]
if not model_paths:
raise ValueError(f"No filesystem paths configured for '{directory}'")
base_path = os.path.abspath(model_paths[0])
normalized_name = os.path.normpath(model_name).lstrip("/\\")
if not normalized_name or normalized_name == ".":
raise ValueError("Model name cannot be empty")
destination = os.path.abspath(os.path.join(base_path, normalized_name))
if os.path.commonpath((base_path, destination)) != base_path:
raise ValueError("Model path escapes configured model directory")
destination_parent = os.path.dirname(destination)
if destination_parent:
os.makedirs(destination_parent, exist_ok=True)
return destination
async def _download_file(session: aiohttp.ClientSession, url: str, destination: str) -> None:
temp_file = f"{destination}.part-{uuid.uuid4().hex}"
try:
async with session.get(url, allow_redirects=True) as response:
response.raise_for_status()
with open(temp_file, "wb") as file_handle:
async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
if chunk:
file_handle.write(chunk)
os.replace(temp_file, destination)
finally:
if os.path.exists(temp_file):
os.remove(temp_file)
def _normalize_model_entry(model_entry: object) -> tuple[str | None, str | None, str | None]:
if not isinstance(model_entry, dict):
return None, None, None
model_name = str(model_entry.get("name", "")).strip()
model_directory = str(model_entry.get("directory", "")).strip()
model_url = str(model_entry.get("url", "")).strip()
return model_name, model_directory, model_url
def _is_http_url(url: str) -> bool:
parsed = urlparse(url)
return parsed.scheme in ("http", "https") and bool(parsed.netloc)

View File

@ -202,7 +202,7 @@ class PromptServer():
mimetypes.add_type('image/webp', '.webp')
self.user_manager = UserManager()
self.model_file_manager = ModelFileManager()
self.model_file_manager = ModelFileManager(self)
self.custom_node_manager = CustomNodeManager()
self.subgraph_manager = SubgraphManager()
self.node_replace_manager = NodeReplaceManager()

View File

@ -12,9 +12,14 @@ pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
class DummyPromptServer:
def __init__(self):
self.client_session = None
@pytest.fixture
def model_manager():
return ModelFileManager()
prompt_server = DummyPromptServer()
return ModelFileManager(prompt_server)
@pytest.fixture
def app(model_manager):