mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 18:13:01 +08:00
138 lines
5.6 KiB
Python
138 lines
5.6 KiB
Python
from __future__ import annotations
|
|
import asyncio
|
|
import os
|
|
import tempfile
|
|
from typing import Optional
|
|
import mimetypes
|
|
import aiohttp
|
|
|
|
from .storage.hashing import blake3_hash_sync
|
|
from .database.db import create_session
|
|
from .database.services import ingest_fs_asset, list_cache_states_by_asset_hash
|
|
from .resolvers import resolve_asset
|
|
from ._assets_helpers import resolve_destination_from_tags, ensure_within_base
|
|
|
|
_FETCH_LOCKS: dict[str, asyncio.Lock] = {}
|
|
|
|
|
|
def _sanitize_filename(name: str) -> str:
|
|
return os.path.basename((name or "").strip()) or "file"
|
|
|
|
|
|
async def ensure_asset_cached(
|
|
asset_hash: str,
|
|
*,
|
|
preferred_name: Optional[str] = None,
|
|
tags_hint: Optional[list[str]] = None,
|
|
) -> str:
|
|
"""
|
|
Ensure there is a verified local file for asset_hash in the correct Comfy folder.
|
|
|
|
Fast path:
|
|
- If any cache_state row has a file_path that exists, return it immediately.
|
|
Preference order is the oldest ID first for stability.
|
|
|
|
Slow path:
|
|
- Resolve remote location + placement tags.
|
|
- Download to the correct folder, verify hash, move into place.
|
|
- Ingest identity + cache state so future fast passes can skip hashing.
|
|
"""
|
|
lock = _FETCH_LOCKS.setdefault(asset_hash, asyncio.Lock())
|
|
async with lock:
|
|
# 1) If we already have any cache_state path present on disk, use it (oldest-first)
|
|
async with await create_session() as sess:
|
|
states = await list_cache_states_by_asset_hash(sess, asset_hash=asset_hash)
|
|
for s in states:
|
|
if s and s.file_path and os.path.isfile(s.file_path):
|
|
return s.file_path
|
|
|
|
# 2) Resolve remote location + placement hints (must include valid tags)
|
|
res = await resolve_asset(asset_hash)
|
|
if not res:
|
|
raise FileNotFoundError(f"No resolver/locations for {asset_hash}")
|
|
|
|
placement_tags = tags_hint or res.tags
|
|
if not placement_tags:
|
|
raise ValueError(f"Resolver did not provide placement tags for {asset_hash}")
|
|
|
|
name_hint = res.filename or preferred_name or asset_hash.replace(":", "_")
|
|
safe_name = _sanitize_filename(name_hint)
|
|
|
|
# 3) Map tags -> destination (strict: raises if invalid root or models category)
|
|
base_dir, subdirs = resolve_destination_from_tags(placement_tags) # may raise
|
|
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
|
os.makedirs(dest_dir, exist_ok=True)
|
|
|
|
final_path = os.path.abspath(os.path.join(dest_dir, safe_name))
|
|
ensure_within_base(final_path, base_dir)
|
|
|
|
# 4) If target path exists, try to reuse; else delete invalid cache
|
|
if os.path.exists(final_path) and os.path.isfile(final_path):
|
|
existing_digest = blake3_hash_sync(final_path)
|
|
if f"blake3:{existing_digest}" == asset_hash:
|
|
size_bytes = os.path.getsize(final_path)
|
|
mtime_ns = getattr(os.stat(final_path), "st_mtime_ns", int(os.path.getmtime(final_path) * 1_000_000_000))
|
|
async with await create_session() as sess:
|
|
await ingest_fs_asset(
|
|
sess,
|
|
asset_hash=asset_hash,
|
|
abs_path=final_path,
|
|
size_bytes=size_bytes,
|
|
mtime_ns=mtime_ns,
|
|
mime_type=None,
|
|
info_name=None,
|
|
tags=(),
|
|
)
|
|
await sess.commit()
|
|
return final_path
|
|
else:
|
|
# Invalid cache: remove before re-downloading
|
|
os.remove(final_path)
|
|
|
|
# 5) Download to temp next to destination
|
|
timeout = aiohttp.ClientTimeout(total=60 * 30)
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
async with session.get(res.download_url, headers=dict(res.headers)) as resp:
|
|
resp.raise_for_status()
|
|
cl = resp.headers.get("Content-Length")
|
|
if res.expected_size and cl and int(cl) != int(res.expected_size):
|
|
raise ValueError("server Content-Length does not match expected size")
|
|
with tempfile.NamedTemporaryFile("wb", delete=False, dir=dest_dir) as tmp:
|
|
tmp_path = tmp.name
|
|
async for chunk in resp.content.iter_chunked(8 * 1024 * 1024):
|
|
if chunk:
|
|
tmp.write(chunk)
|
|
|
|
# 6) Verify content hash
|
|
digest = blake3_hash_sync(tmp_path)
|
|
canonical = f"blake3:{digest}"
|
|
if canonical != asset_hash:
|
|
try:
|
|
os.remove(tmp_path)
|
|
finally:
|
|
raise ValueError(f"Hash mismatch: expected {asset_hash}, got {canonical}")
|
|
|
|
# 7) Atomically move into place
|
|
if os.path.exists(final_path):
|
|
os.remove(final_path)
|
|
os.replace(tmp_path, final_path)
|
|
|
|
# 8) Record identity + cache state (+ mime type)
|
|
size_bytes = os.path.getsize(final_path)
|
|
mtime_ns = getattr(os.stat(final_path), "st_mtime_ns", int(os.path.getmtime(final_path) * 1_000_000_000))
|
|
mime_type = mimetypes.guess_type(safe_name, strict=False)[0]
|
|
async with await create_session() as sess:
|
|
await ingest_fs_asset(
|
|
sess,
|
|
asset_hash=asset_hash,
|
|
abs_path=final_path,
|
|
size_bytes=size_bytes,
|
|
mtime_ns=mtime_ns,
|
|
mime_type=mime_type,
|
|
info_name=None,
|
|
tags=(),
|
|
)
|
|
await sess.commit()
|
|
|
|
return final_path
|