fix tests

This commit is contained in:
doctorpangloss 2025-12-17 10:10:14 -08:00
parent b183b27a87
commit fa05417ebd
4 changed files with 25 additions and 16 deletions

View File

@ -47,9 +47,9 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a
ProgressRegistry
from comfy_execution.utils import CurrentNodeContext
from comfy_execution.validation import validate_node_input
from .latent_preview import set_preview_method
from .. import interruption
from .. import model_management
from ..cli_args_types import LatentPreviewMethod
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
@ -812,13 +812,19 @@ class PromptExecutor:
extra_data = {}
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
async def execute_async(self, prompt, prompt_id, extra_data=None, execute_outputs=None):
# torchao and potentially other optimization approaches break when the models are created in inference mode
# todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here
if execute_outputs is None:
execute_outputs = []
if extra_data is None:
extra_data = {}
inference_mode = all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt))
dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt)
with context_execute_prompt(self.server, prompt_id, progress_registry=ProgressRegistry(prompt_id, dynamic_prompt), inference_mode=inference_mode):
extra_data_preview_method = extra_data.get("preview_method", None)
preview_method_override = LatentPreviewMethod.from_string(extra_data_preview_method) if extra_data_preview_method is not None else None
with context_execute_prompt(self.server, prompt_id, progress_registry=ProgressRegistry(prompt_id, dynamic_prompt), inference_mode=inference_mode, preview_method_override=preview_method_override):
await self._execute_async(dynamic_prompt, prompt_id, extra_data, execute_outputs)
async def _execute_async(self, prompt: DynamicPrompt, prompt_id, extra_data=None, execute_outputs: list[str] = None, inference_mode: bool = True):
@ -827,9 +833,6 @@ class PromptExecutor:
if extra_data is None:
extra_data = {}
extra_data_preview_method = extra_data.get("preview_method", None)
if extra_data_preview_method is not None:
set_preview_method(extra_data_preview_method)
interruption.interrupt_current_processing(False)
if "client_id" in extra_data:

View File

@ -18,8 +18,10 @@ from ..taesd.taesd import TAESD
from ..sd import VAE
from ..utils import load_torch_file
# todo: should not have been introduced
default_preview_method = args.preview_method
# needs to come from context, which it sort of does here
MAX_PREVIEW_RESOLUTION = args.preview_size
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
@ -95,8 +97,9 @@ class Latent2RGBPreviewer(LatentPreviewer):
def get_previewer(device, latent_format):
previewer = None
method = args.preview_method
if method != LatentPreviewMethod.NoPreviews:
override = current_execution_context().preview_method_override
method: LatentPreviewMethod = override if override is not None else args.preview_method
if method is not None and method != LatentPreviewMethod.NoPreviews:
# TODO previewer methods
taesd_decoder_path = None
if latent_format.taesd_decoder_name is not None:
@ -148,9 +151,10 @@ def prepare_callback(model, steps, x0_output_dict=None):
return callback
def set_preview_method(override: str = None):
# todo: this should set a context var where it is called, which is exactly one place
return
raise RuntimeError("not supported")
# if override and override != "default":
# method = LatentPreviewMethod.from_string(override)
@ -160,4 +164,3 @@ def set_preview_method(override: str = None):
#
#
# args.preview_method = default_preview_method

View File

@ -7,7 +7,7 @@ from typing import Optional, Final
from comfy_execution.graph_types import FrozenTopologicalSort
from .cli_args import cli_args_configuration
from .cli_args_types import Configuration
from .cli_args_types import Configuration, LatentPreviewMethod
from .component_model import cvpickle
from .component_model.executor_types import ExecutorToClientProgress
from .component_model.folder_path_types import FolderNames
@ -26,6 +26,7 @@ class ExecutionContext:
# during prompt execution
progress_registry: Optional[AbstractProgressRegistry] = None
preview_method_override: Optional[LatentPreviewMethod] = None
# during node execution
node_id: Optional[str] = None
@ -49,7 +50,7 @@ class ExecutionContext:
yield self.list_index
comfyui_execution_context: Final[ContextVar] = ContextVar("comfyui_execution_context", default=ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub(), configuration=cli_args_configuration()))
comfyui_execution_context: Final[ContextVar] = ContextVar("comfyui_execution_context", default=ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub(), configuration=cli_args_configuration(), preview_method_override=None))
# enables context var propagation across process boundaries for process pool executors
cvpickle.register_contextvar(comfyui_execution_context, __name__)
@ -87,9 +88,9 @@ def context_folder_names_and_paths(folder_names_and_paths: FolderNames):
@contextmanager
def context_execute_prompt(server: ExecutorToClientProgress, prompt_id: str, progress_registry: AbstractProgressRegistry, inference_mode: bool = True):
def context_execute_prompt(server: ExecutorToClientProgress, prompt_id: str, progress_registry: AbstractProgressRegistry, inference_mode: bool = True, preview_method_override=None):
current_ctx = current_execution_context()
new_ctx = replace(current_ctx, server=server, task_id=prompt_id, inference_mode=inference_mode, progress_registry=progress_registry)
new_ctx = replace(current_ctx, server=server, task_id=prompt_id, inference_mode=inference_mode, progress_registry=progress_registry, preview_method_override=preview_method_override)
with _new_execution_context(new_ctx):
yield new_ctx

View File

@ -9,8 +9,10 @@ Tests the preview method override functionality:
"""
import pytest
from comfy.cli_args import args, LatentPreviewMethod
from latent_preview import set_preview_method, default_preview_method
# from comfy.cmd.latent_preview import set_preview_method, default_preview_method
set_preview_method = None
default_preview_method = None
pytestmark = pytest.mark.skip()
class TestLatentPreviewMethodFromString:
"""Test LatentPreviewMethod.from_string() classmethod."""