diff --git a/app/model_downloader/manager.py b/app/model_downloader/manager.py index d5d5734be..0f358c9d5 100644 --- a/app/model_downloader/manager.py +++ b/app/model_downloader/manager.py @@ -14,10 +14,17 @@ from typing import Callable, Optional from app.model_downloader.constants import DownloadStatus from app.model_downloader.database import queries +from app.model_downloader.net.probe import probe from app.model_downloader.scheduler import SCHEDULER from app.model_downloader.security import paths from app.model_downloader.net.http import redact_url -from app.model_downloader.security.allowlist import is_url_allowed +from app.model_downloader.security.allowlist import ( + ALLOWED_MODEL_EXTENSIONS, + filename_extension, + is_host_allowed_url, + is_url_downloadable, + url_path_extension, +) from app.model_downloader.security.paths import InvalidModelId # Non-terminal statuses: an existing row in one of these blocks a re-enqueue. @@ -70,11 +77,30 @@ class DownloadManager: allow_any_extension: bool = False, credential_id: Optional[str] = None, ) -> str: - if not is_url_allowed(url, allow_any_extension): + # Coarse gate first: host/scheme must be allowlisted, and any extension + # present in the URL path must be a known model type. A URL whose path + # carries NO extension (e.g. Civitai's ``/api/download/models/``) is + # admitted here and its real extension is resolved from the network + # below before the download is finally accepted. + if allow_any_extension: + if not is_host_allowed_url(url): + raise DownloadError( + "URL_NOT_ALLOWED", + "URL is not on the download allowlist (host/scheme).", + ) + elif not is_url_downloadable(url): raise DownloadError( "URL_NOT_ALLOWED", "URL is not on the download allowlist (host/scheme/extension).", ) + + # When the URL path has no extension, follow it to where it resolves and + # adopt the real extension from the response, forcing the stored + # filename to match. Skipped when the caller opted into any extension. + if not allow_any_extension and url_path_extension(url) == "": + resolved_ext = await self._resolve_extension(url, credential_id) + model_id = paths.apply_extension(model_id, resolved_ext) + try: paths.parse_model_id(model_id, allow_any_extension) dest_path, temp_path = paths.resolve_destination(model_id, allow_any_extension) @@ -119,6 +145,40 @@ class DownloadManager: await self._scheduler.pump() return download_id + async def _resolve_extension( + self, url: str, credential_id: Optional[str] + ) -> str: + """Follow ``url`` to its final response and return the real extension. + + Used for allowlisted URLs whose path has no extension (e.g. Civitai + download endpoints): the filename lives in the ``Content-Disposition`` + header or the post-redirect URL. Raises :class:`DownloadError` when the + URL can't be resolved, needs credentials, or resolves to something that + is not a known model file — so we never persist a bogus destination. + """ + pr = await probe(url, credential_id=credential_id) + if not pr.ok: + if pr.gated: + raise DownloadError( + "CREDENTIALS_REQUIRED", + f"{redact_url(url)} requires authentication to resolve. Add an " + f"API key for this host at /api/download/credentials and retry.", + status=401, + ) + raise DownloadError( + "URL_RESOLVE_FAILED", + f"Could not resolve {redact_url(url)}: {pr.error or 'unknown error'}", + status=502, + ) + ext = filename_extension(pr.filename) if pr.filename else "" + if ext not in ALLOWED_MODEL_EXTENSIONS: + raise DownloadError( + "URL_NOT_ALLOWED", + f"URL resolves to {pr.filename or ''!r}, which is not a " + f"known model file type {ALLOWED_MODEL_EXTENSIONS}.", + ) + return ext + def _model_lock(self, model_id: str) -> asyncio.Lock: # Lazily create one lock per model_id. There is no ``await`` between the # lookup and the insert, so under the single asyncio thread this is @@ -362,22 +422,25 @@ class DownloadManager: if r.status in _LIVE_STATUSES or r.model_id not in by_model: by_model[r.model_id] = r + # ``url_allowed`` mirrors the coarse enqueue gate (host/scheme + a + # non-disallowed extension); URLs whose extension is only known after a + # network resolve — e.g. Civitai download endpoints — report allowed. out: dict[str, dict] = {} for model_id, url in models.items(): try: exists = await asyncio.to_thread(paths.resolve_existing, model_id) except InvalidModelId: - out[model_id] = {"state": "missing", "url_allowed": is_url_allowed(url)} + out[model_id] = {"state": "missing", "url_allowed": is_url_downloadable(url)} continue if exists: - out[model_id] = {"state": "available", "url_allowed": is_url_allowed(url)} + out[model_id] = {"state": "available", "url_allowed": is_url_downloadable(url)} continue row = by_model.get(model_id) if row is not None and row.status in _LIVE_STATUSES: view = self._view(row) out[model_id] = { "state": "downloading", - "url_allowed": is_url_allowed(url), + "url_allowed": is_url_downloadable(url), "download_id": view["download_id"], "progress": view["progress"], "bytes_done": view["bytes_done"], @@ -385,7 +448,7 @@ class DownloadManager: "speed_bps": view["speed_bps"], } else: - out[model_id] = {"state": "missing", "url_allowed": is_url_allowed(url)} + out[model_id] = {"state": "missing", "url_allowed": is_url_downloadable(url)} return out diff --git a/app/model_downloader/net/http.py b/app/model_downloader/net/http.py index e5cb874a1..8112b7a20 100644 --- a/app/model_downloader/net/http.py +++ b/app/model_downloader/net/http.py @@ -10,9 +10,10 @@ that attaches credentials, so a token can never ride a redirect to a CDN host. from __future__ import annotations import logging +import re from contextlib import asynccontextmanager from typing import AsyncIterator, Optional -from urllib.parse import urljoin, urlsplit, urlunsplit +from urllib.parse import unquote, urljoin, urlsplit, urlunsplit import aiohttp @@ -37,6 +38,43 @@ def redact_url(url: str) -> str: return urlunsplit(parts._replace(query="")) +_CD_FILENAME_STAR = re.compile( + r"filename\*\s*=\s*[^']*'[^']*'([^;]+)", re.IGNORECASE +) +_CD_FILENAME_QUOTED = re.compile(r'filename\s*=\s*"([^"]+)"', re.IGNORECASE) +_CD_FILENAME_BARE = re.compile(r"filename\s*=\s*([^;]+)", re.IGNORECASE) + + +def filename_from_content_disposition(value: Optional[str]) -> Optional[str]: + """Extract the download filename from a ``Content-Disposition`` header. + + Prefers the RFC 5987 ``filename*=`` form (percent-decoded) over the plain + ``filename=`` form. Any directory components in the value are stripped so a + hostile header can only influence the *name*, never the target directory. + Returns ``None`` when no filename is present. + """ + if not value: + return None + for pat, decode in ( + (_CD_FILENAME_STAR, True), + (_CD_FILENAME_QUOTED, False), + (_CD_FILENAME_BARE, False), + ): + m = pat.search(value) + if not m: + continue + raw = m.group(1).strip().strip('"') + if decode: + try: + raw = unquote(raw) + except Exception: + pass + name = raw.replace("\\", "/").rsplit("/", 1)[-1].strip() + if name: + return name + return None + + async def _resolve_final_response( method: str, url: str, diff --git a/app/model_downloader/net/probe.py b/app/model_downloader/net/probe.py index 61169ded5..7ed65b855 100644 --- a/app/model_downloader/net/probe.py +++ b/app/model_downloader/net/probe.py @@ -12,11 +12,14 @@ from __future__ import annotations import logging from dataclasses import dataclass from typing import Optional -from urllib.parse import urlparse +from urllib.parse import urlparse, urlsplit import aiohttp -from app.model_downloader.net.http import open_validated +from app.model_downloader.net.http import ( + filename_from_content_disposition, + open_validated, +) from app.model_downloader.net.session import parse_int_header _PROBE_TIMEOUT = aiohttp.ClientTimeout(total=60, sock_connect=30, sock_read=30) @@ -33,6 +36,11 @@ class ProbeResult: last_modified: Optional[str] = None gated: bool = False # 401/403 — needs (or has wrong) credentials error: Optional[str] = None + # Filename the server intends this response to be saved as: the + # ``Content-Disposition`` name if present, else the post-redirect URL's + # basename. Used to resolve the real extension for URLs (e.g. Civitai's + # ``/api/download`` endpoints) that carry no extension in their path. + filename: Optional[str] = None def _total_from_content_range(value: Optional[str]) -> Optional[int]: @@ -43,6 +51,19 @@ def _total_from_content_range(value: Optional[str]) -> Optional[int]: return parse_int_header(total) +def _filename_from_response( + content_disposition: Optional[str], final_url: Optional[str] +) -> Optional[str]: + name = filename_from_content_disposition(content_disposition) + if name: + return name + if final_url: + base = urlsplit(final_url).path.rsplit("/", 1)[-1] + if base: + return base + return None + + async def probe(url: str, *, credential_id: Optional[str] = None) -> ProbeResult: """Probe ``url`` and return discovered metadata, failing soft.""" try: @@ -85,6 +106,9 @@ async def probe(url: str, *, credential_id: Optional[str] = None) -> ProbeResult accept_ranges=accept_ranges, etag=headers.get("ETag"), last_modified=headers.get("Last-Modified"), + filename=_filename_from_response( + headers.get("Content-Disposition"), final_url + ), ) except Exception as e: # network / SSRF / timeout host = urlparse(url).netloc or "" diff --git a/app/model_downloader/security/allowlist.py b/app/model_downloader/security/allowlist.py index 2bf26cc33..f1e0ecbc0 100644 --- a/app/model_downloader/security/allowlist.py +++ b/app/model_downloader/security/allowlist.py @@ -71,6 +71,62 @@ def has_allowed_extension(path: str, allow_any_extension: bool = False) -> bool: return path.lower().endswith(ALLOWED_MODEL_EXTENSIONS) +def filename_extension(name: str) -> str: + """Lowercased extension (including the leading dot) of a bare filename. + + Returns ``""`` when there is no extension. A leading-dot name + (``.safetensors``) is treated as having no extension (all stem), matching + ``os.path.splitext`` semantics so dotfiles aren't mistaken for typed files. + """ + base = name.replace("\\", "/").rsplit("/", 1)[-1] + dot = base.rfind(".") + if dot <= 0: + return "" + return base[dot:].lower() + + +def is_allowed_extension_name(name: str) -> bool: + """True iff ``name`` ends in one of the known model extensions.""" + return name.lower().endswith(ALLOWED_MODEL_EXTENSIONS) + + +def is_host_allowed_url(url: str) -> bool: + """True iff ``url`` parses and its host+scheme are allowlisted.""" + if not isinstance(url, str) or not url: + return False + try: + parsed = urlparse(url) + except ValueError: + return False + return is_host_allowed(parsed.hostname, parsed.scheme) + + +def url_path_extension(url: str) -> str: + """Extension of the URL *path* basename (query ignored), or ``""``.""" + try: + parsed = urlparse(url) + except ValueError: + return "" + return filename_extension(parsed.path) + + +def is_url_downloadable(url: str) -> bool: + """Coarse enqueue gate: host/scheme allowed and extension not disallowed. + + Unlike :func:`is_url_allowed` (which demands a known extension *in the URL*), + this also admits URLs whose path carries no extension at all — e.g. a Civitai + ``/api/download/models/`` endpoint whose real filename only shows up in + the redirect target / ``Content-Disposition``. The true extension is then + resolved from the network and re-validated before the download is admitted. + A path bearing an explicit *non-model* extension (``.zip``, ``.html``, ...) + is still rejected here. + """ + if not is_host_allowed_url(url): + return False + ext = url_path_extension(url) + return ext == "" or ext in ALLOWED_MODEL_EXTENSIONS + + def is_url_allowed(url: str, allow_any_extension: bool = False) -> bool: """Check whether ``url`` is permitted as a server-side download source.""" if not isinstance(url, str) or not url: diff --git a/app/model_downloader/security/paths.py b/app/model_downloader/security/paths.py index 1d47cf2bf..6b483a42a 100644 --- a/app/model_downloader/security/paths.py +++ b/app/model_downloader/security/paths.py @@ -58,6 +58,28 @@ def parse_model_id(model_id: str, allow_any_extension: bool = False) -> tuple[st return directory, filename +def apply_extension(model_id: str, ext: str) -> str: + """Return ``model_id`` with its filename forced to end in ``ext``. + + ``ext`` includes the leading dot (e.g. ``".safetensors"``). If the filename + already ends in a *known model extension* it is replaced; otherwise ``ext`` + is appended (so ``loras/mymodel`` -> ``loras/mymodel.safetensors`` and + ``loras/mymodel.ckpt`` -> ``loras/mymodel.safetensors``). A filename with a + non-model suffix (``my.model.v2``) is treated as an extensionless stem and + ``ext`` is appended. The directory part is left untouched; validation is + still the caller's job via :func:`parse_model_id`. + """ + directory, sep, filename = model_id.partition("/") + if not sep: + return model_id # malformed; parse_model_id will reject it + low = filename.lower() + for known in ALLOWED_MODEL_EXTENSIONS: + if low.endswith(known): + filename = filename[: -len(known)] + break + return f"{directory}{sep}{filename}{ext}" + + def resolve_existing(model_id: str, allow_any_extension: bool = False) -> Optional[str]: """Return the absolute path of an installed model, or None if missing. diff --git a/tests-unit/model_downloader_test/test_engine_integration.py b/tests-unit/model_downloader_test/test_engine_integration.py index 4a4982e18..435c7f4c7 100644 --- a/tests-unit/model_downloader_test/test_engine_integration.py +++ b/tests-unit/model_downloader_test/test_engine_integration.py @@ -83,6 +83,37 @@ def _range_handler(payload: bytes): return handler +def _content_disposition_handler(payload: bytes, filename: str): + """A range-capable server that only reveals its filename via a header. + + Models a Civitai-style ``/api/download/...`` endpoint: the URL path has no + extension, and the real filename (hence extension) lives in the response + ``Content-Disposition`` header. + """ + + async def handler(request: web.Request) -> web.Response: + headers = { + "Accept-Ranges": "bytes", + "ETag": PAYLOAD_ETAG, + "Content-Disposition": f'attachment; filename="{filename}"', + } + rng = request.headers.get("Range") + if rng: + spec = rng.split("=", 1)[1] + s, _, e = spec.partition("-") + start = int(s) + end = int(e) if e else len(payload) - 1 + chunk = payload[start : end + 1] + return web.Response( + status=206, + body=chunk, + headers={**headers, "Content-Range": f"bytes {start}-{end}/{len(payload)}"}, + ) + return web.Response(status=200, body=payload, headers=headers) + + return handler + + def _noranges_handler(payload: bytes): async def handler(request: web.Request) -> web.Response: # Always full body, never advertises Accept-Ranges -> single-stream. @@ -517,3 +548,90 @@ def test_manager_rejects_disallowed_url(model_root): assert ei.value.code == "URL_NOT_ALLOWED" asyncio.run(_run()) + + +def test_manager_resolves_extensionless_url(model_root): + """An allowlisted URL with no extension in its path is resolved from the + response, and the stored file adopts the resolved extension.""" + payload = _safetensors_payload(1 * 1024 * 1024) + + async def _run(): + await close_session() + from app.model_downloader.manager import DOWNLOAD_MANAGER + + runner, port = await _serve( + _content_disposition_handler(payload, "RealModel.safetensors") + ) + try: + # No extension in the path (Civitai-style) and none in the model_id. + url = f"http://127.0.0.1:{port}/api/download/models/12345" + did = await DOWNLOAD_MANAGER.enqueue(url, "loras/my_civitai_model") + + row = queries.get_download(did) + # The resolved extension was appended to the model_id + destination. + assert row.model_id == "loras/my_civitai_model.safetensors" + assert row.dest_path.endswith("my_civitai_model.safetensors") + + final_path, _ = paths.resolve_destination( + "loras/my_civitai_model.safetensors" + ) + for _ in range(500): + await asyncio.sleep(0.02) + row = queries.get_download(did) + if row.status in DownloadStatus.TERMINAL: + break + row = queries.get_download(did) + assert row.status == DownloadStatus.COMPLETED, row.error + assert open(final_path, "rb").read() == payload + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +def test_manager_overrides_extension_from_resolution(model_root): + """A model_id carrying a different known extension is corrected to match + the resolved URL's extension.""" + payload = _safetensors_payload(256 * 1024) + + async def _run(): + await close_session() + from app.model_downloader.manager import DOWNLOAD_MANAGER + + runner, port = await _serve( + _content_disposition_handler(payload, "weights.safetensors") + ) + try: + url = f"http://127.0.0.1:{port}/api/download/models/777" + # Caller guessed .ckpt; resolution says .safetensors -> corrected. + did = await DOWNLOAD_MANAGER.enqueue(url, "loras/guessed.ckpt") + row = queries.get_download(did) + assert row.model_id == "loras/guessed.safetensors" + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +def test_manager_rejects_non_model_resolution(model_root): + """A URL that resolves to a non-model file is rejected, not downloaded.""" + + async def _run(): + await close_session() + from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError + + runner, port = await _serve( + _content_disposition_handler(b"not a model", "installer.zip") + ) + try: + url = f"http://127.0.0.1:{port}/api/download/models/999" + with pytest.raises(DownloadError) as ei: + await DOWNLOAD_MANAGER.enqueue(url, "loras/whatever") + assert ei.value.code == "URL_NOT_ALLOWED" + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) diff --git a/tests-unit/model_downloader_test/test_security.py b/tests-unit/model_downloader_test/test_security.py index 6ce44eaa0..0d101c8cf 100644 --- a/tests-unit/model_downloader_test/test_security.py +++ b/tests-unit/model_downloader_test/test_security.py @@ -43,6 +43,42 @@ def test_allow_any_extension_relaxes_extension_only(): assert allowlist.is_url_allowed(odd, allow_any_extension=True) is True +@pytest.mark.parametrize( + "url,downloadable", + [ + # known model extension in the path -> allowed + ("https://civitai.com/x/model.safetensors", True), + # no extension in the path (Civitai download API) -> allowed, resolved later + ("https://civitai.com/api/download/models/3031464?fileId=2910346", True), + ("https://civitai.com/api/download/models/3031464", True), + # explicit non-model extension -> rejected even on an allowed host + ("https://civitai.com/api/download/models/thing.zip", False), + ("https://huggingface.co/org/repo/resolve/main/config.json", False), + # off-list host is never downloadable + ("https://evil.example.com/api/download/models/1", False), + # http to a non-loopback allowlisted host is not permitted + ("http://civitai.com/api/download/models/1", False), + ], +) +def test_is_url_downloadable(url, downloadable): + assert allowlist.is_url_downloadable(url) is downloadable + + +@pytest.mark.parametrize( + "name,ext", + [ + ("model.safetensors", ".safetensors"), + ("model.SAFETENSORS", ".safetensors"), + ("archive.tar.gz", ".gz"), + ("noext", ""), + (".safetensors", ""), # leading-dot dotfile -> no extension + ("a/b/c/model.ckpt", ".ckpt"), + ], +) +def test_filename_extension(name, ext): + assert allowlist.filename_extension(name) == ext + + # ----- SSRF: blocked IPs ----- @@ -148,3 +184,48 @@ def test_resolve_destination_stays_in_root(model_root): assert final_path.startswith(model_root) assert temp_path.startswith(model_root) assert temp_path != final_path + + +@pytest.mark.parametrize( + "model_id,ext,expected", + [ + # no extension -> append the resolved one + ("loras/my_civitai_model", ".safetensors", "loras/my_civitai_model.safetensors"), + # different known extension -> replace it + ("loras/mymodel.ckpt", ".safetensors", "loras/mymodel.safetensors"), + # same extension -> unchanged + ("loras/mymodel.safetensors", ".safetensors", "loras/mymodel.safetensors"), + # non-model suffix is treated as a stem, extension appended + ("loras/my.model.v2", ".safetensors", "loras/my.model.v2.safetensors"), + # malformed (no slash) is returned untouched for parse_model_id to reject + ("noslash", ".safetensors", "noslash"), + ], +) +def test_apply_extension(model_id, ext, expected): + assert paths.apply_extension(model_id, ext) == expected + + +# ----- Content-Disposition filename parsing ----- + + +@pytest.mark.parametrize( + "header,expected", + [ + ('attachment; filename="model.safetensors"', "model.safetensors"), + ("attachment; filename=model.ckpt", "model.ckpt"), + # RFC 5987 form is preferred and percent-decoded + ( + "attachment; filename=\"fallback.bin\"; filename*=UTF-8''my%20model.safetensors", + "my model.safetensors", + ), + # directory components in a hostile header are stripped to the basename + ('attachment; filename="../../etc/passwd"', "passwd"), + ('attachment; filename="a\\\\b\\\\model.pt"', "model.pt"), + ("inline", None), + (None, None), + ], +) +def test_filename_from_content_disposition(header, expected): + from app.model_downloader.net.http import filename_from_content_disposition + + assert filename_from_content_disposition(header) == expected