This commit is contained in:
Nicolò Paternoster 2026-05-19 09:27:05 +00:00 committed by GitHub
commit 2f54dd88cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 553 additions and 0 deletions

View File

@ -2,7 +2,9 @@ from aiohttp import web
from typing import Optional
from folder_paths import folder_names_and_paths, get_directory_by_type
from api_server.services.terminal_service import TerminalService
from app.model_download import ModelDownloadError, download_model_to_destination, parse_model_download_request, resolve_model_download_destination
import app.logger
import logging
import os
class InternalRoutes:
@ -51,6 +53,32 @@ class InternalRoutes:
response[key] = folder_names_and_paths[key][0]
return web.json_response(response)
@self.routes.post('/models/download')
async def download_model(request):
try:
try:
json_data = await request.json()
except Exception as err:
raise ModelDownloadError("Expected a JSON request body.") from err
download_request = parse_model_download_request(json_data)
destination = resolve_model_download_destination(download_request)
if self.prompt_server.client_session is None:
raise ModelDownloadError("HTTP client session is not ready.", status=503)
result = await download_model_to_destination(
self.prompt_server.client_session,
download_request,
destination,
)
except ModelDownloadError as err:
return web.json_response({"error": str(err)}, status=err.status)
except Exception as err:
logging.exception("Failed to download model: %s", err)
return web.json_response({"error": "Failed to download model."}, status=500)
response_status = 200 if result["status"] == "already_exists" else 201
return web.json_response(result, status=response_status)
@self.routes.get('/files/{directory_type}')
async def get_files(request: web.Request) -> web.Response:
directory_type = request.match_info['directory_type']

281
app/model_download.py Normal file
View File

