mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-19 11:03:00 +08:00
implemented assets scaner
This commit is contained in:
parent
a82577f64a
commit
d7464e9e73
@ -111,12 +111,12 @@ def upgrade() -> None:
|
||||
op.bulk_insert(
|
||||
tags_table,
|
||||
[
|
||||
# Core concept tags
|
||||
# Root folder tags
|
||||
{"name": "models", "tag_type": "system"},
|
||||
{"name": "input", "tag_type": "system"},
|
||||
{"name": "output", "tag_type": "system"},
|
||||
|
||||
# Canonical single-word types
|
||||
# Core tags
|
||||
{"name": "checkpoint", "tag_type": "system"},
|
||||
{"name": "lora", "tag_type": "system"},
|
||||
{"name": "vae", "tag_type": "system"},
|
||||
@ -130,9 +130,10 @@ def upgrade() -> None:
|
||||
{"name": "vae-approx", "tag_type": "system"},
|
||||
{"name": "gligen", "tag_type": "system"},
|
||||
{"name": "style-model", "tag_type": "system"},
|
||||
{"name": "photomaker", "tag_type": "system"},
|
||||
{"name": "classifier", "tag_type": "system"},
|
||||
{"name": "encoder", "tag_type": "system"},
|
||||
{"name": "decoder", "tag_type": "system"},
|
||||
# TODO: decide what to do with: photomaker, classifiers
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Optional
|
||||
from aiohttp import web
|
||||
from pydantic import ValidationError
|
||||
|
||||
from .. import assets_manager
|
||||
from .. import assets_manager, assets_scanner
|
||||
from . import schemas_in
|
||||
|
||||
|
||||
@ -225,6 +225,31 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/scan/schedule")
|
||||
async def schedule_asset_scan(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
except Exception:
|
||||
payload = {}
|
||||
|
||||
try:
|
||||
body = schemas_in.ScheduleAssetScanBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
states = await assets_scanner.schedule_scans(body.roots)
|
||||
return web.json_response(states.model_dump(mode="json"), status=202)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets/scan")
|
||||
async def get_asset_scan_status(request: web.Request) -> web.Response:
|
||||
root = request.query.get("root", "").strip().lower()
|
||||
states = assets_scanner.current_statuses()
|
||||
if root in {"models", "input", "output"}:
|
||||
states = [s for s in states.scans if s.root == root] # type: ignore
|
||||
return web.json_response(states.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
def register_assets_routes(app: web.Application) -> None:
|
||||
app.add_routes(ROUTES)
|
||||
|
||||
|
||||
@ -146,3 +146,29 @@ class TagsAdd(BaseModel):
|
||||
|
||||
class TagsRemove(TagsAdd):
|
||||
pass
|
||||
|
||||
|
||||
class ScheduleAssetScanBody(BaseModel):
|
||||
roots: list[Literal["models","input","output"]] = Field(default_factory=list)
|
||||
|
||||
@field_validator("roots", mode="before")
|
||||
@classmethod
|
||||
def _normalize_roots(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
items = [x.strip().lower() for x in v.split(",")]
|
||||
elif isinstance(v, list):
|
||||
items = []
|
||||
for x in v:
|
||||
if isinstance(x, str):
|
||||
items.extend([p.strip().lower() for p in x.split(",")])
|
||||
else:
|
||||
return []
|
||||
out = []
|
||||
seen = set()
|
||||
for r in items:
|
||||
if r in {"models","input","output"} and r not in seen:
|
||||
out.append(r)
|
||||
seen.add(r)
|
||||
return out
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
|
||||
|
||||
@ -87,3 +87,20 @@ class TagsRemove(BaseModel):
|
||||
removed: list[str] = Field(default_factory=list)
|
||||
not_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AssetScanStatus(BaseModel):
|
||||
scan_id: str
|
||||
root: Literal["models","input","output"]
|
||||
status: Literal["scheduled","running","completed","failed","cancelled"]
|
||||
scheduled_at: Optional[str] = None
|
||||
started_at: Optional[str] = None
|
||||
finished_at: Optional[str] = None
|
||||
discovered: int = 0
|
||||
processed: int = 0
|
||||
errors: int = 0
|
||||
last_error: Optional[str] = None
|
||||
|
||||
|
||||
class AssetScanStatusResponse(BaseModel):
|
||||
scans: list[AssetScanStatus] = Field(default_factory=list)
|
||||
|
||||
319
app/assets_scanner.py
Normal file
319
app/assets_scanner.py
Normal file
@ -0,0 +1,319 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional, Sequence
|
||||
|
||||
from . import assets_manager
|
||||
from .api import schemas_out
|
||||
|
||||
import folder_paths
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
RootType = Literal["models", "input", "output"]
|
||||
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
|
||||
|
||||
# We run at most one scan per root; overall max parallelism is therefore 3
|
||||
# We also bound per-scan ingestion concurrency to avoid swamping threads/DB
|
||||
DEFAULT_PER_SCAN_CONCURRENCY = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanProgress:
|
||||
scan_id: str
|
||||
root: RootType
|
||||
status: Literal["scheduled", "running", "completed", "failed", "cancelled"] = "scheduled"
|
||||
scheduled_at: float = field(default_factory=lambda: time.time())
|
||||
started_at: Optional[float] = None
|
||||
finished_at: Optional[float] = None
|
||||
|
||||
discovered: int = 0
|
||||
processed: int = 0
|
||||
errors: int = 0
|
||||
last_error: Optional[str] = None
|
||||
|
||||
# Optional details for diagnostics
|
||||
details: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
RUNNING_TASKS: dict[RootType, asyncio.Task] = {}
|
||||
PROGRESS_BY_ROOT: dict[RootType, ScanProgress] = {}
|
||||
|
||||
|
||||
def _new_scan_id(root: RootType) -> str:
|
||||
return f"scan-{root}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
def current_statuses() -> schemas_out.AssetScanStatusResponse:
|
||||
# make shallow copies to avoid external mutation
|
||||
states = [PROGRESS_BY_ROOT[r] for r in ALLOWED_ROOTS if r in PROGRESS_BY_ROOT]
|
||||
return schemas_out.AssetScanStatusResponse(
|
||||
scans=[
|
||||
schemas_out.AssetScanStatus(
|
||||
scan_id=s.scan_id,
|
||||
root=s.root,
|
||||
status=s.status,
|
||||
scheduled_at=_ts_to_iso(s.scheduled_at),
|
||||
started_at=_ts_to_iso(s.started_at),
|
||||
finished_at=_ts_to_iso(s.finished_at),
|
||||
discovered=s.discovered,
|
||||
processed=s.processed,
|
||||
errors=s.errors,
|
||||
last_error=s.last_error,
|
||||
)
|
||||
for s in states
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusResponse:
|
||||
"""Schedule scans for the provided roots; returns progress snapshots.
|
||||
|
||||
Rules:
|
||||
- Only roots in {models, input, output} are accepted.
|
||||
- If a root is already scanning, we do NOT enqueue another one. Status returned as-is.
|
||||
- Otherwise a new task is created and started immediately.
|
||||
- Files with zero size are skipped.
|
||||
"""
|
||||
normalized: list[RootType] = []
|
||||
seen = set()
|
||||
for r in roots or []:
|
||||
if not isinstance(r, str):
|
||||
continue
|
||||
rr = r.strip().lower()
|
||||
if rr in ALLOWED_ROOTS and rr not in seen:
|
||||
normalized.append(rr) # type: ignore
|
||||
seen.add(rr)
|
||||
if not normalized:
|
||||
normalized = list(ALLOWED_ROOTS) # schedule all by default
|
||||
|
||||
results: list[ScanProgress] = []
|
||||
for root in normalized:
|
||||
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
|
||||
# already running; return the live progress object
|
||||
results.append(PROGRESS_BY_ROOT[root])
|
||||
continue
|
||||
|
||||
# Create fresh progress
|
||||
prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled")
|
||||
PROGRESS_BY_ROOT[root] = prog
|
||||
|
||||
# Start task
|
||||
task = asyncio.create_task(_run_scan_for_root(root, prog), name=f"asset-scan:{root}")
|
||||
RUNNING_TASKS[root] = task
|
||||
results.append(prog)
|
||||
|
||||
return schemas_out.AssetScanStatusResponse(
|
||||
scans=[
|
||||
schemas_out.AssetScanStatus(
|
||||
scan_id=s.scan_id,
|
||||
root=s.root,
|
||||
status=s.status,
|
||||
scheduled_at=_ts_to_iso(s.scheduled_at),
|
||||
started_at=_ts_to_iso(s.started_at),
|
||||
finished_at=_ts_to_iso(s.finished_at),
|
||||
discovered=s.discovered,
|
||||
processed=s.processed,
|
||||
errors=s.errors,
|
||||
last_error=s.last_error,
|
||||
)
|
||||
for s in results
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def _run_scan_for_root(root: RootType, prog: ScanProgress) -> None:
|
||||
prog.started_at = time.time()
|
||||
prog.status = "running"
|
||||
try:
|
||||
if root == "models":
|
||||
await _scan_models(prog)
|
||||
elif root == "input":
|
||||
base = folder_paths.get_input_directory()
|
||||
await _scan_directory_tree(base, root, prog)
|
||||
elif root == "output":
|
||||
base = folder_paths.get_output_directory()
|
||||
await _scan_directory_tree(base, root, prog)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported root: {root}")
|
||||
prog.status = "completed"
|
||||
except asyncio.CancelledError:
|
||||
prog.status = "cancelled"
|
||||
raise
|
||||
except Exception as exc:
|
||||
LOGGER.exception("Asset scan failed for %s", root)
|
||||
prog.status = "failed"
|
||||
prog.errors += 1
|
||||
prog.last_error = str(exc)
|
||||
finally:
|
||||
prog.finished_at = time.time()
|
||||
# Drop the task entry if it's the current one
|
||||
t = RUNNING_TASKS.get(root)
|
||||
if t and t.done():
|
||||
RUNNING_TASKS.pop(root, None)
|
||||
|
||||
|
||||
async def _scan_models(prog: ScanProgress) -> None:
|
||||
# Iterate all folder_names whose base paths lie under the Comfy 'models' directory
|
||||
models_root = os.path.abspath(os.path.join(folder_paths.base_path, "models"))
|
||||
|
||||
# Build list of (folder_name, base_paths[]) that are configured for this category.
|
||||
# If any path for the category lies under 'models', include the category.
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
for name, (paths, _exts) in folder_paths.folder_names_and_paths.items():
|
||||
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
|
||||
targets.append((name, paths))
|
||||
|
||||
plans: list[tuple[str, str]] = [] # (abs_path, file_name_for_tags)
|
||||
per_bucket: dict[str, int] = {}
|
||||
|
||||
for folder_name, bases in targets:
|
||||
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||
count_valid = 0
|
||||
|
||||
for rel_path in rel_files:
|
||||
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||
if not abs_path:
|
||||
continue
|
||||
abs_path = os.path.abspath(abs_path)
|
||||
|
||||
# Extra safety: ensure file is inside one of the allowed base paths
|
||||
allowed = False
|
||||
for base in bases:
|
||||
base_abs = os.path.abspath(base)
|
||||
try:
|
||||
common = os.path.commonpath([abs_path, base_abs])
|
||||
except ValueError:
|
||||
common = "" # Different drives on Windows
|
||||
if common == base_abs:
|
||||
allowed = True
|
||||
break
|
||||
if not allowed:
|
||||
LOGGER.warning("Skipping file outside models base: %s", abs_path)
|
||||
continue
|
||||
|
||||
try:
|
||||
if not os.path.getsize(abs_path):
|
||||
continue
|
||||
except OSError as e:
|
||||
LOGGER.warning("Could not stat %s: %s – skipping", abs_path, e)
|
||||
continue
|
||||
|
||||
file_name_for_tags = os.path.join(folder_name, rel_path)
|
||||
plans.append((abs_path, file_name_for_tags))
|
||||
count_valid += 1
|
||||
|
||||
if count_valid:
|
||||
per_bucket[folder_name] = per_bucket.get(folder_name, 0) + count_valid
|
||||
|
||||
prog.discovered = len(plans)
|
||||
for k, v in per_bucket.items():
|
||||
prog.details[k] = prog.details.get(k, 0) + v
|
||||
|
||||
if not plans:
|
||||
LOGGER.info("Model scan %s: nothing to ingest", prog.scan_id)
|
||||
return
|
||||
|
||||
sem = asyncio.Semaphore(DEFAULT_PER_SCAN_CONCURRENCY)
|
||||
tasks: list[asyncio.Task] = []
|
||||
|
||||
for abs_path, name_for_tags in plans:
|
||||
async def worker(fp_abs: str = abs_path, fn_rel: str = name_for_tags):
|
||||
try:
|
||||
# Offload sync ingestion into a thread
|
||||
await asyncio.to_thread(
|
||||
assets_manager.populate_db_with_asset,
|
||||
["models"],
|
||||
fn_rel,
|
||||
fp_abs,
|
||||
)
|
||||
except Exception as e:
|
||||
prog.errors += 1
|
||||
prog.last_error = str(e)
|
||||
LOGGER.debug("Error ingesting %s: %s", fp_abs, e)
|
||||
finally:
|
||||
prog.processed += 1
|
||||
sem.release()
|
||||
|
||||
await sem.acquire()
|
||||
tasks.append(asyncio.create_task(worker()))
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
LOGGER.info(
|
||||
"Model scan %s finished: discovered=%d processed=%d errors=%d",
|
||||
prog.scan_id, prog.discovered, prog.processed, prog.errors
|
||||
)
|
||||
|
||||
|
||||
def _count_files_in_tree(base_abs: str) -> int:
|
||||
if not os.path.isdir(base_abs):
|
||||
return 0
|
||||
total = 0
|
||||
for _dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
||||
total += len(filenames)
|
||||
return total
|
||||
|
||||
|
||||
async def _scan_directory_tree(base_dir: str, root: RootType, prog: ScanProgress) -> None:
|
||||
# Guard: base_dir must be a directory
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
if not os.path.isdir(base_abs):
|
||||
LOGGER.info("Scan root %s skipped: base directory missing: %s", root, base_abs)
|
||||
return
|
||||
|
||||
prog.discovered = _count_files_in_tree(base_abs)
|
||||
|
||||
sem = asyncio.Semaphore(DEFAULT_PER_SCAN_CONCURRENCY)
|
||||
tasks: list[asyncio.Task] = []
|
||||
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
||||
for name in filenames:
|
||||
rel = os.path.relpath(os.path.join(dirpath, name), base_abs)
|
||||
abs_path = os.path.join(base_abs, rel)
|
||||
# Safety: ensure within base
|
||||
try:
|
||||
if os.path.commonpath([os.path.abspath(abs_path), base_abs]) != base_abs:
|
||||
LOGGER.warning("Skipping path outside root %s: %s", root, abs_path)
|
||||
continue
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
async def worker(fp_abs: str = abs_path, fn_rel: str = rel):
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
assets_manager.populate_db_with_asset,
|
||||
[root],
|
||||
fn_rel,
|
||||
fp_abs,
|
||||
)
|
||||
except Exception as e:
|
||||
prog.errors += 1
|
||||
prog.last_error = str(e)
|
||||
finally:
|
||||
prog.processed += 1
|
||||
sem.release()
|
||||
|
||||
await sem.acquire()
|
||||
tasks.append(asyncio.create_task(worker()))
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
LOGGER.info(
|
||||
"%s scan %s finished: discovered=%d processed=%d errors=%d",
|
||||
root.capitalize(), prog.scan_id, prog.discovered, prog.processed, prog.errors
|
||||
)
|
||||
|
||||
|
||||
def _ts_to_iso(ts: Optional[float]) -> Optional[str]:
|
||||
if ts is None:
|
||||
return None
|
||||
# interpret ts as seconds since epoch UTC and return naive UTC (consistent with other models)
|
||||
try:
|
||||
return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat()
|
||||
except Exception:
|
||||
return None
|
||||
Loading…
Reference in New Issue
Block a user