diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 68126b6a5..1155fa503 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -401,12 +401,16 @@ async def upload_asset(request: web.Request) -> web.Response: ) if spec.tags and spec.tags[0] == "models": + # tag[1] may be the standalone category ("checkpoints") or the + # slash-joined shape ("checkpoints/flux/...") that + # `get_name_and_tags_from_asset_path` and cloud both emit. Match + # `resolve_destination_from_tags` by extracting the first segment. + category = spec.tags[1].split("/", 1)[0] if len(spec.tags) >= 2 else "" if ( len(spec.tags) < 2 - or spec.tags[1] not in folder_paths.folder_names_and_paths + or category not in folder_paths.folder_names_and_paths ): delete_temp_file_if_exists(parsed.tmp_path) - category = spec.tags[1] if len(spec.tags) >= 2 else "" return _build_error_response( 400, "INVALID_BODY", f"unknown models category '{category}'" ) diff --git a/app/assets/database/queries/tags.py b/app/assets/database/queries/tags.py index 2501665ac..c891364b4 100644 --- a/app/assets/database/queries/tags.py +++ b/app/assets/database/queries/tags.py @@ -21,7 +21,12 @@ from app.assets.database.queries.common import ( build_visible_owner_clause, iter_row_chunks, ) -from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags +from app.assets.helpers import ( + escape_sql_like_string, + expand_bucket_prefixes, + get_utc_now, + normalize_tags, +) @dataclass(frozen=True) @@ -96,7 +101,7 @@ def set_reference_tags( tags: Sequence[str], origin: str = "manual", ) -> SetTagsResult: - desired = normalize_tags(tags) + desired = expand_bucket_prefixes(normalize_tags(tags)) current = set(get_reference_tags(session, reference_id)) @@ -149,7 +154,7 @@ def add_tags_to_reference( if not ref: raise ValueError(f"AssetReference {reference_id} not found") - norm = normalize_tags(tags) + norm = expand_bucket_prefixes(normalize_tags(tags)) if not norm: total = get_reference_tags(session, reference_id=reference_id) return AddTagsResult(added=[], already_present=[], total_tags=total) diff --git a/app/assets/helpers.py b/app/assets/helpers.py index 3798f3933..2f9e9a0ce 100644 --- a/app/assets/helpers.py +++ b/app/assets/helpers.py @@ -47,6 +47,50 @@ def normalize_tags(tags: list[str] | None) -> list[str]: return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip())) +def _known_bucket_prefixes() -> set[str]: + """Lowercased model-category names eligible for standalone-prefix + expansion. Tags whose first slash segment matches one of these get + the bucket inserted as a separate token, so FE filters like + ``include_tags=models,checkpoints`` keep matching even when the + asset lives in a nested subfolder (`models/checkpoints/flux/foo`). + + Bare user labels with slashes whose first segment is not a registered + bucket (e.g. ``my-org/team-a``) pass through unchanged. + """ + try: + import folder_paths + + return { + name.lower() + for name in folder_paths.folder_names_and_paths.keys() + if name != "custom_nodes" + } + except Exception: + return set() + + +def expand_bucket_prefixes(tags: list[str]) -> list[str]: + """Insert standalone bucket tokens after any slash-joined tag whose + first segment is a registered model category. Preserves caller order + and is idempotent (existing bucket tokens are not duplicated). + """ + if not tags: + return list(tags) + buckets = _known_bucket_prefixes() + if not buckets: + return list(tags) + seen = set(tags) + result: list[str] = [] + for t in tags: + result.append(t) + if "/" in t: + prefix = t.split("/", 1)[0] + if prefix.lower() in buckets and prefix not in seen: + result.append(prefix) + seen.add(prefix) + return result + + def validate_blake3_hash(s: str) -> str: """Validate and normalize a blake3 hash string. diff --git a/app/assets/services/bulk_ingest.py b/app/assets/services/bulk_ingest.py index 7c673632f..a5ff19c6b 100644 --- a/app/assets/services/bulk_ingest.py +++ b/app/assets/services/bulk_ingest.py @@ -13,13 +13,14 @@ from app.assets.database.queries import ( bulk_insert_references_ignore_conflicts, bulk_insert_tags_and_meta, delete_assets_by_ids, + ensure_tags_exist, get_existing_asset_ids, get_reference_ids_by_ids, get_references_by_paths_and_asset_ids, get_unreferenced_unhashed_asset_ids, restore_references_by_paths, ) -from app.assets.helpers import get_utc_now +from app.assets.helpers import expand_bucket_prefixes, get_utc_now if TYPE_CHECKING: from app.assets.services.metadata_extract import ExtractedMetadata @@ -239,7 +240,8 @@ def batch_insert_seed_assets( # this, every tag in a bulk-insert batch shares current_time and # the tag_name tiebreaker sorts them alphabetically — putting the # subpath tag ahead of "models" since "c"/"d"/"l" < "m". - for tag_idx, tag in enumerate(ref_data["tags"]): + ref_tags = expand_bucket_prefixes(ref_data["tags"]) + for tag_idx, tag in enumerate(ref_tags): tag_rows.append( { "asset_reference_id": ref_id, @@ -267,6 +269,16 @@ def batch_insert_seed_assets( } ) + if tag_rows: + # Bucket-prefix expansion may have introduced tags the caller did + # not register via the upstream tag_pool (e.g. `checkpoints` for a + # nested `checkpoints/flux/foo` path). Pre-register the full set so + # the AssetReferenceTag.tag_name FK is satisfied; the underlying + # insert is ON CONFLICT DO NOTHING so re-registration is idempotent. + ensure_tags_exist( + session, {row["tag_name"] for row in tag_rows}, tag_type="user" + ) + bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows) return BulkInsertResult( diff --git a/tests-unit/assets_test/queries/test_asset_info.py b/tests-unit/assets_test/queries/test_asset_info.py index 784030ba4..9d75fd42d 100644 --- a/tests-unit/assets_test/queries/test_asset_info.py +++ b/tests-unit/assets_test/queries/test_asset_info.py @@ -197,7 +197,10 @@ class TestTagRetrievalOrder: result = fetch_reference_asset_and_tags(session, ref.id) assert result is not None _, _, tags = result - assert tags == ["models", "diffusers/kolors/text_encoder"] + # Bucket-prefix expansion appends the standalone `diffusers` token + # at path-tier (microsecond stagger) so FE set-membership filters + # match nested category paths. + assert tags == ["models", "diffusers/kolors/text_encoder", "diffusers"] def test_add_tags_to_reference_lands_after_path_tags(self, session: Session): ref = self._make_ref(session) @@ -256,7 +259,14 @@ class TestTagRetrievalOrder: session.commit() _, tag_map, _ = list_references_page(session) - assert tag_map[ref.id] == ["models", "loras/my/custom/path", "second-tag"] + # `loras` is expanded from the nested category path; user-added + # tags trail behind it via the microsecond stagger. + assert tag_map[ref.id] == [ + "models", + "loras/my/custom/path", + "loras", + "second-tag", + ] class TestFetchReferenceAssetAndTags: diff --git a/tests-unit/assets_test/queries/test_tags.py b/tests-unit/assets_test/queries/test_tags.py index 4ed99aa37..bf71d2962 100644 --- a/tests-unit/assets_test/queries/test_tags.py +++ b/tests-unit/assets_test/queries/test_tags.py @@ -160,6 +160,120 @@ class TestAddTagsToReference: add_tags_to_reference(session, reference_id="nonexistent", tags=["x"]) +class TestBucketPrefixExpansion: + """The standalone bucket token must appear in the asset's tag set for + nested category paths so FE filters like + `include_tags=models,checkpoints` continue to match. + """ + + def test_set_reference_tags_inserts_bucket_for_nested_path( + self, session: Session + ): + asset = _make_asset(session, "hash-nested") + ref = _make_reference(session, asset) + + result = set_reference_tags( + session, + reference_id=ref.id, + tags=["models", "checkpoints/flux"], + ) + session.commit() + + assert set(result.total) == {"models", "checkpoints/flux", "checkpoints"} + stored = get_reference_tags(session, reference_id=ref.id) + # tag[1] keeps the slash-joined positional contract; the standalone + # bucket lands after it via path-tier microsecond stagger so user + # tags remain at the tail. + assert stored[:3] == ["models", "checkpoints/flux", "checkpoints"] + + def test_set_reference_tags_idempotent_on_replay(self, session: Session): + asset = _make_asset(session, "hash-replay") + ref = _make_reference(session, asset) + + set_reference_tags( + session, + reference_id=ref.id, + tags=["models", "checkpoints/flux"], + ) + # Replay with the same caller-supplied set; expansion is already + # baked in, so nothing should be added or removed. + result = set_reference_tags( + session, + reference_id=ref.id, + tags=["models", "checkpoints/flux"], + ) + session.commit() + + assert result.added == [] + assert result.removed == [] + assert set(result.total) == {"models", "checkpoints/flux", "checkpoints"} + + def test_add_tags_to_reference_expands_bucket(self, session: Session): + asset = _make_asset(session, "hash-add") + ref = _make_reference(session, asset) + + result = add_tags_to_reference( + session, + reference_id=ref.id, + tags=["loras/style/v2"], + ) + session.commit() + + assert set(result.added) == {"loras/style/v2", "loras"} + stored = get_reference_tags(session, reference_id=ref.id) + assert "loras" in stored + assert "loras/style/v2" in stored + + def test_add_tags_does_not_duplicate_existing_bucket(self, session: Session): + asset = _make_asset(session, "hash-dedupe") + ref = _make_reference(session, asset) + + add_tags_to_reference( + session, reference_id=ref.id, tags=["models", "checkpoints"] + ) + result = add_tags_to_reference( + session, reference_id=ref.id, tags=["checkpoints/flux"] + ) + session.commit() + + # `checkpoints` was already there from the first add; only the + # slash-joined token is genuinely new. + assert result.added == ["checkpoints/flux"] + assert "checkpoints" in result.already_present + + def test_flat_category_is_unaffected(self, session: Session): + asset = _make_asset(session, "hash-flat") + ref = _make_reference(session, asset) + + result = set_reference_tags( + session, + reference_id=ref.id, + tags=["models", "checkpoints"], + ) + session.commit() + + assert set(result.total) == {"models", "checkpoints"} + assert get_reference_tags(session, reference_id=ref.id) == [ + "models", + "checkpoints", + ] + + def test_unknown_prefix_passes_through(self, session: Session): + asset = _make_asset(session, "hash-user") + ref = _make_reference(session, asset) + + # `my-org` isn't a registered bucket — the slash-joined user tag + # should not trigger bucket expansion. + result = set_reference_tags( + session, + reference_id=ref.id, + tags=["my-org/team-a"], + ) + session.commit() + + assert result.total == ["my-org/team-a"] + + class TestRemoveTagsFromReference: def test_removes_tags(self, session: Session): asset = _make_asset(session, "hash1") diff --git a/tests-unit/assets_test/services/test_bulk_ingest.py b/tests-unit/assets_test/services/test_bulk_ingest.py index 26e22a01d..4ba6db717 100644 --- a/tests-unit/assets_test/services/test_bulk_ingest.py +++ b/tests-unit/assets_test/services/test_bulk_ingest.py @@ -4,7 +4,7 @@ from pathlib import Path from sqlalchemy.orm import Session -from app.assets.database.models import Asset, AssetReference +from app.assets.database.models import Asset, AssetReference, AssetReferenceTag from app.assets.services.bulk_ingest import SeedAssetSpec, batch_insert_seed_assets @@ -102,6 +102,82 @@ class TestBatchInsertSeedAssets: assert asset.mime_type == expected_mime, f"Expected {expected_mime} for {filename}, got {asset.mime_type}" +class TestBucketPrefixExpansionOnIngest: + """Path-scanning ingest must persist the standalone bucket token for + nested category paths so the FE set-membership filter + (`include_tags=models,checkpoints`) matches assets organized into + subfolders (`models/checkpoints/flux/foo.safetensors`). + """ + + def test_nested_path_inserts_standalone_bucket( + self, session: Session, temp_dir: Path + ): + file_path = temp_dir / "flux.safetensors" + file_path.write_bytes(b"content") + + specs: list[SeedAssetSpec] = [ + { + "abs_path": str(file_path), + "size_bytes": 7, + "mtime_ns": 1234567890000000000, + "info_name": "flux", + # Shape emitted by get_name_and_tags_from_asset_path for a + # nested model path. + "tags": ["models", "checkpoints/flux"], + "fname": "flux.safetensors", + "metadata": None, + "hash": None, + "mime_type": "application/safetensors", + } + ] + + result = batch_insert_seed_assets(session, specs=specs, owner_id="") + + assert result.inserted_refs == 1 + ref = session.query(AssetReference).filter_by(name="flux").one() + stored = [ + row.tag_name + for row in session.query(AssetReferenceTag) + .filter_by(asset_reference_id=ref.id) + .order_by(AssetReferenceTag.added_at.asc()) + .all() + ] + assert stored == ["models", "checkpoints/flux", "checkpoints"] + + def test_flat_path_remains_two_tags( + self, session: Session, temp_dir: Path + ): + file_path = temp_dir / "vanilla.safetensors" + file_path.write_bytes(b"content") + + specs: list[SeedAssetSpec] = [ + { + "abs_path": str(file_path), + "size_bytes": 7, + "mtime_ns": 1234567890000000000, + "info_name": "vanilla", + "tags": ["models", "checkpoints"], + "fname": "vanilla.safetensors", + "metadata": None, + "hash": None, + "mime_type": "application/safetensors", + } + ] + + batch_insert_seed_assets(session, specs=specs, owner_id="") + + ref = session.query(AssetReference).filter_by(name="vanilla").one() + stored = { + row.tag_name + for row in session.query(AssetReferenceTag) + .filter_by(asset_reference_id=ref.id) + .all() + } + # Dedupe means flat layouts don't pick up a redundant `checkpoints` + # row — tag[1] already serves both positional and set-membership. + assert stored == {"models", "checkpoints"} + + class TestMetadataExtraction: def test_extracts_mime_type_for_model_files(self, temp_dir: Path): """Verify metadata extraction returns correct mime_type for model files.""" diff --git a/tests-unit/assets_test/test_helpers.py b/tests-unit/assets_test/test_helpers.py new file mode 100644 index 000000000..c950b726b --- /dev/null +++ b/tests-unit/assets_test/test_helpers.py @@ -0,0 +1,69 @@ +"""Unit tests for app.assets.helpers.""" + +from app.assets.helpers import expand_bucket_prefixes + + +class TestExpandBucketPrefixes: + def test_flat_category_unchanged(self): + # `checkpoints` is already a standalone token, no expansion needed. + assert expand_bucket_prefixes(["models", "checkpoints"]) == [ + "models", + "checkpoints", + ] + + def test_nested_category_inserts_bucket(self): + # Path-derived shape for `models/checkpoints/flux/foo.safetensors` — + # the standalone bucket has to be present so the FE set-membership + # filter (`include_tags=models,checkpoints`) matches the asset. + assert expand_bucket_prefixes(["models", "checkpoints/flux"]) == [ + "models", + "checkpoints/flux", + "checkpoints", + ] + + def test_deeply_nested_only_first_segment_expands(self): + # Only the FIRST slash segment ever gets emitted as a standalone — + # intermediate path segments don't have routing significance. + assert expand_bucket_prefixes( + ["models", "diffusers/kolors/text_encoder"] + ) == ["models", "diffusers/kolors/text_encoder", "diffusers"] + + def test_unknown_prefix_does_not_expand(self): + # Free-form user labels with slashes whose first segment is not a + # registered bucket pass through opaquely. + assert expand_bucket_prefixes(["models", "my-org/team-a"]) == [ + "models", + "my-org/team-a", + ] + + def test_idempotent(self): + # Re-applying the helper is a no-op once the bucket is in the set. + expanded = expand_bucket_prefixes(["models", "checkpoints/flux"]) + assert expand_bucket_prefixes(expanded) == expanded + + def test_does_not_duplicate_existing_bucket(self): + # If the caller already supplied the standalone bucket, don't add a + # second copy. + assert expand_bucket_prefixes( + ["models", "checkpoints/flux", "checkpoints"] + ) == ["models", "checkpoints/flux", "checkpoints"] + + def test_preserves_caller_order(self): + # User tags after path tags must stay after; the inserted bucket + # token slots in immediately after its slash-joined parent so the + # microsecond stagger lands it at path-tier before user-tier. + assert expand_bucket_prefixes( + ["models", "loras/style", "favorite", "v2"] + ) == ["models", "loras/style", "loras", "favorite", "v2"] + + def test_empty_input(self): + assert expand_bucket_prefixes([]) == [] + + def test_input_root_with_subpath_no_expansion(self): + # `portraits` isn't a registered model category, so the input + # subpath stays opaque (FE filter doesn't have a checkpoint-loader + # analogue for input subfolders). + assert expand_bucket_prefixes(["input", "portraits/2026"]) == [ + "input", + "portraits/2026", + ] diff --git a/tests-unit/assets_test/test_user_tag_http_smoke.py b/tests-unit/assets_test/test_user_tag_http_smoke.py index e12a8b9d0..c461f5a05 100644 --- a/tests-unit/assets_test/test_user_tag_http_smoke.py +++ b/tests-unit/assets_test/test_user_tag_http_smoke.py @@ -87,3 +87,49 @@ def test_user_tag_batch_lands_after_path_tags_via_http( # Critically: alphabetical sort would put 'aaa-experiment' at position 0. assert tags_after.index("aaa-experiment") > tags_after.index("models") assert tags_after.index("aaa-experiment") > tags_after.index("checkpoints") + + +@pytest.fixture +def nested_checkpoint_asset(http: requests.Session, api_base: str): + """Upload a checkpoint at the slash-joined path shape cloud emits + (`models/checkpoints/flux/...`), then delete it on teardown. + """ + name = "nested_checkpoint.safetensors" + tags = ["models", "checkpoints/flux"] + files = {"file": (name, b"S" * 4096, "application/octet-stream")} + form_data = { + "tags": json.dumps(tags), + "name": name, + "user_metadata": json.dumps({}), + } + r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120) + assert r.status_code == 201, r.text + body = r.json() + yield body + http.delete( + f"{api_base}/api/assets/{body['id']}?delete_content=true", timeout=30 + ) + + +def test_nested_checkpoint_satisfies_fe_set_filter( + http: requests.Session, api_base: str, nested_checkpoint_asset: dict +): + """The case Simon flagged: a nested-path checkpoint must still match + `include_tags=models,checkpoints` — the FE combo-widget filter. + """ + ref_id = nested_checkpoint_asset["id"] + + stored = _fetch_asset_tags(http, api_base, ref_id) + # tag[1] keeps cloud's slash-joined positional contract; tag[2] holds + # the standalone bucket the FE filter looks for. + assert stored[:3] == ["models", "checkpoints/flux", "checkpoints"] + + # The actual FE query — exact set-membership across both tokens. + r = http.get( + f"{api_base}/api/assets", + params=[("include_tags", "models"), ("include_tags", "checkpoints")], + timeout=30, + ) + assert r.status_code == 200, r.text + returned_ids = {a["id"] for a in r.json()["assets"]} + assert ref_id in returned_ids