Improve logging and tracing for validation errors

This commit is contained in:
doctorpangloss 2024-07-16 12:26:30 -07:00
parent a20bf8134d
commit 72baecad87
12 changed files with 167 additions and 79 deletions

View File

@ -1,13 +1,15 @@
import glob
import json import json
import os import os
import re import re
import uuid
import glob
import shutil import shutil
import uuid
from aiohttp import web from aiohttp import web
from .app_settings import AppSettings
from ..cli_args import args from ..cli_args import args
from ..cmd.folder_paths import user_directory from ..cmd.folder_paths import user_directory
from .app_settings import AppSettings
class UserManager(): class UserManager():
@ -17,9 +19,6 @@ class UserManager():
self.settings = AppSettings(self) self.settings = AppSettings(self)
if not os.path.exists(user_directory): if not os.path.exists(user_directory):
os.mkdir(user_directory) os.mkdir(user_directory)
if not args.multi_user:
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
if args.multi_user: if args.multi_user:
if os.path.isfile(self.users_file): if os.path.isfile(self.users_file):
@ -129,7 +128,7 @@ class UserManager():
return web.json_response(results) return web.json_response(results)
def get_user_data_path(request, check_exists = False, param = "file"): def get_user_data_path(request, check_exists=False, param="file"):
file = request.match_info.get(param, None) file = request.match_info.get(param, None)
if not file: if not file:
return web.Response(status=400) return web.Response(status=400)

View File

