spike: expose plural model folder memberships
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run

Amp-Thread-ID: https://ampcode.com/threads/T-019e5117-c707-729d-bf98-dce718fe64d5
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Simon Pinfold 2026-06-11 12:56:13 +12:00
parent 70de84cac7
commit f74df348b6
5 changed files with 134 additions and 0 deletions

View File

@ -207,11 +207,13 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
asset_type = None asset_type = None
model_folder = None model_folder = None
model_folders = None
file_path = None file_path = None
display_name = None display_name = None
if path_info: if path_info:
asset_type = path_info.asset_type asset_type = path_info.asset_type
model_folder = path_info.model_folder model_folder = path_info.model_folder
model_folders = path_info.model_folders
file_path = path_info.file_path file_path = path_info.file_path
display_name = path_info.display_name display_name = path_info.display_name
@ -224,6 +226,7 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
size=int(result.asset.size_bytes) if result.asset else None, size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None, mime_type=result.asset.mime_type if result.asset else None,
model_folder=model_folder, model_folder=model_folder,
model_folders=model_folders,
asset_type=asset_type, asset_type=asset_type,
tags=result.tags, tags=result.tags,
preview_url=preview_url, preview_url=preview_url,

View File

@ -25,6 +25,10 @@ class Asset(BaseModel):
default=None, default=None,
description="Exact, case-sensitive registered ComfyUI model folder name. Present only when asset_type is `model`.", description="Exact, case-sensitive registered ComfyUI model folder name. Present only when asset_type is `model`.",
) )
model_folders: list[str] | None = Field(
default=None,
description="Exact, case-sensitive registered ComfyUI model folder names whose roots contain this asset. Present only when asset_type is `model`. This is plural membership for shared-root spike cases; `model_folder` remains the primary classification.",
)
asset_type: Literal["model", "input", "output", "temp"] | None = None asset_type: Literal["model", "input", "output", "temp"] | None = None
tags: list[str] = Field(default_factory=list) tags: list[str] = Field(default_factory=list)
preview_url: str | None = None preview_url: str | None = None

View File

