Fix inference mode execution issues

This commit is contained in:
doctorpangloss 2024-10-10 21:00:09 -07:00
parent a38968f098
commit f3da381869
5 changed files with 37 additions and 20 deletions

View File

@ -24,7 +24,7 @@ from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \ from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \ ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \ RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \
HistoryResultDict HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage
from ..component_model.files import canonicalize_path from ..component_model.files import canonicalize_path
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..execution_context import new_execution_context, context_execute_node, ExecutionContext from ..execution_context import new_execution_context, context_execute_node, ExecutionContext
@ -311,6 +311,7 @@ def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches,
with context_execute_node(_node_id, prompt_id): with context_execute_node(_node_id, prompt_id):
return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
def _execute(server, dynprompt, caches, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple: def _execute(server, dynprompt, caches, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple:
unique_id = current_item unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id) real_node_id = dynprompt.get_real_node_id(unique_id)
@ -368,11 +369,11 @@ def _execute(server, dynprompt, caches, current_item: str, extra_data, executed,
if len(required_inputs) > 0: if len(required_inputs) > 0:
for i in required_inputs: for i in required_inputs:
execution_list.make_input_strong_link(unique_id, i) execution_list.make_input_strong_link(unique_id, i)
return (ExecutionResult.PENDING, None, None) return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None)
def execution_block_cb(block): def execution_block_cb(block):
if block.message is not None: if block.message is not None:
mes = { mes: ExecutionErrorMessage = {
"prompt_id": prompt_id, "prompt_id": prompt_id,
"node_id": unique_id, "node_id": unique_id,
"node_type": class_type, "node_type": class_type,
@ -442,7 +443,7 @@ def _execute(server, dynprompt, caches, current_item: str, extra_data, executed,
for link in new_output_links: for link in new_output_links:
execution_list.add_strong_link(link[0], link[1], unique_id) execution_list.add_strong_link(link[0], link[1], unique_id)
pending_subgraph_results[unique_id] = cached_outputs pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.PENDING, None, None) return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None)
caches.outputs.set(unique_id, output_data) caches.outputs.set(unique_id, output_data)
except interruption.InterruptProcessingException as iex: except interruption.InterruptProcessingException as iex:
logging.info("Processing interrupted") logging.info("Processing interrupted")
@ -481,7 +482,7 @@ def _execute(server, dynprompt, caches, current_item: str, extra_data, executed,
executed.add(unique_id) executed.add(unique_id)
return ExecutionResult.SUCCESS, None, None return RecursiveExecutionTuple(ExecutionResult.SUCCESS, None, None)
class PromptExecutor: class PromptExecutor:
@ -519,7 +520,7 @@ class PromptExecutor:
# First, send back the status to the frontend depending # First, send back the status to the frontend depending
# on the exception type # on the exception type
if isinstance(ex, interruption.InterruptProcessingException): if isinstance(ex, interruption.InterruptProcessingException):
mes = { mes: ExecutionInterruptedMessage = {
"prompt_id": prompt_id, "prompt_id": prompt_id,
"node_id": node_id, "node_id": node_id,
"node_type": class_type, "node_type": class_type,
@ -527,7 +528,7 @@ class PromptExecutor:
} }
self.add_message("execution_interrupted", mes, broadcast=True) self.add_message("execution_interrupted", mes, broadcast=True)
else: else:
mes = { mes: ExecutionErrorMessage = {
"prompt_id": prompt_id, "prompt_id": prompt_id,
"node_id": node_id, "node_id": node_id,
"node_type": class_type, "node_type": class_type,
@ -544,10 +545,13 @@ class PromptExecutor:
raise ex raise ex
def execute(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None): def execute(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
with new_execution_context(ExecutionContext(self.server, task_id=prompt_id)): # torchao and potentially other optimization approaches break when the models are created in inference mode
# todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here
inference_mode = all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt))
with new_execution_context(ExecutionContext(self.server, task_id=prompt_id, inference_mode=inference_mode)):
self._execute_inner(prompt, prompt_id, extra_data, execute_outputs) self._execute_inner(prompt, prompt_id, extra_data, execute_outputs)
def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None): def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None, inference_mode: bool = True):
if execute_outputs is None: if execute_outputs is None:
execute_outputs = [] execute_outputs = []
if extra_data is None: if extra_data is None:
@ -562,7 +566,7 @@ class PromptExecutor:
self.status_messages = [] self.status_messages = []
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False) self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode() if all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt)) else nullcontext(): with torch.inference_mode() if inference_mode else nullcontext():
dynamic_prompt = DynamicPrompt(prompt) dynamic_prompt = DynamicPrompt(prompt)
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs) is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
for cache in self.caches.all: for cache in self.caches.all:

