Address review feedback on /internal/models/download

- Disable aiohttp auto-redirects and re-validate every Location target
  against the same allowlist used for the initial URL, closing an SSRF
  vector where an allowed host could redirect to an arbitrary internal
  endpoint.
- Accept subdomains of allowlisted hosts so Hugging Face's LFS CDN
  (cdn-lfs.huggingface.co et al.) keeps working under the stricter
  redirect handling.
- Pass an explicit ClientTimeout (connect/sock_read) so hung remotes
  surface as errors instead of blocking the request handler forever.
- Log the exception value alongside the traceback on the 500 fallback.
- Add positive coverage for normalize_model_relative_path, Civitai URL
  allowlisting, and the redirect-following / SSRF-rejection branches of
  open_model_download_response.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
adv0r 2026-05-19 11:26:53 +02:00
parent f9eac7477a
commit 15d49a61b8
3 changed files with 193 additions and 6 deletions

View File

@ -72,8 +72,8 @@ class InternalRoutes:
)
except ModelDownloadError as err:
return web.json_response({"error": str(err)}, status=err.status)
except Exception:
logging.exception("Failed to download model")
except Exception as err:
logging.exception("Failed to download model: %s", err)
return web.json_response({"error": "Failed to download model."}, status=500)
response_status = 200 if result["status"] == "already_exists" else 201

View File

