mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 21:30:15 +08:00
Fix inference mode execution issues
This commit is contained in:
parent
a38968f098
commit
f3da381869
@ -24,7 +24,7 @@ from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
|||||||
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
|
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
|
||||||
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
|
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
|
||||||
RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \
|
RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \
|
||||||
HistoryResultDict
|
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage
|
||||||
from ..component_model.files import canonicalize_path
|
from ..component_model.files import canonicalize_path
|
||||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
||||||
from ..execution_context import new_execution_context, context_execute_node, ExecutionContext
|
from ..execution_context import new_execution_context, context_execute_node, ExecutionContext
|
||||||
@ -311,6 +311,7 @@ def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches,
|
|||||||
with context_execute_node(_node_id, prompt_id):
|
with context_execute_node(_node_id, prompt_id):
|
||||||
return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
||||||
|
|
||||||
|
|
||||||
def _execute(server, dynprompt, caches, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple:
|
def _execute(server, dynprompt, caches, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple:
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||||
@ -368,11 +369,11 @@ def _execute(server, dynprompt, caches, current_item: str, extra_data, executed,
|
|||||||
if len(required_inputs) > 0:
|
if len(required_inputs) > 0:
|
||||||
for i in required_inputs:
|
for i in required_inputs:
|
||||||
execution_list.make_input_strong_link(unique_id, i)
|
execution_list.make_input_strong_link(unique_id, i)
|
||||||
return (ExecutionResult.PENDING, None, None)
|
return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None)
|
||||||
|
|
||||||
def execution_block_cb(block):
|
def execution_block_cb(block):
|
||||||
if block.message is not None:
|
if block.message is not None:
|
||||||
mes = {
|
mes: ExecutionErrorMessage = {
|
||||||
"prompt_id": prompt_id,
|
"prompt_id": prompt_id,
|
||||||
"node_id": unique_id,
|
"node_id": unique_id,
|
||||||
"node_type": class_type,
|
"node_type": class_type,
|
||||||
@ -442,7 +443,7 @@ def _execute(server, dynprompt, caches, current_item: str, extra_data, executed,
|
|||||||
for link in new_output_links:
|
for link in new_output_links:
|
||||||
execution_list.add_strong_link(link[0], link[1], unique_id)
|
execution_list.add_strong_link(link[0], link[1], unique_id)
|
||||||
pending_subgraph_results[unique_id] = cached_outputs
|
pending_subgraph_results[unique_id] = cached_outputs
|
||||||
return (ExecutionResult.PENDING, None, None)
|
return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None)
|
||||||
caches.outputs.set(unique_id, output_data)
|
caches.outputs.set(unique_id, output_data)
|
||||||
except interruption.InterruptProcessingException as iex:
|
except interruption.InterruptProcessingException as iex:
|
||||||
logging.info("Processing interrupted")
|
logging.info("Processing interrupted")
|
||||||
@ -481,7 +482,7 @@ def _execute(server, dynprompt, caches, current_item: str, extra_data, executed,
|
|||||||
|
|
||||||
executed.add(unique_id)
|
executed.add(unique_id)
|
||||||
|
|
||||||
return ExecutionResult.SUCCESS, None, None
|
return RecursiveExecutionTuple(ExecutionResult.SUCCESS, None, None)
|
||||||
|
|
||||||
|
|
||||||
class PromptExecutor:
|
class PromptExecutor:
|
||||||
@ -519,7 +520,7 @@ class PromptExecutor:
|
|||||||
# First, send back the status to the frontend depending
|
# First, send back the status to the frontend depending
|
||||||
# on the exception type
|
# on the exception type
|
||||||
if isinstance(ex, interruption.InterruptProcessingException):
|
if isinstance(ex, interruption.InterruptProcessingException):
|
||||||
mes = {
|
mes: ExecutionInterruptedMessage = {
|
||||||
"prompt_id": prompt_id,
|
"prompt_id": prompt_id,
|
||||||
"node_id": node_id,
|
"node_id": node_id,
|
||||||
"node_type": class_type,
|
"node_type": class_type,
|
||||||
@ -527,7 +528,7 @@ class PromptExecutor:
|
|||||||
}
|
}
|
||||||
self.add_message("execution_interrupted", mes, broadcast=True)
|
self.add_message("execution_interrupted", mes, broadcast=True)
|
||||||
else:
|
else:
|
||||||
mes = {
|
mes: ExecutionErrorMessage = {
|
||||||
"prompt_id": prompt_id,
|
"prompt_id": prompt_id,
|
||||||
"node_id": node_id,
|
"node_id": node_id,
|
||||||
"node_type": class_type,
|
"node_type": class_type,
|
||||||
@ -544,10 +545,13 @@ class PromptExecutor:
|
|||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
def execute(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
|
def execute(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
|
||||||
with new_execution_context(ExecutionContext(self.server, task_id=prompt_id)):
|
# torchao and potentially other optimization approaches break when the models are created in inference mode
|
||||||
|
# todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here
|
||||||
|
inference_mode = all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt))
|
||||||
|
with new_execution_context(ExecutionContext(self.server, task_id=prompt_id, inference_mode=inference_mode)):
|
||||||
self._execute_inner(prompt, prompt_id, extra_data, execute_outputs)
|
self._execute_inner(prompt, prompt_id, extra_data, execute_outputs)
|
||||||
|
|
||||||
def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
|
def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None, inference_mode: bool = True):
|
||||||
if execute_outputs is None:
|
if execute_outputs is None:
|
||||||
execute_outputs = []
|
execute_outputs = []
|
||||||
if extra_data is None:
|
if extra_data is None:
|
||||||
@ -562,7 +566,7 @@ class PromptExecutor:
|
|||||||
self.status_messages = []
|
self.status_messages = []
|
||||||
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
|
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
|
||||||
|
|
||||||
with torch.inference_mode() if all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt)) else nullcontext():
|
with torch.inference_mode() if inference_mode else nullcontext():
|
||||||
dynamic_prompt = DynamicPrompt(prompt)
|
dynamic_prompt = DynamicPrompt(prompt)
|
||||||
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
|
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
|
||||||
for cache in self.caches.all:
|
for cache in self.caches.all:
|
||||||
|
|||||||
@ -50,6 +50,13 @@ class UnencodedPreviewImageMessage(NamedTuple):
|
|||||||
task_id: str = ""
|
task_id: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionInterruptedMessage(TypedDict):
|
||||||
|
prompt_id: str
|
||||||
|
node_id: str
|
||||||
|
node_type: str
|
||||||
|
executed: list[str]
|
||||||
|
|
||||||
|
|
||||||
class ExecutionErrorMessage(TypedDict):
|
class ExecutionErrorMessage(TypedDict):
|
||||||
prompt_id: str
|
prompt_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
@ -74,7 +81,7 @@ ExecutedMessage = ExecutingMessage
|
|||||||
|
|
||||||
SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed"], BinaryEventTypes, None]
|
SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed"], BinaryEventTypes, None]
|
||||||
|
|
||||||
SendSyncData = Union[StatusMessage, ExecutingMessage, ExecutionErrorMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None]
|
SendSyncData = Union[StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None]
|
||||||
|
|
||||||
|
|
||||||
class ExecutorToClientProgress(Protocol):
|
class ExecutorToClientProgress(Protocol):
|
||||||
|
|||||||
@ -22,7 +22,7 @@ class ProcessPoolExecutor(ProcessPool, Executor):
|
|||||||
|
|
||||||
def schedule(self, function: Callable,
|
def schedule(self, function: Callable,
|
||||||
args: list = (),
|
args: list = (),
|
||||||
kwargs: dict = {},
|
kwargs=None,
|
||||||
timeout: float = None) -> ProcessFuture:
|
timeout: float = None) -> ProcessFuture:
|
||||||
# todo: restart worker when there is insufficient VRAM or the workflows are sufficiently different
|
# todo: restart worker when there is insufficient VRAM or the workflows are sufficiently different
|
||||||
# try:
|
# try:
|
||||||
@ -31,6 +31,8 @@ class ProcessPoolExecutor(ProcessPool, Executor):
|
|||||||
#
|
#
|
||||||
# except ValueError:
|
# except ValueError:
|
||||||
# pass
|
# pass
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
return super().schedule(function, args, kwargs, timeout)
|
return super().schedule(function, args, kwargs, timeout)
|
||||||
|
|
||||||
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
|
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
|
||||||
|
|||||||
@ -2,21 +2,24 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import NamedTuple, Optional
|
from dataclasses import dataclass, replace
|
||||||
|
from typing import Optional, Final
|
||||||
|
|
||||||
from .component_model.executor_types import ExecutorToClientProgress
|
from .component_model.executor_types import ExecutorToClientProgress
|
||||||
from .distributed.server_stub import ServerStub
|
from .distributed.server_stub import ServerStub
|
||||||
|
|
||||||
_current_context = ContextVar("comfyui_execution_context")
|
_current_context: Final[ContextVar] = ContextVar("comfyui_execution_context")
|
||||||
|
|
||||||
|
|
||||||
class ExecutionContext(NamedTuple):
|
@dataclass(frozen=True)
|
||||||
|
class ExecutionContext:
|
||||||
server: ExecutorToClientProgress
|
server: ExecutorToClientProgress
|
||||||
node_id: Optional[str] = None
|
node_id: Optional[str] = None
|
||||||
task_id: Optional[str] = None
|
task_id: Optional[str] = None
|
||||||
|
inference_mode: bool = True
|
||||||
|
|
||||||
|
|
||||||
_empty_execution_context = ExecutionContext(ServerStub())
|
_empty_execution_context: Final[ExecutionContext] = ExecutionContext(server=ServerStub())
|
||||||
|
|
||||||
|
|
||||||
def current_execution_context() -> ExecutionContext:
|
def current_execution_context() -> ExecutionContext:
|
||||||
@ -30,7 +33,7 @@ def current_execution_context() -> ExecutionContext:
|
|||||||
def new_execution_context(ctx: ExecutionContext):
|
def new_execution_context(ctx: ExecutionContext):
|
||||||
token = _current_context.set(ctx)
|
token = _current_context.set(ctx)
|
||||||
try:
|
try:
|
||||||
yield
|
yield ctx
|
||||||
finally:
|
finally:
|
||||||
_current_context.reset(token)
|
_current_context.reset(token)
|
||||||
|
|
||||||
@ -38,6 +41,6 @@ def new_execution_context(ctx: ExecutionContext):
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def context_execute_node(node_id: str, prompt_id: str):
|
def context_execute_node(node_id: str, prompt_id: str):
|
||||||
current_ctx = current_execution_context()
|
current_ctx = current_execution_context()
|
||||||
new_ctx = ExecutionContext(current_ctx.server, node_id, prompt_id)
|
new_ctx = replace(current_ctx, node_id=node_id, task_id=prompt_id)
|
||||||
with new_execution_context(new_ctx):
|
with new_execution_context(new_ctx):
|
||||||
yield
|
yield new_ctx
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import torch
|
|||||||
|
|
||||||
from . import model_management
|
from . import model_management
|
||||||
from .cli_args import args
|
from .cli_args import args
|
||||||
|
from .execution_context import current_execution_context
|
||||||
|
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||||
@ -352,7 +353,7 @@ class fp8_ops(manual_cast):
|
|||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, inference_mode: Optional[bool] = None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, inference_mode: Optional[bool] = None):
|
||||||
if inference_mode is None:
|
if inference_mode is None:
|
||||||
# todo: check a context here, since this isn't being used by any callers yet
|
# todo: check a context here, since this isn't being used by any callers yet
|
||||||
inference_mode = False
|
inference_mode = current_execution_context().inference_mode
|
||||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||||
# disable_weight_init seems to interact poorly with some other optimization code
|
# disable_weight_init seems to interact poorly with some other optimization code
|
||||||
return disable_weight_init if inference_mode else skip_init
|
return disable_weight_init if inference_mode else skip_init
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user