@ -8,22 +8,22 @@ import sys
import threading import threading
import traceback import traceback
import typing import typing
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple
import lazy_object_proxy import lazy_object_proxy
import torch import torch
from opentelemetry.trace import get_current_span, StatusCode, Status from opentelemetry.trace import get_current_span, StatusCode, Status
from typing_extensions import TypedDict
from .main_pre import tracer from .main_pre import tracer
from .. import interruption from .. import interruption
from .. import model_management from .. import model_management
from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
ValidationErrorDict, NodeErrorsDictValue
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..execution_context import new_execution_context, ExecutionContext from ..execution_context import new_execution_context, ExecutionContext
from ..nodes.package import import_all_nodes_in_workspace from ..nodes.package import import_all_nodes_in_workspace
from ..nodes.package_typing import ExportedNodes from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the # ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
# various functions that are declared here. It should have been a context in the first place. # various functions that are declared here. It should have been a context in the first place.
@ -492,7 +492,7 @@ class PromptExecutor:
model_management.unload_all_models() model_management.unload_all_models()
def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], typing.Any]: def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple:
# todo: this should check if LoadImage / LoadImageMask paths exist # todo: this should check if LoadImage / LoadImageMask paths exist
# todo: or, nodes should provide a way to validate their values # todo: or, nodes should provide a way to validate their values
unique_id = item unique_id = item
@ -506,11 +506,12 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
class_inputs = obj_class.INPUT_TYPES() class_inputs = obj_class.INPUT_TYPES()
required_inputs = class_inputs['required'] required_inputs = class_inputs['required']
error: ValidationErrorDict
errors = [] errors = []
valid = True valid = True
# todo: investigate if these are at the right indent level # todo: investigate if these are at the right indent level
info = None info: Optional[InputTypeSpec] = None
val = None val = None
validate_function_inputs = [] validate_function_inputs = []
@ -531,7 +532,7 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
continue continue
val = inputs[x] val = inputs[x]
info = required_inputs[x] info: InputTypeSpec = required_inputs[x]
type_input = info[0] type_input = info[0]
if isinstance(val, list): if isinstance(val, list):
if len(val) != 2: if len(val) != 2:
@ -593,7 +594,7 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
"linked_node": val "linked_node": val
} }
}] }]
validated[o_id] = (False, reasons, o_id) validated[o_id] = ValidateInputsTuple(False, reasons, o_id)
continue continue
else: else:
try: try:
@ -622,10 +623,11 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
continue continue
if len(info) > 1: if len(info) > 1:
if "min" in info[1] and val < info[1]["min"]: has_min_max: IntSpecOptions | FloatSpecOptions = info[1]
if "min" in has_min_max and val < has_min_max["min"]:
error = { error = {
"type": "value_smaller_than_min", "type": "value_smaller_than_min",
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]), "message": "Value {} smaller than min of {}".format(val, has_min_max["min"]),
"details": f"{x}", "details": f"{x}",
"extra_info": { "extra_info": {
"input_name": x, "input_name": x,
@ -635,10 +637,10 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
} }
errors.append(error) errors.append(error)
continue continue
if "max" in info[1] and val > info[1]["max"]: if "max" in has_min_max and val > has_min_max["max"]:
error = { error = {
"type": "value_bigger_than_max", "type": "value_bigger_than_max",
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]), "message": "Value {} bigger than max of {}".format(val, has_min_max["max"]),
"details": f"{x}", "details": f"{x}",
"extra_info": { "extra_info": {
"input_name": x, "input_name": x,
@ -706,9 +708,9 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
continue continue
if len(errors) > 0 or valid is not True: if len(errors) > 0 or valid is not True:
ret = (False, errors, unique_id) ret = ValidateInputsTuple(False, errors, unique_id)
else: else:
ret = (True, [], unique_id) ret = ValidateInputsTuple(True, [], unique_id)
validated[unique_id] = ret validated[unique_id] = ret
return ret return ret
@ -721,23 +723,25 @@ def full_type_name(klass):
return module + '.' + klass.__qualname__ return module + '.' + klass.__qualname__
class ValidationErrorExtraInfoDict(TypedDict):
exception_type: str
traceback: List[str]
class ValidationErrorDict(TypedDict):
type: str
message: str
details: str
extra_info: ValidationErrorExtraInfoDict | dict
ValidationTuple = typing.Tuple[bool, Optional[ValidationErrorDict], typing.List[str], Union[dict, list]]
@tracer.start_as_current_span("Validate Prompt") @tracer.start_as_current_span("Validate Prompt")
def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple: def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
res = _validate_prompt(prompt)
if not res.valid:
span = get_current_span()
span.set_status(Status(StatusCode.ERROR))
if res.error is not None and len(res.error) > 0:
span.set_attributes({
f"error.{k}": v for k, v in res.error.items()
})
if len(res.node_errors) > 0:
for node_id, node_error in res.node_errors.items():
for node_error_field, node_error_value in node_error.items():
if isinstance(node_error_value, (str, bool, int, float)):
span.set_attribute("node_errors.{node_id}.{node_error_field}", node_error_value)
return res
def _validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
outputs = set() outputs = set()
for x in prompt: for x in prompt:
if 'class_type' not in prompt[x]: if 'class_type' not in prompt[x]:
@ -747,7 +751,7 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
"details": f"Node ID '#{x}'", "details": f"Node ID '#{x}'",
"extra_info": {} "extra_info": {}
} }
return (False, error, [], []) return ValidationTuple(False, error, [], [])
class_type = prompt[x]['class_type'] class_type = prompt[x]['class_type']
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None) class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
@ -758,7 +762,7 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
"details": f"Node ID '#{x}'", "details": f"Node ID '#{x}'",
"extra_info": {} "extra_info": {}
} }
return (False, error, [], []) return ValidationTuple(False, error, [], [])
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True: if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
outputs.add(x) outputs.add(x)
@ -770,15 +774,15 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
"details": "", "details": "",
"extra_info": {} "extra_info": {}
} }
return False, error, [], [] return ValidationTuple(False, error, [], [])
good_outputs = set() good_outputs = set()
errors = [] errors = []
node_errors = {} node_errors: typing.Dict[str, NodeErrorsDictValue] = {}
validated = {} validated: typing.Dict[str, ValidateInputsTuple] = {}
for o in outputs: for o in outputs:
valid = False valid = False
reasons = [] reasons: List[ValidationErrorDict] = []
try: try:
m = validate_inputs(prompt, o, validated) m = validate_inputs(prompt, o, validated)
valid = m[0] valid = m[0]
@ -796,7 +800,7 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
"traceback": traceback.format_tb(tb) "traceback": traceback.format_tb(tb)
} }
}] }]
validated[o] = (False, reasons, o) validated[o] = ValidateInputsTuple(False, reasons, o)
if valid is True: if valid is True:
good_outputs.add(o) good_outputs.add(o)
@ -841,9 +845,9 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
"extra_info": {} "extra_info": {}
} }
return False, error, list(good_outputs), node_errors return ValidationTuple(False, error, list(good_outputs), node_errors)
return True, None, list(good_outputs), node_errors return ValidationTuple(True, None, list(good_outputs), node_errors)
class PromptQueue(AbstractPromptQueue): class PromptQueue(AbstractPromptQueue):

View File

