Improve requirements.txt for faster installation, improve validation error reporting

This commit is contained in:
doctorpangloss 2024-07-19 09:16:18 -07:00
parent 6a6f1e4960
commit 499545c373
3 changed files with 43 additions and 21 deletions

View File

@ -19,7 +19,8 @@ from .. import interruption
from .. import model_management
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..execution_context import new_execution_context, ExecutionContext
from ..nodes.package import import_all_nodes_in_workspace
@ -163,7 +164,7 @@ def get_output_data(obj, input_data_all) -> Tuple[List[typing.Any], typing.Dict[
return output, ui
def format_value(x):
def format_value(x) -> FormattedValue:
if x is None:
return None
elif isinstance(x, (int, float, bool, str)):
@ -172,7 +173,6 @@ def format_value(x):
return str(x)
@tracer.start_as_current_span("Recursive Execute")
def recursive_execute(server: ExecutorToClientProgress,
prompt,
outputs,
@ -181,15 +181,13 @@ def recursive_execute(server: ExecutorToClientProgress,
executed,
prompt_id,
outputs_ui,
object_storage):
span = get_current_span()
object_storage) -> RecursiveExecutionTuple:
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
span.set_attribute("class_type", class_type)
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if unique_id in outputs:
return (True, None, None)
return RecursiveExecutionTuple(True, None, None)
for x in inputs:
input_data = inputs[x]
@ -227,11 +225,11 @@ def recursive_execute(server: ExecutorToClientProgress,
logging.info("Processing interrupted")
# skip formatting inputs/outputs
error_details = {
error_details: RecursiveExecutionErrorDetailsInterrupted = {
"node_id": unique_id,
}
return (False, error_details, iex)
return RecursiveExecutionTuple(False, error_details, iex)
except Exception as ex:
typ, _, tb = sys.exc_info()
exception_type = full_type_name(typ)
@ -248,7 +246,7 @@ def recursive_execute(server: ExecutorToClientProgress,
logging.error(f"!!! Exception during processing!!! {ex}")
logging.error(traceback.format_exc())
error_details = {
error_details: RecursiveExecutionErrorDetails = {
"node_id": unique_id,
"exception_message": str(ex),
"exception_type": exception_type,
@ -256,14 +254,19 @@ def recursive_execute(server: ExecutorToClientProgress,
"current_inputs": input_data_formatted,
"current_outputs": output_data_formatted
}
return (False, error_details, ex)
return RecursiveExecutionTuple(False, error_details, ex)
executed.add(unique_id)
return (True, None, None)
return RecursiveExecutionTuple(True, None, None)
def recursive_will_execute(prompt, outputs, current_item, memo={}):
_memo = {}
def recursive_will_execute(prompt, outputs, current_item, memo: dict[str, list[str]] = None) -> list[str]:
if memo is None:
memo = _memo
unique_id = current_item
if unique_id in memo:
@ -733,12 +736,12 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
span.set_status(Status(StatusCode.ERROR))
if res.error is not None and len(res.error) > 0:
span.set_attributes({
f"error.{k}": v for k, v in res.error.items() if isinstance(v, (bool, str, bytes, int, float, list[str], list[int], list[float]))
f"error.{k}": v for k, v in res.error.items() if isinstance(v, (bool, str, bytes, int, float, list))
})
if "extra_info" in res.error and isinstance(res.error["extra_info"], dict):
extra_info: ValidationErrorExtraInfoDict = res.error["extra_info"]
span.set_attributes({
f"error.extra_info.{k}": v for k, v in extra_info.items() if isinstance(v, (str, list[str]))
f"error.extra_info.{k}": v for k, v in extra_info.items() if isinstance(v, (str, list))
})
if len(res.node_errors) > 0:
for node_id, node_error in res.node_errors.items():

View File

@ -90,6 +90,7 @@ class ExecutorToClientProgress(Protocol):
ExceptionTypes = Literal["custom_validation_failed", "value_not_in_list", "value_bigger_than_max", "value_smaller_than_min", "invalid_input_type", "exception_during_inner_validation", "return_type_mismatch", "bad_linked_input", "required_input_missing", "invalid_prompt", "prompt_no_outputs", "exception_during_validation", "prompt_outputs_failed_validation"]
FormattedValue = str | int | bool | float | None
class ValidationErrorExtraInfoDict(TypedDict, total=False):
@ -101,7 +102,7 @@ class ValidationErrorExtraInfoDict(TypedDict, total=False):
input_config: NotRequired[typing.Dict[str, InputTypeSpec]]
received_value: NotRequired[typing.Any]
linked_node: NotRequired[str]
traceback: NotRequired[str]
traceback: NotRequired[list[str]]
exception_message: NotRequired[str]
exception_type: NotRequired[str]
@ -130,3 +131,22 @@ class ValidateInputsTuple(typing.NamedTuple):
valid: bool
errors: List[ValidationErrorDict]
unique_id: str
class RecursiveExecutionErrorDetailsInterrupted(TypedDict, total=True):
node_id: str
class RecursiveExecutionErrorDetails(TypedDict, total=True):
node_id: str
exception_message: str
exception_type: str
traceback: list[str]
current_inputs: NotRequired[dict[str, FormattedValue]]
current_outputs: NotRequired[dict[str, list[list[FormattedValue]]]]
class RecursiveExecutionTuple(typing.NamedTuple):
valid: bool
error_details: Optional[RecursiveExecutionErrorDetails | RecursiveExecutionErrorDetailsInterrupted]
exc_info: Optional[Exception]

View File

@ -1,6 +1,6 @@
torch
torchvision
torchaudio
torchaudio>=2.1.2
torchdiffeq>=0.2.3
torchsde>=0.2.6
einops>=0.6.0
@ -8,11 +8,10 @@ open-clip-torch>=2.24.0
transformers>=4.29.1
tokenizers>=0.13.3
sentencepiece
peft
peft==*
torchinfo
safetensors>=0.4.2
bitsandbytes
pytorch-lightning>=2.0.0
aiohttp>=3.8.4
accelerate>=0.25.0
pyyaml>=6.0
@ -35,7 +34,7 @@ psutil
ConfigArgParse
aio-pika
pyjwt[crypto]
kornia>=0.7.1
kornia==*
mpmath>=1.0,!=1.4.0a0
huggingface_hub
lazy-object-proxy
@ -54,7 +53,7 @@ opentelemetry-semantic-conventions
huggingface_extra_chat_templates @ git+https://github.com/AppMana/appmana-comfyui-chat-templates.git
wrapt>=1.16.0
certifi
spandrel
spandrel==*
numpy>=1.26.3,<2.0.0
soundfile
watchdog