@ -24,6 +24,7 @@ class AssetPathInfo:
class AssetResponsePathInfo(AssetPathInfo): class AssetResponsePathInfo(AssetPathInfo):
file_path: str file_path: str
display_name: str | None display_name: str | None
model_folders: list[str] | None = None
@dataclass(frozen=True) @dataclass(frozen=True)
@ -118,6 +119,39 @@ def _normalize_relative_path(relative_path: str) -> str | None:
return "/".join(parts) return "/".join(parts)
def get_model_folder_matches(
file_path: str,
primary_model_folder: str | None = None,
) -> list[str]:
"""Return all registered model-folder names whose roots contain ``file_path``.
This is the plural-membership spike counterpart to the singular
``model_folder`` classification. The singular classification still chooses
one primary folder (deepest root, then registry order), but broad shared
roots can make a physical file visible through several registered buckets.
Results preserve registry order, with ``primary_model_folder`` moved to the
front when provided. Duplicate folder names are collapsed.
"""
fp_abs = os.path.abspath(file_path)
matches: list[str] = []
seen: set[str] = set()
for model_folder, bases in get_comfy_models_folders():
for base in bases:
if Path(fp_abs).is_relative_to(os.path.abspath(base)):
if model_folder not in seen:
matches.append(model_folder)
seen.add(model_folder)
break
if primary_model_folder in seen and matches[0] != primary_model_folder:
matches.remove(primary_model_folder)
matches.insert(0, primary_model_folder)
return matches
def resolve_asset_path_context(file_path: str) -> AssetPathContext: def resolve_asset_path_context(file_path: str) -> AssetPathContext:
"""Resolve a path against Core's asset roots and model-folder registration. """Resolve a path against Core's asset roots and model-folder registration.
@ -305,6 +339,9 @@ def get_asset_response_path_info(file_path: str) -> AssetResponsePathInfo:
model_folder=context.model_folder, model_folder=context.model_folder,
file_path=logical_file_path, file_path=logical_file_path,
display_name=display_name, display_name=display_name,
model_folders=get_model_folder_matches(file_path, context.model_folder)
if context.asset_type == "model"
else None,
) )
@ -362,6 +399,7 @@ def get_stored_asset_response_path_info(
model_folder=model_folder, model_folder=model_folder,
file_path=logical_file_path, file_path=logical_file_path,
display_name=display_name, display_name=display_name,
model_folders=get_model_folder_matches(file_path, model_folder),
) )
root_by_type = { root_by_type = {
@ -379,6 +417,7 @@ def get_stored_asset_response_path_info(
model_folder=None, model_folder=None,
file_path=logical_file_path, file_path=logical_file_path,
display_name=display_name, display_name=display_name,
model_folders=None,
) )

View File

@ -239,9 +239,44 @@ class TestBuildAssetResponsePathFields:
assert asset.asset_type == "model" assert asset.asset_type == "model"
assert asset.model_folder == "checkpoints" assert asset.model_folder == "checkpoints"
assert asset.model_folders == ["checkpoints"]
assert asset.display_name == "sub/model.safetensors" assert asset.display_name == "sub/model.safetensors"
assert asset.file_path == "models/checkpoints/sub/model.safetensors" assert asset.file_path == "models/checkpoints/sub/model.safetensors"
def test_model_response_includes_plural_model_folder_memberships(
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
):
shared_dir = tmp_path / "models"
shared_dir.mkdir()
model_path = shared_dir / "checkpoints" / "model.safetensors"
model_path.parent.mkdir()
model_path.write_text("data")
monkeypatch.setattr(
"app.assets.services.path_utils.get_comfy_models_folders",
lambda: [
("checkpoints", [str(shared_dir)]),
("loras", [str(shared_dir)]),
("vae", [str(shared_dir)]),
],
)
asset = _build_asset_response(
_asset_detail_result(
_reference_data(
name="model.safetensors",
file_path=str(model_path),
asset_type="model",
model_folder="checkpoints",
)
)
)
assert asset.asset_type == "model"
assert asset.model_folder == "checkpoints"
assert asset.model_folders == ["checkpoints", "loras", "vae"]
assert asset.display_name == "checkpoints/model.safetensors"
assert asset.file_path == "models/checkpoints/checkpoints/model.safetensors"
def test_input_output_response_fields_use_persisted_classification( def test_input_output_response_fields_use_persisted_classification(
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
): ):
@ -284,10 +319,12 @@ class TestBuildAssetResponsePathFields:
assert input_asset.asset_type == "input" assert input_asset.asset_type == "input"
assert input_asset.model_folder is None assert input_asset.model_folder is None
assert input_asset.model_folders is None
assert input_asset.display_name == "sub/image.png" assert input_asset.display_name == "sub/image.png"
assert input_asset.file_path == "input/sub/image.png" assert input_asset.file_path == "input/sub/image.png"
assert output_asset.asset_type == "output" assert output_asset.asset_type == "output"
assert output_asset.model_folder is None assert output_asset.model_folder is None
assert output_asset.model_folders is None
assert output_asset.display_name == "result.png" assert output_asset.display_name == "result.png"
assert output_asset.file_path == "output/result.png" assert output_asset.file_path == "output/result.png"
@ -298,6 +335,7 @@ class TestBuildAssetResponsePathFields:
assert asset.asset_type is None assert asset.asset_type is None
assert asset.model_folder is None assert asset.model_folder is None
assert asset.model_folders is None
assert asset.display_name is None assert asset.display_name is None
assert asset.file_path is None assert asset.file_path is None

View File

@ -9,6 +9,7 @@ import pytest
from app.assets.services.path_utils import ( from app.assets.services.path_utils import (
compute_relative_filename, compute_relative_filename,
get_comfy_models_folders, get_comfy_models_folders,
get_model_folder_matches,
get_asset_category_and_relative_path, get_asset_category_and_relative_path,
get_asset_path_info, get_asset_path_info,
get_asset_response_path_info, get_asset_response_path_info,
@ -135,6 +136,7 @@ class TestGetAssetPathInfo:
assert response_info.asset_type == "model" assert response_info.asset_type == "model"
assert response_info.model_folder == "controlnet" assert response_info.model_folder == "controlnet"
assert response_info.model_folders == ["controlnet"]
assert response_info.file_path == "models/controlnet/pose.safetensors" assert response_info.file_path == "models/controlnet/pose.safetensors"
assert response_info.display_name == "pose.safetensors" assert response_info.display_name == "pose.safetensors"
@ -158,6 +160,7 @@ class TestGetAssetPathInfo:
assert response_a.asset_type == response_b.asset_type == "model" assert response_a.asset_type == response_b.asset_type == "model"
assert response_a.model_folder == response_b.model_folder == "checkpoints" assert response_a.model_folder == response_b.model_folder == "checkpoints"
assert response_a.model_folders == response_b.model_folders == ["checkpoints"]
assert response_a.file_path == "models/checkpoints/subdir/model_a.safetensors" assert response_a.file_path == "models/checkpoints/subdir/model_a.safetensors"
assert response_b.file_path == "models/checkpoints/subdir/model_b.safetensors" assert response_b.file_path == "models/checkpoints/subdir/model_b.safetensors"
assert response_a.display_name == "subdir/model_a.safetensors" assert response_a.display_name == "subdir/model_a.safetensors"
@ -196,6 +199,7 @@ class TestGetAssetPathInfo:
response_info = get_asset_response_path_info(str(f)) response_info = get_asset_response_path_info(str(f))
assert response_info.file_path == "input/subdir/photo.png" assert response_info.file_path == "input/subdir/photo.png"
assert response_info.display_name == "subdir/photo.png" assert response_info.display_name == "subdir/photo.png"
assert response_info.model_folders is None
def test_output_backed_registered_model_folder_is_model(self, fake_dirs): def test_output_backed_registered_model_folder_is_model(self, fake_dirs):
output_checkpoints_dir = fake_dirs["output"] / "checkpoints" output_checkpoints_dir = fake_dirs["output"] / "checkpoints"
@ -216,6 +220,48 @@ class TestGetAssetPathInfo:
assert response_info.file_path == "models/checkpoints/saved.safetensors" assert response_info.file_path == "models/checkpoints/saved.safetensors"
assert response_info.display_name == "saved.safetensors" assert response_info.display_name == "saved.safetensors"
assert response_info.model_folders == ["checkpoints"]
def test_shared_root_returns_all_matching_model_folders(self, fake_dirs):
shared_root = fake_dirs["models"].parent / "shared"
shared_root.mkdir()
f = shared_root / "checkpoints" / "foo.safetensors"
f.parent.mkdir()
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[
("checkpoints", [str(shared_root)]),
("loras", [str(shared_root)]),
("vae", [str(shared_root)]),
],
):
context = resolve_asset_path_context(str(f))
response_info = get_asset_response_path_info(str(f))
assert context.model_folder == "checkpoints"
assert response_info.model_folder == "checkpoints"
assert response_info.model_folders == ["checkpoints", "loras", "vae"]
assert response_info.display_name == "checkpoints/foo.safetensors"
assert response_info.file_path == "models/checkpoints/checkpoints/foo.safetensors"
def test_model_folder_matches_can_move_primary_first(self, fake_dirs):
shared_root = fake_dirs["models"].parent / "shared"
shared_root.mkdir()
f = shared_root / "foo.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[
("checkpoints", [str(shared_root)]),
("loras", [str(shared_root)]),
],
):
matches = get_model_folder_matches(str(f), primary_model_folder="loras")
assert matches == ["loras", "checkpoints"]
def test_registered_model_folder_can_contain_slash(self, fake_dirs): def test_registered_model_folder_can_contain_slash(self, fake_dirs):
nested_model_dir = fake_dirs["models"].parent / "text_encoders" / "clip" nested_model_dir = fake_dirs["models"].parent / "text_encoders" / "clip"
@ -264,11 +310,15 @@ class TestGetAssetPathInfo:
], ],
): ):
context = resolve_asset_path_context(str(f)) context = resolve_asset_path_context(str(f))
response_info = get_asset_response_path_info(str(f))
assert context.asset_type == "model" assert context.asset_type == "model"
assert context.model_folder == "text_encoders/clip" assert context.model_folder == "text_encoders/clip"
assert context.relative_path == "clip.safetensors" assert context.relative_path == "clip.safetensors"
assert response_info.model_folder == "text_encoders/clip"
assert response_info.model_folders == ["text_encoders/clip", "text_encoders"]
def test_deepest_registered_model_base_wins_independent_of_registration_order( def test_deepest_registered_model_base_wins_independent_of_registration_order(
self, fake_dirs self, fake_dirs
): ):