Add extension check on the final resolved url -> fix downloading from civitAI.
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run

This commit is contained in:
Talmaj Marinc 2026-07-01 15:37:30 +02:00
parent 4c82c708a7
commit c98a212589
7 changed files with 411 additions and 9 deletions

View File

@ -14,10 +14,17 @@ from typing import Callable, Optional
from app.model_downloader.constants import DownloadStatus
from app.model_downloader.database import queries
from app.model_downloader.net.probe import probe
from app.model_downloader.scheduler import SCHEDULER
from app.model_downloader.security import paths
from app.model_downloader.net.http import redact_url
from app.model_downloader.security.allowlist import is_url_allowed
from app.model_downloader.security.allowlist import (
ALLOWED_MODEL_EXTENSIONS,
filename_extension,
is_host_allowed_url,
is_url_downloadable,
url_path_extension,
)
from app.model_downloader.security.paths import InvalidModelId
# Non-terminal statuses: an existing row in one of these blocks a re-enqueue.
@ -70,11 +77,30 @@ class DownloadManager:
allow_any_extension: bool = False,
credential_id: Optional[str] = None,
) -> str:
if not is_url_allowed(url, allow_any_extension):
# Coarse gate first: host/scheme must be allowlisted, and any extension
# present in the URL path must be a known model type. A URL whose path
# carries NO extension (e.g. Civitai's ``/api/download/models/<id>``) is
# admitted here and its real extension is resolved from the network
# below before the download is finally accepted.
if allow_any_extension:
if not is_host_allowed_url(url):
raise DownloadError(
"URL_NOT_ALLOWED",
"URL is not on the download allowlist (host/scheme).",
)
elif not is_url_downloadable(url):
raise DownloadError(
"URL_NOT_ALLOWED",
"URL is not on the download allowlist (host/scheme/extension).",
)
# When the URL path has no extension, follow it to where it resolves and
# adopt the real extension from the response, forcing the stored
# filename to match. Skipped when the caller opted into any extension.
if not allow_any_extension and url_path_extension(url) == "":
resolved_ext = await self._resolve_extension(url, credential_id)
model_id = paths.apply_extension(model_id, resolved_ext)
try:
paths.parse_model_id(model_id, allow_any_extension)
dest_path, temp_path = paths.resolve_destination(model_id, allow_any_extension)
@ -119,6 +145,40 @@ class DownloadManager:
await self._scheduler.pump()
return download_id
async def _resolve_extension(
self, url: str, credential_id: Optional[str]
) -> str:
"""Follow ``url`` to its final response and return the real extension.
Used for allowlisted URLs whose path has no extension (e.g. Civitai
download endpoints): the filename lives in the ``Content-Disposition``
header or the post-redirect URL. Raises :class:`DownloadError` when the
URL can't be resolved, needs credentials, or resolves to something that
is not a known model file so we never persist a bogus destination.
"""
pr = await probe(url, credential_id=credential_id)
if not pr.ok:
if pr.gated:
raise DownloadError(
"CREDENTIALS_REQUIRED",
f"{redact_url(url)} requires authentication to resolve. Add an "
f"API key for this host at /api/download/credentials and retry.",
status=401,
)
raise DownloadError(
"URL_RESOLVE_FAILED",
f"Could not resolve {redact_url(url)}: {pr.error or 'unknown error'}",
status=502,
)
ext = filename_extension(pr.filename) if pr.filename else ""
if ext not in ALLOWED_MODEL_EXTENSIONS:
raise DownloadError(
"URL_NOT_ALLOWED",
f"URL resolves to {pr.filename or '<unknown>'!r}, which is not a "
f"known model file type {ALLOWED_MODEL_EXTENSIONS}.",
)
return ext
def _model_lock(self, model_id: str) -> asyncio.Lock:
# Lazily create one lock per model_id. There is no ``await`` between the
# lookup and the insert, so under the single asyncio thread this is
@ -362,22 +422,25 @@ class DownloadManager:
if r.status in _LIVE_STATUSES or r.model_id not in by_model:
by_model[r.model_id] = r
# ``url_allowed`` mirrors the coarse enqueue gate (host/scheme + a
# non-disallowed extension); URLs whose extension is only known after a
# network resolve — e.g. Civitai download endpoints — report allowed.
out: dict[str, dict] = {}
for model_id, url in models.items():
try:
exists = await asyncio.to_thread(paths.resolve_existing, model_id)
except InvalidModelId:
out[model_id] = {"state": "missing", "url_allowed": is_url_allowed(url)}
out[model_id] = {"state": "missing", "url_allowed": is_url_downloadable(url)}
continue
if exists:
out[model_id] = {"state": "available", "url_allowed": is_url_allowed(url)}
out[model_id] = {"state": "available", "url_allowed": is_url_downloadable(url)}
continue
row = by_model.get(model_id)
if row is not None and row.status in _LIVE_STATUSES:
view = self._view(row)
out[model_id] = {
"state": "downloading",
"url_allowed": is_url_allowed(url),
"url_allowed": is_url_downloadable(url),
"download_id": view["download_id"],
"progress": view["progress"],
"bytes_done": view["bytes_done"],
@ -385,7 +448,7 @@ class DownloadManager:
"speed_bps": view["speed_bps"],
}
else:
out[model_id] = {"state": "missing", "url_allowed": is_url_allowed(url)}
out[model_id] = {"state": "missing", "url_allowed": is_url_downloadable(url)}
return out

