diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 7e55bdf3c..b1a07273d 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -19,7 +19,8 @@ from .. import interruption from .. import model_management from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \ - ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict + ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \ + RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..execution_context import new_execution_context, ExecutionContext from ..nodes.package import import_all_nodes_in_workspace @@ -163,7 +164,7 @@ def get_output_data(obj, input_data_all) -> Tuple[List[typing.Any], typing.Dict[ return output, ui -def format_value(x): +def format_value(x) -> FormattedValue: if x is None: return None elif isinstance(x, (int, float, bool, str)): @@ -172,7 +173,6 @@ def format_value(x): return str(x) -@tracer.start_as_current_span("Recursive Execute") def recursive_execute(server: ExecutorToClientProgress, prompt, outputs, @@ -181,15 +181,13 @@ def recursive_execute(server: ExecutorToClientProgress, executed, prompt_id, outputs_ui, - object_storage): - span = get_current_span() + object_storage) -> RecursiveExecutionTuple: unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] - span.set_attribute("class_type", class_type) class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if unique_id in outputs: - return (True, None, None) + return RecursiveExecutionTuple(True, None, None) for x in inputs: input_data = inputs[x] @@ -227,11 +225,11 @@ def recursive_execute(server: ExecutorToClientProgress, logging.info("Processing interrupted") # skip formatting inputs/outputs - error_details = { + error_details: RecursiveExecutionErrorDetailsInterrupted = { "node_id": unique_id, } - return (False, error_details, iex) + return RecursiveExecutionTuple(False, error_details, iex) except Exception as ex: typ, _, tb = sys.exc_info() exception_type = full_type_name(typ) @@ -248,7 +246,7 @@ def recursive_execute(server: ExecutorToClientProgress, logging.error(f"!!! Exception during processing!!! {ex}") logging.error(traceback.format_exc()) - error_details = { + error_details: RecursiveExecutionErrorDetails = { "node_id": unique_id, "exception_message": str(ex), "exception_type": exception_type, @@ -256,14 +254,19 @@ def recursive_execute(server: ExecutorToClientProgress, "current_inputs": input_data_formatted, "current_outputs": output_data_formatted } - return (False, error_details, ex) + return RecursiveExecutionTuple(False, error_details, ex) executed.add(unique_id) - return (True, None, None) + return RecursiveExecutionTuple(True, None, None) -def recursive_will_execute(prompt, outputs, current_item, memo={}): +_memo = {} + + +def recursive_will_execute(prompt, outputs, current_item, memo: dict[str, list[str]] = None) -> list[str]: + if memo is None: + memo = _memo unique_id = current_item if unique_id in memo: @@ -733,12 +736,12 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple: span.set_status(Status(StatusCode.ERROR)) if res.error is not None and len(res.error) > 0: span.set_attributes({ - f"error.{k}": v for k, v in res.error.items() if isinstance(v, (bool, str, bytes, int, float, list[str], list[int], list[float])) + f"error.{k}": v for k, v in res.error.items() if isinstance(v, (bool, str, bytes, int, float, list)) }) if "extra_info" in res.error and isinstance(res.error["extra_info"], dict): extra_info: ValidationErrorExtraInfoDict = res.error["extra_info"] span.set_attributes({ - f"error.extra_info.{k}": v for k, v in extra_info.items() if isinstance(v, (str, list[str])) + f"error.extra_info.{k}": v for k, v in extra_info.items() if isinstance(v, (str, list)) }) if len(res.node_errors) > 0: for node_id, node_error in res.node_errors.items(): diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index a0b35d7f3..d75f8e212 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -90,6 +90,7 @@ class ExecutorToClientProgress(Protocol): ExceptionTypes = Literal["custom_validation_failed", "value_not_in_list", "value_bigger_than_max", "value_smaller_than_min", "invalid_input_type", "exception_during_inner_validation", "return_type_mismatch", "bad_linked_input", "required_input_missing", "invalid_prompt", "prompt_no_outputs", "exception_during_validation", "prompt_outputs_failed_validation"] +FormattedValue = str | int | bool | float | None class ValidationErrorExtraInfoDict(TypedDict, total=False): @@ -101,7 +102,7 @@ class ValidationErrorExtraInfoDict(TypedDict, total=False): input_config: NotRequired[typing.Dict[str, InputTypeSpec]] received_value: NotRequired[typing.Any] linked_node: NotRequired[str] - traceback: NotRequired[str] + traceback: NotRequired[list[str]] exception_message: NotRequired[str] exception_type: NotRequired[str] @@ -130,3 +131,22 @@ class ValidateInputsTuple(typing.NamedTuple): valid: bool errors: List[ValidationErrorDict] unique_id: str + + +class RecursiveExecutionErrorDetailsInterrupted(TypedDict, total=True): + node_id: str + + +class RecursiveExecutionErrorDetails(TypedDict, total=True): + node_id: str + exception_message: str + exception_type: str + traceback: list[str] + current_inputs: NotRequired[dict[str, FormattedValue]] + current_outputs: NotRequired[dict[str, list[list[FormattedValue]]]] + + +class RecursiveExecutionTuple(typing.NamedTuple): + valid: bool + error_details: Optional[RecursiveExecutionErrorDetails | RecursiveExecutionErrorDetailsInterrupted] + exc_info: Optional[Exception] diff --git a/requirements.txt b/requirements.txt index b5d40ee31..88696a9af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ torch torchvision -torchaudio +torchaudio>=2.1.2 torchdiffeq>=0.2.3 torchsde>=0.2.6 einops>=0.6.0 @@ -8,11 +8,10 @@ open-clip-torch>=2.24.0 transformers>=4.29.1 tokenizers>=0.13.3 sentencepiece -peft +peft==* torchinfo safetensors>=0.4.2 bitsandbytes -pytorch-lightning>=2.0.0 aiohttp>=3.8.4 accelerate>=0.25.0 pyyaml>=6.0 @@ -35,7 +34,7 @@ psutil ConfigArgParse aio-pika pyjwt[crypto] -kornia>=0.7.1 +kornia==* mpmath>=1.0,!=1.4.0a0 huggingface_hub lazy-object-proxy @@ -54,7 +53,7 @@ opentelemetry-semantic-conventions huggingface_extra_chat_templates @ git+https://github.com/AppMana/appmana-comfyui-chat-templates.git wrapt>=1.16.0 certifi -spandrel +spandrel==* numpy>=1.26.3,<2.0.0 soundfile watchdog