mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Improve logging and tracing for validation errors
This commit is contained in:
parent
a20bf8134d
commit
72baecad87
@ -1,13 +1,15 @@
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import glob
|
||||
import shutil
|
||||
import uuid
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from .app_settings import AppSettings
|
||||
from ..cli_args import args
|
||||
from ..cmd.folder_paths import user_directory
|
||||
from .app_settings import AppSettings
|
||||
|
||||
|
||||
class UserManager():
|
||||
@ -17,9 +19,6 @@ class UserManager():
|
||||
self.settings = AppSettings(self)
|
||||
if not os.path.exists(user_directory):
|
||||
os.mkdir(user_directory)
|
||||
if not args.multi_user:
|
||||
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
||||
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||
|
||||
if args.multi_user:
|
||||
if os.path.isfile(self.users_file):
|
||||
@ -129,7 +128,7 @@ class UserManager():
|
||||
|
||||
return web.json_response(results)
|
||||
|
||||
def get_user_data_path(request, check_exists = False, param = "file"):
|
||||
def get_user_data_path(request, check_exists=False, param="file"):
|
||||
file = request.match_info.get(param, None)
|
||||
if not file:
|
||||
return web.Response(status=400)
|
||||
|
||||
@ -8,22 +8,22 @@ import sys
|
||||
import threading
|
||||
import traceback
|
||||
import typing
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import lazy_object_proxy
|
||||
import torch
|
||||
from opentelemetry.trace import get_current_span, StatusCode, Status
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from .main_pre import tracer
|
||||
from .. import interruption
|
||||
from .. import model_management
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
|
||||
ValidationErrorDict, NodeErrorsDictValue
|
||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
||||
from ..execution_context import new_execution_context, ExecutionContext
|
||||
from ..nodes.package import import_all_nodes_in_workspace
|
||||
from ..nodes.package_typing import ExportedNodes
|
||||
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions
|
||||
|
||||
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
|
||||
# various functions that are declared here. It should have been a context in the first place.
|
||||
@ -492,7 +492,7 @@ class PromptExecutor:
|
||||
model_management.unload_all_models()
|
||||
|
||||
|
||||
def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], typing.Any]:
|
||||
def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple:
|
||||
# todo: this should check if LoadImage / LoadImageMask paths exist
|
||||
# todo: or, nodes should provide a way to validate their values
|
||||
unique_id = item
|
||||
@ -506,11 +506,12 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
|
||||
class_inputs = obj_class.INPUT_TYPES()
|
||||
required_inputs = class_inputs['required']
|
||||
|
||||
error: ValidationErrorDict
|
||||
errors = []
|
||||
valid = True
|
||||
|
||||
# todo: investigate if these are at the right indent level
|
||||
info = None
|
||||
info: Optional[InputTypeSpec] = None
|
||||
val = None
|
||||
|
||||
validate_function_inputs = []
|
||||
@ -531,7 +532,7 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
|
||||
continue
|
||||
|
||||
val = inputs[x]
|
||||
info = required_inputs[x]
|
||||
info: InputTypeSpec = required_inputs[x]
|
||||
type_input = info[0]
|
||||
if isinstance(val, list):
|
||||
if len(val) != 2:
|
||||
@ -593,7 +594,7 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
|
||||
"linked_node": val
|
||||
}
|
||||
}]
|
||||
validated[o_id] = (False, reasons, o_id)
|
||||
validated[o_id] = ValidateInputsTuple(False, reasons, o_id)
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
@ -622,10 +623,11 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
|
||||
continue
|
||||
|
||||
if len(info) > 1:
|
||||
if "min" in info[1] and val < info[1]["min"]:
|
||||
has_min_max: IntSpecOptions | FloatSpecOptions = info[1]
|
||||
if "min" in has_min_max and val < has_min_max["min"]:
|
||||
error = {
|
||||
"type": "value_smaller_than_min",
|
||||
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
|
||||
"message": "Value {} smaller than min of {}".format(val, has_min_max["min"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
@ -635,10 +637,10 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
if "max" in info[1] and val > info[1]["max"]:
|
||||
if "max" in has_min_max and val > has_min_max["max"]:
|
||||
error = {
|
||||
"type": "value_bigger_than_max",
|
||||
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
|
||||
"message": "Value {} bigger than max of {}".format(val, has_min_max["max"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
@ -706,9 +708,9 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
|
||||
continue
|
||||
|
||||
if len(errors) > 0 or valid is not True:
|
||||
ret = (False, errors, unique_id)
|
||||
ret = ValidateInputsTuple(False, errors, unique_id)
|
||||
else:
|
||||
ret = (True, [], unique_id)
|
||||
ret = ValidateInputsTuple(True, [], unique_id)
|
||||
|
||||
validated[unique_id] = ret
|
||||
return ret
|
||||
@ -721,23 +723,25 @@ def full_type_name(klass):
|
||||
return module + '.' + klass.__qualname__
|
||||
|
||||
|
||||
class ValidationErrorExtraInfoDict(TypedDict):
|
||||
exception_type: str
|
||||
traceback: List[str]
|
||||
|
||||
|
||||
class ValidationErrorDict(TypedDict):
|
||||
type: str
|
||||
message: str
|
||||
details: str
|
||||
extra_info: ValidationErrorExtraInfoDict | dict
|
||||
|
||||
|
||||
ValidationTuple = typing.Tuple[bool, Optional[ValidationErrorDict], typing.List[str], Union[dict, list]]
|
||||
|
||||
|
||||
@tracer.start_as_current_span("Validate Prompt")
|
||||
def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
||||
res = _validate_prompt(prompt)
|
||||
if not res.valid:
|
||||
span = get_current_span()
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
if res.error is not None and len(res.error) > 0:
|
||||
span.set_attributes({
|
||||
f"error.{k}": v for k, v in res.error.items()
|
||||
})
|
||||
if len(res.node_errors) > 0:
|
||||
for node_id, node_error in res.node_errors.items():
|
||||
for node_error_field, node_error_value in node_error.items():
|
||||
if isinstance(node_error_value, (str, bool, int, float)):
|
||||
span.set_attribute("node_errors.{node_id}.{node_error_field}", node_error_value)
|
||||
return res
|
||||
|
||||
|
||||
def _validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
||||
outputs = set()
|
||||
for x in prompt:
|
||||
if 'class_type' not in prompt[x]:
|
||||
@ -747,7 +751,7 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
||||
"details": f"Node ID '#{x}'",
|
||||
"extra_info": {}
|
||||
}
|
||||
return (False, error, [], [])
|
||||
return ValidationTuple(False, error, [], [])
|
||||
|
||||
class_type = prompt[x]['class_type']
|
||||
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
|
||||
@ -758,7 +762,7 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
||||
"details": f"Node ID '#{x}'",
|
||||
"extra_info": {}
|
||||
}
|
||||
return (False, error, [], [])
|
||||
return ValidationTuple(False, error, [], [])
|
||||
|
||||
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
||||
outputs.add(x)
|
||||
@ -770,15 +774,15 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
||||
"details": "",
|
||||
"extra_info": {}
|
||||
}
|
||||
return False, error, [], []
|
||||
return ValidationTuple(False, error, [], [])
|
||||
|
||||
good_outputs = set()
|
||||
errors = []
|
||||
node_errors = {}
|
||||
validated = {}
|
||||
node_errors: typing.Dict[str, NodeErrorsDictValue] = {}
|
||||
validated: typing.Dict[str, ValidateInputsTuple] = {}
|
||||
for o in outputs:
|
||||
valid = False
|
||||
reasons = []
|
||||
reasons: List[ValidationErrorDict] = []
|
||||
try:
|
||||
m = validate_inputs(prompt, o, validated)
|
||||
valid = m[0]
|
||||
@ -796,7 +800,7 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
||||
"traceback": traceback.format_tb(tb)
|
||||
}
|
||||
}]
|
||||
validated[o] = (False, reasons, o)
|
||||
validated[o] = ValidateInputsTuple(False, reasons, o)
|
||||
|
||||
if valid is True:
|
||||
good_outputs.add(o)
|
||||
@ -841,9 +845,9 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
||||
"extra_info": {}
|
||||
}
|
||||
|
||||
return False, error, list(good_outputs), node_errors
|
||||
return ValidationTuple(False, error, list(good_outputs), node_errors)
|
||||
|
||||
return True, None, list(good_outputs), node_errors
|
||||
return ValidationTuple(True, None, list(good_outputs), node_errors)
|
||||
|
||||
|
||||
class PromptQueue(AbstractPromptQueue):
|
||||
|
||||
@ -7,9 +7,9 @@ import shutil
|
||||
import threading
|
||||
import time
|
||||
|
||||
from .extra_model_paths import load_extra_path_config
|
||||
# main_pre must be the earliest import since it suppresses some spurious warnings
|
||||
from .main_pre import args
|
||||
from .extra_model_paths import load_extra_path_config
|
||||
from .. import model_management
|
||||
from ..analytics.analytics import initialize_event_tracking
|
||||
from ..cmd import cuda_malloc
|
||||
@ -223,7 +223,10 @@ async def main():
|
||||
|
||||
|
||||
def entrypoint():
|
||||
asyncio.run(main())
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt as keyboard_interrupt:
|
||||
logging.info(f"Gracefully shutting down due to {keyboard_interrupt}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -492,9 +492,6 @@ class PromptServer(ExecutorToClientProgress):
|
||||
|
||||
@routes.post("/prompt")
|
||||
async def post_prompt(request):
|
||||
logging.info("got prompt")
|
||||
resp_code = 200
|
||||
out_string = ""
|
||||
json_data = await request.json()
|
||||
json_data = self.trigger_on_prompt(json_data)
|
||||
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
from __future__ import annotations # for Python 3.7-3.9
|
||||
|
||||
from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple
|
||||
import typing
|
||||
from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple, List
|
||||
|
||||
import PIL.Image
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from .queue_types import BinaryEventTypes
|
||||
from ..nodes.package_typing import InputTypeSpec
|
||||
|
||||
|
||||
class ExecInfo(TypedDict):
|
||||
@ -85,3 +87,46 @@ class ExecutorToClientProgress(Protocol):
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
ExceptionTypes = Literal["custom_validation_failed", "value_not_in_list", "value_bigger_than_max", "value_smaller_than_min", "invalid_input_type", "exception_during_inner_validation", "return_type_mismatch", "bad_linked_input", "required_input_missing", "invalid_prompt", "prompt_no_outputs", "exception_during_validation", "prompt_outputs_failed_validation"]
|
||||
|
||||
|
||||
class ValidationErrorExtraInfoDict(TypedDict, total=False):
|
||||
exception_type: NotRequired[str]
|
||||
traceback: NotRequired[List[str]]
|
||||
dependent_outputs: NotRequired[List[str]]
|
||||
class_type: NotRequired[str]
|
||||
input_name: NotRequired[str]
|
||||
input_config: NotRequired[typing.Dict[str, InputTypeSpec]]
|
||||
received_value: NotRequired[typing.Any]
|
||||
linked_node: NotRequired[str]
|
||||
traceback: NotRequired[str]
|
||||
exception_message: NotRequired[str]
|
||||
exception_type: NotRequired[str]
|
||||
|
||||
|
||||
class ValidationErrorDict(TypedDict):
|
||||
type: str
|
||||
message: str
|
||||
details: str
|
||||
extra_info: list[typing.Never] | ValidationErrorExtraInfoDict
|
||||
|
||||
|
||||
class NodeErrorsDictValue(TypedDict, total=False):
|
||||
dependent_outputs: NotRequired[List[str]]
|
||||
errors: List[ValidationErrorDict]
|
||||
class_type: str
|
||||
|
||||
|
||||
class ValidationTuple(typing.NamedTuple):
|
||||
valid: bool
|
||||
error: Optional[ValidationErrorDict]
|
||||
good_output_node_ids: List[str]
|
||||
node_errors: list[typing.Never] | typing.Dict[str, NodeErrorsDictValue]
|
||||
|
||||
|
||||
class ValidateInputsTuple(typing.NamedTuple):
|
||||
valid: bool
|
||||
errors: List[ValidationErrorDict]
|
||||
unique_id: str
|
||||
|
||||
@ -89,8 +89,6 @@ class BaseModel(torch.nn.Module):
|
||||
self.adm_channels = 0
|
||||
|
||||
self.concat_keys = ()
|
||||
logging.info("model_type {}".format(model_type.name))
|
||||
logging.debug("adm {}".format(self.adm_channels))
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
sigma = t
|
||||
|
||||
@ -315,6 +315,8 @@ KNOWN_CONTROLNETS: Final[List[Downloadable]] = [
|
||||
HuggingFile("xinsir/controlnet-openpose-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-openpose-sdxl-1.0.safetensors"),
|
||||
HuggingFile("xinsir/anime-painter", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-anime-painter-scribble-sdxl-1.0.safetensors"),
|
||||
HuggingFile("TheMistoAI/MistoLine", "mistoLine_rank256.safetensors"),
|
||||
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0-promax.safetensors"),
|
||||
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0.safetensors"),
|
||||
]
|
||||
|
||||
KNOWN_DIFF_CONTROLNETS: Final[List[Downloadable]] = [
|
||||
|
||||
@ -10,9 +10,11 @@ from typing import Literal, List
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
from opentelemetry.trace import get_current_span
|
||||
|
||||
from . import interruption
|
||||
from .cli_args import args
|
||||
from .cmd.main_pre import tracer
|
||||
from .model_management_types import ModelManageable
|
||||
|
||||
model_management_lock = RLock()
|
||||
@ -356,6 +358,12 @@ class LoadedModel:
|
||||
def __eq__(self, other):
|
||||
return self.model is other.model
|
||||
|
||||
def __str__(self):
|
||||
if self.model is not None:
|
||||
return f"<LoadedModel {str(self.model)}>"
|
||||
else:
|
||||
return f"<LoadedModel>"
|
||||
|
||||
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024)
|
||||
@ -392,9 +400,12 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> b
|
||||
return unload_weight
|
||||
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
@tracer.start_as_current_span("Free Memory")
|
||||
def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
|
||||
span = get_current_span()
|
||||
span.set_attribute("memory_required", memory_required)
|
||||
with model_management_lock:
|
||||
unloaded_model = []
|
||||
unloaded_models: List[LoadedModel] = []
|
||||
can_unload = []
|
||||
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
@ -410,12 +421,12 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
if get_free_memory(device) > memory_required:
|
||||
break
|
||||
current_loaded_models[i].model_unload()
|
||||
unloaded_model.append(i)
|
||||
unloaded_models.append(i)
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
for i in sorted(unloaded_models, reverse=True):
|
||||
current_loaded_models.pop(i)
|
||||
|
||||
if len(unloaded_model) > 0:
|
||||
if len(unloaded_models) > 0:
|
||||
soft_empty_cache()
|
||||
else:
|
||||
if vram_state != VRAMState.HIGH_VRAM:
|
||||
@ -423,10 +434,16 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
if mem_free_torch > mem_free_total * 0.25:
|
||||
soft_empty_cache()
|
||||
|
||||
span.set_attribute("unloaded_models", list(map(str, unloaded_models)))
|
||||
return unloaded_models
|
||||
|
||||
|
||||
@tracer.start_as_current_span("Load Models GPU")
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
||||
global vram_state
|
||||
|
||||
span = get_current_span()
|
||||
if memory_required != 0:
|
||||
span.set_attribute("memory_required", memory_required)
|
||||
with model_management_lock:
|
||||
inference_memory = minimum_inference_memory()
|
||||
extra_mem = max(inference_memory, memory_required)
|
||||
@ -452,19 +469,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
||||
loaded.currently_used = True
|
||||
models_already_loaded.append(loaded)
|
||||
if loaded is None:
|
||||
if hasattr(x, "model"):
|
||||
logging.info(f"Requested to load {x.model.__class__.__name__}")
|
||||
models_to_load.append(loaded_model)
|
||||
|
||||
models_freed: List[LoadedModel] = []
|
||||
if len(models_to_load) == 0:
|
||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||
for d in devs:
|
||||
if d != torch.device("cpu"):
|
||||
free_memory(extra_mem, d, models_already_loaded)
|
||||
models_freed += free_memory(extra_mem, d, models_already_loaded)
|
||||
return
|
||||
|
||||
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different
|
||||
@ -472,7 +486,12 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
||||
models_freed += free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
||||
|
||||
span.set_attribute("models_to_load", list(map(str, models_to_load)))
|
||||
span.set_attribute("models_freed", list(map(str, models_freed)))
|
||||
|
||||
logging.info(f"Models loaded: {','.join(map(str, models_to_load))}, models freed: {','.join(map(str, models_freed))}")
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
|
||||
|
||||
@ -2,13 +2,14 @@ import copy
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from . import model_management
|
||||
from . import utils
|
||||
from .types import UnetWrapperFunction
|
||||
from .model_management_types import ModelManageable
|
||||
from .types import UnetWrapperFunction
|
||||
|
||||
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
|
||||
@ -60,14 +61,16 @@ def set_model_options_post_cfg_function(model_options, post_cfg_function, disabl
|
||||
model_options["disable_cfg1_optimization"] = True
|
||||
return model_options
|
||||
|
||||
|
||||
def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False):
|
||||
model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function]
|
||||
if disable_cfg1_optimization:
|
||||
model_options["disable_cfg1_optimization"] = True
|
||||
return model_options
|
||||
|
||||
|
||||
class ModelPatcher(ModelManageable):
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.patches = {}
|
||||
@ -87,6 +90,7 @@ class ModelPatcher(ModelManageable):
|
||||
self.weight_inplace_update = weight_inplace_update
|
||||
self.model_lowvram = False
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
self.ckpt_name = ckpt_name
|
||||
self._lowvram_patch_counter = 0
|
||||
|
||||
@property
|
||||
@ -105,6 +109,7 @@ class ModelPatcher(ModelManageable):
|
||||
|
||||
def clone(self):
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self._current_device, weight_inplace_update=self.weight_inplace_update)
|
||||
n.ckpt_name = self.ckpt_name
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
@ -578,3 +583,9 @@ class ModelPatcher(ModelManageable):
|
||||
@property
|
||||
def current_device(self) -> torch.device:
|
||||
return self._current_device
|
||||
|
||||
def __str__(self):
|
||||
if self.ckpt_name is not None:
|
||||
return f"<ModelPatcher for {self.ckpt_name} ({self.model.__class__.__name__})>"
|
||||
else:
|
||||
return f"<ModelPatcher for {self.model.__class__.__name__}>"
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing_extensions import TypedDict, NotRequired, Generic
|
||||
from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
|
||||
Callable, List, Type
|
||||
|
||||
from typing_extensions import TypedDict, NotRequired
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
@ -51,7 +52,7 @@ StringSpec = Tuple[Literal["STRING"], StringSpecOptions]
|
||||
|
||||
BooleanSpec = Tuple[Literal["BOOLEAN"], BoolSpecOptions]
|
||||
|
||||
ChoiceSpec = Tuple[Union[Sequence[str], Sequence[float], Sequence[int]]]
|
||||
ChoiceSpec = Tuple[Union[List[str], List[float], List[int], Tuple[str, ...], Tuple[float, ...], Tuple[int, ...]]]
|
||||
|
||||
NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any]
|
||||
|
||||
@ -73,6 +74,7 @@ ValidateInputsMethod = Optional[Callable[..., Union[bool, str]]]
|
||||
|
||||
IsChangedMethod = Callable[[Type[Any], ...], str]
|
||||
|
||||
|
||||
class FunctionReturnsUIVariables(TypedDict):
|
||||
ui: dict
|
||||
result: NotRequired[Sequence[Any]]
|
||||
@ -123,6 +125,10 @@ class CustomNode(Protocol):
|
||||
|
||||
IS_CHANGED: Optional[ClassVar[IsChangedMethod]]
|
||||
|
||||
@classmethod
|
||||
def __call__(cls, *args, **kwargs) -> 'CustomNode':
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExportedNodes:
|
||||
|
||||
18
comfy/sd.py
18
comfy/sd.py
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import os.path
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -16,18 +17,18 @@ from . import model_detection
|
||||
from . import model_management
|
||||
from . import model_patcher
|
||||
from . import model_sampling
|
||||
from . import sa_t5
|
||||
from . import sd1_clip
|
||||
from . import sd2_clip
|
||||
from . import sd3_clip
|
||||
from . import sdxl_clip
|
||||
from . import utils
|
||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||
from .ldm.cascade.stage_a import StageA
|
||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||
from .t2i_adapter import adapter
|
||||
from .taesd import taesd
|
||||
from . import sd3_clip
|
||||
from . import sa_t5
|
||||
from .text_encoders import aura_t5
|
||||
|
||||
|
||||
@ -228,7 +229,7 @@ class VAE:
|
||||
self.latent_channels = 64
|
||||
self.output_channels = 2
|
||||
self.upscale_ratio = 2048
|
||||
self.downscale_ratio = 2048
|
||||
self.downscale_ratio = 2048
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
@ -302,7 +303,7 @@ class VAE:
|
||||
|
||||
def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
return utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||
return utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||
|
||||
def decode(self, samples_in):
|
||||
try:
|
||||
@ -558,9 +559,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
logging.debug("left over keys: {}".format(left_over))
|
||||
|
||||
if output_model:
|
||||
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
||||
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device, ckpt_name=os.path.basename(ckpt_path))
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
logging.info("loaded straight to GPU")
|
||||
model_management.load_model_gpu(_model_patcher)
|
||||
|
||||
return (_model_patcher, clip, vae, clipvision)
|
||||
@ -568,7 +568,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
|
||||
def load_unet_state_dict(sd): # load unet in diffusers or regular format
|
||||
|
||||
#Allow loading unets from checkpoint files
|
||||
# Allow loading unets from checkpoint files
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
temp_sd = utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
|
||||
if len(temp_sd) > 0:
|
||||
@ -583,7 +583,7 @@ def load_unet_state_dict(sd): # load unet in diffusers or regular format
|
||||
new_sd = sd
|
||||
else:
|
||||
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||
if new_sd is not None: #diffusers mmdit
|
||||
if new_sd is not None: # diffusers mmdit
|
||||
model_config = model_detection.model_config_from_unet(new_sd, "")
|
||||
if model_config is None:
|
||||
return None
|
||||
|
||||
@ -46,6 +46,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_known_repos(tmp_path_factory):
|
||||
prev_hub_cache = os.getenv("HF_HUB_CACHE")
|
||||
os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache"))
|
||||
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy.cmd.folder_paths import FolderPathsTuple
|
||||
from comfy.model_downloader import get_huggingface_repo_list, \
|
||||
@ -57,9 +60,10 @@ def test_known_repos(tmp_path_factory):
|
||||
test_repo_id = "doctorpangloss/comfyui_downloader_test"
|
||||
prev_huggingface = folder_paths.folder_names_and_paths["huggingface"]
|
||||
prev_huggingface_cache = folder_paths.folder_names_and_paths["huggingface_cache"]
|
||||
prev_hub_cache = os.getenv("HF_HUB_CACHE")
|
||||
|
||||
_delete_repo_from_huggingface_cache(test_repo_id)
|
||||
_delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir)
|
||||
args.disable_known_models = False
|
||||
try:
|
||||
folder_paths.folder_names_and_paths["huggingface"] += FolderPathsTuple("huggingface", [test_local_dir], {""})
|
||||
folder_paths.folder_names_and_paths["huggingface_cache"] += FolderPathsTuple("huggingface_cache", [test_cache_dir], {""})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user