mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-30 19:07:25 +08:00
Co-authored-by: Amp <amp@ampcode.com> Amp-Thread-ID: https://ampcode.com/threads/T-019e5117-c707-729d-bf98-dce718fe64d5
216 lines
7.6 KiB
Python
216 lines
7.6 KiB
Python
"""Shared utilities for database query modules."""
|
|
|
|
import os
|
|
from decimal import Decimal
|
|
from pathlib import Path
|
|
from typing import Iterable, Sequence
|
|
|
|
import sqlalchemy as sa
|
|
from sqlalchemy import exists
|
|
|
|
import folder_paths
|
|
from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag
|
|
from app.assets.helpers import normalize_tags
|
|
|
|
MAX_BIND_PARAMS = 800
|
|
|
|
# Mirrors app.model_manager.MODEL_FOLDER_BLACKLIST: these names are bootstrapped
|
|
# into folder_names_and_paths by core, but /api/experiment/models does not expose
|
|
# them as model folders. Keep this local to avoid reintroducing the asset query
|
|
# package initialization cycle from importing service/model-manager code here.
|
|
_NON_MODEL_FOLDER_NAMES = frozenset({"configs", "custom_nodes"})
|
|
|
|
|
|
def calculate_rows_per_statement(cols: int) -> int:
|
|
"""Calculate how many rows can fit in one statement given column count."""
|
|
return max(1, MAX_BIND_PARAMS // max(1, cols))
|
|
|
|
|
|
def iter_chunks(seq, n: int):
|
|
"""Yield successive n-sized chunks from seq."""
|
|
for i in range(0, len(seq), n):
|
|
yield seq[i : i + n]
|
|
|
|
|
|
def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]:
|
|
"""Yield chunks of rows sized to fit within bind param limits."""
|
|
if not rows:
|
|
return
|
|
yield from iter_chunks(rows, calculate_rows_per_statement(cols_per_row))
|
|
|
|
|
|
def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
|
"""Build owner visibility predicate for reads.
|
|
|
|
Owner-less rows are visible to everyone.
|
|
"""
|
|
owner_id = (owner_id or "").strip()
|
|
if owner_id == "":
|
|
return AssetReference.owner_id == ""
|
|
return AssetReference.owner_id.in_(["", owner_id])
|
|
|
|
|
|
def build_prefix_like_conditions(
|
|
prefixes: list[str],
|
|
) -> list[sa.sql.ColumnElement]:
|
|
"""Build case-exact conditions for matching file paths under directory prefixes."""
|
|
conds = []
|
|
for p in prefixes:
|
|
base = os.path.abspath(p)
|
|
if not base.endswith(os.sep):
|
|
base += os.sep
|
|
conds.append(sa.func.substr(AssetReference.file_path, 1, len(base)) == base)
|
|
return conds
|
|
|
|
|
|
def _get_comfy_model_folders() -> list[tuple[str, list[str]]]:
|
|
"""Return registered model folder names and roots without importing services.
|
|
|
|
This intentionally stays local to the query layer to avoid importing
|
|
``app.assets.services`` from ``app.assets.database.queries.common``, which
|
|
creates a package initialization cycle.
|
|
"""
|
|
targets: list[tuple[str, list[str]]] = []
|
|
for name, values in folder_paths.folder_names_and_paths.items():
|
|
if name in _NON_MODEL_FOLDER_NAMES:
|
|
continue
|
|
paths, _exts = values[0], values[1]
|
|
if paths:
|
|
targets.append((name, paths))
|
|
return targets
|
|
|
|
|
|
def apply_asset_path_filters(
|
|
stmt: sa.sql.Select,
|
|
asset_type: str | None = None,
|
|
model_folder: str | None = None,
|
|
) -> sa.sql.Select:
|
|
"""Filter references using their real filesystem/root registration context."""
|
|
if asset_type is None and model_folder is None:
|
|
return stmt
|
|
if model_folder and asset_type != "model":
|
|
raise ValueError("model_folder can only be used with asset_type=model")
|
|
|
|
registered_model_folders = _get_comfy_model_folders()
|
|
prefixes: list[str] = []
|
|
exclude_prefixes: list[str] = []
|
|
if model_folder:
|
|
for folder_name, paths in registered_model_folders:
|
|
if folder_name == model_folder:
|
|
prefixes.extend(paths)
|
|
break
|
|
if not prefixes:
|
|
return stmt.where(sa.false())
|
|
|
|
target_bases = [os.path.abspath(path) for path in prefixes]
|
|
for folder_name, paths in registered_model_folders:
|
|
if folder_name == model_folder:
|
|
continue
|
|
for path in paths:
|
|
path_abs = os.path.abspath(path)
|
|
if any(
|
|
Path(path_abs).is_relative_to(target_base)
|
|
and path_abs != target_base
|
|
for target_base in target_bases
|
|
):
|
|
exclude_prefixes.append(path)
|
|
elif asset_type == "model":
|
|
for _folder_name, paths in registered_model_folders:
|
|
prefixes.extend(paths)
|
|
elif asset_type == "input":
|
|
prefixes = [folder_paths.get_input_directory()]
|
|
elif asset_type == "output":
|
|
prefixes = [folder_paths.get_output_directory()]
|
|
elif asset_type == "temp":
|
|
prefixes = [folder_paths.get_temp_directory()]
|
|
|
|
conditions = build_prefix_like_conditions(prefixes)
|
|
if not conditions:
|
|
return stmt.where(sa.false())
|
|
|
|
clause = sa.or_(*conditions)
|
|
if asset_type in {"input", "output", "temp"}:
|
|
model_prefixes = [
|
|
path for _folder_name, paths in registered_model_folders for path in paths
|
|
]
|
|
model_conditions = build_prefix_like_conditions(model_prefixes)
|
|
if model_conditions:
|
|
clause = sa.and_(clause, sa.not_(sa.or_(*model_conditions)))
|
|
elif exclude_prefixes:
|
|
exclude_conditions = build_prefix_like_conditions(exclude_prefixes)
|
|
if exclude_conditions:
|
|
clause = sa.and_(clause, sa.not_(sa.or_(*exclude_conditions)))
|
|
|
|
return stmt.where(clause)
|
|
|
|
|
|
def apply_tag_filters(
|
|
stmt: sa.sql.Select,
|
|
include_tags: Sequence[str] | None = None,
|
|
exclude_tags: Sequence[str] | None = None,
|
|
) -> sa.sql.Select:
|
|
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
|
include_tags = normalize_tags(include_tags)
|
|
exclude_tags = normalize_tags(exclude_tags)
|
|
|
|
if include_tags:
|
|
for tag_name in include_tags:
|
|
stmt = stmt.where(
|
|
exists().where(
|
|
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
|
& (AssetReferenceTag.tag_name == tag_name)
|
|
)
|
|
)
|
|
|
|
if exclude_tags:
|
|
stmt = stmt.where(
|
|
~exists().where(
|
|
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
|
& (AssetReferenceTag.tag_name.in_(exclude_tags))
|
|
)
|
|
)
|
|
return stmt
|
|
|
|
|
|
def apply_metadata_filter(
|
|
stmt: sa.sql.Select,
|
|
metadata_filter: dict | None = None,
|
|
) -> sa.sql.Select:
|
|
"""Apply filters using asset_reference_meta projection table."""
|
|
if not metadata_filter:
|
|
return stmt
|
|
|
|
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
|
return sa.exists().where(
|
|
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
|
AssetReferenceMeta.key == key,
|
|
*preds,
|
|
)
|
|
|
|
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
|
if value is None:
|
|
return sa.not_(
|
|
sa.exists().where(
|
|
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
|
AssetReferenceMeta.key == key,
|
|
)
|
|
)
|
|
|
|
if isinstance(value, bool):
|
|
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
|
|
if isinstance(value, (int, float, Decimal)):
|
|
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
|
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
|
|
if isinstance(value, str):
|
|
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
|
|
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
|
|
|
|
for k, v in metadata_filter.items():
|
|
if isinstance(v, list):
|
|
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
|
if ors:
|
|
stmt = stmt.where(sa.or_(*ors))
|
|
else:
|
|
stmt = stmt.where(_exists_clause_for_value(k, v))
|
|
return stmt
|