@ -0,0 +1,281 @@
from __future__ import annotations
import os
import posixpath
from dataclasses import dataclass
from pathlib import PurePosixPath
from urllib.parse import urljoin, urlparse
from aiohttp import ClientSession, ClientTimeout
import folder_paths
ALLOWED_DOWNLOAD_HOSTS = {"huggingface.co", "civitai.com", "civitai.red"}
ALLOWED_DOWNLOAD_SUFFIXES = (".safetensors", ".sft", ".ckpt", ".pth", ".pt")
BLOCKED_MODEL_FOLDERS = {"configs", "custom_nodes"}
CHUNK_SIZE = 1024 * 1024
# Bound the network call so a hung remote eventually surfaces an error
# instead of blocking the request handler forever. ``sock_read`` is the
# inter-chunk read timeout, which is the right knob for long downloads:
# a slow-but-progressing transfer keeps making progress, while a stalled
# socket fails predictably.
DOWNLOAD_TIMEOUT = ClientTimeout(total=None, connect=30, sock_connect=30, sock_read=300)
# Maximum number of redirects we follow manually. Hugging Face typically
# redirects ``/resolve/main/...`` to a single CDN URL, so a small budget
# is enough while still preventing redirect loops.
MAX_DOWNLOAD_REDIRECTS = 5
WHITE_LISTED_DOWNLOAD_URLS = {
"https://huggingface.co/stabilityai/stable-zero123/resolve/main/stable_zero123.ckpt",
"https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth?download=true",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
}
@dataclass(frozen=True)
class ModelDownloadRequest:
name: str
url: str
directory: str
@dataclass(frozen=True)
class ModelDownloadDestination:
directory: str
relative_path: str
full_path: str
already_exists: bool = False
class ModelDownloadError(Exception):
def __init__(self, message: str, status: int = 400):
super().__init__(message)
self.status = status
def parse_model_download_request(data) -> ModelDownloadRequest:
if not isinstance(data, dict):
raise ModelDownloadError("Expected a JSON object.")
name = data.get("name")
url = data.get("url")
directory = data.get("directory")
if not isinstance(name, str) or not isinstance(url, str) or not isinstance(directory, str):
raise ModelDownloadError("Missing model name, URL, or directory.")
name = name.strip()
url = url.strip()
directory = directory.strip()
if not name or not url or not directory:
raise ModelDownloadError("Model name, URL, and directory are required.")
if not is_allowed_model_download_url(url):
raise ModelDownloadError("Model download URL is not allowed.")
relative_path = normalize_model_relative_path(name)
if not relative_path.lower().endswith(ALLOWED_DOWNLOAD_SUFFIXES):
raise ModelDownloadError("Model filename extension is not allowed.")
return ModelDownloadRequest(name=relative_path, url=url, directory=directory)
def is_allowed_model_download_url(url: str) -> bool:
"""Return True for URLs we are willing to fetch on behalf of the user.
The same predicate is applied to the user-supplied URL and to every
redirect target, so SSRF via redirects on an allowed host is contained
to the same allowlist. Subdomains of allowlisted hosts are accepted
because Hugging Face and Civitai both serve actual file payloads from
CDN subdomains (e.g. ``cdn-lfs.huggingface.co``).
"""
if url in WHITE_LISTED_DOWNLOAD_URLS:
return True
try:
parsed = urlparse(url)
except ValueError:
return False
if parsed.scheme != "https":
return False
host = (parsed.hostname or "").lower()
if not host:
return False
for allowed in ALLOWED_DOWNLOAD_HOSTS:
if host == allowed or host.endswith("." + allowed):
return True
return False
def normalize_model_relative_path(name: str) -> str:
if "\x00" in name:
raise ModelDownloadError("Model filename is invalid.")
candidate = name.replace("\\", "/")
if candidate.startswith("/"):
raise ModelDownloadError("Model filename must be relative.")
parts = PurePosixPath(candidate).parts
if not parts or any(part in ("", ".", "..") for part in parts):
raise ModelDownloadError("Model filename must stay inside the model folder.")
normalized = posixpath.normpath(candidate)
if normalized in ("", ".") or normalized.startswith("../"):
raise ModelDownloadError("Model filename must stay inside the model folder.")
return normalized
def resolve_model_download_destination(request: ModelDownloadRequest) -> ModelDownloadDestination:
directory = folder_paths.map_legacy(request.directory)
if directory in BLOCKED_MODEL_FOLDERS or directory not in folder_paths.folder_names_and_paths:
raise ModelDownloadError("Model directory is not allowed.", status=404)
existing_path = folder_paths.get_full_path(directory, request.name)
if existing_path is not None:
return ModelDownloadDestination(
directory=directory,
relative_path=request.name,
full_path=existing_path,
already_exists=True,
)
destination_root = find_writable_model_root(directory)
full_path = safe_join(destination_root, request.name)
return ModelDownloadDestination(
directory=directory,
relative_path=request.name,
full_path=full_path,
)
def find_writable_model_root(directory: str) -> str:
for root in folder_paths.get_folder_paths(directory):
try:
os.makedirs(root, exist_ok=True)
except OSError:
continue
if is_writable_directory(root):
return os.path.abspath(root)
raise ModelDownloadError("No writable model folder is configured.", status=403)
def is_writable_directory(path: str) -> bool:
probe = os.path.join(path, ".comfy-download-write-test")
try:
with open(probe, "xb"):
pass
os.remove(probe)
return True
except OSError:
try:
if os.path.exists(probe):
os.remove(probe)
except OSError:
pass
return False
def safe_join(root: str, relative_path: str) -> str:
root = os.path.abspath(root)
full_path = os.path.abspath(os.path.join(root, relative_path))
if os.path.commonpath((root, full_path)) != root:
raise ModelDownloadError("Model filename must stay inside the model folder.")
return full_path
async def open_model_download_response(session: ClientSession, url: str):
"""GET ``url`` with explicit timeout and an allowlist-checked redirect chain.
aiohttp follows redirects by default, which would let an allowed host
redirect to an arbitrary internal target (SSRF). We disable automatic
following and validate every ``Location`` against the same allowlist
used for the initial URL.
"""
current_url = url
for _ in range(MAX_DOWNLOAD_REDIRECTS + 1):
response = await session.get(
current_url,
allow_redirects=False,
timeout=DOWNLOAD_TIMEOUT,
)
if response.status not in (301, 302, 303, 307, 308):
return response
location = response.headers.get("Location", "").strip()
response.release()
if not location:
raise ModelDownloadError("Redirect response missing Location header.", status=502)
next_url = urljoin(current_url, location)
if not is_allowed_model_download_url(next_url):
raise ModelDownloadError("Model download redirect target is not allowed.", status=403)
current_url = next_url
raise ModelDownloadError("Too many redirects while downloading model.", status=502)
async def download_model_to_destination(
session: ClientSession,
request: ModelDownloadRequest,
destination: ModelDownloadDestination,
) -> dict:
if destination.already_exists:
return download_response("already_exists", request, destination)
os.makedirs(os.path.dirname(destination.full_path), exist_ok=True)
partial_path = f"{destination.full_path}.part"
try:
fd = os.open(partial_path, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644)
except FileExistsError as err:
raise ModelDownloadError("A download for this model is already in progress.", status=409) from err
bytes_written = 0
try:
with os.fdopen(fd, "wb") as output:
async with await open_model_download_response(session, request.url) as response:
if response.status >= 400:
raise ModelDownloadError(f"Model download failed with HTTP {response.status}.", status=502)
expected_size = response.content_length
async for chunk in response.content.iter_chunked(CHUNK_SIZE):
output.write(chunk)
bytes_written += len(chunk)
if expected_size is not None and bytes_written != expected_size:
raise ModelDownloadError("Model download ended before all bytes were received.", status=502)
os.replace(partial_path, destination.full_path)
except Exception:
try:
if os.path.exists(partial_path):
os.remove(partial_path)
except OSError:
pass
raise
return download_response("downloaded", request, destination, bytes_written)
def download_response(
status: str,
request: ModelDownloadRequest,
destination: ModelDownloadDestination,
size: int | None = None,
) -> dict:
response = {
"status": status,
"name": request.name,
"directory": destination.directory,
"path": destination.full_path,
}
if size is not None:
response["size"] = size
return response