View File

@ -10,9 +10,10 @@ that attaches credentials, so a token can never ride a redirect to a CDN host.
from __future__ import annotations
import logging
import re
from contextlib import asynccontextmanager
from typing import AsyncIterator, Optional
from urllib.parse import urljoin, urlsplit, urlunsplit
from urllib.parse import unquote, urljoin, urlsplit, urlunsplit
import aiohttp
@ -37,6 +38,43 @@ def redact_url(url: str) -> str:
return urlunsplit(parts._replace(query=""))
_CD_FILENAME_STAR = re.compile(
r"filename\*\s*=\s*[^']*'[^']*'([^;]+)", re.IGNORECASE
)
_CD_FILENAME_QUOTED = re.compile(r'filename\s*=\s*"([^"]+)"', re.IGNORECASE)
_CD_FILENAME_BARE = re.compile(r"filename\s*=\s*([^;]+)", re.IGNORECASE)
def filename_from_content_disposition(value: Optional[str]) -> Optional[str]:
"""Extract the download filename from a ``Content-Disposition`` header.
Prefers the RFC 5987 ``filename*=`` form (percent-decoded) over the plain
``filename=`` form. Any directory components in the value are stripped so a
hostile header can only influence the *name*, never the target directory.
Returns ``None`` when no filename is present.
"""
if not value:
return None
for pat, decode in (
(_CD_FILENAME_STAR, True),
(_CD_FILENAME_QUOTED, False),
(_CD_FILENAME_BARE, False),
):
m = pat.search(value)
if not m:
continue
raw = m.group(1).strip().strip('"')
if decode:
try:
raw = unquote(raw)
except Exception:
pass
name = raw.replace("\\", "/").rsplit("/", 1)[-1].strip()
if name:
return name
return None
async def _resolve_final_response(
method: str,
url: str,

View File

@ -12,11 +12,14 @@ from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Optional
from urllib.parse import urlparse
from urllib.parse import urlparse, urlsplit
import aiohttp
from app.model_downloader.net.http import open_validated
from app.model_downloader.net.http import (
filename_from_content_disposition,
open_validated,
)
from app.model_downloader.net.session import parse_int_header
_PROBE_TIMEOUT = aiohttp.ClientTimeout(total=60, sock_connect=30, sock_read=30)
@ -33,6 +36,11 @@ class ProbeResult:
last_modified: Optional[str] = None
gated: bool = False # 401/403 — needs (or has wrong) credentials
error: Optional[str] = None
# Filename the server intends this response to be saved as: the
# ``Content-Disposition`` name if present, else the post-redirect URL's
# basename. Used to resolve the real extension for URLs (e.g. Civitai's
# ``/api/download`` endpoints) that carry no extension in their path.
filename: Optional[str] = None
def _total_from_content_range(value: Optional[str]) -> Optional[int]:
@ -43,6 +51,19 @@ def _total_from_content_range(value: Optional[str]) -> Optional[int]:
return parse_int_header(total)
def _filename_from_response(
content_disposition: Optional[str], final_url: Optional[str]
) -> Optional[str]:
name = filename_from_content_disposition(content_disposition)
if name:
return name
if final_url:
base = urlsplit(final_url).path.rsplit("/", 1)[-1]
if base:
return base
return None
async def probe(url: str, *, credential_id: Optional[str] = None) -> ProbeResult:
"""Probe ``url`` and return discovered metadata, failing soft."""
try:
@ -85,6 +106,9 @@ async def probe(url: str, *, credential_id: Optional[str] = None) -> ProbeResult
accept_ranges=accept_ranges,
etag=headers.get("ETag"),
last_modified=headers.get("Last-Modified"),
filename=_filename_from_response(
headers.get("Content-Disposition"), final_url
),
)
except Exception as e: # network / SSRF / timeout
host = urlparse(url).netloc or "<unknown>"

