Add distinction in error messaging for gated models.
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run

This commit is contained in:
Talmaj Marinc 2026-07-02 12:41:23 +02:00
parent c98a212589
commit d657a40681
3 changed files with 47 additions and 10 deletions

View File

@ -28,7 +28,7 @@ from app.model_downloader.engine.planner import (
) )
from app.model_downloader.engine.writer import FileWriter from app.model_downloader.engine.writer import FileWriter
from app.model_downloader.net.http import open_validated, redact_url from app.model_downloader.net.http import open_validated, redact_url
from app.model_downloader.net.probe import probe from app.model_downloader.net.probe import gated_error_message, probe
from app.model_downloader.verify import checksum, dedup, structural from app.model_downloader.verify import checksum, dedup, structural
_RETRYABLE_STATUSES = {408, 429, 500, 502, 503, 504} _RETRYABLE_STATUSES = {408, 429, 500, 502, 503, 504}
@ -191,10 +191,7 @@ class DownloadJob:
pr = await probe(self.spec.url, credential_id=self.spec.credential_id) pr = await probe(self.spec.url, credential_id=self.spec.credential_id)
if not pr.ok: if not pr.ok:
if pr.gated: if pr.gated:
raise FatalError( raise FatalError(gated_error_message(self.spec.url, pr))
f"{redact_url(self.spec.url)} requires authentication. Add an API key for "
f"this host at /api/download/credentials and retry."
)
if pr.status == 0 or pr.status in _RETRYABLE_STATUSES: if pr.status == 0 or pr.status in _RETRYABLE_STATUSES:
raise RetryableError(pr.error or "probe failed") raise RetryableError(pr.error or "probe failed")
raise FatalError(pr.error or f"probe returned HTTP {pr.status}") raise FatalError(pr.error or f"probe returned HTTP {pr.status}")

View File

@ -14,7 +14,7 @@ from typing import Callable, Optional
from app.model_downloader.constants import DownloadStatus from app.model_downloader.constants import DownloadStatus
from app.model_downloader.database import queries from app.model_downloader.database import queries
from app.model_downloader.net.probe import probe from app.model_downloader.net.probe import gated_error_message, probe
from app.model_downloader.scheduler import SCHEDULER from app.model_downloader.scheduler import SCHEDULER
from app.model_downloader.security import paths from app.model_downloader.security import paths
from app.model_downloader.net.http import redact_url from app.model_downloader.net.http import redact_url
@ -160,9 +160,8 @@ class DownloadManager:
if not pr.ok: if not pr.ok:
if pr.gated: if pr.gated:
raise DownloadError( raise DownloadError(
"CREDENTIALS_REQUIRED", "GATED_REPO" if pr.is_gated_repo else "CREDENTIALS_REQUIRED",
f"{redact_url(url)} requires authentication to resolve. Add an " gated_error_message(url, pr),
f"API key for this host at /api/download/credentials and retry.",
status=401, status=401,
) )
raise DownloadError( raise DownloadError(

View File

@ -19,6 +19,7 @@ import aiohttp
from app.model_downloader.net.http import ( from app.model_downloader.net.http import (
filename_from_content_disposition, filename_from_content_disposition,
open_validated, open_validated,
redact_url,
) )
from app.model_downloader.net.session import parse_int_header from app.model_downloader.net.session import parse_int_header
@ -36,12 +37,25 @@ class ProbeResult:
last_modified: Optional[str] = None last_modified: Optional[str] = None
gated: bool = False # 401/403 — needs (or has wrong) credentials gated: bool = False # 401/403 — needs (or has wrong) credentials
error: Optional[str] = None error: Optional[str] = None
# HuggingFace's ``X-Error-Code`` header (e.g. ``GatedRepo``,
# ``RepoNotFound``) when the host reports one. Lets us tell "this repo is
# gated — request access" apart from "you just need a token".
error_code: Optional[str] = None
# Filename the server intends this response to be saved as: the # Filename the server intends this response to be saved as: the
# ``Content-Disposition`` name if present, else the post-redirect URL's # ``Content-Disposition`` name if present, else the post-redirect URL's
# basename. Used to resolve the real extension for URLs (e.g. Civitai's # basename. Used to resolve the real extension for URLs (e.g. Civitai's
# ``/api/download`` endpoints) that carry no extension in their path. # ``/api/download`` endpoints) that carry no extension in their path.
filename: Optional[str] = None filename: Optional[str] = None
@property
def is_gated_repo(self) -> bool:
"""True when the host says the repo is gated (access must be granted).
Distinct from a plain missing/invalid token: even a valid credential
won't help until the user accepts the model's terms on its page.
"""
return (self.error_code or "").lower() == "gatedrepo"
def _total_from_content_range(value: Optional[str]) -> Optional[int]: def _total_from_content_range(value: Optional[str]) -> Optional[int]:
# "bytes 0-0/12345" -> 12345 ; "bytes 0-0/*" -> None # "bytes 0-0/12345" -> 12345 ; "bytes 0-0/*" -> None
@ -75,9 +89,15 @@ async def probe(url: str, *, credential_id: Optional[str] = None) -> ProbeResult
timeout=_PROBE_TIMEOUT, timeout=_PROBE_TIMEOUT,
) as (resp, final_url): ) as (resp, final_url):
if resp.status in (401, 403): if resp.status in (401, 403):
error_code = resp.headers.get("X-Error-Code")
error_message = resp.headers.get("X-Error-Message")
return ProbeResult( return ProbeResult(
ok=False, status=resp.status, final_url=final_url, gated=True, ok=False, status=resp.status, final_url=final_url, gated=True,
error=f"host returned {resp.status} (authentication required)", error_code=error_code,
error=(
error_message
or f"host returned {resp.status} (authentication required)"
),
) )
if resp.status not in (200, 206): if resp.status not in (200, 206):
return ProbeResult( return ProbeResult(
@ -114,3 +134,24 @@ async def probe(url: str, *, credential_id: Optional[str] = None) -> ProbeResult
host = urlparse(url).netloc or "<unknown>" host = urlparse(url).netloc or "<unknown>"
logging.debug("[model_downloader] probe failed for %s: %s", host, type(e).__name__) logging.debug("[model_downloader] probe failed for %s: %s", host, type(e).__name__)
return ProbeResult(ok=False, status=0, error="probe failed: network error") return ProbeResult(ok=False, status=0, error="probe failed: network error")
def gated_error_message(url: str, pr: ProbeResult) -> str:
"""Build a user-facing message for a gated/auth-required probe result.
Distinguishes a *gated* repo (access must be requested/granted on the model
page a token alone is not enough) from a plain missing/invalid credential.
"""
redacted = redact_url(url)
if pr.is_gated_repo:
detail = (pr.error or "access is restricted").rstrip()
if detail and not detail.endswith((".", "!", "?")):
detail += "."
return (
f"{redacted} is a gated model — {detail} Request access on the model's "
f"page, add an API key for this host at /api/download/credentials, and retry."
)
return (
f"{redacted} requires authentication. Add an API key for this host at "
f"/api/download/credentials and retry."
)