mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
Fix inference mode execution issues
This commit is contained in:
parent
a38968f098
commit
f3da381869
@ -24,7 +24,7 @@ from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
|
||||
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
|
||||
RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \
|
||||
HistoryResultDict
|
||||
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage
|
||||
from ..component_model.files import canonicalize_path
|
||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
||||
from ..execution_context import new_execution_context, context_execute_node, ExecutionContext
|
||||
@ -311,6 +311,7 @@ def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches,
|
||||
with context_execute_node(_node_id, prompt_id):
|
||||
return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
||||
|
||||
|
||||
def _execute(server, dynprompt, caches, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple:
|
||||
unique_id = current_item
|
||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||
@ -368,11 +369,11 @@ def _execute(server, dynprompt, caches, current_item: str, extra_data, executed,
|
||||
if len(required_inputs) > 0:
|
||||
for i in required_inputs:
|
||||
execution_list.make_input_strong_link(unique_id, i)
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None)
|
||||
|
||||
def execution_block_cb(block):
|
||||
if block.message is not None:
|
||||
mes = {
|
||||
mes: ExecutionErrorMessage = {
|
||||
"prompt_id": prompt_id,
|
||||
"node_id": unique_id,
|
||||
"node_type": class_type,
|
||||
@ -442,7 +443,7 @@ def _execute(server, dynprompt, caches, current_item: str, extra_data, executed,
|
||||
for link in new_output_links:
|
||||
execution_list.add_strong_link(link[0], link[1], unique_id)
|
||||
pending_subgraph_results[unique_id] = cached_outputs
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None)
|
||||
caches.outputs.set(unique_id, output_data)
|
||||
except interruption.InterruptProcessingException as iex:
|
||||
logging.info("Processing interrupted")
|
||||
@ -481,7 +482,7 @@ def _execute(server, dynprompt, caches, current_item: str, extra_data, executed,
|
||||
|
||||
executed.add(unique_id)
|
||||
|
||||
return ExecutionResult.SUCCESS, None, None
|
||||
return RecursiveExecutionTuple(ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
|
||||
class PromptExecutor:
|
||||
@ -519,7 +520,7 @@ class PromptExecutor:
|
||||
# First, send back the status to the frontend depending
|
||||
# on the exception type
|
||||
if isinstance(ex, interruption.InterruptProcessingException):
|
||||
mes = {
|
||||
mes: ExecutionInterruptedMessage = {
|
||||
"prompt_id": prompt_id,
|
||||
"node_id": node_id,
|
||||
"node_type": class_type,
|
||||
@ -527,7 +528,7 @@ class PromptExecutor:
|
||||
}
|
||||
self.add_message("execution_interrupted", mes, broadcast=True)
|
||||
else:
|
||||
mes = {
|
||||
mes: ExecutionErrorMessage = {
|
||||
"prompt_id": prompt_id,
|
||||
"node_id": node_id,
|
||||
"node_type": class_type,
|
||||
@ -544,10 +545,13 @@ class PromptExecutor:
|
||||
raise ex
|
||||
|
||||
def execute(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
|
||||
with new_execution_context(ExecutionContext(self.server, task_id=prompt_id)):
|
||||
# torchao and potentially other optimization approaches break when the models are created in inference mode
|
||||
# todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here
|
||||
inference_mode = all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt))
|
||||
with new_execution_context(ExecutionContext(self.server, task_id=prompt_id, inference_mode=inference_mode)):
|
||||
self._execute_inner(prompt, prompt_id, extra_data, execute_outputs)
|
||||
|
||||
def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
|
||||
def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None, inference_mode: bool = True):
|
||||
if execute_outputs is None:
|
||||
execute_outputs = []
|
||||
if extra_data is None:
|
||||
@ -562,7 +566,7 @@ class PromptExecutor:
|
||||
self.status_messages = []
|
||||
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
|
||||
|
||||
with torch.inference_mode() if all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt)) else nullcontext():
|
||||
with torch.inference_mode() if inference_mode else nullcontext():
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
|
||||
@ -50,6 +50,13 @@ class UnencodedPreviewImageMessage(NamedTuple):
|
||||
task_id: str = ""
|
||||
|
||||
|
||||
class ExecutionInterruptedMessage(TypedDict):
|
||||
prompt_id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
executed: list[str]
|
||||
|
||||
|
||||
class ExecutionErrorMessage(TypedDict):
|
||||
prompt_id: str
|
||||
node_id: str
|
||||
@ -74,7 +81,7 @@ ExecutedMessage = ExecutingMessage
|
||||
|
||||
SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed"], BinaryEventTypes, None]
|
||||
|
||||
SendSyncData = Union[StatusMessage, ExecutingMessage, ExecutionErrorMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None]
|
||||
SendSyncData = Union[StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None]
|
||||
|
||||
|
||||
class ExecutorToClientProgress(Protocol):
|
||||
|
||||
@ -22,7 +22,7 @@ class ProcessPoolExecutor(ProcessPool, Executor):
|
||||
|
||||
def schedule(self, function: Callable,
|
||||
args: list = (),
|
||||
kwargs: dict = {},
|
||||
kwargs=None,
|
||||
timeout: float = None) -> ProcessFuture:
|
||||
# todo: restart worker when there is insufficient VRAM or the workflows are sufficiently different
|
||||
# try:
|
||||
@ -31,6 +31,8 @@ class ProcessPoolExecutor(ProcessPool, Executor):
|
||||
#
|
||||
# except ValueError:
|
||||
# pass
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
return super().schedule(function, args, kwargs, timeout)
|
||||
|
||||
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
|
||||
|
||||
@ -2,21 +2,24 @@ from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import NamedTuple, Optional
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Optional, Final
|
||||
|
||||
from .component_model.executor_types import ExecutorToClientProgress
|
||||
from .distributed.server_stub import ServerStub
|
||||
|
||||
_current_context = ContextVar("comfyui_execution_context")
|
||||
_current_context: Final[ContextVar] = ContextVar("comfyui_execution_context")
|
||||
|
||||
|
||||
class ExecutionContext(NamedTuple):
|
||||
@dataclass(frozen=True)
|
||||
class ExecutionContext:
|
||||
server: ExecutorToClientProgress
|
||||
node_id: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
inference_mode: bool = True
|
||||
|
||||
|
||||
_empty_execution_context = ExecutionContext(ServerStub())
|
||||
_empty_execution_context: Final[ExecutionContext] = ExecutionContext(server=ServerStub())
|
||||
|
||||
|
||||
def current_execution_context() -> ExecutionContext:
|
||||
@ -30,7 +33,7 @@ def current_execution_context() -> ExecutionContext:
|
||||
def new_execution_context(ctx: ExecutionContext):
|
||||
token = _current_context.set(ctx)
|
||||
try:
|
||||
yield
|
||||
yield ctx
|
||||
finally:
|
||||
_current_context.reset(token)
|
||||
|
||||
@ -38,6 +41,6 @@ def new_execution_context(ctx: ExecutionContext):
|
||||
@contextmanager
|
||||
def context_execute_node(node_id: str, prompt_id: str):
|
||||
current_ctx = current_execution_context()
|
||||
new_ctx = ExecutionContext(current_ctx.server, node_id, prompt_id)
|
||||
new_ctx = replace(current_ctx, node_id=node_id, task_id=prompt_id)
|
||||
with new_execution_context(new_ctx):
|
||||
yield
|
||||
yield new_ctx
|
||||
|
||||
@ -21,6 +21,7 @@ import torch
|
||||
|
||||
from . import model_management
|
||||
from .cli_args import args
|
||||
from .execution_context import current_execution_context
|
||||
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||
@ -352,7 +353,7 @@ class fp8_ops(manual_cast):
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, inference_mode: Optional[bool] = None):
|
||||
if inference_mode is None:
|
||||
# todo: check a context here, since this isn't being used by any callers yet
|
||||
inference_mode = False
|
||||
inference_mode = current_execution_context().inference_mode
|
||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||
# disable_weight_init seems to interact poorly with some other optimization code
|
||||
return disable_weight_init if inference_mode else skip_init
|
||||
|
||||
Loading…
Reference in New Issue
Block a user