Improvements for Wan 2.2 support

- add xet support and add the xet cache to manageable directories
 - xet is enabled by default
 - fix logging to root in various places
 - improve logging about model unloading and loading
 - TorchCompileNode now supports the VAE
 - torchaudio missing will cause less noise in the logs
 - feature flags will assume to be supporting everything in the distributed progress context
 - fixes progress notifications
This commit is contained in:
doctorpangloss 2025-07-28 14:36:27 -07:00
parent b5a50301f6
commit 03e5430121
19 changed files with 192 additions and 82 deletions

View File

@ -84,7 +84,7 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool
stream_handler = logging.StreamHandler() stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter( stream_handler.setFormatter(logging.Formatter(
"%(asctime)s [%(name)s] [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s", "%(asctime)s [%(levelname)s] [%(name)s] [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S" datefmt="%Y-%m-%d %H:%M:%S"
)) ))

View File

@ -86,6 +86,9 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio
if "HF_HUB_CACHE" in os.environ: if "HF_HUB_CACHE" in os.environ:
hf_cache_paths.additional_absolute_directory_paths.append(os.environ.get("HF_HUB_CACHE")) hf_cache_paths.additional_absolute_directory_paths.append(os.environ.get("HF_HUB_CACHE"))
hf_xet = ModelPaths(["xet"], supported_extensions=set())
if "HF_XET_CACHE" in os.environ:
hf_xet.additional_absolute_directory_paths.append(os.environ.get("HF_XET_CACHE"))
model_paths_to_add = [ model_paths_to_add = [
ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)), ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["configs"], additional_absolute_directory_paths=[get_package_as_path("comfy.configs")], supported_extensions={".yaml"}), ModelPaths(["configs"], additional_absolute_directory_paths=[get_package_as_path("comfy.configs")], supported_extensions={".yaml"}),
@ -107,6 +110,7 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio
ModelPaths(["classifiers"], supported_extensions=set()), ModelPaths(["classifiers"], supported_extensions=set()),
ModelPaths(["huggingface"], supported_extensions=set()), ModelPaths(["huggingface"], supported_extensions=set()),
hf_cache_paths, hf_cache_paths,
hf_xet,
] ]
for model_paths in model_paths_to_add: for model_paths in model_paths_to_add:
if replace_existing: if replace_existing:

View File

@ -18,7 +18,7 @@ from .. import options
from ..app import logger from ..app import logger
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1" os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1" os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
os.environ["BITSANDBYTES_NOWELCOME"] = "1" os.environ["BITSANDBYTES_NOWELCOME"] = "1"

View File

@ -221,7 +221,7 @@ class PromptServer(ExecutorToClientProgress):
handler_args={'max_field_size': 16380}, handler_args={'max_field_size': 16380},
middlewares=middlewares) middlewares=middlewares)
self.sockets = dict() self.sockets = dict()
self.sockets_metadata = dict() self._sockets_metadata = dict()
self.web_root = ( self.web_root = (
FrontendManager.init_frontend(args.front_end_version) FrontendManager.init_frontend(args.front_end_version)
if args.front_end_root is None if args.front_end_root is None
@ -278,16 +278,16 @@ class PromptServer(ExecutorToClientProgress):
sid, sid,
) )
logging.info( logger.info(
f"Feature flags negotiated for client {sid}: {client_flags}" f"Feature flags negotiated for client {sid}: {client_flags}"
) )
first_message = False first_message = False
except json.JSONDecodeError: except json.JSONDecodeError:
logging.warning( logger.warning(
f"Invalid JSON received from client {sid}: {msg.data}" f"Invalid JSON received from client {sid}: {msg.data}"
) )
except Exception as e: except Exception as e:
logging.error(f"Error processing WebSocket message: {e}") logger.error(f"Error processing WebSocket message: {e}")
finally: finally:
self.sockets.pop(sid, None) self.sockets.pop(sid, None)
self.sockets_metadata.pop(sid, None) self.sockets_metadata.pop(sid, None)
@ -1236,3 +1236,11 @@ class PromptServer(ExecutorToClientProgress):
message = encode_text_for_progress(node_id, text) message = encode_text_for_progress(node_id, text)
self.send_sync(BinaryEventTypes.TEXT, message, sid) self.send_sync(BinaryEventTypes.TEXT, message, sid)
@property
def sockets_metadata(self):
return self._sockets_metadata
@sockets_metadata.setter
def sockets_metadata(self, value):
self._sockets_metadata = value

View File