View File

@ -71,6 +71,62 @@ def has_allowed_extension(path: str, allow_any_extension: bool = False) -> bool:
return path.lower().endswith(ALLOWED_MODEL_EXTENSIONS)
def filename_extension(name: str) -> str:
"""Lowercased extension (including the leading dot) of a bare filename.
Returns ``""`` when there is no extension. A leading-dot name
(``.safetensors``) is treated as having no extension (all stem), matching
``os.path.splitext`` semantics so dotfiles aren't mistaken for typed files.
"""
base = name.replace("\\", "/").rsplit("/", 1)[-1]
dot = base.rfind(".")
if dot <= 0:
return ""
return base[dot:].lower()
def is_allowed_extension_name(name: str) -> bool:
"""True iff ``name`` ends in one of the known model extensions."""
return name.lower().endswith(ALLOWED_MODEL_EXTENSIONS)
def is_host_allowed_url(url: str) -> bool:
"""True iff ``url`` parses and its host+scheme are allowlisted."""
if not isinstance(url, str) or not url:
return False
try:
parsed = urlparse(url)
except ValueError:
return False
return is_host_allowed(parsed.hostname, parsed.scheme)
def url_path_extension(url: str) -> str:
"""Extension of the URL *path* basename (query ignored), or ``""``."""
try:
parsed = urlparse(url)
except ValueError:
return ""
return filename_extension(parsed.path)
def is_url_downloadable(url: str) -> bool:
"""Coarse enqueue gate: host/scheme allowed and extension not disallowed.
Unlike :func:`is_url_allowed` (which demands a known extension *in the URL*),
this also admits URLs whose path carries no extension at all e.g. a Civitai
``/api/download/models/<id>`` endpoint whose real filename only shows up in
the redirect target / ``Content-Disposition``. The true extension is then
resolved from the network and re-validated before the download is admitted.
A path bearing an explicit *non-model* extension (``.zip``, ``.html``, ...)
is still rejected here.
"""
if not is_host_allowed_url(url):
return False
ext = url_path_extension(url)
return ext == "" or ext in ALLOWED_MODEL_EXTENSIONS
def is_url_allowed(url: str, allow_any_extension: bool = False) -> bool:
"""Check whether ``url`` is permitted as a server-side download source."""
if not isinstance(url, str) or not url:

View File

