diff --git a/api_server/routes/internal/internal_routes.py b/api_server/routes/internal/internal_routes.py index 2d24ecdf5..89af3777a 100644 --- a/api_server/routes/internal/internal_routes.py +++ b/api_server/routes/internal/internal_routes.py @@ -72,8 +72,8 @@ class InternalRoutes: ) except ModelDownloadError as err: return web.json_response({"error": str(err)}, status=err.status) - except Exception: - logging.exception("Failed to download model") + except Exception as err: + logging.exception("Failed to download model: %s", err) return web.json_response({"error": "Failed to download model."}, status=500) response_status = 200 if result["status"] == "already_exists" else 201 diff --git a/app/model_download.py b/app/model_download.py index 959786a0a..d2fb9ae1b 100644 --- a/app/model_download.py +++ b/app/model_download.py @@ -4,9 +4,9 @@ import os import posixpath from dataclasses import dataclass from pathlib import PurePosixPath -from urllib.parse import urlparse +from urllib.parse import urljoin, urlparse -from aiohttp import ClientSession +from aiohttp import ClientSession, ClientTimeout import folder_paths @@ -16,6 +16,18 @@ ALLOWED_DOWNLOAD_SUFFIXES = (".safetensors", ".sft", ".ckpt", ".pth", ".pt") BLOCKED_MODEL_FOLDERS = {"configs", "custom_nodes"} CHUNK_SIZE = 1024 * 1024 +# Bound the network call so a hung remote eventually surfaces an error +# instead of blocking the request handler forever. ``sock_read`` is the +# inter-chunk read timeout, which is the right knob for long downloads: +# a slow-but-progressing transfer keeps making progress, while a stalled +# socket fails predictably. +DOWNLOAD_TIMEOUT = ClientTimeout(total=None, connect=30, sock_connect=30, sock_read=300) + +# Maximum number of redirects we follow manually. Hugging Face typically +# redirects ``/resolve/main/...`` to a single CDN URL, so a small budget +# is enough while still preventing redirect loops. +MAX_DOWNLOAD_REDIRECTS = 5 + WHITE_LISTED_DOWNLOAD_URLS = { "https://huggingface.co/stabilityai/stable-zero123/resolve/main/stable_zero123.ckpt", "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth?download=true", @@ -71,6 +83,14 @@ def parse_model_download_request(data) -> ModelDownloadRequest: def is_allowed_model_download_url(url: str) -> bool: + """Return True for URLs we are willing to fetch on behalf of the user. + + The same predicate is applied to the user-supplied URL and to every + redirect target, so SSRF via redirects on an allowed host is contained + to the same allowlist. Subdomains of allowlisted hosts are accepted + because Hugging Face and Civitai both serve actual file payloads from + CDN subdomains (e.g. ``cdn-lfs.huggingface.co``). + """ if url in WHITE_LISTED_DOWNLOAD_URLS: return True @@ -82,7 +102,14 @@ def is_allowed_model_download_url(url: str) -> bool: if parsed.scheme != "https": return False - return (parsed.hostname or "").lower() in ALLOWED_DOWNLOAD_HOSTS + host = (parsed.hostname or "").lower() + if not host: + return False + + for allowed in ALLOWED_DOWNLOAD_HOSTS: + if host == allowed or host.endswith("." + allowed): + return True + return False def normalize_model_relative_path(name: str) -> str: @@ -164,6 +191,36 @@ def safe_join(root: str, relative_path: str) -> str: return full_path +async def open_model_download_response(session: ClientSession, url: str): + """GET ``url`` with explicit timeout and an allowlist-checked redirect chain. + + aiohttp follows redirects by default, which would let an allowed host + redirect to an arbitrary internal target (SSRF). We disable automatic + following and validate every ``Location`` against the same allowlist + used for the initial URL. + """ + current_url = url + for _ in range(MAX_DOWNLOAD_REDIRECTS + 1): + response = await session.get( + current_url, + allow_redirects=False, + timeout=DOWNLOAD_TIMEOUT, + ) + if response.status not in (301, 302, 303, 307, 308): + return response + + location = response.headers.get("Location", "").strip() + response.release() + if not location: + raise ModelDownloadError("Redirect response missing Location header.", status=502) + next_url = urljoin(current_url, location) + if not is_allowed_model_download_url(next_url): + raise ModelDownloadError("Model download redirect target is not allowed.", status=403) + current_url = next_url + + raise ModelDownloadError("Too many redirects while downloading model.", status=502) + + async def download_model_to_destination( session: ClientSession, request: ModelDownloadRequest, @@ -183,7 +240,7 @@ async def download_model_to_destination( bytes_written = 0 try: with os.fdopen(fd, "wb") as output: - async with session.get(request.url) as response: + async with await open_model_download_response(session, request.url) as response: if response.status >= 400: raise ModelDownloadError(f"Model download failed with HTTP {response.status}.", status=502) diff --git a/tests-unit/app_test/model_download_test.py b/tests-unit/app_test/model_download_test.py index 3dece976c..c4a8a3424 100644 --- a/tests-unit/app_test/model_download_test.py +++ b/tests-unit/app_test/model_download_test.py @@ -8,11 +8,44 @@ from app.model_download import ( ModelDownloadRequest, is_allowed_model_download_url, normalize_model_relative_path, + open_model_download_response, parse_model_download_request, resolve_model_download_destination, ) +class _FakeResponse: + """Minimal stand-in for ``aiohttp.ClientResponse`` for the redirect tests.""" + + def __init__(self, status, headers=None): + self.status = status + self.headers = headers or {} + self.released = False + + def release(self): + self.released = True + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + self.released = True + + +class _FakeSession: + """Hands out queued ``_FakeResponse`` objects in order.""" + + def __init__(self, responses): + self._responses = list(responses) + self.calls = [] + + async def get(self, url, allow_redirects, timeout): + self.calls.append((url, allow_redirects)) + if not self._responses: + raise AssertionError("Unexpected extra session.get call") + return self._responses.pop(0) + + def test_parse_model_download_request_allows_huggingface_model_url(): request = parse_model_download_request({ "name": "nested/model.safetensors", @@ -33,12 +66,46 @@ def test_parse_model_download_request_allows_huggingface_model_url(): "http://localhost:8000/model.safetensors", "http://huggingface.co/org/repo/resolve/main/model.safetensors", "https://example.com/model.safetensors", + "https://huggingface.co.evil.com/model.safetensors", ], ) def test_download_url_allowlist_rejects_untrusted_or_plain_http_urls(url): assert is_allowed_model_download_url(url) is False +@pytest.mark.parametrize( + "url", + [ + # Direct HF model URLs. + "https://huggingface.co/org/repo/resolve/main/model.safetensors", + # HF LFS CDN subdomains: this is where `/resolve/main/...` redirects + # land, so the allowlist must accept them or downloads break. + "https://cdn-lfs.huggingface.co/repos/abc/def/model.safetensors", + "https://cdn-lfs-us-1.huggingface.co/repos/abc/def/model.safetensors", + # Civitai download endpoints (PR objective: support Civitai too). + "https://civitai.com/api/download/models/12345", + "https://civitai.red/api/download/models/12345", + ], +) +def test_download_url_allowlist_accepts_huggingface_and_civitai_urls(url): + assert is_allowed_model_download_url(url) is True + + +@pytest.mark.parametrize( + "name, expected", + [ + ("model.safetensors", "model.safetensors"), + ("sub/model.safetensors", "sub/model.safetensors"), + ("nested/dir/model.safetensors", "nested/dir/model.safetensors"), + # Backslashes are normalized to forward slashes so Windows-style + # paths land in the same place as the POSIX equivalents. + ("nested\\dir\\model.safetensors", "nested/dir/model.safetensors"), + ], +) +def test_normalize_model_relative_path_accepts_safe_paths(name, expected): + assert normalize_model_relative_path(name) == expected + + @pytest.mark.parametrize( "name", [ @@ -112,3 +179,66 @@ def test_resolve_model_download_destination_rejects_blocked_or_unknown_directori url="https://huggingface.co/org/repo/resolve/main/model.safetensors", directory=directory, )) + + +@pytest.mark.asyncio +async def test_open_model_download_response_follows_allowed_subdomain_redirect(): + """HF redirects /resolve/main/... to cdn-lfs.huggingface.co; that must work.""" + session = _FakeSession([ + _FakeResponse(302, {"Location": "https://cdn-lfs.huggingface.co/repos/abc/model.safetensors"}), + _FakeResponse(200), + ]) + + response = await open_model_download_response( + session, "https://huggingface.co/org/repo/resolve/main/model.safetensors" + ) + + assert response.status == 200 + assert session.calls == [ + ("https://huggingface.co/org/repo/resolve/main/model.safetensors", False), + ("https://cdn-lfs.huggingface.co/repos/abc/model.safetensors", False), + ] + + +@pytest.mark.asyncio +async def test_open_model_download_response_rejects_offsite_redirect(): + """A redirect leaving the allowlist must surface as a 403 instead of being followed.""" + session = _FakeSession([ + _FakeResponse(302, {"Location": "https://attacker.example.com/payload"}), + ]) + + with pytest.raises(ModelDownloadError) as exc_info: + await open_model_download_response( + session, "https://huggingface.co/org/repo/resolve/main/model.safetensors" + ) + + assert exc_info.value.status == 403 + # The initial request was issued with redirects disabled, otherwise + # the validation above would be a no-op. + assert session.calls[0][1] is False + + +@pytest.mark.asyncio +async def test_open_model_download_response_rejects_redirect_without_location(): + session = _FakeSession([_FakeResponse(302)]) + + with pytest.raises(ModelDownloadError) as exc_info: + await open_model_download_response( + session, "https://huggingface.co/org/repo/resolve/main/model.safetensors" + ) + + assert exc_info.value.status == 502 + + +@pytest.mark.asyncio +async def test_open_model_download_response_stops_after_too_many_redirects(): + session = _FakeSession( + [_FakeResponse(302, {"Location": "https://cdn-lfs.huggingface.co/loop"})] * 10 + ) + + with pytest.raises(ModelDownloadError) as exc_info: + await open_model_download_response( + session, "https://huggingface.co/org/repo/resolve/main/model.safetensors" + ) + + assert exc_info.value.status == 502