mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-15 04:19:43 +08:00
Merge branch 'master' into matt/asset-enrichment-executed-ws
This commit is contained in:
commit
2cb7cda57c
@ -219,6 +219,7 @@ async def list_assets_route(request: web.Request) -> web.Response:
|
|||||||
exclude_tags=q.exclude_tags,
|
exclude_tags=q.exclude_tags,
|
||||||
name_contains=q.name_contains,
|
name_contains=q.name_contains,
|
||||||
metadata_filter=q.metadata_filter,
|
metadata_filter=q.metadata_filter,
|
||||||
|
job_ids=q.job_ids,
|
||||||
limit=q.limit,
|
limit=q.limit,
|
||||||
offset=q.offset,
|
offset=q.offset,
|
||||||
sort=sort,
|
sort=sort,
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
@ -53,6 +54,7 @@ class ListAssetsQuery(BaseModel):
|
|||||||
include_tags: list[str] = Field(default_factory=list)
|
include_tags: list[str] = Field(default_factory=list)
|
||||||
exclude_tags: list[str] = Field(default_factory=list)
|
exclude_tags: list[str] = Field(default_factory=list)
|
||||||
name_contains: str | None = None
|
name_contains: str | None = None
|
||||||
|
job_ids: list[str] = Field(default_factory=list, max_length=500)
|
||||||
|
|
||||||
# Accept either a JSON string (query param) or a dict
|
# Accept either a JSON string (query param) or a dict
|
||||||
metadata_filter: dict[str, Any] | None = None
|
metadata_filter: dict[str, Any] | None = None
|
||||||
@ -86,6 +88,40 @@ class ListAssetsQuery(BaseModel):
|
|||||||
return out
|
return out
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@field_validator("job_ids", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _split_and_validate_job_ids(cls, v):
|
||||||
|
# Accept "uuid1,uuid2" or ["uuid1","uuid2"] or repeated query params.
|
||||||
|
# Each entry must parse as a UUID; canonicalized to lowercase hyphenated form.
|
||||||
|
if v is None:
|
||||||
|
return []
|
||||||
|
if isinstance(v, str):
|
||||||
|
raw = [t.strip() for t in v.split(",") if t.strip()]
|
||||||
|
elif isinstance(v, list):
|
||||||
|
raw = []
|
||||||
|
for item in v:
|
||||||
|
if not isinstance(item, str):
|
||||||
|
raise ValueError(
|
||||||
|
f"job_ids entries must be strings, got {type(item).__name__}"
|
||||||
|
)
|
||||||
|
raw.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"job_ids must be a string or list of strings, got {type(v).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
out: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for s in raw:
|
||||||
|
try:
|
||||||
|
canonical = str(uuid.UUID(s))
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"job_ids must be UUIDs: {s!r}") from e
|
||||||
|
if canonical not in seen:
|
||||||
|
seen.add(canonical)
|
||||||
|
out.append(canonical)
|
||||||
|
return out
|
||||||
|
|
||||||
@field_validator("metadata_filter", mode="before")
|
@field_validator("metadata_filter", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _parse_metadata_json(cls, v):
|
def _parse_metadata_json(cls, v):
|
||||||
|
|||||||
@ -264,6 +264,7 @@ def list_references_page(
|
|||||||
include_tags: Sequence[str] | None = None,
|
include_tags: Sequence[str] | None = None,
|
||||||
exclude_tags: Sequence[str] | None = None,
|
exclude_tags: Sequence[str] | None = None,
|
||||||
metadata_filter: dict | None = None,
|
metadata_filter: dict | None = None,
|
||||||
|
job_ids: Sequence[str] | None = None,
|
||||||
sort: str | None = None,
|
sort: str | None = None,
|
||||||
order: str | None = None,
|
order: str | None = None,
|
||||||
after_cursor_value: object | None = None,
|
after_cursor_value: object | None = None,
|
||||||
@ -293,6 +294,9 @@ def list_references_page(
|
|||||||
escaped, esc = escape_sql_like_string(name_contains)
|
escaped, esc = escape_sql_like_string(name_contains)
|
||||||
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
||||||
|
|
||||||
|
if job_ids:
|
||||||
|
base = base.where(AssetReference.job_id.in_(list(job_ids)))
|
||||||
|
|
||||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||||
base = apply_metadata_filter(base, metadata_filter)
|
base = apply_metadata_filter(base, metadata_filter)
|
||||||
|
|
||||||
@ -345,6 +349,8 @@ def list_references_page(
|
|||||||
count_stmt = count_stmt.where(
|
count_stmt = count_stmt.where(
|
||||||
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
|
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
|
||||||
)
|
)
|
||||||
|
if job_ids:
|
||||||
|
count_stmt = count_stmt.where(AssetReference.job_id.in_(list(job_ids)))
|
||||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||||
|
|
||||||
|
|||||||
@ -274,6 +274,7 @@ def list_assets_page(
|
|||||||
exclude_tags: Sequence[str] | None = None,
|
exclude_tags: Sequence[str] | None = None,
|
||||||
name_contains: str | None = None,
|
name_contains: str | None = None,
|
||||||
metadata_filter: dict | None = None,
|
metadata_filter: dict | None = None,
|
||||||
|
job_ids: Sequence[str] | None = None,
|
||||||
limit: int = 20,
|
limit: int = 20,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
sort: str = "created_at",
|
sort: str = "created_at",
|
||||||
@ -319,6 +320,7 @@ def list_assets_page(
|
|||||||
exclude_tags=exclude_tags,
|
exclude_tags=exclude_tags,
|
||||||
name_contains=name_contains,
|
name_contains=name_contains,
|
||||||
metadata_filter=metadata_filter,
|
metadata_filter=metadata_filter,
|
||||||
|
job_ids=job_ids,
|
||||||
limit=fetch_limit,
|
limit=fetch_limit,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
sort=sort,
|
sort=sort,
|
||||||
|
|||||||
@ -534,8 +534,10 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
|
|
||||||
torch.backends.cudnn.benchmark = True
|
def set_cudnn_benchmark():
|
||||||
|
if torch.cuda.is_available() and torch.backends.cudnn.is_available():
|
||||||
|
torch.backends.cudnn.benchmark = PerformanceFeature.AutoTune in args.fast
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if torch_version_numeric >= (2, 5):
|
if torch_version_numeric >= (2, 5):
|
||||||
|
|||||||
@ -3,6 +3,7 @@ Job utilities for the /api/jobs endpoint.
|
|||||||
Provides normalization and helper functions for job status tracking.
|
Provides normalization and helper functions for job status tracking.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from comfy_api.internal import prune_dict
|
from comfy_api.internal import prune_dict
|
||||||
@ -19,6 +20,26 @@ class JobStatus:
|
|||||||
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED]
|
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED]
|
||||||
|
|
||||||
|
|
||||||
|
def validate_job_id(value) -> str:
|
||||||
|
"""Validate a client-supplied job (prompt) id.
|
||||||
|
|
||||||
|
Job ids must be UUIDs in the canonical lowercase hyphenated form. The id
|
||||||
|
is stored and compared verbatim everywhere downstream — history keys,
|
||||||
|
websocket events, /interrupt matching, and the assets ``job_ids`` filter
|
||||||
|
(a String(36) column matched exactly) — so accepting another spelling
|
||||||
|
would either rewrite the client's id behind its back or mint a job whose
|
||||||
|
outputs the filter can never find. Rejecting loudly beats both.
|
||||||
|
|
||||||
|
Returns the id unchanged. Raises ValueError when the value is not a
|
||||||
|
string in canonical UUID form.
|
||||||
|
"""
|
||||||
|
if not isinstance(value, str):
|
||||||
|
raise ValueError(f"job id must be a string, got {type(value).__name__}")
|
||||||
|
if str(uuid.UUID(value)) != value:
|
||||||
|
raise ValueError("job id must be a UUID in canonical lowercase hyphenated form")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
# Media types that can be previewed in the frontend
|
# Media types that can be previewed in the frontend
|
||||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'})
|
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'})
|
||||||
|
|
||||||
|
|||||||
5
main.py
5
main.py
@ -490,6 +490,11 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||||
init_api_nodes=not args.disable_api_nodes
|
init_api_nodes=not args.disable_api_nodes
|
||||||
))
|
))
|
||||||
|
|
||||||
|
# Re-apply Comfy's cuDNN benchmark policy after custom-node imports. Benchmark
|
||||||
|
# mode can request near-card-sized autotune workspaces, and some custom nodes set it at import time.
|
||||||
|
comfy.model_management.set_cudnn_benchmark()
|
||||||
|
|
||||||
hook_breaker_ac10a0.restore_functions()
|
hook_breaker_ac10a0.restore_functions()
|
||||||
|
|
||||||
cuda_malloc_warning()
|
cuda_malloc_warning()
|
||||||
|
|||||||
@ -896,6 +896,11 @@ components:
|
|||||||
additionalProperties: true
|
additionalProperties: true
|
||||||
description: The workflow graph to execute
|
description: The workflow graph to execute
|
||||||
type: object
|
type: object
|
||||||
|
prompt_id:
|
||||||
|
description: Optional client-supplied job id. Must be a UUID in canonical lowercase hyphenated form; it is echoed back in the response. Omitted or null means the server generates one.
|
||||||
|
format: uuid
|
||||||
|
nullable: true
|
||||||
|
type: string
|
||||||
workflow_id:
|
workflow_id:
|
||||||
description: UUID identifying the cloud workflow entity to associate with this job
|
description: UUID identifying the cloud workflow entity to associate with this job
|
||||||
type: string
|
type: string
|
||||||
|
|||||||
18
server.py
18
server.py
@ -8,7 +8,7 @@ import time
|
|||||||
import nodes
|
import nodes
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import execution
|
import execution
|
||||||
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs
|
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs, validate_job_id
|
||||||
import uuid
|
import uuid
|
||||||
import urllib
|
import urllib
|
||||||
import json
|
import json
|
||||||
@ -942,7 +942,21 @@ class PromptServer():
|
|||||||
|
|
||||||
if "prompt" in json_data:
|
if "prompt" in json_data:
|
||||||
prompt = json_data["prompt"]
|
prompt = json_data["prompt"]
|
||||||
prompt_id = str(json_data.get("prompt_id", uuid.uuid4()))
|
client_prompt_id = json_data.get("prompt_id")
|
||||||
|
if client_prompt_id is None:
|
||||||
|
# Absent or explicit null: the server mints the id.
|
||||||
|
prompt_id = str(uuid.uuid4())
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
prompt_id = validate_job_id(client_prompt_id)
|
||||||
|
except ValueError:
|
||||||
|
error = {
|
||||||
|
"type": "invalid_prompt_id",
|
||||||
|
"message": "prompt_id must be a valid UUID",
|
||||||
|
"details": "prompt_id must be a UUID string in canonical lowercase hyphenated form; omit it to let the server generate one",
|
||||||
|
"extra_info": {}
|
||||||
|
}
|
||||||
|
return web.json_response({"error": error, "node_errors": {}}, status=400)
|
||||||
|
|
||||||
partial_execution_targets = None
|
partial_execution_targets = None
|
||||||
if "partial_execution_targets" in json_data:
|
if "partial_execution_targets" in json_data:
|
||||||
|
|||||||
@ -158,6 +158,56 @@ class TestListReferencesPage:
|
|||||||
refs, _, _ = list_references_page(session, sort="name", order="asc")
|
refs, _, _ = list_references_page(session, sort="name", order="asc")
|
||||||
assert refs[0].name == "large"
|
assert refs[0].name == "large"
|
||||||
|
|
||||||
|
def test_job_ids_filter(self, session: Session):
|
||||||
|
asset = _make_asset(session, "hash1")
|
||||||
|
job_a = str(uuid.uuid4())
|
||||||
|
job_b = str(uuid.uuid4())
|
||||||
|
ref_a = _make_reference(session, asset, name="from_job_a")
|
||||||
|
ref_a.job_id = job_a
|
||||||
|
ref_b = _make_reference(session, asset, name="from_job_b")
|
||||||
|
ref_b.job_id = job_b
|
||||||
|
_make_reference(session, asset, name="no_job")
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Single job filter
|
||||||
|
refs, _, total = list_references_page(session, job_ids=[job_a])
|
||||||
|
assert total == 1
|
||||||
|
assert refs[0].name == "from_job_a"
|
||||||
|
|
||||||
|
# Multi-job filter (IN)
|
||||||
|
refs, _, total = list_references_page(session, job_ids=[job_a, job_b])
|
||||||
|
names = sorted(r.name for r in refs)
|
||||||
|
assert total == 2
|
||||||
|
assert names == ["from_job_a", "from_job_b"]
|
||||||
|
|
||||||
|
# Unknown job id matches nothing
|
||||||
|
refs, _, total = list_references_page(session, job_ids=[str(uuid.uuid4())])
|
||||||
|
assert total == 0
|
||||||
|
assert refs == []
|
||||||
|
|
||||||
|
# Empty/None means no filter -> all three references
|
||||||
|
refs, _, total = list_references_page(session, job_ids=[])
|
||||||
|
assert total == 3
|
||||||
|
refs, _, total = list_references_page(session, job_ids=None)
|
||||||
|
assert total == 3
|
||||||
|
|
||||||
|
def test_job_ids_combined_with_other_filters(self, session: Session):
|
||||||
|
asset = _make_asset(session, "hash1")
|
||||||
|
job_a = str(uuid.uuid4())
|
||||||
|
ref_match = _make_reference(session, asset, name="match.bin")
|
||||||
|
ref_match.job_id = job_a
|
||||||
|
ref_wrong_name = _make_reference(session, asset, name="other.bin")
|
||||||
|
ref_wrong_name.job_id = job_a
|
||||||
|
ref_wrong_job = _make_reference(session, asset, name="match.bin")
|
||||||
|
ref_wrong_job.job_id = str(uuid.uuid4())
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
refs, _, total = list_references_page(
|
||||||
|
session, job_ids=[job_a], name_contains="match"
|
||||||
|
)
|
||||||
|
assert total == 1
|
||||||
|
assert refs[0].id == ref_match.id
|
||||||
|
|
||||||
|
|
||||||
class TestFetchReferenceAssetAndTags:
|
class TestFetchReferenceAssetAndTags:
|
||||||
def test_returns_none_for_nonexistent(self, session: Session):
|
def test_returns_none_for_nonexistent(self, session: Session):
|
||||||
|
|||||||
60
tests-unit/assets_test/queries/test_list_assets_query.py
Normal file
60
tests-unit/assets_test/queries/test_list_assets_query.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
"""Schema-level unit tests for ListAssetsQuery (no DB required)."""
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from app.assets.api.schemas_in import ListAssetsQuery
|
||||||
|
|
||||||
|
|
||||||
|
class TestJobIdsValidator:
|
||||||
|
def test_csv_string_parses_and_canonicalizes(self):
|
||||||
|
a = "AAAAAAAA-BBBB-CCCC-DDDD-EEEEEEEEEEEE"
|
||||||
|
b = "11111111-2222-3333-4444-555555555555"
|
||||||
|
q = ListAssetsQuery.model_validate({"job_ids": f"{a},{b}"})
|
||||||
|
# Canonicalized to lowercase
|
||||||
|
assert q.job_ids == [a.lower(), b]
|
||||||
|
|
||||||
|
def test_repeated_query_params_as_list(self):
|
||||||
|
a = "11111111-1111-1111-1111-111111111111"
|
||||||
|
b = "22222222-2222-2222-2222-222222222222"
|
||||||
|
q = ListAssetsQuery.model_validate({"job_ids": [a, b]})
|
||||||
|
assert q.job_ids == [a, b]
|
||||||
|
|
||||||
|
def test_dedup_preserves_first_seen_order(self):
|
||||||
|
a = "11111111-1111-1111-1111-111111111111"
|
||||||
|
b = "22222222-2222-2222-2222-222222222222"
|
||||||
|
q = ListAssetsQuery.model_validate({"job_ids": [a, b, a]})
|
||||||
|
assert q.job_ids == [a, b]
|
||||||
|
|
||||||
|
def test_default_empty(self):
|
||||||
|
q = ListAssetsQuery.model_validate({})
|
||||||
|
assert q.job_ids == []
|
||||||
|
|
||||||
|
def test_invalid_uuid_rejected(self):
|
||||||
|
with pytest.raises(ValidationError) as exc:
|
||||||
|
ListAssetsQuery.model_validate({"job_ids": "not-a-uuid"})
|
||||||
|
assert "must be UUIDs" in str(exc.value)
|
||||||
|
|
||||||
|
def test_non_string_list_item_rejected(self):
|
||||||
|
with pytest.raises(ValidationError) as exc:
|
||||||
|
ListAssetsQuery.model_validate(
|
||||||
|
{"job_ids": ["11111111-1111-1111-1111-111111111111", 42]}
|
||||||
|
)
|
||||||
|
assert "must be strings" in str(exc.value)
|
||||||
|
|
||||||
|
def test_non_string_non_list_value_rejected(self):
|
||||||
|
with pytest.raises(ValidationError) as exc:
|
||||||
|
ListAssetsQuery.model_validate({"job_ids": {"bad": "shape"}})
|
||||||
|
assert "must be a string or list of strings" in str(exc.value)
|
||||||
|
|
||||||
|
def test_max_length_enforced(self):
|
||||||
|
too_many = [str(uuid.uuid4()) for _ in range(501)]
|
||||||
|
with pytest.raises(ValidationError) as exc:
|
||||||
|
ListAssetsQuery.model_validate({"job_ids": too_many})
|
||||||
|
assert exc.value.errors()[0]["type"] == "too_long"
|
||||||
|
|
||||||
|
def test_max_length_boundary_accepted(self):
|
||||||
|
at_cap = [str(uuid.uuid4()) for _ in range(500)]
|
||||||
|
q = ListAssetsQuery.model_validate({"job_ids": at_cap})
|
||||||
|
assert len(q.job_ids) == 500
|
||||||
69
tests-unit/assets_test/test_prompt_id_enforcement.py
Normal file
69
tests-unit/assets_test/test_prompt_id_enforcement.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
"""POST /prompt enforces canonical-UUID job ids at creation time.
|
||||||
|
|
||||||
|
Lives in assets_test because it uses this suite's booted-server fixture and
|
||||||
|
because the invariant exists for the assets pipeline: the GET /api/assets
|
||||||
|
``job_ids`` filter matches stored job ids exactly, so a job minted with a
|
||||||
|
non-canonical id would produce assets the filter can never find.
|
||||||
|
|
||||||
|
The prompt bodies here are intentionally invalid workflows — prompt_id
|
||||||
|
validation happens before workflow validation, so a rejected id returns
|
||||||
|
``invalid_prompt_id`` while an accepted id falls through to the ordinary
|
||||||
|
workflow-validation error (proving it cleared the id check).
|
||||||
|
"""
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def _post_prompt(http: requests.Session, api_base: str, body: dict) -> requests.Response:
|
||||||
|
return http.post(api_base + "/prompt", json=body, timeout=30)
|
||||||
|
|
||||||
|
|
||||||
|
def _error_type(r: requests.Response) -> str:
|
||||||
|
return r.json()["error"]["type"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_uuid_prompt_id_rejected(http: requests.Session, api_base: str):
|
||||||
|
r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": "not-a-uuid"})
|
||||||
|
assert r.status_code == 400, r.text
|
||||||
|
assert _error_type(r) == "invalid_prompt_id"
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_string_prompt_id_rejected(http: requests.Session, api_base: str):
|
||||||
|
# Previously str()-coerced (123 became the job id "123"); must now be a 400,
|
||||||
|
# not a 500 from uuid.UUID choking on a non-string.
|
||||||
|
r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": 123})
|
||||||
|
assert r.status_code == 400, r.text
|
||||||
|
assert _error_type(r) == "invalid_prompt_id"
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_canonical_uuid_rejected(http: requests.Session, api_base: str):
|
||||||
|
# Parseable as a UUID, but not the canonical lowercase form: rejected
|
||||||
|
# loudly rather than silently rewritten (downstream lookups match the
|
||||||
|
# stored id exactly).
|
||||||
|
r = _post_prompt(
|
||||||
|
http,
|
||||||
|
api_base,
|
||||||
|
{"prompt": {}, "prompt_id": "AAAAAAAA-BBBB-4CCC-8DDD-EEEEEEEEEEEE"},
|
||||||
|
)
|
||||||
|
assert r.status_code == 400, r.text
|
||||||
|
assert _error_type(r) == "invalid_prompt_id"
|
||||||
|
|
||||||
|
|
||||||
|
def test_canonical_uuid_accepted(http: requests.Session, api_base: str):
|
||||||
|
# The id clears validation; the empty workflow then fails ordinary prompt
|
||||||
|
# validation, proving the request got past the id check.
|
||||||
|
r = _post_prompt(
|
||||||
|
http,
|
||||||
|
api_base,
|
||||||
|
{"prompt": {}, "prompt_id": "aaaaaaaa-bbbb-4ccc-8ddd-eeeeeeeeeeee"},
|
||||||
|
)
|
||||||
|
assert r.status_code == 400, r.text
|
||||||
|
assert _error_type(r) != "invalid_prompt_id"
|
||||||
|
|
||||||
|
|
||||||
|
def test_null_prompt_id_not_rejected(http: requests.Session, api_base: str):
|
||||||
|
# Explicit null means "server generates" and must not be rejected as an
|
||||||
|
# invalid id. (The minted id itself is not observable here because the
|
||||||
|
# workflow is invalid; unit tests cover validate_job_id directly.)
|
||||||
|
r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": None})
|
||||||
|
assert r.status_code == 400, r.text
|
||||||
|
assert _error_type(r) != "invalid_prompt_id"
|
||||||
@ -1,5 +1,7 @@
|
|||||||
"""Unit tests for comfy_execution/jobs.py"""
|
"""Unit tests for comfy_execution/jobs.py"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from comfy_execution.jobs import (
|
from comfy_execution.jobs import (
|
||||||
JobStatus,
|
JobStatus,
|
||||||
is_previewable,
|
is_previewable,
|
||||||
@ -10,9 +12,50 @@ from comfy_execution.jobs import (
|
|||||||
get_outputs_summary,
|
get_outputs_summary,
|
||||||
apply_sorting,
|
apply_sorting,
|
||||||
has_3d_extension,
|
has_3d_extension,
|
||||||
|
validate_job_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateJobId:
|
||||||
|
"""validate_job_id guards job creation: POST /prompt rejects ids it raises on."""
|
||||||
|
|
||||||
|
def test_canonical_form_passes_through(self):
|
||||||
|
cid = "a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7"
|
||||||
|
assert validate_job_id(cid) == cid
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"variant",
|
||||||
|
[
|
||||||
|
"A1B2C3D4-E5F6-7A89-B0C1-D2E3F4A5B6C7", # uppercase
|
||||||
|
"{a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7}", # braced
|
||||||
|
"urn:uuid:a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7", # URN
|
||||||
|
"a1b2c3d4e5f67a89b0c1d2e3f4a5b6c7", # bare hex
|
||||||
|
" a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7 ", # padded
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_non_canonical_spellings_rejected(self, variant):
|
||||||
|
# uuid.UUID parses all of these, but accepting them would silently
|
||||||
|
# rewrite the client's id (history keys, websocket events, and the
|
||||||
|
# assets job_ids filter all match the stored form exactly).
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
validate_job_id(variant)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"bad",
|
||||||
|
["", "not-a-uuid", "prompt-123", "a1b2c3d4-e5f6-7a89-b0c1", "None"],
|
||||||
|
)
|
||||||
|
def test_non_uuid_strings_rejected(self, bad):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
validate_job_id(bad)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("bad", [123, 1.5, True, None, ["a"], {"id": "x"}])
|
||||||
|
def test_non_strings_rejected(self, bad):
|
||||||
|
# uuid.UUID raises AttributeError/TypeError on non-strings; the helper
|
||||||
|
# must normalize those to ValueError so callers need one except clause.
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
validate_job_id(bad)
|
||||||
|
|
||||||
|
|
||||||
class TestJobStatus:
|
class TestJobStatus:
|
||||||
"""Test JobStatus constants."""
|
"""Test JobStatus constants."""
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user