mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-23 16:29:25 +08:00
Self-contained package under app/model_downloader/: - Allowlist + path-validated downloads (SSRF guard: HF/Civitai/localhost + model extension). - Streaming worker: writes to <final>.tmp, atomic rename on success, cooperative cancellation with epoch-based session identity, orphan .tmp sweep. - Unified availability probe with per-URL gated/size caching; is_hf_downloadable recomputed per call so login/license changes surface within one poll. - HuggingFace OAuth 2.0 PKCE flow with loopback callback server and on-disk (0600) token storage + transparent refresh. - Pydantic request/response schemas and aiohttp routes under api/. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
239 lines
8.5 KiB
Python
239 lines
8.5 KiB
Python
"""Per-URL probes for the unified availability endpoint.
|
|
|
|
Three cached/derived facts per URL:
|
|
|
|
- ``is_gated`` intrinsic to the model; cached forever once known.
|
|
Determined by ``auth_check(repo_id, token=None)``:
|
|
``GatedRepoError`` → True, success → False.
|
|
|
|
- ``is_hf_downloadable`` depends on the *current* token; recomputed every
|
|
call. For non-gated URLs this is trivially True
|
|
(no HF call needed). For gated URLs we run
|
|
``auth_check`` with the stored token each call.
|
|
|
|
- ``file_size`` intrinsic to the file. Cached forever once
|
|
determined (including ``None`` on transient
|
|
failure — we don't retry). We only attempt the
|
|
HEAD when we already know the URL is downloadable
|
|
to us; that way a failed-because-gated probe
|
|
never lands as a cached ``None``.
|
|
|
|
Caches are per-process, in-memory; small, no eviction needed for the
|
|
workflow-scale (~tens of URLs). Concurrent calls for the same URL
|
|
deduplicate via per-URL ``asyncio.Lock``.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
import aiohttp
|
|
from huggingface_hub import HfApi
|
|
from huggingface_hub.errors import (
|
|
GatedRepoError,
|
|
HfHubHTTPError,
|
|
RepositoryNotFoundError,
|
|
)
|
|
|
|
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
|
|
from app.model_downloader.hf_url import is_hf_url, repo_id_from_url
|
|
from app.model_downloader.http_client import get_session, parse_content_length
|
|
|
|
|
|
_HEAD_TIMEOUT = aiohttp.ClientTimeout(total=15)
|
|
|
|
|
|
@dataclass
|
|
class ProbeResult:
|
|
file_size: Optional[int]
|
|
is_hf_downloadable: Optional[bool]
|
|
|
|
|
|
# --- caches -------------------------------------------------------------- #
|
|
|
|
|
|
# url → bool. Whether this URL's HF repo gates access. Intrinsic to the
|
|
# model — never changes for a given URL.
|
|
_is_gated_cache: dict[str, bool] = {}
|
|
|
|
# url → Optional[int]. The file's size in bytes, ``None`` if a probe
|
|
# was attempted and produced no answer. **Only populated when we knew
|
|
# the URL was downloadable to us at probe time** — so gated-without-
|
|
# access never lands a ``None`` here that we'd be stuck with after login.
|
|
_file_size_cache: dict[str, Optional[int]] = {}
|
|
|
|
# Per-URL locks for single-flight probes — when multiple polls arrive
|
|
# in the same tick for the same URL, exactly one of them runs the HF
|
|
# call and the others wait on the result.
|
|
_locks: dict[str, asyncio.Lock] = {}
|
|
|
|
|
|
def _lock_for(url: str) -> asyncio.Lock:
|
|
lock = _locks.get(url)
|
|
if lock is None:
|
|
lock = asyncio.Lock()
|
|
_locks[url] = lock
|
|
return lock
|
|
|
|
|
|
def clear_caches_for_tests() -> None:
|
|
"""Test-only: drop everything."""
|
|
_is_gated_cache.clear()
|
|
_file_size_cache.clear()
|
|
_locks.clear()
|
|
|
|
|
|
# --- public entrypoint --------------------------------------------------- #
|
|
|
|
|
|
async def probe_url(url: str) -> ProbeResult:
|
|
"""Return downloadability + size for one URL, using caches where safe."""
|
|
if not is_hf_url(url):
|
|
# Non-HF: ``is_hf_downloadable`` is "not applicable" (None).
|
|
# Size we still cache so we don't HEAD on every poll.
|
|
size = await _get_or_probe_size(url, token=None)
|
|
return ProbeResult(file_size=size, is_hf_downloadable=None)
|
|
|
|
repo_id = repo_id_from_url(url)
|
|
if repo_id is None:
|
|
return ProbeResult(file_size=None, is_hf_downloadable=None)
|
|
|
|
# Determine intrinsic gating once.
|
|
gated = await _resolve_is_gated(url, repo_id)
|
|
if gated is None:
|
|
return ProbeResult(file_size=None, is_hf_downloadable=None)
|
|
|
|
# Compute current-token downloadability per call.
|
|
tok = HF_AUTH_STORE.get_token_sync()
|
|
token_str: Optional[str] = tok.access_token if tok else None
|
|
if not gated:
|
|
is_hf_downloadable: Optional[bool] = True
|
|
else:
|
|
is_hf_downloadable = await _auth_check_with_token(repo_id, token_str)
|
|
|
|
if is_hf_downloadable is True:
|
|
size = await _get_or_probe_size(url, token=token_str)
|
|
else:
|
|
# Skip the HEAD entirely — would 401 and we'd be stuck with
|
|
# cached None that survives a later login.
|
|
size = None
|
|
|
|
return ProbeResult(file_size=size, is_hf_downloadable=is_hf_downloadable)
|
|
|
|
|
|
# --- gated/auth probes --------------------------------------------------- #
|
|
|
|
|
|
async def _resolve_is_gated(url: str, repo_id: str) -> Optional[bool]:
|
|
"""Decide once whether ``repo_id`` is a gated repo."""
|
|
cached = _is_gated_cache.get(url)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
async with _lock_for(url):
|
|
cached = _is_gated_cache.get(url)
|
|
if cached is not None:
|
|
return cached
|
|
try:
|
|
await asyncio.to_thread(_auth_check_sync, repo_id, None)
|
|
_is_gated_cache[url] = False
|
|
return False
|
|
except GatedRepoError:
|
|
_is_gated_cache[url] = True
|
|
return True
|
|
except RepositoryNotFoundError:
|
|
# Repo doesn't exist publicly. Treat as gated — we can't
|
|
# serve it without auth, and an authenticated check might
|
|
# still succeed if it's a private repo the user can see.
|
|
_is_gated_cache[url] = True
|
|
return True
|
|
except (HfHubHTTPError, Exception) as e:
|
|
logging.debug(
|
|
"[hf_auth] is_gated probe failed for %s (will retry): %s",
|
|
repo_id, e,
|
|
)
|
|
return None # don't cache; retry next call
|
|
|
|
|
|
async def _auth_check_with_token(
|
|
repo_id: str, token: Optional[str]
|
|
) -> Optional[bool]:
|
|
"""Auth-check with the supplied token. True/False/None per outcome."""
|
|
try:
|
|
await asyncio.to_thread(_auth_check_sync, repo_id, token)
|
|
return True
|
|
except GatedRepoError:
|
|
return False
|
|
except RepositoryNotFoundError:
|
|
return False
|
|
except HfHubHTTPError as e:
|
|
# 401/403 covers org-SSO-required, revoked tokens, and similar —
|
|
# all of which mean "can't fetch right now" from the user's POV.
|
|
status = getattr(getattr(e, "response", None), "status_code", None)
|
|
if status in (401, 403):
|
|
return False
|
|
logging.debug(
|
|
"[hf_auth] auth_check transient failure for %s: %s", repo_id, e,
|
|
)
|
|
return None
|
|
except Exception as e:
|
|
logging.warning("[hf_auth] unexpected auth_check error for %s: %s", repo_id, e)
|
|
return None
|
|
|
|
|
|
def _auth_check_sync(repo_id: str, token: Optional[str]) -> None:
|
|
"""Thin sync wrapper around ``HfApi.auth_check`` for ``asyncio.to_thread``."""
|
|
HfApi().auth_check(repo_id, token=token)
|
|
|
|
|
|
# --- size probe ---------------------------------------------------------- #
|
|
|
|
|
|
async def _get_or_probe_size(url: str, token: Optional[str]) -> Optional[int]:
|
|
"""Return the cached size or HEAD the URL once and cache the result."""
|
|
if url in _file_size_cache:
|
|
return _file_size_cache[url]
|
|
|
|
async with _lock_for(url):
|
|
if url in _file_size_cache:
|
|
return _file_size_cache[url]
|
|
size = await _probe_size_once(url, token=token)
|
|
_file_size_cache[url] = size
|
|
return size
|
|
|
|
|
|
async def _probe_size_once(url: str, token: Optional[str]) -> Optional[int]:
|
|
"""HEAD the URL and return the file size in bytes, or None on failure.
|
|
|
|
HuggingFace serves LFS-tracked files via a 302 to a signed CDN URL.
|
|
The real file size lives in the non-standard ``X-Linked-Size`` header
|
|
on that 302 response (``Content-Length`` is the redirect-body length).
|
|
Disabling redirect-follow lets us read either header on the same
|
|
response:
|
|
|
|
- LFS files: 302 + ``X-Linked-Size``
|
|
- Small/non-LFS files: 200 + ``Content-Length``
|
|
"""
|
|
headers = {"Authorization": f"Bearer {token}"} if token else {}
|
|
try:
|
|
session = await get_session()
|
|
async with session.head(
|
|
url, allow_redirects=False, timeout=_HEAD_TIMEOUT, headers=headers,
|
|
) as resp:
|
|
linked = parse_content_length(resp.headers.get("X-Linked-Size"))
|
|
if linked is not None:
|
|
return linked
|
|
if resp.status == 200:
|
|
return parse_content_length(resp.headers.get("Content-Length"))
|
|
return None
|
|
except (aiohttp.ClientError, TimeoutError, OSError):
|
|
return None
|
|
|
|
|
|
# Backward-compat shim so consumers that still import the old name keep
|
|
# building during the refactor; can be removed once routes are updated.
|
|
MetadataProbeResult = ProbeResult
|