@ -7,9 +7,9 @@ import shutil
import threading import threading
import time import time
from .extra_model_paths import load_extra_path_config
# main_pre must be the earliest import since it suppresses some spurious warnings # main_pre must be the earliest import since it suppresses some spurious warnings
from .main_pre import args from .main_pre import args
from .extra_model_paths import load_extra_path_config
from .. import model_management from .. import model_management
from ..analytics.analytics import initialize_event_tracking from ..analytics.analytics import initialize_event_tracking
from ..cmd import cuda_malloc from ..cmd import cuda_malloc
@ -223,7 +223,10 @@ async def main():
def entrypoint(): def entrypoint():
asyncio.run(main()) try:
asyncio.run(main())
except KeyboardInterrupt as keyboard_interrupt:
logging.info(f"Gracefully shutting down due to {keyboard_interrupt}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -492,9 +492,6 @@ class PromptServer(ExecutorToClientProgress):
@routes.post("/prompt") @routes.post("/prompt")
async def post_prompt(request): async def post_prompt(request):
logging.info("got prompt")
resp_code = 200
out_string = ""
json_data = await request.json() json_data = await request.json()
json_data = self.trigger_on_prompt(json_data) json_data = self.trigger_on_prompt(json_data)

View File

@ -1,11 +1,13 @@
from __future__ import annotations # for Python 3.7-3.9 from __future__ import annotations # for Python 3.7-3.9
from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple import typing
from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple, List
import PIL.Image import PIL.Image
from typing_extensions import NotRequired, TypedDict from typing_extensions import NotRequired, TypedDict
from .queue_types import BinaryEventTypes from .queue_types import BinaryEventTypes
from ..nodes.package_typing import InputTypeSpec
class ExecInfo(TypedDict): class ExecInfo(TypedDict):
@ -85,3 +87,46 @@ class ExecutorToClientProgress(Protocol):
:return: :return:
""" """
pass pass
ExceptionTypes = Literal["custom_validation_failed", "value_not_in_list", "value_bigger_than_max", "value_smaller_than_min", "invalid_input_type", "exception_during_inner_validation", "return_type_mismatch", "bad_linked_input", "required_input_missing", "invalid_prompt", "prompt_no_outputs", "exception_during_validation", "prompt_outputs_failed_validation"]
class ValidationErrorExtraInfoDict(TypedDict, total=False):
exception_type: NotRequired[str]
traceback: NotRequired[List[str]]
dependent_outputs: NotRequired[List[str]]
class_type: NotRequired[str]
input_name: NotRequired[str]
input_config: NotRequired[typing.Dict[str, InputTypeSpec]]
received_value: NotRequired[typing.Any]
linked_node: NotRequired[str]
traceback: NotRequired[str]
exception_message: NotRequired[str]
exception_type: NotRequired[str]
class ValidationErrorDict(TypedDict):
type: str
message: str
details: str
extra_info: list[typing.Never] | ValidationErrorExtraInfoDict
class NodeErrorsDictValue(TypedDict, total=False):
dependent_outputs: NotRequired[List[str]]
errors: List[ValidationErrorDict]
class_type: str
class ValidationTuple(typing.NamedTuple):
valid: bool
error: Optional[ValidationErrorDict]
good_output_node_ids: List[str]
node_errors: list[typing.Never] | typing.Dict[str, NodeErrorsDictValue]
class ValidateInputsTuple(typing.NamedTuple):
valid: bool
errors: List[ValidationErrorDict]
unique_id: str

View File

@ -89,8 +89,6 @@ class BaseModel(torch.nn.Module):
self.adm_channels = 0 self.adm_channels = 0
self.concat_keys = () self.concat_keys = ()
logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels))
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t sigma = t

View File