View File

@ -50,6 +50,13 @@ class UnencodedPreviewImageMessage(NamedTuple):
task_id: str = "" task_id: str = ""
class ExecutionInterruptedMessage(TypedDict):
prompt_id: str
node_id: str
node_type: str
executed: list[str]
class ExecutionErrorMessage(TypedDict): class ExecutionErrorMessage(TypedDict):
prompt_id: str prompt_id: str
node_id: str node_id: str
@ -74,7 +81,7 @@ ExecutedMessage = ExecutingMessage
SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed"], BinaryEventTypes, None] SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed"], BinaryEventTypes, None]
SendSyncData = Union[StatusMessage, ExecutingMessage, ExecutionErrorMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None] SendSyncData = Union[StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None]
class ExecutorToClientProgress(Protocol): class ExecutorToClientProgress(Protocol):

View File

@ -22,7 +22,7 @@ class ProcessPoolExecutor(ProcessPool, Executor):
def schedule(self, function: Callable, def schedule(self, function: Callable,
args: list = (), args: list = (),
kwargs: dict = {}, kwargs=None,
timeout: float = None) -> ProcessFuture: timeout: float = None) -> ProcessFuture:
# todo: restart worker when there is insufficient VRAM or the workflows are sufficiently different # todo: restart worker when there is insufficient VRAM or the workflows are sufficiently different
# try: # try:
@ -31,6 +31,8 @@ class ProcessPoolExecutor(ProcessPool, Executor):
# #
# except ValueError: # except ValueError:
# pass # pass
if kwargs is None:
kwargs = {}
return super().schedule(function, args, kwargs, timeout) return super().schedule(function, args, kwargs, timeout)
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future: def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:

View File

@ -2,21 +2,24 @@ from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from contextvars import ContextVar from contextvars import ContextVar
from typing import NamedTuple, Optional from dataclasses import dataclass, replace
from typing import Optional, Final
from .component_model.executor_types import ExecutorToClientProgress from .component_model.executor_types import ExecutorToClientProgress
from .distributed.server_stub import ServerStub from .distributed.server_stub import ServerStub
_current_context = ContextVar("comfyui_execution_context") _current_context: Final[ContextVar] = ContextVar("comfyui_execution_context")
class ExecutionContext(NamedTuple): @dataclass(frozen=True)
class ExecutionContext:
server: ExecutorToClientProgress server: ExecutorToClientProgress
node_id: Optional[str] = None node_id: Optional[str] = None
task_id: Optional[str] = None task_id: Optional[str] = None
inference_mode: bool = True
_empty_execution_context = ExecutionContext(ServerStub()) _empty_execution_context: Final[ExecutionContext] = ExecutionContext(server=ServerStub())
def current_execution_context() -> ExecutionContext: def current_execution_context() -> ExecutionContext:
@ -30,7 +33,7 @@ def current_execution_context() -> ExecutionContext:
def new_execution_context(ctx: ExecutionContext): def new_execution_context(ctx: ExecutionContext):
token = _current_context.set(ctx) token = _current_context.set(ctx)
try: try:
yield yield ctx
finally: finally:
_current_context.reset(token) _current_context.reset(token)
@ -38,6 +41,6 @@ def new_execution_context(ctx: ExecutionContext):
@contextmanager @contextmanager
def context_execute_node(node_id: str, prompt_id: str): def context_execute_node(node_id: str, prompt_id: str):
current_ctx = current_execution_context() current_ctx = current_execution_context()
new_ctx = ExecutionContext(current_ctx.server, node_id, prompt_id) new_ctx = replace(current_ctx, node_id=node_id, task_id=prompt_id)
with new_execution_context(new_ctx): with new_execution_context(new_ctx):
yield yield new_ctx

View File

@ -21,6 +21,7 @@ import torch
from . import model_management from . import model_management
from .cli_args import args from .cli_args import args
from .execution_context import current_execution_context
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False): def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
@ -352,7 +353,7 @@ class fp8_ops(manual_cast):
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, inference_mode: Optional[bool] = None): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, inference_mode: Optional[bool] = None):
if inference_mode is None: if inference_mode is None:
# todo: check a context here, since this isn't being used by any callers yet # todo: check a context here, since this isn't being used by any callers yet
inference_mode = False inference_mode = current_execution_context().inference_mode
if compute_dtype is None or weight_dtype == compute_dtype: if compute_dtype is None or weight_dtype == compute_dtype:
# disable_weight_init seems to interact poorly with some other optimization code # disable_weight_init seems to interact poorly with some other optimization code
return disable_weight_init if inference_mode else skip_init return disable_weight_init if inference_mode else skip_init