From f74df348b6409e885ce7c495ab1aea5cadc077aa Mon Sep 17 00:00:00 2001 From: Simon Pinfold Date: Thu, 11 Jun 2026 12:56:13 +1200 Subject: [PATCH] spike: expose plural model folder memberships Amp-Thread-ID: https://ampcode.com/threads/T-019e5117-c707-729d-bf98-dce718fe64d5 Co-authored-by: Amp --- app/assets/api/routes.py | 3 ++ app/assets/api/schemas_out.py | 4 ++ app/assets/services/path_utils.py | 39 +++++++++++++++ .../assets_test/queries/test_asset_info.py | 38 ++++++++++++++ .../assets_test/services/test_path_utils.py | 50 +++++++++++++++++++ 5 files changed, 134 insertions(+) diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 7a0121e8b..02988318a 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -207,11 +207,13 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu asset_type = None model_folder = None + model_folders = None file_path = None display_name = None if path_info: asset_type = path_info.asset_type model_folder = path_info.model_folder + model_folders = path_info.model_folders file_path = path_info.file_path display_name = path_info.display_name @@ -224,6 +226,7 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu size=int(result.asset.size_bytes) if result.asset else None, mime_type=result.asset.mime_type if result.asset else None, model_folder=model_folder, + model_folders=model_folders, asset_type=asset_type, tags=result.tags, preview_url=preview_url, diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py index 9850ee2e6..48c715f88 100644 --- a/app/assets/api/schemas_out.py +++ b/app/assets/api/schemas_out.py @@ -25,6 +25,10 @@ class Asset(BaseModel): default=None, description="Exact, case-sensitive registered ComfyUI model folder name. Present only when asset_type is `model`.", ) + model_folders: list[str] | None = Field( + default=None, + description="Exact, case-sensitive registered ComfyUI model folder names whose roots contain this asset. Present only when asset_type is `model`. This is plural membership for shared-root spike cases; `model_folder` remains the primary classification.", + ) asset_type: Literal["model", "input", "output", "temp"] | None = None tags: list[str] = Field(default_factory=list) preview_url: str | None = None diff --git a/app/assets/services/path_utils.py b/app/assets/services/path_utils.py index c4ba4e21d..c3a8a5c2c 100644 --- a/app/assets/services/path_utils.py +++ b/app/assets/services/path_utils.py @@ -24,6 +24,7 @@ class AssetPathInfo: class AssetResponsePathInfo(AssetPathInfo): file_path: str display_name: str | None + model_folders: list[str] | None = None @dataclass(frozen=True) @@ -118,6 +119,39 @@ def _normalize_relative_path(relative_path: str) -> str | None: return "/".join(parts) +def get_model_folder_matches( + file_path: str, + primary_model_folder: str | None = None, +) -> list[str]: + """Return all registered model-folder names whose roots contain ``file_path``. + + This is the plural-membership spike counterpart to the singular + ``model_folder`` classification. The singular classification still chooses + one primary folder (deepest root, then registry order), but broad shared + roots can make a physical file visible through several registered buckets. + + Results preserve registry order, with ``primary_model_folder`` moved to the + front when provided. Duplicate folder names are collapsed. + """ + fp_abs = os.path.abspath(file_path) + matches: list[str] = [] + seen: set[str] = set() + + for model_folder, bases in get_comfy_models_folders(): + for base in bases: + if Path(fp_abs).is_relative_to(os.path.abspath(base)): + if model_folder not in seen: + matches.append(model_folder) + seen.add(model_folder) + break + + if primary_model_folder in seen and matches[0] != primary_model_folder: + matches.remove(primary_model_folder) + matches.insert(0, primary_model_folder) + + return matches + + def resolve_asset_path_context(file_path: str) -> AssetPathContext: """Resolve a path against Core's asset roots and model-folder registration. @@ -305,6 +339,9 @@ def get_asset_response_path_info(file_path: str) -> AssetResponsePathInfo: model_folder=context.model_folder, file_path=logical_file_path, display_name=display_name, + model_folders=get_model_folder_matches(file_path, context.model_folder) + if context.asset_type == "model" + else None, ) @@ -362,6 +399,7 @@ def get_stored_asset_response_path_info( model_folder=model_folder, file_path=logical_file_path, display_name=display_name, + model_folders=get_model_folder_matches(file_path, model_folder), ) root_by_type = { @@ -379,6 +417,7 @@ def get_stored_asset_response_path_info( model_folder=None, file_path=logical_file_path, display_name=display_name, + model_folders=None, ) diff --git a/tests-unit/assets_test/queries/test_asset_info.py b/tests-unit/assets_test/queries/test_asset_info.py index 8fcc01d06..51e0736d9 100644 --- a/tests-unit/assets_test/queries/test_asset_info.py +++ b/tests-unit/assets_test/queries/test_asset_info.py @@ -239,9 +239,44 @@ class TestBuildAssetResponsePathFields: assert asset.asset_type == "model" assert asset.model_folder == "checkpoints" + assert asset.model_folders == ["checkpoints"] assert asset.display_name == "sub/model.safetensors" assert asset.file_path == "models/checkpoints/sub/model.safetensors" + def test_model_response_includes_plural_model_folder_memberships( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ): + shared_dir = tmp_path / "models" + shared_dir.mkdir() + model_path = shared_dir / "checkpoints" / "model.safetensors" + model_path.parent.mkdir() + model_path.write_text("data") + monkeypatch.setattr( + "app.assets.services.path_utils.get_comfy_models_folders", + lambda: [ + ("checkpoints", [str(shared_dir)]), + ("loras", [str(shared_dir)]), + ("vae", [str(shared_dir)]), + ], + ) + + asset = _build_asset_response( + _asset_detail_result( + _reference_data( + name="model.safetensors", + file_path=str(model_path), + asset_type="model", + model_folder="checkpoints", + ) + ) + ) + + assert asset.asset_type == "model" + assert asset.model_folder == "checkpoints" + assert asset.model_folders == ["checkpoints", "loras", "vae"] + assert asset.display_name == "checkpoints/model.safetensors" + assert asset.file_path == "models/checkpoints/checkpoints/model.safetensors" + def test_input_output_response_fields_use_persisted_classification( self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch ): @@ -284,10 +319,12 @@ class TestBuildAssetResponsePathFields: assert input_asset.asset_type == "input" assert input_asset.model_folder is None + assert input_asset.model_folders is None assert input_asset.display_name == "sub/image.png" assert input_asset.file_path == "input/sub/image.png" assert output_asset.asset_type == "output" assert output_asset.model_folder is None + assert output_asset.model_folders is None assert output_asset.display_name == "result.png" assert output_asset.file_path == "output/result.png" @@ -298,6 +335,7 @@ class TestBuildAssetResponsePathFields: assert asset.asset_type is None assert asset.model_folder is None + assert asset.model_folders is None assert asset.display_name is None assert asset.file_path is None diff --git a/tests-unit/assets_test/services/test_path_utils.py b/tests-unit/assets_test/services/test_path_utils.py index 8735718b1..6e1d04590 100644 --- a/tests-unit/assets_test/services/test_path_utils.py +++ b/tests-unit/assets_test/services/test_path_utils.py @@ -9,6 +9,7 @@ import pytest from app.assets.services.path_utils import ( compute_relative_filename, get_comfy_models_folders, + get_model_folder_matches, get_asset_category_and_relative_path, get_asset_path_info, get_asset_response_path_info, @@ -135,6 +136,7 @@ class TestGetAssetPathInfo: assert response_info.asset_type == "model" assert response_info.model_folder == "controlnet" + assert response_info.model_folders == ["controlnet"] assert response_info.file_path == "models/controlnet/pose.safetensors" assert response_info.display_name == "pose.safetensors" @@ -158,6 +160,7 @@ class TestGetAssetPathInfo: assert response_a.asset_type == response_b.asset_type == "model" assert response_a.model_folder == response_b.model_folder == "checkpoints" + assert response_a.model_folders == response_b.model_folders == ["checkpoints"] assert response_a.file_path == "models/checkpoints/subdir/model_a.safetensors" assert response_b.file_path == "models/checkpoints/subdir/model_b.safetensors" assert response_a.display_name == "subdir/model_a.safetensors" @@ -196,6 +199,7 @@ class TestGetAssetPathInfo: response_info = get_asset_response_path_info(str(f)) assert response_info.file_path == "input/subdir/photo.png" assert response_info.display_name == "subdir/photo.png" + assert response_info.model_folders is None def test_output_backed_registered_model_folder_is_model(self, fake_dirs): output_checkpoints_dir = fake_dirs["output"] / "checkpoints" @@ -216,6 +220,48 @@ class TestGetAssetPathInfo: assert response_info.file_path == "models/checkpoints/saved.safetensors" assert response_info.display_name == "saved.safetensors" + assert response_info.model_folders == ["checkpoints"] + + def test_shared_root_returns_all_matching_model_folders(self, fake_dirs): + shared_root = fake_dirs["models"].parent / "shared" + shared_root.mkdir() + f = shared_root / "checkpoints" / "foo.safetensors" + f.parent.mkdir() + f.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[ + ("checkpoints", [str(shared_root)]), + ("loras", [str(shared_root)]), + ("vae", [str(shared_root)]), + ], + ): + context = resolve_asset_path_context(str(f)) + response_info = get_asset_response_path_info(str(f)) + + assert context.model_folder == "checkpoints" + assert response_info.model_folder == "checkpoints" + assert response_info.model_folders == ["checkpoints", "loras", "vae"] + assert response_info.display_name == "checkpoints/foo.safetensors" + assert response_info.file_path == "models/checkpoints/checkpoints/foo.safetensors" + + def test_model_folder_matches_can_move_primary_first(self, fake_dirs): + shared_root = fake_dirs["models"].parent / "shared" + shared_root.mkdir() + f = shared_root / "foo.safetensors" + f.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[ + ("checkpoints", [str(shared_root)]), + ("loras", [str(shared_root)]), + ], + ): + matches = get_model_folder_matches(str(f), primary_model_folder="loras") + + assert matches == ["loras", "checkpoints"] def test_registered_model_folder_can_contain_slash(self, fake_dirs): nested_model_dir = fake_dirs["models"].parent / "text_encoders" / "clip" @@ -264,11 +310,15 @@ class TestGetAssetPathInfo: ], ): context = resolve_asset_path_context(str(f)) + response_info = get_asset_response_path_info(str(f)) assert context.asset_type == "model" assert context.model_folder == "text_encoders/clip" assert context.relative_path == "clip.safetensors" + assert response_info.model_folder == "text_encoders/clip" + assert response_info.model_folders == ["text_encoders/clip", "text_encoders"] + def test_deepest_registered_model_base_wins_independent_of_registration_order( self, fake_dirs ):