mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-22 04:20:49 +08:00
fixed metadata[filename] feature + new tests for this
This commit is contained in:
parent
a2ec1f7637
commit
6cfa94ec58
@ -140,7 +140,7 @@ def ensure_within_base(candidate: str, base: str) -> None:
|
||||
raise ValueError("invalid destination path")
|
||||
|
||||
|
||||
def compute_model_relative_filename(file_path: str) -> Optional[str]:
|
||||
def compute_relative_filename(file_path: str) -> Optional[str]:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
using forward slashes, eg:
|
||||
@ -155,16 +155,16 @@ def compute_model_relative_filename(file_path: str) -> Optional[str]:
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
if root_category != "models":
|
||||
return None
|
||||
|
||||
p = Path(rel_path)
|
||||
# parts[0] is the well-known category (eg "checkpoints" or "text_encoders")
|
||||
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
|
||||
if not parts:
|
||||
return None
|
||||
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||
return "/".join(inside) # normalize to POSIX style for portability
|
||||
|
||||
if root_category == "models":
|
||||
# parts[0] is the category ("checkpoints", "vae", etc) – drop it
|
||||
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def list_tree(base_dir: str) -> list[str]:
|
||||
|
||||
@ -30,6 +30,7 @@ from .database.services import (
|
||||
list_asset_infos_page,
|
||||
list_cache_states_by_asset_id,
|
||||
list_tags_with_usage,
|
||||
pick_best_live_path,
|
||||
remove_tags_from_asset_info,
|
||||
set_asset_info_preview,
|
||||
touch_asset_info_by_id,
|
||||
@ -177,11 +178,7 @@ async def resolve_asset_content_for_download(
|
||||
|
||||
info, asset = pair
|
||||
states = await list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = ""
|
||||
for s in states:
|
||||
if s and s.file_path and os.path.isfile(s.file_path):
|
||||
abs_path = s.file_path
|
||||
break
|
||||
abs_path = pick_best_live_path(states)
|
||||
if not abs_path:
|
||||
raise FileNotFoundError
|
||||
|
||||
|
||||
@ -32,6 +32,7 @@ from .queries import (
|
||||
get_asset_info_by_id,
|
||||
get_cache_state_by_asset_id,
|
||||
list_cache_states_by_asset_id,
|
||||
pick_best_live_path,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -39,6 +40,7 @@ __all__ = [
|
||||
"asset_exists_by_hash", "get_asset_by_hash", "get_asset_info_by_id", "asset_info_exists_for_asset_id",
|
||||
"get_cache_state_by_asset_id",
|
||||
"list_cache_states_by_asset_id",
|
||||
"pick_best_live_path",
|
||||
# info
|
||||
"list_asset_infos_page", "create_asset_info_for_existing_asset", "set_asset_info_tags",
|
||||
"update_asset_info_full", "replace_asset_info_metadata_projection",
|
||||
|
||||
@ -12,7 +12,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import noload
|
||||
|
||||
from ..._assets_helpers import compute_model_relative_filename, normalize_tags
|
||||
from ..._assets_helpers import compute_relative_filename, normalize_tags
|
||||
from ...storage import hashing as hashing_mod
|
||||
from ..helpers import (
|
||||
ensure_tags_exist,
|
||||
@ -21,6 +21,7 @@ from ..helpers import (
|
||||
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, Tag
|
||||
from ..timeutil import utcnow
|
||||
from .info import replace_asset_info_metadata_projection
|
||||
from .queries import list_cache_states_by_asset_id, pick_best_live_path
|
||||
|
||||
|
||||
async def check_fs_asset_exists_quick(
|
||||
@ -106,6 +107,15 @@ async def ensure_seed_for_path(
|
||||
session.add(info)
|
||||
await session.flush()
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
computed = compute_relative_filename(locator)
|
||||
if computed:
|
||||
await replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
user_metadata={"filename": computed},
|
||||
)
|
||||
|
||||
want = normalize_tags(tags)
|
||||
if want:
|
||||
await ensure_tags_exist(session, want, tag_type="user")
|
||||
@ -265,6 +275,8 @@ async def compute_hash_and_dedup_for_cache_state(
|
||||
if int(remaining or 0) == 0:
|
||||
await session.delete(asset)
|
||||
await session.flush()
|
||||
else:
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=asset.id)
|
||||
return None
|
||||
|
||||
digest = await hashing_mod.blake3_hash(path)
|
||||
@ -316,6 +328,7 @@ async def compute_hash_and_dedup_for_cache_state(
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=canonical.id)
|
||||
await session.flush()
|
||||
return canonical.id
|
||||
|
||||
@ -343,6 +356,7 @@ async def compute_hash_and_dedup_for_cache_state(
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=canonical.id)
|
||||
await session.flush()
|
||||
return canonical.id
|
||||
# If we got here, the integrity error was not about hash uniqueness
|
||||
@ -353,6 +367,7 @@ async def compute_hash_and_dedup_for_cache_state(
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=this_asset.id)
|
||||
await session.flush()
|
||||
return this_asset.id
|
||||
|
||||
@ -364,6 +379,7 @@ async def compute_hash_and_dedup_for_cache_state(
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=this_asset.id)
|
||||
await session.flush()
|
||||
return this_asset.id
|
||||
|
||||
@ -385,11 +401,10 @@ async def compute_hash_and_dedup_for_cache_state(
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=target_id)
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=target_id)
|
||||
await session.flush()
|
||||
return target_id
|
||||
|
||||
except Exception:
|
||||
# Propagate; caller records the error and continues the worker.
|
||||
raise
|
||||
|
||||
|
||||
@ -663,15 +678,8 @@ async def ingest_fs_asset(
|
||||
|
||||
# metadata["filename"] hack
|
||||
if out["asset_info_id"] is not None:
|
||||
primary_path = (
|
||||
await session.execute(
|
||||
select(AssetCacheState.file_path)
|
||||
.where(AssetCacheState.asset_id == asset.id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
.limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
computed_filename = compute_model_relative_filename(primary_path) if primary_path else None
|
||||
primary_path = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
computed_filename = compute_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
current_meta = existing_info.user_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
@ -760,3 +768,26 @@ async def list_cache_states_with_asset_under_prefixes(
|
||||
)
|
||||
).all()
|
||||
return [(r[0], r[1], int(r[2] or 0)) for r in rows]
|
||||
|
||||
|
||||
async def _recompute_and_apply_filename_for_asset(session: AsyncSession, *, asset_id: str) -> None:
|
||||
"""Compute filename from the first *existing* cache state path and apply it to all AssetInfo (if changed)."""
|
||||
try:
|
||||
primary_path = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset_id))
|
||||
if not primary_path:
|
||||
return
|
||||
new_filename = compute_relative_filename(primary_path)
|
||||
if not new_filename:
|
||||
return
|
||||
infos = (
|
||||
await session.execute(select(AssetInfo).where(AssetInfo.asset_id == asset_id))
|
||||
).scalars().all()
|
||||
for info in infos:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") == new_filename:
|
||||
continue
|
||||
updated = dict(current_meta)
|
||||
updated["filename"] = new_filename
|
||||
await replace_asset_info_metadata_projection(session, asset_info_id=info.id, user_metadata=updated)
|
||||
except Exception:
|
||||
logging.exception("Failed to recompute filename metadata for asset %s", asset_id)
|
||||
|
||||
@ -8,7 +8,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import contains_eager, noload
|
||||
|
||||
from ..._assets_helpers import compute_model_relative_filename, normalize_tags
|
||||
from ..._assets_helpers import compute_relative_filename, normalize_tags
|
||||
from ..helpers import (
|
||||
apply_metadata_filter,
|
||||
apply_tag_filters,
|
||||
@ -18,7 +18,11 @@ from ..helpers import (
|
||||
)
|
||||
from ..models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from ..timeutil import utcnow
|
||||
from .queries import get_asset_by_hash, get_cache_state_by_asset_id
|
||||
from .queries import (
|
||||
get_asset_by_hash,
|
||||
list_cache_states_by_asset_id,
|
||||
pick_best_live_path,
|
||||
)
|
||||
|
||||
|
||||
async def list_asset_infos_page(
|
||||
@ -196,9 +200,9 @@ async def create_asset_info_for_existing_asset(
|
||||
new_meta = dict(user_metadata or {})
|
||||
computed_filename = None
|
||||
try:
|
||||
state = await get_cache_state_by_asset_id(session, asset_id=asset.id)
|
||||
if state and state.file_path:
|
||||
computed_filename = compute_model_relative_filename(state.file_path)
|
||||
p = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
if computed_filename:
|
||||
@ -280,9 +284,9 @@ async def update_asset_info_full(
|
||||
|
||||
computed_filename = None
|
||||
try:
|
||||
state = await get_cache_state_by_asset_id(session, asset_id=info.asset_id)
|
||||
if state and state.file_path:
|
||||
computed_filename = compute_model_relative_filename(state.file_path)
|
||||
p = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=info.asset_id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -57,3 +58,19 @@ async def list_cache_states_by_asset_id(
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
def pick_best_live_path(states: Union[list[AssetCacheState], Sequence[AssetCacheState]]) -> str:
|
||||
"""
|
||||
Return the best on-disk path among cache states:
|
||||
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||
2) Otherwise, pick the first path that exists.
|
||||
3) Otherwise return empty string.
|
||||
"""
|
||||
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
|
||||
if not alive:
|
||||
return ""
|
||||
for s in alive:
|
||||
if not getattr(s, "needs_verify", False):
|
||||
return s.file_path
|
||||
return alive[0].file_path
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from conftest import trigger_sync_seed_assets
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -218,3 +220,97 @@ async def test_concurrent_delete_same_asset_info_single_204(
|
||||
# The resource must be gone.
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
|
||||
assert rg.status == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_metadata_filename_is_set_for_seed_asset_without_hash(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
):
|
||||
"""Seed ingest (no hash yet) must compute user_metadata['filename'] immediately."""
|
||||
scope = f"seedmeta-{uuid.uuid4().hex[:6]}"
|
||||
name = "seed_filename.bin"
|
||||
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope / "a" / "b"
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
fp = base / name
|
||||
fp.write_bytes(b"Z" * 2048)
|
||||
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}", "name_contains": name},
|
||||
) as r1:
|
||||
body = await r1.json()
|
||||
assert r1.status == 200, body
|
||||
matches = [a for a in body.get("assets", []) if a.get("name") == name]
|
||||
assert matches, "Seed asset should be visible after sync"
|
||||
assert matches[0].get("asset_hash") is None # still a seed
|
||||
aid = matches[0]["id"]
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as r2:
|
||||
detail = await r2.json()
|
||||
assert r2.status == 200, detail
|
||||
filename = (detail.get("user_metadata") or {}).get("filename")
|
||||
expected = str(fp.relative_to(comfy_tmp_base_dir / root)).replace("\\", "/")
|
||||
assert filename == expected, f"expected filename={expected}, got {filename!r}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_metadata_filename_computed_and_updated_on_retarget(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""
|
||||
1) Ingest under {root}/unit-tests/<scope>/a/b/<name> -> filename reflects relative path.
|
||||
2) Retarget by copying to {root}/unit-tests/<scope>/x/<new_name>, remove old file,
|
||||
run fast pass + scan -> filename updates to new relative path.
|
||||
"""
|
||||
scope = f"meta-fn-{uuid.uuid4().hex[:6]}"
|
||||
name1 = "compute_metadata_filename.png"
|
||||
name2 = "compute_changed_metadata_filename.png"
|
||||
data = make_asset_bytes(name1, 2100)
|
||||
|
||||
# Upload into nested path a/b
|
||||
a = await asset_factory(name1, [root, "unit-tests", scope, "a", "b"], {}, data)
|
||||
aid = a["id"]
|
||||
|
||||
root_base = comfy_tmp_base_dir / root
|
||||
p1 = root_base / "unit-tests" / scope / "a" / "b" / name1
|
||||
assert p1.exists()
|
||||
|
||||
# filename at ingest should be the path relative to root
|
||||
rel1 = str(p1.relative_to(root_base)).replace("\\", "/")
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as g1:
|
||||
d1 = await g1.json()
|
||||
assert g1.status == 200, d1
|
||||
fn1 = d1["user_metadata"].get("filename")
|
||||
assert fn1 == rel1
|
||||
|
||||
# Retarget: copy to x/<name2>, remove old, then sync+scan
|
||||
p2 = root_base / "unit-tests" / scope / "x" / name2
|
||||
p2.parent.mkdir(parents=True, exist_ok=True)
|
||||
p2.write_bytes(data)
|
||||
if p1.exists():
|
||||
p1.unlink()
|
||||
|
||||
await trigger_sync_seed_assets(http, api_base) # seed the new path
|
||||
await run_scan_and_wait(root) # verify/hash and reconcile
|
||||
|
||||
# filename should now point at x/<name2>
|
||||
rel2 = str(p2.relative_to(root_base)).replace("\\", "/")
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as g2:
|
||||
d2 = await g2.json()
|
||||
assert g2.status == 200, d2
|
||||
fn2 = d2["user_metadata"].get("filename")
|
||||
assert fn2 == rel2
|
||||
|
||||
Loading…
Reference in New Issue
Block a user