diff --git a/app/model_downloader/engine/job.py b/app/model_downloader/engine/job.py index 93dcf269c..c0f5e9c2d 100644 --- a/app/model_downloader/engine/job.py +++ b/app/model_downloader/engine/job.py @@ -28,7 +28,7 @@ from app.model_downloader.engine.planner import ( ) from app.model_downloader.engine.writer import FileWriter from app.model_downloader.net.http import open_validated, redact_url -from app.model_downloader.net.probe import probe +from app.model_downloader.net.probe import gated_error_message, probe from app.model_downloader.verify import checksum, dedup, structural _RETRYABLE_STATUSES = {408, 429, 500, 502, 503, 504} @@ -191,10 +191,7 @@ class DownloadJob: pr = await probe(self.spec.url, credential_id=self.spec.credential_id) if not pr.ok: if pr.gated: - raise FatalError( - f"{redact_url(self.spec.url)} requires authentication. Add an API key for " - f"this host at /api/download/credentials and retry." - ) + raise FatalError(gated_error_message(self.spec.url, pr)) if pr.status == 0 or pr.status in _RETRYABLE_STATUSES: raise RetryableError(pr.error or "probe failed") raise FatalError(pr.error or f"probe returned HTTP {pr.status}") diff --git a/app/model_downloader/manager.py b/app/model_downloader/manager.py index 0f358c9d5..60f597237 100644 --- a/app/model_downloader/manager.py +++ b/app/model_downloader/manager.py @@ -14,7 +14,7 @@ from typing import Callable, Optional from app.model_downloader.constants import DownloadStatus from app.model_downloader.database import queries -from app.model_downloader.net.probe import probe +from app.model_downloader.net.probe import gated_error_message, probe from app.model_downloader.scheduler import SCHEDULER from app.model_downloader.security import paths from app.model_downloader.net.http import redact_url @@ -160,9 +160,8 @@ class DownloadManager: if not pr.ok: if pr.gated: raise DownloadError( - "CREDENTIALS_REQUIRED", - f"{redact_url(url)} requires authentication to resolve. Add an " - f"API key for this host at /api/download/credentials and retry.", + "GATED_REPO" if pr.is_gated_repo else "CREDENTIALS_REQUIRED", + gated_error_message(url, pr), status=401, ) raise DownloadError( diff --git a/app/model_downloader/net/probe.py b/app/model_downloader/net/probe.py index 7ed65b855..eca0c7fbb 100644 --- a/app/model_downloader/net/probe.py +++ b/app/model_downloader/net/probe.py @@ -19,6 +19,7 @@ import aiohttp from app.model_downloader.net.http import ( filename_from_content_disposition, open_validated, + redact_url, ) from app.model_downloader.net.session import parse_int_header @@ -36,12 +37,25 @@ class ProbeResult: last_modified: Optional[str] = None gated: bool = False # 401/403 — needs (or has wrong) credentials error: Optional[str] = None + # HuggingFace's ``X-Error-Code`` header (e.g. ``GatedRepo``, + # ``RepoNotFound``) when the host reports one. Lets us tell "this repo is + # gated — request access" apart from "you just need a token". + error_code: Optional[str] = None # Filename the server intends this response to be saved as: the # ``Content-Disposition`` name if present, else the post-redirect URL's # basename. Used to resolve the real extension for URLs (e.g. Civitai's # ``/api/download`` endpoints) that carry no extension in their path. filename: Optional[str] = None + @property + def is_gated_repo(self) -> bool: + """True when the host says the repo is gated (access must be granted). + + Distinct from a plain missing/invalid token: even a valid credential + won't help until the user accepts the model's terms on its page. + """ + return (self.error_code or "").lower() == "gatedrepo" + def _total_from_content_range(value: Optional[str]) -> Optional[int]: # "bytes 0-0/12345" -> 12345 ; "bytes 0-0/*" -> None @@ -75,9 +89,15 @@ async def probe(url: str, *, credential_id: Optional[str] = None) -> ProbeResult timeout=_PROBE_TIMEOUT, ) as (resp, final_url): if resp.status in (401, 403): + error_code = resp.headers.get("X-Error-Code") + error_message = resp.headers.get("X-Error-Message") return ProbeResult( ok=False, status=resp.status, final_url=final_url, gated=True, - error=f"host returned {resp.status} (authentication required)", + error_code=error_code, + error=( + error_message + or f"host returned {resp.status} (authentication required)" + ), ) if resp.status not in (200, 206): return ProbeResult( @@ -114,3 +134,24 @@ async def probe(url: str, *, credential_id: Optional[str] = None) -> ProbeResult host = urlparse(url).netloc or "" logging.debug("[model_downloader] probe failed for %s: %s", host, type(e).__name__) return ProbeResult(ok=False, status=0, error="probe failed: network error") + + +def gated_error_message(url: str, pr: ProbeResult) -> str: + """Build a user-facing message for a gated/auth-required probe result. + + Distinguishes a *gated* repo (access must be requested/granted on the model + page — a token alone is not enough) from a plain missing/invalid credential. + """ + redacted = redact_url(url) + if pr.is_gated_repo: + detail = (pr.error or "access is restricted").rstrip() + if detail and not detail.endswith((".", "!", "?")): + detail += "." + return ( + f"{redacted} is a gated model — {detail} Request access on the model's " + f"page, add an API key for this host at /api/download/credentials, and retry." + ) + return ( + f"{redacted} requires authentication. Add an API key for this host at " + f"/api/download/credentials and retry." + )