mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 17:42:58 +08:00
Merge branch 'master' into v3-dynamic-combo
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.9) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.9) (push) Has been cancelled
This commit is contained in:
commit
0f5741abb5
@ -413,7 +413,8 @@ class ControlNet(nn.Module):
|
||||
out_middle = []
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
if y is None:
|
||||
raise ValueError("y is None, did you try using a controlnet for SDXL on SD1?")
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x
|
||||
|
||||
@ -179,7 +179,10 @@ class Chroma(nn.Module):
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_mmdit:
|
||||
double_mod = (
|
||||
self.get_modulations(mod_vectors, "double_img", idx=i),
|
||||
@ -222,7 +225,10 @@ class Chroma(nn.Module):
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_dit:
|
||||
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
|
||||
if ("single_block", i) in blocks_replace:
|
||||
|
||||
@ -389,7 +389,10 @@ class HunyuanVideo(nn.Module):
|
||||
attn_mask = None
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -411,7 +414,10 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
img = torch.cat((img, txt), 1)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
|
||||
@ -439,7 +439,10 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
patches = transformer_options.get("patches", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
|
||||
@ -8,7 +8,7 @@ import os
|
||||
import textwrap
|
||||
import threading
|
||||
from enum import Enum
|
||||
from typing import Optional, Type, get_origin, get_args
|
||||
from typing import Optional, Type, get_origin, get_args, get_type_hints
|
||||
|
||||
|
||||
class TypeTracker:
|
||||
@ -220,11 +220,18 @@ class AsyncToSyncConverter:
|
||||
self._async_instance = async_class(*args, **kwargs)
|
||||
|
||||
# Handle annotated class attributes (like execution: Execution)
|
||||
# Get all annotations from the class hierarchy
|
||||
all_annotations = {}
|
||||
for base_class in reversed(inspect.getmro(async_class)):
|
||||
if hasattr(base_class, "__annotations__"):
|
||||
all_annotations.update(base_class.__annotations__)
|
||||
# Get all annotations from the class hierarchy and resolve string annotations
|
||||
try:
|
||||
# get_type_hints resolves string annotations to actual type objects
|
||||
# This handles classes using 'from __future__ import annotations'
|
||||
all_annotations = get_type_hints(async_class)
|
||||
except Exception:
|
||||
# Fallback to raw annotations if get_type_hints fails
|
||||
# (e.g., for undefined forward references)
|
||||
all_annotations = {}
|
||||
for base_class in reversed(inspect.getmro(async_class)):
|
||||
if hasattr(base_class, "__annotations__"):
|
||||
all_annotations.update(base_class.__annotations__)
|
||||
|
||||
# For each annotated attribute, check if it needs to be created or wrapped
|
||||
for attr_name, attr_type in all_annotations.items():
|
||||
@ -625,15 +632,19 @@ class AsyncToSyncConverter:
|
||||
"""Extract class attributes that are classes themselves."""
|
||||
class_attributes = []
|
||||
|
||||
# Get resolved type hints to handle string annotations
|
||||
try:
|
||||
type_hints = get_type_hints(async_class)
|
||||
except Exception:
|
||||
type_hints = {}
|
||||
|
||||
# Look for class attributes that are classes
|
||||
for name, attr in sorted(inspect.getmembers(async_class)):
|
||||
if isinstance(attr, type) and not name.startswith("_"):
|
||||
class_attributes.append((name, attr))
|
||||
elif (
|
||||
hasattr(async_class, "__annotations__")
|
||||
and name in async_class.__annotations__
|
||||
):
|
||||
annotation = async_class.__annotations__[name]
|
||||
elif name in type_hints:
|
||||
# Use resolved type hint instead of raw annotation
|
||||
annotation = type_hints[name]
|
||||
if isinstance(annotation, type):
|
||||
class_attributes.append((name, annotation))
|
||||
|
||||
@ -908,11 +919,15 @@ class AsyncToSyncConverter:
|
||||
attribute_mappings = {}
|
||||
|
||||
# First check annotations for typed attributes (including from parent classes)
|
||||
# Collect all annotations from the class hierarchy
|
||||
all_annotations = {}
|
||||
for base_class in reversed(inspect.getmro(async_class)):
|
||||
if hasattr(base_class, "__annotations__"):
|
||||
all_annotations.update(base_class.__annotations__)
|
||||
# Resolve string annotations to actual types
|
||||
try:
|
||||
all_annotations = get_type_hints(async_class)
|
||||
except Exception:
|
||||
# Fallback to raw annotations
|
||||
all_annotations = {}
|
||||
for base_class in reversed(inspect.getmro(async_class)):
|
||||
if hasattr(base_class, "__annotations__"):
|
||||
all_annotations.update(base_class.__annotations__)
|
||||
|
||||
for attr_name, attr_type in sorted(all_annotations.items()):
|
||||
for class_name, class_type in class_attributes:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from fractions import Fraction
|
||||
from typing import Optional, Union, IO
|
||||
import io
|
||||
import av
|
||||
@ -72,6 +73,33 @@ class VideoInput(ABC):
|
||||
frame_count = components.images.shape[0]
|
||||
return float(frame_count / components.frame_rate)
|
||||
|
||||
def get_frame_count(self) -> int:
|
||||
"""
|
||||
Returns the number of frames in the video.
|
||||
|
||||
Default implementation uses :meth:`get_components`, which may require
|
||||
loading all frames into memory. File-based implementations should
|
||||
override this method and use container/stream metadata instead.
|
||||
|
||||
Returns:
|
||||
Total number of frames as an integer.
|
||||
"""
|
||||
return int(self.get_components().images.shape[0])
|
||||
|
||||
def get_frame_rate(self) -> Fraction:
|
||||
"""
|
||||
Returns the frame rate of the video.
|
||||
|
||||
Default implementation materializes the video into memory via
|
||||
`get_components()`. Subclasses that can inspect the underlying
|
||||
container (e.g. `VideoFromFile`) should override this with a more
|
||||
efficient implementation.
|
||||
|
||||
Returns:
|
||||
Frame rate as a Fraction.
|
||||
"""
|
||||
return self.get_components().frame_rate
|
||||
|
||||
def get_container_format(self) -> str:
|
||||
"""
|
||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||
|
||||
@ -121,6 +121,71 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||
|
||||
def get_frame_count(self) -> int:
|
||||
"""
|
||||
Returns the number of frames in the video without materializing them as
|
||||
torch tensors.
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
# 1. Prefer the frames field if available
|
||||
if video_stream.frames and video_stream.frames > 0:
|
||||
return int(video_stream.frames)
|
||||
|
||||
# 2. Try to estimate from duration and average_rate using only metadata
|
||||
if container.duration is not None and video_stream.average_rate:
|
||||
duration_seconds = float(container.duration / av.time_base)
|
||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||
if estimated_frames > 0:
|
||||
return estimated_frames
|
||||
|
||||
if (
|
||||
getattr(video_stream, "duration", None) is not None
|
||||
and getattr(video_stream, "time_base", None) is not None
|
||||
and video_stream.average_rate
|
||||
):
|
||||
duration_seconds = float(video_stream.duration * video_stream.time_base)
|
||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||
if estimated_frames > 0:
|
||||
return estimated_frames
|
||||
|
||||
# 3. Last resort: decode frames and count them (streaming)
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
|
||||
if frame_count == 0:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
|
||||
return frame_count
|
||||
|
||||
def get_frame_rate(self) -> Fraction:
|
||||
"""
|
||||
Returns the average frame rate of the video using container metadata
|
||||
without decoding all frames.
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
|
||||
if video_stream.average_rate:
|
||||
return Fraction(video_stream.average_rate)
|
||||
|
||||
# Fallback: estimate from frames + duration if available
|
||||
if video_stream.frames and container.duration:
|
||||
duration_seconds = float(container.duration / av.time_base)
|
||||
if duration_seconds > 0:
|
||||
return Fraction(video_stream.frames / duration_seconds).limit_denominator()
|
||||
|
||||
# Last resort: match get_components_internal default
|
||||
return Fraction(1)
|
||||
|
||||
def get_container_format(self) -> str:
|
||||
"""
|
||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||
@ -238,6 +303,13 @@ class VideoFromFile(VideoInput):
|
||||
packet.stream = stream_map[packet.stream]
|
||||
output_container.mux(packet)
|
||||
|
||||
def _get_first_video_stream(self, container: InputContainer):
|
||||
video_stream = next((s for s in container.streams if s.type == "video"), None)
|
||||
if video_stream is None:
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
return video_stream
|
||||
|
||||
|
||||
class VideoFromComponents(VideoInput):
|
||||
"""
|
||||
Class representing video input from tensors.
|
||||
|
||||
@ -113,9 +113,9 @@ class GeminiGenerationConfig(BaseModel):
|
||||
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
|
||||
seed: int | None = Field(None)
|
||||
stopSequences: list[str] | None = Field(None)
|
||||
temperature: float | None = Field(1, ge=0.0, le=2.0)
|
||||
topK: int | None = Field(40, ge=1)
|
||||
topP: float | None = Field(0.95, ge=0.0, le=1.0)
|
||||
temperature: float | None = Field(None, ge=0.0, le=2.0)
|
||||
topK: int | None = Field(None, ge=1)
|
||||
topP: float | None = Field(None, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class GeminiImageConfig(BaseModel):
|
||||
|
||||
@ -104,14 +104,14 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
|
||||
List of response parts matching the requested type.
|
||||
"""
|
||||
if response.candidates is None:
|
||||
if response.promptFeedback.blockReason:
|
||||
if response.promptFeedback and response.promptFeedback.blockReason:
|
||||
feedback = response.promptFeedback
|
||||
raise ValueError(
|
||||
f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})"
|
||||
)
|
||||
raise NotImplementedError(
|
||||
"Gemini returned no response candidates. "
|
||||
"Please report to ComfyUI repository with the example of workflow to reproduce this."
|
||||
raise ValueError(
|
||||
"Gemini API returned no response candidates. If you are using the `IMAGE` modality, "
|
||||
"try changing it to `IMAGE+TEXT` to view the model's reasoning and understand why image generation failed."
|
||||
)
|
||||
parts = []
|
||||
for part in response.candidates[0].content.parts:
|
||||
@ -182,11 +182,12 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
|
||||
else:
|
||||
return None
|
||||
final_price = response.usageMetadata.promptTokenCount * input_tokens_price
|
||||
for i in response.usageMetadata.candidatesTokensDetails:
|
||||
if i.modality == Modality.IMAGE:
|
||||
final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models
|
||||
else:
|
||||
final_price += output_text_tokens_price * i.tokenCount
|
||||
if response.usageMetadata.candidatesTokensDetails:
|
||||
for i in response.usageMetadata.candidatesTokensDetails:
|
||||
if i.modality == Modality.IMAGE:
|
||||
final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models
|
||||
else:
|
||||
final_price += output_text_tokens_price * i.tokenCount
|
||||
if response.usageMetadata.thoughtsTokenCount:
|
||||
final_price += output_text_tokens_price * response.usageMetadata.thoughtsTokenCount
|
||||
return final_price / 1_000_000.0
|
||||
@ -645,7 +646,7 @@ class GeminiImage2(IO.ComfyNode):
|
||||
options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
|
||||
default="auto",
|
||||
tooltip="If set to 'auto', matches your input image's aspect ratio; "
|
||||
"if no image is provided, generates a 1:1 square.",
|
||||
"if no image is provided, a 16:9 square is usually generated.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
|
||||
@ -5,8 +5,7 @@ import aiohttp
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis import topaz_api
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
@ -282,7 +281,7 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: VideoInput,
|
||||
video: Input.Video,
|
||||
upscaler_enabled: bool,
|
||||
upscaler_model: str,
|
||||
upscaler_resolution: str,
|
||||
@ -297,12 +296,10 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
if upscaler_enabled is False and interpolation_enabled is False:
|
||||
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
|
||||
src_width, src_height = video.get_dimensions()
|
||||
video_components = video.get_components()
|
||||
src_frame_rate = int(video_components.frame_rate)
|
||||
duration_sec = video.get_duration()
|
||||
estimated_frames = int(duration_sec * src_frame_rate)
|
||||
validate_container_format_is_mp4(video)
|
||||
src_width, src_height = video.get_dimensions()
|
||||
src_frame_rate = int(video.get_frame_rate())
|
||||
duration_sec = video.get_duration()
|
||||
src_video_stream = video.get_stream_source()
|
||||
target_width = src_width
|
||||
target_height = src_height
|
||||
@ -338,7 +335,7 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
container="mp4",
|
||||
size=get_fs_object_size(src_video_stream),
|
||||
duration=int(duration_sec),
|
||||
frameCount=estimated_frames,
|
||||
frameCount=video.get_frame_count(),
|
||||
frameRate=src_frame_rate,
|
||||
resolution=topaz_api.Resolution(width=src_width, height=src_height),
|
||||
),
|
||||
|
||||
@ -38,6 +38,7 @@ class EmptyHunyuanLatentVideo(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyHunyuanLatentVideo",
|
||||
display_name="Empty HunyuanVideo 1.0 Latent",
|
||||
category="latent/video",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
@ -63,6 +64,7 @@ class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo):
|
||||
def define_schema(cls):
|
||||
schema = super().define_schema()
|
||||
schema.node_id = "EmptyHunyuanVideo15Latent"
|
||||
schema.display_name = "Empty HunyuanVideo 1.5 Latent"
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
@ -71,8 +73,6 @@ class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo):
|
||||
latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent})
|
||||
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
class HunyuanVideo15ImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.30.6
|
||||
comfyui-workflow-templates==0.6.0
|
||||
comfyui-workflow-templates==0.7.9
|
||||
comfyui-embedded-docs==0.3.1
|
||||
torch
|
||||
torchsde
|
||||
|
||||
153
tests/execution/test_public_api.py
Normal file
153
tests/execution/test_public_api.py
Normal file
@ -0,0 +1,153 @@
|
||||
"""
|
||||
Tests for public ComfyAPI and ComfyAPISync functions.
|
||||
|
||||
These tests verify that the public API methods work correctly in both sync and async contexts,
|
||||
ensuring that the sync wrapper generation (via get_type_hints() in async_to_sync.py) correctly
|
||||
handles string annotations from 'from __future__ import annotations'.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import time
|
||||
import subprocess
|
||||
import torch
|
||||
from pytest import fixture
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from tests.execution.test_execution import ComfyClient
|
||||
|
||||
|
||||
@pytest.mark.execution
|
||||
class TestPublicAPI:
|
||||
"""Test suite for public ComfyAPI and ComfyAPISync methods."""
|
||||
|
||||
@fixture(scope="class", autouse=True)
|
||||
def _server(self, args_pytest):
|
||||
"""Start ComfyUI server for testing."""
|
||||
pargs = [
|
||||
'python', 'main.py',
|
||||
'--output-directory', args_pytest["output_dir"],
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
|
||||
'--cpu',
|
||||
]
|
||||
p = subprocess.Popen(pargs)
|
||||
yield
|
||||
p.kill()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@fixture(scope="class", autouse=True)
|
||||
def shared_client(self, args_pytest, _server):
|
||||
"""Create shared client with connection retry."""
|
||||
client = ComfyClient()
|
||||
n_tries = 5
|
||||
for i in range(n_tries):
|
||||
time.sleep(4)
|
||||
try:
|
||||
client.connect(listen=args_pytest["listen"], port=args_pytest["port"])
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
if i == n_tries - 1:
|
||||
raise
|
||||
yield client
|
||||
del client
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@fixture
|
||||
def client(self, shared_client, request):
|
||||
"""Set test name for each test."""
|
||||
shared_client.set_test_name(f"public_api[{request.node.name}]")
|
||||
yield shared_client
|
||||
|
||||
@fixture
|
||||
def builder(self, request):
|
||||
"""Create GraphBuilder for each test."""
|
||||
yield GraphBuilder(prefix=request.node.name)
|
||||
|
||||
def test_sync_progress_update_executes(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that TestSyncProgressUpdate executes without errors.
|
||||
|
||||
This test validates that api_sync.execution.set_progress() works correctly,
|
||||
which is the primary code path fixed by adding get_type_hints() to async_to_sync.py.
|
||||
"""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=256, width=256, batch_size=1)
|
||||
|
||||
# Use TestSyncProgressUpdate with short sleep
|
||||
progress_node = g.node("TestSyncProgressUpdate",
|
||||
value=image.out(0),
|
||||
sleep_seconds=0.5)
|
||||
output = g.node("SaveImage", images=progress_node.out(0))
|
||||
|
||||
# Execute workflow
|
||||
result = client.run(g)
|
||||
|
||||
# Verify execution
|
||||
assert result.did_run(progress_node), "Progress node should have executed"
|
||||
assert result.did_run(output), "Output node should have executed"
|
||||
|
||||
# Verify output
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 1, "Should have produced 1 image"
|
||||
|
||||
def test_async_progress_update_executes(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that TestAsyncProgressUpdate executes without errors.
|
||||
|
||||
This test validates that await api.execution.set_progress() works correctly
|
||||
in async contexts.
|
||||
"""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="WHITE", height=256, width=256, batch_size=1)
|
||||
|
||||
# Use TestAsyncProgressUpdate with short sleep
|
||||
progress_node = g.node("TestAsyncProgressUpdate",
|
||||
value=image.out(0),
|
||||
sleep_seconds=0.5)
|
||||
output = g.node("SaveImage", images=progress_node.out(0))
|
||||
|
||||
# Execute workflow
|
||||
result = client.run(g)
|
||||
|
||||
# Verify execution
|
||||
assert result.did_run(progress_node), "Async progress node should have executed"
|
||||
assert result.did_run(output), "Output node should have executed"
|
||||
|
||||
# Verify output
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 1, "Should have produced 1 image"
|
||||
|
||||
def test_sync_and_async_progress_together(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test both sync and async progress updates in same workflow.
|
||||
|
||||
This test ensures that both ComfyAPISync and ComfyAPI can coexist and work
|
||||
correctly in the same workflow execution.
|
||||
"""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=256, width=256, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=256, width=256, batch_size=1)
|
||||
|
||||
# Use both types of progress nodes
|
||||
sync_progress = g.node("TestSyncProgressUpdate",
|
||||
value=image1.out(0),
|
||||
sleep_seconds=0.3)
|
||||
async_progress = g.node("TestAsyncProgressUpdate",
|
||||
value=image2.out(0),
|
||||
sleep_seconds=0.3)
|
||||
|
||||
# Create outputs
|
||||
output1 = g.node("SaveImage", images=sync_progress.out(0))
|
||||
output2 = g.node("SaveImage", images=async_progress.out(0))
|
||||
|
||||
# Execute workflow
|
||||
result = client.run(g)
|
||||
|
||||
# Both should execute successfully
|
||||
assert result.did_run(sync_progress), "Sync progress node should have executed"
|
||||
assert result.did_run(async_progress), "Async progress node should have executed"
|
||||
assert result.did_run(output1), "First output node should have executed"
|
||||
assert result.did_run(output2), "Second output node should have executed"
|
||||
|
||||
# Verify outputs
|
||||
images1 = result.get_images(output1)
|
||||
images2 = result.get_images(output2)
|
||||
assert len(images1) == 1, "Should have produced 1 image from sync node"
|
||||
assert len(images2) == 1, "Should have produced 1 image from async node"
|
||||
Loading…
Reference in New Issue
Block a user