Improvements for Wan 2.2 support

- add xet support and add the xet cache to manageable directories
 - xet is enabled by default
 - fix logging to root in various places
 - improve logging about model unloading and loading
 - TorchCompileNode now supports the VAE
 - torchaudio missing will cause less noise in the logs
 - feature flags will assume to be supporting everything in the distributed progress context
 - fixes progress notifications
This commit is contained in:
doctorpangloss 2025-07-28 14:36:27 -07:00
parent b5a50301f6
commit 03e5430121
19 changed files with 192 additions and 82 deletions

View File

@ -84,7 +84,7 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter(
"%(asctime)s [%(name)s] [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
"%(asctime)s [%(levelname)s] [%(name)s] [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
))

View File

@ -86,6 +86,9 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio
if "HF_HUB_CACHE" in os.environ:
hf_cache_paths.additional_absolute_directory_paths.append(os.environ.get("HF_HUB_CACHE"))
hf_xet = ModelPaths(["xet"], supported_extensions=set())
if "HF_XET_CACHE" in os.environ:
hf_xet.additional_absolute_directory_paths.append(os.environ.get("HF_XET_CACHE"))
model_paths_to_add = [
ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["configs"], additional_absolute_directory_paths=[get_package_as_path("comfy.configs")], supported_extensions={".yaml"}),
@ -107,6 +110,7 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio
ModelPaths(["classifiers"], supported_extensions=set()),
ModelPaths(["huggingface"], supported_extensions=set()),
hf_cache_paths,
hf_xet,
]
for model_paths in model_paths_to_add:
if replace_existing:

View File

@ -18,7 +18,7 @@ from .. import options
from ..app import logger
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
os.environ["BITSANDBYTES_NOWELCOME"] = "1"

View File