@ -315,6 +315,8 @@ KNOWN_CONTROLNETS: Final[List[Downloadable]] = [
HuggingFile("xinsir/controlnet-openpose-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-openpose-sdxl-1.0.safetensors"), HuggingFile("xinsir/controlnet-openpose-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-openpose-sdxl-1.0.safetensors"),
HuggingFile("xinsir/anime-painter", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-anime-painter-scribble-sdxl-1.0.safetensors"), HuggingFile("xinsir/anime-painter", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-anime-painter-scribble-sdxl-1.0.safetensors"),
HuggingFile("TheMistoAI/MistoLine", "mistoLine_rank256.safetensors"), HuggingFile("TheMistoAI/MistoLine", "mistoLine_rank256.safetensors"),
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0-promax.safetensors"),
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0.safetensors"),
] ]
KNOWN_DIFF_CONTROLNETS: Final[List[Downloadable]] = [ KNOWN_DIFF_CONTROLNETS: Final[List[Downloadable]] = [

View File

@ -10,9 +10,11 @@ from typing import Literal, List
import psutil import psutil
import torch import torch
from opentelemetry.trace import get_current_span
from . import interruption from . import interruption
from .cli_args import args from .cli_args import args
from .cmd.main_pre import tracer
from .model_management_types import ModelManageable from .model_management_types import ModelManageable
model_management_lock = RLock() model_management_lock = RLock()
@ -356,6 +358,12 @@ class LoadedModel:
def __eq__(self, other): def __eq__(self, other):
return self.model is other.model return self.model is other.model
def __str__(self):
if self.model is not None:
return f"<LoadedModel {str(self.model)}>"
else:
return f"<LoadedModel>"
def minimum_inference_memory(): def minimum_inference_memory():
return (1024 * 1024 * 1024) return (1024 * 1024 * 1024)
@ -392,9 +400,12 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> b
return unload_weight return unload_weight
def free_memory(memory_required, device, keep_loaded=[]): @tracer.start_as_current_span("Free Memory")
def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
span = get_current_span()
span.set_attribute("memory_required", memory_required)
with model_management_lock: with model_management_lock:
unloaded_model = [] unloaded_models: List[LoadedModel] = []
can_unload = [] can_unload = []
for i in range(len(current_loaded_models) - 1, -1, -1): for i in range(len(current_loaded_models) - 1, -1, -1):
@ -410,12 +421,12 @@ def free_memory(memory_required, device, keep_loaded=[]):
if get_free_memory(device) > memory_required: if get_free_memory(device) > memory_required:
break break
current_loaded_models[i].model_unload() current_loaded_models[i].model_unload()
unloaded_model.append(i) unloaded_models.append(i)
for i in sorted(unloaded_model, reverse=True): for i in sorted(unloaded_models, reverse=True):
current_loaded_models.pop(i) current_loaded_models.pop(i)
if len(unloaded_model) > 0: if len(unloaded_models) > 0:
soft_empty_cache() soft_empty_cache()
else: else:
if vram_state != VRAMState.HIGH_VRAM: if vram_state != VRAMState.HIGH_VRAM:
@ -423,10 +434,16 @@ def free_memory(memory_required, device, keep_loaded=[]):
if mem_free_torch > mem_free_total * 0.25: if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache() soft_empty_cache()
span.set_attribute("unloaded_models", list(map(str, unloaded_models)))
return unloaded_models
@tracer.start_as_current_span("Load Models GPU")
def load_models_gpu(models, memory_required=0, force_patch_weights=False): def load_models_gpu(models, memory_required=0, force_patch_weights=False):
global vram_state global vram_state
span = get_current_span()
if memory_required != 0:
span.set_attribute("memory_required", memory_required)
with model_management_lock: with model_management_lock:
inference_memory = minimum_inference_memory() inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required) extra_mem = max(inference_memory, memory_required)
@ -452,19 +469,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
loaded.currently_used = True loaded.currently_used = True
models_already_loaded.append(loaded) models_already_loaded.append(loaded)
if loaded is None: if loaded is None:
if hasattr(x, "model"):
logging.info(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model) models_to_load.append(loaded_model)
models_freed: List[LoadedModel] = []
if len(models_to_load) == 0: if len(models_to_load) == 0:
devs = set(map(lambda a: a.device, models_already_loaded)) devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs: for d in devs:
if d != torch.device("cpu"): if d != torch.device("cpu"):
free_memory(extra_mem, d, models_already_loaded) models_freed += free_memory(extra_mem, d, models_already_loaded)
return return
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
total_memory_required = {} total_memory_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different
@ -472,7 +486,12 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
for device in total_memory_required: for device in total_memory_required:
if device != torch.device("cpu"): if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) models_freed += free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
span.set_attribute("models_to_load", list(map(str, models_to_load)))
span.set_attribute("models_freed", list(map(str, models_freed)))
logging.info(f"Models loaded: {','.join(map(str, models_to_load))}, models freed: {','.join(map(str, models_freed))}")
for loaded_model in models_to_load: for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded

View File

