mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Add extension check on the final resolved url -> fix downloading from civitAI.
This commit is contained in:
parent
4c82c708a7
commit
c98a212589
@ -14,10 +14,17 @@ from typing import Callable, Optional
|
||||
|
||||
from app.model_downloader.constants import DownloadStatus
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.net.probe import probe
|
||||
from app.model_downloader.scheduler import SCHEDULER
|
||||
from app.model_downloader.security import paths
|
||||
from app.model_downloader.net.http import redact_url
|
||||
from app.model_downloader.security.allowlist import is_url_allowed
|
||||
from app.model_downloader.security.allowlist import (
|
||||
ALLOWED_MODEL_EXTENSIONS,
|
||||
filename_extension,
|
||||
is_host_allowed_url,
|
||||
is_url_downloadable,
|
||||
url_path_extension,
|
||||
)
|
||||
from app.model_downloader.security.paths import InvalidModelId
|
||||
|
||||
# Non-terminal statuses: an existing row in one of these blocks a re-enqueue.
|
||||
@ -70,11 +77,30 @@ class DownloadManager:
|
||||
allow_any_extension: bool = False,
|
||||
credential_id: Optional[str] = None,
|
||||
) -> str:
|
||||
if not is_url_allowed(url, allow_any_extension):
|
||||
# Coarse gate first: host/scheme must be allowlisted, and any extension
|
||||
# present in the URL path must be a known model type. A URL whose path
|
||||
# carries NO extension (e.g. Civitai's ``/api/download/models/<id>``) is
|
||||
# admitted here and its real extension is resolved from the network
|
||||
# below before the download is finally accepted.
|
||||
if allow_any_extension:
|
||||
if not is_host_allowed_url(url):
|
||||
raise DownloadError(
|
||||
"URL_NOT_ALLOWED",
|
||||
"URL is not on the download allowlist (host/scheme).",
|
||||
)
|
||||
elif not is_url_downloadable(url):
|
||||
raise DownloadError(
|
||||
"URL_NOT_ALLOWED",
|
||||
"URL is not on the download allowlist (host/scheme/extension).",
|
||||
)
|
||||
|
||||
# When the URL path has no extension, follow it to where it resolves and
|
||||
# adopt the real extension from the response, forcing the stored
|
||||
# filename to match. Skipped when the caller opted into any extension.
|
||||
if not allow_any_extension and url_path_extension(url) == "":
|
||||
resolved_ext = await self._resolve_extension(url, credential_id)
|
||||
model_id = paths.apply_extension(model_id, resolved_ext)
|
||||
|
||||
try:
|
||||
paths.parse_model_id(model_id, allow_any_extension)
|
||||
dest_path, temp_path = paths.resolve_destination(model_id, allow_any_extension)
|
||||
@ -119,6 +145,40 @@ class DownloadManager:
|
||||
await self._scheduler.pump()
|
||||
return download_id
|
||||
|
||||
async def _resolve_extension(
|
||||
self, url: str, credential_id: Optional[str]
|
||||
) -> str:
|
||||
"""Follow ``url`` to its final response and return the real extension.
|
||||
|
||||
Used for allowlisted URLs whose path has no extension (e.g. Civitai
|
||||
download endpoints): the filename lives in the ``Content-Disposition``
|
||||
header or the post-redirect URL. Raises :class:`DownloadError` when the
|
||||
URL can't be resolved, needs credentials, or resolves to something that
|
||||
is not a known model file — so we never persist a bogus destination.
|
||||
"""
|
||||
pr = await probe(url, credential_id=credential_id)
|
||||
if not pr.ok:
|
||||
if pr.gated:
|
||||
raise DownloadError(
|
||||
"CREDENTIALS_REQUIRED",
|
||||
f"{redact_url(url)} requires authentication to resolve. Add an "
|
||||
f"API key for this host at /api/download/credentials and retry.",
|
||||
status=401,
|
||||
)
|
||||
raise DownloadError(
|
||||
"URL_RESOLVE_FAILED",
|
||||
f"Could not resolve {redact_url(url)}: {pr.error or 'unknown error'}",
|
||||
status=502,
|
||||
)
|
||||
ext = filename_extension(pr.filename) if pr.filename else ""
|
||||
if ext not in ALLOWED_MODEL_EXTENSIONS:
|
||||
raise DownloadError(
|
||||
"URL_NOT_ALLOWED",
|
||||
f"URL resolves to {pr.filename or '<unknown>'!r}, which is not a "
|
||||
f"known model file type {ALLOWED_MODEL_EXTENSIONS}.",
|
||||
)
|
||||
return ext
|
||||
|
||||
def _model_lock(self, model_id: str) -> asyncio.Lock:
|
||||
# Lazily create one lock per model_id. There is no ``await`` between the
|
||||
# lookup and the insert, so under the single asyncio thread this is
|
||||
@ -362,22 +422,25 @@ class DownloadManager:
|
||||
if r.status in _LIVE_STATUSES or r.model_id not in by_model:
|
||||
by_model[r.model_id] = r
|
||||
|
||||
# ``url_allowed`` mirrors the coarse enqueue gate (host/scheme + a
|
||||
# non-disallowed extension); URLs whose extension is only known after a
|
||||
# network resolve — e.g. Civitai download endpoints — report allowed.
|
||||
out: dict[str, dict] = {}
|
||||
for model_id, url in models.items():
|
||||
try:
|
||||
exists = await asyncio.to_thread(paths.resolve_existing, model_id)
|
||||
except InvalidModelId:
|
||||
out[model_id] = {"state": "missing", "url_allowed": is_url_allowed(url)}
|
||||
out[model_id] = {"state": "missing", "url_allowed": is_url_downloadable(url)}
|
||||
continue
|
||||
if exists:
|
||||
out[model_id] = {"state": "available", "url_allowed": is_url_allowed(url)}
|
||||
out[model_id] = {"state": "available", "url_allowed": is_url_downloadable(url)}
|
||||
continue
|
||||
row = by_model.get(model_id)
|
||||
if row is not None and row.status in _LIVE_STATUSES:
|
||||
view = self._view(row)
|
||||
out[model_id] = {
|
||||
"state": "downloading",
|
||||
"url_allowed": is_url_allowed(url),
|
||||
"url_allowed": is_url_downloadable(url),
|
||||
"download_id": view["download_id"],
|
||||
"progress": view["progress"],
|
||||
"bytes_done": view["bytes_done"],
|
||||
@ -385,7 +448,7 @@ class DownloadManager:
|
||||
"speed_bps": view["speed_bps"],
|
||||
}
|
||||
else:
|
||||
out[model_id] = {"state": "missing", "url_allowed": is_url_allowed(url)}
|
||||
out[model_id] = {"state": "missing", "url_allowed": is_url_downloadable(url)}
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@ -10,9 +10,10 @@ that attaches credentials, so a token can never ride a redirect to a CDN host.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator, Optional
|
||||
from urllib.parse import urljoin, urlsplit, urlunsplit
|
||||
from urllib.parse import unquote, urljoin, urlsplit, urlunsplit
|
||||
|
||||
import aiohttp
|
||||
|
||||
@ -37,6 +38,43 @@ def redact_url(url: str) -> str:
|
||||
return urlunsplit(parts._replace(query=""))
|
||||
|
||||
|
||||
_CD_FILENAME_STAR = re.compile(
|
||||
r"filename\*\s*=\s*[^']*'[^']*'([^;]+)", re.IGNORECASE
|
||||
)
|
||||
_CD_FILENAME_QUOTED = re.compile(r'filename\s*=\s*"([^"]+)"', re.IGNORECASE)
|
||||
_CD_FILENAME_BARE = re.compile(r"filename\s*=\s*([^;]+)", re.IGNORECASE)
|
||||
|
||||
|
||||
def filename_from_content_disposition(value: Optional[str]) -> Optional[str]:
|
||||
"""Extract the download filename from a ``Content-Disposition`` header.
|
||||
|
||||
Prefers the RFC 5987 ``filename*=`` form (percent-decoded) over the plain
|
||||
``filename=`` form. Any directory components in the value are stripped so a
|
||||
hostile header can only influence the *name*, never the target directory.
|
||||
Returns ``None`` when no filename is present.
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
for pat, decode in (
|
||||
(_CD_FILENAME_STAR, True),
|
||||
(_CD_FILENAME_QUOTED, False),
|
||||
(_CD_FILENAME_BARE, False),
|
||||
):
|
||||
m = pat.search(value)
|
||||
if not m:
|
||||
continue
|
||||
raw = m.group(1).strip().strip('"')
|
||||
if decode:
|
||||
try:
|
||||
raw = unquote(raw)
|
||||
except Exception:
|
||||
pass
|
||||
name = raw.replace("\\", "/").rsplit("/", 1)[-1].strip()
|
||||
if name:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
async def _resolve_final_response(
|
||||
method: str,
|
||||
url: str,
|
||||
|
||||
@ -12,11 +12,14 @@ from __future__ import annotations
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import urlparse, urlsplit
|
||||
|
||||
import aiohttp
|
||||
|
||||
from app.model_downloader.net.http import open_validated
|
||||
from app.model_downloader.net.http import (
|
||||
filename_from_content_disposition,
|
||||
open_validated,
|
||||
)
|
||||
from app.model_downloader.net.session import parse_int_header
|
||||
|
||||
_PROBE_TIMEOUT = aiohttp.ClientTimeout(total=60, sock_connect=30, sock_read=30)
|
||||
@ -33,6 +36,11 @@ class ProbeResult:
|
||||
last_modified: Optional[str] = None
|
||||
gated: bool = False # 401/403 — needs (or has wrong) credentials
|
||||
error: Optional[str] = None
|
||||
# Filename the server intends this response to be saved as: the
|
||||
# ``Content-Disposition`` name if present, else the post-redirect URL's
|
||||
# basename. Used to resolve the real extension for URLs (e.g. Civitai's
|
||||
# ``/api/download`` endpoints) that carry no extension in their path.
|
||||
filename: Optional[str] = None
|
||||
|
||||
|
||||
def _total_from_content_range(value: Optional[str]) -> Optional[int]:
|
||||
@ -43,6 +51,19 @@ def _total_from_content_range(value: Optional[str]) -> Optional[int]:
|
||||
return parse_int_header(total)
|
||||
|
||||
|
||||
def _filename_from_response(
|
||||
content_disposition: Optional[str], final_url: Optional[str]
|
||||
) -> Optional[str]:
|
||||
name = filename_from_content_disposition(content_disposition)
|
||||
if name:
|
||||
return name
|
||||
if final_url:
|
||||
base = urlsplit(final_url).path.rsplit("/", 1)[-1]
|
||||
if base:
|
||||
return base
|
||||
return None
|
||||
|
||||
|
||||
async def probe(url: str, *, credential_id: Optional[str] = None) -> ProbeResult:
|
||||
"""Probe ``url`` and return discovered metadata, failing soft."""
|
||||
try:
|
||||
@ -85,6 +106,9 @@ async def probe(url: str, *, credential_id: Optional[str] = None) -> ProbeResult
|
||||
accept_ranges=accept_ranges,
|
||||
etag=headers.get("ETag"),
|
||||
last_modified=headers.get("Last-Modified"),
|
||||
filename=_filename_from_response(
|
||||
headers.get("Content-Disposition"), final_url
|
||||
),
|
||||
)
|
||||
except Exception as e: # network / SSRF / timeout
|
||||
host = urlparse(url).netloc or "<unknown>"
|
||||
|
||||
@ -71,6 +71,62 @@ def has_allowed_extension(path: str, allow_any_extension: bool = False) -> bool:
|
||||
return path.lower().endswith(ALLOWED_MODEL_EXTENSIONS)
|
||||
|
||||
|
||||
def filename_extension(name: str) -> str:
|
||||
"""Lowercased extension (including the leading dot) of a bare filename.
|
||||
|
||||
Returns ``""`` when there is no extension. A leading-dot name
|
||||
(``.safetensors``) is treated as having no extension (all stem), matching
|
||||
``os.path.splitext`` semantics so dotfiles aren't mistaken for typed files.
|
||||
"""
|
||||
base = name.replace("\\", "/").rsplit("/", 1)[-1]
|
||||
dot = base.rfind(".")
|
||||
if dot <= 0:
|
||||
return ""
|
||||
return base[dot:].lower()
|
||||
|
||||
|
||||
def is_allowed_extension_name(name: str) -> bool:
|
||||
"""True iff ``name`` ends in one of the known model extensions."""
|
||||
return name.lower().endswith(ALLOWED_MODEL_EXTENSIONS)
|
||||
|
||||
|
||||
def is_host_allowed_url(url: str) -> bool:
|
||||
"""True iff ``url`` parses and its host+scheme are allowlisted."""
|
||||
if not isinstance(url, str) or not url:
|
||||
return False
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except ValueError:
|
||||
return False
|
||||
return is_host_allowed(parsed.hostname, parsed.scheme)
|
||||
|
||||
|
||||
def url_path_extension(url: str) -> str:
|
||||
"""Extension of the URL *path* basename (query ignored), or ``""``."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except ValueError:
|
||||
return ""
|
||||
return filename_extension(parsed.path)
|
||||
|
||||
|
||||
def is_url_downloadable(url: str) -> bool:
|
||||
"""Coarse enqueue gate: host/scheme allowed and extension not disallowed.
|
||||
|
||||
Unlike :func:`is_url_allowed` (which demands a known extension *in the URL*),
|
||||
this also admits URLs whose path carries no extension at all — e.g. a Civitai
|
||||
``/api/download/models/<id>`` endpoint whose real filename only shows up in
|
||||
the redirect target / ``Content-Disposition``. The true extension is then
|
||||
resolved from the network and re-validated before the download is admitted.
|
||||
A path bearing an explicit *non-model* extension (``.zip``, ``.html``, ...)
|
||||
is still rejected here.
|
||||
"""
|
||||
if not is_host_allowed_url(url):
|
||||
return False
|
||||
ext = url_path_extension(url)
|
||||
return ext == "" or ext in ALLOWED_MODEL_EXTENSIONS
|
||||
|
||||
|
||||
def is_url_allowed(url: str, allow_any_extension: bool = False) -> bool:
|
||||
"""Check whether ``url`` is permitted as a server-side download source."""
|
||||
if not isinstance(url, str) or not url:
|
||||
|
||||
@ -58,6 +58,28 @@ def parse_model_id(model_id: str, allow_any_extension: bool = False) -> tuple[st
|
||||
return directory, filename
|
||||
|
||||
|
||||
def apply_extension(model_id: str, ext: str) -> str:
|
||||
"""Return ``model_id`` with its filename forced to end in ``ext``.
|
||||
|
||||
``ext`` includes the leading dot (e.g. ``".safetensors"``). If the filename
|
||||
already ends in a *known model extension* it is replaced; otherwise ``ext``
|
||||
is appended (so ``loras/mymodel`` -> ``loras/mymodel.safetensors`` and
|
||||
``loras/mymodel.ckpt`` -> ``loras/mymodel.safetensors``). A filename with a
|
||||
non-model suffix (``my.model.v2``) is treated as an extensionless stem and
|
||||
``ext`` is appended. The directory part is left untouched; validation is
|
||||
still the caller's job via :func:`parse_model_id`.
|
||||
"""
|
||||
directory, sep, filename = model_id.partition("/")
|
||||
if not sep:
|
||||
return model_id # malformed; parse_model_id will reject it
|
||||
low = filename.lower()
|
||||
for known in ALLOWED_MODEL_EXTENSIONS:
|
||||
if low.endswith(known):
|
||||
filename = filename[: -len(known)]
|
||||
break
|
||||
return f"{directory}{sep}{filename}{ext}"
|
||||
|
||||
|
||||
def resolve_existing(model_id: str, allow_any_extension: bool = False) -> Optional[str]:
|
||||
"""Return the absolute path of an installed model, or None if missing.
|
||||
|
||||
|
||||
@ -83,6 +83,37 @@ def _range_handler(payload: bytes):
|
||||
return handler
|
||||
|
||||
|
||||
def _content_disposition_handler(payload: bytes, filename: str):
|
||||
"""A range-capable server that only reveals its filename via a header.
|
||||
|
||||
Models a Civitai-style ``/api/download/...`` endpoint: the URL path has no
|
||||
extension, and the real filename (hence extension) lives in the response
|
||||
``Content-Disposition`` header.
|
||||
"""
|
||||
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"ETag": PAYLOAD_ETAG,
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
}
|
||||
rng = request.headers.get("Range")
|
||||
if rng:
|
||||
spec = rng.split("=", 1)[1]
|
||||
s, _, e = spec.partition("-")
|
||||
start = int(s)
|
||||
end = int(e) if e else len(payload) - 1
|
||||
chunk = payload[start : end + 1]
|
||||
return web.Response(
|
||||
status=206,
|
||||
body=chunk,
|
||||
headers={**headers, "Content-Range": f"bytes {start}-{end}/{len(payload)}"},
|
||||
)
|
||||
return web.Response(status=200, body=payload, headers=headers)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def _noranges_handler(payload: bytes):
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
# Always full body, never advertises Accept-Ranges -> single-stream.
|
||||
@ -517,3 +548,90 @@ def test_manager_rejects_disallowed_url(model_root):
|
||||
assert ei.value.code == "URL_NOT_ALLOWED"
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_manager_resolves_extensionless_url(model_root):
|
||||
"""An allowlisted URL with no extension in its path is resolved from the
|
||||
response, and the stored file adopts the resolved extension."""
|
||||
payload = _safetensors_payload(1 * 1024 * 1024)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER
|
||||
|
||||
runner, port = await _serve(
|
||||
_content_disposition_handler(payload, "RealModel.safetensors")
|
||||
)
|
||||
try:
|
||||
# No extension in the path (Civitai-style) and none in the model_id.
|
||||
url = f"http://127.0.0.1:{port}/api/download/models/12345"
|
||||
did = await DOWNLOAD_MANAGER.enqueue(url, "loras/my_civitai_model")
|
||||
|
||||
row = queries.get_download(did)
|
||||
# The resolved extension was appended to the model_id + destination.
|
||||
assert row.model_id == "loras/my_civitai_model.safetensors"
|
||||
assert row.dest_path.endswith("my_civitai_model.safetensors")
|
||||
|
||||
final_path, _ = paths.resolve_destination(
|
||||
"loras/my_civitai_model.safetensors"
|
||||
)
|
||||
for _ in range(500):
|
||||
await asyncio.sleep(0.02)
|
||||
row = queries.get_download(did)
|
||||
if row.status in DownloadStatus.TERMINAL:
|
||||
break
|
||||
row = queries.get_download(did)
|
||||
assert row.status == DownloadStatus.COMPLETED, row.error
|
||||
assert open(final_path, "rb").read() == payload
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_manager_overrides_extension_from_resolution(model_root):
|
||||
"""A model_id carrying a different known extension is corrected to match
|
||||
the resolved URL's extension."""
|
||||
payload = _safetensors_payload(256 * 1024)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER
|
||||
|
||||
runner, port = await _serve(
|
||||
_content_disposition_handler(payload, "weights.safetensors")
|
||||
)
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/api/download/models/777"
|
||||
# Caller guessed .ckpt; resolution says .safetensors -> corrected.
|
||||
did = await DOWNLOAD_MANAGER.enqueue(url, "loras/guessed.ckpt")
|
||||
row = queries.get_download(did)
|
||||
assert row.model_id == "loras/guessed.safetensors"
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_manager_rejects_non_model_resolution(model_root):
|
||||
"""A URL that resolves to a non-model file is rejected, not downloaded."""
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError
|
||||
|
||||
runner, port = await _serve(
|
||||
_content_disposition_handler(b"not a model", "installer.zip")
|
||||
)
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/api/download/models/999"
|
||||
with pytest.raises(DownloadError) as ei:
|
||||
await DOWNLOAD_MANAGER.enqueue(url, "loras/whatever")
|
||||
assert ei.value.code == "URL_NOT_ALLOWED"
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
@ -43,6 +43,42 @@ def test_allow_any_extension_relaxes_extension_only():
|
||||
assert allowlist.is_url_allowed(odd, allow_any_extension=True) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url,downloadable",
|
||||
[
|
||||
# known model extension in the path -> allowed
|
||||
("https://civitai.com/x/model.safetensors", True),
|
||||
# no extension in the path (Civitai download API) -> allowed, resolved later
|
||||
("https://civitai.com/api/download/models/3031464?fileId=2910346", True),
|
||||
("https://civitai.com/api/download/models/3031464", True),
|
||||
# explicit non-model extension -> rejected even on an allowed host
|
||||
("https://civitai.com/api/download/models/thing.zip", False),
|
||||
("https://huggingface.co/org/repo/resolve/main/config.json", False),
|
||||
# off-list host is never downloadable
|
||||
("https://evil.example.com/api/download/models/1", False),
|
||||
# http to a non-loopback allowlisted host is not permitted
|
||||
("http://civitai.com/api/download/models/1", False),
|
||||
],
|
||||
)
|
||||
def test_is_url_downloadable(url, downloadable):
|
||||
assert allowlist.is_url_downloadable(url) is downloadable
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name,ext",
|
||||
[
|
||||
("model.safetensors", ".safetensors"),
|
||||
("model.SAFETENSORS", ".safetensors"),
|
||||
("archive.tar.gz", ".gz"),
|
||||
("noext", ""),
|
||||
(".safetensors", ""), # leading-dot dotfile -> no extension
|
||||
("a/b/c/model.ckpt", ".ckpt"),
|
||||
],
|
||||
)
|
||||
def test_filename_extension(name, ext):
|
||||
assert allowlist.filename_extension(name) == ext
|
||||
|
||||
|
||||
# ----- SSRF: blocked IPs -----
|
||||
|
||||
|
||||
@ -148,3 +184,48 @@ def test_resolve_destination_stays_in_root(model_root):
|
||||
assert final_path.startswith(model_root)
|
||||
assert temp_path.startswith(model_root)
|
||||
assert temp_path != final_path
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_id,ext,expected",
|
||||
[
|
||||
# no extension -> append the resolved one
|
||||
("loras/my_civitai_model", ".safetensors", "loras/my_civitai_model.safetensors"),
|
||||
# different known extension -> replace it
|
||||
("loras/mymodel.ckpt", ".safetensors", "loras/mymodel.safetensors"),
|
||||
# same extension -> unchanged
|
||||
("loras/mymodel.safetensors", ".safetensors", "loras/mymodel.safetensors"),
|
||||
# non-model suffix is treated as a stem, extension appended
|
||||
("loras/my.model.v2", ".safetensors", "loras/my.model.v2.safetensors"),
|
||||
# malformed (no slash) is returned untouched for parse_model_id to reject
|
||||
("noslash", ".safetensors", "noslash"),
|
||||
],
|
||||
)
|
||||
def test_apply_extension(model_id, ext, expected):
|
||||
assert paths.apply_extension(model_id, ext) == expected
|
||||
|
||||
|
||||
# ----- Content-Disposition filename parsing -----
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"header,expected",
|
||||
[
|
||||
('attachment; filename="model.safetensors"', "model.safetensors"),
|
||||
("attachment; filename=model.ckpt", "model.ckpt"),
|
||||
# RFC 5987 form is preferred and percent-decoded
|
||||
(
|
||||
"attachment; filename=\"fallback.bin\"; filename*=UTF-8''my%20model.safetensors",
|
||||
"my model.safetensors",
|
||||
),
|
||||
# directory components in a hostile header are stripped to the basename
|
||||
('attachment; filename="../../etc/passwd"', "passwd"),
|
||||
('attachment; filename="a\\\\b\\\\model.pt"', "model.pt"),
|
||||
("inline", None),
|
||||
(None, None),
|
||||
],
|
||||
)
|
||||
def test_filename_from_content_disposition(header, expected):
|
||||
from app.model_downloader.net.http import filename_from_content_disposition
|
||||
|
||||
assert filename_from_content_disposition(header) == expected
|
||||
|
||||
Loading…
Reference in New Issue
Block a user