add route to automatically download missing models

This commit is contained in:
teddav 2026-02-16 12:14:56 +01:00
parent 88e6370527
commit f075b35fd1
3 changed files with 236 additions and 3 deletions

View File

@ -8,15 +8,33 @@ import logging
import folder_paths import folder_paths
import glob import glob
import comfy.utils import comfy.utils
import uuid
from urllib.parse import urlparse
import aiohttp
from aiohttp import web from aiohttp import web
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
ALLOWED_MODEL_SOURCES = (
"https://civitai.com/",
"https://huggingface.co/",
"http://localhost:",
)
ALLOWED_MODEL_SUFFIXES = (".safetensors", ".sft")
WHITELISTED_MODEL_URLS = {
"https://huggingface.co/stabilityai/stable-zero123/resolve/main/stable_zero123.ckpt",
"https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth?download=true",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
}
DOWNLOAD_CHUNK_SIZE = 1024 * 1024
MAX_BULK_MODEL_DOWNLOADS = 200
class ModelFileManager: class ModelFileManager:
def __init__(self) -> None: def __init__(self, prompt_server) -> None:
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {} self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
self.prompt_server = prompt_server
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None: def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
return self.cache.get(key, default) return self.cache.get(key, default)
@ -75,6 +93,144 @@ class ModelFileManager:
return web.Response(body=img_bytes.getvalue(), content_type="image/webp") return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
except: except:
return web.Response(status=404) return web.Response(status=404)
@routes.post("/experiment/models/download_missing")
async def download_missing_models(request: web.Request) -> web.Response:
print("download_missing_models")
try:
payload = await request.json()
except Exception:
return web.json_response({"error": "Invalid JSON body"}, status=400)
print("download_missing_models")
print(payload)
models = payload.get("models")
if not isinstance(models, list):
return web.json_response({"error": "Field 'models' must be a list"}, status=400)
if len(models) > MAX_BULK_MODEL_DOWNLOADS:
return web.json_response(
{"error": f"Maximum of {MAX_BULK_MODEL_DOWNLOADS} models allowed per request"},
status=400,
)
results = []
downloaded = 0
skipped = 0
failed = 0
session = self.prompt_server.client_session
owns_session = False
if session is None or session.closed:
timeout = aiohttp.ClientTimeout(total=None)
session = aiohttp.ClientSession(timeout=timeout)
owns_session = True
try:
for model_entry in models:
model_name, model_directory, model_url = _normalize_model_entry(model_entry)
if not model_name or not model_directory or not model_url:
failed += 1
results.append(
{
"name": model_name or "",
"directory": model_directory or "",
"url": model_url or "",
"status": "failed",
"error": "Each model must include non-empty name, directory, and url",
}
)
continue
if not _is_http_url(model_url):
failed += 1
results.append(
{
"name": model_name,
"directory": model_directory,
"url": model_url,
"status": "failed",
"error": "URL must be an absolute HTTP/HTTPS URL",
}
)
continue
allowed, reason = _is_model_download_allowed(model_name, model_url)
if not allowed:
failed += 1
results.append(
{
"name": model_name,
"directory": model_directory,
"url": model_url,
"status": "blocked",
"error": reason,
}
)
continue
try:
destination = _resolve_download_destination(model_directory, model_name)
except Exception as exc:
failed += 1
results.append(
{
"name": model_name,
"directory": model_directory,
"url": model_url,
"status": "failed",
"error": str(exc),
}
)
continue
if os.path.exists(destination):
skipped += 1
results.append(
{
"name": model_name,
"directory": model_directory,
"url": model_url,
"status": "skipped_existing",
}
)
continue
try:
await _download_file(session, model_url, destination)
downloaded += 1
results.append(
{
"name": model_name,
"directory": model_directory,
"url": model_url,
"status": "downloaded",
}
)
except Exception as exc:
failed += 1
results.append(
{
"name": model_name,
"directory": model_directory,
"url": model_url,
"status": "failed",
"error": str(exc),
}
)
finally:
if owns_session:
await session.close()
return web.json_response(
{
"downloaded": downloaded,
"skipped": skipped,
"failed": failed,
"results": results,
},
status=200,
)
def get_model_file_list(self, folder_name: str): def get_model_file_list(self, folder_name: str):
folder_name = map_legacy(folder_name) folder_name = map_legacy(folder_name)
@ -193,3 +349,75 @@ class ModelFileManager:
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
self.clear_cache() self.clear_cache()
def _is_model_download_allowed(model_name: str, model_url: str) -> tuple[bool, str | None]:
if model_url in WHITELISTED_MODEL_URLS:
return True, None
if not any(model_url.startswith(source) for source in ALLOWED_MODEL_SOURCES):
return (
False,
f"Download not allowed from source '{model_url}'.",
)
if not any(model_name.endswith(suffix) for suffix in ALLOWED_MODEL_SUFFIXES):
return (
False,
f"Only allowed suffixes are: {', '.join(ALLOWED_MODEL_SUFFIXES)}",
)
return True, None
def _resolve_download_destination(directory: str, model_name: str) -> str:
if directory not in folder_paths.folder_names_and_paths:
raise ValueError(f"Unknown model directory '{directory}'")
model_paths = folder_paths.folder_names_and_paths[directory][0]
if not model_paths:
raise ValueError(f"No filesystem paths configured for '{directory}'")
base_path = os.path.abspath(model_paths[0])
normalized_name = os.path.normpath(model_name).lstrip("/\\")
if not normalized_name or normalized_name == ".":
raise ValueError("Model name cannot be empty")
destination = os.path.abspath(os.path.join(base_path, normalized_name))
if os.path.commonpath((base_path, destination)) != base_path:
raise ValueError("Model path escapes configured model directory")
destination_parent = os.path.dirname(destination)
if destination_parent:
os.makedirs(destination_parent, exist_ok=True)
return destination
async def _download_file(session: aiohttp.ClientSession, url: str, destination: str) -> None:
temp_file = f"{destination}.part-{uuid.uuid4().hex}"
try:
async with session.get(url, allow_redirects=True) as response:
response.raise_for_status()
with open(temp_file, "wb") as file_handle:
async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
if chunk:
file_handle.write(chunk)
os.replace(temp_file, destination)
finally:
if os.path.exists(temp_file):
os.remove(temp_file)
def _normalize_model_entry(model_entry: object) -> tuple[str | None, str | None, str | None]:
if not isinstance(model_entry, dict):
return None, None, None
model_name = str(model_entry.get("name", "")).strip()
model_directory = str(model_entry.get("directory", "")).strip()
model_url = str(model_entry.get("url", "")).strip()
return model_name, model_directory, model_url
def _is_http_url(url: str) -> bool:
parsed = urlparse(url)
return parsed.scheme in ("http", "https") and bool(parsed.netloc)

View File

@ -202,7 +202,7 @@ class PromptServer():
mimetypes.add_type('image/webp', '.webp') mimetypes.add_type('image/webp', '.webp')
self.user_manager = UserManager() self.user_manager = UserManager()
self.model_file_manager = ModelFileManager() self.model_file_manager = ModelFileManager(self)
self.custom_node_manager = CustomNodeManager() self.custom_node_manager = CustomNodeManager()
self.subgraph_manager = SubgraphManager() self.subgraph_manager = SubgraphManager()
self.node_replace_manager = NodeReplaceManager() self.node_replace_manager = NodeReplaceManager()

View File

@ -12,9 +12,14 @@ pytestmark = (
pytest.mark.asyncio pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module ) # This applies the asyncio mark to all test functions in the module
class DummyPromptServer:
def __init__(self):
self.client_session = None
@pytest.fixture @pytest.fixture
def model_manager(): def model_manager():
return ModelFileManager() prompt_server = DummyPromptServer()
return ModelFileManager(prompt_server)
@pytest.fixture @pytest.fixture
def app(model_manager): def app(model_manager):