View File

@ -0,0 +1,244 @@
import os
import pytest
import folder_paths
from app.model_download import (
ModelDownloadError,
ModelDownloadRequest,
is_allowed_model_download_url,
normalize_model_relative_path,
open_model_download_response,
parse_model_download_request,
resolve_model_download_destination,
)
class _FakeResponse:
"""Minimal stand-in for ``aiohttp.ClientResponse`` for the redirect tests."""
def __init__(self, status, headers=None):
self.status = status
self.headers = headers or {}
self.released = False
def release(self):
self.released = True
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
self.released = True
class _FakeSession:
"""Hands out queued ``_FakeResponse`` objects in order."""
def __init__(self, responses):
self._responses = list(responses)
self.calls = []
async def get(self, url, allow_redirects, timeout):
self.calls.append((url, allow_redirects))
if not self._responses:
raise AssertionError("Unexpected extra session.get call")
return self._responses.pop(0)
def test_parse_model_download_request_allows_huggingface_model_url():
request = parse_model_download_request({
"name": "nested/model.safetensors",
"url": "https://huggingface.co/org/repo/resolve/main/model.safetensors",
"directory": "checkpoints",
})
assert request == ModelDownloadRequest(
name="nested/model.safetensors",
url="https://huggingface.co/org/repo/resolve/main/model.safetensors",
directory="checkpoints",
)
@pytest.mark.parametrize(
"url",
[
"http://localhost:8000/model.safetensors",
"http://huggingface.co/org/repo/resolve/main/model.safetensors",
"https://example.com/model.safetensors",
"https://huggingface.co.evil.com/model.safetensors",
],
)
def test_download_url_allowlist_rejects_untrusted_or_plain_http_urls(url):
assert is_allowed_model_download_url(url) is False
@pytest.mark.parametrize(
"url",
[
# Direct HF model URLs.
"https://huggingface.co/org/repo/resolve/main/model.safetensors",
# HF LFS CDN subdomains: this is where `/resolve/main/...` redirects
# land, so the allowlist must accept them or downloads break.
"https://cdn-lfs.huggingface.co/repos/abc/def/model.safetensors",
"https://cdn-lfs-us-1.huggingface.co/repos/abc/def/model.safetensors",
# Civitai download endpoints (PR objective: support Civitai too).
"https://civitai.com/api/download/models/12345",
"https://civitai.red/api/download/models/12345",
],
)
def test_download_url_allowlist_accepts_huggingface_and_civitai_urls(url):
assert is_allowed_model_download_url(url) is True
@pytest.mark.parametrize(
"name, expected",
[
("model.safetensors", "model.safetensors"),
("sub/model.safetensors", "sub/model.safetensors"),
("nested/dir/model.safetensors", "nested/dir/model.safetensors"),
# Backslashes are normalized to forward slashes so Windows-style
# paths land in the same place as the POSIX equivalents.
("nested\\dir\\model.safetensors", "nested/dir/model.safetensors"),
],
)
def test_normalize_model_relative_path_accepts_safe_paths(name, expected):
assert normalize_model_relative_path(name) == expected
@pytest.mark.parametrize(
"name",
[
"../model.safetensors",
"nested/../../model.safetensors",
"/absolute/model.safetensors",
"model.safetensors\x00",
],
)
def test_normalize_model_relative_path_rejects_unsafe_paths(name):
with pytest.raises(ModelDownloadError):
normalize_model_relative_path(name)
def test_parse_model_download_request_rejects_unsupported_extensions():
with pytest.raises(ModelDownloadError):
parse_model_download_request({
"name": "model.gguf",
"url": "https://huggingface.co/org/repo/resolve/main/model.gguf",
"directory": "checkpoints",
})
def test_resolve_model_download_destination_uses_configured_model_folder(tmp_path, monkeypatch):
model_root = tmp_path / "models" / "checkpoints"
monkeypatch.setattr(folder_paths, "folder_names_and_paths", {
"checkpoints": ([str(model_root)], {".safetensors"}),
})
destination = resolve_model_download_destination(ModelDownloadRequest(
name="sub/model.safetensors",
url="https://huggingface.co/org/repo/resolve/main/model.safetensors",
directory="checkpoints",
))
assert destination.directory == "checkpoints"
assert destination.relative_path == "sub/model.safetensors"
assert destination.full_path == os.path.join(str(model_root), "sub", "model.safetensors")
assert destination.already_exists is False
def test_resolve_model_download_destination_reuses_existing_model(tmp_path, monkeypatch):
model_root = tmp_path / "models" / "checkpoints"
model_root.mkdir(parents=True)
existing = model_root / "model.safetensors"
existing.write_bytes(b"model")
monkeypatch.setattr(folder_paths, "folder_names_and_paths", {
"checkpoints": ([str(model_root)], {".safetensors"}),
})
destination = resolve_model_download_destination(ModelDownloadRequest(
name="model.safetensors",
url="https://huggingface.co/org/repo/resolve/main/model.safetensors",
directory="checkpoints",
))
assert destination.full_path == str(existing)
assert destination.already_exists is True
@pytest.mark.parametrize("directory", ["configs", "custom_nodes", "unknown"])
def test_resolve_model_download_destination_rejects_blocked_or_unknown_directories(tmp_path, monkeypatch, directory):
monkeypatch.setattr(folder_paths, "folder_names_and_paths", {
"configs": ([str(tmp_path / "configs")], {".yaml"}),
"custom_nodes": ([str(tmp_path / "custom_nodes")], set()),
})
with pytest.raises(ModelDownloadError):
resolve_model_download_destination(ModelDownloadRequest(
name="model.safetensors",
url="https://huggingface.co/org/repo/resolve/main/model.safetensors",
directory=directory,
))
@pytest.mark.asyncio
async def test_open_model_download_response_follows_allowed_subdomain_redirect():
"""HF redirects /resolve/main/... to cdn-lfs.huggingface.co; that must work."""
session = _FakeSession([
_FakeResponse(302, {"Location": "https://cdn-lfs.huggingface.co/repos/abc/model.safetensors"}),
_FakeResponse(200),
])
response = await open_model_download_response(
session, "https://huggingface.co/org/repo/resolve/main/model.safetensors"
)
assert response.status == 200
assert session.calls == [
("https://huggingface.co/org/repo/resolve/main/model.safetensors", False),
("https://cdn-lfs.huggingface.co/repos/abc/model.safetensors", False),
]
@pytest.mark.asyncio
async def test_open_model_download_response_rejects_offsite_redirect():
"""A redirect leaving the allowlist must surface as a 403 instead of being followed."""
session = _FakeSession([
_FakeResponse(302, {"Location": "https://attacker.example.com/payload"}),
])
with pytest.raises(ModelDownloadError) as exc_info:
await open_model_download_response(
session, "https://huggingface.co/org/repo/resolve/main/model.safetensors"
)
assert exc_info.value.status == 403
# The initial request was issued with redirects disabled, otherwise
# the validation above would be a no-op.
assert session.calls[0][1] is False
@pytest.mark.asyncio
async def test_open_model_download_response_rejects_redirect_without_location():
session = _FakeSession([_FakeResponse(302)])
with pytest.raises(ModelDownloadError) as exc_info:
await open_model_download_response(
session, "https://huggingface.co/org/repo/resolve/main/model.safetensors"
)
assert exc_info.value.status == 502
@pytest.mark.asyncio
async def test_open_model_download_response_stops_after_too_many_redirects():
session = _FakeSession(
[_FakeResponse(302, {"Location": "https://cdn-lfs.huggingface.co/loop"})] * 10
)
with pytest.raises(ModelDownloadError) as exc_info:
await open_model_download_response(
session, "https://huggingface.co/org/repo/resolve/main/model.safetensors"
)
assert exc_info.value.status == 502