@ -2,13 +2,14 @@ import copy
import inspect import inspect
import logging import logging
import uuid import uuid
from typing import Optional
import torch import torch
from . import model_management from . import model_management
from . import utils from . import utils
from .types import UnetWrapperFunction
from .model_management_types import ModelManageable from .model_management_types import ModelManageable
from .types import UnetWrapperFunction
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength): def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
@ -60,14 +61,16 @@ def set_model_options_post_cfg_function(model_options, post_cfg_function, disabl
model_options["disable_cfg1_optimization"] = True model_options["disable_cfg1_optimization"] = True
return model_options return model_options
def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False): def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False):
model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function] model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function]
if disable_cfg1_optimization: if disable_cfg1_optimization:
model_options["disable_cfg1_optimization"] = True model_options["disable_cfg1_optimization"] = True
return model_options return model_options
class ModelPatcher(ModelManageable): class ModelPatcher(ModelManageable):
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False, ckpt_name: Optional[str] = None):
self.size = size self.size = size
self.model = model self.model = model
self.patches = {} self.patches = {}
@ -87,6 +90,7 @@ class ModelPatcher(ModelManageable):
self.weight_inplace_update = weight_inplace_update self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False self.model_lowvram = False
self.patches_uuid = uuid.uuid4() self.patches_uuid = uuid.uuid4()
self.ckpt_name = ckpt_name
self._lowvram_patch_counter = 0 self._lowvram_patch_counter = 0
@property @property
@ -105,6 +109,7 @@ class ModelPatcher(ModelManageable):
def clone(self): def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self._current_device, weight_inplace_update=self.weight_inplace_update) n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self._current_device, weight_inplace_update=self.weight_inplace_update)
n.ckpt_name = self.ckpt_name
n.patches = {} n.patches = {}
for k in self.patches: for k in self.patches:
n.patches[k] = self.patches[k][:] n.patches[k] = self.patches[k][:]
@ -578,3 +583,9 @@ class ModelPatcher(ModelManageable):
@property @property
def current_device(self) -> torch.device: def current_device(self) -> torch.device:
return self._current_device return self._current_device
def __str__(self):
if self.ckpt_name is not None:
return f"<ModelPatcher for {self.ckpt_name} ({self.model.__class__.__name__})>"
else:
return f"<ModelPatcher for {self.model.__class__.__name__}>"

View File