@ -58,6 +58,28 @@ def parse_model_id(model_id: str, allow_any_extension: bool = False) -> tuple[st
return directory, filename
def apply_extension(model_id: str, ext: str) -> str:
"""Return ``model_id`` with its filename forced to end in ``ext``.
``ext`` includes the leading dot (e.g. ``".safetensors"``). If the filename
already ends in a *known model extension* it is replaced; otherwise ``ext``
is appended (so ``loras/mymodel`` -> ``loras/mymodel.safetensors`` and
``loras/mymodel.ckpt`` -> ``loras/mymodel.safetensors``). A filename with a
non-model suffix (``my.model.v2``) is treated as an extensionless stem and
``ext`` is appended. The directory part is left untouched; validation is
still the caller's job via :func:`parse_model_id`.
"""
directory, sep, filename = model_id.partition("/")
if not sep:
return model_id # malformed; parse_model_id will reject it
low = filename.lower()
for known in ALLOWED_MODEL_EXTENSIONS:
if low.endswith(known):
filename = filename[: -len(known)]
break
return f"{directory}{sep}{filename}{ext}"
def resolve_existing(model_id: str, allow_any_extension: bool = False) -> Optional[str]:
"""Return the absolute path of an installed model, or None if missing.

View File

@ -83,6 +83,37 @@ def _range_handler(payload: bytes):
return handler
def _content_disposition_handler(payload: bytes, filename: str):
"""A range-capable server that only reveals its filename via a header.
Models a Civitai-style ``/api/download/...`` endpoint: the URL path has no
extension, and the real filename (hence extension) lives in the response
``Content-Disposition`` header.
"""
async def handler(request: web.Request) -> web.Response:
headers = {
"Accept-Ranges": "bytes",
"ETag": PAYLOAD_ETAG,
"Content-Disposition": f'attachment; filename="{filename}"',
}
rng = request.headers.get("Range")
if rng:
spec = rng.split("=", 1)[1]
s, _, e = spec.partition("-")
start = int(s)
end = int(e) if e else len(payload) - 1
chunk = payload[start : end + 1]
return web.Response(
status=206,
body=chunk,
headers={**headers, "Content-Range": f"bytes {start}-{end}/{len(payload)}"},
)
return web.Response(status=200, body=payload, headers=headers)
return handler
def _noranges_handler(payload: bytes):
async def handler(request: web.Request) -> web.Response:
# Always full body, never advertises Accept-Ranges -> single-stream.
@ -517,3 +548,90 @@ def test_manager_rejects_disallowed_url(model_root):
assert ei.value.code == "URL_NOT_ALLOWED"
asyncio.run(_run())
def test_manager_resolves_extensionless_url(model_root):
"""An allowlisted URL with no extension in its path is resolved from the
response, and the stored file adopts the resolved extension."""
payload = _safetensors_payload(1 * 1024 * 1024)
async def _run():
await close_session()
from app.model_downloader.manager import DOWNLOAD_MANAGER
runner, port = await _serve(
_content_disposition_handler(payload, "RealModel.safetensors")
)
try:
# No extension in the path (Civitai-style) and none in the model_id.
url = f"http://127.0.0.1:{port}/api/download/models/12345"
did = await DOWNLOAD_MANAGER.enqueue(url, "loras/my_civitai_model")
row = queries.get_download(did)
# The resolved extension was appended to the model_id + destination.
assert row.model_id == "loras/my_civitai_model.safetensors"
assert row.dest_path.endswith("my_civitai_model.safetensors")
final_path, _ = paths.resolve_destination(
"loras/my_civitai_model.safetensors"
)
for _ in range(500):
await asyncio.sleep(0.02)
row = queries.get_download(did)
if row.status in DownloadStatus.TERMINAL:
break
row = queries.get_download(did)
assert row.status == DownloadStatus.COMPLETED, row.error
assert open(final_path, "rb").read() == payload
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
def test_manager_overrides_extension_from_resolution(model_root):
"""A model_id carrying a different known extension is corrected to match
the resolved URL's extension."""
payload = _safetensors_payload(256 * 1024)
async def _run():
await close_session()
from app.model_downloader.manager import DOWNLOAD_MANAGER
runner, port = await _serve(
_content_disposition_handler(payload, "weights.safetensors")
)
try:
url = f"http://127.0.0.1:{port}/api/download/models/777"
# Caller guessed .ckpt; resolution says .safetensors -> corrected.
did = await DOWNLOAD_MANAGER.enqueue(url, "loras/guessed.ckpt")
row = queries.get_download(did)
assert row.model_id == "loras/guessed.safetensors"
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
def test_manager_rejects_non_model_resolution(model_root):
"""A URL that resolves to a non-model file is rejected, not downloaded."""
async def _run():
await close_session()
from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError
runner, port = await _serve(
_content_disposition_handler(b"not a model", "installer.zip")
)
try:
url = f"http://127.0.0.1:{port}/api/download/models/999"
with pytest.raises(DownloadError) as ei:
await DOWNLOAD_MANAGER.enqueue(url, "loras/whatever")
assert ei.value.code == "URL_NOT_ALLOWED"
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())

