Add server-side missing model downloads

This commit is contained in:
adv0r 2026-05-18 15:29:15 +02:00
parent d4c6c9eff8
commit f9eac7477a
3 changed files with 366 additions and 0 deletions

View File

@ -2,7 +2,9 @@ from aiohttp import web
from typing import Optional
from folder_paths import folder_names_and_paths, get_directory_by_type
from api_server.services.terminal_service import TerminalService
from app.model_download import ModelDownloadError, download_model_to_destination, parse_model_download_request, resolve_model_download_destination
import app.logger
import logging
import os
class InternalRoutes:
@ -51,6 +53,32 @@ class InternalRoutes:
response[key] = folder_names_and_paths[key][0]
return web.json_response(response)
@self.routes.post('/models/download')
async def download_model(request):
try:
try:
json_data = await request.json()
except Exception as err:
raise ModelDownloadError("Expected a JSON request body.") from err
download_request = parse_model_download_request(json_data)
destination = resolve_model_download_destination(download_request)
if self.prompt_server.client_session is None:
raise ModelDownloadError("HTTP client session is not ready.", status=503)
result = await download_model_to_destination(
self.prompt_server.client_session,
download_request,
destination,
)
except ModelDownloadError as err:
return web.json_response({"error": str(err)}, status=err.status)
except Exception:
logging.exception("Failed to download model")
return web.json_response({"error": "Failed to download model."}, status=500)
response_status = 200 if result["status"] == "already_exists" else 201
return web.json_response(result, status=response_status)
@self.routes.get('/files/{directory_type}')
async def get_files(request: web.Request) -> web.Response:
directory_type = request.match_info['directory_type']

224
app/model_download.py Normal file
View File