@ -1,10 +1,11 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing_extensions import TypedDict, NotRequired, Generic
from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \ from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
Callable, List, Type Callable, List, Type
from typing_extensions import TypedDict, NotRequired
T = TypeVar('T') T = TypeVar('T')
@ -51,7 +52,7 @@ StringSpec = Tuple[Literal["STRING"], StringSpecOptions]
BooleanSpec = Tuple[Literal["BOOLEAN"], BoolSpecOptions] BooleanSpec = Tuple[Literal["BOOLEAN"], BoolSpecOptions]
ChoiceSpec = Tuple[Union[Sequence[str], Sequence[float], Sequence[int]]] ChoiceSpec = Tuple[Union[List[str], List[float], List[int], Tuple[str, ...], Tuple[float, ...], Tuple[int, ...]]]
NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any] NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any]
@ -73,6 +74,7 @@ ValidateInputsMethod = Optional[Callable[..., Union[bool, str]]]
IsChangedMethod = Callable[[Type[Any], ...], str] IsChangedMethod = Callable[[Type[Any], ...], str]
class FunctionReturnsUIVariables(TypedDict): class FunctionReturnsUIVariables(TypedDict):
ui: dict ui: dict
result: NotRequired[Sequence[Any]] result: NotRequired[Sequence[Any]]
@ -123,6 +125,10 @@ class CustomNode(Protocol):
IS_CHANGED: Optional[ClassVar[IsChangedMethod]] IS_CHANGED: Optional[ClassVar[IsChangedMethod]]
@classmethod
def __call__(cls, *args, **kwargs) -> 'CustomNode':
...
@dataclass @dataclass
class ExportedNodes: class ExportedNodes:

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import dataclasses import dataclasses
import logging import logging
import os.path
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any, Optional
@ -16,18 +17,18 @@ from . import model_detection
from . import model_management from . import model_management
from . import model_patcher from . import model_patcher
from . import model_sampling from . import model_sampling
from . import sa_t5
from . import sd1_clip from . import sd1_clip
from . import sd2_clip from . import sd2_clip
from . import sd3_clip
from . import sdxl_clip from . import sdxl_clip
from . import utils from . import utils
from .ldm.audio.autoencoder import AudioOobleckVAE
from .ldm.cascade.stage_a import StageA from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.audio.autoencoder import AudioOobleckVAE
from .t2i_adapter import adapter from .t2i_adapter import adapter
from .taesd import taesd from .taesd import taesd
from . import sd3_clip
from . import sa_t5
from .text_encoders import aura_t5 from .text_encoders import aura_t5
@ -228,7 +229,7 @@ class VAE:
self.latent_channels = 64 self.latent_channels = 64
self.output_channels = 2 self.output_channels = 2
self.upscale_ratio = 2048 self.upscale_ratio = 2048
self.downscale_ratio = 2048 self.downscale_ratio = 2048
self.process_output = lambda audio: audio self.process_output = lambda audio: audio
self.process_input = lambda audio: audio self.process_input = lambda audio: audio
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
@ -302,7 +303,7 @@ class VAE:
def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048): def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
return utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device) return utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
def decode(self, samples_in): def decode(self, samples_in):
try: try:
@ -558,9 +559,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
logging.debug("left over keys: {}".format(left_over)) logging.debug("left over keys: {}".format(left_over))
if output_model: if output_model:
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device) _model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device, ckpt_name=os.path.basename(ckpt_path))
if inital_load_device != torch.device("cpu"): if inital_load_device != torch.device("cpu"):
logging.info("loaded straight to GPU")
model_management.load_model_gpu(_model_patcher) model_management.load_model_gpu(_model_patcher)
return (_model_patcher, clip, vae, clipvision) return (_model_patcher, clip, vae, clipvision)
@ -568,7 +568,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def load_unet_state_dict(sd): # load unet in diffusers or regular format def load_unet_state_dict(sd): # load unet in diffusers or regular format
#Allow loading unets from checkpoint files # Allow loading unets from checkpoint files
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
temp_sd = utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True) temp_sd = utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
if len(temp_sd) > 0: if len(temp_sd) > 0:
@ -583,7 +583,7 @@ def load_unet_state_dict(sd): # load unet in diffusers or regular format
new_sd = sd new_sd = sd
else: else:
new_sd = model_detection.convert_diffusers_mmdit(sd, "") new_sd = model_detection.convert_diffusers_mmdit(sd, "")
if new_sd is not None: #diffusers mmdit if new_sd is not None: # diffusers mmdit
model_config = model_detection.model_config_from_unet(new_sd, "") model_config = model_detection.model_config_from_unet(new_sd, "")
if model_config is None: if model_config is None:
return None return None

View File

@ -46,6 +46,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
@pytest.mark.asyncio @pytest.mark.asyncio
def test_known_repos(tmp_path_factory): def test_known_repos(tmp_path_factory):
prev_hub_cache = os.getenv("HF_HUB_CACHE")
os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache"))
from comfy.cmd import folder_paths from comfy.cmd import folder_paths
from comfy.cmd.folder_paths import FolderPathsTuple from comfy.cmd.folder_paths import FolderPathsTuple
from comfy.model_downloader import get_huggingface_repo_list, \ from comfy.model_downloader import get_huggingface_repo_list, \
@ -57,9 +60,10 @@ def test_known_repos(tmp_path_factory):
test_repo_id = "doctorpangloss/comfyui_downloader_test" test_repo_id = "doctorpangloss/comfyui_downloader_test"
prev_huggingface = folder_paths.folder_names_and_paths["huggingface"] prev_huggingface = folder_paths.folder_names_and_paths["huggingface"]
prev_huggingface_cache = folder_paths.folder_names_and_paths["huggingface_cache"] prev_huggingface_cache = folder_paths.folder_names_and_paths["huggingface_cache"]
prev_hub_cache = os.getenv("HF_HUB_CACHE")
_delete_repo_from_huggingface_cache(test_repo_id) _delete_repo_from_huggingface_cache(test_repo_id)
_delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir) _delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir)
args.disable_known_models = False
try: try:
folder_paths.folder_names_and_paths["huggingface"] += FolderPathsTuple("huggingface", [test_local_dir], {""}) folder_paths.folder_names_and_paths["huggingface"] += FolderPathsTuple("huggingface", [test_local_dir], {""})
folder_paths.folder_names_and_paths["huggingface_cache"] += FolderPathsTuple("huggingface_cache", [test_cache_dir], {""}) folder_paths.folder_names_and_paths["huggingface_cache"] += FolderPathsTuple("huggingface_cache", [test_cache_dir], {""})