@ -1,10 +1,11 @@
from __future__ import annotations # for Python 3.7-3.9 from __future__ import annotations # for Python 3.7-3.9
import PIL.Image
import concurrent.futures import concurrent.futures
import typing import typing
from enum import Enum from enum import Enum
from typing import Optional, Literal, Protocol, Union, NamedTuple, List from typing import Optional, Literal, Protocol, Union, NamedTuple, List
import PIL.Image
from typing_extensions import NotRequired, TypedDict from typing_extensions import NotRequired, TypedDict
from .encode_text_for_progress import encode_text_for_progress from .encode_text_for_progress import encode_text_for_progress
@ -79,11 +80,39 @@ class DependencyExecutionErrorMessage(TypedDict):
current_inputs: list[typing.Never] current_inputs: list[typing.Never]
class ActiveNodeProgressState(TypedDict, total=True):
value: float
max: float
# a string value from the NodeState enum
state: Literal["pending", "running", "finished", "error"]
node_id: str
prompt_id: str
display_node_id: str
parent_node_id: str
real_node_id: str
class ProgressStateMessage(TypedDict, total=True):
prompt_id: str
nodes: dict[str, ActiveNodeProgressState]
ExecutedMessage = ExecutingMessage ExecutedMessage = ExecutingMessage
SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed"], BinaryEventTypes, None] SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed", "progress_state"], BinaryEventTypes, None]
SendSyncData = Union[StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, tuple[UnencodedPreviewImageMessage, PreviewImageMetadata], bytes, bytearray, str, None] SendSyncData = Union[ProgressStateMessage, StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, tuple[PIL.Image.Image, PreviewImageMetadata], bytes, bytearray, str, None]
class SocketsMetadata(TypedDict, total=True):
feature_flags: dict[str, typing.Any]
class DefaultSocketsMetadata(TypedDict, total=True):
__unimplemented: Literal[True]
SocketsMetadataType = dict[str, SocketsMetadata] | DefaultSocketsMetadata
class ExecutorToClientProgress(Protocol): class ExecutorToClientProgress(Protocol):
@ -108,6 +137,10 @@ class ExecutorToClientProgress(Protocol):
""" """
return False return False
@property
def sockets_metadata(self) -> SocketsMetadataType:
return {"__unimplemented": True}
def send_sync(self, def send_sync(self,
event: SendSyncEvent, event: SendSyncEvent,
data: SendSyncData, data: SendSyncData,

View File

@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
try: try:
import torchaudio # pylint: disable=import-error import torchaudio # pylint: disable=import-error
except: except:
logger.warning("torchaudio missing, ACE model will be broken") logger.debug("torchaudio missing, ACE model will be broken")
import torchvision.transforms as transforms import torchvision.transforms as transforms
from .music_vocoder import ADaMoSHiFiGANV1 from .music_vocoder import ADaMoSHiFiGANV1

View File

@ -10,7 +10,7 @@ logger = logging.getLogger(__name__)
try: try:
from torchaudio.transforms import MelScale # pylint: disable=import-error from torchaudio.transforms import MelScale # pylint: disable=import-error
except: except:
logger.warning("torchaudio missing, ACE model will be broken") logger.debug("torchaudio missing, ACE model will be broken")
from .... import model_management from .... import model_management

View File

@ -10,6 +10,8 @@ from .... import ops
ops = ops.disable_weight_init ops = ops.disable_weight_init
logger = logging.getLogger(__name__)
if model_management.xformers_enabled_vae(): if model_management.xformers_enabled_vae():
import xformers # pylint: disable=import-error import xformers # pylint: disable=import-error
import xformers.ops # pylint: disable=import-error import xformers.ops # pylint: disable=import-error
@ -242,7 +244,7 @@ def slice_attention(q, k, v):
steps *= 2 steps *= 2
if steps > 128: if steps > 128:
raise e raise e
logging.warning("out of memory error, increasing steps and trying again {}".format(steps)) logger.warning("out of memory error, increasing steps and trying again {}".format(steps))
return r1 return r1
@ -296,20 +298,20 @@ def pytorch_attention(q, k, v):
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(orig_shape) out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION: except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") logger.warning("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape) out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out return out
def vae_attention(): def vae_attention():
if model_management.xformers_enabled_vae(): if model_management.xformers_enabled_vae():
logging.debug("Using xformers attention in VAE") logger.debug("Using xformers attention in VAE")
return xformers_attention return xformers_attention
elif model_management.pytorch_attention_enabled_vae(): elif model_management.pytorch_attention_enabled_vae():
logging.debug("Using pytorch attention in VAE") logger.debug("Using pytorch attention in VAE")
return pytorch_attention return pytorch_attention
else: else:
logging.debug("Using split attention in VAE") logger.debug("Using split attention in VAE")
return normal_attention return normal_attention
@ -650,7 +652,7 @@ class Decoder(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1] block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1) curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res) self.z_shape = (1, z_channels, curr_res, curr_res)
logging.debug("Working with z of shape {} = {} dimensions.".format( logger.debug("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape))) self.z_shape, np.prod(self.z_shape)))
# z to block_in # z to block_in

View File

@ -16,13 +16,13 @@
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
""" """
import math
import logging import logging
import torch import math
from enum import Enum from enum import Enum
from typing import TypeVar, Type, Protocol, Any, Optional from typing import TypeVar, Type, Protocol, Any, Optional
import torch
from . import conds from . import conds
from . import latent_formats from . import latent_formats
from . import model_management from . import model_management
@ -50,9 +50,9 @@ from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from .ldm.omnigen.omnigen2 import OmniGen2Transformer2DModel
from .ldm.pixart.pixartms import PixArtMS from .ldm.pixart.pixartms import PixArtMS
from .ldm.wan.model import WanModel, VaceWanModel, CameraWanModel from .ldm.wan.model import WanModel, VaceWanModel, CameraWanModel
from .ldm.omnigen.omnigen2 import OmniGen2Transformer2DModel
from .model_management_types import ModelManageable from .model_management_types import ModelManageable
from .model_sampling import CONST, ModelSamplingDiscreteFlow, ModelSamplingFlux, IMG_TO_IMG from .model_sampling import CONST, ModelSamplingDiscreteFlow, ModelSamplingFlux, IMG_TO_IMG
from .model_sampling import StableCascadeSampling, COSMOS_RFLOW, ModelSamplingCosmosRFlow, V_PREDICTION, \ from .model_sampling import StableCascadeSampling, COSMOS_RFLOW, ModelSamplingCosmosRFlow, V_PREDICTION, \
@ -60,6 +60,8 @@ from .model_sampling import StableCascadeSampling, COSMOS_RFLOW, ModelSamplingCo
from .ops import Operations from .ops import Operations
from .patcher_extension import WrapperExecutor, WrappersMP, get_all_wrappers from .patcher_extension import WrapperExecutor, WrappersMP, get_all_wrappers
logger = logging.getLogger(__name__)
class ModelType(Enum): class ModelType(Enum):
EPS = 1 EPS = 1
@ -149,8 +151,8 @@ class BaseModel(torch.nn.Module):
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
if model_management.force_channels_last(): if model_management.force_channels_last():
self.diffusion_model.to(memory_format=torch.channels_last) self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model") logger.debug("using channels last mode for diffusion model")
logging.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype)) logger.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
else: else:
self.operations = None self.operations = None
self.model_type = model_type self.model_type = model_type
@ -161,8 +163,8 @@ class BaseModel(torch.nn.Module):
self.adm_channels = 0 self.adm_channels = 0
self.concat_keys = () self.concat_keys = ()
logging.debug("model_type {}".format(model_type.name)) logger.debug("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels)) logger.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor self.memory_usage_factor = model_config.memory_usage_factor
self.memory_usage_factor_conds = () self.memory_usage_factor_conds = ()
self.training = False self.training = False
@ -310,10 +312,10 @@ class BaseModel(torch.nn.Module):
to_load = self.model_config.process_unet_state_dict(to_load) to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False) m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0: if len(m) > 0:
logging.warning("unet missing: {}".format(m)) logger.warning("unet missing: {}".format(m))
if len(u) > 0: if len(u) > 0:
logging.warning("unet unexpected: {}".format(u)) logger.warning("unet unexpected: {}".format(u))
del to_load del to_load
return self return self
@ -1227,6 +1229,7 @@ class WAN21_Camera(WAN21):
out['camera_conditions'] = conds.CONDRegular(camera_conditions) out['camera_conditions'] = conds.CONDRegular(camera_conditions)
return out return out
class WAN22(BaseModel): class WAN22(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=WanModel) super().__init__(model_config, model_type, device=device, unet_model=WanModel)
@ -1252,6 +1255,7 @@ class WAN22(BaseModel):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image return latent_image
class Hunyuan3Dv2(BaseModel): class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=Hunyuan3Dv2Model) super().__init__(model_config, model_type, device=device, unet_model=Hunyuan3Dv2Model)
@ -1321,6 +1325,7 @@ class ACEStep(BaseModel):
out['lyrics_strength'] = conds.CONDConstant(kwargs.get("lyrics_strength", 1.0)) out['lyrics_strength'] = conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
return out return out
class Omnigen2(BaseModel): class Omnigen2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=OmniGen2Transformer2DModel) super().__init__(model_config, model_type, device=device, unet_model=OmniGen2Transformer2DModel)

