From 03e54301216ee98b6bc211ea6ef842366a3daf94 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Mon, 28 Jul 2025 14:36:27 -0700 Subject: [PATCH] Improvements for Wan 2.2 support - add xet support and add the xet cache to manageable directories - xet is enabled by default - fix logging to root in various places - improve logging about model unloading and loading - TorchCompileNode now supports the VAE - torchaudio missing will cause less noise in the logs - feature flags will assume to be supporting everything in the distributed progress context - fixes progress notifications --- comfy/app/logger.py | 2 +- comfy/cmd/folder_paths.py | 4 ++ comfy/cmd/main_pre.py | 2 +- comfy/cmd/server.py | 16 ++++++-- comfy/component_model/executor_types.py | 39 +++++++++++++++++-- comfy/ldm/ace/vae/music_dcae_pipeline.py | 2 +- comfy/ldm/ace/vae/music_log_mel.py | 2 +- comfy/ldm/modules/diffusionmodules/model.py | 14 ++++--- comfy/model_base.py | 25 +++++++----- comfy/model_downloader.py | 12 +++++- comfy/model_management.py | 1 - comfy/model_patcher.py | 27 +++++++------ comfy/progress.py | 9 +++-- comfy/sd.py | 41 ++++++++++++++++---- comfy/utils.py | 23 ++++++----- comfy_api/feature_flags.py | 6 ++- comfy_extras/nodes/nodes_torch_compile.py | 42 +++++++++++++-------- pyproject.toml | 5 ++- tests/conftest.py | 2 +- 19 files changed, 192 insertions(+), 82 deletions(-) diff --git a/comfy/app/logger.py b/comfy/app/logger.py index 2ee8c2ebf..82482240a 100644 --- a/comfy/app/logger.py +++ b/comfy/app/logger.py @@ -84,7 +84,7 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool stream_handler = logging.StreamHandler() stream_handler.setFormatter(logging.Formatter( - "%(asctime)s [%(name)s] [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s", + "%(asctime)s [%(levelname)s] [%(name)s] [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d %H:%M:%S" )) diff --git a/comfy/cmd/folder_paths.py b/comfy/cmd/folder_paths.py index a50da45c3..59832b638 100644 --- a/comfy/cmd/folder_paths.py +++ b/comfy/cmd/folder_paths.py @@ -86,6 +86,9 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio if "HF_HUB_CACHE" in os.environ: hf_cache_paths.additional_absolute_directory_paths.append(os.environ.get("HF_HUB_CACHE")) + hf_xet = ModelPaths(["xet"], supported_extensions=set()) + if "HF_XET_CACHE" in os.environ: + hf_xet.additional_absolute_directory_paths.append(os.environ.get("HF_XET_CACHE")) model_paths_to_add = [ ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)), ModelPaths(["configs"], additional_absolute_directory_paths=[get_package_as_path("comfy.configs")], supported_extensions={".yaml"}), @@ -107,6 +110,7 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio ModelPaths(["classifiers"], supported_extensions=set()), ModelPaths(["huggingface"], supported_extensions=set()), hf_cache_paths, + hf_xet, ] for model_paths in model_paths_to_add: if replace_existing: diff --git a/comfy/cmd/main_pre.py b/comfy/cmd/main_pre.py index 6a0a616b8..cc944f666 100644 --- a/comfy/cmd/main_pre.py +++ b/comfy/cmd/main_pre.py @@ -18,7 +18,7 @@ from .. import options from ..app import logger os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" -os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +os.environ["HF_XET_HIGH_PERFORMANCE"] = "True" os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1" os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1" os.environ["BITSANDBYTES_NOWELCOME"] = "1" diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 7522732be..b456eae17 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -221,7 +221,7 @@ class PromptServer(ExecutorToClientProgress): handler_args={'max_field_size': 16380}, middlewares=middlewares) self.sockets = dict() - self.sockets_metadata = dict() + self._sockets_metadata = dict() self.web_root = ( FrontendManager.init_frontend(args.front_end_version) if args.front_end_root is None @@ -278,16 +278,16 @@ class PromptServer(ExecutorToClientProgress): sid, ) - logging.info( + logger.info( f"Feature flags negotiated for client {sid}: {client_flags}" ) first_message = False except json.JSONDecodeError: - logging.warning( + logger.warning( f"Invalid JSON received from client {sid}: {msg.data}" ) except Exception as e: - logging.error(f"Error processing WebSocket message: {e}") + logger.error(f"Error processing WebSocket message: {e}") finally: self.sockets.pop(sid, None) self.sockets_metadata.pop(sid, None) @@ -1236,3 +1236,11 @@ class PromptServer(ExecutorToClientProgress): message = encode_text_for_progress(node_id, text) self.send_sync(BinaryEventTypes.TEXT, message, sid) + + @property + def sockets_metadata(self): + return self._sockets_metadata + + @sockets_metadata.setter + def sockets_metadata(self, value): + self._sockets_metadata = value diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index 44b033ed6..8440a2e4a 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -1,10 +1,11 @@ from __future__ import annotations # for Python 3.7-3.9 -import PIL.Image import concurrent.futures import typing from enum import Enum from typing import Optional, Literal, Protocol, Union, NamedTuple, List + +import PIL.Image from typing_extensions import NotRequired, TypedDict from .encode_text_for_progress import encode_text_for_progress @@ -79,11 +80,39 @@ class DependencyExecutionErrorMessage(TypedDict): current_inputs: list[typing.Never] +class ActiveNodeProgressState(TypedDict, total=True): + value: float + max: float + # a string value from the NodeState enum + state: Literal["pending", "running", "finished", "error"] + node_id: str + prompt_id: str + display_node_id: str + parent_node_id: str + real_node_id: str + + +class ProgressStateMessage(TypedDict, total=True): + prompt_id: str + nodes: dict[str, ActiveNodeProgressState] + + ExecutedMessage = ExecutingMessage -SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed"], BinaryEventTypes, None] +SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed", "progress_state"], BinaryEventTypes, None] -SendSyncData = Union[StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, tuple[UnencodedPreviewImageMessage, PreviewImageMetadata], bytes, bytearray, str, None] +SendSyncData = Union[ProgressStateMessage, StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, tuple[PIL.Image.Image, PreviewImageMetadata], bytes, bytearray, str, None] + + +class SocketsMetadata(TypedDict, total=True): + feature_flags: dict[str, typing.Any] + + +class DefaultSocketsMetadata(TypedDict, total=True): + __unimplemented: Literal[True] + + +SocketsMetadataType = dict[str, SocketsMetadata] | DefaultSocketsMetadata class ExecutorToClientProgress(Protocol): @@ -108,6 +137,10 @@ class ExecutorToClientProgress(Protocol): """ return False + @property + def sockets_metadata(self) -> SocketsMetadataType: + return {"__unimplemented": True} + def send_sync(self, event: SendSyncEvent, data: SendSyncData, diff --git a/comfy/ldm/ace/vae/music_dcae_pipeline.py b/comfy/ldm/ace/vae/music_dcae_pipeline.py index e29e8fc2c..adce17b5e 100644 --- a/comfy/ldm/ace/vae/music_dcae_pipeline.py +++ b/comfy/ldm/ace/vae/music_dcae_pipeline.py @@ -9,7 +9,7 @@ logger = logging.getLogger(__name__) try: import torchaudio # pylint: disable=import-error except: - logger.warning("torchaudio missing, ACE model will be broken") + logger.debug("torchaudio missing, ACE model will be broken") import torchvision.transforms as transforms from .music_vocoder import ADaMoSHiFiGANV1 diff --git a/comfy/ldm/ace/vae/music_log_mel.py b/comfy/ldm/ace/vae/music_log_mel.py index 8cc07084d..87a38be5b 100755 --- a/comfy/ldm/ace/vae/music_log_mel.py +++ b/comfy/ldm/ace/vae/music_log_mel.py @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) try: from torchaudio.transforms import MelScale # pylint: disable=import-error except: - logger.warning("torchaudio missing, ACE model will be broken") + logger.debug("torchaudio missing, ACE model will be broken") from .... import model_management diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 9eab1262f..c66ad327c 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -10,6 +10,8 @@ from .... import ops ops = ops.disable_weight_init +logger = logging.getLogger(__name__) + if model_management.xformers_enabled_vae(): import xformers # pylint: disable=import-error import xformers.ops # pylint: disable=import-error @@ -242,7 +244,7 @@ def slice_attention(q, k, v): steps *= 2 if steps > 128: raise e - logging.warning("out of memory error, increasing steps and trying again {}".format(steps)) + logger.warning("out of memory error, increasing steps and trying again {}".format(steps)) return r1 @@ -296,20 +298,20 @@ def pytorch_attention(q, k, v): out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(orig_shape) except model_management.OOM_EXCEPTION: - logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") + logger.warning("scaled_dot_product_attention OOMed: switched to slice attention") out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape) return out def vae_attention(): if model_management.xformers_enabled_vae(): - logging.debug("Using xformers attention in VAE") + logger.debug("Using xformers attention in VAE") return xformers_attention elif model_management.pytorch_attention_enabled_vae(): - logging.debug("Using pytorch attention in VAE") + logger.debug("Using pytorch attention in VAE") return pytorch_attention else: - logging.debug("Using split attention in VAE") + logger.debug("Using split attention in VAE") return normal_attention @@ -650,7 +652,7 @@ class Decoder(nn.Module): block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) - logging.debug("Working with z of shape {} = {} dimensions.".format( + logger.debug("Working with z of shape {} = {} dimensions.".format( self.z_shape, np.prod(self.z_shape))) # z to block_in diff --git a/comfy/model_base.py b/comfy/model_base.py index b5f89d342..d90e50c1d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -16,13 +16,13 @@ along with this program. If not, see . """ -import math - import logging -import torch +import math from enum import Enum from typing import TypeVar, Type, Protocol, Any, Optional +import torch + from . import conds from . import latent_formats from . import model_management @@ -50,9 +50,9 @@ from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation +from .ldm.omnigen.omnigen2 import OmniGen2Transformer2DModel from .ldm.pixart.pixartms import PixArtMS from .ldm.wan.model import WanModel, VaceWanModel, CameraWanModel -from .ldm.omnigen.omnigen2 import OmniGen2Transformer2DModel from .model_management_types import ModelManageable from .model_sampling import CONST, ModelSamplingDiscreteFlow, ModelSamplingFlux, IMG_TO_IMG from .model_sampling import StableCascadeSampling, COSMOS_RFLOW, ModelSamplingCosmosRFlow, V_PREDICTION, \ @@ -60,6 +60,8 @@ from .model_sampling import StableCascadeSampling, COSMOS_RFLOW, ModelSamplingCo from .ops import Operations from .patcher_extension import WrapperExecutor, WrappersMP, get_all_wrappers +logger = logging.getLogger(__name__) + class ModelType(Enum): EPS = 1 @@ -149,8 +151,8 @@ class BaseModel(torch.nn.Module): self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) if model_management.force_channels_last(): self.diffusion_model.to(memory_format=torch.channels_last) - logging.debug("using channels last mode for diffusion model") - logging.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype)) + logger.debug("using channels last mode for diffusion model") + logger.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype)) else: self.operations = None self.model_type = model_type @@ -161,8 +163,8 @@ class BaseModel(torch.nn.Module): self.adm_channels = 0 self.concat_keys = () - logging.debug("model_type {}".format(model_type.name)) - logging.debug("adm {}".format(self.adm_channels)) + logger.debug("model_type {}".format(model_type.name)) + logger.debug("adm {}".format(self.adm_channels)) self.memory_usage_factor = model_config.memory_usage_factor self.memory_usage_factor_conds = () self.training = False @@ -310,10 +312,10 @@ class BaseModel(torch.nn.Module): to_load = self.model_config.process_unet_state_dict(to_load) m, u = self.diffusion_model.load_state_dict(to_load, strict=False) if len(m) > 0: - logging.warning("unet missing: {}".format(m)) + logger.warning("unet missing: {}".format(m)) if len(u) > 0: - logging.warning("unet unexpected: {}".format(u)) + logger.warning("unet unexpected: {}".format(u)) del to_load return self @@ -1227,6 +1229,7 @@ class WAN21_Camera(WAN21): out['camera_conditions'] = conds.CONDRegular(camera_conditions) return out + class WAN22(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=WanModel) @@ -1252,6 +1255,7 @@ class WAN22(BaseModel): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image + class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=Hunyuan3Dv2Model) @@ -1321,6 +1325,7 @@ class ACEStep(BaseModel): out['lyrics_strength'] = conds.CONDConstant(kwargs.get("lyrics_strength", 1.0)) return out + class Omnigen2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=OmniGen2Transformer2DModel) diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 741240460..316cec02e 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -14,7 +14,7 @@ from pathlib import Path from typing import List, Optional, Final, Set # enable better transfer -os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +os.environ["HF_XET_HIGH_PERFORMANCE"] = "True" import tqdm from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem @@ -486,6 +486,7 @@ KNOWN_VAES: Final[KnownDownloadables] = KnownDownloadables([ HuggingFile("comfyanonymous/cosmos_1.0_text_encoder_and_VAE_ComfyUI", "vae/cosmos_cv8x8x8_1.0.safetensors"), HuggingFile("Comfy-Org/Lumina_Image_2.0_Repackaged", "split_files/vae/ae.safetensors", save_with_filename="lumina_image_2.0-ae.safetensors"), HuggingFile("Comfy-Org/Wan_2.1_ComfyUI_repackaged", "split_files/vae/wan_2.1_vae.safetensors"), + HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/vae/wan2.2_vae.safetensors"), ], folder_name="vae") KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = { @@ -546,6 +547,15 @@ KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([ HuggingFile("Comfy-Org/Cosmos_Predict2_repackaged", "cosmos_predict2_2B_video2world_480p_16fps.safetensors"), HuggingFile("Comfy-Org/Wan_2.1_ComfyUI_repackaged", "split_files/diffusion_models/wan2.1_vace_14B_fp16.safetensors"), HuggingFile("Comfy-Org/Wan_2.1_ComfyUI_repackaged", "split_files/diffusion_models/wan2.1_fun_camera_v1.1_1.3B_bf16.safetensors"), + HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp16.safetensors"), + HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors"), + HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp16.safetensors"), + HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors"), + HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_high_noise_14B_fp16.safetensors"), + HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors"), + HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp16.safetensors"), + HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors"), + HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors"), HuggingFile("lodestones/Chroma", "chroma-unlocked-v37.safetensors"), ], folder_names=["diffusion_models", "unet"]) diff --git a/comfy/model_management.py b/comfy/model_management.py index 961e64ecd..185f6ba7e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -742,7 +742,6 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0 loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) current_loaded_models.insert(0, loaded_model) - logger.debug(f"Loaded {loaded_model}") span = get_current_span() span.set_attribute("models_to_load", list(map(str, models_to_load))) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 852c8c7de..4e11eb403 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -29,6 +29,7 @@ from typing import Callable, Optional import torch import torch.nn from humanize import naturalsize +from natsort import natsorted from . import model_management, lora from . import patcher_extension @@ -119,6 +120,7 @@ def wipe_lowvram_weight(m): if hasattr(m, "bias_function"): m.bias_function = [] + def move_weight_functions(m, device): if device is None: return 0 @@ -289,7 +291,7 @@ class ModelPatcher(ModelManageable): return self._force_cast_weights @force_cast_weights.setter - def force_cast_weights(self, value:bool) -> None: + def force_cast_weights(self, value: bool) -> None: self._force_cast_weights = value def lowvram_patch_counter(self): @@ -475,7 +477,7 @@ class ModelPatcher(ModelManageable): self.add_object_patch("manual_cast_dtype", dtype) if dtype is not None: self.force_cast_weights = True - self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this + self.patches_uuid = uuid.uuid4() # TODO: optimize by preventing a full model reload for this def add_weight_wrapper(self, name, function): self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function] @@ -630,7 +632,6 @@ class ModelPatcher(ModelManageable): else: set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) - def _load_list(self) -> list[LoadingListItem]: loading = [] for n, m in self.model.named_modules(): @@ -715,6 +716,7 @@ class ModelPatcher(ModelManageable): mem_counter += move_weight_functions(m, device_to) load_completely.sort(reverse=True) + models_loaded_regularly: list[str] = [] for x in load_completely: n = x.name m = x.module @@ -726,17 +728,17 @@ class ModelPatcher(ModelManageable): for param in params: self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to) - logger.debug("lowvram: loaded module regularly {} {}".format(n, m)) + models_loaded_regularly.append("name={} module={}".format(n, m)) m.comfy_patched_weights = True - + logger.debug("lowvram: loaded module regularly: {}".format(", ".join(models_loaded_regularly))) for x in load_completely: x.module.to(device_to) if lowvram_counter > 0: - logger.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) + logger.debug("loaded partially lowvram_model_memory={}MB mem_counter={}MB patch_counter={}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) self._memory_measurements.model_lowvram = True else: - logger.debug("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) + logger.debug("loaded completely lowvram_model_memory={}MB mem_counter={}MB full_load={}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) self._memory_measurements.model_lowvram = False if full_load: self.model.to(device_to) @@ -812,6 +814,7 @@ class ModelPatcher(ModelManageable): self.object_patches_backup.clear() def partially_unload(self, device_to, memory_to_free=0): + freed_layers: list[str] = [] with self.use_ejected(): hooks_unpatched = False memory_freed = 0 @@ -867,7 +870,9 @@ class ModelPatcher(ModelManageable): m.comfy_cast_weights = True m.comfy_patched_weights = False memory_freed += module_mem - logging.debug("freed {}".format(n)) + freed_layers.append(n) + + logger.debug("freed {}".format(natsorted(freed_layers))) self._memory_measurements.model_lowvram = True self._memory_measurements.lowvram_patch_counter += patch_counter @@ -1190,7 +1195,7 @@ class ModelPatcher(ModelManageable): model_sd_keys_set = set(model_sd_keys) for key in cached_weights: if key not in model_sd_keys: - logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}") + logger.warning(f"Cached hook could not patch. Key does not exist in model: {key}") continue self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter) model_sd_keys_set.remove(key) @@ -1203,7 +1208,7 @@ class ModelPatcher(ModelManageable): original_weights = self.get_key_patches() for key in relevant_patches: if key not in model_sd_keys: - logging.warning(f"Cached hook would not patch. Key does not exist in model: {key}") + logger.warning(f"Cached hook would not patch. Key does not exist in model: {key}") continue self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights, memory_counter=memory_counter) @@ -1265,7 +1270,7 @@ class ModelPatcher(ModelManageable): del out_weight del weight - def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None: + def unpatch_hooks(self, whitelist_keys_set: set[str] = None) -> None: with self.use_ejected(): if len(self.hook_backup) == 0: self.current_hooks = None diff --git a/comfy/progress.py b/comfy/progress.py index 1711711f6..84d7e61fb 100644 --- a/comfy/progress.py +++ b/comfy/progress.py @@ -7,9 +7,10 @@ from PIL import Image from tqdm import tqdm from typing_extensions import override +from .component_model.executor_types import ExecutorToClientProgress from .component_model.module_property import create_module_properties from .execution_context import current_execution_context -from .progress_types import AbstractProgressRegistry +from .progress_types import AbstractProgressRegistry, PreviewImageMetadata if TYPE_CHECKING: from comfy_execution.graph import DynamicPrompt @@ -157,7 +158,7 @@ class WebUIProgressHandler(ProgressHandler): Handler that sends progress updates to the WebUI via WebSockets. """ - def __init__(self, server_instance): + def __init__(self, server_instance: ExecutorToClientProgress): super().__init__("webui") self.server_instance = server_instance @@ -216,7 +217,7 @@ class WebUIProgressHandler(ProgressHandler): self.server_instance.client_id, "supports_preview_metadata", ): - metadata = { + metadata: PreviewImageMetadata = { "node_id": node_id, "prompt_id": prompt_id, "display_node_id": self.registry.dynprompt.get_display_node_id( @@ -327,7 +328,7 @@ class ProgressRegistry(AbstractProgressRegistry): # Global registry instance @_module_properties.getter -def _global_progress_registry() -> ProgressRegistry: +def _global_progress_registry() -> AbstractProgressRegistry | None: return current_execution_context().progress_registry diff --git a/comfy/sd.py b/comfy/sd.py index d1ea579c0..74569c501 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -41,23 +41,23 @@ from .model_management import load_models_gpu from .model_patcher import ModelPatcher from .t2i_adapter import adapter from .taesd import taesd +from .text_encoders import ace from .text_encoders import aura_t5 -from .text_encoders import hidream from .text_encoders import cosmos from .text_encoders import flux from .text_encoders import genmo +from .text_encoders import hidream from .text_encoders import hunyuan_video from .text_encoders import hydit from .text_encoders import long_clipl from .text_encoders import lt from .text_encoders import lumina2 +from .text_encoders import omnigen2 from .text_encoders import pixart_t5 from .text_encoders import sa_t5 from .text_encoders import sd2_clip from .text_encoders import sd3_clip from .text_encoders import wan -from .text_encoders import ace -from .text_encoders import omnigen2 from .utils import ProgressBar logger = logging.getLogger(__name__) @@ -280,7 +280,9 @@ class CLIP: class VAE: - def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): + def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None, no_init=False): + if no_init: + return if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): # diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) @@ -469,7 +471,7 @@ class VAE: ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post} self.first_stage_model = ShapeVAE(**ddconfig) self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] - elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio + elif "vocoder.backbone.channel_layers.0.0.bias" in sd: # Ace Step Audio self.first_stage_model = MusicDCAE(source_sample_rate=44100) self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype) @@ -511,6 +513,29 @@ class VAE: self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) logger.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) + def clone(self): + n = VAE(no_init=True) + n.memory_used_encode = self.memory_used_encode + n.memory_used_decode = self.memory_used_decode + n.downscale_ratio = self.downscale_ratio + n.upscale_ratio = self.upscale_ratio + n.latent_channels = self.latent_channels + n.latent_dim = self.latent_dim + n.output_channels = self.output_channels + n.process_input = self.process_input + n.process_output = self.process_output + n.working_dtypes = self.working_dtypes.copy() + n.disable_offload = self.disable_offload + n.downscale_index_formula = self.downscale_index_formula + n.upscale_index_formula = self.upscale_index_formula + n.extra_1d_channel = self.extra_1d_channel + n.first_stage_model = self.first_stage_model + n.device = self.device + n.vae_dtype = self.vae_dtype + n.output_device = self.output_device + n.patcher = self.patcher.clone() + return n + def throw_exception_if_invalid(self): if self.first_stage_model is None: raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") @@ -920,7 +945,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.HIDREAM: clip_target.clip = hidream.hidream_clip(**t5xxl_detect(clip_data), - clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None) + clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None) clip_target.tokenizer = hidream.HiDreamTokenizer else: # CLIPType.MOCHI clip_target.clip = genmo.mochi_te(**t5xxl_detect(clip_data)) @@ -945,7 +970,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif te_model == TEModel.LLAMA3_8: clip_target.clip = hidream.hidream_clip(**llama_detect(clip_data), - clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None) + clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None) clip_target.tokenizer = hidream.HiDreamTokenizer elif te_model == TEModel.QWEN25_3B: clip_target.clip = omnigen2.te(**llama_detect(clip_data)) @@ -1033,6 +1058,7 @@ def model_detection_error_hint(path, state_dict): return "\nHINT: This seems to be a Lora file and Lora files should be put in the lora folder and loaded with a lora loader node.." return "" + def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): logger.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.") model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True) @@ -1097,7 +1123,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return None return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used' - unet_weight_dtype = list(model_config.supported_inference_dtypes) if model_config.scaled_fp8 is not None: weight_dtype = None diff --git a/comfy/utils.py b/comfy/utils.py index 9844c3ac2..92dcf71e4 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -41,6 +41,7 @@ from einops import rearrange from torch.nn.functional import interpolate from tqdm import tqdm +from comfy_api import feature_flags from . import interruption, checkpoint_pickle from .cli_args import args from .component_model import files @@ -48,6 +49,7 @@ from .component_model.deprecation import _deprecate_method from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage from .component_model.queue_types import BinaryEventTypes from .execution_context import current_execution_context +from .progress import get_progress_state MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -1106,22 +1108,23 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amou return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) -def _progress_bar_update(value: float, total: float, preview_image_or_data: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None, node_id: str = None): - server = server or current_execution_context().server - # todo: this should really be from the context. right now the server is behaving like a context - client_id = client_id or server.client_id +def _progress_bar_update(value: float, total: float, preview_image_or_data: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None, node_id: str = None, prompt_id: str = None): + context = current_execution_context() + server = server or context.server + executing_context = context + prompt_id = prompt_id or executing_context.task_id or server.last_prompt_id + node_id = node_id or executing_context.node_id or server.last_node_id interruption.throw_exception_if_processing_interrupted() - progress: ProgressMessage = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": node_id or server.last_node_id} + + progress: ProgressMessage = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id} + # todo: is this still necessary? if isinstance(preview_image_or_data, dict): progress["output"] = preview_image_or_data + # this is responsible for encoding the image + get_progress_state().update_progress(node_id, value, total, preview_image_or_data) server.send_sync("progress", progress, client_id) - # todo: investigate a better way to send the image data, since it needs the node ID - if preview_image_or_data is not None and not isinstance(preview_image_or_data, dict): - server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image_or_data, client_id) - - def set_progress_bar_enabled(enabled: bool): warnings.warn( "The global method 'set_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.", diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index 0d4389a6e..c6ccc4ec2 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -43,7 +43,8 @@ def get_connection_feature( def supports_feature( sockets_metadata: Dict[str, Dict[str, Any]], sid: str, - feature_name: str + feature_name: str, + force=True, ) -> bool: """ Check if a connection supports a specific feature. @@ -52,10 +53,13 @@ def supports_feature( sockets_metadata: Dictionary of socket metadata sid: Session ID of the connection feature_name: Name of the feature to check + force (bool): If it cannot be determined, assume True Returns: Boolean indicating if feature is supported """ + if sockets_metadata is None or "__unimplemented" in sockets_metadata: + return force return get_connection_feature(sockets_metadata, sid, feature_name, False) is True diff --git a/comfy_extras/nodes/nodes_torch_compile.py b/comfy_extras/nodes/nodes_torch_compile.py index 9bac7854e..e93a06ce9 100644 --- a/comfy_extras/nodes/nodes_torch_compile.py +++ b/comfy_extras/nodes/nodes_torch_compile.py @@ -11,6 +11,7 @@ from comfy import model_management from comfy.language.transformers_model_management import TransformersManagedModel from comfy.model_patcher import ModelPatcher from comfy.nodes.package_typing import CustomNode, InputTypes +from comfy.sd import VAE from comfy_api.torch_helpers import set_torch_compile_wrapper logger = logging.getLogger(__name__) @@ -45,6 +46,7 @@ def write_atomic( torch._inductor.codecache.write_atomic = write_atomic +# torch._inductor.utils.is_big_gpu = lambda *args: True class TorchCompileModel(CustomNode): @@ -52,10 +54,10 @@ class TorchCompileModel(CustomNode): def INPUT_TYPES(s): return { "required": { - "model": ("MODEL",), + "model": ("MODEL,VAE",), }, "optional": { - "object_patch": ("STRING", {"default": DIFFUSION_MODEL}), + "object_patch": ("STRING", {"default": ""}), "fullgraph": ("BOOLEAN", {"default": False}), "dynamic": ("BOOLEAN", {"default": False}), "backend": (TORCH_COMPILE_BACKENDS, {"default": "inductor"}), @@ -64,15 +66,14 @@ class TorchCompileModel(CustomNode): } } - RETURN_TYPES = ("MODEL",) + RETURN_TYPES = ("MODEL,VAE",) FUNCTION = "patch" + RETURN_NAMES = ("model or vae",) CATEGORY = "_for_testing" EXPERIMENTAL = True - def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: str = "max-autotune", torch_tensorrt_optimization_level: int = 3) -> tuple[Callable]: - if object_patch is None: - object_patch = DIFFUSION_MODEL + def patch(self, model: ModelPatcher | VAE | torch.nn.Module, object_patch: str | None = "", fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: str = "max-autotune", torch_tensorrt_optimization_level: int = 3) -> tuple[Callable]: compile_kwargs = { "fullgraph": fullgraph, "dynamic": dynamic, @@ -99,17 +100,26 @@ class TorchCompileModel(CustomNode): } move_to_gpu = True del compile_kwargs["mode"] - if isinstance(model, ModelPatcher) or isinstance(model, TransformersManagedModel): - m = model.clone() + if isinstance(model, (ModelPatcher, TransformersManagedModel, VAE)): + to_return = model.clone() + object_patches = [p.strip() for p in object_patch.split(",")] + patcher: ModelPatcher + if isinstance(to_return, VAE): + patcher = to_return.patcher + object_patches = ["encoder", "decoder"] + else: + patcher = to_return + if object_patch is None or len(object_patches) == 0: + object_patches = [DIFFUSION_MODEL] if move_to_gpu: model_management.unload_all_models() - model_management.load_models_gpu([m]) - set_torch_compile_wrapper(m, object_patch=object_patch, **compile_kwargs) - m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs)) + model_management.load_models_gpu([patcher]) + set_torch_compile_wrapper(patcher, keys=object_patches, **compile_kwargs) + # m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs)) # todo: do we want to move something back off the GPU? # if move_to_gpu: # model_management.unload_all_models() - return m, + return to_return, elif isinstance(model, torch.nn.Module): if move_to_gpu: model_management.unload_all_models() @@ -119,7 +129,7 @@ class TorchCompileModel(CustomNode): model.to(device=model_management.unet_offload_device()) return res, else: - logging.warning("Encountered a model that cannot be compiled") + logger.warning("Encountered a model that cannot be compiled") return model, except OSError as os_error: try: @@ -132,7 +142,7 @@ class TorchCompileModel(CustomNode): torch._inductor.utils.clear_inductor_caches() # pylint: disable=no-member except Exception: pass - logging.error(f"An exception occurred while trying to compile {str(model)}, gracefully skipping compilation", exc_info=exc_info) + logger.error(f"An exception occurred while trying to compile {str(model)}, gracefully skipping compilation", exc_info=exc_info) return model, @@ -160,7 +170,7 @@ class QuantizeModel(CustomNode): RETURN_TYPES = ("MODEL",) def warn_in_place(self, model: ModelPatcher): - logging.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.") + logger.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.") def execute(self, model: ModelPatcher, strategy: str = _QUANTIZATION_STRATEGIES[0]) -> tuple[ModelPatcher]: model = model.clone() @@ -179,7 +189,7 @@ class QuantizeModel(CustomNode): "final_layer", } if strategy == "quanto": - logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations") + logger.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations") self.warn_in_place(model) from optimum.quanto import quantize, qint8 # pylint: disable=import-error exclusion_list = [ diff --git a/pyproject.toml b/pyproject.toml index 77e4eefac..e610094e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ dependencies = [ "pyjwt[crypto]", "kornia>=0.7.0", "mpmath>=1.0,!=1.4.0a0", - "huggingface_hub[hf_transfer]>0.20", + "huggingface_hub[hf_xet]>=0.32.0", "lazy-object-proxy", "lazy_loader>=0.3", "can_ada", @@ -76,7 +76,8 @@ dependencies = [ "wrapt>=1.16.0", "certifi", "spandrel>=0.3.4", - "numpy>=1.24.4", + # https://github.com/conda-forge/numba-feedstock/issues/158 until numba is released with support for a later version of numpy + "numpy>=1.24.4,<2.3", "soundfile", "watchdog", "PySoundFile", diff --git a/tests/conftest.py b/tests/conftest.py index 552813bea..0d78a3613 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,7 @@ import requests os.environ['OTEL_METRICS_EXPORTER'] = 'none' os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" -os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +os.environ["HF_XET_HIGH_PERFORMANCE"] = "True" # fixes issues with running the testcontainers rabbitmqcontainer on Windows os.environ["TC_HOST"] = "localhost"