mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 01:37:04 +08:00
Compare commits
3 Commits
f2039b999b
...
894a3e914a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
894a3e914a | ||
|
|
6592bffc60 | ||
|
|
a4872fc717 |
@ -97,6 +97,13 @@ class LatentPreviewMethod(enum.Enum):
|
|||||||
Latent2RGB = "latent2rgb"
|
Latent2RGB = "latent2rgb"
|
||||||
TAESD = "taesd"
|
TAESD = "taesd"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_string(cls, value: str):
|
||||||
|
for member in cls:
|
||||||
|
if member.value == value:
|
||||||
|
return member
|
||||||
|
return None
|
||||||
|
|
||||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||||
|
|
||||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||||
|
|||||||
@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
|
||||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||||
"""
|
"""
|
||||||
|
if solver_type not in {"phi_1", "phi_2"}:
|
||||||
|
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
|
||||||
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||||
|
|
||||||
# Step 2
|
# Step 2
|
||||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
if solver_type == "phi_1":
|
||||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||||
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||||
|
elif solver_type == "phi_2":
|
||||||
|
b2 = ei_h_phi_2(-h_eta) / r
|
||||||
|
b1 = ei_h_phi_1(-h_eta) - b2
|
||||||
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
|
||||||
|
|
||||||
if inject_noise:
|
if inject_noise:
|
||||||
segment_factor = (r - 1) * h * eta
|
segment_factor = (r - 1) * h * eta
|
||||||
sde_noise = sde_noise * segment_factor.exp()
|
sde_noise = sde_noise * segment_factor.exp()
|
||||||
|
|||||||
@ -659,6 +659,31 @@ class SamplerSASolver(io.ComfyNode):
|
|||||||
get_sampler = execute
|
get_sampler = execute
|
||||||
|
|
||||||
|
|
||||||
|
class SamplerSEEDS2(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SamplerSEEDS2",
|
||||||
|
category="sampling/custom_sampling/samplers",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
|
||||||
|
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
|
||||||
|
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
|
||||||
|
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
|
||||||
|
],
|
||||||
|
outputs=[io.Sampler.Output()]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
|
||||||
|
sampler_name = "seeds_2"
|
||||||
|
sampler = comfy.samplers.ksampler(
|
||||||
|
sampler_name,
|
||||||
|
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
|
||||||
|
)
|
||||||
|
return io.NodeOutput(sampler)
|
||||||
|
|
||||||
|
|
||||||
class Noise_EmptyNoise:
|
class Noise_EmptyNoise:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.seed = 0
|
self.seed = 0
|
||||||
@ -996,6 +1021,7 @@ class CustomSamplersExtension(ComfyExtension):
|
|||||||
SamplerDPMAdaptative,
|
SamplerDPMAdaptative,
|
||||||
SamplerER_SDE,
|
SamplerER_SDE,
|
||||||
SamplerSASolver,
|
SamplerSASolver,
|
||||||
|
SamplerSEEDS2,
|
||||||
SplitSigmas,
|
SplitSigmas,
|
||||||
SplitSigmasDenoise,
|
SplitSigmasDenoise,
|
||||||
FlipSigmas,
|
FlipSigmas,
|
||||||
|
|||||||
@ -13,6 +13,7 @@ import asyncio
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from latent_preview import set_preview_method
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_execution.caching import (
|
from comfy_execution.caching import (
|
||||||
BasicCache,
|
BasicCache,
|
||||||
@ -669,6 +670,8 @@ class PromptExecutor:
|
|||||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||||
|
|
||||||
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||||
|
set_preview_method(extra_data.get("preview_method"))
|
||||||
|
|
||||||
nodes.interrupt_processing(False)
|
nodes.interrupt_processing(False)
|
||||||
|
|
||||||
if "client_id" in extra_data:
|
if "client_id" in extra_data:
|
||||||
|
|||||||
@ -8,6 +8,8 @@ import folder_paths
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
default_preview_method = args.preview_method
|
||||||
|
|
||||||
MAX_PREVIEW_RESOLUTION = args.preview_size
|
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||||
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
|
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
|
||||||
|
|
||||||
@ -125,3 +127,11 @@ def prepare_callback(model, steps, x0_output_dict=None):
|
|||||||
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
|
def set_preview_method(override: str = None):
|
||||||
|
if override and override != "default":
|
||||||
|
method = LatentPreviewMethod.from_string(override)
|
||||||
|
if method is not None:
|
||||||
|
args.preview_method = method
|
||||||
|
return
|
||||||
|
args.preview_method = default_preview_method
|
||||||
|
|
||||||
|
|||||||
352
tests-unit/execution_test/preview_method_override_test.py
Normal file
352
tests-unit/execution_test/preview_method_override_test.py
Normal file
@ -0,0 +1,352 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for Queue-specific Preview Method Override feature.
|
||||||
|
|
||||||
|
Tests the preview method override functionality:
|
||||||
|
- LatentPreviewMethod.from_string() method
|
||||||
|
- set_preview_method() function in latent_preview.py
|
||||||
|
- default_preview_method variable
|
||||||
|
- Integration with args.preview_method
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
from comfy.cli_args import args, LatentPreviewMethod
|
||||||
|
from latent_preview import set_preview_method, default_preview_method
|
||||||
|
|
||||||
|
|
||||||
|
class TestLatentPreviewMethodFromString:
|
||||||
|
"""Test LatentPreviewMethod.from_string() classmethod."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("value,expected", [
|
||||||
|
("auto", LatentPreviewMethod.Auto),
|
||||||
|
("latent2rgb", LatentPreviewMethod.Latent2RGB),
|
||||||
|
("taesd", LatentPreviewMethod.TAESD),
|
||||||
|
("none", LatentPreviewMethod.NoPreviews),
|
||||||
|
])
|
||||||
|
def test_valid_values_return_enum(self, value, expected):
|
||||||
|
"""Valid string values should return corresponding enum."""
|
||||||
|
assert LatentPreviewMethod.from_string(value) == expected
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("invalid", [
|
||||||
|
"invalid",
|
||||||
|
"TAESD", # Case sensitive
|
||||||
|
"AUTO", # Case sensitive
|
||||||
|
"Latent2RGB", # Case sensitive
|
||||||
|
"latent",
|
||||||
|
"",
|
||||||
|
"default", # default is special, not a method
|
||||||
|
])
|
||||||
|
def test_invalid_values_return_none(self, invalid):
|
||||||
|
"""Invalid string values should return None."""
|
||||||
|
assert LatentPreviewMethod.from_string(invalid) is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestLatentPreviewMethodEnumValues:
|
||||||
|
"""Test LatentPreviewMethod enum has expected values."""
|
||||||
|
|
||||||
|
def test_enum_values(self):
|
||||||
|
"""Verify enum values match expected strings."""
|
||||||
|
assert LatentPreviewMethod.NoPreviews.value == "none"
|
||||||
|
assert LatentPreviewMethod.Auto.value == "auto"
|
||||||
|
assert LatentPreviewMethod.Latent2RGB.value == "latent2rgb"
|
||||||
|
assert LatentPreviewMethod.TAESD.value == "taesd"
|
||||||
|
|
||||||
|
def test_enum_count(self):
|
||||||
|
"""Verify exactly 4 preview methods exist."""
|
||||||
|
assert len(LatentPreviewMethod) == 4
|
||||||
|
|
||||||
|
|
||||||
|
class TestSetPreviewMethod:
|
||||||
|
"""Test set_preview_method() function from latent_preview.py."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Store original value before each test."""
|
||||||
|
self.original = args.preview_method
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
"""Restore original value after each test."""
|
||||||
|
args.preview_method = self.original
|
||||||
|
|
||||||
|
def test_override_with_taesd(self):
|
||||||
|
"""'taesd' should set args.preview_method to TAESD."""
|
||||||
|
set_preview_method("taesd")
|
||||||
|
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||||
|
|
||||||
|
def test_override_with_latent2rgb(self):
|
||||||
|
"""'latent2rgb' should set args.preview_method to Latent2RGB."""
|
||||||
|
set_preview_method("latent2rgb")
|
||||||
|
assert args.preview_method == LatentPreviewMethod.Latent2RGB
|
||||||
|
|
||||||
|
def test_override_with_auto(self):
|
||||||
|
"""'auto' should set args.preview_method to Auto."""
|
||||||
|
set_preview_method("auto")
|
||||||
|
assert args.preview_method == LatentPreviewMethod.Auto
|
||||||
|
|
||||||
|
def test_override_with_none_value(self):
|
||||||
|
"""'none' should set args.preview_method to NoPreviews."""
|
||||||
|
set_preview_method("none")
|
||||||
|
assert args.preview_method == LatentPreviewMethod.NoPreviews
|
||||||
|
|
||||||
|
def test_default_restores_original(self):
|
||||||
|
"""'default' should restore to default_preview_method."""
|
||||||
|
# First override to something else
|
||||||
|
set_preview_method("taesd")
|
||||||
|
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||||
|
|
||||||
|
# Then use 'default' to restore
|
||||||
|
set_preview_method("default")
|
||||||
|
assert args.preview_method == default_preview_method
|
||||||
|
|
||||||
|
def test_none_param_restores_original(self):
|
||||||
|
"""None parameter should restore to default_preview_method."""
|
||||||
|
# First override to something else
|
||||||
|
set_preview_method("taesd")
|
||||||
|
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||||
|
|
||||||
|
# Then use None to restore
|
||||||
|
set_preview_method(None)
|
||||||
|
assert args.preview_method == default_preview_method
|
||||||
|
|
||||||
|
def test_empty_string_restores_original(self):
|
||||||
|
"""Empty string should restore to default_preview_method."""
|
||||||
|
set_preview_method("taesd")
|
||||||
|
set_preview_method("")
|
||||||
|
assert args.preview_method == default_preview_method
|
||||||
|
|
||||||
|
def test_invalid_value_restores_original(self):
|
||||||
|
"""Invalid value should restore to default_preview_method."""
|
||||||
|
set_preview_method("taesd")
|
||||||
|
set_preview_method("invalid_method")
|
||||||
|
assert args.preview_method == default_preview_method
|
||||||
|
|
||||||
|
def test_case_sensitive_invalid_restores(self):
|
||||||
|
"""Case-mismatched values should restore to default."""
|
||||||
|
set_preview_method("taesd")
|
||||||
|
set_preview_method("TAESD") # Wrong case
|
||||||
|
assert args.preview_method == default_preview_method
|
||||||
|
|
||||||
|
|
||||||
|
class TestDefaultPreviewMethod:
|
||||||
|
"""Test default_preview_method module variable."""
|
||||||
|
|
||||||
|
def test_default_is_not_none(self):
|
||||||
|
"""default_preview_method should not be None."""
|
||||||
|
assert default_preview_method is not None
|
||||||
|
|
||||||
|
def test_default_is_enum_member(self):
|
||||||
|
"""default_preview_method should be a LatentPreviewMethod enum."""
|
||||||
|
assert isinstance(default_preview_method, LatentPreviewMethod)
|
||||||
|
|
||||||
|
def test_default_matches_args_initial(self):
|
||||||
|
"""default_preview_method should match CLI default or user setting."""
|
||||||
|
# This tests that default_preview_method was captured at module load
|
||||||
|
# After set_preview_method(None), args should equal default
|
||||||
|
original = args.preview_method
|
||||||
|
set_preview_method("taesd")
|
||||||
|
set_preview_method(None)
|
||||||
|
assert args.preview_method == default_preview_method
|
||||||
|
args.preview_method = original
|
||||||
|
|
||||||
|
|
||||||
|
class TestArgsPreviewMethodModification:
|
||||||
|
"""Test args.preview_method can be modified correctly."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Store original value before each test."""
|
||||||
|
self.original = args.preview_method
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
"""Restore original value after each test."""
|
||||||
|
args.preview_method = self.original
|
||||||
|
|
||||||
|
def test_args_accepts_all_enum_values(self):
|
||||||
|
"""args.preview_method should accept all LatentPreviewMethod values."""
|
||||||
|
for method in LatentPreviewMethod:
|
||||||
|
args.preview_method = method
|
||||||
|
assert args.preview_method == method
|
||||||
|
|
||||||
|
def test_args_modification_and_restoration(self):
|
||||||
|
"""args.preview_method should be modifiable and restorable."""
|
||||||
|
original = args.preview_method
|
||||||
|
|
||||||
|
args.preview_method = LatentPreviewMethod.TAESD
|
||||||
|
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||||
|
|
||||||
|
args.preview_method = original
|
||||||
|
assert args.preview_method == original
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecutionFlow:
|
||||||
|
"""Test the execution flow pattern used in execution.py."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Store original value before each test."""
|
||||||
|
self.original = args.preview_method
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
"""Restore original value after each test."""
|
||||||
|
args.preview_method = self.original
|
||||||
|
|
||||||
|
def test_sequential_executions_with_different_methods(self):
|
||||||
|
"""Simulate multiple queue executions with different preview methods."""
|
||||||
|
# Execution 1: taesd
|
||||||
|
set_preview_method("taesd")
|
||||||
|
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||||
|
|
||||||
|
# Execution 2: none
|
||||||
|
set_preview_method("none")
|
||||||
|
assert args.preview_method == LatentPreviewMethod.NoPreviews
|
||||||
|
|
||||||
|
# Execution 3: default (restore)
|
||||||
|
set_preview_method("default")
|
||||||
|
assert args.preview_method == default_preview_method
|
||||||
|
|
||||||
|
# Execution 4: auto
|
||||||
|
set_preview_method("auto")
|
||||||
|
assert args.preview_method == LatentPreviewMethod.Auto
|
||||||
|
|
||||||
|
# Execution 5: no override (None)
|
||||||
|
set_preview_method(None)
|
||||||
|
assert args.preview_method == default_preview_method
|
||||||
|
|
||||||
|
def test_override_then_default_pattern(self):
|
||||||
|
"""Test the pattern: override -> execute -> next call restores."""
|
||||||
|
# First execution with override
|
||||||
|
set_preview_method("latent2rgb")
|
||||||
|
assert args.preview_method == LatentPreviewMethod.Latent2RGB
|
||||||
|
|
||||||
|
# Second execution without override restores default
|
||||||
|
set_preview_method(None)
|
||||||
|
assert args.preview_method == default_preview_method
|
||||||
|
|
||||||
|
def test_extra_data_simulation(self):
|
||||||
|
"""Simulate extra_data.get('preview_method') patterns."""
|
||||||
|
# Simulate: extra_data = {"preview_method": "taesd"}
|
||||||
|
extra_data = {"preview_method": "taesd"}
|
||||||
|
set_preview_method(extra_data.get("preview_method"))
|
||||||
|
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||||
|
|
||||||
|
# Simulate: extra_data = {}
|
||||||
|
extra_data = {}
|
||||||
|
set_preview_method(extra_data.get("preview_method"))
|
||||||
|
assert args.preview_method == default_preview_method
|
||||||
|
|
||||||
|
# Simulate: extra_data = {"preview_method": "default"}
|
||||||
|
extra_data = {"preview_method": "default"}
|
||||||
|
set_preview_method(extra_data.get("preview_method"))
|
||||||
|
assert args.preview_method == default_preview_method
|
||||||
|
|
||||||
|
|
||||||
|
class TestRealWorldScenarios:
|
||||||
|
"""Tests using real-world prompt data patterns."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Store original value before each test."""
|
||||||
|
self.original = args.preview_method
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
"""Restore original value after each test."""
|
||||||
|
args.preview_method = self.original
|
||||||
|
|
||||||
|
def test_captured_prompt_without_preview_method(self):
|
||||||
|
"""
|
||||||
|
Test with captured prompt that has no preview_method.
|
||||||
|
Based on: tests-unit/execution_test/fixtures/default_prompt.json
|
||||||
|
"""
|
||||||
|
# Real captured extra_data structure (preview_method absent)
|
||||||
|
extra_data = {
|
||||||
|
"extra_pnginfo": {"workflow": {}},
|
||||||
|
"client_id": "271314f0dabd48e5aaa488ed7a4ceb0d",
|
||||||
|
"create_time": 1765416558179
|
||||||
|
}
|
||||||
|
|
||||||
|
set_preview_method(extra_data.get("preview_method"))
|
||||||
|
assert args.preview_method == default_preview_method
|
||||||
|
|
||||||
|
def test_captured_prompt_with_preview_method_taesd(self):
|
||||||
|
"""Test captured prompt with preview_method: taesd."""
|
||||||
|
extra_data = {
|
||||||
|
"extra_pnginfo": {"workflow": {}},
|
||||||
|
"client_id": "271314f0dabd48e5aaa488ed7a4ceb0d",
|
||||||
|
"preview_method": "taesd"
|
||||||
|
}
|
||||||
|
|
||||||
|
set_preview_method(extra_data.get("preview_method"))
|
||||||
|
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||||
|
|
||||||
|
def test_captured_prompt_with_preview_method_none(self):
|
||||||
|
"""Test captured prompt with preview_method: none (disable preview)."""
|
||||||
|
extra_data = {
|
||||||
|
"extra_pnginfo": {"workflow": {}},
|
||||||
|
"client_id": "test-client",
|
||||||
|
"preview_method": "none"
|
||||||
|
}
|
||||||
|
|
||||||
|
set_preview_method(extra_data.get("preview_method"))
|
||||||
|
assert args.preview_method == LatentPreviewMethod.NoPreviews
|
||||||
|
|
||||||
|
def test_captured_prompt_with_preview_method_latent2rgb(self):
|
||||||
|
"""Test captured prompt with preview_method: latent2rgb."""
|
||||||
|
extra_data = {
|
||||||
|
"extra_pnginfo": {"workflow": {}},
|
||||||
|
"client_id": "test-client",
|
||||||
|
"preview_method": "latent2rgb"
|
||||||
|
}
|
||||||
|
|
||||||
|
set_preview_method(extra_data.get("preview_method"))
|
||||||
|
assert args.preview_method == LatentPreviewMethod.Latent2RGB
|
||||||
|
|
||||||
|
def test_captured_prompt_with_preview_method_auto(self):
|
||||||
|
"""Test captured prompt with preview_method: auto."""
|
||||||
|
extra_data = {
|
||||||
|
"extra_pnginfo": {"workflow": {}},
|
||||||
|
"client_id": "test-client",
|
||||||
|
"preview_method": "auto"
|
||||||
|
}
|
||||||
|
|
||||||
|
set_preview_method(extra_data.get("preview_method"))
|
||||||
|
assert args.preview_method == LatentPreviewMethod.Auto
|
||||||
|
|
||||||
|
def test_captured_prompt_with_preview_method_default(self):
|
||||||
|
"""Test captured prompt with preview_method: default (use CLI setting)."""
|
||||||
|
# First set to something else
|
||||||
|
set_preview_method("taesd")
|
||||||
|
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||||
|
|
||||||
|
# Then simulate a prompt with "default"
|
||||||
|
extra_data = {
|
||||||
|
"extra_pnginfo": {"workflow": {}},
|
||||||
|
"client_id": "test-client",
|
||||||
|
"preview_method": "default"
|
||||||
|
}
|
||||||
|
|
||||||
|
set_preview_method(extra_data.get("preview_method"))
|
||||||
|
assert args.preview_method == default_preview_method
|
||||||
|
|
||||||
|
def test_sequential_queue_with_different_preview_methods(self):
|
||||||
|
"""
|
||||||
|
Simulate real queue scenario: multiple prompts with different settings.
|
||||||
|
This tests the actual usage pattern in ComfyUI.
|
||||||
|
"""
|
||||||
|
# Queue 1: User wants TAESD preview
|
||||||
|
extra_data_1 = {"client_id": "client-1", "preview_method": "taesd"}
|
||||||
|
set_preview_method(extra_data_1.get("preview_method"))
|
||||||
|
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||||
|
|
||||||
|
# Queue 2: User wants no preview (faster execution)
|
||||||
|
extra_data_2 = {"client_id": "client-2", "preview_method": "none"}
|
||||||
|
set_preview_method(extra_data_2.get("preview_method"))
|
||||||
|
assert args.preview_method == LatentPreviewMethod.NoPreviews
|
||||||
|
|
||||||
|
# Queue 3: User doesn't specify (use server default)
|
||||||
|
extra_data_3 = {"client_id": "client-3"}
|
||||||
|
set_preview_method(extra_data_3.get("preview_method"))
|
||||||
|
assert args.preview_method == default_preview_method
|
||||||
|
|
||||||
|
# Queue 4: User explicitly wants default
|
||||||
|
extra_data_4 = {"client_id": "client-4", "preview_method": "default"}
|
||||||
|
set_preview_method(extra_data_4.get("preview_method"))
|
||||||
|
assert args.preview_method == default_preview_method
|
||||||
|
|
||||||
|
# Queue 5: User wants latent2rgb
|
||||||
|
extra_data_5 = {"client_id": "client-5", "preview_method": "latent2rgb"}
|
||||||
|
set_preview_method(extra_data_5.get("preview_method"))
|
||||||
|
assert args.preview_method == LatentPreviewMethod.Latent2RGB
|
||||||
358
tests/execution/test_preview_method.py
Normal file
358
tests/execution/test_preview_method.py
Normal file
@ -0,0 +1,358 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for Queue-specific Preview Method Override feature.
|
||||||
|
|
||||||
|
Tests actual execution with different preview_method values.
|
||||||
|
Requires a running ComfyUI server with models.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
COMFYUI_SERVER=http://localhost:8988 pytest test_preview_method_e2e.py -v -m preview_method
|
||||||
|
|
||||||
|
Note:
|
||||||
|
These tests execute actual image generation and wait for completion.
|
||||||
|
Tests verify preview image transmission based on preview_method setting.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
import uuid
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import websocket
|
||||||
|
import urllib.request
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
# Server configuration
|
||||||
|
SERVER_URL = os.environ.get("COMFYUI_SERVER", "http://localhost:8988")
|
||||||
|
SERVER_HOST = SERVER_URL.replace("http://", "").replace("https://", "")
|
||||||
|
|
||||||
|
# Use existing inference graph fixture
|
||||||
|
GRAPH_FILE = Path(__file__).parent.parent / "inference" / "graphs" / "default_graph_sdxl1_0.json"
|
||||||
|
|
||||||
|
|
||||||
|
def is_server_running() -> bool:
|
||||||
|
"""Check if ComfyUI server is running."""
|
||||||
|
try:
|
||||||
|
request = urllib.request.Request(f"{SERVER_URL}/system_stats")
|
||||||
|
with urllib.request.urlopen(request, timeout=2.0):
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_graph_for_test(graph: dict, steps: int = 5) -> dict:
|
||||||
|
"""Prepare graph for testing: randomize seeds and reduce steps."""
|
||||||
|
adapted = json.loads(json.dumps(graph)) # Deep copy
|
||||||
|
for node_id, node in adapted.items():
|
||||||
|
inputs = node.get("inputs", {})
|
||||||
|
# Handle both "seed" and "noise_seed" (used by KSamplerAdvanced)
|
||||||
|
if "seed" in inputs:
|
||||||
|
inputs["seed"] = random.randint(0, 2**32 - 1)
|
||||||
|
if "noise_seed" in inputs:
|
||||||
|
inputs["noise_seed"] = random.randint(0, 2**32 - 1)
|
||||||
|
# Reduce steps for faster testing (default 20 -> 5)
|
||||||
|
if "steps" in inputs:
|
||||||
|
inputs["steps"] = steps
|
||||||
|
return adapted
|
||||||
|
|
||||||
|
|
||||||
|
# Alias for backward compatibility
|
||||||
|
randomize_seed = prepare_graph_for_test
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewMethodClient:
|
||||||
|
"""Client for testing preview_method with WebSocket execution tracking."""
|
||||||
|
|
||||||
|
def __init__(self, server_address: str):
|
||||||
|
self.server_address = server_address
|
||||||
|
self.client_id = str(uuid.uuid4())
|
||||||
|
self.ws = None
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
"""Connect to WebSocket."""
|
||||||
|
self.ws = websocket.WebSocket()
|
||||||
|
self.ws.settimeout(120) # 2 minute timeout for sampling
|
||||||
|
self.ws.connect(f"ws://{self.server_address}/ws?clientId={self.client_id}")
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close WebSocket connection."""
|
||||||
|
if self.ws:
|
||||||
|
self.ws.close()
|
||||||
|
|
||||||
|
def queue_prompt(self, prompt: dict, extra_data: dict = None) -> dict:
|
||||||
|
"""Queue a prompt and return response with prompt_id."""
|
||||||
|
data = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"extra_data": extra_data or {}
|
||||||
|
}
|
||||||
|
req = urllib.request.Request(
|
||||||
|
f"http://{self.server_address}/prompt",
|
||||||
|
data=json.dumps(data).encode("utf-8"),
|
||||||
|
headers={"Content-Type": "application/json"}
|
||||||
|
)
|
||||||
|
return json.loads(urllib.request.urlopen(req).read())
|
||||||
|
|
||||||
|
def wait_for_execution(self, prompt_id: str, timeout: float = 120.0) -> dict:
|
||||||
|
"""
|
||||||
|
Wait for execution to complete via WebSocket.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with keys: completed, error, preview_count, execution_time
|
||||||
|
"""
|
||||||
|
result = {
|
||||||
|
"completed": False,
|
||||||
|
"error": None,
|
||||||
|
"preview_count": 0,
|
||||||
|
"execution_time": 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
self.ws.settimeout(timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
out = self.ws.recv()
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
|
if isinstance(out, str):
|
||||||
|
message = json.loads(out)
|
||||||
|
msg_type = message.get("type")
|
||||||
|
data = message.get("data", {})
|
||||||
|
|
||||||
|
if data.get("prompt_id") != prompt_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if msg_type == "executing":
|
||||||
|
if data.get("node") is None:
|
||||||
|
# Execution complete
|
||||||
|
result["completed"] = True
|
||||||
|
result["execution_time"] = elapsed
|
||||||
|
break
|
||||||
|
|
||||||
|
elif msg_type == "execution_error":
|
||||||
|
result["error"] = data
|
||||||
|
result["execution_time"] = elapsed
|
||||||
|
break
|
||||||
|
|
||||||
|
elif msg_type == "progress":
|
||||||
|
# Progress update during sampling
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif isinstance(out, bytes):
|
||||||
|
# Binary data = preview image
|
||||||
|
result["preview_count"] += 1
|
||||||
|
|
||||||
|
except websocket.WebSocketTimeoutException:
|
||||||
|
result["error"] = "Timeout waiting for execution"
|
||||||
|
result["execution_time"] = time.time() - start_time
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def load_graph() -> dict:
|
||||||
|
"""Load the SDXL graph fixture with randomized seed."""
|
||||||
|
with open(GRAPH_FILE) as f:
|
||||||
|
graph = json.load(f)
|
||||||
|
return randomize_seed(graph) # Avoid caching
|
||||||
|
|
||||||
|
|
||||||
|
# Skip all tests if server is not running
|
||||||
|
pytestmark = [
|
||||||
|
pytest.mark.skipif(
|
||||||
|
not is_server_running(),
|
||||||
|
reason=f"ComfyUI server not running at {SERVER_URL}"
|
||||||
|
),
|
||||||
|
pytest.mark.preview_method,
|
||||||
|
pytest.mark.execution,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
"""Create and connect a test client."""
|
||||||
|
c = PreviewMethodClient(SERVER_HOST)
|
||||||
|
c.connect()
|
||||||
|
yield c
|
||||||
|
c.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def graph():
|
||||||
|
"""Load the test graph."""
|
||||||
|
return load_graph()
|
||||||
|
|
||||||
|
|
||||||
|
class TestPreviewMethodExecution:
|
||||||
|
"""Test actual execution with different preview methods."""
|
||||||
|
|
||||||
|
def test_execution_with_latent2rgb(self, client, graph):
|
||||||
|
"""
|
||||||
|
Execute with preview_method=latent2rgb.
|
||||||
|
Should complete and potentially receive preview images.
|
||||||
|
"""
|
||||||
|
extra_data = {"preview_method": "latent2rgb"}
|
||||||
|
|
||||||
|
response = client.queue_prompt(graph, extra_data)
|
||||||
|
assert "prompt_id" in response
|
||||||
|
|
||||||
|
result = client.wait_for_execution(response["prompt_id"])
|
||||||
|
|
||||||
|
# Should complete (may error if model missing, but that's separate)
|
||||||
|
assert result["completed"] or result["error"] is not None
|
||||||
|
# Execution should take some time (sampling)
|
||||||
|
if result["completed"]:
|
||||||
|
assert result["execution_time"] > 0.5, "Execution too fast - likely didn't run"
|
||||||
|
# latent2rgb should produce previews
|
||||||
|
print(f"latent2rgb: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
||||||
|
|
||||||
|
def test_execution_with_taesd(self, client, graph):
|
||||||
|
"""
|
||||||
|
Execute with preview_method=taesd.
|
||||||
|
TAESD provides higher quality previews.
|
||||||
|
"""
|
||||||
|
extra_data = {"preview_method": "taesd"}
|
||||||
|
|
||||||
|
response = client.queue_prompt(graph, extra_data)
|
||||||
|
assert "prompt_id" in response
|
||||||
|
|
||||||
|
result = client.wait_for_execution(response["prompt_id"])
|
||||||
|
|
||||||
|
assert result["completed"] or result["error"] is not None
|
||||||
|
if result["completed"]:
|
||||||
|
assert result["execution_time"] > 0.5
|
||||||
|
# taesd should also produce previews
|
||||||
|
print(f"taesd: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
||||||
|
|
||||||
|
def test_execution_with_none_preview(self, client, graph):
|
||||||
|
"""
|
||||||
|
Execute with preview_method=none.
|
||||||
|
No preview images should be generated.
|
||||||
|
"""
|
||||||
|
extra_data = {"preview_method": "none"}
|
||||||
|
|
||||||
|
response = client.queue_prompt(graph, extra_data)
|
||||||
|
assert "prompt_id" in response
|
||||||
|
|
||||||
|
result = client.wait_for_execution(response["prompt_id"])
|
||||||
|
|
||||||
|
assert result["completed"] or result["error"] is not None
|
||||||
|
if result["completed"]:
|
||||||
|
# With "none", should receive no preview images
|
||||||
|
assert result["preview_count"] == 0, \
|
||||||
|
f"Expected no previews with 'none', got {result['preview_count']}"
|
||||||
|
print(f"none: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
||||||
|
|
||||||
|
def test_execution_with_default(self, client, graph):
|
||||||
|
"""
|
||||||
|
Execute with preview_method=default.
|
||||||
|
Should use server's CLI default setting.
|
||||||
|
"""
|
||||||
|
extra_data = {"preview_method": "default"}
|
||||||
|
|
||||||
|
response = client.queue_prompt(graph, extra_data)
|
||||||
|
assert "prompt_id" in response
|
||||||
|
|
||||||
|
result = client.wait_for_execution(response["prompt_id"])
|
||||||
|
|
||||||
|
assert result["completed"] or result["error"] is not None
|
||||||
|
if result["completed"]:
|
||||||
|
print(f"default: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
||||||
|
|
||||||
|
def test_execution_without_preview_method(self, client, graph):
|
||||||
|
"""
|
||||||
|
Execute without preview_method in extra_data.
|
||||||
|
Should use server's default preview method.
|
||||||
|
"""
|
||||||
|
extra_data = {} # No preview_method
|
||||||
|
|
||||||
|
response = client.queue_prompt(graph, extra_data)
|
||||||
|
assert "prompt_id" in response
|
||||||
|
|
||||||
|
result = client.wait_for_execution(response["prompt_id"])
|
||||||
|
|
||||||
|
assert result["completed"] or result["error"] is not None
|
||||||
|
if result["completed"]:
|
||||||
|
print(f"(no override): {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
||||||
|
|
||||||
|
|
||||||
|
class TestPreviewMethodComparison:
|
||||||
|
"""Compare preview behavior between different methods."""
|
||||||
|
|
||||||
|
def test_none_vs_latent2rgb_preview_count(self, client, graph):
|
||||||
|
"""
|
||||||
|
Compare preview counts: 'none' should have 0, others should have >0.
|
||||||
|
This is the key verification that preview_method actually works.
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
# Run with none (randomize seed to avoid caching)
|
||||||
|
graph_none = randomize_seed(graph)
|
||||||
|
extra_data_none = {"preview_method": "none"}
|
||||||
|
response = client.queue_prompt(graph_none, extra_data_none)
|
||||||
|
results["none"] = client.wait_for_execution(response["prompt_id"])
|
||||||
|
|
||||||
|
# Run with latent2rgb (randomize seed again)
|
||||||
|
graph_rgb = randomize_seed(graph)
|
||||||
|
extra_data_rgb = {"preview_method": "latent2rgb"}
|
||||||
|
response = client.queue_prompt(graph_rgb, extra_data_rgb)
|
||||||
|
results["latent2rgb"] = client.wait_for_execution(response["prompt_id"])
|
||||||
|
|
||||||
|
# Verify both completed
|
||||||
|
assert results["none"]["completed"], f"'none' execution failed: {results['none']['error']}"
|
||||||
|
assert results["latent2rgb"]["completed"], f"'latent2rgb' execution failed: {results['latent2rgb']['error']}"
|
||||||
|
|
||||||
|
# Key assertion: 'none' should have 0 previews
|
||||||
|
assert results["none"]["preview_count"] == 0, \
|
||||||
|
f"'none' should have 0 previews, got {results['none']['preview_count']}"
|
||||||
|
|
||||||
|
# 'latent2rgb' should have at least 1 preview (depends on steps)
|
||||||
|
assert results["latent2rgb"]["preview_count"] > 0, \
|
||||||
|
f"'latent2rgb' should have >0 previews, got {results['latent2rgb']['preview_count']}"
|
||||||
|
|
||||||
|
print("\nPreview count comparison:") # noqa: T201
|
||||||
|
print(f" none: {results['none']['preview_count']} previews") # noqa: T201
|
||||||
|
print(f" latent2rgb: {results['latent2rgb']['preview_count']} previews") # noqa: T201
|
||||||
|
|
||||||
|
|
||||||
|
class TestPreviewMethodSequential:
|
||||||
|
"""Test sequential execution with different preview methods."""
|
||||||
|
|
||||||
|
def test_sequential_different_methods(self, client, graph):
|
||||||
|
"""
|
||||||
|
Execute multiple prompts sequentially with different preview methods.
|
||||||
|
Each should complete independently with correct preview behavior.
|
||||||
|
"""
|
||||||
|
methods = ["latent2rgb", "none", "default"]
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for method in methods:
|
||||||
|
# Randomize seed for each execution to avoid caching
|
||||||
|
graph_run = randomize_seed(graph)
|
||||||
|
extra_data = {"preview_method": method}
|
||||||
|
response = client.queue_prompt(graph_run, extra_data)
|
||||||
|
|
||||||
|
result = client.wait_for_execution(response["prompt_id"])
|
||||||
|
results.append({
|
||||||
|
"method": method,
|
||||||
|
"completed": result["completed"],
|
||||||
|
"preview_count": result["preview_count"],
|
||||||
|
"execution_time": result["execution_time"],
|
||||||
|
"error": result["error"]
|
||||||
|
})
|
||||||
|
|
||||||
|
# All should complete or have clear errors
|
||||||
|
for r in results:
|
||||||
|
assert r["completed"] or r["error"] is not None, \
|
||||||
|
f"Method {r['method']} neither completed nor errored"
|
||||||
|
|
||||||
|
# "none" should have zero previews if completed
|
||||||
|
none_result = next(r for r in results if r["method"] == "none")
|
||||||
|
if none_result["completed"]:
|
||||||
|
assert none_result["preview_count"] == 0, \
|
||||||
|
f"'none' should have 0 previews, got {none_result['preview_count']}"
|
||||||
|
|
||||||
|
print("\nSequential execution results:") # noqa: T201
|
||||||
|
for r in results:
|
||||||
|
status = "✓" if r["completed"] else f"✗ ({r['error']})"
|
||||||
|
print(f" {r['method']}: {status}, {r['preview_count']} previews, {r['execution_time']:.2f}s") # noqa: T201
|
||||||
Loading…
Reference in New Issue
Block a user