@ -4,9 +4,9 @@ import os
import posixpath
from dataclasses import dataclass
from pathlib import PurePosixPath
from urllib.parse import urlparse
from urllib.parse import urljoin, urlparse
from aiohttp import ClientSession
from aiohttp import ClientSession, ClientTimeout
import folder_paths
@ -16,6 +16,18 @@ ALLOWED_DOWNLOAD_SUFFIXES = (".safetensors", ".sft", ".ckpt", ".pth", ".pt")
BLOCKED_MODEL_FOLDERS = {"configs", "custom_nodes"}
CHUNK_SIZE = 1024 * 1024
# Bound the network call so a hung remote eventually surfaces an error
# instead of blocking the request handler forever. ``sock_read`` is the
# inter-chunk read timeout, which is the right knob for long downloads:
# a slow-but-progressing transfer keeps making progress, while a stalled
# socket fails predictably.
DOWNLOAD_TIMEOUT = ClientTimeout(total=None, connect=30, sock_connect=30, sock_read=300)
# Maximum number of redirects we follow manually. Hugging Face typically
# redirects ``/resolve/main/...`` to a single CDN URL, so a small budget
# is enough while still preventing redirect loops.
MAX_DOWNLOAD_REDIRECTS = 5
WHITE_LISTED_DOWNLOAD_URLS = {
"https://huggingface.co/stabilityai/stable-zero123/resolve/main/stable_zero123.ckpt",
"https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth?download=true",
@ -71,6 +83,14 @@ def parse_model_download_request(data) -> ModelDownloadRequest:
def is_allowed_model_download_url(url: str) -> bool:
"""Return True for URLs we are willing to fetch on behalf of the user.
The same predicate is applied to the user-supplied URL and to every
redirect target, so SSRF via redirects on an allowed host is contained
to the same allowlist. Subdomains of allowlisted hosts are accepted
because Hugging Face and Civitai both serve actual file payloads from
CDN subdomains (e.g. ``cdn-lfs.huggingface.co``).
"""
if url in WHITE_LISTED_DOWNLOAD_URLS:
return True
@ -82,7 +102,14 @@ def is_allowed_model_download_url(url: str) -> bool:
if parsed.scheme != "https":
return False
return (parsed.hostname or "").lower() in ALLOWED_DOWNLOAD_HOSTS
host = (parsed.hostname or "").lower()
if not host:
return False
for allowed in ALLOWED_DOWNLOAD_HOSTS:
if host == allowed or host.endswith("." + allowed):
return True
return False
def normalize_model_relative_path(name: str) -> str:
@ -164,6 +191,36 @@ def safe_join(root: str, relative_path: str) -> str:
return full_path
async def open_model_download_response(session: ClientSession, url: str):
"""GET ``url`` with explicit timeout and an allowlist-checked redirect chain.
aiohttp follows redirects by default, which would let an allowed host
redirect to an arbitrary internal target (SSRF). We disable automatic
following and validate every ``Location`` against the same allowlist
used for the initial URL.
"""
current_url = url
for _ in range(MAX_DOWNLOAD_REDIRECTS + 1):
response = await session.get(
current_url,
allow_redirects=False,
timeout=DOWNLOAD_TIMEOUT,
)
if response.status not in (301, 302, 303, 307, 308):
return response
location = response.headers.get("Location", "").strip()
response.release()
if not location:
raise ModelDownloadError("Redirect response missing Location header.", status=502)
next_url = urljoin(current_url, location)
if not is_allowed_model_download_url(next_url):
raise ModelDownloadError("Model download redirect target is not allowed.", status=403)
current_url = next_url
raise ModelDownloadError("Too many redirects while downloading model.", status=502)
async def download_model_to_destination(
session: ClientSession,
request: ModelDownloadRequest,
@ -183,7 +240,7 @@ async def download_model_to_destination(
bytes_written = 0
try:
with os.fdopen(fd, "wb") as output:
async with session.get(request.url) as response:
async with await open_model_download_response(session, request.url) as response:
if response.status >= 400:
raise ModelDownloadError(f"Model download failed with HTTP {response.status}.", status=502)

View File

@ -8,11 +8,44 @@ from app.model_download import (
ModelDownloadRequest,
is_allowed_model_download_url,
normalize_model_relative_path,
open_model_download_response,
parse_model_download_request,
resolve_model_download_destination,
)
class _FakeResponse:
"""Minimal stand-in for ``aiohttp.ClientResponse`` for the redirect tests."""
def __init__(self, status, headers=None):
self.status = status
self.headers = headers or {}
self.released = False
def release(self):
self.released = True
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
self.released = True
class _FakeSession:
"""Hands out queued ``_FakeResponse`` objects in order."""
def __init__(self, responses):
self._responses = list(responses)
self.calls = []
async def get(self, url, allow_redirects, timeout):
self.calls.append((url, allow_redirects))
if not self._responses:
raise AssertionError("Unexpected extra session.get call")
return self._responses.pop(0)
def test_parse_model_download_request_allows_huggingface_model_url():
request = parse_model_download_request({
"name": "nested/model.safetensors",
@ -33,12 +66,46 @@ def test_parse_model_download_request_allows_huggingface_model_url():
"http://localhost:8000/model.safetensors",
"http://huggingface.co/org/repo/resolve/main/model.safetensors",
"https://example.com/model.safetensors",
"https://huggingface.co.evil.com/model.safetensors",
],
)
def test_download_url_allowlist_rejects_untrusted_or_plain_http_urls(url):
assert is_allowed_model_download_url(url) is False
@pytest.mark.parametrize(
"url",
[
# Direct HF model URLs.
"https://huggingface.co/org/repo/resolve/main/model.safetensors",
# HF LFS CDN subdomains: this is where `/resolve/main/...` redirects
# land, so the allowlist must accept them or downloads break.
"https://cdn-lfs.huggingface.co/repos/abc/def/model.safetensors",
"https://cdn-lfs-us-1.huggingface.co/repos/abc/def/model.safetensors",
# Civitai download endpoints (PR objective: support Civitai too).
"https://civitai.com/api/download/models/12345",
"https://civitai.red/api/download/models/12345",
],
)
def test_download_url_allowlist_accepts_huggingface_and_civitai_urls(url):
assert is_allowed_model_download_url(url) is True
@pytest.mark.parametrize(
"name, expected",
[
("model.safetensors", "model.safetensors"),
("sub/model.safetensors", "sub/model.safetensors"),
("nested/dir/model.safetensors", "nested/dir/model.safetensors"),
# Backslashes are normalized to forward slashes so Windows-style
# paths land in the same place as the POSIX equivalents.
("nested\\dir\\model.safetensors", "nested/dir/model.safetensors"),
],
)
def test_normalize_model_relative_path_accepts_safe_paths(name, expected):
assert normalize_model_relative_path(name) == expected
@pytest.mark.parametrize(
"name",
[
@ -112,3 +179,66 @@ def test_resolve_model_download_destination_rejects_blocked_or_unknown_directori
url="https://huggingface.co/org/repo/resolve/main/model.safetensors",
directory=directory,
))
@pytest.mark.asyncio
async def test_open_model_download_response_follows_allowed_subdomain_redirect():
"""HF redirects /resolve/main/... to cdn-lfs.huggingface.co; that must work."""
session = _FakeSession([
_FakeResponse(302, {"Location": "https://cdn-lfs.huggingface.co/repos/abc/model.safetensors"}),
_FakeResponse(200),
])
response = await open_model_download_response(
session, "https://huggingface.co/org/repo/resolve/main/model.safetensors"
)
assert response.status == 200
assert session.calls == [
("https://huggingface.co/org/repo/resolve/main/model.safetensors", False),
("https://cdn-lfs.huggingface.co/repos/abc/model.safetensors", False),
]
@pytest.mark.asyncio
async def test_open_model_download_response_rejects_offsite_redirect():
"""A redirect leaving the allowlist must surface as a 403 instead of being followed."""
session = _FakeSession([
_FakeResponse(302, {"Location": "https://attacker.example.com/payload"}),
])
with pytest.raises(ModelDownloadError) as exc_info:
await open_model_download_response(
session, "https://huggingface.co/org/repo/resolve/main/model.safetensors"
)
assert exc_info.value.status == 403
# The initial request was issued with redirects disabled, otherwise
# the validation above would be a no-op.
assert session.calls[0][1] is False
@pytest.mark.asyncio
async def test_open_model_download_response_rejects_redirect_without_location():
session = _FakeSession([_FakeResponse(302)])
with pytest.raises(ModelDownloadError) as exc_info:
await open_model_download_response(
session, "https://huggingface.co/org/repo/resolve/main/model.safetensors"
)
assert exc_info.value.status == 502
@pytest.mark.asyncio
async def test_open_model_download_response_stops_after_too_many_redirects():
session = _FakeSession(
[_FakeResponse(302, {"Location": "https://cdn-lfs.huggingface.co/loop"})] * 10
)
with pytest.raises(ModelDownloadError) as exc_info:
await open_model_download_response(
session, "https://huggingface.co/org/repo/resolve/main/model.safetensors"
)
assert exc_info.value.status == 502