@ -0,0 +1,224 @@
from __future__ import annotations
import os
import posixpath
from dataclasses import dataclass
from pathlib import PurePosixPath
from urllib.parse import urlparse
from aiohttp import ClientSession
import folder_paths
ALLOWED_DOWNLOAD_HOSTS = {"huggingface.co", "civitai.com", "civitai.red"}
ALLOWED_DOWNLOAD_SUFFIXES = (".safetensors", ".sft", ".ckpt", ".pth", ".pt")
BLOCKED_MODEL_FOLDERS = {"configs", "custom_nodes"}
CHUNK_SIZE = 1024 * 1024
WHITE_LISTED_DOWNLOAD_URLS = {
"https://huggingface.co/stabilityai/stable-zero123/resolve/main/stable_zero123.ckpt",
"https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth?download=true",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
}
@dataclass(frozen=True)
class ModelDownloadRequest:
name: str
url: str
directory: str
@dataclass(frozen=True)
class ModelDownloadDestination:
directory: str
relative_path: str
full_path: str
already_exists: bool = False
class ModelDownloadError(Exception):
def __init__(self, message: str, status: int = 400):
super().__init__(message)
self.status = status
def parse_model_download_request(data) -> ModelDownloadRequest:
if not isinstance(data, dict):
raise ModelDownloadError("Expected a JSON object.")
name = data.get("name")
url = data.get("url")
directory = data.get("directory")
if not isinstance(name, str) or not isinstance(url, str) or not isinstance(directory, str):
raise ModelDownloadError("Missing model name, URL, or directory.")
name = name.strip()
url = url.strip()
directory = directory.strip()
if not name or not url or not directory:
raise ModelDownloadError("Model name, URL, and directory are required.")
if not is_allowed_model_download_url(url):
raise ModelDownloadError("Model download URL is not allowed.")
relative_path = normalize_model_relative_path(name)
if not relative_path.lower().endswith(ALLOWED_DOWNLOAD_SUFFIXES):
raise ModelDownloadError("Model filename extension is not allowed.")
return ModelDownloadRequest(name=relative_path, url=url, directory=directory)
def is_allowed_model_download_url(url: str) -> bool:
if url in WHITE_LISTED_DOWNLOAD_URLS:
return True
try:
parsed = urlparse(url)
except ValueError:
return False
if parsed.scheme != "https":
return False
return (parsed.hostname or "").lower() in ALLOWED_DOWNLOAD_HOSTS
def normalize_model_relative_path(name: str) -> str:
if "\x00" in name:
raise ModelDownloadError("Model filename is invalid.")
candidate = name.replace("\\", "/")
if candidate.startswith("/"):
raise ModelDownloadError("Model filename must be relative.")
parts = PurePosixPath(candidate).parts
if not parts or any(part in ("", ".", "..") for part in parts):
raise ModelDownloadError("Model filename must stay inside the model folder.")
normalized = posixpath.normpath(candidate)
if normalized in ("", ".") or normalized.startswith("../"):
raise ModelDownloadError("Model filename must stay inside the model folder.")
return normalized
def resolve_model_download_destination(request: ModelDownloadRequest) -> ModelDownloadDestination:
directory = folder_paths.map_legacy(request.directory)
if directory in BLOCKED_MODEL_FOLDERS or directory not in folder_paths.folder_names_and_paths:
raise ModelDownloadError("Model directory is not allowed.", status=404)
existing_path = folder_paths.get_full_path(directory, request.name)
if existing_path is not None:
return ModelDownloadDestination(
directory=directory,
relative_path=request.name,
full_path=existing_path,
already_exists=True,
)
destination_root = find_writable_model_root(directory)
full_path = safe_join(destination_root, request.name)
return ModelDownloadDestination(
directory=directory,
relative_path=request.name,
full_path=full_path,
)
def find_writable_model_root(directory: str) -> str:
for root in folder_paths.get_folder_paths(directory):
try:
os.makedirs(root, exist_ok=True)
except OSError:
continue
if is_writable_directory(root):
return os.path.abspath(root)
raise ModelDownloadError("No writable model folder is configured.", status=403)
def is_writable_directory(path: str) -> bool:
probe = os.path.join(path, ".comfy-download-write-test")
try:
with open(probe, "xb"):
pass
os.remove(probe)
return True
except OSError:
try:
if os.path.exists(probe):
os.remove(probe)
except OSError:
pass
return False
def safe_join(root: str, relative_path: str) -> str:
root = os.path.abspath(root)
full_path = os.path.abspath(os.path.join(root, relative_path))
if os.path.commonpath((root, full_path)) != root:
raise ModelDownloadError("Model filename must stay inside the model folder.")
return full_path
async def download_model_to_destination(
session: ClientSession,
request: ModelDownloadRequest,
destination: ModelDownloadDestination,
) -> dict:
if destination.already_exists:
return download_response("already_exists", request, destination)
os.makedirs(os.path.dirname(destination.full_path), exist_ok=True)
partial_path = f"{destination.full_path}.part"
try:
fd = os.open(partial_path, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644)
except FileExistsError as err:
raise ModelDownloadError("A download for this model is already in progress.", status=409) from err
bytes_written = 0
try:
with os.fdopen(fd, "wb") as output:
async with session.get(request.url) as response:
if response.status >= 400:
raise ModelDownloadError(f"Model download failed with HTTP {response.status}.", status=502)
expected_size = response.content_length
async for chunk in response.content.iter_chunked(CHUNK_SIZE):
output.write(chunk)
bytes_written += len(chunk)
if expected_size is not None and bytes_written != expected_size:
raise ModelDownloadError("Model download ended before all bytes were received.", status=502)
os.replace(partial_path, destination.full_path)
except Exception:
try:
if os.path.exists(partial_path):
os.remove(partial_path)
except OSError:
pass
raise
return download_response("downloaded", request, destination, bytes_written)
def download_response(
status: str,
request: ModelDownloadRequest,
destination: ModelDownloadDestination,
size: int | None = None,
) -> dict:
response = {
"status": status,
"name": request.name,
"directory": destination.directory,
"path": destination.full_path,
}
if size is not None:
response["size"] = size
return response

