diff --git a/comfy/app/user_manager.py b/comfy/app/user_manager.py index 9f34c8b44..fb202eb88 100644 --- a/comfy/app/user_manager.py +++ b/comfy/app/user_manager.py @@ -1,13 +1,15 @@ +import glob import json import os import re -import uuid -import glob import shutil +import uuid + from aiohttp import web + +from .app_settings import AppSettings from ..cli_args import args from ..cmd.folder_paths import user_directory -from .app_settings import AppSettings class UserManager(): @@ -17,9 +19,6 @@ class UserManager(): self.settings = AppSettings(self) if not os.path.exists(user_directory): os.mkdir(user_directory) - if not args.multi_user: - print("****** User settings have been changed to be stored on the server instead of browser storage. ******") - print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******") if args.multi_user: if os.path.isfile(self.users_file): @@ -129,7 +128,7 @@ class UserManager(): return web.json_response(results) - def get_user_data_path(request, check_exists = False, param = "file"): + def get_user_data_path(request, check_exists=False, param="file"): file = request.match_info.get(param, None) if not file: return web.Response(status=400) diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 8310fb2a7..c7744ba25 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -8,22 +8,22 @@ import sys import threading import traceback import typing -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple import lazy_object_proxy import torch from opentelemetry.trace import get_current_span, StatusCode, Status -from typing_extensions import TypedDict from .main_pre import tracer from .. import interruption from .. import model_management from ..component_model.abstract_prompt_queue import AbstractPromptQueue -from ..component_model.executor_types import ExecutorToClientProgress +from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \ + ValidationErrorDict, NodeErrorsDictValue from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..execution_context import new_execution_context, ExecutionContext from ..nodes.package import import_all_nodes_in_workspace -from ..nodes.package_typing import ExportedNodes +from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions # ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the # various functions that are declared here. It should have been a context in the first place. @@ -492,7 +492,7 @@ class PromptExecutor: model_management.unload_all_models() -def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], typing.Any]: +def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple: # todo: this should check if LoadImage / LoadImageMask paths exist # todo: or, nodes should provide a way to validate their values unique_id = item @@ -506,11 +506,12 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t class_inputs = obj_class.INPUT_TYPES() required_inputs = class_inputs['required'] + error: ValidationErrorDict errors = [] valid = True # todo: investigate if these are at the right indent level - info = None + info: Optional[InputTypeSpec] = None val = None validate_function_inputs = [] @@ -531,7 +532,7 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t continue val = inputs[x] - info = required_inputs[x] + info: InputTypeSpec = required_inputs[x] type_input = info[0] if isinstance(val, list): if len(val) != 2: @@ -593,7 +594,7 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t "linked_node": val } }] - validated[o_id] = (False, reasons, o_id) + validated[o_id] = ValidateInputsTuple(False, reasons, o_id) continue else: try: @@ -622,10 +623,11 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t continue if len(info) > 1: - if "min" in info[1] and val < info[1]["min"]: + has_min_max: IntSpecOptions | FloatSpecOptions = info[1] + if "min" in has_min_max and val < has_min_max["min"]: error = { "type": "value_smaller_than_min", - "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), + "message": "Value {} smaller than min of {}".format(val, has_min_max["min"]), "details": f"{x}", "extra_info": { "input_name": x, @@ -635,10 +637,10 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t } errors.append(error) continue - if "max" in info[1] and val > info[1]["max"]: + if "max" in has_min_max and val > has_min_max["max"]: error = { "type": "value_bigger_than_max", - "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), + "message": "Value {} bigger than max of {}".format(val, has_min_max["max"]), "details": f"{x}", "extra_info": { "input_name": x, @@ -706,9 +708,9 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t continue if len(errors) > 0 or valid is not True: - ret = (False, errors, unique_id) + ret = ValidateInputsTuple(False, errors, unique_id) else: - ret = (True, [], unique_id) + ret = ValidateInputsTuple(True, [], unique_id) validated[unique_id] = ret return ret @@ -721,23 +723,25 @@ def full_type_name(klass): return module + '.' + klass.__qualname__ -class ValidationErrorExtraInfoDict(TypedDict): - exception_type: str - traceback: List[str] - - -class ValidationErrorDict(TypedDict): - type: str - message: str - details: str - extra_info: ValidationErrorExtraInfoDict | dict - - -ValidationTuple = typing.Tuple[bool, Optional[ValidationErrorDict], typing.List[str], Union[dict, list]] - - @tracer.start_as_current_span("Validate Prompt") def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple: + res = _validate_prompt(prompt) + if not res.valid: + span = get_current_span() + span.set_status(Status(StatusCode.ERROR)) + if res.error is not None and len(res.error) > 0: + span.set_attributes({ + f"error.{k}": v for k, v in res.error.items() + }) + if len(res.node_errors) > 0: + for node_id, node_error in res.node_errors.items(): + for node_error_field, node_error_value in node_error.items(): + if isinstance(node_error_value, (str, bool, int, float)): + span.set_attribute("node_errors.{node_id}.{node_error_field}", node_error_value) + return res + + +def _validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple: outputs = set() for x in prompt: if 'class_type' not in prompt[x]: @@ -747,7 +751,7 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple: "details": f"Node ID '#{x}'", "extra_info": {} } - return (False, error, [], []) + return ValidationTuple(False, error, [], []) class_type = prompt[x]['class_type'] class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None) @@ -758,7 +762,7 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple: "details": f"Node ID '#{x}'", "extra_info": {} } - return (False, error, [], []) + return ValidationTuple(False, error, [], []) if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True: outputs.add(x) @@ -770,15 +774,15 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple: "details": "", "extra_info": {} } - return False, error, [], [] + return ValidationTuple(False, error, [], []) good_outputs = set() errors = [] - node_errors = {} - validated = {} + node_errors: typing.Dict[str, NodeErrorsDictValue] = {} + validated: typing.Dict[str, ValidateInputsTuple] = {} for o in outputs: valid = False - reasons = [] + reasons: List[ValidationErrorDict] = [] try: m = validate_inputs(prompt, o, validated) valid = m[0] @@ -796,7 +800,7 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple: "traceback": traceback.format_tb(tb) } }] - validated[o] = (False, reasons, o) + validated[o] = ValidateInputsTuple(False, reasons, o) if valid is True: good_outputs.add(o) @@ -841,9 +845,9 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple: "extra_info": {} } - return False, error, list(good_outputs), node_errors + return ValidationTuple(False, error, list(good_outputs), node_errors) - return True, None, list(good_outputs), node_errors + return ValidationTuple(True, None, list(good_outputs), node_errors) class PromptQueue(AbstractPromptQueue): diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index b4585206e..249dda843 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -7,9 +7,9 @@ import shutil import threading import time +from .extra_model_paths import load_extra_path_config # main_pre must be the earliest import since it suppresses some spurious warnings from .main_pre import args -from .extra_model_paths import load_extra_path_config from .. import model_management from ..analytics.analytics import initialize_event_tracking from ..cmd import cuda_malloc @@ -223,7 +223,10 @@ async def main(): def entrypoint(): - asyncio.run(main()) + try: + asyncio.run(main()) + except KeyboardInterrupt as keyboard_interrupt: + logging.info(f"Gracefully shutting down due to {keyboard_interrupt}") if __name__ == "__main__": diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 573395841..3287eb55a 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -492,9 +492,6 @@ class PromptServer(ExecutorToClientProgress): @routes.post("/prompt") async def post_prompt(request): - logging.info("got prompt") - resp_code = 200 - out_string = "" json_data = await request.json() json_data = self.trigger_on_prompt(json_data) diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index 0a35297e0..a0b35d7f3 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -1,11 +1,13 @@ from __future__ import annotations # for Python 3.7-3.9 -from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple +import typing +from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple, List import PIL.Image from typing_extensions import NotRequired, TypedDict from .queue_types import BinaryEventTypes +from ..nodes.package_typing import InputTypeSpec class ExecInfo(TypedDict): @@ -85,3 +87,46 @@ class ExecutorToClientProgress(Protocol): :return: """ pass + + +ExceptionTypes = Literal["custom_validation_failed", "value_not_in_list", "value_bigger_than_max", "value_smaller_than_min", "invalid_input_type", "exception_during_inner_validation", "return_type_mismatch", "bad_linked_input", "required_input_missing", "invalid_prompt", "prompt_no_outputs", "exception_during_validation", "prompt_outputs_failed_validation"] + + +class ValidationErrorExtraInfoDict(TypedDict, total=False): + exception_type: NotRequired[str] + traceback: NotRequired[List[str]] + dependent_outputs: NotRequired[List[str]] + class_type: NotRequired[str] + input_name: NotRequired[str] + input_config: NotRequired[typing.Dict[str, InputTypeSpec]] + received_value: NotRequired[typing.Any] + linked_node: NotRequired[str] + traceback: NotRequired[str] + exception_message: NotRequired[str] + exception_type: NotRequired[str] + + +class ValidationErrorDict(TypedDict): + type: str + message: str + details: str + extra_info: list[typing.Never] | ValidationErrorExtraInfoDict + + +class NodeErrorsDictValue(TypedDict, total=False): + dependent_outputs: NotRequired[List[str]] + errors: List[ValidationErrorDict] + class_type: str + + +class ValidationTuple(typing.NamedTuple): + valid: bool + error: Optional[ValidationErrorDict] + good_output_node_ids: List[str] + node_errors: list[typing.Never] | typing.Dict[str, NodeErrorsDictValue] + + +class ValidateInputsTuple(typing.NamedTuple): + valid: bool + errors: List[ValidationErrorDict] + unique_id: str diff --git a/comfy/model_base.py b/comfy/model_base.py index e68d6b5e0..c2cb0117f 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -89,8 +89,6 @@ class BaseModel(torch.nn.Module): self.adm_channels = 0 self.concat_keys = () - logging.info("model_type {}".format(model_type.name)) - logging.debug("adm {}".format(self.adm_channels)) def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): sigma = t diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index dafc4cc76..eaad659be 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -315,6 +315,8 @@ KNOWN_CONTROLNETS: Final[List[Downloadable]] = [ HuggingFile("xinsir/controlnet-openpose-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-openpose-sdxl-1.0.safetensors"), HuggingFile("xinsir/anime-painter", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-anime-painter-scribble-sdxl-1.0.safetensors"), HuggingFile("TheMistoAI/MistoLine", "mistoLine_rank256.safetensors"), + HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0-promax.safetensors"), + HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0.safetensors"), ] KNOWN_DIFF_CONTROLNETS: Final[List[Downloadable]] = [ diff --git a/comfy/model_management.py b/comfy/model_management.py index 9d061600d..b8250af85 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -10,9 +10,11 @@ from typing import Literal, List import psutil import torch +from opentelemetry.trace import get_current_span from . import interruption from .cli_args import args +from .cmd.main_pre import tracer from .model_management_types import ModelManageable model_management_lock = RLock() @@ -356,6 +358,12 @@ class LoadedModel: def __eq__(self, other): return self.model is other.model + def __str__(self): + if self.model is not None: + return f"" + else: + return f"" + def minimum_inference_memory(): return (1024 * 1024 * 1024) @@ -392,9 +400,12 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> b return unload_weight -def free_memory(memory_required, device, keep_loaded=[]): +@tracer.start_as_current_span("Free Memory") +def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]: + span = get_current_span() + span.set_attribute("memory_required", memory_required) with model_management_lock: - unloaded_model = [] + unloaded_models: List[LoadedModel] = [] can_unload = [] for i in range(len(current_loaded_models) - 1, -1, -1): @@ -410,12 +421,12 @@ def free_memory(memory_required, device, keep_loaded=[]): if get_free_memory(device) > memory_required: break current_loaded_models[i].model_unload() - unloaded_model.append(i) + unloaded_models.append(i) - for i in sorted(unloaded_model, reverse=True): + for i in sorted(unloaded_models, reverse=True): current_loaded_models.pop(i) - if len(unloaded_model) > 0: + if len(unloaded_models) > 0: soft_empty_cache() else: if vram_state != VRAMState.HIGH_VRAM: @@ -423,10 +434,16 @@ def free_memory(memory_required, device, keep_loaded=[]): if mem_free_torch > mem_free_total * 0.25: soft_empty_cache() + span.set_attribute("unloaded_models", list(map(str, unloaded_models))) + return unloaded_models + +@tracer.start_as_current_span("Load Models GPU") def load_models_gpu(models, memory_required=0, force_patch_weights=False): global vram_state - + span = get_current_span() + if memory_required != 0: + span.set_attribute("memory_required", memory_required) with model_management_lock: inference_memory = minimum_inference_memory() extra_mem = max(inference_memory, memory_required) @@ -452,19 +469,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False): loaded.currently_used = True models_already_loaded.append(loaded) if loaded is None: - if hasattr(x, "model"): - logging.info(f"Requested to load {x.model.__class__.__name__}") models_to_load.append(loaded_model) + models_freed: List[LoadedModel] = [] if len(models_to_load) == 0: devs = set(map(lambda a: a.device, models_already_loaded)) for d in devs: if d != torch.device("cpu"): - free_memory(extra_mem, d, models_already_loaded) + models_freed += free_memory(extra_mem, d, models_already_loaded) return - logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") - total_memory_required = {} for loaded_model in models_to_load: if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different @@ -472,7 +486,12 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False): for device in total_memory_required: if device != torch.device("cpu"): - free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) + models_freed += free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) + + span.set_attribute("models_to_load", list(map(str, models_to_load))) + span.set_attribute("models_freed", list(map(str, models_freed))) + + logging.info(f"Models loaded: {','.join(map(str, models_to_load))}, models freed: {','.join(map(str, models_freed))}") for loaded_model in models_to_load: weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 6e9c40a95..1fbcb0808 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -2,13 +2,14 @@ import copy import inspect import logging import uuid +from typing import Optional import torch from . import model_management from . import utils -from .types import UnetWrapperFunction from .model_management_types import ModelManageable +from .types import UnetWrapperFunction def weight_decompose(dora_scale, weight, lora_diff, alpha, strength): @@ -60,14 +61,16 @@ def set_model_options_post_cfg_function(model_options, post_cfg_function, disabl model_options["disable_cfg1_optimization"] = True return model_options + def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False): model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function] if disable_cfg1_optimization: model_options["disable_cfg1_optimization"] = True return model_options + class ModelPatcher(ModelManageable): - def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): + def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False, ckpt_name: Optional[str] = None): self.size = size self.model = model self.patches = {} @@ -87,6 +90,7 @@ class ModelPatcher(ModelManageable): self.weight_inplace_update = weight_inplace_update self.model_lowvram = False self.patches_uuid = uuid.uuid4() + self.ckpt_name = ckpt_name self._lowvram_patch_counter = 0 @property @@ -105,6 +109,7 @@ class ModelPatcher(ModelManageable): def clone(self): n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self._current_device, weight_inplace_update=self.weight_inplace_update) + n.ckpt_name = self.ckpt_name n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] @@ -578,3 +583,9 @@ class ModelPatcher(ModelManageable): @property def current_device(self) -> torch.device: return self._current_device + + def __str__(self): + if self.ckpt_name is not None: + return f"" + else: + return f"" diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py index 4eb30ac3d..e78aa6a58 100644 --- a/comfy/nodes/package_typing.py +++ b/comfy/nodes/package_typing.py @@ -1,10 +1,11 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing_extensions import TypedDict, NotRequired, Generic from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \ Callable, List, Type +from typing_extensions import TypedDict, NotRequired + T = TypeVar('T') @@ -51,7 +52,7 @@ StringSpec = Tuple[Literal["STRING"], StringSpecOptions] BooleanSpec = Tuple[Literal["BOOLEAN"], BoolSpecOptions] -ChoiceSpec = Tuple[Union[Sequence[str], Sequence[float], Sequence[int]]] +ChoiceSpec = Tuple[Union[List[str], List[float], List[int], Tuple[str, ...], Tuple[float, ...], Tuple[int, ...]]] NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any] @@ -73,6 +74,7 @@ ValidateInputsMethod = Optional[Callable[..., Union[bool, str]]] IsChangedMethod = Callable[[Type[Any], ...], str] + class FunctionReturnsUIVariables(TypedDict): ui: dict result: NotRequired[Sequence[Any]] @@ -123,6 +125,10 @@ class CustomNode(Protocol): IS_CHANGED: Optional[ClassVar[IsChangedMethod]] + @classmethod + def __call__(cls, *args, **kwargs) -> 'CustomNode': + ... + @dataclass class ExportedNodes: diff --git a/comfy/sd.py b/comfy/sd.py index ef9b73bd5..f9f0288a3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -2,6 +2,7 @@ from __future__ import annotations import dataclasses import logging +import os.path from enum import Enum from typing import Any, Optional @@ -16,18 +17,18 @@ from . import model_detection from . import model_management from . import model_patcher from . import model_sampling +from . import sa_t5 from . import sd1_clip from . import sd2_clip +from . import sd3_clip from . import sdxl_clip from . import utils +from .ldm.audio.autoencoder import AudioOobleckVAE from .ldm.cascade.stage_a import StageA from .ldm.cascade.stage_c_coder import StageC_coder from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine -from .ldm.audio.autoencoder import AudioOobleckVAE from .t2i_adapter import adapter from .taesd import taesd -from . import sd3_clip -from . import sa_t5 from .text_encoders import aura_t5 @@ -228,7 +229,7 @@ class VAE: self.latent_channels = 64 self.output_channels = 2 self.upscale_ratio = 2048 - self.downscale_ratio = 2048 + self.downscale_ratio = 2048 self.process_output = lambda audio: audio self.process_input = lambda audio: audio self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] @@ -302,7 +303,7 @@ class VAE: def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048): encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() - return utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device) + return utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device) def decode(self, samples_in): try: @@ -558,9 +559,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o logging.debug("left over keys: {}".format(left_over)) if output_model: - _model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device) + _model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device, ckpt_name=os.path.basename(ckpt_path)) if inital_load_device != torch.device("cpu"): - logging.info("loaded straight to GPU") model_management.load_model_gpu(_model_patcher) return (_model_patcher, clip, vae, clipvision) @@ -568,7 +568,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o def load_unet_state_dict(sd): # load unet in diffusers or regular format - #Allow loading unets from checkpoint files + # Allow loading unets from checkpoint files diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) temp_sd = utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True) if len(temp_sd) > 0: @@ -583,7 +583,7 @@ def load_unet_state_dict(sd): # load unet in diffusers or regular format new_sd = sd else: new_sd = model_detection.convert_diffusers_mmdit(sd, "") - if new_sd is not None: #diffusers mmdit + if new_sd is not None: # diffusers mmdit model_config = model_detection.model_config_from_unet(new_sd, "") if model_config is None: return None diff --git a/tests/downloader/test_huggingface_downloads.py b/tests/downloader/test_huggingface_downloads.py index 6b64dcb5a..d1182770b 100644 --- a/tests/downloader/test_huggingface_downloads.py +++ b/tests/downloader/test_huggingface_downloads.py @@ -46,6 +46,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text @pytest.mark.asyncio def test_known_repos(tmp_path_factory): + prev_hub_cache = os.getenv("HF_HUB_CACHE") + os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache")) + from comfy.cmd import folder_paths from comfy.cmd.folder_paths import FolderPathsTuple from comfy.model_downloader import get_huggingface_repo_list, \ @@ -57,9 +60,10 @@ def test_known_repos(tmp_path_factory): test_repo_id = "doctorpangloss/comfyui_downloader_test" prev_huggingface = folder_paths.folder_names_and_paths["huggingface"] prev_huggingface_cache = folder_paths.folder_names_and_paths["huggingface_cache"] - prev_hub_cache = os.getenv("HF_HUB_CACHE") + _delete_repo_from_huggingface_cache(test_repo_id) _delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir) + args.disable_known_models = False try: folder_paths.folder_names_and_paths["huggingface"] += FolderPathsTuple("huggingface", [test_local_dir], {""}) folder_paths.folder_names_and_paths["huggingface_cache"] += FolderPathsTuple("huggingface_cache", [test_cache_dir], {""})