mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 01:37:04 +08:00
- Add set_preview_method() to override live preview method per queue item - Read extra_data.preview_method from /prompt request - Support values: taesd, latent2rgb, none, auto, default - "default" or unset uses server's CLI --preview-method setting - Add 44 tests (37 unit + 7 E2E)
359 lines
13 KiB
Python
359 lines
13 KiB
Python
"""
|
|
E2E tests for Queue-specific Preview Method Override feature.
|
|
|
|
Tests actual execution with different preview_method values.
|
|
Requires a running ComfyUI server with models.
|
|
|
|
Usage:
|
|
COMFYUI_SERVER=http://localhost:8988 pytest test_preview_method_e2e.py -v -m preview_method
|
|
|
|
Note:
|
|
These tests execute actual image generation and wait for completion.
|
|
Tests verify preview image transmission based on preview_method setting.
|
|
"""
|
|
import os
|
|
import json
|
|
import pytest
|
|
import uuid
|
|
import time
|
|
import random
|
|
import websocket
|
|
import urllib.request
|
|
from pathlib import Path
|
|
|
|
|
|
# Server configuration
|
|
SERVER_URL = os.environ.get("COMFYUI_SERVER", "http://localhost:8988")
|
|
SERVER_HOST = SERVER_URL.replace("http://", "").replace("https://", "")
|
|
|
|
# Use existing inference graph fixture
|
|
GRAPH_FILE = Path(__file__).parent.parent / "inference" / "graphs" / "default_graph_sdxl1_0.json"
|
|
|
|
|
|
def is_server_running() -> bool:
|
|
"""Check if ComfyUI server is running."""
|
|
try:
|
|
request = urllib.request.Request(f"{SERVER_URL}/system_stats")
|
|
with urllib.request.urlopen(request, timeout=2.0):
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def prepare_graph_for_test(graph: dict, steps: int = 5) -> dict:
|
|
"""Prepare graph for testing: randomize seeds and reduce steps."""
|
|
adapted = json.loads(json.dumps(graph)) # Deep copy
|
|
for node_id, node in adapted.items():
|
|
inputs = node.get("inputs", {})
|
|
# Handle both "seed" and "noise_seed" (used by KSamplerAdvanced)
|
|
if "seed" in inputs:
|
|
inputs["seed"] = random.randint(0, 2**32 - 1)
|
|
if "noise_seed" in inputs:
|
|
inputs["noise_seed"] = random.randint(0, 2**32 - 1)
|
|
# Reduce steps for faster testing (default 20 -> 5)
|
|
if "steps" in inputs:
|
|
inputs["steps"] = steps
|
|
return adapted
|
|
|
|
|
|
# Alias for backward compatibility
|
|
randomize_seed = prepare_graph_for_test
|
|
|
|
|
|
class PreviewMethodClient:
|
|
"""Client for testing preview_method with WebSocket execution tracking."""
|
|
|
|
def __init__(self, server_address: str):
|
|
self.server_address = server_address
|
|
self.client_id = str(uuid.uuid4())
|
|
self.ws = None
|
|
|
|
def connect(self):
|
|
"""Connect to WebSocket."""
|
|
self.ws = websocket.WebSocket()
|
|
self.ws.settimeout(120) # 2 minute timeout for sampling
|
|
self.ws.connect(f"ws://{self.server_address}/ws?clientId={self.client_id}")
|
|
|
|
def close(self):
|
|
"""Close WebSocket connection."""
|
|
if self.ws:
|
|
self.ws.close()
|
|
|
|
def queue_prompt(self, prompt: dict, extra_data: dict = None) -> dict:
|
|
"""Queue a prompt and return response with prompt_id."""
|
|
data = {
|
|
"prompt": prompt,
|
|
"client_id": self.client_id,
|
|
"extra_data": extra_data or {}
|
|
}
|
|
req = urllib.request.Request(
|
|
f"http://{self.server_address}/prompt",
|
|
data=json.dumps(data).encode("utf-8"),
|
|
headers={"Content-Type": "application/json"}
|
|
)
|
|
return json.loads(urllib.request.urlopen(req).read())
|
|
|
|
def wait_for_execution(self, prompt_id: str, timeout: float = 120.0) -> dict:
|
|
"""
|
|
Wait for execution to complete via WebSocket.
|
|
|
|
Returns:
|
|
dict with keys: completed, error, preview_count, execution_time
|
|
"""
|
|
result = {
|
|
"completed": False,
|
|
"error": None,
|
|
"preview_count": 0,
|
|
"execution_time": 0.0
|
|
}
|
|
|
|
start_time = time.time()
|
|
self.ws.settimeout(timeout)
|
|
|
|
try:
|
|
while True:
|
|
out = self.ws.recv()
|
|
elapsed = time.time() - start_time
|
|
|
|
if isinstance(out, str):
|
|
message = json.loads(out)
|
|
msg_type = message.get("type")
|
|
data = message.get("data", {})
|
|
|
|
if data.get("prompt_id") != prompt_id:
|
|
continue
|
|
|
|
if msg_type == "executing":
|
|
if data.get("node") is None:
|
|
# Execution complete
|
|
result["completed"] = True
|
|
result["execution_time"] = elapsed
|
|
break
|
|
|
|
elif msg_type == "execution_error":
|
|
result["error"] = data
|
|
result["execution_time"] = elapsed
|
|
break
|
|
|
|
elif msg_type == "progress":
|
|
# Progress update during sampling
|
|
pass
|
|
|
|
elif isinstance(out, bytes):
|
|
# Binary data = preview image
|
|
result["preview_count"] += 1
|
|
|
|
except websocket.WebSocketTimeoutException:
|
|
result["error"] = "Timeout waiting for execution"
|
|
result["execution_time"] = time.time() - start_time
|
|
|
|
return result
|
|
|
|
|
|
def load_graph() -> dict:
|
|
"""Load the SDXL graph fixture with randomized seed."""
|
|
with open(GRAPH_FILE) as f:
|
|
graph = json.load(f)
|
|
return randomize_seed(graph) # Avoid caching
|
|
|
|
|
|
# Skip all tests if server is not running
|
|
pytestmark = [
|
|
pytest.mark.skipif(
|
|
not is_server_running(),
|
|
reason=f"ComfyUI server not running at {SERVER_URL}"
|
|
),
|
|
pytest.mark.preview_method,
|
|
pytest.mark.execution,
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
"""Create and connect a test client."""
|
|
c = PreviewMethodClient(SERVER_HOST)
|
|
c.connect()
|
|
yield c
|
|
c.close()
|
|
|
|
|
|
@pytest.fixture
|
|
def graph():
|
|
"""Load the test graph."""
|
|
return load_graph()
|
|
|
|
|
|
class TestPreviewMethodExecution:
|
|
"""Test actual execution with different preview methods."""
|
|
|
|
def test_execution_with_latent2rgb(self, client, graph):
|
|
"""
|
|
Execute with preview_method=latent2rgb.
|
|
Should complete and potentially receive preview images.
|
|
"""
|
|
extra_data = {"preview_method": "latent2rgb"}
|
|
|
|
response = client.queue_prompt(graph, extra_data)
|
|
assert "prompt_id" in response
|
|
|
|
result = client.wait_for_execution(response["prompt_id"])
|
|
|
|
# Should complete (may error if model missing, but that's separate)
|
|
assert result["completed"] or result["error"] is not None
|
|
# Execution should take some time (sampling)
|
|
if result["completed"]:
|
|
assert result["execution_time"] > 0.5, "Execution too fast - likely didn't run"
|
|
# latent2rgb should produce previews
|
|
print(f"latent2rgb: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
|
|
|
def test_execution_with_taesd(self, client, graph):
|
|
"""
|
|
Execute with preview_method=taesd.
|
|
TAESD provides higher quality previews.
|
|
"""
|
|
extra_data = {"preview_method": "taesd"}
|
|
|
|
response = client.queue_prompt(graph, extra_data)
|
|
assert "prompt_id" in response
|
|
|
|
result = client.wait_for_execution(response["prompt_id"])
|
|
|
|
assert result["completed"] or result["error"] is not None
|
|
if result["completed"]:
|
|
assert result["execution_time"] > 0.5
|
|
# taesd should also produce previews
|
|
print(f"taesd: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
|
|
|
def test_execution_with_none_preview(self, client, graph):
|
|
"""
|
|
Execute with preview_method=none.
|
|
No preview images should be generated.
|
|
"""
|
|
extra_data = {"preview_method": "none"}
|
|
|
|
response = client.queue_prompt(graph, extra_data)
|
|
assert "prompt_id" in response
|
|
|
|
result = client.wait_for_execution(response["prompt_id"])
|
|
|
|
assert result["completed"] or result["error"] is not None
|
|
if result["completed"]:
|
|
# With "none", should receive no preview images
|
|
assert result["preview_count"] == 0, \
|
|
f"Expected no previews with 'none', got {result['preview_count']}"
|
|
print(f"none: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
|
|
|
def test_execution_with_default(self, client, graph):
|
|
"""
|
|
Execute with preview_method=default.
|
|
Should use server's CLI default setting.
|
|
"""
|
|
extra_data = {"preview_method": "default"}
|
|
|
|
response = client.queue_prompt(graph, extra_data)
|
|
assert "prompt_id" in response
|
|
|
|
result = client.wait_for_execution(response["prompt_id"])
|
|
|
|
assert result["completed"] or result["error"] is not None
|
|
if result["completed"]:
|
|
print(f"default: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
|
|
|
def test_execution_without_preview_method(self, client, graph):
|
|
"""
|
|
Execute without preview_method in extra_data.
|
|
Should use server's default preview method.
|
|
"""
|
|
extra_data = {} # No preview_method
|
|
|
|
response = client.queue_prompt(graph, extra_data)
|
|
assert "prompt_id" in response
|
|
|
|
result = client.wait_for_execution(response["prompt_id"])
|
|
|
|
assert result["completed"] or result["error"] is not None
|
|
if result["completed"]:
|
|
print(f"(no override): {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
|
|
|
|
|
class TestPreviewMethodComparison:
|
|
"""Compare preview behavior between different methods."""
|
|
|
|
def test_none_vs_latent2rgb_preview_count(self, client, graph):
|
|
"""
|
|
Compare preview counts: 'none' should have 0, others should have >0.
|
|
This is the key verification that preview_method actually works.
|
|
"""
|
|
results = {}
|
|
|
|
# Run with none (randomize seed to avoid caching)
|
|
graph_none = randomize_seed(graph)
|
|
extra_data_none = {"preview_method": "none"}
|
|
response = client.queue_prompt(graph_none, extra_data_none)
|
|
results["none"] = client.wait_for_execution(response["prompt_id"])
|
|
|
|
# Run with latent2rgb (randomize seed again)
|
|
graph_rgb = randomize_seed(graph)
|
|
extra_data_rgb = {"preview_method": "latent2rgb"}
|
|
response = client.queue_prompt(graph_rgb, extra_data_rgb)
|
|
results["latent2rgb"] = client.wait_for_execution(response["prompt_id"])
|
|
|
|
# Verify both completed
|
|
assert results["none"]["completed"], f"'none' execution failed: {results['none']['error']}"
|
|
assert results["latent2rgb"]["completed"], f"'latent2rgb' execution failed: {results['latent2rgb']['error']}"
|
|
|
|
# Key assertion: 'none' should have 0 previews
|
|
assert results["none"]["preview_count"] == 0, \
|
|
f"'none' should have 0 previews, got {results['none']['preview_count']}"
|
|
|
|
# 'latent2rgb' should have at least 1 preview (depends on steps)
|
|
assert results["latent2rgb"]["preview_count"] > 0, \
|
|
f"'latent2rgb' should have >0 previews, got {results['latent2rgb']['preview_count']}"
|
|
|
|
print("\nPreview count comparison:") # noqa: T201
|
|
print(f" none: {results['none']['preview_count']} previews") # noqa: T201
|
|
print(f" latent2rgb: {results['latent2rgb']['preview_count']} previews") # noqa: T201
|
|
|
|
|
|
class TestPreviewMethodSequential:
|
|
"""Test sequential execution with different preview methods."""
|
|
|
|
def test_sequential_different_methods(self, client, graph):
|
|
"""
|
|
Execute multiple prompts sequentially with different preview methods.
|
|
Each should complete independently with correct preview behavior.
|
|
"""
|
|
methods = ["latent2rgb", "none", "default"]
|
|
results = []
|
|
|
|
for method in methods:
|
|
# Randomize seed for each execution to avoid caching
|
|
graph_run = randomize_seed(graph)
|
|
extra_data = {"preview_method": method}
|
|
response = client.queue_prompt(graph_run, extra_data)
|
|
|
|
result = client.wait_for_execution(response["prompt_id"])
|
|
results.append({
|
|
"method": method,
|
|
"completed": result["completed"],
|
|
"preview_count": result["preview_count"],
|
|
"execution_time": result["execution_time"],
|
|
"error": result["error"]
|
|
})
|
|
|
|
# All should complete or have clear errors
|
|
for r in results:
|
|
assert r["completed"] or r["error"] is not None, \
|
|
f"Method {r['method']} neither completed nor errored"
|
|
|
|
# "none" should have zero previews if completed
|
|
none_result = next(r for r in results if r["method"] == "none")
|
|
if none_result["completed"]:
|
|
assert none_result["preview_count"] == 0, \
|
|
f"'none' should have 0 previews, got {none_result['preview_count']}"
|
|
|
|
print("\nSequential execution results:") # noqa: T201
|
|
for r in results:
|
|
status = "✓" if r["completed"] else f"✗ ({r['error']})"
|
|
print(f" {r['method']}: {status}, {r['preview_count']} previews, {r['execution_time']:.2f}s") # noqa: T201
|