From f9eac7477a5b8371ca0af06b26f8f3889066a88c Mon Sep 17 00:00:00 2001 From: adv0r Date: Mon, 18 May 2026 15:29:15 +0200 Subject: [PATCH 1/2] Add server-side missing model downloads --- api_server/routes/internal/internal_routes.py | 28 +++ app/model_download.py | 224 ++++++++++++++++++ tests-unit/app_test/model_download_test.py | 114 +++++++++ 3 files changed, 366 insertions(+) create mode 100644 app/model_download.py create mode 100644 tests-unit/app_test/model_download_test.py diff --git a/api_server/routes/internal/internal_routes.py b/api_server/routes/internal/internal_routes.py index 1477afa01..2d24ecdf5 100644 --- a/api_server/routes/internal/internal_routes.py +++ b/api_server/routes/internal/internal_routes.py @@ -2,7 +2,9 @@ from aiohttp import web from typing import Optional from folder_paths import folder_names_and_paths, get_directory_by_type from api_server.services.terminal_service import TerminalService +from app.model_download import ModelDownloadError, download_model_to_destination, parse_model_download_request, resolve_model_download_destination import app.logger +import logging import os class InternalRoutes: @@ -51,6 +53,32 @@ class InternalRoutes: response[key] = folder_names_and_paths[key][0] return web.json_response(response) + @self.routes.post('/models/download') + async def download_model(request): + try: + try: + json_data = await request.json() + except Exception as err: + raise ModelDownloadError("Expected a JSON request body.") from err + + download_request = parse_model_download_request(json_data) + destination = resolve_model_download_destination(download_request) + if self.prompt_server.client_session is None: + raise ModelDownloadError("HTTP client session is not ready.", status=503) + result = await download_model_to_destination( + self.prompt_server.client_session, + download_request, + destination, + ) + except ModelDownloadError as err: + return web.json_response({"error": str(err)}, status=err.status) + except Exception: + logging.exception("Failed to download model") + return web.json_response({"error": "Failed to download model."}, status=500) + + response_status = 200 if result["status"] == "already_exists" else 201 + return web.json_response(result, status=response_status) + @self.routes.get('/files/{directory_type}') async def get_files(request: web.Request) -> web.Response: directory_type = request.match_info['directory_type'] diff --git a/app/model_download.py b/app/model_download.py new file mode 100644 index 000000000..959786a0a --- /dev/null +++ b/app/model_download.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +import os +import posixpath +from dataclasses import dataclass +from pathlib import PurePosixPath +from urllib.parse import urlparse + +from aiohttp import ClientSession + +import folder_paths + + +ALLOWED_DOWNLOAD_HOSTS = {"huggingface.co", "civitai.com", "civitai.red"} +ALLOWED_DOWNLOAD_SUFFIXES = (".safetensors", ".sft", ".ckpt", ".pth", ".pt") +BLOCKED_MODEL_FOLDERS = {"configs", "custom_nodes"} +CHUNK_SIZE = 1024 * 1024 + +WHITE_LISTED_DOWNLOAD_URLS = { + "https://huggingface.co/stabilityai/stable-zero123/resolve/main/stable_zero123.ckpt", + "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth?download=true", + "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", +} + + +@dataclass(frozen=True) +class ModelDownloadRequest: + name: str + url: str + directory: str + + +@dataclass(frozen=True) +class ModelDownloadDestination: + directory: str + relative_path: str + full_path: str + already_exists: bool = False + + +class ModelDownloadError(Exception): + def __init__(self, message: str, status: int = 400): + super().__init__(message) + self.status = status + + +def parse_model_download_request(data) -> ModelDownloadRequest: + if not isinstance(data, dict): + raise ModelDownloadError("Expected a JSON object.") + + name = data.get("name") + url = data.get("url") + directory = data.get("directory") + if not isinstance(name, str) or not isinstance(url, str) or not isinstance(directory, str): + raise ModelDownloadError("Missing model name, URL, or directory.") + + name = name.strip() + url = url.strip() + directory = directory.strip() + if not name or not url or not directory: + raise ModelDownloadError("Model name, URL, and directory are required.") + + if not is_allowed_model_download_url(url): + raise ModelDownloadError("Model download URL is not allowed.") + + relative_path = normalize_model_relative_path(name) + if not relative_path.lower().endswith(ALLOWED_DOWNLOAD_SUFFIXES): + raise ModelDownloadError("Model filename extension is not allowed.") + + return ModelDownloadRequest(name=relative_path, url=url, directory=directory) + + +def is_allowed_model_download_url(url: str) -> bool: + if url in WHITE_LISTED_DOWNLOAD_URLS: + return True + + try: + parsed = urlparse(url) + except ValueError: + return False + + if parsed.scheme != "https": + return False + + return (parsed.hostname or "").lower() in ALLOWED_DOWNLOAD_HOSTS + + +def normalize_model_relative_path(name: str) -> str: + if "\x00" in name: + raise ModelDownloadError("Model filename is invalid.") + + candidate = name.replace("\\", "/") + if candidate.startswith("/"): + raise ModelDownloadError("Model filename must be relative.") + + parts = PurePosixPath(candidate).parts + if not parts or any(part in ("", ".", "..") for part in parts): + raise ModelDownloadError("Model filename must stay inside the model folder.") + + normalized = posixpath.normpath(candidate) + if normalized in ("", ".") or normalized.startswith("../"): + raise ModelDownloadError("Model filename must stay inside the model folder.") + + return normalized + + +def resolve_model_download_destination(request: ModelDownloadRequest) -> ModelDownloadDestination: + directory = folder_paths.map_legacy(request.directory) + if directory in BLOCKED_MODEL_FOLDERS or directory not in folder_paths.folder_names_and_paths: + raise ModelDownloadError("Model directory is not allowed.", status=404) + + existing_path = folder_paths.get_full_path(directory, request.name) + if existing_path is not None: + return ModelDownloadDestination( + directory=directory, + relative_path=request.name, + full_path=existing_path, + already_exists=True, + ) + + destination_root = find_writable_model_root(directory) + full_path = safe_join(destination_root, request.name) + return ModelDownloadDestination( + directory=directory, + relative_path=request.name, + full_path=full_path, + ) + + +def find_writable_model_root(directory: str) -> str: + for root in folder_paths.get_folder_paths(directory): + try: + os.makedirs(root, exist_ok=True) + except OSError: + continue + + if is_writable_directory(root): + return os.path.abspath(root) + + raise ModelDownloadError("No writable model folder is configured.", status=403) + + +def is_writable_directory(path: str) -> bool: + probe = os.path.join(path, ".comfy-download-write-test") + try: + with open(probe, "xb"): + pass + os.remove(probe) + return True + except OSError: + try: + if os.path.exists(probe): + os.remove(probe) + except OSError: + pass + return False + + +def safe_join(root: str, relative_path: str) -> str: + root = os.path.abspath(root) + full_path = os.path.abspath(os.path.join(root, relative_path)) + if os.path.commonpath((root, full_path)) != root: + raise ModelDownloadError("Model filename must stay inside the model folder.") + return full_path + + +async def download_model_to_destination( + session: ClientSession, + request: ModelDownloadRequest, + destination: ModelDownloadDestination, +) -> dict: + if destination.already_exists: + return download_response("already_exists", request, destination) + + os.makedirs(os.path.dirname(destination.full_path), exist_ok=True) + partial_path = f"{destination.full_path}.part" + + try: + fd = os.open(partial_path, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644) + except FileExistsError as err: + raise ModelDownloadError("A download for this model is already in progress.", status=409) from err + + bytes_written = 0 + try: + with os.fdopen(fd, "wb") as output: + async with session.get(request.url) as response: + if response.status >= 400: + raise ModelDownloadError(f"Model download failed with HTTP {response.status}.", status=502) + + expected_size = response.content_length + async for chunk in response.content.iter_chunked(CHUNK_SIZE): + output.write(chunk) + bytes_written += len(chunk) + + if expected_size is not None and bytes_written != expected_size: + raise ModelDownloadError("Model download ended before all bytes were received.", status=502) + + os.replace(partial_path, destination.full_path) + except Exception: + try: + if os.path.exists(partial_path): + os.remove(partial_path) + except OSError: + pass + raise + + return download_response("downloaded", request, destination, bytes_written) + + +def download_response( + status: str, + request: ModelDownloadRequest, + destination: ModelDownloadDestination, + size: int | None = None, +) -> dict: + response = { + "status": status, + "name": request.name, + "directory": destination.directory, + "path": destination.full_path, + } + if size is not None: + response["size"] = size + return response diff --git a/tests-unit/app_test/model_download_test.py b/tests-unit/app_test/model_download_test.py new file mode 100644 index 000000000..3dece976c --- /dev/null +++ b/tests-unit/app_test/model_download_test.py @@ -0,0 +1,114 @@ +import os + +import pytest + +import folder_paths +from app.model_download import ( + ModelDownloadError, + ModelDownloadRequest, + is_allowed_model_download_url, + normalize_model_relative_path, + parse_model_download_request, + resolve_model_download_destination, +) + + +def test_parse_model_download_request_allows_huggingface_model_url(): + request = parse_model_download_request({ + "name": "nested/model.safetensors", + "url": "https://huggingface.co/org/repo/resolve/main/model.safetensors", + "directory": "checkpoints", + }) + + assert request == ModelDownloadRequest( + name="nested/model.safetensors", + url="https://huggingface.co/org/repo/resolve/main/model.safetensors", + directory="checkpoints", + ) + + +@pytest.mark.parametrize( + "url", + [ + "http://localhost:8000/model.safetensors", + "http://huggingface.co/org/repo/resolve/main/model.safetensors", + "https://example.com/model.safetensors", + ], +) +def test_download_url_allowlist_rejects_untrusted_or_plain_http_urls(url): + assert is_allowed_model_download_url(url) is False + + +@pytest.mark.parametrize( + "name", + [ + "../model.safetensors", + "nested/../../model.safetensors", + "/absolute/model.safetensors", + "model.safetensors\x00", + ], +) +def test_normalize_model_relative_path_rejects_unsafe_paths(name): + with pytest.raises(ModelDownloadError): + normalize_model_relative_path(name) + + +def test_parse_model_download_request_rejects_unsupported_extensions(): + with pytest.raises(ModelDownloadError): + parse_model_download_request({ + "name": "model.gguf", + "url": "https://huggingface.co/org/repo/resolve/main/model.gguf", + "directory": "checkpoints", + }) + + +def test_resolve_model_download_destination_uses_configured_model_folder(tmp_path, monkeypatch): + model_root = tmp_path / "models" / "checkpoints" + monkeypatch.setattr(folder_paths, "folder_names_and_paths", { + "checkpoints": ([str(model_root)], {".safetensors"}), + }) + + destination = resolve_model_download_destination(ModelDownloadRequest( + name="sub/model.safetensors", + url="https://huggingface.co/org/repo/resolve/main/model.safetensors", + directory="checkpoints", + )) + + assert destination.directory == "checkpoints" + assert destination.relative_path == "sub/model.safetensors" + assert destination.full_path == os.path.join(str(model_root), "sub", "model.safetensors") + assert destination.already_exists is False + + +def test_resolve_model_download_destination_reuses_existing_model(tmp_path, monkeypatch): + model_root = tmp_path / "models" / "checkpoints" + model_root.mkdir(parents=True) + existing = model_root / "model.safetensors" + existing.write_bytes(b"model") + monkeypatch.setattr(folder_paths, "folder_names_and_paths", { + "checkpoints": ([str(model_root)], {".safetensors"}), + }) + + destination = resolve_model_download_destination(ModelDownloadRequest( + name="model.safetensors", + url="https://huggingface.co/org/repo/resolve/main/model.safetensors", + directory="checkpoints", + )) + + assert destination.full_path == str(existing) + assert destination.already_exists is True + + +@pytest.mark.parametrize("directory", ["configs", "custom_nodes", "unknown"]) +def test_resolve_model_download_destination_rejects_blocked_or_unknown_directories(tmp_path, monkeypatch, directory): + monkeypatch.setattr(folder_paths, "folder_names_and_paths", { + "configs": ([str(tmp_path / "configs")], {".yaml"}), + "custom_nodes": ([str(tmp_path / "custom_nodes")], set()), + }) + + with pytest.raises(ModelDownloadError): + resolve_model_download_destination(ModelDownloadRequest( + name="model.safetensors", + url="https://huggingface.co/org/repo/resolve/main/model.safetensors", + directory=directory, + )) From 15d49a61b8e960939aef5de8e96985b80e5a3ad2 Mon Sep 17 00:00:00 2001 From: adv0r <> Date: Tue, 19 May 2026 11:26:53 +0200 Subject: [PATCH 2/2] Address review feedback on /internal/models/download - Disable aiohttp auto-redirects and re-validate every Location target against the same allowlist used for the initial URL, closing an SSRF vector where an allowed host could redirect to an arbitrary internal endpoint. - Accept subdomains of allowlisted hosts so Hugging Face's LFS CDN (cdn-lfs.huggingface.co et al.) keeps working under the stricter redirect handling. - Pass an explicit ClientTimeout (connect/sock_read) so hung remotes surface as errors instead of blocking the request handler forever. - Log the exception value alongside the traceback on the 500 fallback. - Add positive coverage for normalize_model_relative_path, Civitai URL allowlisting, and the redirect-following / SSRF-rejection branches of open_model_download_response. Co-authored-by: Cursor --- api_server/routes/internal/internal_routes.py | 4 +- app/model_download.py | 65 ++++++++- tests-unit/app_test/model_download_test.py | 130 ++++++++++++++++++ 3 files changed, 193 insertions(+), 6 deletions(-) diff --git a/api_server/routes/internal/internal_routes.py b/api_server/routes/internal/internal_routes.py index 2d24ecdf5..89af3777a 100644 --- a/api_server/routes/internal/internal_routes.py +++ b/api_server/routes/internal/internal_routes.py @@ -72,8 +72,8 @@ class InternalRoutes: ) except ModelDownloadError as err: return web.json_response({"error": str(err)}, status=err.status) - except Exception: - logging.exception("Failed to download model") + except Exception as err: + logging.exception("Failed to download model: %s", err) return web.json_response({"error": "Failed to download model."}, status=500) response_status = 200 if result["status"] == "already_exists" else 201 diff --git a/app/model_download.py b/app/model_download.py index 959786a0a..d2fb9ae1b 100644 --- a/app/model_download.py +++ b/app/model_download.py @@ -4,9 +4,9 @@ import os import posixpath from dataclasses import dataclass from pathlib import PurePosixPath -from urllib.parse import urlparse +from urllib.parse import urljoin, urlparse -from aiohttp import ClientSession +from aiohttp import ClientSession, ClientTimeout import folder_paths @@ -16,6 +16,18 @@ ALLOWED_DOWNLOAD_SUFFIXES = (".safetensors", ".sft", ".ckpt", ".pth", ".pt") BLOCKED_MODEL_FOLDERS = {"configs", "custom_nodes"} CHUNK_SIZE = 1024 * 1024 +# Bound the network call so a hung remote eventually surfaces an error +# instead of blocking the request handler forever. ``sock_read`` is the +# inter-chunk read timeout, which is the right knob for long downloads: +# a slow-but-progressing transfer keeps making progress, while a stalled +# socket fails predictably. +DOWNLOAD_TIMEOUT = ClientTimeout(total=None, connect=30, sock_connect=30, sock_read=300) + +# Maximum number of redirects we follow manually. Hugging Face typically +# redirects ``/resolve/main/...`` to a single CDN URL, so a small budget +# is enough while still preventing redirect loops. +MAX_DOWNLOAD_REDIRECTS = 5 + WHITE_LISTED_DOWNLOAD_URLS = { "https://huggingface.co/stabilityai/stable-zero123/resolve/main/stable_zero123.ckpt", "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth?download=true", @@ -71,6 +83,14 @@ def parse_model_download_request(data) -> ModelDownloadRequest: def is_allowed_model_download_url(url: str) -> bool: + """Return True for URLs we are willing to fetch on behalf of the user. + + The same predicate is applied to the user-supplied URL and to every + redirect target, so SSRF via redirects on an allowed host is contained + to the same allowlist. Subdomains of allowlisted hosts are accepted + because Hugging Face and Civitai both serve actual file payloads from + CDN subdomains (e.g. ``cdn-lfs.huggingface.co``). + """ if url in WHITE_LISTED_DOWNLOAD_URLS: return True @@ -82,7 +102,14 @@ def is_allowed_model_download_url(url: str) -> bool: if parsed.scheme != "https": return False - return (parsed.hostname or "").lower() in ALLOWED_DOWNLOAD_HOSTS + host = (parsed.hostname or "").lower() + if not host: + return False + + for allowed in ALLOWED_DOWNLOAD_HOSTS: + if host == allowed or host.endswith("." + allowed): + return True + return False def normalize_model_relative_path(name: str) -> str: @@ -164,6 +191,36 @@ def safe_join(root: str, relative_path: str) -> str: return full_path +async def open_model_download_response(session: ClientSession, url: str): + """GET ``url`` with explicit timeout and an allowlist-checked redirect chain. + + aiohttp follows redirects by default, which would let an allowed host + redirect to an arbitrary internal target (SSRF). We disable automatic + following and validate every ``Location`` against the same allowlist + used for the initial URL. + """ + current_url = url + for _ in range(MAX_DOWNLOAD_REDIRECTS + 1): + response = await session.get( + current_url, + allow_redirects=False, + timeout=DOWNLOAD_TIMEOUT, + ) + if response.status not in (301, 302, 303, 307, 308): + return response + + location = response.headers.get("Location", "").strip() + response.release() + if not location: + raise ModelDownloadError("Redirect response missing Location header.", status=502) + next_url = urljoin(current_url, location) + if not is_allowed_model_download_url(next_url): + raise ModelDownloadError("Model download redirect target is not allowed.", status=403) + current_url = next_url + + raise ModelDownloadError("Too many redirects while downloading model.", status=502) + + async def download_model_to_destination( session: ClientSession, request: ModelDownloadRequest, @@ -183,7 +240,7 @@ async def download_model_to_destination( bytes_written = 0 try: with os.fdopen(fd, "wb") as output: - async with session.get(request.url) as response: + async with await open_model_download_response(session, request.url) as response: if response.status >= 400: raise ModelDownloadError(f"Model download failed with HTTP {response.status}.", status=502) diff --git a/tests-unit/app_test/model_download_test.py b/tests-unit/app_test/model_download_test.py index 3dece976c..c4a8a3424 100644 --- a/tests-unit/app_test/model_download_test.py +++ b/tests-unit/app_test/model_download_test.py @@ -8,11 +8,44 @@ from app.model_download import ( ModelDownloadRequest, is_allowed_model_download_url, normalize_model_relative_path, + open_model_download_response, parse_model_download_request, resolve_model_download_destination, ) +class _FakeResponse: + """Minimal stand-in for ``aiohttp.ClientResponse`` for the redirect tests.""" + + def __init__(self, status, headers=None): + self.status = status + self.headers = headers or {} + self.released = False + + def release(self): + self.released = True + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + self.released = True + + +class _FakeSession: + """Hands out queued ``_FakeResponse`` objects in order.""" + + def __init__(self, responses): + self._responses = list(responses) + self.calls = [] + + async def get(self, url, allow_redirects, timeout): + self.calls.append((url, allow_redirects)) + if not self._responses: + raise AssertionError("Unexpected extra session.get call") + return self._responses.pop(0) + + def test_parse_model_download_request_allows_huggingface_model_url(): request = parse_model_download_request({ "name": "nested/model.safetensors", @@ -33,12 +66,46 @@ def test_parse_model_download_request_allows_huggingface_model_url(): "http://localhost:8000/model.safetensors", "http://huggingface.co/org/repo/resolve/main/model.safetensors", "https://example.com/model.safetensors", + "https://huggingface.co.evil.com/model.safetensors", ], ) def test_download_url_allowlist_rejects_untrusted_or_plain_http_urls(url): assert is_allowed_model_download_url(url) is False +@pytest.mark.parametrize( + "url", + [ + # Direct HF model URLs. + "https://huggingface.co/org/repo/resolve/main/model.safetensors", + # HF LFS CDN subdomains: this is where `/resolve/main/...` redirects + # land, so the allowlist must accept them or downloads break. + "https://cdn-lfs.huggingface.co/repos/abc/def/model.safetensors", + "https://cdn-lfs-us-1.huggingface.co/repos/abc/def/model.safetensors", + # Civitai download endpoints (PR objective: support Civitai too). + "https://civitai.com/api/download/models/12345", + "https://civitai.red/api/download/models/12345", + ], +) +def test_download_url_allowlist_accepts_huggingface_and_civitai_urls(url): + assert is_allowed_model_download_url(url) is True + + +@pytest.mark.parametrize( + "name, expected", + [ + ("model.safetensors", "model.safetensors"), + ("sub/model.safetensors", "sub/model.safetensors"), + ("nested/dir/model.safetensors", "nested/dir/model.safetensors"), + # Backslashes are normalized to forward slashes so Windows-style + # paths land in the same place as the POSIX equivalents. + ("nested\\dir\\model.safetensors", "nested/dir/model.safetensors"), + ], +) +def test_normalize_model_relative_path_accepts_safe_paths(name, expected): + assert normalize_model_relative_path(name) == expected + + @pytest.mark.parametrize( "name", [ @@ -112,3 +179,66 @@ def test_resolve_model_download_destination_rejects_blocked_or_unknown_directori url="https://huggingface.co/org/repo/resolve/main/model.safetensors", directory=directory, )) + + +@pytest.mark.asyncio +async def test_open_model_download_response_follows_allowed_subdomain_redirect(): + """HF redirects /resolve/main/... to cdn-lfs.huggingface.co; that must work.""" + session = _FakeSession([ + _FakeResponse(302, {"Location": "https://cdn-lfs.huggingface.co/repos/abc/model.safetensors"}), + _FakeResponse(200), + ]) + + response = await open_model_download_response( + session, "https://huggingface.co/org/repo/resolve/main/model.safetensors" + ) + + assert response.status == 200 + assert session.calls == [ + ("https://huggingface.co/org/repo/resolve/main/model.safetensors", False), + ("https://cdn-lfs.huggingface.co/repos/abc/model.safetensors", False), + ] + + +@pytest.mark.asyncio +async def test_open_model_download_response_rejects_offsite_redirect(): + """A redirect leaving the allowlist must surface as a 403 instead of being followed.""" + session = _FakeSession([ + _FakeResponse(302, {"Location": "https://attacker.example.com/payload"}), + ]) + + with pytest.raises(ModelDownloadError) as exc_info: + await open_model_download_response( + session, "https://huggingface.co/org/repo/resolve/main/model.safetensors" + ) + + assert exc_info.value.status == 403 + # The initial request was issued with redirects disabled, otherwise + # the validation above would be a no-op. + assert session.calls[0][1] is False + + +@pytest.mark.asyncio +async def test_open_model_download_response_rejects_redirect_without_location(): + session = _FakeSession([_FakeResponse(302)]) + + with pytest.raises(ModelDownloadError) as exc_info: + await open_model_download_response( + session, "https://huggingface.co/org/repo/resolve/main/model.safetensors" + ) + + assert exc_info.value.status == 502 + + +@pytest.mark.asyncio +async def test_open_model_download_response_stops_after_too_many_redirects(): + session = _FakeSession( + [_FakeResponse(302, {"Location": "https://cdn-lfs.huggingface.co/loop"})] * 10 + ) + + with pytest.raises(ModelDownloadError) as exc_info: + await open_model_download_response( + session, "https://huggingface.co/org/repo/resolve/main/model.safetensors" + ) + + assert exc_info.value.status == 502