mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-26 01:39:25 +08:00
Merge branch 'master' into cloud-openapi-projection
This commit is contained in:
commit
2c313ae675
@ -219,7 +219,6 @@ async def list_assets_route(request: web.Request) -> web.Response:
|
|||||||
exclude_tags=q.exclude_tags,
|
exclude_tags=q.exclude_tags,
|
||||||
name_contains=q.name_contains,
|
name_contains=q.name_contains,
|
||||||
metadata_filter=q.metadata_filter,
|
metadata_filter=q.metadata_filter,
|
||||||
job_ids=q.job_ids,
|
|
||||||
limit=q.limit,
|
limit=q.limit,
|
||||||
offset=q.offset,
|
offset=q.offset,
|
||||||
sort=sort,
|
sort=sort,
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import json
|
import json
|
||||||
import uuid
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
@ -54,7 +53,6 @@ class ListAssetsQuery(BaseModel):
|
|||||||
include_tags: list[str] = Field(default_factory=list)
|
include_tags: list[str] = Field(default_factory=list)
|
||||||
exclude_tags: list[str] = Field(default_factory=list)
|
exclude_tags: list[str] = Field(default_factory=list)
|
||||||
name_contains: str | None = None
|
name_contains: str | None = None
|
||||||
job_ids: list[str] = Field(default_factory=list, max_length=500)
|
|
||||||
|
|
||||||
# Accept either a JSON string (query param) or a dict
|
# Accept either a JSON string (query param) or a dict
|
||||||
metadata_filter: dict[str, Any] | None = None
|
metadata_filter: dict[str, Any] | None = None
|
||||||
@ -88,40 +86,6 @@ class ListAssetsQuery(BaseModel):
|
|||||||
return out
|
return out
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("job_ids", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def _split_and_validate_job_ids(cls, v):
|
|
||||||
# Accept "uuid1,uuid2" or ["uuid1","uuid2"] or repeated query params.
|
|
||||||
# Each entry must parse as a UUID; canonicalized to lowercase hyphenated form.
|
|
||||||
if v is None:
|
|
||||||
return []
|
|
||||||
if isinstance(v, str):
|
|
||||||
raw = [t.strip() for t in v.split(",") if t.strip()]
|
|
||||||
elif isinstance(v, list):
|
|
||||||
raw = []
|
|
||||||
for item in v:
|
|
||||||
if not isinstance(item, str):
|
|
||||||
raise ValueError(
|
|
||||||
f"job_ids entries must be strings, got {type(item).__name__}"
|
|
||||||
)
|
|
||||||
raw.extend([t.strip() for t in item.split(",") if t.strip()])
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"job_ids must be a string or list of strings, got {type(v).__name__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
out: list[str] = []
|
|
||||||
seen: set[str] = set()
|
|
||||||
for s in raw:
|
|
||||||
try:
|
|
||||||
canonical = str(uuid.UUID(s))
|
|
||||||
except ValueError as e:
|
|
||||||
raise ValueError(f"job_ids must be UUIDs: {s!r}") from e
|
|
||||||
if canonical not in seen:
|
|
||||||
seen.add(canonical)
|
|
||||||
out.append(canonical)
|
|
||||||
return out
|
|
||||||
|
|
||||||
@field_validator("metadata_filter", mode="before")
|
@field_validator("metadata_filter", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _parse_metadata_json(cls, v):
|
def _parse_metadata_json(cls, v):
|
||||||
|
|||||||
@ -264,7 +264,6 @@ def list_references_page(
|
|||||||
include_tags: Sequence[str] | None = None,
|
include_tags: Sequence[str] | None = None,
|
||||||
exclude_tags: Sequence[str] | None = None,
|
exclude_tags: Sequence[str] | None = None,
|
||||||
metadata_filter: dict | None = None,
|
metadata_filter: dict | None = None,
|
||||||
job_ids: Sequence[str] | None = None,
|
|
||||||
sort: str | None = None,
|
sort: str | None = None,
|
||||||
order: str | None = None,
|
order: str | None = None,
|
||||||
after_cursor_value: object | None = None,
|
after_cursor_value: object | None = None,
|
||||||
@ -294,9 +293,6 @@ def list_references_page(
|
|||||||
escaped, esc = escape_sql_like_string(name_contains)
|
escaped, esc = escape_sql_like_string(name_contains)
|
||||||
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
||||||
|
|
||||||
if job_ids:
|
|
||||||
base = base.where(AssetReference.job_id.in_(list(job_ids)))
|
|
||||||
|
|
||||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||||
base = apply_metadata_filter(base, metadata_filter)
|
base = apply_metadata_filter(base, metadata_filter)
|
||||||
|
|
||||||
@ -349,8 +345,6 @@ def list_references_page(
|
|||||||
count_stmt = count_stmt.where(
|
count_stmt = count_stmt.where(
|
||||||
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
|
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
|
||||||
)
|
)
|
||||||
if job_ids:
|
|
||||||
count_stmt = count_stmt.where(AssetReference.job_id.in_(list(job_ids)))
|
|
||||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||||
|
|
||||||
|
|||||||
@ -274,7 +274,6 @@ def list_assets_page(
|
|||||||
exclude_tags: Sequence[str] | None = None,
|
exclude_tags: Sequence[str] | None = None,
|
||||||
name_contains: str | None = None,
|
name_contains: str | None = None,
|
||||||
metadata_filter: dict | None = None,
|
metadata_filter: dict | None = None,
|
||||||
job_ids: Sequence[str] | None = None,
|
|
||||||
limit: int = 20,
|
limit: int = 20,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
sort: str = "created_at",
|
sort: str = "created_at",
|
||||||
@ -320,7 +319,6 @@ def list_assets_page(
|
|||||||
exclude_tags=exclude_tags,
|
exclude_tags=exclude_tags,
|
||||||
name_contains=name_contains,
|
name_contains=name_contains,
|
||||||
metadata_filter=metadata_filter,
|
metadata_filter=metadata_filter,
|
||||||
job_ids=job_ids,
|
|
||||||
limit=fetch_limit,
|
limit=fetch_limit,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
sort=sort,
|
sort=sort,
|
||||||
|
|||||||
66
comfy_execution/asset_enrichment.py
Normal file
66
comfy_execution/asset_enrichment.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
"""Enrich executed-node output entries with asset id."""
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def enrich_output_with_assets(output_ui: dict) -> dict:
|
||||||
|
"""Register file-type output entries as assets and inject their ``id``.
|
||||||
|
|
||||||
|
Runs at output-processing time, once per produced output, when
|
||||||
|
--enable-assets is set. Returns a new dict; entries without a resolvable
|
||||||
|
on-disk file path are left unchanged. Errors are caught per-entry so a
|
||||||
|
failure never blocks execution or the other entries.
|
||||||
|
"""
|
||||||
|
from comfy.cli_args import args
|
||||||
|
if not args.enable_assets:
|
||||||
|
return output_ui
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
from app.assets.services.ingest import register_file_in_place, DependencyMissingError
|
||||||
|
|
||||||
|
enriched = {}
|
||||||
|
for key, entries in output_ui.items():
|
||||||
|
if not isinstance(entries, list):
|
||||||
|
enriched[key] = entries
|
||||||
|
continue
|
||||||
|
new_entries = []
|
||||||
|
for entry in entries:
|
||||||
|
if not isinstance(entry, dict) or "filename" not in entry or "type" not in entry:
|
||||||
|
new_entries.append(entry)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
base = folder_paths.get_directory_by_type(entry["type"])
|
||||||
|
if base is None:
|
||||||
|
new_entries.append(entry)
|
||||||
|
continue
|
||||||
|
base_abs = os.path.abspath(base)
|
||||||
|
abs_path = os.path.abspath(os.path.join(base_abs, entry.get("subfolder") or "", entry["filename"]))
|
||||||
|
try:
|
||||||
|
if os.path.commonpath([base_abs, abs_path]) != base_abs:
|
||||||
|
raise ValueError("escapes base")
|
||||||
|
except ValueError:
|
||||||
|
logging.warning("Asset enrichment skipped (path escapes base): %s", entry.get("filename"))
|
||||||
|
new_entries.append(entry)
|
||||||
|
continue
|
||||||
|
if not os.path.isfile(abs_path):
|
||||||
|
new_entries.append(entry)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Register unconditionally: the file was just produced, and
|
||||||
|
# register_file_in_place re-hashes so an overwritten path can
|
||||||
|
# never carry a stale id.
|
||||||
|
result = register_file_in_place(
|
||||||
|
abs_path=abs_path,
|
||||||
|
name=entry["filename"],
|
||||||
|
tags=[entry["type"]],
|
||||||
|
)
|
||||||
|
|
||||||
|
entry = dict(entry)
|
||||||
|
entry["id"] = result.ref.id
|
||||||
|
except DependencyMissingError:
|
||||||
|
logging.warning("Asset enrichment skipped (blake3 not available): %s", entry.get("filename"))
|
||||||
|
except Exception:
|
||||||
|
logging.warning("Failed to enrich output entry with asset id: %s", entry.get("filename"), exc_info=True)
|
||||||
|
new_entries.append(entry)
|
||||||
|
enriched[key] = new_entries
|
||||||
|
return enriched
|
||||||
@ -25,10 +25,9 @@ def validate_job_id(value) -> str:
|
|||||||
|
|
||||||
Job ids must be UUIDs in the canonical lowercase hyphenated form. The id
|
Job ids must be UUIDs in the canonical lowercase hyphenated form. The id
|
||||||
is stored and compared verbatim everywhere downstream — history keys,
|
is stored and compared verbatim everywhere downstream — history keys,
|
||||||
websocket events, /interrupt matching, and the assets ``job_ids`` filter
|
websocket events, and /interrupt matching — so accepting another spelling
|
||||||
(a String(36) column matched exactly) — so accepting another spelling
|
would silently rewrite the client's id and then miss every exact-match
|
||||||
would either rewrite the client's id behind its back or mint a job whose
|
lookup. Rejecting loudly beats that.
|
||||||
outputs the filter can never find. Rejecting loudly beats both.
|
|
||||||
|
|
||||||
Returns the id unchanged. Raises ValueError when the value is not a
|
Returns the id unchanged. Raises ValueError when the value is not a
|
||||||
string in canonical UUID form.
|
string in canonical UUID form.
|
||||||
|
|||||||
@ -40,6 +40,7 @@ from comfy_execution.graph_utils import GraphBuilder, is_link
|
|||||||
from comfy_execution.validation import validate_node_input
|
from comfy_execution.validation import validate_node_input
|
||||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||||
from comfy_execution.utils import CurrentNodeContext
|
from comfy_execution.utils import CurrentNodeContext
|
||||||
|
from comfy_execution.asset_enrichment import enrich_output_with_assets
|
||||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||||
from comfy_api.latest import io, _io
|
from comfy_api.latest import io, _io
|
||||||
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
|
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
|
||||||
@ -418,6 +419,7 @@ def _is_intermediate_output(dynprompt, node_id):
|
|||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
|
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
|
||||||
|
|
||||||
|
|
||||||
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
|
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
|
||||||
if server.client_id is None:
|
if server.client_id is None:
|
||||||
return
|
return
|
||||||
@ -552,6 +554,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
asyncio.create_task(await_completion())
|
asyncio.create_task(await_completion())
|
||||||
return (ExecutionResult.PENDING, None, None)
|
return (ExecutionResult.PENDING, None, None)
|
||||||
if len(output_ui) > 0:
|
if len(output_ui) > 0:
|
||||||
|
# Enrich at output-processing time (not in the send path) so assets
|
||||||
|
# are registered even when no client is connected, and the asset id
|
||||||
|
# flows into ui_outputs and the cache alongside the raw entries.
|
||||||
|
output_ui = enrich_output_with_assets(output_ui)
|
||||||
ui_outputs[unique_id] = {
|
ui_outputs[unique_id] = {
|
||||||
"meta": {
|
"meta": {
|
||||||
"node_id": unique_id,
|
"node_id": unique_id,
|
||||||
|
|||||||
@ -158,56 +158,6 @@ class TestListReferencesPage:
|
|||||||
refs, _, _ = list_references_page(session, sort="name", order="asc")
|
refs, _, _ = list_references_page(session, sort="name", order="asc")
|
||||||
assert refs[0].name == "large"
|
assert refs[0].name == "large"
|
||||||
|
|
||||||
def test_job_ids_filter(self, session: Session):
|
|
||||||
asset = _make_asset(session, "hash1")
|
|
||||||
job_a = str(uuid.uuid4())
|
|
||||||
job_b = str(uuid.uuid4())
|
|
||||||
ref_a = _make_reference(session, asset, name="from_job_a")
|
|
||||||
ref_a.job_id = job_a
|
|
||||||
ref_b = _make_reference(session, asset, name="from_job_b")
|
|
||||||
ref_b.job_id = job_b
|
|
||||||
_make_reference(session, asset, name="no_job")
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Single job filter
|
|
||||||
refs, _, total = list_references_page(session, job_ids=[job_a])
|
|
||||||
assert total == 1
|
|
||||||
assert refs[0].name == "from_job_a"
|
|
||||||
|
|
||||||
# Multi-job filter (IN)
|
|
||||||
refs, _, total = list_references_page(session, job_ids=[job_a, job_b])
|
|
||||||
names = sorted(r.name for r in refs)
|
|
||||||
assert total == 2
|
|
||||||
assert names == ["from_job_a", "from_job_b"]
|
|
||||||
|
|
||||||
# Unknown job id matches nothing
|
|
||||||
refs, _, total = list_references_page(session, job_ids=[str(uuid.uuid4())])
|
|
||||||
assert total == 0
|
|
||||||
assert refs == []
|
|
||||||
|
|
||||||
# Empty/None means no filter -> all three references
|
|
||||||
refs, _, total = list_references_page(session, job_ids=[])
|
|
||||||
assert total == 3
|
|
||||||
refs, _, total = list_references_page(session, job_ids=None)
|
|
||||||
assert total == 3
|
|
||||||
|
|
||||||
def test_job_ids_combined_with_other_filters(self, session: Session):
|
|
||||||
asset = _make_asset(session, "hash1")
|
|
||||||
job_a = str(uuid.uuid4())
|
|
||||||
ref_match = _make_reference(session, asset, name="match.bin")
|
|
||||||
ref_match.job_id = job_a
|
|
||||||
ref_wrong_name = _make_reference(session, asset, name="other.bin")
|
|
||||||
ref_wrong_name.job_id = job_a
|
|
||||||
ref_wrong_job = _make_reference(session, asset, name="match.bin")
|
|
||||||
ref_wrong_job.job_id = str(uuid.uuid4())
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
refs, _, total = list_references_page(
|
|
||||||
session, job_ids=[job_a], name_contains="match"
|
|
||||||
)
|
|
||||||
assert total == 1
|
|
||||||
assert refs[0].id == ref_match.id
|
|
||||||
|
|
||||||
|
|
||||||
class TestFetchReferenceAssetAndTags:
|
class TestFetchReferenceAssetAndTags:
|
||||||
def test_returns_none_for_nonexistent(self, session: Session):
|
def test_returns_none_for_nonexistent(self, session: Session):
|
||||||
|
|||||||
@ -1,60 +0,0 @@
|
|||||||
"""Schema-level unit tests for ListAssetsQuery (no DB required)."""
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import ValidationError
|
|
||||||
|
|
||||||
from app.assets.api.schemas_in import ListAssetsQuery
|
|
||||||
|
|
||||||
|
|
||||||
class TestJobIdsValidator:
|
|
||||||
def test_csv_string_parses_and_canonicalizes(self):
|
|
||||||
a = "AAAAAAAA-BBBB-CCCC-DDDD-EEEEEEEEEEEE"
|
|
||||||
b = "11111111-2222-3333-4444-555555555555"
|
|
||||||
q = ListAssetsQuery.model_validate({"job_ids": f"{a},{b}"})
|
|
||||||
# Canonicalized to lowercase
|
|
||||||
assert q.job_ids == [a.lower(), b]
|
|
||||||
|
|
||||||
def test_repeated_query_params_as_list(self):
|
|
||||||
a = "11111111-1111-1111-1111-111111111111"
|
|
||||||
b = "22222222-2222-2222-2222-222222222222"
|
|
||||||
q = ListAssetsQuery.model_validate({"job_ids": [a, b]})
|
|
||||||
assert q.job_ids == [a, b]
|
|
||||||
|
|
||||||
def test_dedup_preserves_first_seen_order(self):
|
|
||||||
a = "11111111-1111-1111-1111-111111111111"
|
|
||||||
b = "22222222-2222-2222-2222-222222222222"
|
|
||||||
q = ListAssetsQuery.model_validate({"job_ids": [a, b, a]})
|
|
||||||
assert q.job_ids == [a, b]
|
|
||||||
|
|
||||||
def test_default_empty(self):
|
|
||||||
q = ListAssetsQuery.model_validate({})
|
|
||||||
assert q.job_ids == []
|
|
||||||
|
|
||||||
def test_invalid_uuid_rejected(self):
|
|
||||||
with pytest.raises(ValidationError) as exc:
|
|
||||||
ListAssetsQuery.model_validate({"job_ids": "not-a-uuid"})
|
|
||||||
assert "must be UUIDs" in str(exc.value)
|
|
||||||
|
|
||||||
def test_non_string_list_item_rejected(self):
|
|
||||||
with pytest.raises(ValidationError) as exc:
|
|
||||||
ListAssetsQuery.model_validate(
|
|
||||||
{"job_ids": ["11111111-1111-1111-1111-111111111111", 42]}
|
|
||||||
)
|
|
||||||
assert "must be strings" in str(exc.value)
|
|
||||||
|
|
||||||
def test_non_string_non_list_value_rejected(self):
|
|
||||||
with pytest.raises(ValidationError) as exc:
|
|
||||||
ListAssetsQuery.model_validate({"job_ids": {"bad": "shape"}})
|
|
||||||
assert "must be a string or list of strings" in str(exc.value)
|
|
||||||
|
|
||||||
def test_max_length_enforced(self):
|
|
||||||
too_many = [str(uuid.uuid4()) for _ in range(501)]
|
|
||||||
with pytest.raises(ValidationError) as exc:
|
|
||||||
ListAssetsQuery.model_validate({"job_ids": too_many})
|
|
||||||
assert exc.value.errors()[0]["type"] == "too_long"
|
|
||||||
|
|
||||||
def test_max_length_boundary_accepted(self):
|
|
||||||
at_cap = [str(uuid.uuid4()) for _ in range(500)]
|
|
||||||
q = ListAssetsQuery.model_validate({"job_ids": at_cap})
|
|
||||||
assert len(q.job_ids) == 500
|
|
||||||
@ -1,9 +1,9 @@
|
|||||||
"""POST /prompt enforces canonical-UUID job ids at creation time.
|
"""POST /prompt enforces canonical-UUID job ids at creation time.
|
||||||
|
|
||||||
Lives in assets_test because it uses this suite's booted-server fixture and
|
Lives in assets_test because it uses this suite's booted-server fixture. The
|
||||||
because the invariant exists for the assets pipeline: the GET /api/assets
|
invariant itself is pipeline-wide: a job id is stored and compared verbatim
|
||||||
``job_ids`` filter matches stored job ids exactly, so a job minted with a
|
downstream — history keys, websocket correlation, and /interrupt matching —
|
||||||
non-canonical id would produce assets the filter can never find.
|
so a job minted with a non-canonical id would miss every exact-match lookup.
|
||||||
|
|
||||||
The prompt bodies here are intentionally invalid workflows — prompt_id
|
The prompt bodies here are intentionally invalid workflows — prompt_id
|
||||||
validation happens before workflow validation, so a rejected id returns
|
validation happens before workflow validation, so a rejected id returns
|
||||||
|
|||||||
205
tests-unit/execution_test/test_enrich_output.py
Normal file
205
tests-unit/execution_test/test_enrich_output.py
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
"""Tests for enrich_output_with_assets in comfy_execution/asset_enrichment.py."""
|
||||||
|
import os
|
||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
def _make_args(enable_assets: bool):
|
||||||
|
a = types.SimpleNamespace()
|
||||||
|
a.enable_assets = enable_assets
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
def _make_register_result(ref_id="ref-id-2"):
|
||||||
|
result = MagicMock()
|
||||||
|
result.ref.id = ref_id
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# Platform-appropriate absolute base. tempfile.gettempdir() returns C:\... on
|
||||||
|
# Windows and /tmp on POSIX, so containment via commonpath behaves naturally.
|
||||||
|
_DEFAULT_BASE = os.path.join(__import__("tempfile").gettempdir(), "asset-enrichment-test-base")
|
||||||
|
|
||||||
|
|
||||||
|
def _mocked_modules(*, enable_assets=True, register_file_in_place=None, directory=_DEFAULT_BASE):
|
||||||
|
return {
|
||||||
|
"comfy.cli_args": MagicMock(args=_make_args(enable_assets)),
|
||||||
|
"folder_paths": MagicMock(get_directory_by_type=MagicMock(return_value=directory)),
|
||||||
|
"app.assets.services.ingest": MagicMock(
|
||||||
|
register_file_in_place=register_file_in_place or MagicMock(return_value=_make_register_result()),
|
||||||
|
DependencyMissingError=type("DependencyMissingError", (Exception,), {}),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _call(output_ui, *, enable_assets=True, file_exists=True, register_result=None, directory=_DEFAULT_BASE):
|
||||||
|
register_mock = MagicMock(return_value=register_result or _make_register_result())
|
||||||
|
mocked = _mocked_modules(
|
||||||
|
enable_assets=enable_assets,
|
||||||
|
register_file_in_place=register_mock,
|
||||||
|
directory=directory,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only os.path.isfile is patched — abspath/join must run natively so the
|
||||||
|
# containment check sees real platform paths.
|
||||||
|
with patch.dict("sys.modules", mocked), \
|
||||||
|
patch("os.path.isfile", return_value=file_exists):
|
||||||
|
import importlib
|
||||||
|
import comfy_execution.asset_enrichment as mod
|
||||||
|
importlib.reload(mod)
|
||||||
|
return mod.enrich_output_with_assets(output_ui)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnrichOutputWithAssets(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_disabled_returns_unchanged(self):
|
||||||
|
output = {"images": [{"filename": "a.png", "subfolder": "", "type": "output"}]}
|
||||||
|
result = _call(output, enable_assets=False)
|
||||||
|
self.assertNotIn("id", result["images"][0])
|
||||||
|
|
||||||
|
def test_non_list_value_passed_through(self):
|
||||||
|
output = {"text": "hello"}
|
||||||
|
result = _call(output)
|
||||||
|
self.assertEqual(result["text"], "hello")
|
||||||
|
|
||||||
|
def test_entry_without_filename_unchanged(self):
|
||||||
|
output = {"latent": [{"subfolder": "", "type": "output"}]}
|
||||||
|
result = _call(output)
|
||||||
|
self.assertNotIn("id", result["latent"][0])
|
||||||
|
|
||||||
|
def test_entry_without_type_unchanged(self):
|
||||||
|
output = {"data": [{"filename": "a.png", "subfolder": ""}]}
|
||||||
|
result = _call(output)
|
||||||
|
self.assertNotIn("id", result["data"][0])
|
||||||
|
|
||||||
|
def test_file_not_on_disk_unchanged(self):
|
||||||
|
output = {"images": [{"filename": "missing.png", "subfolder": "", "type": "output"}]}
|
||||||
|
result = _call(output, file_exists=False)
|
||||||
|
self.assertNotIn("id", result["images"][0])
|
||||||
|
|
||||||
|
def test_unknown_type_returns_none_directory_unchanged(self):
|
||||||
|
output = {"images": [{"filename": "a.png", "subfolder": "", "type": "unknown"}]}
|
||||||
|
result = _call(output, directory=None)
|
||||||
|
self.assertNotIn("id", result["images"][0])
|
||||||
|
|
||||||
|
def test_register_injects_only_id(self):
|
||||||
|
reg = _make_register_result(ref_id="inline-ref")
|
||||||
|
output = {"images": [{"filename": "new.png", "subfolder": "", "type": "output"}]}
|
||||||
|
result = _call(output, register_result=reg)
|
||||||
|
img = result["images"][0]
|
||||||
|
self.assertEqual(img["id"], "inline-ref")
|
||||||
|
# Only id is injected — no asset_hash, name, preview_url, size
|
||||||
|
self.assertNotIn("asset_hash", img)
|
||||||
|
self.assertNotIn("name", img)
|
||||||
|
self.assertNotIn("preview_url", img)
|
||||||
|
self.assertNotIn("size", img)
|
||||||
|
|
||||||
|
def test_register_called_per_entry(self):
|
||||||
|
register_mock = MagicMock(return_value=_make_register_result())
|
||||||
|
mocked = _mocked_modules(register_file_in_place=register_mock)
|
||||||
|
output = {
|
||||||
|
"images": [
|
||||||
|
{"filename": "a.png", "subfolder": "", "type": "output"},
|
||||||
|
{"filename": "b.png", "subfolder": "", "type": "output"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", mocked), \
|
||||||
|
patch("os.path.isfile", return_value=True):
|
||||||
|
import importlib
|
||||||
|
import comfy_execution.asset_enrichment as mod
|
||||||
|
importlib.reload(mod)
|
||||||
|
mod.enrich_output_with_assets(output)
|
||||||
|
|
||||||
|
self.assertEqual(register_mock.call_count, 2)
|
||||||
|
|
||||||
|
def test_original_entry_not_mutated(self):
|
||||||
|
orig = {"filename": "a.png", "subfolder": "", "type": "output"}
|
||||||
|
output = {"images": [orig]}
|
||||||
|
_call(output)
|
||||||
|
self.assertNotIn("id", orig)
|
||||||
|
|
||||||
|
def test_enrichment_error_does_not_block_sibling_entries(self):
|
||||||
|
call_count = [0]
|
||||||
|
good_reg = _make_register_result(ref_id="good-ref")
|
||||||
|
|
||||||
|
def register_side_effect(abs_path, name, tags):
|
||||||
|
call_count[0] += 1
|
||||||
|
if call_count[0] == 1:
|
||||||
|
raise RuntimeError("boom")
|
||||||
|
return good_reg
|
||||||
|
|
||||||
|
mocked = _mocked_modules(register_file_in_place=register_side_effect)
|
||||||
|
|
||||||
|
output = {
|
||||||
|
"images": [
|
||||||
|
{"filename": "bad.png", "subfolder": "", "type": "output"},
|
||||||
|
{"filename": "good.png", "subfolder": "", "type": "output"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", mocked), \
|
||||||
|
patch("os.path.isfile", return_value=True):
|
||||||
|
import importlib
|
||||||
|
import comfy_execution.asset_enrichment as mod
|
||||||
|
importlib.reload(mod)
|
||||||
|
result = mod.enrich_output_with_assets(output)
|
||||||
|
|
||||||
|
imgs = result["images"]
|
||||||
|
self.assertNotIn("id", imgs[0])
|
||||||
|
self.assertEqual(imgs[1]["id"], "good-ref")
|
||||||
|
|
||||||
|
def test_multiple_output_keys_all_enriched(self):
|
||||||
|
output = {
|
||||||
|
"images": [{"filename": "a.png", "subfolder": "", "type": "output"}],
|
||||||
|
"videos": [{"filename": "b.mp4", "subfolder": "", "type": "output"}],
|
||||||
|
}
|
||||||
|
result = _call(output)
|
||||||
|
self.assertIn("id", result["images"][0])
|
||||||
|
self.assertIn("id", result["videos"][0])
|
||||||
|
|
||||||
|
def test_none_entry_in_list_unchanged(self):
|
||||||
|
output = {"images": [None, {"filename": "a.png", "subfolder": "", "type": "output"}]}
|
||||||
|
result = _call(output)
|
||||||
|
self.assertIsNone(result["images"][0])
|
||||||
|
self.assertIn("id", result["images"][1])
|
||||||
|
|
||||||
|
def test_path_traversal_subfolder_skipped(self):
|
||||||
|
register_mock = MagicMock(return_value=_make_register_result())
|
||||||
|
mocked = _mocked_modules(register_file_in_place=register_mock)
|
||||||
|
|
||||||
|
output = {"images": [{"filename": "passwd", "subfolder": "../../etc", "type": "output"}]}
|
||||||
|
|
||||||
|
# Do NOT patch os.path.abspath — real resolution is required for the containment check.
|
||||||
|
with patch.dict("sys.modules", mocked), \
|
||||||
|
patch("os.path.isfile", return_value=True):
|
||||||
|
import importlib
|
||||||
|
import comfy_execution.asset_enrichment as mod
|
||||||
|
importlib.reload(mod)
|
||||||
|
result = mod.enrich_output_with_assets(output)
|
||||||
|
|
||||||
|
self.assertNotIn("id", result["images"][0])
|
||||||
|
register_mock.assert_not_called()
|
||||||
|
|
||||||
|
def test_absolute_filename_skipped(self):
|
||||||
|
register_mock = MagicMock(return_value=_make_register_result())
|
||||||
|
mocked = _mocked_modules(register_file_in_place=register_mock)
|
||||||
|
|
||||||
|
# Absolute filename — os.path.join discards earlier components when a later one is absolute.
|
||||||
|
absolute_filename = os.path.abspath(os.sep + "etc" + os.sep + "passwd")
|
||||||
|
output = {"images": [{"filename": absolute_filename, "subfolder": "", "type": "output"}]}
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", mocked), \
|
||||||
|
patch("os.path.isfile", return_value=True):
|
||||||
|
import importlib
|
||||||
|
import comfy_execution.asset_enrichment as mod
|
||||||
|
importlib.reload(mod)
|
||||||
|
result = mod.enrich_output_with_assets(output)
|
||||||
|
|
||||||
|
self.assertNotIn("id", result["images"][0])
|
||||||
|
register_mock.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@ -35,8 +35,8 @@ class TestValidateJobId:
|
|||||||
)
|
)
|
||||||
def test_non_canonical_spellings_rejected(self, variant):
|
def test_non_canonical_spellings_rejected(self, variant):
|
||||||
# uuid.UUID parses all of these, but accepting them would silently
|
# uuid.UUID parses all of these, but accepting them would silently
|
||||||
# rewrite the client's id (history keys, websocket events, and the
|
# rewrite the client's id (history keys, websocket events, and
|
||||||
# assets job_ids filter all match the stored form exactly).
|
# /interrupt matching all match the stored form exactly).
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
validate_job_id(variant)
|
validate_job_id(variant)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user