mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-26 01:39:25 +08:00
spike: add asset system tags and duplicate refs
Amp-Thread-ID: https://ampcode.com/threads/T-019e5117-c707-729d-bf98-dce718fe64d5 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
f74df348b6
commit
133e0a6d92
226
BE-1092-compare_asset_mappings.py
Normal file
226
BE-1092-compare_asset_mappings.py
Normal file
@ -0,0 +1,226 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Compare FE MODEL_NODE_MAPPINGS against Core /object_info and /api/assets."""
|
||||
# ruff: noqa: T201
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
DEFAULT_FRONTEND = Path(
|
||||
"/home/simon/comfy/ComfyUI_frontend/.wt/dante01yoon/"
|
||||
"pr-12411-integration-do-not-merge-m1-fe-asset-sta"
|
||||
)
|
||||
DEFAULT_MAPPING = DEFAULT_FRONTEND / "src/platform/assets/mappings/modelNodeMappings.ts"
|
||||
|
||||
|
||||
def fetch_json(base_url: str, path: str, params: dict[str, Any] | None = None) -> Any:
|
||||
url = base_url.rstrip("/") + path
|
||||
if params:
|
||||
url += "?" + urllib.parse.urlencode(params)
|
||||
with urllib.request.urlopen(url, timeout=20) as response:
|
||||
return json.load(response)
|
||||
|
||||
|
||||
def fetch_object_info(base_url: str) -> dict[str, Any]:
|
||||
try:
|
||||
return fetch_json(base_url, "/object_info")
|
||||
except urllib.error.HTTPError as error:
|
||||
if error.code != 404:
|
||||
raise
|
||||
return fetch_json(base_url, "/api/object_info")
|
||||
|
||||
|
||||
def fetch_all_model_assets(base_url: str) -> list[dict[str, Any]]:
|
||||
assets: list[dict[str, Any]] = []
|
||||
offset = 0
|
||||
limit = 500
|
||||
while True:
|
||||
page = fetch_json(
|
||||
base_url,
|
||||
"/api/assets",
|
||||
{"asset_type": "model", "limit": limit, "offset": offset},
|
||||
)
|
||||
batch = page.get("assets", [])
|
||||
assets.extend(batch)
|
||||
if not page.get("has_more") or not batch:
|
||||
return assets
|
||||
offset += len(batch)
|
||||
|
||||
|
||||
def parse_model_node_mappings(path: Path) -> list[tuple[str, str, str]]:
|
||||
text = path.read_text()
|
||||
pattern = re.compile(
|
||||
r"\[\s*(['\"])(.*?)\1\s*,\s*(['\"])(.*?)\3\s*,\s*(['\"])(.*?)\5\s*\]",
|
||||
re.DOTALL,
|
||||
)
|
||||
return [(m.group(2), m.group(4), m.group(6)) for m in pattern.finditer(text)]
|
||||
|
||||
|
||||
def get_combo_options(node_def: dict[str, Any], input_key: str) -> list[str] | None:
|
||||
inputs = node_def.get("input", {})
|
||||
for section in ("required", "optional"):
|
||||
spec = inputs.get(section, {}).get(input_key)
|
||||
if spec is None:
|
||||
continue
|
||||
if not isinstance(spec, list) or not spec:
|
||||
return None
|
||||
input_type = spec[0]
|
||||
options = (
|
||||
spec[1].get("options")
|
||||
if len(spec) > 1 and isinstance(spec[1], dict)
|
||||
else None
|
||||
)
|
||||
if isinstance(input_type, list):
|
||||
return [str(item) for item in input_type]
|
||||
if input_type == "COMBO" and isinstance(options, list):
|
||||
return [str(item) for item in options]
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def asset_values_by_folder(assets: list[dict[str, Any]]) -> dict[str, set[str]]:
|
||||
values: dict[str, set[str]] = defaultdict(set)
|
||||
for asset in assets:
|
||||
display_name = asset.get("display_name") or asset.get("name")
|
||||
if not display_name:
|
||||
continue
|
||||
folders = asset.get("model_folders") or []
|
||||
if not folders and asset.get("model_folder"):
|
||||
folders = [asset["model_folder"]]
|
||||
for folder in folders:
|
||||
values[str(folder)].add(str(display_name))
|
||||
return values
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--base-url", default="http://127.0.0.1:6410")
|
||||
parser.add_argument("--mapping", type=Path, default=DEFAULT_MAPPING)
|
||||
parser.add_argument(
|
||||
"--json", action="store_true", help="Emit machine-readable JSON"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
mappings = parse_model_node_mappings(args.mapping)
|
||||
try:
|
||||
object_info = fetch_object_info(args.base_url)
|
||||
assets = fetch_all_model_assets(args.base_url)
|
||||
except urllib.error.URLError as error:
|
||||
print(f"Failed to reach {args.base_url}: {error}", file=sys.stderr)
|
||||
return 2
|
||||
|
||||
assets_by_folder = asset_values_by_folder(assets)
|
||||
rows: list[dict[str, Any]] = []
|
||||
|
||||
for model_folder, node_class, input_key in mappings:
|
||||
node_def = object_info.get(node_class)
|
||||
if not node_def:
|
||||
rows.append(
|
||||
{
|
||||
"status": "missing_node",
|
||||
"model_folder": model_folder,
|
||||
"node_class": node_class,
|
||||
"input_key": input_key,
|
||||
}
|
||||
)
|
||||
continue
|
||||
if not input_key:
|
||||
rows.append(
|
||||
{
|
||||
"status": "no_input_key",
|
||||
"model_folder": model_folder,
|
||||
"node_class": node_class,
|
||||
"input_key": input_key,
|
||||
}
|
||||
)
|
||||
continue
|
||||
options = get_combo_options(node_def, input_key)
|
||||
if options is None:
|
||||
rows.append(
|
||||
{
|
||||
"status": "missing_or_non_combo_input",
|
||||
"model_folder": model_folder,
|
||||
"node_class": node_class,
|
||||
"input_key": input_key,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
object_values = set(options)
|
||||
asset_values = assets_by_folder.get(model_folder, set())
|
||||
rows.append(
|
||||
{
|
||||
"status": "match" if object_values == asset_values else "diff",
|
||||
"model_folder": model_folder,
|
||||
"node_class": node_class,
|
||||
"input_key": input_key,
|
||||
"object_info_count": len(object_values),
|
||||
"asset_count": len(asset_values),
|
||||
"missing_from_assets": sorted(object_values - asset_values),
|
||||
"extra_in_assets": sorted(asset_values - object_values),
|
||||
}
|
||||
)
|
||||
|
||||
statuses = dict(
|
||||
sorted(
|
||||
(status, sum(1 for r in rows if r["status"] == status))
|
||||
for status in {r["status"] for r in rows}
|
||||
)
|
||||
)
|
||||
summary = {
|
||||
"mapping_file": str(args.mapping),
|
||||
"base_url": args.base_url,
|
||||
"mapping_rows": len(mappings),
|
||||
"model_assets": len(assets),
|
||||
"asset_folders": sorted(assets_by_folder),
|
||||
"statuses": statuses,
|
||||
"rows": rows,
|
||||
}
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(summary, indent=2, sort_keys=True))
|
||||
return 0
|
||||
|
||||
print(f"Mapping file: {summary['mapping_file']}")
|
||||
print(f"Base URL: {args.base_url}")
|
||||
print(f"Mappings: {len(mappings)} model assets: {len(assets)}")
|
||||
print(f"Asset folders: {', '.join(summary['asset_folders']) or '(none)'}")
|
||||
print("Statuses: " + ", ".join(f"{k}={v}" for k, v in statuses.items()))
|
||||
print()
|
||||
|
||||
for row in rows:
|
||||
if row["status"] == "match":
|
||||
print(
|
||||
f"MATCH {row['model_folder']} -> {row['node_class']}.{row['input_key']} ({row['asset_count']})"
|
||||
)
|
||||
elif row["status"] == "diff":
|
||||
print(
|
||||
f"DIFF {row['model_folder']} -> {row['node_class']}.{row['input_key']} "
|
||||
f"object_info={row['object_info_count']} assets={row['asset_count']}"
|
||||
)
|
||||
if row["missing_from_assets"]:
|
||||
print(
|
||||
" missing_from_assets: "
|
||||
+ ", ".join(row["missing_from_assets"][:10])
|
||||
)
|
||||
if row["extra_in_assets"]:
|
||||
print(" extra_in_assets: " + ", ".join(row["extra_in_assets"][:10]))
|
||||
else:
|
||||
print(
|
||||
f"{row['status'].upper()} {row['model_folder']} -> {row['node_class']}.{row['input_key']}"
|
||||
)
|
||||
|
||||
return 1 if any(row["status"] == "diff" for row in rows) else 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@ -42,6 +42,7 @@ from app.assets.services import (
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.services.path_utils import (
|
||||
get_asset_system_tags,
|
||||
get_comfy_models_folders,
|
||||
get_stored_asset_response_path_info,
|
||||
)
|
||||
@ -146,21 +147,24 @@ def _get_asset_path_info(
|
||||
def _build_preview_url_from_view(
|
||||
asset_type: str | None,
|
||||
user_metadata: dict[str, Any] | None,
|
||||
display_name: str | None = None,
|
||||
fallback_tags: list[str] | None = None,
|
||||
) -> str | None:
|
||||
"""Build a /api/view preview URL from path-derived type and filename metadata."""
|
||||
if not user_metadata:
|
||||
return None
|
||||
filename = user_metadata.get("filename")
|
||||
filename = display_name
|
||||
if not filename and user_metadata:
|
||||
filename = user_metadata.get("filename")
|
||||
if not filename:
|
||||
return None
|
||||
|
||||
if asset_type in {"input", "output"}:
|
||||
if asset_type in {"input", "output", "temp"}:
|
||||
view_type = asset_type
|
||||
elif fallback_tags and "input" in fallback_tags:
|
||||
view_type = "input"
|
||||
elif fallback_tags and "output" in fallback_tags:
|
||||
view_type = "output"
|
||||
elif fallback_tags and "temp" in fallback_tags:
|
||||
view_type = "temp"
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -194,6 +198,7 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
|
||||
preview_url = _build_preview_url_from_view(
|
||||
preview_path_info.asset_type if preview_path_info else None,
|
||||
preview_detail.ref.user_metadata,
|
||||
display_name=preview_path_info.display_name if preview_path_info else None,
|
||||
fallback_tags=preview_detail.tags,
|
||||
)
|
||||
else:
|
||||
@ -202,6 +207,7 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
|
||||
preview_url = _build_preview_url_from_view(
|
||||
path_info.asset_type if path_info else None,
|
||||
result.ref.user_metadata,
|
||||
display_name=path_info.display_name if path_info else None,
|
||||
fallback_tags=result.tags,
|
||||
)
|
||||
|
||||
@ -216,6 +222,14 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
|
||||
model_folders = path_info.model_folders
|
||||
file_path = path_info.file_path
|
||||
display_name = path_info.display_name
|
||||
tags = list(
|
||||
dict.fromkeys(
|
||||
[
|
||||
*result.tags,
|
||||
*get_asset_system_tags(asset_type, model_folder, model_folders),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return schemas_out.Asset(
|
||||
id=result.ref.id,
|
||||
@ -228,7 +242,7 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
|
||||
model_folder=model_folder,
|
||||
model_folders=model_folders,
|
||||
asset_type=asset_type,
|
||||
tags=result.tags,
|
||||
tags=tags,
|
||||
preview_url=preview_url,
|
||||
preview_id=result.ref.preview_id,
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
|
||||
@ -29,7 +29,7 @@ from app.assets.database.queries import (
|
||||
update_reference_updated_at,
|
||||
)
|
||||
from app.assets.helpers import select_best_live_path
|
||||
from app.assets.services.path_utils import compute_relative_filename
|
||||
from app.assets.services.path_utils import compute_relative_filename, get_asset_system_tags
|
||||
from app.assets.services.schemas import (
|
||||
AssetData,
|
||||
AssetDetailResult,
|
||||
@ -104,7 +104,10 @@ def update_asset_metadata(
|
||||
set_reference_tags(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
tags=tags,
|
||||
tags=[
|
||||
*tags,
|
||||
*get_asset_system_tags(ref.asset_type, ref.model_folder),
|
||||
],
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
@ -6,8 +6,10 @@ from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset
|
||||
from app.assets.database.queries import (
|
||||
bulk_insert_assets,
|
||||
bulk_insert_references_ignore_conflicts,
|
||||
@ -91,6 +93,13 @@ class MetadataRow(TypedDict):
|
||||
val_json: dict[str, Any] | None
|
||||
|
||||
|
||||
def _get_asset_ids_by_hashes(session: Session, hashes: list[str]) -> dict[str, str]:
|
||||
if not hashes:
|
||||
return {}
|
||||
rows = session.execute(select(Asset.hash, Asset.id).where(Asset.hash.in_(hashes)))
|
||||
return {hash_value: asset_id for hash_value, asset_id in rows if hash_value}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BulkInsertResult:
|
||||
"""Result of bulk asset insertion."""
|
||||
@ -144,7 +153,9 @@ def batch_insert_seed_assets(
|
||||
asset_rows: list[AssetRow] = []
|
||||
reference_rows: list[ReferenceRow] = []
|
||||
path_to_asset_id: dict[str, str] = {}
|
||||
asset_id_to_ref_data: dict[str, dict] = {}
|
||||
path_to_created_asset_id: dict[str, str] = {}
|
||||
path_to_hash: dict[str, str | None] = {}
|
||||
path_to_ref_data: dict[str, dict] = {}
|
||||
absolute_path_list: list[str] = []
|
||||
|
||||
for spec in specs:
|
||||
@ -153,6 +164,8 @@ def batch_insert_seed_assets(
|
||||
reference_id = str(uuid.uuid4())
|
||||
absolute_path_list.append(absolute_path)
|
||||
path_to_asset_id[absolute_path] = asset_id
|
||||
path_to_created_asset_id[absolute_path] = asset_id
|
||||
path_to_hash[absolute_path] = spec.get("hash")
|
||||
|
||||
mime_type = spec.get("mime_type")
|
||||
try:
|
||||
@ -200,7 +213,7 @@ def batch_insert_seed_assets(
|
||||
}
|
||||
)
|
||||
|
||||
asset_id_to_ref_data[asset_id] = {
|
||||
path_to_ref_data[absolute_path] = {
|
||||
"reference_id": reference_id,
|
||||
"tags": spec["tags"],
|
||||
"filename": spec["fname"],
|
||||
@ -209,24 +222,35 @@ def batch_insert_seed_assets(
|
||||
|
||||
bulk_insert_assets(session, asset_rows)
|
||||
|
||||
# Filter reference rows to only those whose assets were actually inserted
|
||||
# (assets with duplicate hashes are silently dropped by ON CONFLICT DO NOTHING)
|
||||
inserted_asset_ids = get_existing_asset_ids(
|
||||
session, [r["asset_id"] for r in reference_rows]
|
||||
)
|
||||
reference_rows = [r for r in reference_rows if r["asset_id"] in inserted_asset_ids]
|
||||
asset_ids_by_hash = _get_asset_ids_by_hashes(
|
||||
session, [h for h in path_to_hash.values() if h]
|
||||
)
|
||||
resolved_reference_rows: list[ReferenceRow] = []
|
||||
for row in reference_rows:
|
||||
if row["asset_id"] in inserted_asset_ids:
|
||||
resolved_reference_rows.append(row)
|
||||
continue
|
||||
existing_asset_id = asset_ids_by_hash.get(path_to_hash[row["file_path"]])
|
||||
if existing_asset_id:
|
||||
row["asset_id"] = existing_asset_id
|
||||
path_to_asset_id[row["file_path"]] = existing_asset_id
|
||||
resolved_reference_rows.append(row)
|
||||
reference_rows = resolved_reference_rows
|
||||
|
||||
bulk_insert_references_ignore_conflicts(session, reference_rows)
|
||||
restore_references_by_paths(session, absolute_path_list)
|
||||
winning_paths = get_references_by_paths_and_asset_ids(session, path_to_asset_id)
|
||||
|
||||
inserted_paths = {
|
||||
path
|
||||
for path in absolute_path_list
|
||||
if path_to_asset_id[path] in inserted_asset_ids
|
||||
}
|
||||
inserted_paths = {row["file_path"] for row in reference_rows}
|
||||
losing_paths = inserted_paths - winning_paths
|
||||
lost_asset_ids = [path_to_asset_id[path] for path in losing_paths]
|
||||
lost_asset_ids = [
|
||||
path_to_created_asset_id[path]
|
||||
for path in losing_paths
|
||||
if path_to_created_asset_id[path] in inserted_asset_ids
|
||||
]
|
||||
|
||||
if lost_asset_ids:
|
||||
delete_assets_by_ids(session, lost_asset_ids)
|
||||
@ -240,7 +264,7 @@ def batch_insert_seed_assets(
|
||||
|
||||
# Get reference IDs for winners
|
||||
winning_ref_ids = [
|
||||
asset_id_to_ref_data[path_to_asset_id[path]]["reference_id"]
|
||||
path_to_ref_data[path]["reference_id"]
|
||||
for path in winning_paths
|
||||
]
|
||||
inserted_ref_ids = get_reference_ids_by_ids(session, winning_ref_ids)
|
||||
@ -250,8 +274,7 @@ def batch_insert_seed_assets(
|
||||
|
||||
if inserted_ref_ids:
|
||||
for path in winning_paths:
|
||||
asset_id = path_to_asset_id[path]
|
||||
ref_data = asset_id_to_ref_data[asset_id]
|
||||
ref_data = path_to_ref_data[path]
|
||||
ref_id = ref_data["reference_id"]
|
||||
|
||||
if ref_id not in inserted_ref_ids:
|
||||
|
||||
@ -16,7 +16,7 @@ from app.assets.database.queries import (
|
||||
get_asset_by_hash,
|
||||
get_reference_by_file_path,
|
||||
get_reference_tags,
|
||||
get_or_create_reference,
|
||||
insert_reference,
|
||||
reference_exists,
|
||||
remove_missing_tag_for_asset_id,
|
||||
set_reference_metadata,
|
||||
@ -32,6 +32,9 @@ from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
get_asset_path_info,
|
||||
get_asset_response_path_info,
|
||||
get_asset_system_tags,
|
||||
get_asset_system_tags_from_tags,
|
||||
get_name_and_tags_from_asset_path,
|
||||
resolve_destination_from_tags,
|
||||
validate_path_within_base,
|
||||
@ -71,10 +74,16 @@ def _ingest_file_from_path(
|
||||
reference_id: str | None = None
|
||||
|
||||
with create_session() as session:
|
||||
system_tags: list[str] = []
|
||||
try:
|
||||
path_info = get_asset_path_info(locator)
|
||||
path_info = get_asset_response_path_info(locator)
|
||||
asset_type = path_info.asset_type
|
||||
model_folder = path_info.model_folder
|
||||
system_tags = get_asset_system_tags(
|
||||
path_info.asset_type,
|
||||
path_info.model_folder,
|
||||
path_info.model_folders,
|
||||
)
|
||||
except ValueError:
|
||||
asset_type = None
|
||||
model_folder = None
|
||||
@ -109,7 +118,7 @@ def _ingest_file_from_path(
|
||||
if preview_id and ref.preview_id != preview_id:
|
||||
ref.preview_id = preview_id
|
||||
|
||||
norm = normalize_tags(list(tags))
|
||||
norm = normalize_tags([*list(tags), *system_tags])
|
||||
if norm:
|
||||
if require_existing_tags:
|
||||
validate_tags_exist(session, norm)
|
||||
@ -259,6 +268,7 @@ def _register_existing_asset(
|
||||
preview_id: str | None = None,
|
||||
) -> RegisterAssetResult:
|
||||
user_metadata = user_metadata or {}
|
||||
tags = normalize_tags([*(tags or []), *get_asset_system_tags_from_tags(tags or [])])
|
||||
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
@ -272,27 +282,15 @@ def _register_existing_asset(
|
||||
if not reference_exists(session, preview_id):
|
||||
preview_id = None
|
||||
|
||||
ref, ref_created = get_or_create_reference(
|
||||
ref = insert_reference(
|
||||
session,
|
||||
asset_id=asset.id,
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
preview_id=preview_id,
|
||||
)
|
||||
|
||||
if not ref_created:
|
||||
if preview_id and ref.preview_id != preview_id:
|
||||
ref.preview_id = preview_id
|
||||
|
||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||
result = RegisterAssetResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_names,
|
||||
created=False,
|
||||
)
|
||||
session.commit()
|
||||
return result
|
||||
if not ref:
|
||||
raise RuntimeError("Failed to create AssetReference for existing asset")
|
||||
|
||||
new_meta = dict(user_metadata)
|
||||
computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None
|
||||
@ -306,7 +304,7 @@ def _register_existing_asset(
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
if tags:
|
||||
set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
|
||||
@ -33,6 +33,32 @@ class AssetPathContext(AssetPathInfo):
|
||||
relative_path: str
|
||||
|
||||
|
||||
def get_asset_system_tags(
|
||||
asset_type: str | None,
|
||||
model_folder: str | None = None,
|
||||
model_folders: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
tags: list[str] = []
|
||||
if asset_type:
|
||||
tags.append(f"asset_type:{asset_type}")
|
||||
if asset_type == "model":
|
||||
folders = model_folders or ([model_folder] if model_folder else [])
|
||||
tags.extend(f"model_folder:{folder}" for folder in folders if folder)
|
||||
return normalize_tags(tags)
|
||||
|
||||
|
||||
def get_asset_system_tags_from_tags(tags: list[str] | None) -> list[str]:
|
||||
if not tags:
|
||||
return []
|
||||
root = tags[0].strip().lower()
|
||||
if root == "models":
|
||||
model_folder = tags[1] if len(tags) > 1 else None
|
||||
return get_asset_system_tags("model", model_folder=model_folder)
|
||||
if root in {"input", "output", "temp"}:
|
||||
return get_asset_system_tags(root)
|
||||
return []
|
||||
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""Build list of (folder_name, base_paths[]) for all model locations.
|
||||
|
||||
@ -435,4 +461,16 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
parent_parts = [
|
||||
part for part in p.parent.parts if part not in (".", "..", p.anchor)
|
||||
]
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
response_info = get_asset_response_path_info(file_path)
|
||||
tags = normalize_tags(
|
||||
[
|
||||
root_category,
|
||||
*parent_parts,
|
||||
*get_asset_system_tags(
|
||||
response_info.asset_type,
|
||||
response_info.model_folder,
|
||||
response_info.model_folders,
|
||||
),
|
||||
]
|
||||
)
|
||||
return p.name, list(dict.fromkeys(tags))
|
||||
|
||||
@ -242,6 +242,8 @@ class TestBuildAssetResponsePathFields:
|
||||
assert asset.model_folders == ["checkpoints"]
|
||||
assert asset.display_name == "sub/model.safetensors"
|
||||
assert asset.file_path == "models/checkpoints/sub/model.safetensors"
|
||||
assert "asset_type:model" in asset.tags
|
||||
assert "model_folder:checkpoints" in asset.tags
|
||||
|
||||
def test_model_response_includes_plural_model_folder_memberships(
|
||||
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
@ -276,6 +278,9 @@ class TestBuildAssetResponsePathFields:
|
||||
assert asset.model_folders == ["checkpoints", "loras", "vae"]
|
||||
assert asset.display_name == "checkpoints/model.safetensors"
|
||||
assert asset.file_path == "models/checkpoints/checkpoints/model.safetensors"
|
||||
assert "model_folder:checkpoints" in asset.tags
|
||||
assert "model_folder:loras" in asset.tags
|
||||
assert "model_folder:vae" in asset.tags
|
||||
|
||||
def test_input_output_response_fields_use_persisted_classification(
|
||||
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
|
||||
@ -102,6 +102,53 @@ class TestBatchInsertSeedAssets:
|
||||
assert len(refs) == 1
|
||||
assert refs[0].name == "first"
|
||||
|
||||
def test_duplicate_hashes_preserve_distinct_file_references(
|
||||
self, session: Session, temp_dir: Path
|
||||
):
|
||||
path_a = temp_dir / "models" / "checkpoints" / "same.safetensors"
|
||||
path_b = temp_dir / "models" / "loras" / "same.safetensors"
|
||||
path_a.parent.mkdir(parents=True)
|
||||
path_b.parent.mkdir(parents=True)
|
||||
path_a.write_bytes(b"same content")
|
||||
path_b.write_bytes(b"same content")
|
||||
asset_hash = "blake3:" + "a" * 64
|
||||
|
||||
specs: list[SeedAssetSpec] = [
|
||||
{
|
||||
"abs_path": str(path_a),
|
||||
"size_bytes": 12,
|
||||
"mtime_ns": 123,
|
||||
"info_name": "checkpoint copy",
|
||||
"tags": ["models", "checkpoints", "asset_type:model"],
|
||||
"fname": "checkpoints/same.safetensors",
|
||||
"metadata": None,
|
||||
"hash": asset_hash,
|
||||
"mime_type": "application/safetensors",
|
||||
},
|
||||
{
|
||||
"abs_path": str(path_b),
|
||||
"size_bytes": 12,
|
||||
"mtime_ns": 456,
|
||||
"info_name": "lora copy",
|
||||
"tags": ["models", "loras", "asset_type:model"],
|
||||
"fname": "loras/same.safetensors",
|
||||
"metadata": None,
|
||||
"hash": asset_hash,
|
||||
"mime_type": "application/safetensors",
|
||||
},
|
||||
]
|
||||
|
||||
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
||||
|
||||
assert result.inserted_refs == 2
|
||||
assets = session.query(Asset).all()
|
||||
refs = session.query(AssetReference).order_by(AssetReference.name).all()
|
||||
assert len(assets) == 1
|
||||
assert len(refs) == 2
|
||||
assert {ref.asset_id for ref in refs} == {assets[0].id}
|
||||
assert {ref.file_path for ref in refs} == {str(path_a), str(path_b)}
|
||||
assert {ref.name for ref in refs} == {"checkpoint copy", "lora copy"}
|
||||
|
||||
def test_various_model_mime_types(self, session: Session, temp_dir: Path):
|
||||
"""Verify various model file types get correct mime_type."""
|
||||
test_cases = [
|
||||
|
||||
@ -9,6 +9,7 @@ import pytest
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
get_comfy_models_folders,
|
||||
get_name_and_tags_from_asset_path,
|
||||
get_model_folder_matches,
|
||||
get_asset_category_and_relative_path,
|
||||
get_asset_path_info,
|
||||
@ -84,6 +85,27 @@ class TestGetAssetCategoryAndRelativePath:
|
||||
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||
assert cat == "models"
|
||||
|
||||
def test_model_path_tags_include_namespaced_memberships(self, fake_dirs):
|
||||
f = fake_dirs["models"] / "subdir" / "model.safetensors"
|
||||
f.parent.mkdir()
|
||||
f.touch()
|
||||
|
||||
_name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
|
||||
assert "models" in tags
|
||||
assert "checkpoints" in tags
|
||||
assert "asset_type:model" in tags
|
||||
assert "model_folder:checkpoints" in tags
|
||||
|
||||
def test_output_path_tags_include_namespaced_asset_type(self, fake_dirs):
|
||||
f = fake_dirs["output"] / "result.png"
|
||||
f.touch()
|
||||
|
||||
_name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
|
||||
assert "output" in tags
|
||||
assert "asset_type:output" in tags
|
||||
|
||||
def test_unknown_path_raises(self, fake_dirs):
|
||||
with pytest.raises(ValueError, match="not within"):
|
||||
get_asset_category_and_relative_path("/some/random/path.png")
|
||||
|
||||
@ -65,6 +65,61 @@ def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_
|
||||
assert b2["asset_hash"] == h
|
||||
|
||||
|
||||
def test_upload_fastpath_same_hash_and_name_creates_distinct_references(
|
||||
http: requests.Session, api_base: str
|
||||
):
|
||||
files = {"file": ("seed.bin", b"same-content" * 64, "application/octet-stream")}
|
||||
form = {
|
||||
"tags": json.dumps(["output", "unit-tests", "seed"]),
|
||||
"name": "same-output.png",
|
||||
"user_metadata": json.dumps({"filename": "same-output.png", "run": "seed"}),
|
||||
}
|
||||
r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
seed = r1.json()
|
||||
assert r1.status_code == 201, seed
|
||||
asset_hash = seed["asset_hash"]
|
||||
|
||||
def create_output_copy(run: str):
|
||||
parts = [
|
||||
("hash", (None, asset_hash)),
|
||||
("tags", (None, json.dumps(["output", "unit-tests", run]))),
|
||||
("name", (None, "same-output.png")),
|
||||
(
|
||||
"user_metadata",
|
||||
(None, json.dumps({"filename": "same-output.png", "run": run})),
|
||||
),
|
||||
]
|
||||
response = http.post(api_base + "/api/assets", files=parts, timeout=120)
|
||||
body = response.json()
|
||||
assert response.status_code == 200, body
|
||||
return body
|
||||
|
||||
first = create_output_copy("run-a")
|
||||
second = create_output_copy("run-b")
|
||||
|
||||
assert first["asset_hash"] == second["asset_hash"] == asset_hash
|
||||
assert first["name"] == second["name"] == "same-output.png"
|
||||
assert first["id"] != second["id"]
|
||||
|
||||
detail_a = http.get(f"{api_base}/api/assets/{first['id']}", timeout=120).json()
|
||||
detail_b = http.get(f"{api_base}/api/assets/{second['id']}", timeout=120).json()
|
||||
assert detail_a["user_metadata"]["run"] == "run-a"
|
||||
assert detail_b["user_metadata"]["run"] == "run-b"
|
||||
assert "run-a" in detail_a["tags"]
|
||||
assert "run-b" in detail_b["tags"]
|
||||
assert "asset_type:output" in detail_a["tags"]
|
||||
assert "asset_type:output" in detail_b["tags"]
|
||||
|
||||
filtered = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "asset_type:output,unit-tests"},
|
||||
timeout=120,
|
||||
).json()
|
||||
filtered_ids = {asset["id"] for asset in filtered["assets"]}
|
||||
assert first["id"] in filtered_ids
|
||||
assert second["id"] in filtered_ids
|
||||
|
||||
|
||||
def test_upload_fastpath_with_known_hash_and_file(
|
||||
http: requests.Session, api_base: str
|
||||
):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user