View File

@ -14,7 +14,7 @@ from pathlib import Path
from typing import List, Optional, Final, Set from typing import List, Optional, Final, Set
# enable better transfer # enable better transfer
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
import tqdm import tqdm
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem
@ -486,6 +486,7 @@ KNOWN_VAES: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("comfyanonymous/cosmos_1.0_text_encoder_and_VAE_ComfyUI", "vae/cosmos_cv8x8x8_1.0.safetensors"), HuggingFile("comfyanonymous/cosmos_1.0_text_encoder_and_VAE_ComfyUI", "vae/cosmos_cv8x8x8_1.0.safetensors"),
HuggingFile("Comfy-Org/Lumina_Image_2.0_Repackaged", "split_files/vae/ae.safetensors", save_with_filename="lumina_image_2.0-ae.safetensors"), HuggingFile("Comfy-Org/Lumina_Image_2.0_Repackaged", "split_files/vae/ae.safetensors", save_with_filename="lumina_image_2.0-ae.safetensors"),
HuggingFile("Comfy-Org/Wan_2.1_ComfyUI_repackaged", "split_files/vae/wan_2.1_vae.safetensors"), HuggingFile("Comfy-Org/Wan_2.1_ComfyUI_repackaged", "split_files/vae/wan_2.1_vae.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/vae/wan2.2_vae.safetensors"),
], folder_name="vae") ], folder_name="vae")
KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = { KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = {
@ -546,6 +547,15 @@ KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("Comfy-Org/Cosmos_Predict2_repackaged", "cosmos_predict2_2B_video2world_480p_16fps.safetensors"), HuggingFile("Comfy-Org/Cosmos_Predict2_repackaged", "cosmos_predict2_2B_video2world_480p_16fps.safetensors"),
HuggingFile("Comfy-Org/Wan_2.1_ComfyUI_repackaged", "split_files/diffusion_models/wan2.1_vace_14B_fp16.safetensors"), HuggingFile("Comfy-Org/Wan_2.1_ComfyUI_repackaged", "split_files/diffusion_models/wan2.1_vace_14B_fp16.safetensors"),
HuggingFile("Comfy-Org/Wan_2.1_ComfyUI_repackaged", "split_files/diffusion_models/wan2.1_fun_camera_v1.1_1.3B_bf16.safetensors"), HuggingFile("Comfy-Org/Wan_2.1_ComfyUI_repackaged", "split_files/diffusion_models/wan2.1_fun_camera_v1.1_1.3B_bf16.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp16.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp16.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_high_noise_14B_fp16.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp16.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors"),
HuggingFile("lodestones/Chroma", "chroma-unlocked-v37.safetensors"), HuggingFile("lodestones/Chroma", "chroma-unlocked-v37.safetensors"),
], folder_names=["diffusion_models", "unet"]) ], folder_names=["diffusion_models", "unet"])

