mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-29 15:50:22 +08:00
Merge branch 'develop'
This commit is contained in:
commit
4f6615e939
4
.github/workflows/docker-build-amd.yml
vendored
4
.github/workflows/docker-build-amd.yml
vendored
@ -34,7 +34,9 @@ jobs:
|
|||||||
tags: |
|
tags: |
|
||||||
type=raw,value=latest-rocm,enable={{is_default_branch}}
|
type=raw,value=latest-rocm,enable={{is_default_branch}}
|
||||||
type=sha,prefix=,suffix=-rocm
|
type=sha,prefix=,suffix=-rocm
|
||||||
type=semver,pattern={{version}},suffix=-rocm
|
type=match,pattern=v?(\d+\.\d+\.\d+\.\d+),group=1,suffix=-rocm
|
||||||
|
type=match,pattern=v?(\d+\.\d+\.\d+),group=1,suffix=-rocm
|
||||||
|
type=match,pattern=v?(\d+\.\d+),group=1,suffix=-rocm
|
||||||
- name: Build and push ROCm (AMD) image
|
- name: Build and push ROCm (AMD) image
|
||||||
uses: docker/build-push-action@v6
|
uses: docker/build-push-action@v6
|
||||||
with:
|
with:
|
||||||
|
|||||||
8
.github/workflows/docker-build.yml
vendored
8
.github/workflows/docker-build.yml
vendored
@ -41,8 +41,12 @@ jobs:
|
|||||||
type=raw,value=latest-cuda,enable={{is_default_branch}}
|
type=raw,value=latest-cuda,enable={{is_default_branch}}
|
||||||
type=sha,prefix=
|
type=sha,prefix=
|
||||||
type=sha,prefix=,suffix=-cuda
|
type=sha,prefix=,suffix=-cuda
|
||||||
type=semver,pattern={{version}}
|
type=match,pattern=v?(\d+\.\d+\.\d+\.\d+),group=1
|
||||||
type=semver,pattern={{version}},suffix=-cuda
|
type=match,pattern=v?(\d+\.\d+\.\d+\.\d+),group=1,suffix=-cuda
|
||||||
|
type=match,pattern=v?(\d+\.\d+\.\d+),group=1
|
||||||
|
type=match,pattern=v?(\d+\.\d+\.\d+),group=1,suffix=-cuda
|
||||||
|
type=match,pattern=v?(\d+\.\d+),group=1
|
||||||
|
type=match,pattern=v?(\d+\.\d+),group=1,suffix=-cuda
|
||||||
- name: Build and push CUDA (NVIDIA) image
|
- name: Build and push CUDA (NVIDIA) image
|
||||||
uses: docker/build-push-action@v6
|
uses: docker/build-push-action@v6
|
||||||
with:
|
with:
|
||||||
|
|||||||
@ -33,7 +33,7 @@ RUN pip install uv && uv --version && \
|
|||||||
|
|
||||||
# install sageattention
|
# install sageattention
|
||||||
ADD pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl /workspace/pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl
|
ADD pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl /workspace/pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl
|
||||||
RUN uv pip install -U --no-deps --no-build-isolation spandrel timm tensorboard poetry flash-attn "xformers==0.0.31.post1" "file:./pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl"
|
RUN uv pip install -U --no-deps --no-build-isolation spandrel timm tensorboard poetry "flash-attn<=2.8.0" "xformers==0.0.31.post1" "file:./pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl"
|
||||||
# this exotic command will determine the correct torchaudio to install for the image
|
# this exotic command will determine the correct torchaudio to install for the image
|
||||||
RUN <<-EOF
|
RUN <<-EOF
|
||||||
python -c 'import torch, re, subprocess
|
python -c 'import torch, re, subprocess
|
||||||
@ -66,7 +66,7 @@ WORKDIR /workspace
|
|||||||
# addresses https://github.com/pytorch/pytorch/issues/104801
|
# addresses https://github.com/pytorch/pytorch/issues/104801
|
||||||
# and issues reported by importing nodes_canny
|
# and issues reported by importing nodes_canny
|
||||||
# smoke test
|
# smoke test
|
||||||
RUN python -c "import torch; import xformers; import sageattention; import cv2" && comfyui --quick-test-for-ci --cpu --cwd /workspace
|
RUN python -c "import torch; import xformers; import sageattention; import cv2; import diffusers.hooks" && comfyui --quick-test-for-ci --cpu --cwd /workspace
|
||||||
|
|
||||||
EXPOSE 8188
|
EXPOSE 8188
|
||||||
CMD ["python", "-m", "comfy.cmd.main", "--listen", "--use-sage-attention", "--reserve-vram=0", "--logging-level=INFO", "--enable-cors"]
|
CMD ["python", "-m", "comfy.cmd.main", "--listen", "--use-sage-attention", "--reserve-vram=0", "--logging-level=INFO", "--enable-cors"]
|
||||||
|
|||||||
@ -354,19 +354,29 @@ paths:
|
|||||||
required: true
|
required: true
|
||||||
description: |
|
description: |
|
||||||
The ID of the prompt to query.
|
The ID of the prompt to query.
|
||||||
responses:
|
responses:
|
||||||
204:
|
204:
|
||||||
description: |
|
description: |
|
||||||
The prompt is still in progress
|
The prompt is still in progress
|
||||||
200:
|
200:
|
||||||
description: |
|
description: |
|
||||||
Prompt outputs
|
Prompt outputs
|
||||||
content:
|
content:
|
||||||
application/json:
|
application/json:
|
||||||
$ref: "#/components/schemas/Outputs"
|
schema:
|
||||||
404:
|
$ref: "#/components/schemas/Outputs"
|
||||||
description: |
|
404:
|
||||||
The prompt was not found
|
description: |
|
||||||
|
The prompt was not found
|
||||||
|
500:
|
||||||
|
description: |
|
||||||
|
An execution error occurred while processing the prompt.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
description:
|
||||||
|
An execution status directly from the workers
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/ExecutionStatusAsDict"
|
||||||
/api/v1/prompts:
|
/api/v1/prompts:
|
||||||
get:
|
get:
|
||||||
operationId: list_prompts
|
operationId: list_prompts
|
||||||
@ -861,6 +871,22 @@ components:
|
|||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
type: string
|
type: string
|
||||||
|
node_errors:
|
||||||
|
type: object
|
||||||
|
description: "Detailed validation errors per node"
|
||||||
|
additionalProperties:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
errors:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: "#/components/schemas/ValidationErrorDict"
|
||||||
|
dependent_outputs:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
class_type:
|
||||||
|
type: string
|
||||||
required:
|
required:
|
||||||
- type
|
- type
|
||||||
- details
|
- details
|
||||||
|
|||||||
@ -156,6 +156,7 @@ def _create_parser() -> EnhancedConfigArgParser:
|
|||||||
parser.add_argument("--whitelist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
|
parser.add_argument("--whitelist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
|
||||||
parser.add_argument("--blacklist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to never load. Accepts shell-style globs.")
|
parser.add_argument("--blacklist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to never load. Accepts shell-style globs.")
|
||||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
||||||
|
parser.add_argument("--enable-eval", action="store_true", help="Enable nodes that can evaluate Python code in workflows.")
|
||||||
|
|
||||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||||
parser.add_argument("--create-directories", action="store_true",
|
parser.add_argument("--create-directories", action="store_true",
|
||||||
|
|||||||
@ -169,6 +169,7 @@ class Configuration(dict):
|
|||||||
whitelist_custom_nodes (list[str]): Specify custom node folders to load even when --disable-all-custom-nodes is enabled.
|
whitelist_custom_nodes (list[str]): Specify custom node folders to load even when --disable-all-custom-nodes is enabled.
|
||||||
default_device (Optional[int]): Set the id of the default device, all other devices will stay visible.
|
default_device (Optional[int]): Set the id of the default device, all other devices will stay visible.
|
||||||
block_runtime_package_installation (Optional[bool]): When set, custom nodes like ComfyUI Manager, Easy Use, Nunchaku and others will not be able to use pip or uv to install packages at runtime (experimental).
|
block_runtime_package_installation (Optional[bool]): When set, custom nodes like ComfyUI Manager, Easy Use, Nunchaku and others will not be able to use pip or uv to install packages at runtime (experimental).
|
||||||
|
enable_eval (Optional[bool]): Enable nodes that can evaluate Python code in workflows.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@ -288,6 +289,7 @@ class Configuration(dict):
|
|||||||
self.database_url: str = db_config()
|
self.database_url: str = db_config()
|
||||||
self.default_device: Optional[int] = None
|
self.default_device: Optional[int] = None
|
||||||
self.block_runtime_package_installation = None
|
self.block_runtime_package_installation = None
|
||||||
|
self.enable_eval: Optional[bool] = False
|
||||||
|
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
self[key] = value
|
self[key] = value
|
||||||
@ -420,6 +422,7 @@ class FlattenAndAppendAction(argparse.Action):
|
|||||||
Custom action to handle comma-separated values and multiple invocations
|
Custom action to handle comma-separated values and multiple invocations
|
||||||
of the same argument, flattening them into a single list.
|
of the same argument, flattening them into a single list.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, parser, namespace, values, option_string=None):
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
items = getattr(namespace, self.dest, None)
|
items = getattr(namespace, self.dest, None)
|
||||||
if items is None:
|
if items is None:
|
||||||
|
|||||||
@ -1,23 +1,23 @@
|
|||||||
from asyncio import AbstractEventLoop
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
from aiohttp import WSMessage, ClientResponse, ClientTimeout
|
from asyncio import AbstractEventLoop
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from urllib.parse import urlparse, urljoin
|
from urllib.parse import urlparse, urljoin
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp import WSMessage, ClientResponse, ClientTimeout
|
||||||
|
from opentelemetry import trace
|
||||||
|
|
||||||
from .client_types import V1QueuePromptResponse
|
from .client_types import V1QueuePromptResponse
|
||||||
from ..api.api_client import JSONEncoder
|
from ..api.api_client import JSONEncoder
|
||||||
from ..api.components.schema.prompt import PromptDict
|
from ..api.components.schema.prompt import PromptDict
|
||||||
from ..api.components.schema.prompt_request import PromptRequest
|
from ..api.components.schema.prompt_request import PromptRequest
|
||||||
from ..api.paths.history.get.responses.response_200.content.application_json.schema import Schema as GetHistoryDict
|
from ..api.paths.history.get.responses.response_200.content.application_json.schema import Schema as GetHistoryDict
|
||||||
from ..api.schemas import immutabledict
|
from ..api.schemas import immutabledict
|
||||||
from ..component_model.file_output_path import file_output_path
|
|
||||||
from ..component_model.outputs_types import OutputsDict
|
from ..component_model.outputs_types import OutputsDict
|
||||||
|
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AsyncRemoteComfyClient:
|
class AsyncRemoteComfyClient:
|
||||||
"""
|
"""
|
||||||
@ -57,6 +57,27 @@ class AsyncRemoteComfyClient:
|
|||||||
def session(self) -> aiohttp.ClientSession:
|
def session(self) -> aiohttp.ClientSession:
|
||||||
return self._ensure_session()
|
return self._ensure_session()
|
||||||
|
|
||||||
|
def _build_headers(self, accept_header: str, prefer_header: Optional[str] = None, content_type: str = "application/json") -> dict:
|
||||||
|
"""Build HTTP headers for requests."""
|
||||||
|
headers = {'Content-Type': content_type, 'Accept': accept_header}
|
||||||
|
if prefer_header:
|
||||||
|
headers['Prefer'] = prefer_header
|
||||||
|
return headers
|
||||||
|
|
||||||
|
@tracer.start_as_current_span("Post Prompt")
|
||||||
|
async def _post_prompt(self, prompt: PromptDict, endpoint: str, accept_header: str, prefer_header: Optional[str] = None) -> ClientResponse:
|
||||||
|
"""
|
||||||
|
Common method to POST a prompt to a given endpoint.
|
||||||
|
:param prompt: The prompt to send
|
||||||
|
:param endpoint: The API endpoint (e.g., "/api/v1/prompts")
|
||||||
|
:param accept_header: The Accept header value
|
||||||
|
:param prefer_header: Optional Prefer header value
|
||||||
|
:return: The response object
|
||||||
|
"""
|
||||||
|
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
|
||||||
|
headers = self._build_headers(accept_header, prefer_header)
|
||||||
|
return await self.session.post(urljoin(self.server_address, endpoint), data=prompt_json, headers=headers)
|
||||||
|
|
||||||
async def len_queue(self) -> int:
|
async def len_queue(self) -> int:
|
||||||
async with self.session.get(urljoin(self.server_address, "/prompt"), headers={'Accept': 'application/json'}) as response:
|
async with self.session.get(urljoin(self.server_address, "/prompt"), headers={'Accept': 'application/json'}) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
@ -65,34 +86,30 @@ class AsyncRemoteComfyClient:
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f"unexpected response: {response.status}: {await response.text()}")
|
raise RuntimeError(f"unexpected response: {response.status}: {await response.text()}")
|
||||||
|
|
||||||
async def queue_and_forget_prompt_api(self, prompt: PromptDict) -> str:
|
async def queue_and_forget_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = "respond-async", accept_header: str = "application/json") -> str:
|
||||||
"""
|
"""
|
||||||
Calls the API to queue a prompt, and forgets about it
|
Calls the API to queue a prompt, and forgets about it
|
||||||
:param prompt:
|
:param prompt:
|
||||||
|
:param prefer_header: The Prefer header value (e.g., "respond-async" or None)
|
||||||
|
:param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async")
|
||||||
:return: the task ID
|
:return: the task ID
|
||||||
"""
|
"""
|
||||||
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
|
async with await self._post_prompt(prompt, "/api/v1/prompts", accept_header, prefer_header) as response:
|
||||||
response: ClientResponse
|
|
||||||
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
|
|
||||||
headers={'Content-Type': 'application/json', 'Accept': 'application/json', 'Prefer': 'respond-async'}) as response:
|
|
||||||
|
|
||||||
if 200 <= response.status < 400:
|
if 200 <= response.status < 400:
|
||||||
response_json = await response.json()
|
response_json = await response.json()
|
||||||
return response_json["prompt_id"]
|
return response_json["prompt_id"]
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
|
raise RuntimeError(f"could not prompt: {response.status}, reason={response.reason}: {await response.text()}")
|
||||||
|
|
||||||
async def queue_prompt_api(self, prompt: PromptDict) -> V1QueuePromptResponse:
|
async def queue_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse:
|
||||||
"""
|
"""
|
||||||
Calls the API to queue a prompt.
|
Calls the API to queue a prompt.
|
||||||
:param prompt:
|
:param prompt:
|
||||||
|
:param prefer_header: The Prefer header value (e.g., "respond-async" or None)
|
||||||
|
:param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async")
|
||||||
:return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true)
|
:return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true)
|
||||||
"""
|
"""
|
||||||
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
|
async with await self._post_prompt(prompt, "/api/v1/prompts", accept_header, prefer_header) as response:
|
||||||
response: ClientResponse
|
|
||||||
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
|
|
||||||
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response:
|
|
||||||
|
|
||||||
if 200 <= response.status < 400:
|
if 200 <= response.status < 400:
|
||||||
return V1QueuePromptResponse(**(await response.json()))
|
return V1QueuePromptResponse(**(await response.json()))
|
||||||
else:
|
else:
|
||||||
@ -112,17 +129,13 @@ class AsyncRemoteComfyClient:
|
|||||||
:param prompt:
|
:param prompt:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
|
async with await self._post_prompt(prompt, "/api/v1/prompts", "image/png") as response:
|
||||||
response: ClientResponse
|
|
||||||
headers = {'Content-Type': 'application/json', 'Accept': 'image/png'}
|
|
||||||
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
|
|
||||||
headers=headers) as response:
|
|
||||||
|
|
||||||
if 200 <= response.status < 400:
|
if 200 <= response.status < 400:
|
||||||
return await response.read()
|
return await response.read()
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
|
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
|
||||||
|
|
||||||
|
@tracer.start_as_current_span("Post Prompt (UI)")
|
||||||
async def queue_prompt_ui(self, prompt: PromptDict) -> OutputsDict:
|
async def queue_prompt_ui(self, prompt: PromptDict) -> OutputsDict:
|
||||||
"""
|
"""
|
||||||
Uses the comfyui UI API calls to retrieve the outputs dictionary
|
Uses the comfyui UI API calls to retrieve the outputs dictionary
|
||||||
@ -160,3 +173,41 @@ class AsyncRemoteComfyClient:
|
|||||||
# images have filename, subfolder, type keys
|
# images have filename, subfolder, type keys
|
||||||
# todo: use the OpenAPI spec for this when I get around to updating it
|
# todo: use the OpenAPI spec for this when I get around to updating it
|
||||||
return history_json[prompt_id].outputs
|
return history_json[prompt_id].outputs
|
||||||
|
|
||||||
|
async def get_prompt_status(self, prompt_id: str) -> ClientResponse:
|
||||||
|
"""
|
||||||
|
Get the status of a prompt by ID using the API endpoint.
|
||||||
|
:param prompt_id: The prompt ID to query
|
||||||
|
:return: The ClientResponse object (caller should check status and read body)
|
||||||
|
"""
|
||||||
|
return await self.session.get(urljoin(self.server_address, f"/api/v1/prompts/{prompt_id}"))
|
||||||
|
|
||||||
|
@tracer.start_as_current_span("Poll Prompt Until Done")
|
||||||
|
async def poll_prompt_until_done(self, prompt_id: str, max_attempts: int = 60, poll_interval: float = 1.0) -> tuple[int, dict | None]:
|
||||||
|
"""
|
||||||
|
Poll a prompt until it's done (200), errors (500), or times out.
|
||||||
|
:param prompt_id: The prompt ID to poll
|
||||||
|
:param max_attempts: Maximum number of polling attempts
|
||||||
|
:param poll_interval: Time to wait between polls in seconds
|
||||||
|
:return: Tuple of (status_code, response_json or None)
|
||||||
|
"""
|
||||||
|
span = trace.get_current_span()
|
||||||
|
span.set_attribute("prompt_id", prompt_id)
|
||||||
|
span.set_attribute("max_attempts", max_attempts)
|
||||||
|
|
||||||
|
for _ in range(max_attempts):
|
||||||
|
async with await self.get_prompt_status(prompt_id) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
return response.status, await response.json()
|
||||||
|
elif response.status == 500:
|
||||||
|
return response.status, await response.json()
|
||||||
|
elif response.status == 404:
|
||||||
|
return response.status, None
|
||||||
|
elif response.status == 204:
|
||||||
|
# Still in progress
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
else:
|
||||||
|
# Unexpected status
|
||||||
|
return response.status, None
|
||||||
|
# Timeout
|
||||||
|
return 408, None
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
from typing import TypeAlias, Union
|
from typing import TypeAlias, Union
|
||||||
|
|
||||||
from ..api.components.schema.prompt import PromptDict, Prompt
|
|
||||||
|
|
||||||
JSON: TypeAlias = Union[dict[str, "JSON"], list["JSON"], str, int, float, bool, None]
|
JSON: TypeAlias = Union[dict[str, "JSON"], list["JSON"], str, int, float, bool, None]
|
||||||
_BASE_PROMPT: JSON = {
|
_BASE_PROMPT: JSON = {
|
||||||
"4": {
|
"4": {
|
||||||
@ -161,7 +159,7 @@ def sdxl_workflow_with_refiner(prompt: str,
|
|||||||
sampler="euler_ancestral",
|
sampler="euler_ancestral",
|
||||||
scheduler="normal",
|
scheduler="normal",
|
||||||
filename_prefix="sdxl_",
|
filename_prefix="sdxl_",
|
||||||
seed=42) -> PromptDict:
|
seed=42) -> dict:
|
||||||
prompt_dict: JSON = copy.deepcopy(_BASE_PROMPT)
|
prompt_dict: JSON = copy.deepcopy(_BASE_PROMPT)
|
||||||
prompt_dict["17"]["inputs"]["text"] = prompt
|
prompt_dict["17"]["inputs"]["text"] = prompt
|
||||||
prompt_dict["20"]["inputs"]["text"] = negative_prompt
|
prompt_dict["20"]["inputs"]["text"] = negative_prompt
|
||||||
@ -188,4 +186,4 @@ def sdxl_workflow_with_refiner(prompt: str,
|
|||||||
prompt_dict["14"]["inputs"]["scheduler"] = scheduler
|
prompt_dict["14"]["inputs"]["scheduler"] = scheduler
|
||||||
|
|
||||||
prompt_dict["13"]["inputs"]["filename_prefix"] = filename_prefix
|
prompt_dict["13"]["inputs"]["filename_prefix"] = filename_prefix
|
||||||
return Prompt.validate(prompt_dict)
|
return prompt_dict
|
||||||
|
|||||||
@ -1246,16 +1246,36 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
|
|||||||
|
|
||||||
if len(good_outputs) == 0:
|
if len(good_outputs) == 0:
|
||||||
errors_list = []
|
errors_list = []
|
||||||
|
extra_info = {}
|
||||||
for o, _errors in errors:
|
for o, _errors in errors:
|
||||||
for error in _errors:
|
for error in _errors:
|
||||||
errors_list.append(f"{error['message']}: {error['details']}")
|
errors_list.append(f"{error['message']}: {error['details']}")
|
||||||
|
# Aggregate exception_type and traceback from validation errors
|
||||||
|
if 'extra_info' in error and error['extra_info']:
|
||||||
|
if 'exception_type' in error['extra_info'] and 'exception_type' not in extra_info:
|
||||||
|
extra_info['exception_type'] = error['extra_info']['exception_type']
|
||||||
|
if 'traceback' in error['extra_info'] and 'traceback' not in extra_info:
|
||||||
|
extra_info['traceback'] = error['extra_info']['traceback']
|
||||||
|
|
||||||
|
# Per OpenAPI spec, extra_info must have exception_type and traceback
|
||||||
|
# For non-exception validation errors, provide synthetic values
|
||||||
|
if 'exception_type' not in extra_info:
|
||||||
|
extra_info['exception_type'] = 'ValidationError'
|
||||||
|
if 'traceback' not in extra_info:
|
||||||
|
# Capture current stack for validation errors that don't have their own traceback
|
||||||
|
extra_info['traceback'] = traceback.format_stack()
|
||||||
|
|
||||||
|
# Include detailed node_errors for actionable debugging information
|
||||||
|
if node_errors:
|
||||||
|
extra_info['node_errors'] = node_errors
|
||||||
|
|
||||||
errors_list = "\n".join(errors_list)
|
errors_list = "\n".join(errors_list)
|
||||||
|
|
||||||
error = {
|
error = {
|
||||||
"type": "prompt_outputs_failed_validation",
|
"type": "prompt_outputs_failed_validation",
|
||||||
"message": "Prompt outputs failed validation",
|
"message": "Prompt outputs failed validation",
|
||||||
"details": errors_list,
|
"details": errors_list,
|
||||||
"extra_info": {}
|
"extra_info": extra_info
|
||||||
}
|
}
|
||||||
|
|
||||||
return ValidationTuple(False, error, list(good_outputs), node_errors)
|
return ValidationTuple(False, error, list(good_outputs), node_errors)
|
||||||
@ -1301,7 +1321,7 @@ class PromptQueue(AbstractPromptQueue):
|
|||||||
return copy.deepcopy(item_with_future.queue_tuple), task_id
|
return copy.deepcopy(item_with_future.queue_tuple), task_id
|
||||||
|
|
||||||
def task_done(self, item_id: str, outputs: HistoryResultDict,
|
def task_done(self, item_id: str, outputs: HistoryResultDict,
|
||||||
status: Optional[ExecutionStatus]):
|
status: Optional[ExecutionStatus], error_details: Optional[ExecutionErrorMessage] = None):
|
||||||
history_result = outputs
|
history_result = outputs
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
queue_item = self.currently_running.pop(item_id)
|
queue_item = self.currently_running.pop(item_id)
|
||||||
@ -1311,7 +1331,7 @@ class PromptQueue(AbstractPromptQueue):
|
|||||||
|
|
||||||
status_dict = None
|
status_dict = None
|
||||||
if status is not None:
|
if status is not None:
|
||||||
status_dict: Optional[ExecutionStatusAsDict] = status.as_dict()
|
status_dict: Optional[ExecutionStatusAsDict] = status.as_dict(error_details=error_details)
|
||||||
|
|
||||||
outputs_ = history_result["outputs"]
|
outputs_ = history_result["outputs"]
|
||||||
# Remove sensitive data from extra_data before storing in history
|
# Remove sensitive data from extra_data before storing in history
|
||||||
|
|||||||
@ -79,12 +79,25 @@ async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.
|
|||||||
|
|
||||||
await e.execute_async(item[2], prompt_id, item[3], item[4])
|
await e.execute_async(item[2], prompt_id, item[3], item[4])
|
||||||
need_gc = True
|
need_gc = True
|
||||||
|
|
||||||
|
# Extract error details from status_messages if there's an error
|
||||||
|
error_details = None
|
||||||
|
if not e.success:
|
||||||
|
for event, data in e.status_messages:
|
||||||
|
if event == "execution_error":
|
||||||
|
error_details = data
|
||||||
|
break
|
||||||
|
|
||||||
|
# Convert status_messages tuples to string messages for backward compatibility
|
||||||
|
messages = [f"{event}: {data.get('exception_message', str(data))}" if isinstance(data, dict) and 'exception_message' in data else f"{event}" for event, data in e.status_messages]
|
||||||
|
|
||||||
q.task_done(item_id,
|
q.task_done(item_id,
|
||||||
e.history_result,
|
e.history_result,
|
||||||
status=queue_types.ExecutionStatus(
|
status=queue_types.ExecutionStatus(
|
||||||
status_str='success' if e.success else 'error',
|
status_str='success' if e.success else 'error',
|
||||||
completed=e.success,
|
completed=e.success,
|
||||||
messages=e.status_messages))
|
messages=messages),
|
||||||
|
error_details=error_details)
|
||||||
if server_instance.client_id is not None:
|
if server_instance.client_id is not None:
|
||||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id},
|
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id},
|
||||||
server_instance.client_id)
|
server_instance.client_id)
|
||||||
|
|||||||
@ -15,6 +15,7 @@ import shutil
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
|
from opentelemetry.instrumentation.urllib3 import URLLib3Instrumentor
|
||||||
|
|
||||||
from .. import options
|
from .. import options
|
||||||
from ..app import logger
|
from ..app import logger
|
||||||
@ -41,6 +42,7 @@ warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_
|
|||||||
warnings.filterwarnings("ignore", message="Torch was not compiled with flash attention.")
|
warnings.filterwarnings("ignore", message="Torch was not compiled with flash attention.")
|
||||||
warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*")
|
warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*")
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning, message=r'`torch\.cuda\.amp\.custom_fwd.*')
|
warnings.filterwarnings('ignore', category=FutureWarning, message=r'`torch\.cuda\.amp\.custom_fwd.*')
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning, message="Please use the new API settings to control TF32 behavior.*")
|
||||||
warnings.filterwarnings("ignore", message="Importing from timm.models.registry is deprecated, please import via timm.models", category=FutureWarning)
|
warnings.filterwarnings("ignore", message="Importing from timm.models.registry is deprecated, please import via timm.models", category=FutureWarning)
|
||||||
warnings.filterwarnings("ignore", message="Importing from timm.models.layers is deprecated, please import via timm.layers", category=FutureWarning)
|
warnings.filterwarnings("ignore", message="Importing from timm.models.layers is deprecated, please import via timm.layers", category=FutureWarning)
|
||||||
warnings.filterwarnings("ignore", message="Inheritance class _InstrumentedApplication from web.Application is discouraged", category=DeprecationWarning)
|
warnings.filterwarnings("ignore", message="Inheritance class _InstrumentedApplication from web.Application is discouraged", category=DeprecationWarning)
|
||||||
@ -125,9 +127,11 @@ def _create_tracer():
|
|||||||
from opentelemetry.sdk.resources import Resource
|
from opentelemetry.sdk.resources import Resource
|
||||||
from opentelemetry.sdk.trace import TracerProvider
|
from opentelemetry.sdk.trace import TracerProvider
|
||||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter
|
||||||
|
from opentelemetry.processor.baggage import BaggageSpanProcessor, ALLOW_ALL_BAGGAGE_KEYS
|
||||||
|
from opentelemetry.instrumentation.aiohttp_server import AioHttpServerInstrumentor
|
||||||
|
from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor
|
||||||
from ..tracing_compatibility import ProgressSpanSampler
|
from ..tracing_compatibility import ProgressSpanSampler
|
||||||
from ..tracing_compatibility import patch_spanbuilder_set_channel
|
from ..tracing_compatibility import patch_spanbuilder_set_channel
|
||||||
from ..vendor.aiohttp_server_instrumentation import AioHttpServerInstrumentor
|
|
||||||
|
|
||||||
resource = Resource.create({
|
resource = Resource.create({
|
||||||
service_attributes.SERVICE_NAME: args.otel_service_name,
|
service_attributes.SERVICE_NAME: args.otel_service_name,
|
||||||
@ -141,18 +145,24 @@ def _create_tracer():
|
|||||||
has_endpoint = args.otel_exporter_otlp_endpoint is not None
|
has_endpoint = args.otel_exporter_otlp_endpoint is not None
|
||||||
|
|
||||||
if has_endpoint:
|
if has_endpoint:
|
||||||
otlp_exporter = OTLPSpanExporter()
|
exporter = OTLPSpanExporter()
|
||||||
else:
|
else:
|
||||||
otlp_exporter = SpanExporter()
|
exporter = SpanExporter()
|
||||||
|
|
||||||
processor = BatchSpanProcessor(otlp_exporter)
|
processor = BatchSpanProcessor(exporter)
|
||||||
provider.add_span_processor(processor)
|
provider.add_span_processor(processor)
|
||||||
|
|
||||||
# enable instrumentation
|
# enable instrumentation
|
||||||
patch_spanbuilder_set_channel()
|
patch_spanbuilder_set_channel()
|
||||||
|
|
||||||
AioPikaInstrumentor().instrument()
|
AioPikaInstrumentor().instrument()
|
||||||
AioHttpServerInstrumentor().instrument()
|
AioHttpServerInstrumentor().instrument()
|
||||||
|
AioHttpClientInstrumentor().instrument()
|
||||||
RequestsInstrumentor().instrument()
|
RequestsInstrumentor().instrument()
|
||||||
|
URLLib3Instrumentor().instrument()
|
||||||
|
|
||||||
|
|
||||||
|
provider.add_span_processor(BaggageSpanProcessor(ALLOW_ALL_BAGGAGE_KEYS))
|
||||||
# makes this behave better as a library
|
# makes this behave better as a library
|
||||||
return trace.get_tracer(args.otel_service_name, tracer_provider=provider)
|
return trace.get_tracer(args.otel_service_name, tracer_provider=provider)
|
||||||
|
|
||||||
|
|||||||
@ -116,6 +116,24 @@ async def compress_body(request: web.Request, handler):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@web.middleware
|
||||||
|
async def opentelemetry_middleware(request: web.Request, handler):
|
||||||
|
"""Middleware to extract and propagate OpenTelemetry context from request headers"""
|
||||||
|
from opentelemetry import propagate, context
|
||||||
|
|
||||||
|
# Extract OpenTelemetry context from headers
|
||||||
|
carrier = dict(request.headers)
|
||||||
|
ctx = propagate.extract(carrier)
|
||||||
|
|
||||||
|
# Attach context and execute handler
|
||||||
|
token = context.attach(ctx)
|
||||||
|
try:
|
||||||
|
response = await handler(request)
|
||||||
|
return response
|
||||||
|
finally:
|
||||||
|
context.detach(token)
|
||||||
|
|
||||||
|
|
||||||
def create_cors_middleware(allowed_origin: str):
|
def create_cors_middleware(allowed_origin: str):
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def cors_middleware(request: web.Request, handler):
|
async def cors_middleware(request: web.Request, handler):
|
||||||
@ -127,7 +145,7 @@ def create_cors_middleware(allowed_origin: str):
|
|||||||
|
|
||||||
response.headers['Access-Control-Allow-Origin'] = allowed_origin
|
response.headers['Access-Control-Allow-Origin'] = allowed_origin
|
||||||
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
|
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
|
||||||
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
|
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, traceparent, tracestate'
|
||||||
response.headers['Access-Control-Allow-Credentials'] = 'true'
|
response.headers['Access-Control-Allow-Credentials'] = 'true'
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@ -224,7 +242,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
self._external_address: Optional[str] = None
|
self._external_address: Optional[str] = None
|
||||||
self.background_tasks: dict[str, Task] = dict()
|
self.background_tasks: dict[str, Task] = dict()
|
||||||
|
|
||||||
middlewares = [cache_control, deprecation_warning]
|
middlewares = [opentelemetry_middleware, cache_control, deprecation_warning]
|
||||||
if args.enable_compress_response_body:
|
if args.enable_compress_response_body:
|
||||||
middlewares.append(compress_body)
|
middlewares.append(compress_body)
|
||||||
|
|
||||||
@ -867,9 +885,19 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
return web.json_response(status=404)
|
return web.json_response(status=404)
|
||||||
elif prompt_id in history_items:
|
elif prompt_id in history_items:
|
||||||
history_entry = history_items[prompt_id]
|
history_entry = history_items[prompt_id]
|
||||||
|
# Check if execution resulted in an error
|
||||||
|
if "status" in history_entry:
|
||||||
|
status = history_entry["status"]
|
||||||
|
if isinstance(status, dict) and status.get("status_str") == "error":
|
||||||
|
# Return ExecutionStatusAsDict format with status 500, matching POST /api/v1/prompts behavior
|
||||||
|
return web.Response(
|
||||||
|
body=json.dumps(status),
|
||||||
|
status=500,
|
||||||
|
content_type="application/json"
|
||||||
|
)
|
||||||
return web.json_response(history_entry["outputs"])
|
return web.json_response(history_entry["outputs"])
|
||||||
else:
|
else:
|
||||||
return web.json_response(status=500)
|
return web.Response(status=404, reason="prompt not found in expected state")
|
||||||
|
|
||||||
@routes.post("/api/v1/prompts")
|
@routes.post("/api/v1/prompts")
|
||||||
async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
|
async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
|
||||||
@ -877,9 +905,13 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
if accept == '*/*':
|
if accept == '*/*':
|
||||||
accept = "application/json"
|
accept = "application/json"
|
||||||
content_type = request.headers.get("content-type", "application/json")
|
content_type = request.headers.get("content-type", "application/json")
|
||||||
preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type
|
preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type + " " + accept
|
||||||
|
|
||||||
|
# handle media type parameters like "application/json+respond-async"
|
||||||
if "+" in content_type:
|
if "+" in content_type:
|
||||||
content_type = content_type.split("+")[0]
|
content_type = content_type.split("+")[0]
|
||||||
|
if "+" in accept:
|
||||||
|
accept = accept.split("+")[0]
|
||||||
|
|
||||||
wait = not "respond-async" in preferences
|
wait = not "respond-async" in preferences
|
||||||
|
|
||||||
@ -965,7 +997,8 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
return web.Response(body=str(ex), status=500)
|
return web.Response(body=str(ex), status=500)
|
||||||
|
|
||||||
if result.status is not None and result.status.status_str == "error":
|
if result.status is not None and result.status.status_str == "error":
|
||||||
return web.Response(body=json.dumps(result.status._asdict()), status=500, content_type="application/json")
|
status_dict = result.status.as_dict(error_details=result.error_details)
|
||||||
|
return web.Response(body=json.dumps(status_dict), status=500, content_type="application/json")
|
||||||
# find images and read them
|
# find images and read them
|
||||||
output_images: List[FileOutput] = []
|
output_images: List[FileOutput] = []
|
||||||
for node_id, node in result.outputs.items():
|
for node_id, node in result.outputs.items():
|
||||||
|
|||||||
@ -199,9 +199,8 @@ class ValidationErrorExtraInfoDict(TypedDict, total=False):
|
|||||||
input_config: NotRequired[Dict[str, InputTypeSpec]]
|
input_config: NotRequired[Dict[str, InputTypeSpec]]
|
||||||
received_value: NotRequired[Any]
|
received_value: NotRequired[Any]
|
||||||
linked_node: NotRequired[str]
|
linked_node: NotRequired[str]
|
||||||
traceback: NotRequired[list[str]]
|
|
||||||
exception_message: NotRequired[str]
|
exception_message: NotRequired[str]
|
||||||
exception_type: NotRequired[str]
|
node_errors: NotRequired[Dict[str, 'NodeErrorsDictValue']]
|
||||||
|
|
||||||
|
|
||||||
class ValidationErrorDict(TypedDict):
|
class ValidationErrorDict(TypedDict):
|
||||||
|
|||||||
@ -18,6 +18,7 @@ class TaskInvocation(NamedTuple):
|
|||||||
item_id: int | str
|
item_id: int | str
|
||||||
outputs: OutputsDict
|
outputs: OutputsDict
|
||||||
status: Optional[ExecutionStatus]
|
status: Optional[ExecutionStatus]
|
||||||
|
error_details: Optional['ExecutionErrorMessage'] = None
|
||||||
|
|
||||||
|
|
||||||
class ExecutionStatus(NamedTuple):
|
class ExecutionStatus(NamedTuple):
|
||||||
@ -25,12 +26,15 @@ class ExecutionStatus(NamedTuple):
|
|||||||
completed: bool
|
completed: bool
|
||||||
messages: List[str]
|
messages: List[str]
|
||||||
|
|
||||||
def as_dict(self) -> ExecutionStatusAsDict:
|
def as_dict(self, error_details: Optional['ExecutionErrorMessage'] = None) -> ExecutionStatusAsDict:
|
||||||
return {
|
result: ExecutionStatusAsDict = {
|
||||||
"status_str": self.status_str,
|
"status_str": self.status_str,
|
||||||
"completed": self.completed,
|
"completed": self.completed,
|
||||||
"messages": copy.copy(self.messages),
|
"messages": copy.copy(self.messages),
|
||||||
}
|
}
|
||||||
|
if error_details is not None:
|
||||||
|
result["error_details"] = error_details
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class ExecutionError(RuntimeError):
|
class ExecutionError(RuntimeError):
|
||||||
|
|||||||
@ -162,7 +162,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
|
|||||||
|
|
||||||
return item, item[1]
|
return item, item[1]
|
||||||
|
|
||||||
def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus]):
|
def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus], error_details: Optional['ExecutionErrorMessage'] = None):
|
||||||
# callee: executed on the worker thread
|
# callee: executed on the worker thread
|
||||||
if "outputs" in outputs:
|
if "outputs" in outputs:
|
||||||
outputs: HistoryResultDict
|
outputs: HistoryResultDict
|
||||||
@ -173,7 +173,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
|
|||||||
assert pending.completed is not None
|
assert pending.completed is not None
|
||||||
assert not pending.completed.done()
|
assert not pending.completed.done()
|
||||||
# finish the task. status will transmit the errors in comfy's domain-specific way
|
# finish the task. status will transmit the errors in comfy's domain-specific way
|
||||||
pending.completed.set_result(TaskInvocation(item_id=item_id, outputs=outputs, status=status))
|
pending.completed.set_result(TaskInvocation(item_id=item_id, outputs=outputs, status=status, error_details=error_details))
|
||||||
# todo: the caller is responsible for sending a websocket message right now that the UI expects for updates
|
# todo: the caller is responsible for sending a websocket message right now that the UI expects for updates
|
||||||
|
|
||||||
def get_current_queue(self) -> Tuple[List[QueueTuple], List[QueueTuple]]:
|
def get_current_queue(self) -> Tuple[List[QueueTuple], List[QueueTuple]]:
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from typing import Optional, Sequence
|
|||||||
from aio_pika.abc import AbstractChannel
|
from aio_pika.abc import AbstractChannel
|
||||||
from opentelemetry.context import Context
|
from opentelemetry.context import Context
|
||||||
from opentelemetry.sdk.trace.sampling import Sampler, SamplingResult, Decision
|
from opentelemetry.sdk.trace.sampling import Sampler, SamplingResult, Decision
|
||||||
from opentelemetry.semconv.trace import SpanAttributes
|
from opentelemetry.semconv.attributes.network_attributes import NETWORK_PEER_ADDRESS, NETWORK_PEER_PORT
|
||||||
from opentelemetry.trace import SpanKind, Link, TraceState
|
from opentelemetry.trace import SpanKind, Link, TraceState
|
||||||
from opentelemetry.util.types import Attributes
|
from opentelemetry.util.types import Attributes
|
||||||
|
|
||||||
@ -22,8 +22,8 @@ def patch_spanbuilder_set_channel() -> None:
|
|||||||
port = url.port or 5672
|
port = url.port or 5672
|
||||||
self._attributes.update(
|
self._attributes.update(
|
||||||
{
|
{
|
||||||
SpanAttributes.NET_PEER_NAME: url.host,
|
NETWORK_PEER_ADDRESS: url.host,
|
||||||
SpanAttributes.NET_PEER_PORT: port,
|
NETWORK_PEER_PORT: port,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
271
comfy/vendor/aiohttp_server_instrumentation.py
vendored
271
comfy/vendor/aiohttp_server_instrumentation.py
vendored
@ -1,271 +0,0 @@
|
|||||||
# Copyright 2020, OpenTelemetry Authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import urllib
|
|
||||||
from timeit import default_timer
|
|
||||||
from typing import Dict, List, Tuple, Union
|
|
||||||
|
|
||||||
from aiohttp import web
|
|
||||||
from multidict import CIMultiDictProxy
|
|
||||||
|
|
||||||
from opentelemetry import metrics, trace
|
|
||||||
_instruments = ("aiohttp ~= 3.0",)
|
|
||||||
__version__ = "0.49b0.dev"
|
|
||||||
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
|
|
||||||
from opentelemetry.instrumentation.utils import (
|
|
||||||
http_status_to_status_code,
|
|
||||||
is_http_instrumentation_enabled,
|
|
||||||
)
|
|
||||||
from opentelemetry.propagate import extract
|
|
||||||
from opentelemetry.propagators.textmap import Getter
|
|
||||||
from opentelemetry.semconv.metrics import MetricInstruments
|
|
||||||
from opentelemetry.semconv.trace import SpanAttributes
|
|
||||||
from opentelemetry.trace.status import Status, StatusCode
|
|
||||||
from opentelemetry.util.http import get_excluded_urls, remove_url_credentials
|
|
||||||
|
|
||||||
_duration_attrs = [
|
|
||||||
SpanAttributes.HTTP_METHOD,
|
|
||||||
SpanAttributes.HTTP_HOST,
|
|
||||||
SpanAttributes.HTTP_SCHEME,
|
|
||||||
SpanAttributes.HTTP_STATUS_CODE,
|
|
||||||
SpanAttributes.HTTP_FLAVOR,
|
|
||||||
SpanAttributes.HTTP_SERVER_NAME,
|
|
||||||
SpanAttributes.NET_HOST_NAME,
|
|
||||||
SpanAttributes.NET_HOST_PORT,
|
|
||||||
SpanAttributes.HTTP_ROUTE,
|
|
||||||
]
|
|
||||||
|
|
||||||
_active_requests_count_attrs = [
|
|
||||||
SpanAttributes.HTTP_METHOD,
|
|
||||||
SpanAttributes.HTTP_HOST,
|
|
||||||
SpanAttributes.HTTP_SCHEME,
|
|
||||||
SpanAttributes.HTTP_FLAVOR,
|
|
||||||
SpanAttributes.HTTP_SERVER_NAME,
|
|
||||||
]
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
meter = metrics.get_meter(__name__, __version__)
|
|
||||||
_excluded_urls = get_excluded_urls("AIOHTTP_SERVER")
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_duration_attrs(req_attrs):
|
|
||||||
duration_attrs = {}
|
|
||||||
for attr_key in _duration_attrs:
|
|
||||||
if req_attrs.get(attr_key) is not None:
|
|
||||||
duration_attrs[attr_key] = req_attrs[attr_key]
|
|
||||||
return duration_attrs
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_active_request_count_attrs(req_attrs):
|
|
||||||
active_requests_count_attrs = {}
|
|
||||||
for attr_key in _active_requests_count_attrs:
|
|
||||||
if req_attrs.get(attr_key) is not None:
|
|
||||||
active_requests_count_attrs[attr_key] = req_attrs[attr_key]
|
|
||||||
return active_requests_count_attrs
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_span_details(request: web.Request) -> Tuple[str, dict]:
|
|
||||||
"""Default implementation for get_default_span_details
|
|
||||||
Args:
|
|
||||||
request: the request object itself.
|
|
||||||
Returns:
|
|
||||||
a tuple of the span name, and any attributes to attach to the span.
|
|
||||||
"""
|
|
||||||
span_name = request.path.strip() or f"HTTP {request.method}"
|
|
||||||
return span_name, {}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_view_func(request: web.Request) -> str:
|
|
||||||
"""Returns the name of the request handler.
|
|
||||||
Args:
|
|
||||||
request: the request object itself.
|
|
||||||
Returns:
|
|
||||||
a string containing the name of the handler function
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return request.match_info.handler.__name__
|
|
||||||
except AttributeError:
|
|
||||||
return "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
def collect_request_attributes(request: web.Request) -> Dict:
|
|
||||||
"""Collects HTTP request attributes from the ASGI scope and returns a
|
|
||||||
dictionary to be used as span creation attributes."""
|
|
||||||
|
|
||||||
server_host, port, http_url = (
|
|
||||||
request.url.host,
|
|
||||||
request.url.port,
|
|
||||||
str(request.url),
|
|
||||||
)
|
|
||||||
query_string = request.query_string
|
|
||||||
if query_string and http_url:
|
|
||||||
if isinstance(query_string, bytes):
|
|
||||||
query_string = query_string.decode("utf8")
|
|
||||||
http_url += "?" + urllib.parse.unquote(query_string)
|
|
||||||
|
|
||||||
result = {
|
|
||||||
SpanAttributes.HTTP_SCHEME: request.scheme,
|
|
||||||
SpanAttributes.HTTP_HOST: server_host,
|
|
||||||
SpanAttributes.NET_HOST_PORT: port,
|
|
||||||
SpanAttributes.HTTP_ROUTE: _get_view_func(request),
|
|
||||||
SpanAttributes.HTTP_FLAVOR: f"{request.version.major}.{request.version.minor}",
|
|
||||||
SpanAttributes.HTTP_TARGET: request.path,
|
|
||||||
SpanAttributes.HTTP_URL: remove_url_credentials(http_url),
|
|
||||||
}
|
|
||||||
|
|
||||||
http_method = request.method
|
|
||||||
if http_method:
|
|
||||||
result[SpanAttributes.HTTP_METHOD] = http_method
|
|
||||||
|
|
||||||
http_host_value_list = (
|
|
||||||
[request.host] if not isinstance(request.host, list) else request.host
|
|
||||||
)
|
|
||||||
if http_host_value_list:
|
|
||||||
result[SpanAttributes.HTTP_SERVER_NAME] = ",".join(
|
|
||||||
http_host_value_list
|
|
||||||
)
|
|
||||||
http_user_agent = request.headers.get("user-agent")
|
|
||||||
if http_user_agent:
|
|
||||||
result[SpanAttributes.HTTP_USER_AGENT] = http_user_agent
|
|
||||||
|
|
||||||
# remove None values
|
|
||||||
result = {k: v for k, v in result.items() if v is not None}
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def set_status_code(span, status_code: int) -> None:
|
|
||||||
"""Adds HTTP response attributes to span using the status_code argument."""
|
|
||||||
|
|
||||||
try:
|
|
||||||
status_code = int(status_code)
|
|
||||||
except ValueError:
|
|
||||||
span.set_status(
|
|
||||||
Status(
|
|
||||||
StatusCode.ERROR,
|
|
||||||
"Non-integer HTTP status: " + repr(status_code),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
span.set_attribute(SpanAttributes.HTTP_STATUS_CODE, status_code)
|
|
||||||
span.set_status(
|
|
||||||
Status(http_status_to_status_code(status_code, server_span=True))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AiohttpGetter(Getter):
|
|
||||||
"""Extract current trace from headers"""
|
|
||||||
|
|
||||||
def get(self, carrier, key: str) -> Union[List, None]:
|
|
||||||
"""Getter implementation to retrieve an HTTP header value from the ASGI
|
|
||||||
scope.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
carrier: ASGI scope object
|
|
||||||
key: header name in scope
|
|
||||||
Returns:
|
|
||||||
A list of all header values matching the key, or None if the key
|
|
||||||
does not match any header.
|
|
||||||
"""
|
|
||||||
headers: CIMultiDictProxy = carrier.headers
|
|
||||||
if not headers:
|
|
||||||
return None
|
|
||||||
return headers.getall(key, None)
|
|
||||||
|
|
||||||
def keys(self, carrier: Dict) -> List:
|
|
||||||
return list(carrier.keys())
|
|
||||||
|
|
||||||
|
|
||||||
getter = AiohttpGetter()
|
|
||||||
|
|
||||||
|
|
||||||
@web.middleware
|
|
||||||
async def middleware(request, handler):
|
|
||||||
"""Middleware for aiohttp implementing tracing logic"""
|
|
||||||
if not is_http_instrumentation_enabled() or _excluded_urls.url_disabled(
|
|
||||||
request.url.path
|
|
||||||
):
|
|
||||||
return await handler(request)
|
|
||||||
|
|
||||||
span_name, additional_attributes = get_default_span_details(request)
|
|
||||||
|
|
||||||
req_attrs = collect_request_attributes(request)
|
|
||||||
duration_attrs = _parse_duration_attrs(req_attrs)
|
|
||||||
active_requests_count_attrs = _parse_active_request_count_attrs(req_attrs)
|
|
||||||
|
|
||||||
duration_histogram = meter.create_histogram(
|
|
||||||
name=MetricInstruments.HTTP_SERVER_DURATION,
|
|
||||||
unit="ms",
|
|
||||||
description="Measures the duration of inbound HTTP requests.",
|
|
||||||
)
|
|
||||||
|
|
||||||
active_requests_counter = meter.create_up_down_counter(
|
|
||||||
name=MetricInstruments.HTTP_SERVER_ACTIVE_REQUESTS,
|
|
||||||
unit="requests",
|
|
||||||
description="measures the number of concurrent HTTP requests those are currently in flight",
|
|
||||||
)
|
|
||||||
|
|
||||||
with tracer.start_as_current_span(
|
|
||||||
span_name,
|
|
||||||
context=extract(request, getter=getter),
|
|
||||||
kind=trace.SpanKind.SERVER,
|
|
||||||
) as span:
|
|
||||||
attributes = collect_request_attributes(request)
|
|
||||||
attributes.update(additional_attributes)
|
|
||||||
span.set_attributes(attributes)
|
|
||||||
start = default_timer()
|
|
||||||
active_requests_counter.add(1, active_requests_count_attrs)
|
|
||||||
try:
|
|
||||||
resp = await handler(request)
|
|
||||||
set_status_code(span, resp.status)
|
|
||||||
except web.HTTPException as ex:
|
|
||||||
set_status_code(span, ex.status_code)
|
|
||||||
raise
|
|
||||||
except AttributeError:
|
|
||||||
# No response was returned or a NoneType response was returned, handle gracefully
|
|
||||||
set_status_code(span, 500)
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
duration = max((default_timer() - start) * 1000, 0)
|
|
||||||
duration_histogram.record(duration, duration_attrs)
|
|
||||||
active_requests_counter.add(-1, active_requests_count_attrs)
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
class _InstrumentedApplication(web.Application):
|
|
||||||
"""Insert tracing middleware"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
middlewares = kwargs.pop("middlewares", [])
|
|
||||||
middlewares.insert(0, middleware)
|
|
||||||
kwargs["middlewares"] = middlewares
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class AioHttpServerInstrumentor(BaseInstrumentor):
|
|
||||||
# pylint: disable=protected-access,attribute-defined-outside-init
|
|
||||||
"""An instrumentor for aiohttp.web.Application
|
|
||||||
|
|
||||||
See `BaseInstrumentor`
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _instrument(self, **kwargs):
|
|
||||||
self._original_app = web.Application
|
|
||||||
setattr(web, "Application", _InstrumentedApplication)
|
|
||||||
|
|
||||||
def _uninstrument(self, **kwargs):
|
|
||||||
setattr(web, "Application", self._original_app)
|
|
||||||
|
|
||||||
def instrumentation_dependencies(self):
|
|
||||||
return _instruments
|
|
||||||
0
comfy_extras/eval_web/__init__.py
Normal file
0
comfy_extras/eval_web/__init__.py
Normal file
242
comfy_extras/eval_web/eval_python.js
Normal file
242
comfy_extras/eval_web/eval_python.js
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
/**
|
||||||
|
* Uses code adapted from https://github.com/yorkane/ComfyUI-KYNode
|
||||||
|
*
|
||||||
|
* MIT License
|
||||||
|
*
|
||||||
|
* Copyright (c) 2024 Kevin Yuan
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
* of this software and associated documentation files (the "Software"), to deal
|
||||||
|
* in the Software without restriction, including without limitation the rights
|
||||||
|
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
* copies of the Software, and to permit persons to whom the Software is
|
||||||
|
* furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in all
|
||||||
|
* copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
* SOFTWARE.
|
||||||
|
*/
|
||||||
|
import { app } from "../../scripts/app.js";
|
||||||
|
|
||||||
|
// Load Ace editor using script tag for Safari compatibility
|
||||||
|
// The noconflict build includes AMD loader that works in all browsers
|
||||||
|
let ace;
|
||||||
|
const aceLoadPromise = new Promise((resolve) => {
|
||||||
|
if (window.ace) {
|
||||||
|
ace = window.ace;
|
||||||
|
resolve();
|
||||||
|
} else {
|
||||||
|
const script = document.createElement("script");
|
||||||
|
script.src = "https://cdn.jsdelivr.net/npm/ace-builds@1.43.4/src-noconflict/ace.js";
|
||||||
|
script.onload = () => {
|
||||||
|
ace = window.ace;
|
||||||
|
ace.config.set("basePath", "https://cdn.jsdelivr.net/npm/ace-builds@1.43.4/src-noconflict");
|
||||||
|
resolve();
|
||||||
|
};
|
||||||
|
document.head.appendChild(script);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// todo: do we really want to do this here?
|
||||||
|
await aceLoadPromise;
|
||||||
|
const findWidget = (node, value, attr = "name", func = "find") => {
|
||||||
|
return node?.widgets ? node.widgets[func]((w) => (Array.isArray(value) ? value.includes(w[attr]) : w[attr] === value)) : null;
|
||||||
|
};
|
||||||
|
|
||||||
|
const makeElement = (tag, attrs = {}) => {
|
||||||
|
if (!tag) tag = "div";
|
||||||
|
const element = document.createElement(tag);
|
||||||
|
Object.keys(attrs).forEach((key) => {
|
||||||
|
const currValue = attrs[key];
|
||||||
|
if (key === "class") {
|
||||||
|
if (Array.isArray(currValue)) {
|
||||||
|
element.classList.add(...currValue);
|
||||||
|
} else if (currValue instanceof String || typeof currValue === "string") {
|
||||||
|
element.className = currValue;
|
||||||
|
}
|
||||||
|
} else if (key === "dataset") {
|
||||||
|
try {
|
||||||
|
if (Array.isArray(currValue)) {
|
||||||
|
currValue.forEach((datasetArr) => {
|
||||||
|
const [prop, propval] = Object.entries(datasetArr)[0];
|
||||||
|
element.dataset[prop] = propval;
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
Object.entries(currValue).forEach((datasetArr) => {
|
||||||
|
const [prop, propval] = datasetArr;
|
||||||
|
element.dataset[prop] = propval;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
// todo: what is this trying to do?
|
||||||
|
}
|
||||||
|
} else if (key === "style") {
|
||||||
|
if (typeof currValue === "object" && !Array.isArray(currValue) && Object.keys(currValue).length) {
|
||||||
|
Object.assign(element[key], currValue);
|
||||||
|
} else if (typeof currValue === "object" && Array.isArray(currValue) && currValue.length) {
|
||||||
|
element[key] = [...currValue];
|
||||||
|
} else if (currValue instanceof String || typeof currValue === "string") {
|
||||||
|
element[key] = currValue;
|
||||||
|
}
|
||||||
|
} else if (["for"].includes(key)) {
|
||||||
|
element.setAttribute(key, currValue);
|
||||||
|
} else if (key === "children") {
|
||||||
|
element.append(...(currValue instanceof Array ? currValue : [currValue]));
|
||||||
|
} else if (key === "parent") {
|
||||||
|
currValue.append(element);
|
||||||
|
} else {
|
||||||
|
element[key] = currValue;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return element;
|
||||||
|
};
|
||||||
|
|
||||||
|
const getPosition = (node, ctx, w_width, y, n_height) => {
|
||||||
|
const margin = 5;
|
||||||
|
|
||||||
|
const rect = ctx.canvas.getBoundingClientRect();
|
||||||
|
const transform = ctx.getTransform();
|
||||||
|
const scale = app.canvas.ds.scale;
|
||||||
|
|
||||||
|
// The context is already transformed to draw at the widget position
|
||||||
|
// transform.e and transform.f give us the canvas coordinates (in canvas pixels)
|
||||||
|
// We need to convert these to screen pixels by accounting for the canvas scale
|
||||||
|
// rect gives us the canvas element's position on the page
|
||||||
|
|
||||||
|
// The transform matrix has scale baked in (transform.a = transform.d = scale)
|
||||||
|
// transform.e and transform.f are the translation in canvas-pixel space
|
||||||
|
const canvasPixelToScreenPixel = rect.width / ctx.canvas.width;
|
||||||
|
|
||||||
|
const x = transform.e * canvasPixelToScreenPixel + rect.left;
|
||||||
|
const y_pos = transform.f * canvasPixelToScreenPixel + rect.top;
|
||||||
|
|
||||||
|
// Convert widget dimensions from canvas coordinates to screen pixels
|
||||||
|
const scaledWidth = w_width * scale;
|
||||||
|
const scaledHeight = (n_height - y - 15) * scale;
|
||||||
|
const scaledMargin = margin * scale;
|
||||||
|
const scaledY = y * scale;
|
||||||
|
|
||||||
|
return {
|
||||||
|
left: `${x + scaledMargin}px`,
|
||||||
|
top: `${y_pos + scaledY + scaledMargin}px`,
|
||||||
|
width: `${scaledWidth - scaledMargin * 2}px`,
|
||||||
|
maxWidth: `${scaledWidth - scaledMargin * 2}px`,
|
||||||
|
height: `${scaledHeight - scaledMargin * 2}px`,
|
||||||
|
maxHeight: `${scaledHeight - scaledMargin * 2}px`,
|
||||||
|
position: "absolute",
|
||||||
|
scrollbarColor: "var(--descrip-text) var(--bg-color)",
|
||||||
|
scrollbarWidth: "thin",
|
||||||
|
zIndex: app.graph._nodes.indexOf(node),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create code editor widget
|
||||||
|
const codeEditor = (node, inputName, inputData) => {
|
||||||
|
const widget = {
|
||||||
|
type: "code_block_python",
|
||||||
|
name: inputName,
|
||||||
|
options: { hideOnZoom: true },
|
||||||
|
value: inputData[1]?.default || "",
|
||||||
|
draw(ctx, node, widgetWidth, y) {
|
||||||
|
const hidden = node.flags?.collapsed || (!!this.options.hideOnZoom && app.canvas.ds.scale < 0.5) || this.type === "converted-widget" || this.type === "hidden" || this.type === "converted-widget";
|
||||||
|
|
||||||
|
this.codeElement.hidden = hidden;
|
||||||
|
|
||||||
|
if (hidden) {
|
||||||
|
this.options.onHide?.(this);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Object.assign(this.codeElement.style, getPosition(node, ctx, widgetWidth, y, node.size[1]));
|
||||||
|
},
|
||||||
|
computeSize() {
|
||||||
|
return [500, 250];
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
widget.codeElement = makeElement("pre", {
|
||||||
|
innerHTML: widget.value,
|
||||||
|
});
|
||||||
|
|
||||||
|
widget.editor = ace.edit(widget.codeElement);
|
||||||
|
widget.editor.setTheme("ace/theme/monokai");
|
||||||
|
widget.editor.session.setMode("ace/mode/python");
|
||||||
|
widget.editor.setOptions({
|
||||||
|
enableAutoIndent: true,
|
||||||
|
enableLiveAutocompletion: true,
|
||||||
|
enableBasicAutocompletion: true,
|
||||||
|
fontFamily: "monospace",
|
||||||
|
});
|
||||||
|
widget.codeElement.hidden = true;
|
||||||
|
|
||||||
|
document.body.appendChild(widget.codeElement);
|
||||||
|
|
||||||
|
const originalCollapse = node.collapse;
|
||||||
|
node.collapse = function () {
|
||||||
|
originalCollapse.apply(this, arguments);
|
||||||
|
widget.codeElement.hidden = !!this.flags?.collapsed;
|
||||||
|
};
|
||||||
|
|
||||||
|
return widget;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Trigger workflow change tracking
|
||||||
|
const markWorkflowChanged = () => {
|
||||||
|
app?.extensionManager?.workflow?.activeWorkflow?.changeTracker?.checkState();
|
||||||
|
};
|
||||||
|
|
||||||
|
// Register extensions
|
||||||
|
app.registerExtension({
|
||||||
|
name: "Comfy.EvalPython",
|
||||||
|
getCustomWidgets(app) {
|
||||||
|
return {
|
||||||
|
CODE_BLOCK_PYTHON: (node, inputName, inputData) => {
|
||||||
|
const widget = codeEditor(node, inputName, inputData);
|
||||||
|
|
||||||
|
widget.editor.getSession().on("change", () => {
|
||||||
|
widget.value = widget.editor.getValue();
|
||||||
|
markWorkflowChanged();
|
||||||
|
});
|
||||||
|
|
||||||
|
node.onRemoved = function () {
|
||||||
|
for (const w of this.widgets) {
|
||||||
|
if (w?.codeElement) {
|
||||||
|
w.codeElement.remove();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
node.addCustomWidget(widget);
|
||||||
|
|
||||||
|
return widget;
|
||||||
|
},
|
||||||
|
};
|
||||||
|
},
|
||||||
|
|
||||||
|
async beforeRegisterNodeDef(nodeType, nodeData) {
|
||||||
|
// Handle all EvalPython node variants
|
||||||
|
if (nodeData.name.startsWith("EvalPython")) {
|
||||||
|
const originalOnConfigure = nodeType.prototype.onConfigure;
|
||||||
|
nodeType.prototype.onConfigure = function (info) {
|
||||||
|
originalOnConfigure?.apply(this, arguments);
|
||||||
|
|
||||||
|
if (info?.widgets_values?.length) {
|
||||||
|
const widgetCodeIndex = findWidget(this, "code_block_python", "type", "findIndex");
|
||||||
|
const editor = this.widgets[widgetCodeIndex]?.editor;
|
||||||
|
|
||||||
|
if (editor) {
|
||||||
|
editor.setValue(info.widgets_values[widgetCodeIndex]);
|
||||||
|
editor.clearSelection();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
119
comfy_extras/nodes/nodes_eval.py
Normal file
119
comfy_extras/nodes/nodes_eval.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from comfy.comfy_types import IO
|
||||||
|
from comfy.execution_context import current_execution_context
|
||||||
|
from comfy.node_helpers import export_package_as_web_directory, export_custom_nodes
|
||||||
|
from comfy.nodes.package_typing import CustomNode
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def eval_python(inputs=5, outputs=5, name=None, input_is_list=None, output_is_list=None):
|
||||||
|
"""
|
||||||
|
Factory function to create EvalPython node classes with configurable input/output counts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Number of input value slots (default: 5)
|
||||||
|
outputs: Number of output item slots (default: 5)
|
||||||
|
name: Class name (default: f"EvalPython_{inputs}_{outputs}")
|
||||||
|
input_is_list: Optional list of bools indicating which inputs accept lists (default: None, meaning all scalar)
|
||||||
|
output_is_list: Optional tuple of bools indicating which outputs return lists (default: None, meaning all scalar)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A CustomNode subclass configured with the specified inputs/outputs
|
||||||
|
"""
|
||||||
|
if name is None:
|
||||||
|
name = f"EvalPython_{inputs}_{outputs}"
|
||||||
|
|
||||||
|
default_code = f"""
|
||||||
|
print("Hello World!")
|
||||||
|
return {", ".join([f"value{i}" for i in range(inputs)])}
|
||||||
|
"""
|
||||||
|
|
||||||
|
class EvalPythonNode(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"pycode": (
|
||||||
|
"CODE_BLOCK_PYTHON",
|
||||||
|
{
|
||||||
|
"default": default_code
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {f"value{i}": (IO.ANY, {}) for i in range(inputs)},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = tuple(IO.ANY for _ in range(outputs))
|
||||||
|
RETURN_NAMES = tuple(f"item{i}" for i in range(outputs))
|
||||||
|
FUNCTION = "exec_py"
|
||||||
|
DESCRIPTION = ""
|
||||||
|
CATEGORY = "eval"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(cls, *args, **kwargs):
|
||||||
|
ctx = current_execution_context()
|
||||||
|
|
||||||
|
return ctx.configuration.enable_eval
|
||||||
|
|
||||||
|
def exec_py(self, pycode, **kwargs):
|
||||||
|
ctx = current_execution_context()
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
**{f"value{i}": None for i in range(inputs)},
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
def print(*args):
|
||||||
|
ctx.server.send_progress_text(" ".join(map(str, args)), ctx.node_id)
|
||||||
|
|
||||||
|
if not ctx.configuration.enable_eval:
|
||||||
|
raise ValueError("Python eval is disabled")
|
||||||
|
|
||||||
|
value_args = [kwargs.pop(f"value{i}") for i in range(inputs)]
|
||||||
|
arg_names = ", ".join(f"value{i}=None" for i in range(inputs))
|
||||||
|
|
||||||
|
wrapped_code = f"def _eval_func({arg_names}):\n"
|
||||||
|
for line in pycode.splitlines():
|
||||||
|
wrapped_code += " " + line + "\n"
|
||||||
|
|
||||||
|
globals_for_eval = {
|
||||||
|
**kwargs,
|
||||||
|
"logger": logger,
|
||||||
|
"print": print,
|
||||||
|
}
|
||||||
|
|
||||||
|
exec(wrapped_code, globals_for_eval)
|
||||||
|
results = globals_for_eval["_eval_func"](*value_args)
|
||||||
|
if not isinstance(results, tuple):
|
||||||
|
results = (results,)
|
||||||
|
|
||||||
|
if len(results) < outputs:
|
||||||
|
results += (None,) * (outputs - len(results))
|
||||||
|
elif len(results) > outputs:
|
||||||
|
results = results[:outputs]
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
# todo: interact better with the weird comfyui machinery for this
|
||||||
|
if input_is_list is not None:
|
||||||
|
setattr(EvalPythonNode, "INPUT_IS_LIST", input_is_list)
|
||||||
|
|
||||||
|
if output_is_list is not None:
|
||||||
|
setattr(EvalPythonNode, "OUTPUT_IS_LIST", output_is_list)
|
||||||
|
|
||||||
|
EvalPythonNode.__name__ = name
|
||||||
|
EvalPythonNode.__qualname__ = name
|
||||||
|
|
||||||
|
return EvalPythonNode
|
||||||
|
|
||||||
|
|
||||||
|
EvalPython_1_1 = eval_python(inputs=1, outputs=1, name="EvalPython_1_1")
|
||||||
|
EvalPython_5_5 = eval_python(inputs=5, outputs=5, name="EvalPython_5_5")
|
||||||
|
EvalPython_List_1 = eval_python(inputs=1, outputs=1, name="EvalPython_List_1", input_is_list=True, output_is_list=None)
|
||||||
|
EvalPython_1_List = eval_python(inputs=1, outputs=1, name="EvalPython_1_List", input_is_list=None, output_is_list=(True,))
|
||||||
|
EvalPython_List_List = eval_python(inputs=1, outputs=1, name="EvalPython_List_List", input_is_list=True, output_is_list=(True,))
|
||||||
|
|
||||||
|
export_custom_nodes()
|
||||||
|
export_package_as_web_directory("comfy_extras.eval_web")
|
||||||
@ -65,13 +65,18 @@ dependencies = [
|
|||||||
"natsort",
|
"natsort",
|
||||||
"OpenEXR",
|
"OpenEXR",
|
||||||
"opentelemetry-distro",
|
"opentelemetry-distro",
|
||||||
"opentelemetry-sdk<1.34.0",
|
"opentelemetry-sdk",
|
||||||
"opentelemetry-exporter-otlp<=1.27.0",
|
"opentelemetry-exporter-otlp",
|
||||||
"opentelemetry-propagator-jaeger",
|
"opentelemetry-propagator-jaeger",
|
||||||
"opentelemetry-instrumentation",
|
"opentelemetry-instrumentation",
|
||||||
"opentelemetry-util-http",
|
"opentelemetry-util-http",
|
||||||
"opentelemetry-instrumentation-aio-pika",
|
"opentelemetry-instrumentation-aio-pika",
|
||||||
"opentelemetry-instrumentation-requests",
|
"opentelemetry-instrumentation-requests",
|
||||||
|
"opentelemetry-instrumentation-aiohttp-server",
|
||||||
|
"opentelemetry-instrumentation-aiohttp-client",
|
||||||
|
"opentelemetry-instrumentation-asyncio",
|
||||||
|
"opentelemetry-instrumentation-urllib3",
|
||||||
|
"opentelemetry-processor-baggage",
|
||||||
"opentelemetry-semantic-conventions",
|
"opentelemetry-semantic-conventions",
|
||||||
"wrapt>=1.16.0",
|
"wrapt>=1.16.0",
|
||||||
"certifi",
|
"certifi",
|
||||||
|
|||||||
@ -3,8 +3,6 @@ import multiprocessing
|
|||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import urllib
|
import urllib
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
@ -12,9 +10,9 @@ from typing import List, Any, Generator
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
from comfy.cli_args import default_configuration
|
|
||||||
from comfy.execution_context import context_configuration
|
|
||||||
|
|
||||||
os.environ['OTEL_METRICS_EXPORTER'] = 'none'
|
os.environ['OTEL_METRICS_EXPORTER'] = 'none'
|
||||||
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
||||||
@ -22,6 +20,7 @@ os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
|
|||||||
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
|
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
|
||||||
os.environ["TC_HOST"] = "localhost"
|
os.environ["TC_HOST"] = "localhost"
|
||||||
|
|
||||||
|
from comfy.cli_args import default_configuration
|
||||||
from comfy.cli_args_types import Configuration
|
from comfy.cli_args_types import Configuration
|
||||||
|
|
||||||
logging.getLogger("pika").setLevel(logging.CRITICAL + 1)
|
logging.getLogger("pika").setLevel(logging.CRITICAL + 1)
|
||||||
@ -91,9 +90,17 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
|
|||||||
params = rabbitmq.get_connection_params()
|
params = rabbitmq.get_connection_params()
|
||||||
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||||
|
|
||||||
|
# Check if OTEL endpoint is configured for integration testing
|
||||||
|
otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
if otel_endpoint:
|
||||||
|
env["OTEL_EXPORTER_OTLP_ENDPOINT"] = otel_endpoint
|
||||||
|
logging.info(f"Configuring services to export traces to: {otel_endpoint}")
|
||||||
|
|
||||||
frontend_command = [
|
frontend_command = [
|
||||||
"comfyui",
|
"comfyui",
|
||||||
"--listen=127.0.0.1",
|
"--listen=0.0.0.0",
|
||||||
"--port=19001",
|
"--port=19001",
|
||||||
"--cpu",
|
"--cpu",
|
||||||
"--distributed-queue-frontend",
|
"--distributed-queue-frontend",
|
||||||
@ -101,7 +108,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
|
|||||||
f"--distributed-queue-connection-uri={connection_uri}",
|
f"--distributed-queue-connection-uri={connection_uri}",
|
||||||
]
|
]
|
||||||
|
|
||||||
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
|
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr, env=env))
|
||||||
|
|
||||||
# Start multiple workers
|
# Start multiple workers
|
||||||
for i in range(num_workers):
|
for i in range(num_workers):
|
||||||
@ -112,7 +119,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
|
|||||||
f"--distributed-queue-connection-uri={connection_uri}",
|
f"--distributed-queue-connection-uri={connection_uri}",
|
||||||
f"--executor-factory={executor_factory}"
|
f"--executor-factory={executor_factory}"
|
||||||
]
|
]
|
||||||
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
|
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr, env=env))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
server_address = f"http://127.0.0.1:19001"
|
server_address = f"http://127.0.0.1:19001"
|
||||||
|
|||||||
@ -243,3 +243,368 @@ async def test_two_workers_distinct_requests():
|
|||||||
all_workflows.update(worker.processed_workflows)
|
all_workflows.update(worker.processed_workflows)
|
||||||
|
|
||||||
assert len(all_workflows) == 2, f"Expected 2 distinct workflows, but got {len(all_workflows)}"
|
assert len(all_workflows) == 2, f"Expected 2 distinct workflows, but got {len(all_workflows)}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_error_reporting_blocking_request(frontend_backend_worker_with_rabbitmq):
|
||||||
|
"""Test error reporting with blocking request (no async preference)"""
|
||||||
|
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||||
|
# Create an invalid prompt that will cause a validation error
|
||||||
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
||||||
|
# Make the prompt invalid by referencing a non-existent checkpoint
|
||||||
|
prompt["4"]["inputs"]["ckpt_name"] = "nonexistent_checkpoint.safetensors"
|
||||||
|
|
||||||
|
# Post with blocking behavior (no prefer header for async)
|
||||||
|
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
|
||||||
|
async with client.session.post(
|
||||||
|
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
|
||||||
|
data=prompt_json,
|
||||||
|
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}
|
||||||
|
) as response:
|
||||||
|
# Should return 400 for validation error (invalid checkpoint)
|
||||||
|
assert response.status == 400, f"Expected 400, got {response.status}"
|
||||||
|
error_body = await response.json()
|
||||||
|
|
||||||
|
# Verify ValidationErrorDict structure per OpenAPI spec
|
||||||
|
assert "type" in error_body, "Missing 'type' field in error response"
|
||||||
|
assert "message" in error_body, "Missing 'message' field in error response"
|
||||||
|
assert "details" in error_body, "Missing 'details' field in error response"
|
||||||
|
assert "extra_info" in error_body, "Missing 'extra_info' field in error response"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_error_reporting_async_prefer_header(frontend_backend_worker_with_rabbitmq):
|
||||||
|
"""Test error reporting with Prefer: respond-async header"""
|
||||||
|
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||||
|
# Create a valid prompt structure but with invalid checkpoint
|
||||||
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
||||||
|
prompt["4"]["inputs"]["ckpt_name"] = "nonexistent.safetensors"
|
||||||
|
|
||||||
|
# Post with Prefer: respond-async header
|
||||||
|
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
|
||||||
|
async with client.session.post(
|
||||||
|
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
|
||||||
|
data=prompt_json,
|
||||||
|
headers={
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Accept': 'application/json',
|
||||||
|
'Prefer': 'respond-async'
|
||||||
|
}
|
||||||
|
) as response:
|
||||||
|
# Should return 400 immediately for validation error
|
||||||
|
assert response.status == 400, f"Expected 400 for validation error, got {response.status}"
|
||||||
|
error_body = await response.json()
|
||||||
|
assert "type" in error_body
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_error_reporting_async_accept_mimetype(frontend_backend_worker_with_rabbitmq):
|
||||||
|
"""Test error reporting with +respond-async in Accept mimetype"""
|
||||||
|
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||||
|
# Create a prompt with validation error
|
||||||
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
||||||
|
prompt["4"]["inputs"]["ckpt_name"] = "invalid_model.safetensors"
|
||||||
|
|
||||||
|
# Post with +respond-async in Accept header
|
||||||
|
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
|
||||||
|
async with client.session.post(
|
||||||
|
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
|
||||||
|
data=prompt_json,
|
||||||
|
headers={
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Accept': 'application/json+respond-async'
|
||||||
|
}
|
||||||
|
) as response:
|
||||||
|
# Should return 400 for validation error (happens before queuing)
|
||||||
|
assert response.status == 400, f"Expected 400, got {response.status}"
|
||||||
|
error_body = await response.json()
|
||||||
|
assert "type" in error_body
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_get_prompt_status_success(frontend_backend_worker_with_rabbitmq):
|
||||||
|
"""Test GET /api/v1/prompts/{prompt_id} returns 200 with Outputs on success"""
|
||||||
|
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||||
|
# Create a valid prompt
|
||||||
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
||||||
|
|
||||||
|
# Queue async to get prompt_id
|
||||||
|
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
|
||||||
|
assert task_id is not None
|
||||||
|
|
||||||
|
# Poll until done
|
||||||
|
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
|
||||||
|
|
||||||
|
# For a valid prompt, should get 200
|
||||||
|
assert status_code == 200, f"Expected 200 for successful execution, got {status_code}"
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
# Verify it returns outputs structure (dict with node IDs)
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert len(result) > 0, "Expected non-empty outputs"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_get_prompt_status_404(frontend_backend_worker_with_rabbitmq):
|
||||||
|
"""Test GET /api/v1/prompts/{prompt_id} returns 404 for non-existent prompt"""
|
||||||
|
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||||
|
# Request a non-existent prompt ID
|
||||||
|
fake_prompt_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
async with await client.get_prompt_status(fake_prompt_id) as response:
|
||||||
|
assert response.status == 404, f"Expected 404 for non-existent prompt, got {response.status}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_get_prompt_status_204_in_progress(frontend_backend_worker_with_rabbitmq):
|
||||||
|
"""Test GET /api/v1/prompts/{prompt_id} returns 204 while prompt is in progress"""
|
||||||
|
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||||
|
# Create a prompt that takes some time to execute
|
||||||
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=10, refiner_steps=10)
|
||||||
|
|
||||||
|
# Queue async
|
||||||
|
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
|
||||||
|
|
||||||
|
# Immediately check status (should be 204 or 200 if very fast)
|
||||||
|
async with await client.get_prompt_status(task_id) as response:
|
||||||
|
# Should be either 204 (in progress) or 200 (completed very fast)
|
||||||
|
assert response.status in [200, 204], f"Expected 200 or 204, got {response.status}"
|
||||||
|
|
||||||
|
if response.status == 204:
|
||||||
|
# No content for in-progress
|
||||||
|
content = await response.read()
|
||||||
|
assert len(content) == 0 or content == b'', "Expected no content for 204 response"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_async_workflow_both_methods(frontend_backend_worker_with_rabbitmq):
|
||||||
|
"""Test full async workflow: queue with respond-async, then poll for completion"""
|
||||||
|
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||||
|
# Create a valid prompt
|
||||||
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
||||||
|
|
||||||
|
# Method 1: Prefer header
|
||||||
|
task_id_1 = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
|
||||||
|
assert task_id_1 is not None
|
||||||
|
|
||||||
|
# Method 2: +respond-async in Accept header
|
||||||
|
task_id_2 = await client.queue_and_forget_prompt_api(
|
||||||
|
prompt, prefer_header=None, accept_header="application/json+respond-async"
|
||||||
|
)
|
||||||
|
assert task_id_2 is not None
|
||||||
|
|
||||||
|
# Poll both until done
|
||||||
|
status_1, result_1 = await client.poll_prompt_until_done(task_id_1, max_attempts=60, poll_interval=1.0)
|
||||||
|
status_2, result_2 = await client.poll_prompt_until_done(task_id_2, max_attempts=60, poll_interval=1.0)
|
||||||
|
|
||||||
|
# Both should succeed
|
||||||
|
assert status_1 == 200, f"Task 1 failed with status {status_1}"
|
||||||
|
assert status_2 == 200, f"Task 2 failed with status {status_2}"
|
||||||
|
|
||||||
|
# Both should have outputs
|
||||||
|
assert result_1 is not None and len(result_1) > 0
|
||||||
|
assert result_2 is not None and len(result_2) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_validation_error_structure(frontend_backend_worker_with_rabbitmq):
|
||||||
|
"""Test that validation errors return proper ValidationErrorDict structure"""
|
||||||
|
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||||
|
# Create an invalid prompt (invalid checkpoint name)
|
||||||
|
prompt = sdxl_workflow_with_refiner("test", "", 1, refiner_steps=1)
|
||||||
|
prompt["4"]["inputs"]["ckpt_name"] = "fake.safetensors"
|
||||||
|
|
||||||
|
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
|
||||||
|
|
||||||
|
async with client.session.post(
|
||||||
|
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
|
||||||
|
data=prompt_json,
|
||||||
|
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}
|
||||||
|
) as response:
|
||||||
|
assert response.status == 400, f"Expected 400, got {response.status}"
|
||||||
|
|
||||||
|
error_body = await response.json()
|
||||||
|
|
||||||
|
# Verify ValidationErrorDict structure per OpenAPI spec
|
||||||
|
assert "type" in error_body, "Missing 'type'"
|
||||||
|
assert "message" in error_body, "Missing 'message'"
|
||||||
|
assert "details" in error_body, "Missing 'details'"
|
||||||
|
assert "extra_info" in error_body, "Missing 'extra_info'"
|
||||||
|
|
||||||
|
assert error_body["type"] == "prompt_outputs_failed_validation", "unexpected type"
|
||||||
|
|
||||||
|
# extra_info should have exception_type and traceback
|
||||||
|
assert "exception_type" in error_body["extra_info"], "Missing 'exception_type' in extra_info"
|
||||||
|
assert "traceback" in error_body["extra_info"], "Missing 'traceback' in extra_info"
|
||||||
|
assert isinstance(error_body["extra_info"]["traceback"], list), "traceback should be a list"
|
||||||
|
|
||||||
|
# extra_info should have node_errors with detailed validation information
|
||||||
|
assert "node_errors" in error_body["extra_info"], "Missing 'node_errors' in extra_info"
|
||||||
|
node_errors = error_body["extra_info"]["node_errors"]
|
||||||
|
assert isinstance(node_errors, dict), "node_errors should be a dict"
|
||||||
|
assert len(node_errors) > 0, "node_errors should contain at least one node"
|
||||||
|
|
||||||
|
# Verify node_errors structure for node "4" (CheckpointLoaderSimple with invalid ckpt_name)
|
||||||
|
assert "4" in node_errors, "Node '4' should have validation errors"
|
||||||
|
node_4_errors = node_errors["4"]
|
||||||
|
assert "errors" in node_4_errors, "Node '4' should have 'errors' field"
|
||||||
|
assert "class_type" in node_4_errors, "Node '4' should have 'class_type' field"
|
||||||
|
assert "dependent_outputs" in node_4_errors, "Node '4' should have 'dependent_outputs' field"
|
||||||
|
|
||||||
|
assert node_4_errors["class_type"] == "CheckpointLoaderSimple", "Node '4' class_type should be CheckpointLoaderSimple"
|
||||||
|
assert len(node_4_errors["errors"]) > 0, "Node '4' should have at least one error"
|
||||||
|
|
||||||
|
# Verify the error details include the validation error type and message
|
||||||
|
first_error = node_4_errors["errors"][0]
|
||||||
|
assert "type" in first_error, "Error should have 'type' field"
|
||||||
|
assert "message" in first_error, "Error should have 'message' field"
|
||||||
|
assert "details" in first_error, "Error should have 'details' field"
|
||||||
|
assert first_error["type"] == "value_not_in_list", f"Expected 'value_not_in_list' error, got {first_error['type']}"
|
||||||
|
assert "fake.safetensors" in first_error["details"], "Error details should mention 'fake.safetensors'"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_success_response_contract(frontend_backend_worker_with_rabbitmq):
|
||||||
|
"""Test that successful execution returns proper response structure"""
|
||||||
|
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||||
|
# Create a valid prompt
|
||||||
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
||||||
|
|
||||||
|
# Queue and wait for blocking response
|
||||||
|
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
|
||||||
|
async with client.session.post(
|
||||||
|
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
|
||||||
|
data=prompt_json,
|
||||||
|
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}
|
||||||
|
) as response:
|
||||||
|
assert response.status == 200, f"Expected 200, got {response.status}"
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
|
||||||
|
# Should have 'outputs' key (and deprecated 'urls' key)
|
||||||
|
assert "outputs" in result, "Missing 'outputs' in response"
|
||||||
|
|
||||||
|
# outputs should be a dict with node IDs as keys
|
||||||
|
outputs = result["outputs"]
|
||||||
|
assert isinstance(outputs, dict), "outputs should be a dict"
|
||||||
|
assert len(outputs) > 0, "outputs should not be empty"
|
||||||
|
|
||||||
|
# Each output should follow the Output schema
|
||||||
|
for node_id, output in outputs.items():
|
||||||
|
assert isinstance(output, dict), f"Output for node {node_id} should be a dict"
|
||||||
|
# Should have images or other output types
|
||||||
|
if "images" in output:
|
||||||
|
assert isinstance(output["images"], list), f"images for node {node_id} should be a list"
|
||||||
|
for image in output["images"]:
|
||||||
|
assert "filename" in image, f"image missing 'filename' in node {node_id}"
|
||||||
|
assert "subfolder" in image, f"image missing 'subfolder' in node {node_id}"
|
||||||
|
assert "type" in image, f"image missing 'type' in node {node_id}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_get_prompt_returns_outputs_directly(frontend_backend_worker_with_rabbitmq):
|
||||||
|
"""Test GET /api/v1/prompts/{prompt_id} returns Outputs directly (not wrapped in history entry)"""
|
||||||
|
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||||
|
# Create and queue a prompt
|
||||||
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
||||||
|
task_id = await client.queue_and_forget_prompt_api(prompt)
|
||||||
|
|
||||||
|
# Poll until done
|
||||||
|
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
|
||||||
|
|
||||||
|
assert status_code == 200, f"Expected 200, got {status_code}"
|
||||||
|
assert result is not None, "Result should not be None"
|
||||||
|
|
||||||
|
# Per OpenAPI spec, GET should return Outputs directly, not wrapped
|
||||||
|
# result should be a dict with node IDs as keys
|
||||||
|
assert isinstance(result, dict), "Result should be a dict (Outputs)"
|
||||||
|
|
||||||
|
# Should NOT have 'prompt', 'outputs', 'status' keys (those are in history entry)
|
||||||
|
# Should have node IDs directly
|
||||||
|
for key in result.keys():
|
||||||
|
# Node IDs are typically numeric strings like "4", "13", etc.
|
||||||
|
# Should not be "prompt", "outputs", "status"
|
||||||
|
assert key not in ["prompt", "status"], \
|
||||||
|
f"GET endpoint should return Outputs directly, not history entry. Found key: {key}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_execution_error_blocking_mode(frontend_backend_worker_with_rabbitmq):
|
||||||
|
"""Test that execution errors (not validation) return proper error structure in blocking mode"""
|
||||||
|
from comfy_execution.graph_utils import GraphBuilder
|
||||||
|
|
||||||
|
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||||
|
# Create a prompt that will fail during execution (not validation)
|
||||||
|
# Use Regex with a group name that doesn't exist - validation passes but execution fails
|
||||||
|
g = GraphBuilder()
|
||||||
|
regex_match = g.node("Regex", pattern="hello", string="hello world")
|
||||||
|
# Request a non-existent group name - this will pass validation but fail during execution
|
||||||
|
match_group = g.node("RegexMatchGroupByName", match=regex_match.out(0), name="nonexistent_group")
|
||||||
|
g.node("SaveString", value=match_group.out(0), filename_prefix="test")
|
||||||
|
|
||||||
|
prompt = g.finalize()
|
||||||
|
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
|
||||||
|
|
||||||
|
async with client.session.post(
|
||||||
|
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
|
||||||
|
data=prompt_json,
|
||||||
|
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}
|
||||||
|
) as response:
|
||||||
|
# Execution errors return 500
|
||||||
|
assert response.status == 500, f"Expected 500 for execution error, got {response.status}"
|
||||||
|
|
||||||
|
error_body = await response.json()
|
||||||
|
|
||||||
|
# Verify ExecutionStatus structure
|
||||||
|
assert "status_str" in error_body, "Missing 'status_str'"
|
||||||
|
assert "completed" in error_body, "Missing 'completed'"
|
||||||
|
assert "messages" in error_body, "Missing 'messages'"
|
||||||
|
|
||||||
|
assert error_body["status_str"] == "error", f"Expected 'error', got {error_body['status_str']}"
|
||||||
|
assert error_body["completed"] == False, "completed should be False for errors"
|
||||||
|
assert isinstance(error_body["messages"], list), "messages should be a list"
|
||||||
|
assert len(error_body["messages"]) > 0, "messages should contain error details"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_execution_error_async_mode(frontend_backend_worker_with_rabbitmq):
|
||||||
|
"""Test that execution errors return proper error structure in respond-async mode"""
|
||||||
|
from comfy_execution.graph_utils import GraphBuilder
|
||||||
|
|
||||||
|
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||||
|
# Create a prompt that will fail during execution (not validation)
|
||||||
|
# Use Regex with a group name that doesn't exist - validation passes but execution fails
|
||||||
|
g = GraphBuilder()
|
||||||
|
regex_match = g.node("Regex", pattern="hello", string="hello world")
|
||||||
|
# Request a non-existent group name - this will pass validation but fail during execution
|
||||||
|
match_group = g.node("RegexMatchGroupByName", match=regex_match.out(0), name="nonexistent_group")
|
||||||
|
g.node("SaveString", value=match_group.out(0), filename_prefix="test")
|
||||||
|
|
||||||
|
prompt = g.finalize()
|
||||||
|
|
||||||
|
# Queue with respond-async
|
||||||
|
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
|
||||||
|
assert task_id is not None, "Should get task_id in async mode"
|
||||||
|
|
||||||
|
# Poll for completion
|
||||||
|
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
|
||||||
|
|
||||||
|
# In async mode with polling, errors come back as 200 with error in the response body
|
||||||
|
# because the prompt was accepted (202) and we're just retrieving the completed result
|
||||||
|
assert status_code in (200, 500), f"Expected 200 or 500, got {status_code}"
|
||||||
|
|
||||||
|
if status_code == 500:
|
||||||
|
# Error returned directly - should be ExecutionStatus
|
||||||
|
assert "status_str" in result, "Missing 'status_str'"
|
||||||
|
assert "completed" in result, "Missing 'completed'"
|
||||||
|
assert "messages" in result, "Missing 'messages'"
|
||||||
|
assert result["status_str"] == "error"
|
||||||
|
assert result["completed"] == False
|
||||||
|
assert len(result["messages"]) > 0
|
||||||
|
else:
|
||||||
|
# Error in successful response - result might be ExecutionStatus or empty outputs
|
||||||
|
# If it's a dict with status info, verify it
|
||||||
|
if "status_str" in result:
|
||||||
|
assert result["status_str"] == "error"
|
||||||
|
assert result["completed"] == False
|
||||||
|
assert len(result["messages"]) > 0
|
||||||
|
|||||||
128
tests/distributed/test_tracing.py
Normal file
128
tests/distributed/test_tracing.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.ERROR)
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from testcontainers.rabbitmq import RabbitMqContainer
|
||||||
|
from opentelemetry import trace, propagate, context
|
||||||
|
from opentelemetry.trace import SpanKind
|
||||||
|
from opentelemetry.sdk.trace import TracerProvider
|
||||||
|
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||||
|
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
|
||||||
|
|
||||||
|
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
||||||
|
from comfy.component_model.make_mutable import make_mutable
|
||||||
|
from comfy.component_model.queue_types import QueueItem, QueueTuple, ExecutionStatus
|
||||||
|
from comfy.distributed.server_stub import ServerStub
|
||||||
|
|
||||||
|
|
||||||
|
async def create_test_prompt() -> QueueItem:
|
||||||
|
from comfy.cmd.execution import validate_prompt
|
||||||
|
|
||||||
|
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
|
||||||
|
item_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
validation_tuple = await validate_prompt(item_id, prompt)
|
||||||
|
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
|
||||||
|
return QueueItem(queue_tuple, None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rabbitmq_message_properties_contain_trace_context():
|
||||||
|
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||||
|
params = rabbitmq.get_connection_params()
|
||||||
|
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||||
|
|
||||||
|
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||||
|
import aio_pika
|
||||||
|
|
||||||
|
exporter = InMemorySpanExporter()
|
||||||
|
provider = TracerProvider()
|
||||||
|
provider.add_span_processor(SimpleSpanProcessor(exporter))
|
||||||
|
tracer = trace.get_tracer(__name__, tracer_provider=provider)
|
||||||
|
|
||||||
|
with tracer.start_as_current_span("test_message_headers", kind=SpanKind.PRODUCER):
|
||||||
|
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
|
||||||
|
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
|
||||||
|
queue_item = await create_test_prompt()
|
||||||
|
|
||||||
|
put_task = asyncio.create_task(frontend.put_async(queue_item))
|
||||||
|
|
||||||
|
incoming, incoming_prompt_id = await worker.get_async(timeout=5.0)
|
||||||
|
assert incoming is not None, "Worker should receive message"
|
||||||
|
|
||||||
|
worker.task_done(incoming_prompt_id, {}, ExecutionStatus("success", True, []))
|
||||||
|
|
||||||
|
result = await put_task
|
||||||
|
assert result is not None, "Frontend should get result"
|
||||||
|
|
||||||
|
# Now inspect the RabbitMQ queue directly to see message structure
|
||||||
|
connection = await aio_pika.connect_robust(connection_uri)
|
||||||
|
channel = await connection.channel()
|
||||||
|
|
||||||
|
# Declare a test queue to inspect message format
|
||||||
|
test_queue = await channel.declare_queue("test_inspection_queue", durable=False, auto_delete=True)
|
||||||
|
|
||||||
|
# Publish a test message with trace context
|
||||||
|
carrier = {}
|
||||||
|
propagate.inject(carrier)
|
||||||
|
|
||||||
|
test_message = aio_pika.Message(
|
||||||
|
body=b"test",
|
||||||
|
headers=carrier
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel.default_exchange.publish(
|
||||||
|
test_message,
|
||||||
|
routing_key=test_queue.name
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get and inspect the message
|
||||||
|
received = await test_queue.get(timeout=2, fail=False)
|
||||||
|
if received:
|
||||||
|
headers = received.headers or {}
|
||||||
|
|
||||||
|
# Document what trace headers should be present
|
||||||
|
# OpenTelemetry uses 'traceparent' header for W3C Trace Context
|
||||||
|
has_traceparent = "traceparent" in headers
|
||||||
|
|
||||||
|
assert has_traceparent
|
||||||
|
|
||||||
|
await received.ack()
|
||||||
|
|
||||||
|
await connection.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_distributed_queue_uses_async_interface():
|
||||||
|
"""
|
||||||
|
Test that demonstrates the correct way to use DistributedPromptQueue in async context.
|
||||||
|
The synchronous get() method cannot be used in async tests due to event loop assertions.
|
||||||
|
"""
|
||||||
|
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||||
|
params = rabbitmq.get_connection_params()
|
||||||
|
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||||
|
|
||||||
|
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||||
|
|
||||||
|
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
|
||||||
|
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
|
||||||
|
queue_item = await create_test_prompt()
|
||||||
|
|
||||||
|
# Start consuming in background
|
||||||
|
result_future = asyncio.create_task(frontend.put_async(queue_item))
|
||||||
|
|
||||||
|
# Worker gets item asynchronously (not using blocking get())
|
||||||
|
incoming, incoming_prompt_id = await worker.get_async(timeout=5.0)
|
||||||
|
assert incoming is not None, "Should receive a queue item"
|
||||||
|
|
||||||
|
# Complete the work
|
||||||
|
worker.task_done(incoming_prompt_id, {}, ExecutionStatus("success", True, []))
|
||||||
|
|
||||||
|
# Wait for frontend to complete
|
||||||
|
result = await result_future
|
||||||
|
assert result is not None, "Should get result from worker"
|
||||||
|
assert result.status.status_str == "success"
|
||||||
569
tests/distributed/test_tracing_integration.py
Normal file
569
tests/distributed/test_tracing_integration.py
Normal file
@ -0,0 +1,569 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for distributed tracing across RabbitMQ and services.
|
||||||
|
|
||||||
|
These tests validate that trace context propagates correctly from frontend
|
||||||
|
to backend workers through RabbitMQ, and that Jaeger can reconstruct the
|
||||||
|
full distributed trace.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||||
|
from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor
|
||||||
|
from opentelemetry.sdk.resources import Resource
|
||||||
|
from opentelemetry.sdk.trace import TracerProvider
|
||||||
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||||
|
from opentelemetry.semconv.attributes import service_attributes
|
||||||
|
from testcontainers.core.container import DockerContainer
|
||||||
|
from testcontainers.core.waiting_utils import wait_for_logs
|
||||||
|
from testcontainers.nginx import NginxContainer
|
||||||
|
|
||||||
|
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class JaegerContainer(DockerContainer):
|
||||||
|
"""Testcontainer for Jaeger all-in-one with OTLP support."""
|
||||||
|
|
||||||
|
def __init__(self, image: str = "jaegertracing/all-in-one:latest"):
|
||||||
|
super().__init__(image)
|
||||||
|
self.with_exposed_ports(16686, 4318, 14268) # UI, OTLP HTTP, Jaeger HTTP
|
||||||
|
self.with_env("COLLECTOR_OTLP_ENABLED", "true")
|
||||||
|
|
||||||
|
def get_query_url(self) -> str:
|
||||||
|
"""Get Jaeger Query API URL."""
|
||||||
|
host = self.get_container_host_ip()
|
||||||
|
port = self.get_exposed_port(16686)
|
||||||
|
return f"http://{host}:{port}"
|
||||||
|
|
||||||
|
def get_otlp_endpoint(self) -> str:
|
||||||
|
"""Get OTLP HTTP endpoint for sending traces."""
|
||||||
|
host = self.get_container_host_ip()
|
||||||
|
port = self.get_exposed_port(4318)
|
||||||
|
return f"http://{host}:{port}"
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
super().start()
|
||||||
|
wait_for_logs(self, ".*Starting GRPC server.*", timeout=30)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def nginx_proxy(frontend_backend_worker_with_rabbitmq):
|
||||||
|
"""
|
||||||
|
Provide an nginx proxy in front of the ComfyUI frontend.
|
||||||
|
This tests if nginx is blocking W3C trace context propagation.
|
||||||
|
"""
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
# Extract host and port from frontend address
|
||||||
|
frontend_url = frontend_backend_worker_with_rabbitmq
|
||||||
|
# frontend_url is like "http://127.0.0.1:19001"
|
||||||
|
import re
|
||||||
|
match = re.match(r'http://([^:]+):(\d+)', frontend_url)
|
||||||
|
if not match:
|
||||||
|
raise ValueError(f"Could not parse frontend URL: {frontend_url}")
|
||||||
|
|
||||||
|
frontend_host = match.group(1)
|
||||||
|
frontend_port = match.group(2)
|
||||||
|
nginx_port = 8085
|
||||||
|
|
||||||
|
# Get the Docker bridge gateway IP (this is how containers reach the host on Linux)
|
||||||
|
# Try to get the default Docker bridge gateway
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["docker", "network", "inspect", "bridge", "-f", "{{range .IPAM.Config}}{{.Gateway}}{{end}}"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=True
|
||||||
|
)
|
||||||
|
docker_gateway = result.stdout.strip()
|
||||||
|
logger.info(f"Using Docker gateway IP: {docker_gateway}")
|
||||||
|
except Exception as e:
|
||||||
|
# Fallback: try common gateway IPs
|
||||||
|
docker_gateway = "172.17.0.1" # Default Docker bridge gateway on Linux
|
||||||
|
logger.warning(f"Could not detect Docker gateway, using default: {docker_gateway}")
|
||||||
|
|
||||||
|
# Create nginx config that proxies to the frontend and passes trace headers
|
||||||
|
nginx_conf = f"""
|
||||||
|
events {{
|
||||||
|
worker_connections 1024;
|
||||||
|
}}
|
||||||
|
|
||||||
|
http {{
|
||||||
|
upstream backend {{
|
||||||
|
server {docker_gateway}:{frontend_port};
|
||||||
|
}}
|
||||||
|
|
||||||
|
server {{
|
||||||
|
listen {nginx_port};
|
||||||
|
|
||||||
|
location / {{
|
||||||
|
proxy_pass http://backend;
|
||||||
|
proxy_set_header Host $host;
|
||||||
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Write config to a temporary file
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.conf', delete=False) as f:
|
||||||
|
f.write(nginx_conf)
|
||||||
|
nginx_conf_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Start nginx container with the config
|
||||||
|
nginx = NginxContainer(port=nginx_port)
|
||||||
|
nginx.with_volume_mapping(nginx_conf_path, "/etc/nginx/nginx.conf")
|
||||||
|
nginx.start()
|
||||||
|
|
||||||
|
# Get the nginx URL
|
||||||
|
host = nginx.get_container_host_ip()
|
||||||
|
port = nginx.get_exposed_port(nginx_port)
|
||||||
|
nginx_url = f"http://{host}:{port}"
|
||||||
|
|
||||||
|
logger.info(f"Nginx proxy started at {nginx_url} -> {frontend_url}")
|
||||||
|
|
||||||
|
# Wait for nginx to be ready
|
||||||
|
for _ in range(30):
|
||||||
|
try:
|
||||||
|
response = requests.get(nginx_url, timeout=1)
|
||||||
|
if response.status_code:
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
yield nginx_url
|
||||||
|
finally:
|
||||||
|
nginx.stop()
|
||||||
|
os.unlink(nginx_conf_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def jaeger_container():
|
||||||
|
"""
|
||||||
|
Provide a Jaeger container for collecting traces.
|
||||||
|
|
||||||
|
This fixture automatically sets OTEL_EXPORTER_OTLP_ENDPOINT to point to the
|
||||||
|
Jaeger container, and cleans it up when the container stops.
|
||||||
|
"""
|
||||||
|
container = JaegerContainer()
|
||||||
|
container.start()
|
||||||
|
|
||||||
|
# Wait for Jaeger to be fully ready
|
||||||
|
query_url = container.get_query_url()
|
||||||
|
otlp_endpoint = container.get_otlp_endpoint()
|
||||||
|
|
||||||
|
for _ in range(30):
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{query_url}/api/services")
|
||||||
|
if response.status_code == 200:
|
||||||
|
logger.info(f"Jaeger ready at {query_url}")
|
||||||
|
logger.info(f"OTLP endpoint: {otlp_endpoint}")
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# Set OTEL_EXPORTER_OTLP_ENDPOINT for the duration of the test
|
||||||
|
old_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
|
||||||
|
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = otlp_endpoint
|
||||||
|
logger.info(f"Set OTEL_EXPORTER_OTLP_ENDPOINT={otlp_endpoint}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield container
|
||||||
|
finally:
|
||||||
|
# Restore original OTEL_EXPORTER_OTLP_ENDPOINT
|
||||||
|
if old_endpoint is not None:
|
||||||
|
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = old_endpoint
|
||||||
|
logger.info(f"Restored OTEL_EXPORTER_OTLP_ENDPOINT={old_endpoint}")
|
||||||
|
else:
|
||||||
|
os.environ.pop("OTEL_EXPORTER_OTLP_ENDPOINT", None)
|
||||||
|
logger.info("Removed OTEL_EXPORTER_OTLP_ENDPOINT")
|
||||||
|
|
||||||
|
container.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def query_jaeger_traces(jaeger_url: str, service: str, operation: str = None,
|
||||||
|
lookback: str = "1h", limit: int = 100) -> dict:
|
||||||
|
"""
|
||||||
|
Query Jaeger for traces.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
jaeger_url: Base URL of Jaeger query service
|
||||||
|
service: Service name to query
|
||||||
|
operation: Optional operation name filter
|
||||||
|
lookback: Lookback period (e.g., "1h", "30m")
|
||||||
|
limit: Maximum number of traces to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON response from Jaeger API
|
||||||
|
"""
|
||||||
|
params = {
|
||||||
|
"service": service,
|
||||||
|
"lookback": lookback,
|
||||||
|
"limit": limit
|
||||||
|
}
|
||||||
|
if operation:
|
||||||
|
params["operation"] = operation
|
||||||
|
|
||||||
|
response = requests.get(f"{jaeger_url}/api/traces", params=params)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
def find_trace_by_operation(traces_response: dict, operation_name: str) -> dict:
|
||||||
|
"""Find a specific trace by operation name."""
|
||||||
|
for trace in traces_response.get("data", []):
|
||||||
|
for span in trace.get("spans", []):
|
||||||
|
if span.get("operationName") == operation_name:
|
||||||
|
return trace
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def verify_trace_continuity(trace: dict, expected_services: list[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Verify that a trace spans multiple services and maintains parent-child relationships.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trace: Jaeger trace object
|
||||||
|
expected_services: List of service names expected in the trace
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if trace shows proper distributed tracing across services
|
||||||
|
"""
|
||||||
|
if not trace:
|
||||||
|
return False
|
||||||
|
|
||||||
|
spans = trace.get("spans", [])
|
||||||
|
if not spans:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check that all expected services are present
|
||||||
|
trace_services = set()
|
||||||
|
for span in spans:
|
||||||
|
process_id = span.get("processID")
|
||||||
|
if process_id:
|
||||||
|
process = trace.get("processes", {}).get(process_id, {})
|
||||||
|
service_name = process.get("serviceName")
|
||||||
|
if service_name:
|
||||||
|
trace_services.add(service_name)
|
||||||
|
|
||||||
|
logger.info(f"Trace contains services: {trace_services}")
|
||||||
|
logger.info(f"Expected services: {set(expected_services)}")
|
||||||
|
|
||||||
|
# Verify all expected services are present
|
||||||
|
for service in expected_services:
|
||||||
|
if service not in trace_services:
|
||||||
|
logger.warning(f"Expected service '{service}' not found in trace")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Verify all spans share the same trace ID
|
||||||
|
trace_ids = set(span.get("traceID") for span in spans)
|
||||||
|
if len(trace_ids) != 1:
|
||||||
|
logger.warning(f"Multiple trace IDs found: {trace_ids}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Verify parent-child relationships exist
|
||||||
|
span_ids = {span.get("spanID") for span in spans}
|
||||||
|
has_parent_refs = False
|
||||||
|
|
||||||
|
for span in spans:
|
||||||
|
references = span.get("references", [])
|
||||||
|
for ref in references:
|
||||||
|
if ref.get("refType") == "CHILD_OF":
|
||||||
|
parent_span_id = ref.get("spanID")
|
||||||
|
if parent_span_id in span_ids:
|
||||||
|
has_parent_refs = True
|
||||||
|
logger.info(f"Found parent-child relationship: {parent_span_id} -> {span.get('spanID')}")
|
||||||
|
|
||||||
|
if not has_parent_refs:
|
||||||
|
logger.warning("No parent-child relationships found in trace")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# order matters, execute jaeger_container first
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tracing_integration(jaeger_container, nginx_proxy):
|
||||||
|
"""
|
||||||
|
Integration test for distributed tracing across services with nginx proxy.
|
||||||
|
|
||||||
|
This test:
|
||||||
|
1. Starts ComfyUI frontend and worker with RabbitMQ
|
||||||
|
2. Starts nginx proxy in front of the frontend to test trace context propagation through nginx
|
||||||
|
3. Configures OTLP export to Jaeger testcontainer
|
||||||
|
4. Submits a workflow through the nginx proxy
|
||||||
|
5. Queries Jaeger to verify trace propagation
|
||||||
|
6. Validates that the trace spans multiple services with proper relationships
|
||||||
|
|
||||||
|
This specifically tests if nginx is blocking W3C trace context (traceparent/tracestate headers).
|
||||||
|
"""
|
||||||
|
server_address = nginx_proxy
|
||||||
|
jaeger_url = jaeger_container.get_query_url()
|
||||||
|
otlp_endpoint = jaeger_container.get_otlp_endpoint()
|
||||||
|
|
||||||
|
logger.info(f"Frontend server: {server_address}")
|
||||||
|
logger.info(f"Jaeger UI: {jaeger_url}")
|
||||||
|
logger.info(f"OTLP endpoint: {otlp_endpoint}")
|
||||||
|
|
||||||
|
# Set up tracing for the async HTTP client
|
||||||
|
resource = Resource.create({
|
||||||
|
service_attributes.SERVICE_NAME: "comfyui-client",
|
||||||
|
})
|
||||||
|
provider = TracerProvider(resource=resource)
|
||||||
|
exporter = OTLPSpanExporter(endpoint=f"{otlp_endpoint}/v1/traces")
|
||||||
|
processor = BatchSpanProcessor(exporter)
|
||||||
|
provider.add_span_processor(processor)
|
||||||
|
from opentelemetry import trace
|
||||||
|
|
||||||
|
trace.set_tracer_provider(provider)
|
||||||
|
|
||||||
|
# Instrument aiohttp client
|
||||||
|
AioHttpClientInstrumentor().instrument()
|
||||||
|
|
||||||
|
# we have to call this very late, so that the instrumentation isn't initialized too early
|
||||||
|
from comfy.client.aio_client import AsyncRemoteComfyClient
|
||||||
|
|
||||||
|
# Note: In a real integration test, you'd need to configure the ComfyUI
|
||||||
|
# services to export traces to this Jaeger instance. For now, this test
|
||||||
|
# documents the expected behavior.
|
||||||
|
|
||||||
|
# Create a unique prompt to identify our trace
|
||||||
|
test_id = str(uuid.uuid4())[:8]
|
||||||
|
prompt = sdxl_workflow_with_refiner(f"test_trace_{test_id}", inference_steps=1, refiner_steps=1)
|
||||||
|
|
||||||
|
# Get the tracer for the client
|
||||||
|
client_tracer = trace.get_tracer("test_tracing_integration")
|
||||||
|
|
||||||
|
# Submit the workflow - wrap in a span to capture the trace ID
|
||||||
|
with client_tracer.start_as_current_span("submit_workflow") as workflow_span:
|
||||||
|
trace_id = format(workflow_span.get_span_context().trace_id, '032x')
|
||||||
|
logger.info(f"Started trace with trace_id: {trace_id}")
|
||||||
|
|
||||||
|
async with AsyncRemoteComfyClient(server_address=server_address) as client:
|
||||||
|
logger.info(f"Submitting workflow with test_id: {test_id}")
|
||||||
|
|
||||||
|
# Queue the prompt with async response
|
||||||
|
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
|
||||||
|
assert task_id is not None, "Failed to get task ID"
|
||||||
|
|
||||||
|
logger.info(f"Queued task: {task_id}")
|
||||||
|
|
||||||
|
# Poll for completion
|
||||||
|
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
|
||||||
|
assert status_code == 200, f"Task failed with status {status_code}"
|
||||||
|
logger.info("Task completed successfully")
|
||||||
|
|
||||||
|
# Give Jaeger time to receive and process spans
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
# Query Jaeger for traces
|
||||||
|
# Note: The actual service names depend on how your services are configured
|
||||||
|
# Common service names might be: "slack-bot", "comfyui-frontend", "comfyui-worker"
|
||||||
|
|
||||||
|
expected_services = ["comfyui", "comfyui-client"] # Adjust based on actual service names
|
||||||
|
|
||||||
|
logger.info(f"Querying Jaeger for traces with trace_id: {trace_id}...")
|
||||||
|
|
||||||
|
# First, try to find our specific trace by trace_id from the client service
|
||||||
|
our_trace = None
|
||||||
|
for service in expected_services:
|
||||||
|
try:
|
||||||
|
traces_response = query_jaeger_traces(jaeger_url, service, lookback="5m")
|
||||||
|
if traces_response.get("data"):
|
||||||
|
logger.info(f"Found {len(traces_response['data'])} traces for service '{service}'")
|
||||||
|
for trace in traces_response["data"]:
|
||||||
|
if trace.get("traceID") == trace_id:
|
||||||
|
our_trace = trace
|
||||||
|
logger.info(f"Found our trace in service '{service}'")
|
||||||
|
break
|
||||||
|
if our_trace:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not query traces for service '{service}': {e}")
|
||||||
|
|
||||||
|
# Assert we can find the trace we just created
|
||||||
|
assert our_trace is not None, (
|
||||||
|
f"Could not find trace with trace_id {trace_id} in Jaeger. "
|
||||||
|
f"This indicates that spans from comfyui-client are not being exported correctly."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Successfully found trace with trace_id {trace_id}")
|
||||||
|
|
||||||
|
# Extract services from the trace
|
||||||
|
trace_services = set()
|
||||||
|
for span in our_trace.get("spans", []):
|
||||||
|
process_id = span.get("processID")
|
||||||
|
if process_id:
|
||||||
|
process = our_trace.get("processes", {}).get(process_id, {})
|
||||||
|
service_name = process.get("serviceName")
|
||||||
|
if service_name:
|
||||||
|
trace_services.add(service_name)
|
||||||
|
|
||||||
|
logger.info(f"Services found in trace: {trace_services}")
|
||||||
|
|
||||||
|
# Assert that comfyui-client service is present (since we instrumented it)
|
||||||
|
assert "comfyui-client" in trace_services, (
|
||||||
|
f"Expected 'comfyui-client' service in trace, but found only: {trace_services}. "
|
||||||
|
f"This indicates the client instrumentation is not working."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate trace structure
|
||||||
|
logger.info(f"Analyzing trace with {len(our_trace.get('spans', []))} spans")
|
||||||
|
|
||||||
|
# Log all spans for debugging
|
||||||
|
for span in our_trace.get("spans", []):
|
||||||
|
process_id = span.get("processID")
|
||||||
|
process = our_trace.get("processes", {}).get(process_id, {})
|
||||||
|
service_name = process.get("serviceName", "unknown")
|
||||||
|
operation = span.get("operationName", "unknown")
|
||||||
|
logger.info(f" Span: {service_name}.{operation}")
|
||||||
|
|
||||||
|
# Verify trace continuity - only if both services are present
|
||||||
|
assert "comfyui" in trace_services
|
||||||
|
is_continuous = verify_trace_continuity(our_trace, expected_services)
|
||||||
|
|
||||||
|
# This assertion documents what SHOULD happen when distributed tracing works
|
||||||
|
assert is_continuous, (
|
||||||
|
"Trace does not show proper distributed tracing. "
|
||||||
|
"Expected to see spans from multiple services with parent-child relationships. "
|
||||||
|
"This indicates that trace context is not being propagated correctly through RabbitMQ."
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trace_context_in_http_headers(frontend_backend_worker_with_rabbitmq):
|
||||||
|
"""
|
||||||
|
Test that HTTP requests include traceparent headers.
|
||||||
|
|
||||||
|
This validates that the HTTP layer is properly instrumented for tracing.
|
||||||
|
"""
|
||||||
|
server_address = frontend_backend_worker_with_rabbitmq
|
||||||
|
|
||||||
|
# Make a simple HTTP request and check for trace headers
|
||||||
|
# Note: We're checking the server's response headers to see if it's trace-aware
|
||||||
|
response = requests.get(f"{server_address}/system_stats")
|
||||||
|
|
||||||
|
logger.info(f"Response headers: {dict(response.headers)}")
|
||||||
|
|
||||||
|
# The server should be instrumented and may include trace context in responses
|
||||||
|
# or at minimum, should accept traceparent headers in requests
|
||||||
|
|
||||||
|
# Test sending a traceparent header
|
||||||
|
test_traceparent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"
|
||||||
|
response_with_trace = requests.get(
|
||||||
|
f"{server_address}/system_stats",
|
||||||
|
headers={"traceparent": test_traceparent}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not error when traceparent is provided
|
||||||
|
assert response_with_trace.status_code == 200, "Server should accept traceparent header"
|
||||||
|
|
||||||
|
logger.info("✓ Server accepts traceparent headers in HTTP requests")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_requests_different_traces(frontend_backend_worker_with_rabbitmq, jaeger_container):
|
||||||
|
"""
|
||||||
|
Test that multiple independent requests create separate traces.
|
||||||
|
|
||||||
|
This validates that trace context is properly scoped per request.
|
||||||
|
"""
|
||||||
|
server_address = frontend_backend_worker_with_rabbitmq
|
||||||
|
|
||||||
|
# Submit multiple workflows
|
||||||
|
task_ids = []
|
||||||
|
|
||||||
|
from comfy.client.aio_client import AsyncRemoteComfyClient
|
||||||
|
async with AsyncRemoteComfyClient(server_address=server_address) as client:
|
||||||
|
for i in range(3):
|
||||||
|
prompt = sdxl_workflow_with_refiner(f"test_{i}", inference_steps=1, refiner_steps=1)
|
||||||
|
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
|
||||||
|
task_ids.append(task_id)
|
||||||
|
logger.info(f"Queued task {i}: {task_id}")
|
||||||
|
|
||||||
|
# Wait for all to complete
|
||||||
|
for i, task_id in enumerate(task_ids):
|
||||||
|
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
|
||||||
|
assert status_code == 200, f"Task {i} failed"
|
||||||
|
logger.info(f"Task {i} completed")
|
||||||
|
|
||||||
|
# Give Jaeger time to receive spans
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
# Query Jaeger and verify we have multiple distinct traces
|
||||||
|
jaeger_url = jaeger_container.get_query_url()
|
||||||
|
|
||||||
|
traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m", limit=10)
|
||||||
|
traces = traces_response.get("data", [])
|
||||||
|
|
||||||
|
assert len(traces) >= 2
|
||||||
|
# Get trace IDs
|
||||||
|
trace_ids = [trace.get("traceID") for trace in traces]
|
||||||
|
unique_trace_ids = set(trace_ids)
|
||||||
|
|
||||||
|
logger.info(f"Found {len(unique_trace_ids)} unique traces")
|
||||||
|
|
||||||
|
# Verify we have multiple distinct traces
|
||||||
|
assert len(unique_trace_ids) >= 2, (
|
||||||
|
f"Expected at least 2 distinct traces, found {len(unique_trace_ids)}. "
|
||||||
|
"Each request should create its own trace."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("✓ Multiple requests created distinct traces")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skip(reason="rabbitmq has to be configured for observability?")
|
||||||
|
async def test_trace_contains_rabbitmq_operations(frontend_backend_worker_with_rabbitmq, jaeger_container):
|
||||||
|
"""
|
||||||
|
Test that traces include RabbitMQ publish/consume operations.
|
||||||
|
|
||||||
|
This is critical for distributed tracing - the RabbitMQ operations
|
||||||
|
are what link the frontend and backend spans together.
|
||||||
|
"""
|
||||||
|
server_address = frontend_backend_worker_with_rabbitmq
|
||||||
|
jaeger_url = jaeger_container.get_query_url()
|
||||||
|
|
||||||
|
# Submit a workflow
|
||||||
|
from comfy.client.aio_client import AsyncRemoteComfyClient
|
||||||
|
async with AsyncRemoteComfyClient(server_address=server_address) as client:
|
||||||
|
prompt = sdxl_workflow_with_refiner("test_rmq", inference_steps=1, refiner_steps=1)
|
||||||
|
task_id = await client.queue_and_forget_prompt_api(prompt)
|
||||||
|
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60)
|
||||||
|
assert status_code == 200
|
||||||
|
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m")
|
||||||
|
traces = traces_response.get("data", [])
|
||||||
|
|
||||||
|
# Look for RabbitMQ-related operations in any trace
|
||||||
|
rabbitmq_operations = [
|
||||||
|
"publish", "consume", "amq_queue_publish", "amq_queue_consume",
|
||||||
|
"amq.basic.publish", "amq.basic.consume", "send", "receive"
|
||||||
|
]
|
||||||
|
|
||||||
|
found_rabbitmq_ops = []
|
||||||
|
for trace in traces:
|
||||||
|
for span in trace.get("spans", []):
|
||||||
|
op_name = span.get("operationName", "").lower()
|
||||||
|
for rmq_op in rabbitmq_operations:
|
||||||
|
if rmq_op in op_name:
|
||||||
|
found_rabbitmq_ops.append(op_name)
|
||||||
|
|
||||||
|
assert found_rabbitmq_ops, "No RabbitMQ-related operations found in traces"
|
||||||
683
tests/unit/test_eval_nodes.py
Normal file
683
tests/unit/test_eval_nodes.py
Normal file
@ -0,0 +1,683 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from comfy.cli_args import default_configuration
|
||||||
|
from comfy.execution_context import context_configuration
|
||||||
|
from comfy_extras.nodes.nodes_eval import (
|
||||||
|
eval_python,
|
||||||
|
EvalPython_5_5,
|
||||||
|
EvalPython_List_1,
|
||||||
|
EvalPython_1_List,
|
||||||
|
EvalPython_List_List,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def eval_context():
|
||||||
|
"""Fixture that sets up execution context with eval enabled"""
|
||||||
|
config = default_configuration()
|
||||||
|
config.enable_eval = True
|
||||||
|
with context_configuration(config):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_basic_return(eval_context):
|
||||||
|
"""Test basic return statement with single value"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
result = node.exec_py(pycode="return 42", value0=0, value1=1, value2=2, value3=3, value4=4)
|
||||||
|
assert result == (42, None, None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_multiple_returns(eval_context):
|
||||||
|
"""Test return statement with tuple of values"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return 1, 2, 3",
|
||||||
|
value0=0, value1=0, value2=0, value3=0, value4=0
|
||||||
|
)
|
||||||
|
assert result == (1, 2, 3, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_all_five_returns(eval_context):
|
||||||
|
"""Test return statement with all five values"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return 'a', 'b', 'c', 'd', 'e'",
|
||||||
|
value0=0, value1=0, value2=0, value3=0, value4=0
|
||||||
|
)
|
||||||
|
assert result == ('a', 'b', 'c', 'd', 'e')
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_excess_returns(eval_context):
|
||||||
|
"""Test that excess return values are truncated to 5"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return 1, 2, 3, 4, 5, 6, 7",
|
||||||
|
value0=0, value1=0, value2=0, value3=0, value4=0
|
||||||
|
)
|
||||||
|
assert result == (1, 2, 3, 4, 5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_use_value_args(eval_context):
|
||||||
|
"""Test that value arguments are accessible in pycode"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return value0 + value1 + value2",
|
||||||
|
value0=10, value1=20, value2=30, value3=0, value4=0
|
||||||
|
)
|
||||||
|
assert result == (60, None, None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_all_value_args(eval_context):
|
||||||
|
"""Test all value arguments are accessible"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return value0, value1, value2, value3, value4",
|
||||||
|
value0=1, value1=2, value2=3, value3=4, value4=5
|
||||||
|
)
|
||||||
|
assert result == (1, 2, 3, 4, 5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_computation(eval_context):
|
||||||
|
"""Test computation with value arguments"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
code = """
|
||||||
|
x = value0 * 2
|
||||||
|
y = value1 * 3
|
||||||
|
z = x + y
|
||||||
|
return z
|
||||||
|
"""
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode=code,
|
||||||
|
value0=5, value1=10, value2=0, value3=0, value4=0
|
||||||
|
)
|
||||||
|
assert result == (40, None, None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_multiline(eval_context):
|
||||||
|
"""Test multiline code with conditionals"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
code = """
|
||||||
|
if value0 > 10:
|
||||||
|
result = "large"
|
||||||
|
else:
|
||||||
|
result = "small"
|
||||||
|
return result, value0
|
||||||
|
"""
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode=code,
|
||||||
|
value0=15, value1=0, value2=0, value3=0, value4=0
|
||||||
|
)
|
||||||
|
assert result == ("large", 15, None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_list_comprehension(eval_context):
|
||||||
|
"""Test list comprehension and iteration"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
code = """
|
||||||
|
numbers = [value0, value1, value2]
|
||||||
|
doubled = [x * 2 for x in numbers]
|
||||||
|
return sum(doubled)
|
||||||
|
"""
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode=code,
|
||||||
|
value0=1, value1=2, value2=3, value3=0, value4=0
|
||||||
|
)
|
||||||
|
assert result == (12, None, None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_string_operations(eval_context):
|
||||||
|
"""Test string operations"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
code = """
|
||||||
|
s1 = str(value0)
|
||||||
|
s2 = str(value1)
|
||||||
|
return s1 + s2, len(s1 + s2)
|
||||||
|
"""
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode=code,
|
||||||
|
value0=123, value1=456, value2=0, value3=0, value4=0
|
||||||
|
)
|
||||||
|
assert result == ("123456", 6, None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_type_mixing(eval_context):
|
||||||
|
"""Test mixing different types"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
code = """
|
||||||
|
return value0, str(value1), float(value2), bool(value3)
|
||||||
|
"""
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode=code,
|
||||||
|
value0=42, value1=100, value2=3, value3=1, value4=0
|
||||||
|
)
|
||||||
|
assert result == (42, "100", 3.0, True, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_logger_available(eval_context):
|
||||||
|
"""Test that logger is available in eval context"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
code = """
|
||||||
|
logger.info("test log")
|
||||||
|
return "success"
|
||||||
|
"""
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode=code,
|
||||||
|
value0=0, value1=0, value2=0, value3=0, value4=0
|
||||||
|
)
|
||||||
|
assert result == ("success", None, None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_print_available(eval_context):
|
||||||
|
"""Test that print function is available"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
code = """
|
||||||
|
print("Hello World!")
|
||||||
|
return "printed"
|
||||||
|
"""
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode=code,
|
||||||
|
value0=0, value1=0, value2=0, value3=0, value4=0
|
||||||
|
)
|
||||||
|
assert result == ("printed", None, None, None, None)
|
||||||
|
|
||||||
|
def test_eval_python_print_is_called(eval_context):
|
||||||
|
"""Test that print function is called and receives correct arguments"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
|
||||||
|
# Track print calls
|
||||||
|
print_calls = []
|
||||||
|
|
||||||
|
code = """
|
||||||
|
print("Hello", "World")
|
||||||
|
print("Line 2")
|
||||||
|
return "done"
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Mock exec to capture the globals dict and verify print is there
|
||||||
|
original_exec = exec
|
||||||
|
captured_globals = {}
|
||||||
|
|
||||||
|
def mock_exec(code_str, globals_dict, *args, **kwargs):
|
||||||
|
# Capture the globals dict
|
||||||
|
captured_globals.update(globals_dict)
|
||||||
|
|
||||||
|
# Wrap the print function to track calls
|
||||||
|
original_print = globals_dict.get('print')
|
||||||
|
if original_print:
|
||||||
|
def tracked_print(*args):
|
||||||
|
print_calls.append(args)
|
||||||
|
return original_print(*args)
|
||||||
|
globals_dict['print'] = tracked_print
|
||||||
|
|
||||||
|
# Run the original exec
|
||||||
|
return original_exec(code_str, globals_dict, *args, **kwargs)
|
||||||
|
|
||||||
|
with patch('builtins.exec', side_effect=mock_exec):
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode=code,
|
||||||
|
value0=0, value1=0, value2=0, value3=0, value4=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result == ("done", None, None, None, None)
|
||||||
|
|
||||||
|
# Verify print was in the globals
|
||||||
|
assert 'print' in captured_globals
|
||||||
|
|
||||||
|
# Verify print was called twice with correct arguments
|
||||||
|
assert len(print_calls) == 2
|
||||||
|
assert print_calls[0] == ("Hello", "World")
|
||||||
|
assert print_calls[1] == ("Line 2",)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_print_sends_to_server(eval_context):
|
||||||
|
"""Test that print sends messages to PromptServer via context"""
|
||||||
|
from comfy.execution_context import current_execution_context
|
||||||
|
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
ctx = current_execution_context()
|
||||||
|
|
||||||
|
# Mock the server's send_progress_text method
|
||||||
|
original_send = ctx.server.send_progress_text if hasattr(ctx.server, 'send_progress_text') else None
|
||||||
|
mock_send = Mock()
|
||||||
|
ctx.server.send_progress_text = mock_send
|
||||||
|
|
||||||
|
code = """
|
||||||
|
print("Hello", "World")
|
||||||
|
print("Value:", value0)
|
||||||
|
return "done"
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode=code,
|
||||||
|
value0=42, value1=0, value2=0, value3=0, value4=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result == ("done", None, None, None, None)
|
||||||
|
|
||||||
|
# Verify print messages were sent to server
|
||||||
|
assert mock_send.call_count == 2
|
||||||
|
|
||||||
|
# Verify the messages sent
|
||||||
|
calls = mock_send.call_args_list
|
||||||
|
assert calls[0][0][0] == "Hello World"
|
||||||
|
assert calls[0][0][1] == ctx.node_id
|
||||||
|
assert calls[1][0][0] == "Value: 42"
|
||||||
|
assert calls[1][0][1] == ctx.node_id
|
||||||
|
finally:
|
||||||
|
# Restore original
|
||||||
|
if original_send:
|
||||||
|
ctx.server.send_progress_text = original_send
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_config_disabled_raises():
|
||||||
|
"""Test that enable_eval=False raises an error"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
config = default_configuration()
|
||||||
|
config.enable_eval = False
|
||||||
|
with context_configuration(config):
|
||||||
|
with pytest.raises(ValueError, match="Python eval is disabled"):
|
||||||
|
node.exec_py(
|
||||||
|
pycode="return 42",
|
||||||
|
value0=0, value1=0, value2=0, value3=0, value4=0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_config_enabled_works(eval_context):
|
||||||
|
"""Test that enable_eval=True allows execution"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return 42",
|
||||||
|
value0=0, value1=0, value2=0, value3=0, value4=0
|
||||||
|
)
|
||||||
|
assert result == (42, None, None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_default_code(eval_context):
|
||||||
|
"""Test the default code example works"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
# Get the default code from INPUT_TYPES
|
||||||
|
default_code = EvalPython_5_5.INPUT_TYPES()["required"]["pycode"][1]["default"]
|
||||||
|
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode=default_code,
|
||||||
|
value0=1, value1=2, value2=3, value3=4, value4=5
|
||||||
|
)
|
||||||
|
# Default code prints and returns the values
|
||||||
|
assert result == (1, 2, 3, 4, 5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_function_definition(eval_context):
|
||||||
|
"""Test defining and using functions"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
code = """
|
||||||
|
def multiply(a, b):
|
||||||
|
return a * b
|
||||||
|
|
||||||
|
result = multiply(value0, value1)
|
||||||
|
return result
|
||||||
|
"""
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode=code,
|
||||||
|
value0=7, value1=8, value2=0, value3=0, value4=0
|
||||||
|
)
|
||||||
|
assert result == (56, None, None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_nested_functions(eval_context):
|
||||||
|
"""Test nested function definitions"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
code = """
|
||||||
|
def outer(x):
|
||||||
|
def inner(y):
|
||||||
|
return y * 2
|
||||||
|
return inner(x) + 10
|
||||||
|
|
||||||
|
result = outer(value0)
|
||||||
|
return result
|
||||||
|
"""
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode=code,
|
||||||
|
value0=5, value1=0, value2=0, value3=0, value4=0
|
||||||
|
)
|
||||||
|
assert result == (20, None, None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_dict_operations(eval_context):
|
||||||
|
"""Test dictionary creation and operations"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
code = """
|
||||||
|
data = {
|
||||||
|
'a': value0,
|
||||||
|
'b': value1,
|
||||||
|
'c': value2
|
||||||
|
}
|
||||||
|
return sum(data.values()), len(data)
|
||||||
|
"""
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode=code,
|
||||||
|
value0=10, value1=20, value2=30, value3=0, value4=0
|
||||||
|
)
|
||||||
|
assert result == (60, 3, None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_list_operations(eval_context):
|
||||||
|
"""Test list creation and operations"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
code = """
|
||||||
|
items = [value0, value1, value2, value3, value4]
|
||||||
|
filtered = [x for x in items if x > 5]
|
||||||
|
return len(filtered), sum(filtered)
|
||||||
|
"""
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode=code,
|
||||||
|
value0=1, value1=10, value2=3, value3=15, value4=2
|
||||||
|
)
|
||||||
|
assert result == (2, 25, None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_early_return(eval_context):
|
||||||
|
"""Test early return in conditional"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
code = """
|
||||||
|
if value0 > 100:
|
||||||
|
return "large"
|
||||||
|
return "small"
|
||||||
|
"""
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode=code,
|
||||||
|
value0=150, value1=0, value2=0, value3=0, value4=0
|
||||||
|
)
|
||||||
|
assert result == ("large", None, None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_loop_with_return(eval_context):
|
||||||
|
"""Test loop with return statement"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
code = """
|
||||||
|
total = 0
|
||||||
|
for i in range(value0):
|
||||||
|
total += i
|
||||||
|
return total
|
||||||
|
"""
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode=code,
|
||||||
|
value0=10, value1=0, value2=0, value3=0, value4=0
|
||||||
|
)
|
||||||
|
assert result == (45, None, None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_exception_handling(eval_context):
|
||||||
|
"""Test try/except blocks"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
code = """
|
||||||
|
try:
|
||||||
|
result = value0 / value1
|
||||||
|
except ZeroDivisionError:
|
||||||
|
result = float('inf')
|
||||||
|
return result
|
||||||
|
"""
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode=code,
|
||||||
|
value0=10, value1=0, value2=0, value3=0, value4=0
|
||||||
|
)
|
||||||
|
assert result == (float('inf'), None, None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_none_values(eval_context):
|
||||||
|
"""Test handling None values in inputs"""
|
||||||
|
node = EvalPython_5_5()
|
||||||
|
code = """
|
||||||
|
return value0, value1 is None, value2 is None
|
||||||
|
"""
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode=code,
|
||||||
|
value0=42, value1=None, value2=None, value3=0, value4=0
|
||||||
|
)
|
||||||
|
assert result == (42, True, True, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_input_types():
|
||||||
|
"""Test that INPUT_TYPES returns correct structure"""
|
||||||
|
input_types = EvalPython_5_5.INPUT_TYPES()
|
||||||
|
assert "required" in input_types
|
||||||
|
assert "optional" in input_types
|
||||||
|
assert "pycode" in input_types["required"]
|
||||||
|
assert input_types["required"]["pycode"][0] == "CODE_BLOCK_PYTHON"
|
||||||
|
|
||||||
|
# Check optional inputs
|
||||||
|
for i in range(5):
|
||||||
|
assert f"value{i}" in input_types["optional"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_metadata():
|
||||||
|
"""Test node metadata"""
|
||||||
|
assert EvalPython_5_5.FUNCTION == "exec_py"
|
||||||
|
assert EvalPython_5_5.CATEGORY == "eval"
|
||||||
|
assert len(EvalPython_5_5.RETURN_TYPES) == 5
|
||||||
|
assert len(EvalPython_5_5.RETURN_NAMES) == 5
|
||||||
|
assert all(name.startswith("item") for name in EvalPython_5_5.RETURN_NAMES)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_factory_custom_inputs_outputs(eval_context):
|
||||||
|
"""Test creating nodes with custom input/output counts"""
|
||||||
|
# Create a node with 3 inputs and 2 outputs
|
||||||
|
CustomNode = eval_python(inputs=3, outputs=2)
|
||||||
|
|
||||||
|
node = CustomNode()
|
||||||
|
|
||||||
|
# Verify INPUT_TYPES has correct number of inputs
|
||||||
|
input_types = CustomNode.INPUT_TYPES()
|
||||||
|
assert len(input_types["optional"]) == 3
|
||||||
|
assert "value0" in input_types["optional"]
|
||||||
|
assert "value1" in input_types["optional"]
|
||||||
|
assert "value2" in input_types["optional"]
|
||||||
|
assert "value3" not in input_types["optional"]
|
||||||
|
|
||||||
|
# Verify RETURN_TYPES has correct number of outputs
|
||||||
|
assert len(CustomNode.RETURN_TYPES) == 2
|
||||||
|
assert len(CustomNode.RETURN_NAMES) == 2
|
||||||
|
|
||||||
|
# Test execution
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return value0 + value1 + value2, value0 * 2",
|
||||||
|
value0=1, value1=2, value2=3
|
||||||
|
)
|
||||||
|
assert result == (6, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_factory_custom_name(eval_context):
|
||||||
|
"""Test creating nodes with custom names"""
|
||||||
|
CustomNode = eval_python(inputs=2, outputs=2, name="MyCustomEval")
|
||||||
|
|
||||||
|
assert CustomNode.__name__ == "MyCustomEval"
|
||||||
|
assert CustomNode.__qualname__ == "MyCustomEval"
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_factory_default_name(eval_context):
|
||||||
|
"""Test that default name follows pattern"""
|
||||||
|
CustomNode = eval_python(inputs=3, outputs=4)
|
||||||
|
|
||||||
|
assert CustomNode.__name__ == "EvalPython_3_4"
|
||||||
|
assert CustomNode.__qualname__ == "EvalPython_3_4"
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_factory_single_output(eval_context):
|
||||||
|
"""Test node with single output"""
|
||||||
|
SingleOutputNode = eval_python(inputs=2, outputs=1)
|
||||||
|
|
||||||
|
node = SingleOutputNode()
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return value0 + value1",
|
||||||
|
value0=10, value1=20
|
||||||
|
)
|
||||||
|
assert result == (30,)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_factory_many_outputs(eval_context):
|
||||||
|
"""Test node with many outputs"""
|
||||||
|
ManyOutputNode = eval_python(inputs=1, outputs=10)
|
||||||
|
|
||||||
|
node = ManyOutputNode()
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return tuple(range(10))",
|
||||||
|
value0=0
|
||||||
|
)
|
||||||
|
assert result == (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_factory_fewer_returns_than_outputs(eval_context):
|
||||||
|
"""Test that fewer returns are padded with None"""
|
||||||
|
Node = eval_python(inputs=2, outputs=5)
|
||||||
|
|
||||||
|
node = Node()
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return value0, value1",
|
||||||
|
value0=1, value1=2
|
||||||
|
)
|
||||||
|
assert result == (1, 2, None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_factory_more_returns_than_outputs(eval_context):
|
||||||
|
"""Test that excess returns are truncated"""
|
||||||
|
Node = eval_python(inputs=2, outputs=3)
|
||||||
|
|
||||||
|
node = Node()
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return 1, 2, 3, 4, 5",
|
||||||
|
value0=0, value1=0
|
||||||
|
)
|
||||||
|
assert result == (1, 2, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_list_1_input_is_list(eval_context):
|
||||||
|
"""Test EvalPython_List_1 with list input"""
|
||||||
|
node = EvalPython_List_1()
|
||||||
|
|
||||||
|
# Verify INPUT_IS_LIST is set
|
||||||
|
assert EvalPython_List_1.INPUT_IS_LIST is True
|
||||||
|
|
||||||
|
# Test that value0 receives a list
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return sum(value0)",
|
||||||
|
value0=[1, 2, 3, 4, 5]
|
||||||
|
)
|
||||||
|
assert result == (15,)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_list_1_iterate_list(eval_context):
|
||||||
|
"""Test EvalPython_List_1 iterating over list input"""
|
||||||
|
node = EvalPython_List_1()
|
||||||
|
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return [x * 2 for x in value0]",
|
||||||
|
value0=[1, 2, 3]
|
||||||
|
)
|
||||||
|
assert result == ([2, 4, 6],)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_1_list_output_is_list(eval_context):
|
||||||
|
"""Test EvalPython_1_List with list output"""
|
||||||
|
node = EvalPython_1_List()
|
||||||
|
|
||||||
|
# Verify OUTPUT_IS_LIST is set
|
||||||
|
assert EvalPython_1_List.OUTPUT_IS_LIST == (True,)
|
||||||
|
|
||||||
|
# Test that returns a list
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return list(range(value0))",
|
||||||
|
value0=5
|
||||||
|
)
|
||||||
|
assert result == ([0, 1, 2, 3, 4],)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_1_list_multiple_items(eval_context):
|
||||||
|
"""Test EvalPython_1_List returning multiple items in list"""
|
||||||
|
node = EvalPython_1_List()
|
||||||
|
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return ['a', 'b', 'c']",
|
||||||
|
value0=0
|
||||||
|
)
|
||||||
|
assert result == (['a', 'b', 'c'],)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_list_list_both(eval_context):
|
||||||
|
"""Test EvalPython_List_List with both list input and output"""
|
||||||
|
node = EvalPython_List_List()
|
||||||
|
|
||||||
|
# Verify both are set
|
||||||
|
assert EvalPython_List_List.INPUT_IS_LIST is True
|
||||||
|
assert EvalPython_List_List.OUTPUT_IS_LIST == (True,)
|
||||||
|
|
||||||
|
# Test processing list input and returning list output
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return [x ** 2 for x in value0]",
|
||||||
|
value0=[1, 2, 3, 4]
|
||||||
|
)
|
||||||
|
assert result == ([1, 4, 9, 16],)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_list_list_filter(eval_context):
|
||||||
|
"""Test EvalPython_List_List filtering a list"""
|
||||||
|
node = EvalPython_List_List()
|
||||||
|
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return [x for x in value0 if x > 5]",
|
||||||
|
value0=[1, 3, 5, 7, 9, 11]
|
||||||
|
)
|
||||||
|
assert result == ([7, 9, 11],)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_list_list_transform(eval_context):
|
||||||
|
"""Test EvalPython_List_List transforming list elements"""
|
||||||
|
node = EvalPython_List_List()
|
||||||
|
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return [str(x).upper() for x in value0]",
|
||||||
|
value0=['hello', 'world', 'python']
|
||||||
|
)
|
||||||
|
assert result == (['HELLO', 'WORLD', 'PYTHON'],)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_factory_with_list_flags(eval_context):
|
||||||
|
"""Test factory function with custom list flags"""
|
||||||
|
# Create node with input as list but output scalar
|
||||||
|
ListInputNode = eval_python(inputs=1, outputs=1, input_is_list=True, output_is_list=None)
|
||||||
|
|
||||||
|
assert ListInputNode.INPUT_IS_LIST is True
|
||||||
|
|
||||||
|
node = ListInputNode()
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return len(value0)",
|
||||||
|
value0=[1, 2, 3, 4, 5]
|
||||||
|
)
|
||||||
|
assert result == (5,)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_factory_scalar_output_list(eval_context):
|
||||||
|
"""Test factory function with scalar input and list output"""
|
||||||
|
ScalarToListNode = eval_python(inputs=1, outputs=1, input_is_list=None, output_is_list=(True,))
|
||||||
|
|
||||||
|
assert ScalarToListNode.OUTPUT_IS_LIST == (True,)
|
||||||
|
|
||||||
|
node = ScalarToListNode()
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return [value0] * 3",
|
||||||
|
value0='x'
|
||||||
|
)
|
||||||
|
assert result == (['x', 'x', 'x'],)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_python_list_empty_list(eval_context):
|
||||||
|
"""Test list nodes with empty lists"""
|
||||||
|
node = EvalPython_List_List()
|
||||||
|
|
||||||
|
result = node.exec_py(
|
||||||
|
pycode="return []",
|
||||||
|
value0=[]
|
||||||
|
)
|
||||||
|
assert result == ([],)
|
||||||
Loading…
Reference in New Issue
Block a user