@ -221,7 +221,7 @@ class PromptServer(ExecutorToClientProgress):
handler_args={'max_field_size': 16380},
middlewares=middlewares)
self.sockets = dict()
self.sockets_metadata = dict()
self._sockets_metadata = dict()
self.web_root = (
FrontendManager.init_frontend(args.front_end_version)
if args.front_end_root is None
@ -278,16 +278,16 @@ class PromptServer(ExecutorToClientProgress):
sid,
)
logging.info(
logger.info(
f"Feature flags negotiated for client {sid}: {client_flags}"
)
first_message = False
except json.JSONDecodeError:
logging.warning(
logger.warning(
f"Invalid JSON received from client {sid}: {msg.data}"
)
except Exception as e:
logging.error(f"Error processing WebSocket message: {e}")
logger.error(f"Error processing WebSocket message: {e}")
finally:
self.sockets.pop(sid, None)
self.sockets_metadata.pop(sid, None)
@ -1236,3 +1236,11 @@ class PromptServer(ExecutorToClientProgress):
message = encode_text_for_progress(node_id, text)
self.send_sync(BinaryEventTypes.TEXT, message, sid)
@property
def sockets_metadata(self):
return self._sockets_metadata
@sockets_metadata.setter
def sockets_metadata(self, value):
self._sockets_metadata = value

View File

@ -1,10 +1,11 @@
from __future__ import annotations # for Python 3.7-3.9
import PIL.Image
import concurrent.futures
import typing
from enum import Enum
from typing import Optional, Literal, Protocol, Union, NamedTuple, List
import PIL.Image
from typing_extensions import NotRequired, TypedDict
from .encode_text_for_progress import encode_text_for_progress
@ -79,11 +80,39 @@ class DependencyExecutionErrorMessage(TypedDict):
current_inputs: list[typing.Never]
class ActiveNodeProgressState(TypedDict, total=True):
value: float
max: float
# a string value from the NodeState enum
state: Literal["pending", "running", "finished", "error"]
node_id: str
prompt_id: str
display_node_id: str
parent_node_id: str
real_node_id: str
class ProgressStateMessage(TypedDict, total=True):
prompt_id: str
nodes: dict[str, ActiveNodeProgressState]
ExecutedMessage = ExecutingMessage
SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed"], BinaryEventTypes, None]
SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed", "progress_state"], BinaryEventTypes, None]
SendSyncData = Union[StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, tuple[UnencodedPreviewImageMessage, PreviewImageMetadata], bytes, bytearray, str, None]
SendSyncData = Union[ProgressStateMessage, StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, tuple[PIL.Image.Image, PreviewImageMetadata], bytes, bytearray, str, None]
class SocketsMetadata(TypedDict, total=True):
feature_flags: dict[str, typing.Any]
class DefaultSocketsMetadata(TypedDict, total=True):
__unimplemented: Literal[True]
SocketsMetadataType = dict[str, SocketsMetadata] | DefaultSocketsMetadata
class ExecutorToClientProgress(Protocol):
@ -108,6 +137,10 @@ class ExecutorToClientProgress(Protocol):
"""
return False
@property
def sockets_metadata(self) -> SocketsMetadataType:
return {"__unimplemented": True}
def send_sync(self,
event: SendSyncEvent,
data: SendSyncData,

View File

@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
try:
import torchaudio # pylint: disable=import-error
except:
logger.warning("torchaudio missing, ACE model will be broken")
logger.debug("torchaudio missing, ACE model will be broken")
import torchvision.transforms as transforms
from .music_vocoder import ADaMoSHiFiGANV1

View File

@ -10,7 +10,7 @@ logger = logging.getLogger(__name__)
try:
from torchaudio.transforms import MelScale # pylint: disable=import-error
except:
logger.warning("torchaudio missing, ACE model will be broken")
logger.debug("torchaudio missing, ACE model will be broken")
from .... import model_management

View File

@ -10,6 +10,8 @@ from .... import ops
ops = ops.disable_weight_init
logger = logging.getLogger(__name__)
if model_management.xformers_enabled_vae():
import xformers # pylint: disable=import-error
import xformers.ops # pylint: disable=import-error
@ -242,7 +244,7 @@ def slice_attention(q, k, v):
steps *= 2
if steps > 128:
raise e
logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
logger.warning("out of memory error, increasing steps and trying again {}".format(steps))
return r1
@ -296,20 +298,20 @@ def pytorch_attention(q, k, v):
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
logger.warning("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out
def vae_attention():
if model_management.xformers_enabled_vae():
logging.debug("Using xformers attention in VAE")
logger.debug("Using xformers attention in VAE")
return xformers_attention
elif model_management.pytorch_attention_enabled_vae():
logging.debug("Using pytorch attention in VAE")
logger.debug("Using pytorch attention in VAE")
return pytorch_attention
else:
logging.debug("Using split attention in VAE")
logger.debug("Using split attention in VAE")
return normal_attention
@ -650,7 +652,7 @@ class Decoder(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
logging.debug("Working with z of shape {} = {} dimensions.".format(
logger.debug("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in

View File

@ -16,13 +16,13 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import math
import logging
import torch
import math
from enum import Enum
from typing import TypeVar, Type, Protocol, Any, Optional
import torch
from . import conds
from . import latent_formats
from . import model_management
@ -50,9 +50,9 @@ from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from .ldm.omnigen.omnigen2 import OmniGen2Transformer2DModel
from .ldm.pixart.pixartms import PixArtMS
from .ldm.wan.model import WanModel, VaceWanModel, CameraWanModel
from .ldm.omnigen.omnigen2 import OmniGen2Transformer2DModel
from .model_management_types import ModelManageable
from .model_sampling import CONST, ModelSamplingDiscreteFlow, ModelSamplingFlux, IMG_TO_IMG
from .model_sampling import StableCascadeSampling, COSMOS_RFLOW, ModelSamplingCosmosRFlow, V_PREDICTION, \
@ -60,6 +60,8 @@ from .model_sampling import StableCascadeSampling, COSMOS_RFLOW, ModelSamplingCo
from .ops import Operations
from .patcher_extension import WrapperExecutor, WrappersMP, get_all_wrappers
logger = logging.getLogger(__name__)
class ModelType(Enum):
EPS = 1
@ -149,8 +151,8 @@ class BaseModel(torch.nn.Module):
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
if model_management.force_channels_last():
self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model")
logging.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
logger.debug("using channels last mode for diffusion model")
logger.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
else:
self.operations = None
self.model_type = model_type
@ -161,8 +163,8 @@ class BaseModel(torch.nn.Module):
self.adm_channels = 0
self.concat_keys = ()
logging.debug("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels))
logger.debug("model_type {}".format(model_type.name))
logger.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor
self.memory_usage_factor_conds = ()
self.training = False
@ -310,10 +312,10 @@ class BaseModel(torch.nn.Module):
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0:
logging.warning("unet missing: {}".format(m))
logger.warning("unet missing: {}".format(m))
if len(u) > 0:
logging.warning("unet unexpected: {}".format(u))
logger.warning("unet unexpected: {}".format(u))
del to_load
return self
@ -1227,6 +1229,7 @@ class WAN21_Camera(WAN21):
out['camera_conditions'] = conds.CONDRegular(camera_conditions)
return out
class WAN22(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=WanModel)
@ -1252,6 +1255,7 @@ class WAN22(BaseModel):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=Hunyuan3Dv2Model)
@ -1321,6 +1325,7 @@ class ACEStep(BaseModel):
out['lyrics_strength'] = conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
return out
class Omnigen2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=OmniGen2Transformer2DModel)

View File

@ -14,7 +14,7 @@ from pathlib import Path
from typing import List, Optional, Final, Set
# enable better transfer
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
import tqdm
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem
@ -486,6 +486,7 @@ KNOWN_VAES: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("comfyanonymous/cosmos_1.0_text_encoder_and_VAE_ComfyUI", "vae/cosmos_cv8x8x8_1.0.safetensors"),
HuggingFile("Comfy-Org/Lumina_Image_2.0_Repackaged", "split_files/vae/ae.safetensors", save_with_filename="lumina_image_2.0-ae.safetensors"),
HuggingFile("Comfy-Org/Wan_2.1_ComfyUI_repackaged", "split_files/vae/wan_2.1_vae.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/vae/wan2.2_vae.safetensors"),
], folder_name="vae")
KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = {
@ -546,6 +547,15 @@ KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("Comfy-Org/Cosmos_Predict2_repackaged", "cosmos_predict2_2B_video2world_480p_16fps.safetensors"),
HuggingFile("Comfy-Org/Wan_2.1_ComfyUI_repackaged", "split_files/diffusion_models/wan2.1_vace_14B_fp16.safetensors"),
HuggingFile("Comfy-Org/Wan_2.1_ComfyUI_repackaged", "split_files/diffusion_models/wan2.1_fun_camera_v1.1_1.3B_bf16.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp16.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp16.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_high_noise_14B_fp16.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp16.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors"),
HuggingFile("lodestones/Chroma", "chroma-unlocked-v37.safetensors"),
], folder_names=["diffusion_models", "unet"])

View File

@ -742,7 +742,6 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
logger.debug(f"Loaded {loaded_model}")
span = get_current_span()
span.set_attribute("models_to_load", list(map(str, models_to_load)))

View File

@ -29,6 +29,7 @@ from typing import Callable, Optional
import torch
import torch.nn
from humanize import naturalsize
from natsort import natsorted
from . import model_management, lora
from . import patcher_extension
@ -119,6 +120,7 @@ def wipe_lowvram_weight(m):
if hasattr(m, "bias_function"):
m.bias_function = []
def move_weight_functions(m, device):
if device is None:
return 0
@ -289,7 +291,7 @@ class ModelPatcher(ModelManageable):
return self._force_cast_weights
@force_cast_weights.setter
def force_cast_weights(self, value:bool) -> None:
def force_cast_weights(self, value: bool) -> None:
self._force_cast_weights = value
def lowvram_patch_counter(self):
@ -475,7 +477,7 @@ class ModelPatcher(ModelManageable):
self.add_object_patch("manual_cast_dtype", dtype)
if dtype is not None:
self.force_cast_weights = True
self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this
self.patches_uuid = uuid.uuid4() # TODO: optimize by preventing a full model reload for this
def add_weight_wrapper(self, name, function):
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
@ -630,7 +632,6 @@ class ModelPatcher(ModelManageable):
else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def _load_list(self) -> list[LoadingListItem]:
loading = []
for n, m in self.model.named_modules():
@ -715,6 +716,7 @@ class ModelPatcher(ModelManageable):
mem_counter += move_weight_functions(m, device_to)
load_completely.sort(reverse=True)
models_loaded_regularly: list[str] = []
for x in load_completely:
n = x.name
m = x.module
@ -726,17 +728,17 @@ class ModelPatcher(ModelManageable):
for param in params:
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
logger.debug("lowvram: loaded module regularly {} {}".format(n, m))
models_loaded_regularly.append("name={} module={}".format(n, m))
m.comfy_patched_weights = True
logger.debug("lowvram: loaded module regularly: {}".format(", ".join(models_loaded_regularly)))
for x in load_completely:
x.module.to(device_to)
if lowvram_counter > 0:
logger.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
logger.debug("loaded partially lowvram_model_memory={}MB mem_counter={}MB patch_counter={}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self._memory_measurements.model_lowvram = True
else:
logger.debug("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
logger.debug("loaded completely lowvram_model_memory={}MB mem_counter={}MB full_load={}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self._memory_measurements.model_lowvram = False
if full_load:
self.model.to(device_to)
@ -812,6 +814,7 @@ class ModelPatcher(ModelManageable):
self.object_patches_backup.clear()
def partially_unload(self, device_to, memory_to_free=0):
freed_layers: list[str] = []
with self.use_ejected():
hooks_unpatched = False
memory_freed = 0
@ -867,7 +870,9 @@ class ModelPatcher(ModelManageable):
m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem
logging.debug("freed {}".format(n))
freed_layers.append(n)
logger.debug("freed {}".format(natsorted(freed_layers)))
self._memory_measurements.model_lowvram = True
self._memory_measurements.lowvram_patch_counter += patch_counter
@ -1190,7 +1195,7 @@ class ModelPatcher(ModelManageable):
model_sd_keys_set = set(model_sd_keys)
for key in cached_weights:
if key not in model_sd_keys:
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
logger.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
continue
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
model_sd_keys_set.remove(key)
@ -1203,7 +1208,7 @@ class ModelPatcher(ModelManageable):
original_weights = self.get_key_patches()
for key in relevant_patches:
if key not in model_sd_keys:
logging.warning(f"Cached hook would not patch. Key does not exist in model: {key}")
logger.warning(f"Cached hook would not patch. Key does not exist in model: {key}")
continue
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
memory_counter=memory_counter)
@ -1265,7 +1270,7 @@ class ModelPatcher(ModelManageable):
del out_weight
del weight
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
def unpatch_hooks(self, whitelist_keys_set: set[str] = None) -> None:
with self.use_ejected():
if len(self.hook_backup) == 0:
self.current_hooks = None

View File

@ -7,9 +7,10 @@ from PIL import Image
from tqdm import tqdm
from typing_extensions import override
from .component_model.executor_types import ExecutorToClientProgress
from .component_model.module_property import create_module_properties
from .execution_context import current_execution_context
from .progress_types import AbstractProgressRegistry
from .progress_types import AbstractProgressRegistry, PreviewImageMetadata
if TYPE_CHECKING:
from comfy_execution.graph import DynamicPrompt
@ -157,7 +158,7 @@ class WebUIProgressHandler(ProgressHandler):
Handler that sends progress updates to the WebUI via WebSockets.
"""
def __init__(self, server_instance):
def __init__(self, server_instance: ExecutorToClientProgress):
super().__init__("webui")
self.server_instance = server_instance
@ -216,7 +217,7 @@ class WebUIProgressHandler(ProgressHandler):
self.server_instance.client_id,
"supports_preview_metadata",
):
metadata = {
metadata: PreviewImageMetadata = {
"node_id": node_id,
"prompt_id": prompt_id,
"display_node_id": self.registry.dynprompt.get_display_node_id(
@ -327,7 +328,7 @@ class ProgressRegistry(AbstractProgressRegistry):
# Global registry instance
@_module_properties.getter
def _global_progress_registry() -> ProgressRegistry:
def _global_progress_registry() -> AbstractProgressRegistry | None:
return current_execution_context().progress_registry

View File

@ -41,23 +41,23 @@ from .model_management import load_models_gpu
from .model_patcher import ModelPatcher
from .t2i_adapter import adapter
from .taesd import taesd
from .text_encoders import ace
from .text_encoders import aura_t5
from .text_encoders import hidream
from .text_encoders import cosmos
from .text_encoders import flux
from .text_encoders import genmo
from .text_encoders import hidream
from .text_encoders import hunyuan_video
from .text_encoders import hydit
from .text_encoders import long_clipl
from .text_encoders import lt
from .text_encoders import lumina2
from .text_encoders import omnigen2
from .text_encoders import pixart_t5
from .text_encoders import sa_t5
from .text_encoders import sd2_clip
from .text_encoders import sd3_clip
from .text_encoders import wan
from .text_encoders import ace
from .text_encoders import omnigen2
from .utils import ProgressBar
logger = logging.getLogger(__name__)
@ -280,7 +280,9 @@ class CLIP:
class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None, no_init=False):
if no_init:
return
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): # diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
@ -469,7 +471,7 @@ class VAE:
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
self.first_stage_model = ShapeVAE(**ddconfig)
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: # Ace Step Audio
self.first_stage_model = MusicDCAE(source_sample_rate=44100)
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
@ -511,6 +513,29 @@ class VAE:
self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logger.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
def clone(self):
n = VAE(no_init=True)
n.memory_used_encode = self.memory_used_encode
n.memory_used_decode = self.memory_used_decode
n.downscale_ratio = self.downscale_ratio
n.upscale_ratio = self.upscale_ratio
n.latent_channels = self.latent_channels
n.latent_dim = self.latent_dim
n.output_channels = self.output_channels
n.process_input = self.process_input
n.process_output = self.process_output
n.working_dtypes = self.working_dtypes.copy()
n.disable_offload = self.disable_offload
n.downscale_index_formula = self.downscale_index_formula
n.upscale_index_formula = self.upscale_index_formula
n.extra_1d_channel = self.extra_1d_channel
n.first_stage_model = self.first_stage_model
n.device = self.device
n.vae_dtype = self.vae_dtype
n.output_device = self.output_device
n.patcher = self.patcher.clone()
return n
def throw_exception_if_invalid(self):
if self.first_stage_model is None:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
@ -920,7 +945,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif clip_type == CLIPType.HIDREAM:
clip_target.clip = hidream.hidream_clip(**t5xxl_detect(clip_data),
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
clip_target.tokenizer = hidream.HiDreamTokenizer
else: # CLIPType.MOCHI
clip_target.clip = genmo.mochi_te(**t5xxl_detect(clip_data))
@ -945,7 +970,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8:
clip_target.clip = hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
clip_target.tokenizer = hidream.HiDreamTokenizer
elif te_model == TEModel.QWEN25_3B:
clip_target.clip = omnigen2.te(**llama_detect(clip_data))
@ -1033,6 +1058,7 @@ def model_detection_error_hint(path, state_dict):
return "\nHINT: This seems to be a Lora file and Lora files should be put in the lora folder and loaded with a lora loader node.."
return ""
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
logger.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
@ -1097,7 +1123,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return None
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.scaled_fp8 is not None:
weight_dtype = None

View File

@ -41,6 +41,7 @@ from einops import rearrange
from torch.nn.functional import interpolate
from tqdm import tqdm
from comfy_api import feature_flags
from . import interruption, checkpoint_pickle
from .cli_args import args
from .component_model import files
@ -48,6 +49,7 @@ from .component_model.deprecation import _deprecate_method
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
from .component_model.queue_types import BinaryEventTypes
from .execution_context import current_execution_context
from .progress import get_progress_state
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
@ -1106,22 +1108,23 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amou
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
def _progress_bar_update(value: float, total: float, preview_image_or_data: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None, node_id: str = None):
server = server or current_execution_context().server
# todo: this should really be from the context. right now the server is behaving like a context
client_id = client_id or server.client_id
def _progress_bar_update(value: float, total: float, preview_image_or_data: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None, node_id: str = None, prompt_id: str = None):
context = current_execution_context()
server = server or context.server
executing_context = context
prompt_id = prompt_id or executing_context.task_id or server.last_prompt_id
node_id = node_id or executing_context.node_id or server.last_node_id
interruption.throw_exception_if_processing_interrupted()
progress: ProgressMessage = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": node_id or server.last_node_id}
progress: ProgressMessage = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
# todo: is this still necessary?
if isinstance(preview_image_or_data, dict):
progress["output"] = preview_image_or_data
# this is responsible for encoding the image
get_progress_state().update_progress(node_id, value, total, preview_image_or_data)
server.send_sync("progress", progress, client_id)
# todo: investigate a better way to send the image data, since it needs the node ID
if preview_image_or_data is not None and not isinstance(preview_image_or_data, dict):
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image_or_data, client_id)
def set_progress_bar_enabled(enabled: bool):
warnings.warn(
"The global method 'set_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",

View File

@ -43,7 +43,8 @@ def get_connection_feature(
def supports_feature(
sockets_metadata: Dict[str, Dict[str, Any]],
sid: str,
feature_name: str
feature_name: str,
force=True,
) -> bool:
"""
Check if a connection supports a specific feature.
@ -52,10 +53,13 @@ def supports_feature(
sockets_metadata: Dictionary of socket metadata
sid: Session ID of the connection
feature_name: Name of the feature to check
force (bool): If it cannot be determined, assume True
Returns:
Boolean indicating if feature is supported
"""
if sockets_metadata is None or "__unimplemented" in sockets_metadata:
return force
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True

View File

@ -11,6 +11,7 @@ from comfy import model_management
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_patcher import ModelPatcher
from comfy.nodes.package_typing import CustomNode, InputTypes
from comfy.sd import VAE
from comfy_api.torch_helpers import set_torch_compile_wrapper
logger = logging.getLogger(__name__)
@ -45,6 +46,7 @@ def write_atomic(
torch._inductor.codecache.write_atomic = write_atomic
# torch._inductor.utils.is_big_gpu = lambda *args: True
class TorchCompileModel(CustomNode):
@ -52,10 +54,10 @@ class TorchCompileModel(CustomNode):
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"model": ("MODEL,VAE",),
},
"optional": {
"object_patch": ("STRING", {"default": DIFFUSION_MODEL}),
"object_patch": ("STRING", {"default": ""}),
"fullgraph": ("BOOLEAN", {"default": False}),
"dynamic": ("BOOLEAN", {"default": False}),
"backend": (TORCH_COMPILE_BACKENDS, {"default": "inductor"}),
@ -64,15 +66,14 @@ class TorchCompileModel(CustomNode):
}
}
RETURN_TYPES = ("MODEL",)
RETURN_TYPES = ("MODEL,VAE",)
FUNCTION = "patch"
RETURN_NAMES = ("model or vae",)
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: str = "max-autotune", torch_tensorrt_optimization_level: int = 3) -> tuple[Callable]:
if object_patch is None:
object_patch = DIFFUSION_MODEL
def patch(self, model: ModelPatcher | VAE | torch.nn.Module, object_patch: str | None = "", fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: str = "max-autotune", torch_tensorrt_optimization_level: int = 3) -> tuple[Callable]:
compile_kwargs = {
"fullgraph": fullgraph,
"dynamic": dynamic,
@ -99,17 +100,26 @@ class TorchCompileModel(CustomNode):
}
move_to_gpu = True
del compile_kwargs["mode"]
if isinstance(model, ModelPatcher) or isinstance(model, TransformersManagedModel):
m = model.clone()
if isinstance(model, (ModelPatcher, TransformersManagedModel, VAE)):
to_return = model.clone()
object_patches = [p.strip() for p in object_patch.split(",")]
patcher: ModelPatcher
if isinstance(to_return, VAE):
patcher = to_return.patcher
object_patches = ["encoder", "decoder"]
else:
patcher = to_return
if object_patch is None or len(object_patches) == 0:
object_patches = [DIFFUSION_MODEL]
if move_to_gpu:
model_management.unload_all_models()
model_management.load_models_gpu([m])
set_torch_compile_wrapper(m, object_patch=object_patch, **compile_kwargs)
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
model_management.load_models_gpu([patcher])
set_torch_compile_wrapper(patcher, keys=object_patches, **compile_kwargs)
# m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
# todo: do we want to move something back off the GPU?
# if move_to_gpu:
# model_management.unload_all_models()
return m,
return to_return,
elif isinstance(model, torch.nn.Module):
if move_to_gpu:
model_management.unload_all_models()
@ -119,7 +129,7 @@ class TorchCompileModel(CustomNode):
model.to(device=model_management.unet_offload_device())
return res,
else:
logging.warning("Encountered a model that cannot be compiled")
logger.warning("Encountered a model that cannot be compiled")
return model,
except OSError as os_error:
try:
@ -132,7 +142,7 @@ class TorchCompileModel(CustomNode):
torch._inductor.utils.clear_inductor_caches() # pylint: disable=no-member
except Exception:
pass
logging.error(f"An exception occurred while trying to compile {str(model)}, gracefully skipping compilation", exc_info=exc_info)
logger.error(f"An exception occurred while trying to compile {str(model)}, gracefully skipping compilation", exc_info=exc_info)
return model,
@ -160,7 +170,7 @@ class QuantizeModel(CustomNode):
RETURN_TYPES = ("MODEL",)
def warn_in_place(self, model: ModelPatcher):
logging.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.")
logger.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.")
def execute(self, model: ModelPatcher, strategy: str = _QUANTIZATION_STRATEGIES[0]) -> tuple[ModelPatcher]:
model = model.clone()
@ -179,7 +189,7 @@ class QuantizeModel(CustomNode):
"final_layer",
}
if strategy == "quanto":
logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations")
logger.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations")
self.warn_in_place(model)
from optimum.quanto import quantize, qint8 # pylint: disable=import-error
exclusion_list = [

View File

@ -57,7 +57,7 @@ dependencies = [
"pyjwt[crypto]",
"kornia>=0.7.0",
"mpmath>=1.0,!=1.4.0a0",
"huggingface_hub[hf_transfer]>0.20",
"huggingface_hub[hf_xet]>=0.32.0",
"lazy-object-proxy",
"lazy_loader>=0.3",
"can_ada",
@ -76,7 +76,8 @@ dependencies = [
"wrapt>=1.16.0",
"certifi",
"spandrel>=0.3.4",
"numpy>=1.24.4",
# https://github.com/conda-forge/numba-feedstock/issues/158 until numba is released with support for a later version of numpy
"numpy>=1.24.4,<2.3",
"soundfile",
"watchdog",
"PySoundFile",

View File

@ -15,7 +15,7 @@ import requests
os.environ['OTEL_METRICS_EXPORTER'] = 'none'
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
os.environ["TC_HOST"] = "localhost"