View File

@ -742,7 +742,6 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model) current_loaded_models.insert(0, loaded_model)
logger.debug(f"Loaded {loaded_model}")
span = get_current_span() span = get_current_span()
span.set_attribute("models_to_load", list(map(str, models_to_load))) span.set_attribute("models_to_load", list(map(str, models_to_load)))

View File

@ -29,6 +29,7 @@ from typing import Callable, Optional
import torch import torch
import torch.nn import torch.nn
from humanize import naturalsize from humanize import naturalsize
from natsort import natsorted
from . import model_management, lora from . import model_management, lora
from . import patcher_extension from . import patcher_extension
@ -119,6 +120,7 @@ def wipe_lowvram_weight(m):
if hasattr(m, "bias_function"): if hasattr(m, "bias_function"):
m.bias_function = [] m.bias_function = []
def move_weight_functions(m, device): def move_weight_functions(m, device):
if device is None: if device is None:
return 0 return 0
@ -289,7 +291,7 @@ class ModelPatcher(ModelManageable):
return self._force_cast_weights return self._force_cast_weights
@force_cast_weights.setter @force_cast_weights.setter
def force_cast_weights(self, value:bool) -> None: def force_cast_weights(self, value: bool) -> None:
self._force_cast_weights = value self._force_cast_weights = value
def lowvram_patch_counter(self): def lowvram_patch_counter(self):
@ -475,7 +477,7 @@ class ModelPatcher(ModelManageable):
self.add_object_patch("manual_cast_dtype", dtype) self.add_object_patch("manual_cast_dtype", dtype)
if dtype is not None: if dtype is not None:
self.force_cast_weights = True self.force_cast_weights = True
self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this self.patches_uuid = uuid.uuid4() # TODO: optimize by preventing a full model reload for this
def add_weight_wrapper(self, name, function): def add_weight_wrapper(self, name, function):
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function] self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
@ -630,7 +632,6 @@ class ModelPatcher(ModelManageable):
else: else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def _load_list(self) -> list[LoadingListItem]: def _load_list(self) -> list[LoadingListItem]:
loading = [] loading = []
for n, m in self.model.named_modules(): for n, m in self.model.named_modules():
@ -715,6 +716,7 @@ class ModelPatcher(ModelManageable):
mem_counter += move_weight_functions(m, device_to) mem_counter += move_weight_functions(m, device_to)
load_completely.sort(reverse=True) load_completely.sort(reverse=True)
models_loaded_regularly: list[str] = []
for x in load_completely: for x in load_completely:
n = x.name n = x.name
m = x.module m = x.module
@ -726,17 +728,17 @@ class ModelPatcher(ModelManageable):
for param in params: for param in params:
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to) self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
logger.debug("lowvram: loaded module regularly {} {}".format(n, m)) models_loaded_regularly.append("name={} module={}".format(n, m))
m.comfy_patched_weights = True m.comfy_patched_weights = True
logger.debug("lowvram: loaded module regularly: {}".format(", ".join(models_loaded_regularly)))
for x in load_completely: for x in load_completely:
x.module.to(device_to) x.module.to(device_to)
if lowvram_counter > 0: if lowvram_counter > 0:
logger.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) logger.debug("loaded partially lowvram_model_memory={}MB mem_counter={}MB patch_counter={}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self._memory_measurements.model_lowvram = True self._memory_measurements.model_lowvram = True
else: else:
logger.debug("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) logger.debug("loaded completely lowvram_model_memory={}MB mem_counter={}MB full_load={}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self._memory_measurements.model_lowvram = False self._memory_measurements.model_lowvram = False
if full_load: if full_load:
self.model.to(device_to) self.model.to(device_to)
@ -812,6 +814,7 @@ class ModelPatcher(ModelManageable):
self.object_patches_backup.clear() self.object_patches_backup.clear()
def partially_unload(self, device_to, memory_to_free=0): def partially_unload(self, device_to, memory_to_free=0):
freed_layers: list[str] = []
with self.use_ejected(): with self.use_ejected():
hooks_unpatched = False hooks_unpatched = False
memory_freed = 0 memory_freed = 0
@ -867,7 +870,9 @@ class ModelPatcher(ModelManageable):
m.comfy_cast_weights = True m.comfy_cast_weights = True
m.comfy_patched_weights = False m.comfy_patched_weights = False
memory_freed += module_mem memory_freed += module_mem
logging.debug("freed {}".format(n)) freed_layers.append(n)
logger.debug("freed {}".format(natsorted(freed_layers)))
self._memory_measurements.model_lowvram = True self._memory_measurements.model_lowvram = True
self._memory_measurements.lowvram_patch_counter += patch_counter self._memory_measurements.lowvram_patch_counter += patch_counter
@ -1190,7 +1195,7 @@ class ModelPatcher(ModelManageable):
model_sd_keys_set = set(model_sd_keys) model_sd_keys_set = set(model_sd_keys)
for key in cached_weights: for key in cached_weights:
if key not in model_sd_keys: if key not in model_sd_keys:
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}") logger.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
continue continue
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter) self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
model_sd_keys_set.remove(key) model_sd_keys_set.remove(key)
@ -1203,7 +1208,7 @@ class ModelPatcher(ModelManageable):
original_weights = self.get_key_patches() original_weights = self.get_key_patches()
for key in relevant_patches: for key in relevant_patches:
if key not in model_sd_keys: if key not in model_sd_keys:
logging.warning(f"Cached hook would not patch. Key does not exist in model: {key}") logger.warning(f"Cached hook would not patch. Key does not exist in model: {key}")
continue continue
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights, self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
memory_counter=memory_counter) memory_counter=memory_counter)
@ -1265,7 +1270,7 @@ class ModelPatcher(ModelManageable):
del out_weight del out_weight
del weight del weight
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None: def unpatch_hooks(self, whitelist_keys_set: set[str] = None) -> None:
with self.use_ejected(): with self.use_ejected():
if len(self.hook_backup) == 0: if len(self.hook_backup) == 0:
self.current_hooks = None self.current_hooks = None

