Fix inference mode execution issues

This commit is contained in:
doctorpangloss 2024-10-10 21:00:09 -07:00
parent a38968f098
commit f3da381869
5 changed files with 37 additions and 20 deletions

View File

@ -24,7 +24,7 @@ from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \
HistoryResultDict
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage
from ..component_model.files import canonicalize_path
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..execution_context import new_execution_context, context_execute_node, ExecutionContext
@ -311,6 +311,7 @@ def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches,
with context_execute_node(_node_id, prompt_id):
return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
def _execute(server, dynprompt, caches, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple:
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
@ -368,11 +369,11 @@ def _execute(server, dynprompt, caches, current_item: str, extra_data, executed,
if len(required_inputs) > 0:
for i in required_inputs:
execution_list.make_input_strong_link(unique_id, i)
return (ExecutionResult.PENDING, None, None)
return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None)
def execution_block_cb(block):
if block.message is not None:
mes = {
mes: ExecutionErrorMessage = {
"prompt_id": prompt_id,
"node_id": unique_id,
"node_type": class_type,
@ -442,7 +443,7 @@ def _execute(server, dynprompt, caches, current_item: str, extra_data, executed,
for link in new_output_links:
execution_list.add_strong_link(link[0], link[1], unique_id)
pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.PENDING, None, None)
return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None)
caches.outputs.set(unique_id, output_data)
except interruption.InterruptProcessingException as iex:
logging.info("Processing interrupted")
@ -481,7 +482,7 @@ def _execute(server, dynprompt, caches, current_item: str, extra_data, executed,
executed.add(unique_id)
return ExecutionResult.SUCCESS, None, None
return RecursiveExecutionTuple(ExecutionResult.SUCCESS, None, None)
class PromptExecutor:
@ -519,7 +520,7 @@ class PromptExecutor:
# First, send back the status to the frontend depending
# on the exception type
if isinstance(ex, interruption.InterruptProcessingException):
mes = {
mes: ExecutionInterruptedMessage = {
"prompt_id": prompt_id,
"node_id": node_id,
"node_type": class_type,
@ -527,7 +528,7 @@ class PromptExecutor:
}
self.add_message("execution_interrupted", mes, broadcast=True)
else:
mes = {
mes: ExecutionErrorMessage = {
"prompt_id": prompt_id,
"node_id": node_id,
"node_type": class_type,
@ -544,10 +545,13 @@ class PromptExecutor:
raise ex
def execute(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
with new_execution_context(ExecutionContext(self.server, task_id=prompt_id)):
# torchao and potentially other optimization approaches break when the models are created in inference mode
# todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here
inference_mode = all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt))
with new_execution_context(ExecutionContext(self.server, task_id=prompt_id, inference_mode=inference_mode)):
self._execute_inner(prompt, prompt_id, extra_data, execute_outputs)
def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None, inference_mode: bool = True):
if execute_outputs is None:
execute_outputs = []
if extra_data is None:
@ -562,7 +566,7 @@ class PromptExecutor:
self.status_messages = []
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode() if all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt)) else nullcontext():
with torch.inference_mode() if inference_mode else nullcontext():
dynamic_prompt = DynamicPrompt(prompt)
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:

View File

@ -50,6 +50,13 @@ class UnencodedPreviewImageMessage(NamedTuple):
task_id: str = ""
class ExecutionInterruptedMessage(TypedDict):
prompt_id: str
node_id: str
node_type: str
executed: list[str]
class ExecutionErrorMessage(TypedDict):
prompt_id: str
node_id: str
@ -74,7 +81,7 @@ ExecutedMessage = ExecutingMessage
SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed"], BinaryEventTypes, None]
SendSyncData = Union[StatusMessage, ExecutingMessage, ExecutionErrorMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None]
SendSyncData = Union[StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None]
class ExecutorToClientProgress(Protocol):

View File

@ -22,7 +22,7 @@ class ProcessPoolExecutor(ProcessPool, Executor):
def schedule(self, function: Callable,
args: list = (),
kwargs: dict = {},
kwargs=None,
timeout: float = None) -> ProcessFuture:
# todo: restart worker when there is insufficient VRAM or the workflows are sufficiently different
# try:
@ -31,6 +31,8 @@ class ProcessPoolExecutor(ProcessPool, Executor):
#
# except ValueError:
# pass
if kwargs is None:
kwargs = {}
return super().schedule(function, args, kwargs, timeout)
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:

View File

@ -2,21 +2,24 @@ from __future__ import annotations
from contextlib import contextmanager
from contextvars import ContextVar
from typing import NamedTuple, Optional
from dataclasses import dataclass, replace
from typing import Optional, Final
from .component_model.executor_types import ExecutorToClientProgress
from .distributed.server_stub import ServerStub
_current_context = ContextVar("comfyui_execution_context")
_current_context: Final[ContextVar] = ContextVar("comfyui_execution_context")
class ExecutionContext(NamedTuple):
@dataclass(frozen=True)
class ExecutionContext:
server: ExecutorToClientProgress
node_id: Optional[str] = None
task_id: Optional[str] = None
inference_mode: bool = True
_empty_execution_context = ExecutionContext(ServerStub())
_empty_execution_context: Final[ExecutionContext] = ExecutionContext(server=ServerStub())
def current_execution_context() -> ExecutionContext:
@ -30,7 +33,7 @@ def current_execution_context() -> ExecutionContext:
def new_execution_context(ctx: ExecutionContext):
token = _current_context.set(ctx)
try:
yield
yield ctx
finally:
_current_context.reset(token)
@ -38,6 +41,6 @@ def new_execution_context(ctx: ExecutionContext):
@contextmanager
def context_execute_node(node_id: str, prompt_id: str):
current_ctx = current_execution_context()
new_ctx = ExecutionContext(current_ctx.server, node_id, prompt_id)
new_ctx = replace(current_ctx, node_id=node_id, task_id=prompt_id)
with new_execution_context(new_ctx):
yield
yield new_ctx

View File

@ -21,6 +21,7 @@ import torch
from . import model_management
from .cli_args import args
from .execution_context import current_execution_context
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
@ -352,7 +353,7 @@ class fp8_ops(manual_cast):
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, inference_mode: Optional[bool] = None):
if inference_mode is None:
# todo: check a context here, since this isn't being used by any callers yet
inference_mode = False
inference_mode = current_execution_context().inference_mode
if compute_dtype is None or weight_dtype == compute_dtype:
# disable_weight_init seems to interact poorly with some other optimization code
return disable_weight_init if inference_mode else skip_init