mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Add distinction in error messaging for gated models.
This commit is contained in:
parent
c98a212589
commit
d657a40681
@ -28,7 +28,7 @@ from app.model_downloader.engine.planner import (
|
|||||||
)
|
)
|
||||||
from app.model_downloader.engine.writer import FileWriter
|
from app.model_downloader.engine.writer import FileWriter
|
||||||
from app.model_downloader.net.http import open_validated, redact_url
|
from app.model_downloader.net.http import open_validated, redact_url
|
||||||
from app.model_downloader.net.probe import probe
|
from app.model_downloader.net.probe import gated_error_message, probe
|
||||||
from app.model_downloader.verify import checksum, dedup, structural
|
from app.model_downloader.verify import checksum, dedup, structural
|
||||||
|
|
||||||
_RETRYABLE_STATUSES = {408, 429, 500, 502, 503, 504}
|
_RETRYABLE_STATUSES = {408, 429, 500, 502, 503, 504}
|
||||||
@ -191,10 +191,7 @@ class DownloadJob:
|
|||||||
pr = await probe(self.spec.url, credential_id=self.spec.credential_id)
|
pr = await probe(self.spec.url, credential_id=self.spec.credential_id)
|
||||||
if not pr.ok:
|
if not pr.ok:
|
||||||
if pr.gated:
|
if pr.gated:
|
||||||
raise FatalError(
|
raise FatalError(gated_error_message(self.spec.url, pr))
|
||||||
f"{redact_url(self.spec.url)} requires authentication. Add an API key for "
|
|
||||||
f"this host at /api/download/credentials and retry."
|
|
||||||
)
|
|
||||||
if pr.status == 0 or pr.status in _RETRYABLE_STATUSES:
|
if pr.status == 0 or pr.status in _RETRYABLE_STATUSES:
|
||||||
raise RetryableError(pr.error or "probe failed")
|
raise RetryableError(pr.error or "probe failed")
|
||||||
raise FatalError(pr.error or f"probe returned HTTP {pr.status}")
|
raise FatalError(pr.error or f"probe returned HTTP {pr.status}")
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from typing import Callable, Optional
|
|||||||
|
|
||||||
from app.model_downloader.constants import DownloadStatus
|
from app.model_downloader.constants import DownloadStatus
|
||||||
from app.model_downloader.database import queries
|
from app.model_downloader.database import queries
|
||||||
from app.model_downloader.net.probe import probe
|
from app.model_downloader.net.probe import gated_error_message, probe
|
||||||
from app.model_downloader.scheduler import SCHEDULER
|
from app.model_downloader.scheduler import SCHEDULER
|
||||||
from app.model_downloader.security import paths
|
from app.model_downloader.security import paths
|
||||||
from app.model_downloader.net.http import redact_url
|
from app.model_downloader.net.http import redact_url
|
||||||
@ -160,9 +160,8 @@ class DownloadManager:
|
|||||||
if not pr.ok:
|
if not pr.ok:
|
||||||
if pr.gated:
|
if pr.gated:
|
||||||
raise DownloadError(
|
raise DownloadError(
|
||||||
"CREDENTIALS_REQUIRED",
|
"GATED_REPO" if pr.is_gated_repo else "CREDENTIALS_REQUIRED",
|
||||||
f"{redact_url(url)} requires authentication to resolve. Add an "
|
gated_error_message(url, pr),
|
||||||
f"API key for this host at /api/download/credentials and retry.",
|
|
||||||
status=401,
|
status=401,
|
||||||
)
|
)
|
||||||
raise DownloadError(
|
raise DownloadError(
|
||||||
|
|||||||
@ -19,6 +19,7 @@ import aiohttp
|
|||||||
from app.model_downloader.net.http import (
|
from app.model_downloader.net.http import (
|
||||||
filename_from_content_disposition,
|
filename_from_content_disposition,
|
||||||
open_validated,
|
open_validated,
|
||||||
|
redact_url,
|
||||||
)
|
)
|
||||||
from app.model_downloader.net.session import parse_int_header
|
from app.model_downloader.net.session import parse_int_header
|
||||||
|
|
||||||
@ -36,12 +37,25 @@ class ProbeResult:
|
|||||||
last_modified: Optional[str] = None
|
last_modified: Optional[str] = None
|
||||||
gated: bool = False # 401/403 — needs (or has wrong) credentials
|
gated: bool = False # 401/403 — needs (or has wrong) credentials
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
# HuggingFace's ``X-Error-Code`` header (e.g. ``GatedRepo``,
|
||||||
|
# ``RepoNotFound``) when the host reports one. Lets us tell "this repo is
|
||||||
|
# gated — request access" apart from "you just need a token".
|
||||||
|
error_code: Optional[str] = None
|
||||||
# Filename the server intends this response to be saved as: the
|
# Filename the server intends this response to be saved as: the
|
||||||
# ``Content-Disposition`` name if present, else the post-redirect URL's
|
# ``Content-Disposition`` name if present, else the post-redirect URL's
|
||||||
# basename. Used to resolve the real extension for URLs (e.g. Civitai's
|
# basename. Used to resolve the real extension for URLs (e.g. Civitai's
|
||||||
# ``/api/download`` endpoints) that carry no extension in their path.
|
# ``/api/download`` endpoints) that carry no extension in their path.
|
||||||
filename: Optional[str] = None
|
filename: Optional[str] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_gated_repo(self) -> bool:
|
||||||
|
"""True when the host says the repo is gated (access must be granted).
|
||||||
|
|
||||||
|
Distinct from a plain missing/invalid token: even a valid credential
|
||||||
|
won't help until the user accepts the model's terms on its page.
|
||||||
|
"""
|
||||||
|
return (self.error_code or "").lower() == "gatedrepo"
|
||||||
|
|
||||||
|
|
||||||
def _total_from_content_range(value: Optional[str]) -> Optional[int]:
|
def _total_from_content_range(value: Optional[str]) -> Optional[int]:
|
||||||
# "bytes 0-0/12345" -> 12345 ; "bytes 0-0/*" -> None
|
# "bytes 0-0/12345" -> 12345 ; "bytes 0-0/*" -> None
|
||||||
@ -75,9 +89,15 @@ async def probe(url: str, *, credential_id: Optional[str] = None) -> ProbeResult
|
|||||||
timeout=_PROBE_TIMEOUT,
|
timeout=_PROBE_TIMEOUT,
|
||||||
) as (resp, final_url):
|
) as (resp, final_url):
|
||||||
if resp.status in (401, 403):
|
if resp.status in (401, 403):
|
||||||
|
error_code = resp.headers.get("X-Error-Code")
|
||||||
|
error_message = resp.headers.get("X-Error-Message")
|
||||||
return ProbeResult(
|
return ProbeResult(
|
||||||
ok=False, status=resp.status, final_url=final_url, gated=True,
|
ok=False, status=resp.status, final_url=final_url, gated=True,
|
||||||
error=f"host returned {resp.status} (authentication required)",
|
error_code=error_code,
|
||||||
|
error=(
|
||||||
|
error_message
|
||||||
|
or f"host returned {resp.status} (authentication required)"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if resp.status not in (200, 206):
|
if resp.status not in (200, 206):
|
||||||
return ProbeResult(
|
return ProbeResult(
|
||||||
@ -114,3 +134,24 @@ async def probe(url: str, *, credential_id: Optional[str] = None) -> ProbeResult
|
|||||||
host = urlparse(url).netloc or "<unknown>"
|
host = urlparse(url).netloc or "<unknown>"
|
||||||
logging.debug("[model_downloader] probe failed for %s: %s", host, type(e).__name__)
|
logging.debug("[model_downloader] probe failed for %s: %s", host, type(e).__name__)
|
||||||
return ProbeResult(ok=False, status=0, error="probe failed: network error")
|
return ProbeResult(ok=False, status=0, error="probe failed: network error")
|
||||||
|
|
||||||
|
|
||||||
|
def gated_error_message(url: str, pr: ProbeResult) -> str:
|
||||||
|
"""Build a user-facing message for a gated/auth-required probe result.
|
||||||
|
|
||||||
|
Distinguishes a *gated* repo (access must be requested/granted on the model
|
||||||
|
page — a token alone is not enough) from a plain missing/invalid credential.
|
||||||
|
"""
|
||||||
|
redacted = redact_url(url)
|
||||||
|
if pr.is_gated_repo:
|
||||||
|
detail = (pr.error or "access is restricted").rstrip()
|
||||||
|
if detail and not detail.endswith((".", "!", "?")):
|
||||||
|
detail += "."
|
||||||
|
return (
|
||||||
|
f"{redacted} is a gated model — {detail} Request access on the model's "
|
||||||
|
f"page, add an API key for this host at /api/download/credentials, and retry."
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
f"{redacted} requires authentication. Add an API key for this host at "
|
||||||
|
f"/api/download/credentials and retry."
|
||||||
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user