View File

@ -7,9 +7,10 @@ from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from typing_extensions import override from typing_extensions import override
from .component_model.executor_types import ExecutorToClientProgress
from .component_model.module_property import create_module_properties from .component_model.module_property import create_module_properties
from .execution_context import current_execution_context from .execution_context import current_execution_context
from .progress_types import AbstractProgressRegistry from .progress_types import AbstractProgressRegistry, PreviewImageMetadata
if TYPE_CHECKING: if TYPE_CHECKING:
from comfy_execution.graph import DynamicPrompt from comfy_execution.graph import DynamicPrompt
@ -157,7 +158,7 @@ class WebUIProgressHandler(ProgressHandler):
Handler that sends progress updates to the WebUI via WebSockets. Handler that sends progress updates to the WebUI via WebSockets.
""" """
def __init__(self, server_instance): def __init__(self, server_instance: ExecutorToClientProgress):
super().__init__("webui") super().__init__("webui")
self.server_instance = server_instance self.server_instance = server_instance
@ -216,7 +217,7 @@ class WebUIProgressHandler(ProgressHandler):
self.server_instance.client_id, self.server_instance.client_id,
"supports_preview_metadata", "supports_preview_metadata",
): ):
metadata = { metadata: PreviewImageMetadata = {
"node_id": node_id, "node_id": node_id,
"prompt_id": prompt_id, "prompt_id": prompt_id,
"display_node_id": self.registry.dynprompt.get_display_node_id( "display_node_id": self.registry.dynprompt.get_display_node_id(
@ -327,7 +328,7 @@ class ProgressRegistry(AbstractProgressRegistry):
# Global registry instance # Global registry instance
@_module_properties.getter @_module_properties.getter
def _global_progress_registry() -> ProgressRegistry: def _global_progress_registry() -> AbstractProgressRegistry | None:
return current_execution_context().progress_registry return current_execution_context().progress_registry

View File

@ -41,23 +41,23 @@ from .model_management import load_models_gpu
from .model_patcher import ModelPatcher from .model_patcher import ModelPatcher
from .t2i_adapter import adapter from .t2i_adapter import adapter
from .taesd import taesd from .taesd import taesd
from .text_encoders import ace
from .text_encoders import aura_t5 from .text_encoders import aura_t5
from .text_encoders import hidream
from .text_encoders import cosmos from .text_encoders import cosmos
from .text_encoders import flux from .text_encoders import flux
from .text_encoders import genmo from .text_encoders import genmo
from .text_encoders import hidream
from .text_encoders import hunyuan_video from .text_encoders import hunyuan_video
from .text_encoders import hydit from .text_encoders import hydit
from .text_encoders import long_clipl from .text_encoders import long_clipl
from .text_encoders import lt from .text_encoders import lt
from .text_encoders import lumina2 from .text_encoders import lumina2
from .text_encoders import omnigen2
from .text_encoders import pixart_t5 from .text_encoders import pixart_t5
from .text_encoders import sa_t5 from .text_encoders import sa_t5
from .text_encoders import sd2_clip from .text_encoders import sd2_clip
from .text_encoders import sd3_clip from .text_encoders import sd3_clip
from .text_encoders import wan from .text_encoders import wan
from .text_encoders import ace
from .text_encoders import omnigen2
from .utils import ProgressBar from .utils import ProgressBar
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -280,7 +280,9 @@ class CLIP:
class VAE: class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None, no_init=False):
if no_init:
return
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): # diffusers format if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): # diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd) sd = diffusers_convert.convert_vae_state_dict(sd)
@ -469,7 +471,7 @@ class VAE:
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post} ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
self.first_stage_model = ShapeVAE(**ddconfig) self.first_stage_model = ShapeVAE(**ddconfig)
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio elif "vocoder.backbone.channel_layers.0.0.bias" in sd: # Ace Step Audio
self.first_stage_model = MusicDCAE(source_sample_rate=44100) self.first_stage_model = MusicDCAE(source_sample_rate=44100)
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
@ -511,6 +513,29 @@ class VAE:
self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logger.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) logger.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
def clone(self):
n = VAE(no_init=True)
n.memory_used_encode = self.memory_used_encode
n.memory_used_decode = self.memory_used_decode
n.downscale_ratio = self.downscale_ratio
n.upscale_ratio = self.upscale_ratio
n.latent_channels = self.latent_channels
n.latent_dim = self.latent_dim
n.output_channels = self.output_channels
n.process_input = self.process_input
n.process_output = self.process_output
n.working_dtypes = self.working_dtypes.copy()
n.disable_offload = self.disable_offload
n.downscale_index_formula = self.downscale_index_formula
n.upscale_index_formula = self.upscale_index_formula
n.extra_1d_channel = self.extra_1d_channel
n.first_stage_model = self.first_stage_model
n.device = self.device
n.vae_dtype = self.vae_dtype
n.output_device = self.output_device
n.patcher = self.patcher.clone()
return n
def throw_exception_if_invalid(self): def throw_exception_if_invalid(self):
if self.first_stage_model is None: if self.first_stage_model is None:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
@ -920,7 +945,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif clip_type == CLIPType.HIDREAM: elif clip_type == CLIPType.HIDREAM:
clip_target.clip = hidream.hidream_clip(**t5xxl_detect(clip_data), clip_target.clip = hidream.hidream_clip(**t5xxl_detect(clip_data),
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None) clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
clip_target.tokenizer = hidream.HiDreamTokenizer clip_target.tokenizer = hidream.HiDreamTokenizer
else: # CLIPType.MOCHI else: # CLIPType.MOCHI
clip_target.clip = genmo.mochi_te(**t5xxl_detect(clip_data)) clip_target.clip = genmo.mochi_te(**t5xxl_detect(clip_data))
@ -945,7 +970,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8: elif te_model == TEModel.LLAMA3_8:
clip_target.clip = hidream.hidream_clip(**llama_detect(clip_data), clip_target.clip = hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None) clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
clip_target.tokenizer = hidream.HiDreamTokenizer clip_target.tokenizer = hidream.HiDreamTokenizer
elif te_model == TEModel.QWEN25_3B: elif te_model == TEModel.QWEN25_3B:
clip_target.clip = omnigen2.te(**llama_detect(clip_data)) clip_target.clip = omnigen2.te(**llama_detect(clip_data))
@ -1033,6 +1058,7 @@ def model_detection_error_hint(path, state_dict):
return "\nHINT: This seems to be a Lora file and Lora files should be put in the lora folder and loaded with a lora loader node.." return "\nHINT: This seems to be a Lora file and Lora files should be put in the lora folder and loaded with a lora loader node.."
return "" return ""
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
logger.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.") logger.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True) model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
@ -1097,7 +1123,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return None return None
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used' return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
unet_weight_dtype = list(model_config.supported_inference_dtypes) unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.scaled_fp8 is not None: if model_config.scaled_fp8 is not None:
weight_dtype = None weight_dtype = None

View File

@ -41,6 +41,7 @@ from einops import rearrange
from torch.nn.functional import interpolate from torch.nn.functional import interpolate
from tqdm import tqdm from tqdm import tqdm
from comfy_api import feature_flags
from . import interruption, checkpoint_pickle from . import interruption, checkpoint_pickle
from .cli_args import args from .cli_args import args
from .component_model import files from .component_model import files
@ -48,6 +49,7 @@ from .component_model.deprecation import _deprecate_method
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
from .component_model.queue_types import BinaryEventTypes from .component_model.queue_types import BinaryEventTypes
from .execution_context import current_execution_context from .execution_context import current_execution_context
from .progress import get_progress_state
MMAP_TORCH_FILES = args.mmap_torch_files MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap DISABLE_MMAP = args.disable_mmap
@ -1106,22 +1108,23 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amou
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
def _progress_bar_update(value: float, total: float, preview_image_or_data: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None, node_id: str = None): def _progress_bar_update(value: float, total: float, preview_image_or_data: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None, node_id: str = None, prompt_id: str = None):
server = server or current_execution_context().server context = current_execution_context()
# todo: this should really be from the context. right now the server is behaving like a context server = server or context.server
client_id = client_id or server.client_id executing_context = context
prompt_id = prompt_id or executing_context.task_id or server.last_prompt_id
node_id = node_id or executing_context.node_id or server.last_node_id
interruption.throw_exception_if_processing_interrupted() interruption.throw_exception_if_processing_interrupted()
progress: ProgressMessage = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": node_id or server.last_node_id}
progress: ProgressMessage = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
# todo: is this still necessary?
if isinstance(preview_image_or_data, dict): if isinstance(preview_image_or_data, dict):
progress["output"] = preview_image_or_data progress["output"] = preview_image_or_data
# this is responsible for encoding the image
get_progress_state().update_progress(node_id, value, total, preview_image_or_data)
server.send_sync("progress", progress, client_id) server.send_sync("progress", progress, client_id)
# todo: investigate a better way to send the image data, since it needs the node ID
if preview_image_or_data is not None and not isinstance(preview_image_or_data, dict):
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image_or_data, client_id)
def set_progress_bar_enabled(enabled: bool): def set_progress_bar_enabled(enabled: bool):
warnings.warn( warnings.warn(
"The global method 'set_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.", "The global method 'set_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",

View File

@ -43,7 +43,8 @@ def get_connection_feature(
def supports_feature( def supports_feature(
sockets_metadata: Dict[str, Dict[str, Any]], sockets_metadata: Dict[str, Dict[str, Any]],
sid: str, sid: str,
feature_name: str feature_name: str,
force=True,
) -> bool: ) -> bool:
""" """
Check if a connection supports a specific feature. Check if a connection supports a specific feature.
@ -52,10 +53,13 @@ def supports_feature(
sockets_metadata: Dictionary of socket metadata sockets_metadata: Dictionary of socket metadata
sid: Session ID of the connection sid: Session ID of the connection
feature_name: Name of the feature to check feature_name: Name of the feature to check
force (bool): If it cannot be determined, assume True
Returns: Returns:
Boolean indicating if feature is supported Boolean indicating if feature is supported
""" """
if sockets_metadata is None or "__unimplemented" in sockets_metadata:
return force
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True return get_connection_feature(sockets_metadata, sid, feature_name, False) is True

View File

@ -11,6 +11,7 @@ from comfy import model_management
from comfy.language.transformers_model_management import TransformersManagedModel from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_patcher import ModelPatcher from comfy.model_patcher import ModelPatcher
from comfy.nodes.package_typing import CustomNode, InputTypes from comfy.nodes.package_typing import CustomNode, InputTypes
from comfy.sd import VAE
from comfy_api.torch_helpers import set_torch_compile_wrapper from comfy_api.torch_helpers import set_torch_compile_wrapper
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -45,6 +46,7 @@ def write_atomic(
torch._inductor.codecache.write_atomic = write_atomic torch._inductor.codecache.write_atomic = write_atomic
# torch._inductor.utils.is_big_gpu = lambda *args: True
class TorchCompileModel(CustomNode): class TorchCompileModel(CustomNode):
@ -52,10 +54,10 @@ class TorchCompileModel(CustomNode):
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
"required": { "required": {
"model": ("MODEL",), "model": ("MODEL,VAE",),
}, },
"optional": { "optional": {
"object_patch": ("STRING", {"default": DIFFUSION_MODEL}), "object_patch": ("STRING", {"default": ""}),
"fullgraph": ("BOOLEAN", {"default": False}), "fullgraph": ("BOOLEAN", {"default": False}),
"dynamic": ("BOOLEAN", {"default": False}), "dynamic": ("BOOLEAN", {"default": False}),
"backend": (TORCH_COMPILE_BACKENDS, {"default": "inductor"}), "backend": (TORCH_COMPILE_BACKENDS, {"default": "inductor"}),
@ -64,15 +66,14 @@ class TorchCompileModel(CustomNode):
} }
} }
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL,VAE",)
FUNCTION = "patch" FUNCTION = "patch"
RETURN_NAMES = ("model or vae",)
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
EXPERIMENTAL = True EXPERIMENTAL = True
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: str = "max-autotune", torch_tensorrt_optimization_level: int = 3) -> tuple[Callable]: def patch(self, model: ModelPatcher | VAE | torch.nn.Module, object_patch: str | None = "", fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: str = "max-autotune", torch_tensorrt_optimization_level: int = 3) -> tuple[Callable]:
if object_patch is None:
object_patch = DIFFUSION_MODEL
compile_kwargs = { compile_kwargs = {
"fullgraph": fullgraph, "fullgraph": fullgraph,
"dynamic": dynamic, "dynamic": dynamic,
@ -99,17 +100,26 @@ class TorchCompileModel(CustomNode):
} }
move_to_gpu = True move_to_gpu = True
del compile_kwargs["mode"] del compile_kwargs["mode"]
if isinstance(model, ModelPatcher) or isinstance(model, TransformersManagedModel): if isinstance(model, (ModelPatcher, TransformersManagedModel, VAE)):
m = model.clone() to_return = model.clone()
object_patches = [p.strip() for p in object_patch.split(",")]
patcher: ModelPatcher
if isinstance(to_return, VAE):
patcher = to_return.patcher
object_patches = ["encoder", "decoder"]
else:
patcher = to_return
if object_patch is None or len(object_patches) == 0:
object_patches = [DIFFUSION_MODEL]
if move_to_gpu: if move_to_gpu:
model_management.unload_all_models() model_management.unload_all_models()
model_management.load_models_gpu([m]) model_management.load_models_gpu([patcher])
set_torch_compile_wrapper(m, object_patch=object_patch, **compile_kwargs) set_torch_compile_wrapper(patcher, keys=object_patches, **compile_kwargs)
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs)) # m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
# todo: do we want to move something back off the GPU? # todo: do we want to move something back off the GPU?
# if move_to_gpu: # if move_to_gpu:
# model_management.unload_all_models() # model_management.unload_all_models()
return m, return to_return,
elif isinstance(model, torch.nn.Module): elif isinstance(model, torch.nn.Module):
if move_to_gpu: if move_to_gpu:
model_management.unload_all_models() model_management.unload_all_models()
@ -119,7 +129,7 @@ class TorchCompileModel(CustomNode):
model.to(device=model_management.unet_offload_device()) model.to(device=model_management.unet_offload_device())
return res, return res,
else: else:
logging.warning("Encountered a model that cannot be compiled") logger.warning("Encountered a model that cannot be compiled")
return model, return model,
except OSError as os_error: except OSError as os_error:
try: try:
@ -132,7 +142,7 @@ class TorchCompileModel(CustomNode):
torch._inductor.utils.clear_inductor_caches() # pylint: disable=no-member torch._inductor.utils.clear_inductor_caches() # pylint: disable=no-member
except Exception: except Exception:
pass pass
logging.error(f"An exception occurred while trying to compile {str(model)}, gracefully skipping compilation", exc_info=exc_info) logger.error(f"An exception occurred while trying to compile {str(model)}, gracefully skipping compilation", exc_info=exc_info)
return model, return model,
@ -160,7 +170,7 @@ class QuantizeModel(CustomNode):
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
def warn_in_place(self, model: ModelPatcher): def warn_in_place(self, model: ModelPatcher):
logging.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.") logger.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.")
def execute(self, model: ModelPatcher, strategy: str = _QUANTIZATION_STRATEGIES[0]) -> tuple[ModelPatcher]: def execute(self, model: ModelPatcher, strategy: str = _QUANTIZATION_STRATEGIES[0]) -> tuple[ModelPatcher]:
model = model.clone() model = model.clone()
@ -179,7 +189,7 @@ class QuantizeModel(CustomNode):
"final_layer", "final_layer",
} }
if strategy == "quanto": if strategy == "quanto":
logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations") logger.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations")
self.warn_in_place(model) self.warn_in_place(model)
from optimum.quanto import quantize, qint8 # pylint: disable=import-error from optimum.quanto import quantize, qint8 # pylint: disable=import-error
exclusion_list = [ exclusion_list = [

View File

@ -57,7 +57,7 @@ dependencies = [
"pyjwt[crypto]", "pyjwt[crypto]",
"kornia>=0.7.0", "kornia>=0.7.0",
"mpmath>=1.0,!=1.4.0a0", "mpmath>=1.0,!=1.4.0a0",
"huggingface_hub[hf_transfer]>0.20", "huggingface_hub[hf_xet]>=0.32.0",
"lazy-object-proxy", "lazy-object-proxy",
"lazy_loader>=0.3", "lazy_loader>=0.3",
"can_ada", "can_ada",
@ -76,7 +76,8 @@ dependencies = [
"wrapt>=1.16.0", "wrapt>=1.16.0",
"certifi", "certifi",
"spandrel>=0.3.4", "spandrel>=0.3.4",
"numpy>=1.24.4", # https://github.com/conda-forge/numba-feedstock/issues/158 until numba is released with support for a later version of numpy
"numpy>=1.24.4,<2.3",
"soundfile", "soundfile",
"watchdog", "watchdog",
"PySoundFile", "PySoundFile",

View File

@ -15,7 +15,7 @@ import requests
os.environ['OTEL_METRICS_EXPORTER'] = 'none' os.environ['OTEL_METRICS_EXPORTER'] = 'none'
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
# fixes issues with running the testcontainers rabbitmqcontainer on Windows # fixes issues with running the testcontainers rabbitmqcontainer on Windows
os.environ["TC_HOST"] = "localhost" os.environ["TC_HOST"] = "localhost"