View File

@ -43,6 +43,42 @@ def test_allow_any_extension_relaxes_extension_only():
assert allowlist.is_url_allowed(odd, allow_any_extension=True) is True
@pytest.mark.parametrize(
"url,downloadable",
[
# known model extension in the path -> allowed
("https://civitai.com/x/model.safetensors", True),
# no extension in the path (Civitai download API) -> allowed, resolved later
("https://civitai.com/api/download/models/3031464?fileId=2910346", True),
("https://civitai.com/api/download/models/3031464", True),
# explicit non-model extension -> rejected even on an allowed host
("https://civitai.com/api/download/models/thing.zip", False),
("https://huggingface.co/org/repo/resolve/main/config.json", False),
# off-list host is never downloadable
("https://evil.example.com/api/download/models/1", False),
# http to a non-loopback allowlisted host is not permitted
("http://civitai.com/api/download/models/1", False),
],
)
def test_is_url_downloadable(url, downloadable):
assert allowlist.is_url_downloadable(url) is downloadable
@pytest.mark.parametrize(
"name,ext",
[
("model.safetensors", ".safetensors"),
("model.SAFETENSORS", ".safetensors"),
("archive.tar.gz", ".gz"),
("noext", ""),
(".safetensors", ""), # leading-dot dotfile -> no extension
("a/b/c/model.ckpt", ".ckpt"),
],
)
def test_filename_extension(name, ext):
assert allowlist.filename_extension(name) == ext
# ----- SSRF: blocked IPs -----
@ -148,3 +184,48 @@ def test_resolve_destination_stays_in_root(model_root):
assert final_path.startswith(model_root)
assert temp_path.startswith(model_root)
assert temp_path != final_path
@pytest.mark.parametrize(
"model_id,ext,expected",
[
# no extension -> append the resolved one
("loras/my_civitai_model", ".safetensors", "loras/my_civitai_model.safetensors"),
# different known extension -> replace it
("loras/mymodel.ckpt", ".safetensors", "loras/mymodel.safetensors"),
# same extension -> unchanged
("loras/mymodel.safetensors", ".safetensors", "loras/mymodel.safetensors"),
# non-model suffix is treated as a stem, extension appended
("loras/my.model.v2", ".safetensors", "loras/my.model.v2.safetensors"),
# malformed (no slash) is returned untouched for parse_model_id to reject
("noslash", ".safetensors", "noslash"),
],
)
def test_apply_extension(model_id, ext, expected):
assert paths.apply_extension(model_id, ext) == expected
# ----- Content-Disposition filename parsing -----
@pytest.mark.parametrize(
"header,expected",
[
('attachment; filename="model.safetensors"', "model.safetensors"),
("attachment; filename=model.ckpt", "model.ckpt"),
# RFC 5987 form is preferred and percent-decoded
(
"attachment; filename=\"fallback.bin\"; filename*=UTF-8''my%20model.safetensors",
"my model.safetensors",
),
# directory components in a hostile header are stripped to the basename
('attachment; filename="../../etc/passwd"', "passwd"),
('attachment; filename="a\\\\b\\\\model.pt"', "model.pt"),
("inline", None),
(None, None),
],
)
def test_filename_from_content_disposition(header, expected):
from app.model_downloader.net.http import filename_from_content_disposition
assert filename_from_content_disposition(header) == expected