mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-26 01:39:25 +08:00
spike: expose plural model folder memberships
Amp-Thread-ID: https://ampcode.com/threads/T-019e5117-c707-729d-bf98-dce718fe64d5 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
70de84cac7
commit
f74df348b6
@ -207,11 +207,13 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
|
||||
|
||||
asset_type = None
|
||||
model_folder = None
|
||||
model_folders = None
|
||||
file_path = None
|
||||
display_name = None
|
||||
if path_info:
|
||||
asset_type = path_info.asset_type
|
||||
model_folder = path_info.model_folder
|
||||
model_folders = path_info.model_folders
|
||||
file_path = path_info.file_path
|
||||
display_name = path_info.display_name
|
||||
|
||||
@ -224,6 +226,7 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
|
||||
size=int(result.asset.size_bytes) if result.asset else None,
|
||||
mime_type=result.asset.mime_type if result.asset else None,
|
||||
model_folder=model_folder,
|
||||
model_folders=model_folders,
|
||||
asset_type=asset_type,
|
||||
tags=result.tags,
|
||||
preview_url=preview_url,
|
||||
|
||||
@ -25,6 +25,10 @@ class Asset(BaseModel):
|
||||
default=None,
|
||||
description="Exact, case-sensitive registered ComfyUI model folder name. Present only when asset_type is `model`.",
|
||||
)
|
||||
model_folders: list[str] | None = Field(
|
||||
default=None,
|
||||
description="Exact, case-sensitive registered ComfyUI model folder names whose roots contain this asset. Present only when asset_type is `model`. This is plural membership for shared-root spike cases; `model_folder` remains the primary classification.",
|
||||
)
|
||||
asset_type: Literal["model", "input", "output", "temp"] | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
preview_url: str | None = None
|
||||
|
||||
@ -24,6 +24,7 @@ class AssetPathInfo:
|
||||
class AssetResponsePathInfo(AssetPathInfo):
|
||||
file_path: str
|
||||
display_name: str | None
|
||||
model_folders: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -118,6 +119,39 @@ def _normalize_relative_path(relative_path: str) -> str | None:
|
||||
return "/".join(parts)
|
||||
|
||||
|
||||
def get_model_folder_matches(
|
||||
file_path: str,
|
||||
primary_model_folder: str | None = None,
|
||||
) -> list[str]:
|
||||
"""Return all registered model-folder names whose roots contain ``file_path``.
|
||||
|
||||
This is the plural-membership spike counterpart to the singular
|
||||
``model_folder`` classification. The singular classification still chooses
|
||||
one primary folder (deepest root, then registry order), but broad shared
|
||||
roots can make a physical file visible through several registered buckets.
|
||||
|
||||
Results preserve registry order, with ``primary_model_folder`` moved to the
|
||||
front when provided. Duplicate folder names are collapsed.
|
||||
"""
|
||||
fp_abs = os.path.abspath(file_path)
|
||||
matches: list[str] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
for model_folder, bases in get_comfy_models_folders():
|
||||
for base in bases:
|
||||
if Path(fp_abs).is_relative_to(os.path.abspath(base)):
|
||||
if model_folder not in seen:
|
||||
matches.append(model_folder)
|
||||
seen.add(model_folder)
|
||||
break
|
||||
|
||||
if primary_model_folder in seen and matches[0] != primary_model_folder:
|
||||
matches.remove(primary_model_folder)
|
||||
matches.insert(0, primary_model_folder)
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def resolve_asset_path_context(file_path: str) -> AssetPathContext:
|
||||
"""Resolve a path against Core's asset roots and model-folder registration.
|
||||
|
||||
@ -305,6 +339,9 @@ def get_asset_response_path_info(file_path: str) -> AssetResponsePathInfo:
|
||||
model_folder=context.model_folder,
|
||||
file_path=logical_file_path,
|
||||
display_name=display_name,
|
||||
model_folders=get_model_folder_matches(file_path, context.model_folder)
|
||||
if context.asset_type == "model"
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
@ -362,6 +399,7 @@ def get_stored_asset_response_path_info(
|
||||
model_folder=model_folder,
|
||||
file_path=logical_file_path,
|
||||
display_name=display_name,
|
||||
model_folders=get_model_folder_matches(file_path, model_folder),
|
||||
)
|
||||
|
||||
root_by_type = {
|
||||
@ -379,6 +417,7 @@ def get_stored_asset_response_path_info(
|
||||
model_folder=None,
|
||||
file_path=logical_file_path,
|
||||
display_name=display_name,
|
||||
model_folders=None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -239,9 +239,44 @@ class TestBuildAssetResponsePathFields:
|
||||
|
||||
assert asset.asset_type == "model"
|
||||
assert asset.model_folder == "checkpoints"
|
||||
assert asset.model_folders == ["checkpoints"]
|
||||
assert asset.display_name == "sub/model.safetensors"
|
||||
assert asset.file_path == "models/checkpoints/sub/model.safetensors"
|
||||
|
||||
def test_model_response_includes_plural_model_folder_memberships(
|
||||
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
shared_dir = tmp_path / "models"
|
||||
shared_dir.mkdir()
|
||||
model_path = shared_dir / "checkpoints" / "model.safetensors"
|
||||
model_path.parent.mkdir()
|
||||
model_path.write_text("data")
|
||||
monkeypatch.setattr(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
lambda: [
|
||||
("checkpoints", [str(shared_dir)]),
|
||||
("loras", [str(shared_dir)]),
|
||||
("vae", [str(shared_dir)]),
|
||||
],
|
||||
)
|
||||
|
||||
asset = _build_asset_response(
|
||||
_asset_detail_result(
|
||||
_reference_data(
|
||||
name="model.safetensors",
|
||||
file_path=str(model_path),
|
||||
asset_type="model",
|
||||
model_folder="checkpoints",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
assert asset.asset_type == "model"
|
||||
assert asset.model_folder == "checkpoints"
|
||||
assert asset.model_folders == ["checkpoints", "loras", "vae"]
|
||||
assert asset.display_name == "checkpoints/model.safetensors"
|
||||
assert asset.file_path == "models/checkpoints/checkpoints/model.safetensors"
|
||||
|
||||
def test_input_output_response_fields_use_persisted_classification(
|
||||
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
@ -284,10 +319,12 @@ class TestBuildAssetResponsePathFields:
|
||||
|
||||
assert input_asset.asset_type == "input"
|
||||
assert input_asset.model_folder is None
|
||||
assert input_asset.model_folders is None
|
||||
assert input_asset.display_name == "sub/image.png"
|
||||
assert input_asset.file_path == "input/sub/image.png"
|
||||
assert output_asset.asset_type == "output"
|
||||
assert output_asset.model_folder is None
|
||||
assert output_asset.model_folders is None
|
||||
assert output_asset.display_name == "result.png"
|
||||
assert output_asset.file_path == "output/result.png"
|
||||
|
||||
@ -298,6 +335,7 @@ class TestBuildAssetResponsePathFields:
|
||||
|
||||
assert asset.asset_type is None
|
||||
assert asset.model_folder is None
|
||||
assert asset.model_folders is None
|
||||
assert asset.display_name is None
|
||||
assert asset.file_path is None
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ import pytest
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
get_comfy_models_folders,
|
||||
get_model_folder_matches,
|
||||
get_asset_category_and_relative_path,
|
||||
get_asset_path_info,
|
||||
get_asset_response_path_info,
|
||||
@ -135,6 +136,7 @@ class TestGetAssetPathInfo:
|
||||
|
||||
assert response_info.asset_type == "model"
|
||||
assert response_info.model_folder == "controlnet"
|
||||
assert response_info.model_folders == ["controlnet"]
|
||||
assert response_info.file_path == "models/controlnet/pose.safetensors"
|
||||
assert response_info.display_name == "pose.safetensors"
|
||||
|
||||
@ -158,6 +160,7 @@ class TestGetAssetPathInfo:
|
||||
|
||||
assert response_a.asset_type == response_b.asset_type == "model"
|
||||
assert response_a.model_folder == response_b.model_folder == "checkpoints"
|
||||
assert response_a.model_folders == response_b.model_folders == ["checkpoints"]
|
||||
assert response_a.file_path == "models/checkpoints/subdir/model_a.safetensors"
|
||||
assert response_b.file_path == "models/checkpoints/subdir/model_b.safetensors"
|
||||
assert response_a.display_name == "subdir/model_a.safetensors"
|
||||
@ -196,6 +199,7 @@ class TestGetAssetPathInfo:
|
||||
response_info = get_asset_response_path_info(str(f))
|
||||
assert response_info.file_path == "input/subdir/photo.png"
|
||||
assert response_info.display_name == "subdir/photo.png"
|
||||
assert response_info.model_folders is None
|
||||
|
||||
def test_output_backed_registered_model_folder_is_model(self, fake_dirs):
|
||||
output_checkpoints_dir = fake_dirs["output"] / "checkpoints"
|
||||
@ -216,6 +220,48 @@ class TestGetAssetPathInfo:
|
||||
|
||||
assert response_info.file_path == "models/checkpoints/saved.safetensors"
|
||||
assert response_info.display_name == "saved.safetensors"
|
||||
assert response_info.model_folders == ["checkpoints"]
|
||||
|
||||
def test_shared_root_returns_all_matching_model_folders(self, fake_dirs):
|
||||
shared_root = fake_dirs["models"].parent / "shared"
|
||||
shared_root.mkdir()
|
||||
f = shared_root / "checkpoints" / "foo.safetensors"
|
||||
f.parent.mkdir()
|
||||
f.touch()
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[
|
||||
("checkpoints", [str(shared_root)]),
|
||||
("loras", [str(shared_root)]),
|
||||
("vae", [str(shared_root)]),
|
||||
],
|
||||
):
|
||||
context = resolve_asset_path_context(str(f))
|
||||
response_info = get_asset_response_path_info(str(f))
|
||||
|
||||
assert context.model_folder == "checkpoints"
|
||||
assert response_info.model_folder == "checkpoints"
|
||||
assert response_info.model_folders == ["checkpoints", "loras", "vae"]
|
||||
assert response_info.display_name == "checkpoints/foo.safetensors"
|
||||
assert response_info.file_path == "models/checkpoints/checkpoints/foo.safetensors"
|
||||
|
||||
def test_model_folder_matches_can_move_primary_first(self, fake_dirs):
|
||||
shared_root = fake_dirs["models"].parent / "shared"
|
||||
shared_root.mkdir()
|
||||
f = shared_root / "foo.safetensors"
|
||||
f.touch()
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[
|
||||
("checkpoints", [str(shared_root)]),
|
||||
("loras", [str(shared_root)]),
|
||||
],
|
||||
):
|
||||
matches = get_model_folder_matches(str(f), primary_model_folder="loras")
|
||||
|
||||
assert matches == ["loras", "checkpoints"]
|
||||
|
||||
def test_registered_model_folder_can_contain_slash(self, fake_dirs):
|
||||
nested_model_dir = fake_dirs["models"].parent / "text_encoders" / "clip"
|
||||
@ -264,11 +310,15 @@ class TestGetAssetPathInfo:
|
||||
],
|
||||
):
|
||||
context = resolve_asset_path_context(str(f))
|
||||
response_info = get_asset_response_path_info(str(f))
|
||||
|
||||
assert context.asset_type == "model"
|
||||
assert context.model_folder == "text_encoders/clip"
|
||||
assert context.relative_path == "clip.safetensors"
|
||||
|
||||
assert response_info.model_folder == "text_encoders/clip"
|
||||
assert response_info.model_folders == ["text_encoders/clip", "text_encoders"]
|
||||
|
||||
def test_deepest_registered_model_base_wins_independent_of_registration_order(
|
||||
self, fake_dirs
|
||||
):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user