View File

@ -0,0 +1,114 @@
import os
import pytest
import folder_paths
from app.model_download import (
ModelDownloadError,
ModelDownloadRequest,
is_allowed_model_download_url,
normalize_model_relative_path,
parse_model_download_request,
resolve_model_download_destination,
)
def test_parse_model_download_request_allows_huggingface_model_url():
request = parse_model_download_request({
"name": "nested/model.safetensors",
"url": "https://huggingface.co/org/repo/resolve/main/model.safetensors",
"directory": "checkpoints",
})
assert request == ModelDownloadRequest(
name="nested/model.safetensors",
url="https://huggingface.co/org/repo/resolve/main/model.safetensors",
directory="checkpoints",
)
@pytest.mark.parametrize(
"url",
[
"http://localhost:8000/model.safetensors",
"http://huggingface.co/org/repo/resolve/main/model.safetensors",
"https://example.com/model.safetensors",
],
)
def test_download_url_allowlist_rejects_untrusted_or_plain_http_urls(url):
assert is_allowed_model_download_url(url) is False
@pytest.mark.parametrize(
"name",
[
"../model.safetensors",
"nested/../../model.safetensors",
"/absolute/model.safetensors",
"model.safetensors\x00",
],
)
def test_normalize_model_relative_path_rejects_unsafe_paths(name):
with pytest.raises(ModelDownloadError):
normalize_model_relative_path(name)
def test_parse_model_download_request_rejects_unsupported_extensions():
with pytest.raises(ModelDownloadError):
parse_model_download_request({
"name": "model.gguf",
"url": "https://huggingface.co/org/repo/resolve/main/model.gguf",
"directory": "checkpoints",
})
def test_resolve_model_download_destination_uses_configured_model_folder(tmp_path, monkeypatch):
model_root = tmp_path / "models" / "checkpoints"
monkeypatch.setattr(folder_paths, "folder_names_and_paths", {
"checkpoints": ([str(model_root)], {".safetensors"}),
})
destination = resolve_model_download_destination(ModelDownloadRequest(
name="sub/model.safetensors",
url="https://huggingface.co/org/repo/resolve/main/model.safetensors",
directory="checkpoints",
))
assert destination.directory == "checkpoints"
assert destination.relative_path == "sub/model.safetensors"
assert destination.full_path == os.path.join(str(model_root), "sub", "model.safetensors")
assert destination.already_exists is False
def test_resolve_model_download_destination_reuses_existing_model(tmp_path, monkeypatch):
model_root = tmp_path / "models" / "checkpoints"
model_root.mkdir(parents=True)
existing = model_root / "model.safetensors"
existing.write_bytes(b"model")
monkeypatch.setattr(folder_paths, "folder_names_and_paths", {
"checkpoints": ([str(model_root)], {".safetensors"}),
})
destination = resolve_model_download_destination(ModelDownloadRequest(
name="model.safetensors",
url="https://huggingface.co/org/repo/resolve/main/model.safetensors",
directory="checkpoints",
))
assert destination.full_path == str(existing)
assert destination.already_exists is True
@pytest.mark.parametrize("directory", ["configs", "custom_nodes", "unknown"])
def test_resolve_model_download_destination_rejects_blocked_or_unknown_directories(tmp_path, monkeypatch, directory):
monkeypatch.setattr(folder_paths, "folder_names_and_paths", {
"configs": ([str(tmp_path / "configs")], {".yaml"}),
"custom_nodes": ([str(tmp_path / "custom_nodes")], set()),
})
with pytest.raises(ModelDownloadError):
resolve_model_download_destination(ModelDownloadRequest(
name="model.safetensors",
url="https://huggingface.co/org/repo/resolve/main/model.safetensors",
directory=directory,
))