ComfyUI/app/model_downloader/gated_detection.py
DoronGenzelHass fdd84d04a0 feat(model_downloader): server-side model download + HuggingFace OAuth subsystem
Self-contained package under app/model_downloader/:
- Allowlist + path-validated downloads (SSRF guard: HF/Civitai/localhost + model extension).
- Streaming worker: writes to <final>.tmp, atomic rename on success, cooperative cancellation with epoch-based session identity, orphan .tmp sweep.
- Unified availability probe with per-URL gated/size caching; is_hf_downloadable recomputed per call so login/license changes surface within one poll.
- HuggingFace OAuth 2.0 PKCE flow with loopback callback server and on-disk (0600) token storage + transparent refresh.
- Pydantic request/response schemas and aiohttp routes under api/.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-22 15:16:59 +03:00

239 lines
8.5 KiB
Python

"""Per-URL probes for the unified availability endpoint.
Three cached/derived facts per URL:
- ``is_gated`` intrinsic to the model; cached forever once known.
Determined by ``auth_check(repo_id, token=None)``:
``GatedRepoError`` → True, success → False.
- ``is_hf_downloadable`` depends on the *current* token; recomputed every
call. For non-gated URLs this is trivially True
(no HF call needed). For gated URLs we run
``auth_check`` with the stored token each call.
- ``file_size`` intrinsic to the file. Cached forever once
determined (including ``None`` on transient
failure — we don't retry). We only attempt the
HEAD when we already know the URL is downloadable
to us; that way a failed-because-gated probe
never lands as a cached ``None``.
Caches are per-process, in-memory; small, no eviction needed for the
workflow-scale (~tens of URLs). Concurrent calls for the same URL
deduplicate via per-URL ``asyncio.Lock``.
"""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from typing import Optional
import aiohttp
from huggingface_hub import HfApi
from huggingface_hub.errors import (
GatedRepoError,
HfHubHTTPError,
RepositoryNotFoundError,
)
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
from app.model_downloader.hf_url import is_hf_url, repo_id_from_url
from app.model_downloader.http_client import get_session, parse_content_length
_HEAD_TIMEOUT = aiohttp.ClientTimeout(total=15)
@dataclass
class ProbeResult:
file_size: Optional[int]
is_hf_downloadable: Optional[bool]
# --- caches -------------------------------------------------------------- #
# url → bool. Whether this URL's HF repo gates access. Intrinsic to the
# model — never changes for a given URL.
_is_gated_cache: dict[str, bool] = {}
# url → Optional[int]. The file's size in bytes, ``None`` if a probe
# was attempted and produced no answer. **Only populated when we knew
# the URL was downloadable to us at probe time** — so gated-without-
# access never lands a ``None`` here that we'd be stuck with after login.
_file_size_cache: dict[str, Optional[int]] = {}
# Per-URL locks for single-flight probes — when multiple polls arrive
# in the same tick for the same URL, exactly one of them runs the HF
# call and the others wait on the result.
_locks: dict[str, asyncio.Lock] = {}
def _lock_for(url: str) -> asyncio.Lock:
lock = _locks.get(url)
if lock is None:
lock = asyncio.Lock()
_locks[url] = lock
return lock
def clear_caches_for_tests() -> None:
"""Test-only: drop everything."""
_is_gated_cache.clear()
_file_size_cache.clear()
_locks.clear()
# --- public entrypoint --------------------------------------------------- #
async def probe_url(url: str) -> ProbeResult:
"""Return downloadability + size for one URL, using caches where safe."""
if not is_hf_url(url):
# Non-HF: ``is_hf_downloadable`` is "not applicable" (None).
# Size we still cache so we don't HEAD on every poll.
size = await _get_or_probe_size(url, token=None)
return ProbeResult(file_size=size, is_hf_downloadable=None)
repo_id = repo_id_from_url(url)
if repo_id is None:
return ProbeResult(file_size=None, is_hf_downloadable=None)
# Determine intrinsic gating once.
gated = await _resolve_is_gated(url, repo_id)
if gated is None:
return ProbeResult(file_size=None, is_hf_downloadable=None)
# Compute current-token downloadability per call.
tok = HF_AUTH_STORE.get_token_sync()
token_str: Optional[str] = tok.access_token if tok else None
if not gated:
is_hf_downloadable: Optional[bool] = True
else:
is_hf_downloadable = await _auth_check_with_token(repo_id, token_str)
if is_hf_downloadable is True:
size = await _get_or_probe_size(url, token=token_str)
else:
# Skip the HEAD entirely — would 401 and we'd be stuck with
# cached None that survives a later login.
size = None
return ProbeResult(file_size=size, is_hf_downloadable=is_hf_downloadable)
# --- gated/auth probes --------------------------------------------------- #
async def _resolve_is_gated(url: str, repo_id: str) -> Optional[bool]:
"""Decide once whether ``repo_id`` is a gated repo."""
cached = _is_gated_cache.get(url)
if cached is not None:
return cached
async with _lock_for(url):
cached = _is_gated_cache.get(url)
if cached is not None:
return cached
try:
await asyncio.to_thread(_auth_check_sync, repo_id, None)
_is_gated_cache[url] = False
return False
except GatedRepoError:
_is_gated_cache[url] = True
return True
except RepositoryNotFoundError:
# Repo doesn't exist publicly. Treat as gated — we can't
# serve it without auth, and an authenticated check might
# still succeed if it's a private repo the user can see.
_is_gated_cache[url] = True
return True
except (HfHubHTTPError, Exception) as e:
logging.debug(
"[hf_auth] is_gated probe failed for %s (will retry): %s",
repo_id, e,
)
return None # don't cache; retry next call
async def _auth_check_with_token(
repo_id: str, token: Optional[str]
) -> Optional[bool]:
"""Auth-check with the supplied token. True/False/None per outcome."""
try:
await asyncio.to_thread(_auth_check_sync, repo_id, token)
return True
except GatedRepoError:
return False
except RepositoryNotFoundError:
return False
except HfHubHTTPError as e:
# 401/403 covers org-SSO-required, revoked tokens, and similar —
# all of which mean "can't fetch right now" from the user's POV.
status = getattr(getattr(e, "response", None), "status_code", None)
if status in (401, 403):
return False
logging.debug(
"[hf_auth] auth_check transient failure for %s: %s", repo_id, e,
)
return None
except Exception as e:
logging.warning("[hf_auth] unexpected auth_check error for %s: %s", repo_id, e)
return None
def _auth_check_sync(repo_id: str, token: Optional[str]) -> None:
"""Thin sync wrapper around ``HfApi.auth_check`` for ``asyncio.to_thread``."""
HfApi().auth_check(repo_id, token=token)
# --- size probe ---------------------------------------------------------- #
async def _get_or_probe_size(url: str, token: Optional[str]) -> Optional[int]:
"""Return the cached size or HEAD the URL once and cache the result."""
if url in _file_size_cache:
return _file_size_cache[url]
async with _lock_for(url):
if url in _file_size_cache:
return _file_size_cache[url]
size = await _probe_size_once(url, token=token)
_file_size_cache[url] = size
return size
async def _probe_size_once(url: str, token: Optional[str]) -> Optional[int]:
"""HEAD the URL and return the file size in bytes, or None on failure.
HuggingFace serves LFS-tracked files via a 302 to a signed CDN URL.
The real file size lives in the non-standard ``X-Linked-Size`` header
on that 302 response (``Content-Length`` is the redirect-body length).
Disabling redirect-follow lets us read either header on the same
response:
- LFS files: 302 + ``X-Linked-Size``
- Small/non-LFS files: 200 + ``Content-Length``
"""
headers = {"Authorization": f"Bearer {token}"} if token else {}
try:
session = await get_session()
async with session.head(
url, allow_redirects=False, timeout=_HEAD_TIMEOUT, headers=headers,
) as resp:
linked = parse_content_length(resp.headers.get("X-Linked-Size"))
if linked is not None:
return linked
if resp.status == 200:
return parse_content_length(resp.headers.get("Content-Length"))
return None
except (aiohttp.ClientError, TimeoutError, OSError):
return None
# Backward-compat shim so consumers that still import the old name keep
# building during the refactor; can be removed once routes are updated.
MetadataProbeResult = ProbeResult