mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-26 14:20:27 +08:00
Improve logging and tracing for validation errors
This commit is contained in:
parent
a20bf8134d
commit
72baecad87
@ -1,13 +1,15 @@
|
|||||||
|
import glob
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import uuid
|
|
||||||
import glob
|
|
||||||
import shutil
|
import shutil
|
||||||
|
import uuid
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
|
from .app_settings import AppSettings
|
||||||
from ..cli_args import args
|
from ..cli_args import args
|
||||||
from ..cmd.folder_paths import user_directory
|
from ..cmd.folder_paths import user_directory
|
||||||
from .app_settings import AppSettings
|
|
||||||
|
|
||||||
|
|
||||||
class UserManager():
|
class UserManager():
|
||||||
@ -17,9 +19,6 @@ class UserManager():
|
|||||||
self.settings = AppSettings(self)
|
self.settings = AppSettings(self)
|
||||||
if not os.path.exists(user_directory):
|
if not os.path.exists(user_directory):
|
||||||
os.mkdir(user_directory)
|
os.mkdir(user_directory)
|
||||||
if not args.multi_user:
|
|
||||||
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
|
||||||
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
|
||||||
|
|
||||||
if args.multi_user:
|
if args.multi_user:
|
||||||
if os.path.isfile(self.users_file):
|
if os.path.isfile(self.users_file):
|
||||||
@ -129,7 +128,7 @@ class UserManager():
|
|||||||
|
|
||||||
return web.json_response(results)
|
return web.json_response(results)
|
||||||
|
|
||||||
def get_user_data_path(request, check_exists = False, param = "file"):
|
def get_user_data_path(request, check_exists=False, param="file"):
|
||||||
file = request.match_info.get(param, None)
|
file = request.match_info.get(param, None)
|
||||||
if not file:
|
if not file:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
|
|||||||
@ -8,22 +8,22 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import traceback
|
import traceback
|
||||||
import typing
|
import typing
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import lazy_object_proxy
|
import lazy_object_proxy
|
||||||
import torch
|
import torch
|
||||||
from opentelemetry.trace import get_current_span, StatusCode, Status
|
from opentelemetry.trace import get_current_span, StatusCode, Status
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from .main_pre import tracer
|
from .main_pre import tracer
|
||||||
from .. import interruption
|
from .. import interruption
|
||||||
from .. import model_management
|
from .. import model_management
|
||||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress
|
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
|
||||||
|
ValidationErrorDict, NodeErrorsDictValue
|
||||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
||||||
from ..execution_context import new_execution_context, ExecutionContext
|
from ..execution_context import new_execution_context, ExecutionContext
|
||||||
from ..nodes.package import import_all_nodes_in_workspace
|
from ..nodes.package import import_all_nodes_in_workspace
|
||||||
from ..nodes.package_typing import ExportedNodes
|
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions
|
||||||
|
|
||||||
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
|
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
|
||||||
# various functions that are declared here. It should have been a context in the first place.
|
# various functions that are declared here. It should have been a context in the first place.
|
||||||
@ -492,7 +492,7 @@ class PromptExecutor:
|
|||||||
model_management.unload_all_models()
|
model_management.unload_all_models()
|
||||||
|
|
||||||
|
|
||||||
def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], typing.Any]:
|
def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple:
|
||||||
# todo: this should check if LoadImage / LoadImageMask paths exist
|
# todo: this should check if LoadImage / LoadImageMask paths exist
|
||||||
# todo: or, nodes should provide a way to validate their values
|
# todo: or, nodes should provide a way to validate their values
|
||||||
unique_id = item
|
unique_id = item
|
||||||
@ -506,11 +506,12 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
|
|||||||
class_inputs = obj_class.INPUT_TYPES()
|
class_inputs = obj_class.INPUT_TYPES()
|
||||||
required_inputs = class_inputs['required']
|
required_inputs = class_inputs['required']
|
||||||
|
|
||||||
|
error: ValidationErrorDict
|
||||||
errors = []
|
errors = []
|
||||||
valid = True
|
valid = True
|
||||||
|
|
||||||
# todo: investigate if these are at the right indent level
|
# todo: investigate if these are at the right indent level
|
||||||
info = None
|
info: Optional[InputTypeSpec] = None
|
||||||
val = None
|
val = None
|
||||||
|
|
||||||
validate_function_inputs = []
|
validate_function_inputs = []
|
||||||
@ -531,7 +532,7 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
val = inputs[x]
|
val = inputs[x]
|
||||||
info = required_inputs[x]
|
info: InputTypeSpec = required_inputs[x]
|
||||||
type_input = info[0]
|
type_input = info[0]
|
||||||
if isinstance(val, list):
|
if isinstance(val, list):
|
||||||
if len(val) != 2:
|
if len(val) != 2:
|
||||||
@ -593,7 +594,7 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
|
|||||||
"linked_node": val
|
"linked_node": val
|
||||||
}
|
}
|
||||||
}]
|
}]
|
||||||
validated[o_id] = (False, reasons, o_id)
|
validated[o_id] = ValidateInputsTuple(False, reasons, o_id)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
@ -622,10 +623,11 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if len(info) > 1:
|
if len(info) > 1:
|
||||||
if "min" in info[1] and val < info[1]["min"]:
|
has_min_max: IntSpecOptions | FloatSpecOptions = info[1]
|
||||||
|
if "min" in has_min_max and val < has_min_max["min"]:
|
||||||
error = {
|
error = {
|
||||||
"type": "value_smaller_than_min",
|
"type": "value_smaller_than_min",
|
||||||
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
|
"message": "Value {} smaller than min of {}".format(val, has_min_max["min"]),
|
||||||
"details": f"{x}",
|
"details": f"{x}",
|
||||||
"extra_info": {
|
"extra_info": {
|
||||||
"input_name": x,
|
"input_name": x,
|
||||||
@ -635,10 +637,10 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
|
|||||||
}
|
}
|
||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
if "max" in info[1] and val > info[1]["max"]:
|
if "max" in has_min_max and val > has_min_max["max"]:
|
||||||
error = {
|
error = {
|
||||||
"type": "value_bigger_than_max",
|
"type": "value_bigger_than_max",
|
||||||
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
|
"message": "Value {} bigger than max of {}".format(val, has_min_max["max"]),
|
||||||
"details": f"{x}",
|
"details": f"{x}",
|
||||||
"extra_info": {
|
"extra_info": {
|
||||||
"input_name": x,
|
"input_name": x,
|
||||||
@ -706,9 +708,9 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if len(errors) > 0 or valid is not True:
|
if len(errors) > 0 or valid is not True:
|
||||||
ret = (False, errors, unique_id)
|
ret = ValidateInputsTuple(False, errors, unique_id)
|
||||||
else:
|
else:
|
||||||
ret = (True, [], unique_id)
|
ret = ValidateInputsTuple(True, [], unique_id)
|
||||||
|
|
||||||
validated[unique_id] = ret
|
validated[unique_id] = ret
|
||||||
return ret
|
return ret
|
||||||
@ -721,23 +723,25 @@ def full_type_name(klass):
|
|||||||
return module + '.' + klass.__qualname__
|
return module + '.' + klass.__qualname__
|
||||||
|
|
||||||
|
|
||||||
class ValidationErrorExtraInfoDict(TypedDict):
|
|
||||||
exception_type: str
|
|
||||||
traceback: List[str]
|
|
||||||
|
|
||||||
|
|
||||||
class ValidationErrorDict(TypedDict):
|
|
||||||
type: str
|
|
||||||
message: str
|
|
||||||
details: str
|
|
||||||
extra_info: ValidationErrorExtraInfoDict | dict
|
|
||||||
|
|
||||||
|
|
||||||
ValidationTuple = typing.Tuple[bool, Optional[ValidationErrorDict], typing.List[str], Union[dict, list]]
|
|
||||||
|
|
||||||
|
|
||||||
@tracer.start_as_current_span("Validate Prompt")
|
@tracer.start_as_current_span("Validate Prompt")
|
||||||
def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
||||||
|
res = _validate_prompt(prompt)
|
||||||
|
if not res.valid:
|
||||||
|
span = get_current_span()
|
||||||
|
span.set_status(Status(StatusCode.ERROR))
|
||||||
|
if res.error is not None and len(res.error) > 0:
|
||||||
|
span.set_attributes({
|
||||||
|
f"error.{k}": v for k, v in res.error.items()
|
||||||
|
})
|
||||||
|
if len(res.node_errors) > 0:
|
||||||
|
for node_id, node_error in res.node_errors.items():
|
||||||
|
for node_error_field, node_error_value in node_error.items():
|
||||||
|
if isinstance(node_error_value, (str, bool, int, float)):
|
||||||
|
span.set_attribute("node_errors.{node_id}.{node_error_field}", node_error_value)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
||||||
outputs = set()
|
outputs = set()
|
||||||
for x in prompt:
|
for x in prompt:
|
||||||
if 'class_type' not in prompt[x]:
|
if 'class_type' not in prompt[x]:
|
||||||
@ -747,7 +751,7 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
|||||||
"details": f"Node ID '#{x}'",
|
"details": f"Node ID '#{x}'",
|
||||||
"extra_info": {}
|
"extra_info": {}
|
||||||
}
|
}
|
||||||
return (False, error, [], [])
|
return ValidationTuple(False, error, [], [])
|
||||||
|
|
||||||
class_type = prompt[x]['class_type']
|
class_type = prompt[x]['class_type']
|
||||||
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
|
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
|
||||||
@ -758,7 +762,7 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
|||||||
"details": f"Node ID '#{x}'",
|
"details": f"Node ID '#{x}'",
|
||||||
"extra_info": {}
|
"extra_info": {}
|
||||||
}
|
}
|
||||||
return (False, error, [], [])
|
return ValidationTuple(False, error, [], [])
|
||||||
|
|
||||||
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
||||||
outputs.add(x)
|
outputs.add(x)
|
||||||
@ -770,15 +774,15 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
|||||||
"details": "",
|
"details": "",
|
||||||
"extra_info": {}
|
"extra_info": {}
|
||||||
}
|
}
|
||||||
return False, error, [], []
|
return ValidationTuple(False, error, [], [])
|
||||||
|
|
||||||
good_outputs = set()
|
good_outputs = set()
|
||||||
errors = []
|
errors = []
|
||||||
node_errors = {}
|
node_errors: typing.Dict[str, NodeErrorsDictValue] = {}
|
||||||
validated = {}
|
validated: typing.Dict[str, ValidateInputsTuple] = {}
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
valid = False
|
valid = False
|
||||||
reasons = []
|
reasons: List[ValidationErrorDict] = []
|
||||||
try:
|
try:
|
||||||
m = validate_inputs(prompt, o, validated)
|
m = validate_inputs(prompt, o, validated)
|
||||||
valid = m[0]
|
valid = m[0]
|
||||||
@ -796,7 +800,7 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
|||||||
"traceback": traceback.format_tb(tb)
|
"traceback": traceback.format_tb(tb)
|
||||||
}
|
}
|
||||||
}]
|
}]
|
||||||
validated[o] = (False, reasons, o)
|
validated[o] = ValidateInputsTuple(False, reasons, o)
|
||||||
|
|
||||||
if valid is True:
|
if valid is True:
|
||||||
good_outputs.add(o)
|
good_outputs.add(o)
|
||||||
@ -841,9 +845,9 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
|||||||
"extra_info": {}
|
"extra_info": {}
|
||||||
}
|
}
|
||||||
|
|
||||||
return False, error, list(good_outputs), node_errors
|
return ValidationTuple(False, error, list(good_outputs), node_errors)
|
||||||
|
|
||||||
return True, None, list(good_outputs), node_errors
|
return ValidationTuple(True, None, list(good_outputs), node_errors)
|
||||||
|
|
||||||
|
|
||||||
class PromptQueue(AbstractPromptQueue):
|
class PromptQueue(AbstractPromptQueue):
|
||||||
|
|||||||
@ -7,9 +7,9 @@ import shutil
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from .extra_model_paths import load_extra_path_config
|
||||||
# main_pre must be the earliest import since it suppresses some spurious warnings
|
# main_pre must be the earliest import since it suppresses some spurious warnings
|
||||||
from .main_pre import args
|
from .main_pre import args
|
||||||
from .extra_model_paths import load_extra_path_config
|
|
||||||
from .. import model_management
|
from .. import model_management
|
||||||
from ..analytics.analytics import initialize_event_tracking
|
from ..analytics.analytics import initialize_event_tracking
|
||||||
from ..cmd import cuda_malloc
|
from ..cmd import cuda_malloc
|
||||||
@ -223,7 +223,10 @@ async def main():
|
|||||||
|
|
||||||
|
|
||||||
def entrypoint():
|
def entrypoint():
|
||||||
asyncio.run(main())
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt as keyboard_interrupt:
|
||||||
|
logging.info(f"Gracefully shutting down due to {keyboard_interrupt}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -492,9 +492,6 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
|
|
||||||
@routes.post("/prompt")
|
@routes.post("/prompt")
|
||||||
async def post_prompt(request):
|
async def post_prompt(request):
|
||||||
logging.info("got prompt")
|
|
||||||
resp_code = 200
|
|
||||||
out_string = ""
|
|
||||||
json_data = await request.json()
|
json_data = await request.json()
|
||||||
json_data = self.trigger_on_prompt(json_data)
|
json_data = self.trigger_on_prompt(json_data)
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,13 @@
|
|||||||
from __future__ import annotations # for Python 3.7-3.9
|
from __future__ import annotations # for Python 3.7-3.9
|
||||||
|
|
||||||
from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple
|
import typing
|
||||||
|
from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple, List
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
from typing_extensions import NotRequired, TypedDict
|
from typing_extensions import NotRequired, TypedDict
|
||||||
|
|
||||||
from .queue_types import BinaryEventTypes
|
from .queue_types import BinaryEventTypes
|
||||||
|
from ..nodes.package_typing import InputTypeSpec
|
||||||
|
|
||||||
|
|
||||||
class ExecInfo(TypedDict):
|
class ExecInfo(TypedDict):
|
||||||
@ -85,3 +87,46 @@ class ExecutorToClientProgress(Protocol):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
ExceptionTypes = Literal["custom_validation_failed", "value_not_in_list", "value_bigger_than_max", "value_smaller_than_min", "invalid_input_type", "exception_during_inner_validation", "return_type_mismatch", "bad_linked_input", "required_input_missing", "invalid_prompt", "prompt_no_outputs", "exception_during_validation", "prompt_outputs_failed_validation"]
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationErrorExtraInfoDict(TypedDict, total=False):
|
||||||
|
exception_type: NotRequired[str]
|
||||||
|
traceback: NotRequired[List[str]]
|
||||||
|
dependent_outputs: NotRequired[List[str]]
|
||||||
|
class_type: NotRequired[str]
|
||||||
|
input_name: NotRequired[str]
|
||||||
|
input_config: NotRequired[typing.Dict[str, InputTypeSpec]]
|
||||||
|
received_value: NotRequired[typing.Any]
|
||||||
|
linked_node: NotRequired[str]
|
||||||
|
traceback: NotRequired[str]
|
||||||
|
exception_message: NotRequired[str]
|
||||||
|
exception_type: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationErrorDict(TypedDict):
|
||||||
|
type: str
|
||||||
|
message: str
|
||||||
|
details: str
|
||||||
|
extra_info: list[typing.Never] | ValidationErrorExtraInfoDict
|
||||||
|
|
||||||
|
|
||||||
|
class NodeErrorsDictValue(TypedDict, total=False):
|
||||||
|
dependent_outputs: NotRequired[List[str]]
|
||||||
|
errors: List[ValidationErrorDict]
|
||||||
|
class_type: str
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationTuple(typing.NamedTuple):
|
||||||
|
valid: bool
|
||||||
|
error: Optional[ValidationErrorDict]
|
||||||
|
good_output_node_ids: List[str]
|
||||||
|
node_errors: list[typing.Never] | typing.Dict[str, NodeErrorsDictValue]
|
||||||
|
|
||||||
|
|
||||||
|
class ValidateInputsTuple(typing.NamedTuple):
|
||||||
|
valid: bool
|
||||||
|
errors: List[ValidationErrorDict]
|
||||||
|
unique_id: str
|
||||||
|
|||||||
@ -89,8 +89,6 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.adm_channels = 0
|
self.adm_channels = 0
|
||||||
|
|
||||||
self.concat_keys = ()
|
self.concat_keys = ()
|
||||||
logging.info("model_type {}".format(model_type.name))
|
|
||||||
logging.debug("adm {}".format(self.adm_channels))
|
|
||||||
|
|
||||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
sigma = t
|
sigma = t
|
||||||
|
|||||||
@ -315,6 +315,8 @@ KNOWN_CONTROLNETS: Final[List[Downloadable]] = [
|
|||||||
HuggingFile("xinsir/controlnet-openpose-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-openpose-sdxl-1.0.safetensors"),
|
HuggingFile("xinsir/controlnet-openpose-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-openpose-sdxl-1.0.safetensors"),
|
||||||
HuggingFile("xinsir/anime-painter", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-anime-painter-scribble-sdxl-1.0.safetensors"),
|
HuggingFile("xinsir/anime-painter", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-anime-painter-scribble-sdxl-1.0.safetensors"),
|
||||||
HuggingFile("TheMistoAI/MistoLine", "mistoLine_rank256.safetensors"),
|
HuggingFile("TheMistoAI/MistoLine", "mistoLine_rank256.safetensors"),
|
||||||
|
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0-promax.safetensors"),
|
||||||
|
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0.safetensors"),
|
||||||
]
|
]
|
||||||
|
|
||||||
KNOWN_DIFF_CONTROLNETS: Final[List[Downloadable]] = [
|
KNOWN_DIFF_CONTROLNETS: Final[List[Downloadable]] = [
|
||||||
|
|||||||
@ -10,9 +10,11 @@ from typing import Literal, List
|
|||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
from opentelemetry.trace import get_current_span
|
||||||
|
|
||||||
from . import interruption
|
from . import interruption
|
||||||
from .cli_args import args
|
from .cli_args import args
|
||||||
|
from .cmd.main_pre import tracer
|
||||||
from .model_management_types import ModelManageable
|
from .model_management_types import ModelManageable
|
||||||
|
|
||||||
model_management_lock = RLock()
|
model_management_lock = RLock()
|
||||||
@ -356,6 +358,12 @@ class LoadedModel:
|
|||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.model is other.model
|
return self.model is other.model
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
if self.model is not None:
|
||||||
|
return f"<LoadedModel {str(self.model)}>"
|
||||||
|
else:
|
||||||
|
return f"<LoadedModel>"
|
||||||
|
|
||||||
|
|
||||||
def minimum_inference_memory():
|
def minimum_inference_memory():
|
||||||
return (1024 * 1024 * 1024)
|
return (1024 * 1024 * 1024)
|
||||||
@ -392,9 +400,12 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> b
|
|||||||
return unload_weight
|
return unload_weight
|
||||||
|
|
||||||
|
|
||||||
def free_memory(memory_required, device, keep_loaded=[]):
|
@tracer.start_as_current_span("Free Memory")
|
||||||
|
def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
|
||||||
|
span = get_current_span()
|
||||||
|
span.set_attribute("memory_required", memory_required)
|
||||||
with model_management_lock:
|
with model_management_lock:
|
||||||
unloaded_model = []
|
unloaded_models: List[LoadedModel] = []
|
||||||
can_unload = []
|
can_unload = []
|
||||||
|
|
||||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||||
@ -410,12 +421,12 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
if get_free_memory(device) > memory_required:
|
if get_free_memory(device) > memory_required:
|
||||||
break
|
break
|
||||||
current_loaded_models[i].model_unload()
|
current_loaded_models[i].model_unload()
|
||||||
unloaded_model.append(i)
|
unloaded_models.append(i)
|
||||||
|
|
||||||
for i in sorted(unloaded_model, reverse=True):
|
for i in sorted(unloaded_models, reverse=True):
|
||||||
current_loaded_models.pop(i)
|
current_loaded_models.pop(i)
|
||||||
|
|
||||||
if len(unloaded_model) > 0:
|
if len(unloaded_models) > 0:
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
else:
|
else:
|
||||||
if vram_state != VRAMState.HIGH_VRAM:
|
if vram_state != VRAMState.HIGH_VRAM:
|
||||||
@ -423,10 +434,16 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
if mem_free_torch > mem_free_total * 0.25:
|
if mem_free_torch > mem_free_total * 0.25:
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
|
|
||||||
|
span.set_attribute("unloaded_models", list(map(str, unloaded_models)))
|
||||||
|
return unloaded_models
|
||||||
|
|
||||||
|
|
||||||
|
@tracer.start_as_current_span("Load Models GPU")
|
||||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
||||||
global vram_state
|
global vram_state
|
||||||
|
span = get_current_span()
|
||||||
|
if memory_required != 0:
|
||||||
|
span.set_attribute("memory_required", memory_required)
|
||||||
with model_management_lock:
|
with model_management_lock:
|
||||||
inference_memory = minimum_inference_memory()
|
inference_memory = minimum_inference_memory()
|
||||||
extra_mem = max(inference_memory, memory_required)
|
extra_mem = max(inference_memory, memory_required)
|
||||||
@ -452,19 +469,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
|||||||
loaded.currently_used = True
|
loaded.currently_used = True
|
||||||
models_already_loaded.append(loaded)
|
models_already_loaded.append(loaded)
|
||||||
if loaded is None:
|
if loaded is None:
|
||||||
if hasattr(x, "model"):
|
|
||||||
logging.info(f"Requested to load {x.model.__class__.__name__}")
|
|
||||||
models_to_load.append(loaded_model)
|
models_to_load.append(loaded_model)
|
||||||
|
|
||||||
|
models_freed: List[LoadedModel] = []
|
||||||
if len(models_to_load) == 0:
|
if len(models_to_load) == 0:
|
||||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||||
for d in devs:
|
for d in devs:
|
||||||
if d != torch.device("cpu"):
|
if d != torch.device("cpu"):
|
||||||
free_memory(extra_mem, d, models_already_loaded)
|
models_freed += free_memory(extra_mem, d, models_already_loaded)
|
||||||
return
|
return
|
||||||
|
|
||||||
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
|
||||||
|
|
||||||
total_memory_required = {}
|
total_memory_required = {}
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different
|
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different
|
||||||
@ -472,7 +486,12 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
|||||||
|
|
||||||
for device in total_memory_required:
|
for device in total_memory_required:
|
||||||
if device != torch.device("cpu"):
|
if device != torch.device("cpu"):
|
||||||
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
models_freed += free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
||||||
|
|
||||||
|
span.set_attribute("models_to_load", list(map(str, models_to_load)))
|
||||||
|
span.set_attribute("models_freed", list(map(str, models_freed)))
|
||||||
|
|
||||||
|
logging.info(f"Models loaded: {','.join(map(str, models_to_load))}, models freed: {','.join(map(str, models_freed))}")
|
||||||
|
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
|
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
|
||||||
|
|||||||
@ -2,13 +2,14 @@ import copy
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from . import model_management
|
from . import model_management
|
||||||
from . import utils
|
from . import utils
|
||||||
from .types import UnetWrapperFunction
|
|
||||||
from .model_management_types import ModelManageable
|
from .model_management_types import ModelManageable
|
||||||
|
from .types import UnetWrapperFunction
|
||||||
|
|
||||||
|
|
||||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
|
||||||
@ -60,14 +61,16 @@ def set_model_options_post_cfg_function(model_options, post_cfg_function, disabl
|
|||||||
model_options["disable_cfg1_optimization"] = True
|
model_options["disable_cfg1_optimization"] = True
|
||||||
return model_options
|
return model_options
|
||||||
|
|
||||||
|
|
||||||
def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False):
|
def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False):
|
||||||
model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function]
|
model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function]
|
||||||
if disable_cfg1_optimization:
|
if disable_cfg1_optimization:
|
||||||
model_options["disable_cfg1_optimization"] = True
|
model_options["disable_cfg1_optimization"] = True
|
||||||
return model_options
|
return model_options
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher(ModelManageable):
|
class ModelPatcher(ModelManageable):
|
||||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.model = model
|
self.model = model
|
||||||
self.patches = {}
|
self.patches = {}
|
||||||
@ -87,6 +90,7 @@ class ModelPatcher(ModelManageable):
|
|||||||
self.weight_inplace_update = weight_inplace_update
|
self.weight_inplace_update = weight_inplace_update
|
||||||
self.model_lowvram = False
|
self.model_lowvram = False
|
||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid = uuid.uuid4()
|
||||||
|
self.ckpt_name = ckpt_name
|
||||||
self._lowvram_patch_counter = 0
|
self._lowvram_patch_counter = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -105,6 +109,7 @@ class ModelPatcher(ModelManageable):
|
|||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self._current_device, weight_inplace_update=self.weight_inplace_update)
|
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self._current_device, weight_inplace_update=self.weight_inplace_update)
|
||||||
|
n.ckpt_name = self.ckpt_name
|
||||||
n.patches = {}
|
n.patches = {}
|
||||||
for k in self.patches:
|
for k in self.patches:
|
||||||
n.patches[k] = self.patches[k][:]
|
n.patches[k] = self.patches[k][:]
|
||||||
@ -578,3 +583,9 @@ class ModelPatcher(ModelManageable):
|
|||||||
@property
|
@property
|
||||||
def current_device(self) -> torch.device:
|
def current_device(self) -> torch.device:
|
||||||
return self._current_device
|
return self._current_device
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
if self.ckpt_name is not None:
|
||||||
|
return f"<ModelPatcher for {self.ckpt_name} ({self.model.__class__.__name__})>"
|
||||||
|
else:
|
||||||
|
return f"<ModelPatcher for {self.model.__class__.__name__}>"
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing_extensions import TypedDict, NotRequired, Generic
|
|
||||||
from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
|
from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
|
||||||
Callable, List, Type
|
Callable, List, Type
|
||||||
|
|
||||||
|
from typing_extensions import TypedDict, NotRequired
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
@ -51,7 +52,7 @@ StringSpec = Tuple[Literal["STRING"], StringSpecOptions]
|
|||||||
|
|
||||||
BooleanSpec = Tuple[Literal["BOOLEAN"], BoolSpecOptions]
|
BooleanSpec = Tuple[Literal["BOOLEAN"], BoolSpecOptions]
|
||||||
|
|
||||||
ChoiceSpec = Tuple[Union[Sequence[str], Sequence[float], Sequence[int]]]
|
ChoiceSpec = Tuple[Union[List[str], List[float], List[int], Tuple[str, ...], Tuple[float, ...], Tuple[int, ...]]]
|
||||||
|
|
||||||
NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any]
|
NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any]
|
||||||
|
|
||||||
@ -73,6 +74,7 @@ ValidateInputsMethod = Optional[Callable[..., Union[bool, str]]]
|
|||||||
|
|
||||||
IsChangedMethod = Callable[[Type[Any], ...], str]
|
IsChangedMethod = Callable[[Type[Any], ...], str]
|
||||||
|
|
||||||
|
|
||||||
class FunctionReturnsUIVariables(TypedDict):
|
class FunctionReturnsUIVariables(TypedDict):
|
||||||
ui: dict
|
ui: dict
|
||||||
result: NotRequired[Sequence[Any]]
|
result: NotRequired[Sequence[Any]]
|
||||||
@ -123,6 +125,10 @@ class CustomNode(Protocol):
|
|||||||
|
|
||||||
IS_CHANGED: Optional[ClassVar[IsChangedMethod]]
|
IS_CHANGED: Optional[ClassVar[IsChangedMethod]]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __call__(cls, *args, **kwargs) -> 'CustomNode':
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExportedNodes:
|
class ExportedNodes:
|
||||||
|
|||||||
18
comfy/sd.py
18
comfy/sd.py
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
|
import os.path
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
@ -16,18 +17,18 @@ from . import model_detection
|
|||||||
from . import model_management
|
from . import model_management
|
||||||
from . import model_patcher
|
from . import model_patcher
|
||||||
from . import model_sampling
|
from . import model_sampling
|
||||||
|
from . import sa_t5
|
||||||
from . import sd1_clip
|
from . import sd1_clip
|
||||||
from . import sd2_clip
|
from . import sd2_clip
|
||||||
|
from . import sd3_clip
|
||||||
from . import sdxl_clip
|
from . import sdxl_clip
|
||||||
from . import utils
|
from . import utils
|
||||||
|
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||||
from .ldm.cascade.stage_a import StageA
|
from .ldm.cascade.stage_a import StageA
|
||||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
|
||||||
from .t2i_adapter import adapter
|
from .t2i_adapter import adapter
|
||||||
from .taesd import taesd
|
from .taesd import taesd
|
||||||
from . import sd3_clip
|
|
||||||
from . import sa_t5
|
|
||||||
from .text_encoders import aura_t5
|
from .text_encoders import aura_t5
|
||||||
|
|
||||||
|
|
||||||
@ -228,7 +229,7 @@ class VAE:
|
|||||||
self.latent_channels = 64
|
self.latent_channels = 64
|
||||||
self.output_channels = 2
|
self.output_channels = 2
|
||||||
self.upscale_ratio = 2048
|
self.upscale_ratio = 2048
|
||||||
self.downscale_ratio = 2048
|
self.downscale_ratio = 2048
|
||||||
self.process_output = lambda audio: audio
|
self.process_output = lambda audio: audio
|
||||||
self.process_input = lambda audio: audio
|
self.process_input = lambda audio: audio
|
||||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
@ -302,7 +303,7 @@ class VAE:
|
|||||||
|
|
||||||
def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
|
def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||||
return utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
return utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||||
|
|
||||||
def decode(self, samples_in):
|
def decode(self, samples_in):
|
||||||
try:
|
try:
|
||||||
@ -558,9 +559,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
logging.debug("left over keys: {}".format(left_over))
|
logging.debug("left over keys: {}".format(left_over))
|
||||||
|
|
||||||
if output_model:
|
if output_model:
|
||||||
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device, ckpt_name=os.path.basename(ckpt_path))
|
||||||
if inital_load_device != torch.device("cpu"):
|
if inital_load_device != torch.device("cpu"):
|
||||||
logging.info("loaded straight to GPU")
|
|
||||||
model_management.load_model_gpu(_model_patcher)
|
model_management.load_model_gpu(_model_patcher)
|
||||||
|
|
||||||
return (_model_patcher, clip, vae, clipvision)
|
return (_model_patcher, clip, vae, clipvision)
|
||||||
@ -568,7 +568,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
|
|
||||||
def load_unet_state_dict(sd): # load unet in diffusers or regular format
|
def load_unet_state_dict(sd): # load unet in diffusers or regular format
|
||||||
|
|
||||||
#Allow loading unets from checkpoint files
|
# Allow loading unets from checkpoint files
|
||||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||||
temp_sd = utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
|
temp_sd = utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
|
||||||
if len(temp_sd) > 0:
|
if len(temp_sd) > 0:
|
||||||
@ -583,7 +583,7 @@ def load_unet_state_dict(sd): # load unet in diffusers or regular format
|
|||||||
new_sd = sd
|
new_sd = sd
|
||||||
else:
|
else:
|
||||||
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
if new_sd is not None: #diffusers mmdit
|
if new_sd is not None: # diffusers mmdit
|
||||||
model_config = model_detection.model_config_from_unet(new_sd, "")
|
model_config = model_detection.model_config_from_unet(new_sd, "")
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -46,6 +46,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
def test_known_repos(tmp_path_factory):
|
def test_known_repos(tmp_path_factory):
|
||||||
|
prev_hub_cache = os.getenv("HF_HUB_CACHE")
|
||||||
|
os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache"))
|
||||||
|
|
||||||
from comfy.cmd import folder_paths
|
from comfy.cmd import folder_paths
|
||||||
from comfy.cmd.folder_paths import FolderPathsTuple
|
from comfy.cmd.folder_paths import FolderPathsTuple
|
||||||
from comfy.model_downloader import get_huggingface_repo_list, \
|
from comfy.model_downloader import get_huggingface_repo_list, \
|
||||||
@ -57,9 +60,10 @@ def test_known_repos(tmp_path_factory):
|
|||||||
test_repo_id = "doctorpangloss/comfyui_downloader_test"
|
test_repo_id = "doctorpangloss/comfyui_downloader_test"
|
||||||
prev_huggingface = folder_paths.folder_names_and_paths["huggingface"]
|
prev_huggingface = folder_paths.folder_names_and_paths["huggingface"]
|
||||||
prev_huggingface_cache = folder_paths.folder_names_and_paths["huggingface_cache"]
|
prev_huggingface_cache = folder_paths.folder_names_and_paths["huggingface_cache"]
|
||||||
prev_hub_cache = os.getenv("HF_HUB_CACHE")
|
|
||||||
_delete_repo_from_huggingface_cache(test_repo_id)
|
_delete_repo_from_huggingface_cache(test_repo_id)
|
||||||
_delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir)
|
_delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir)
|
||||||
|
args.disable_known_models = False
|
||||||
try:
|
try:
|
||||||
folder_paths.folder_names_and_paths["huggingface"] += FolderPathsTuple("huggingface", [test_local_dir], {""})
|
folder_paths.folder_names_and_paths["huggingface"] += FolderPathsTuple("huggingface", [test_local_dir], {""})
|
||||||
folder_paths.folder_names_and_paths["huggingface_cache"] += FolderPathsTuple("huggingface_cache", [test_cache_dir], {""})
|
folder_paths.folder_names_and_paths["huggingface_cache"] += FolderPathsTuple("huggingface_cache", [test_cache_dir], {""})
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user