From f5e29f0e6189e7ecad2f412f2b85056ff215728d Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Tue, 9 Sep 2025 12:58:23 -0700 Subject: [PATCH] Fix sage attention for qwen-image, fix Qwen Image VAE memory usage, improve compatibility with hooks when using KSampler based workflows and non-ModelPatcher model manageable-stuff --- comfy/ldm/modules/attention.py | 4 +- .../ldm/modules/sage_attention_dispatcher.py | 93 +++++++++++++++++++ comfy/ldm/qwen_image/model.py | 2 +- comfy/model_management_types.py | 75 +++++++++++++-- comfy/nodes/base_nodes.py | 9 +- comfy/sd.py | 25 ++++- 6 files changed, 187 insertions(+), 21 deletions(-) create mode 100644 comfy/ldm/modules/sage_attention_dispatcher.py diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 9e2c03c65..d927a735f 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -22,11 +22,9 @@ if model_management.xformers_enabled(): sageattn = None if model_management.sage_attention_enabled(): try: - from sageattention import sageattn # pylint: disable=import-error + from .sage_attention_dispatcher import sageattn except ModuleNotFoundError as e: if e.name == "sageattention": - import sys - logger.error(f"To use the `--use-sage-attention` feature, the `sageattention` package must be installed first.") else: raise e diff --git a/comfy/ldm/modules/sage_attention_dispatcher.py b/comfy/ldm/modules/sage_attention_dispatcher.py new file mode 100644 index 000000000..88f71fb20 --- /dev/null +++ b/comfy/ldm/modules/sage_attention_dispatcher.py @@ -0,0 +1,93 @@ +from typing import Optional, Any + +import torch +# only imported when sage attention is enabled +from sageattention import * # pylint: disable=import-error + + +def get_cuda_arch_versions(): + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + + +def sageattn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + sm_scale: Optional[float] = None, + return_lse: bool = False, + **kwargs: Any, +): + """ + Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + """ + + arch = get_cuda_arch_versions()[q.device.index] + if arch in ("sm80", "sm86"): + return sageattn_qk_int8_pv_fp16_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32") + # todo: the triton kernel is broken on ampere, so disable it + # elif arch == "sm86": + # return sageattn_qk_int8_pv_fp16_triton(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse) + elif arch == "sm89": + return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16") + elif arch == "sm90": + return sageattn_qk_int8_pv_fp8_cuda_sm90(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32") + elif arch == "sm120": + return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16") # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. + else: + raise ValueError(f"Unsupported CUDA architecture: {arch}") diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 927469aa3..5b44a3835 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -8,7 +8,7 @@ from typing import Optional, Tuple from ..common_dit import pad_to_patch_size from ..flux.layers import EmbedND from ..lightricks.model import TimestepEmbedding, Timesteps -from ..modules.attention import optimized_attention_no_sage_masked as optimized_attention_masked +from ..modules.attention import optimized_attention_masked from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP class GELU(nn.Module): diff --git a/comfy/model_management_types.py b/comfy/model_management_types.py index b5e73d811..57b66517f 100644 --- a/comfy/model_management_types.py +++ b/comfy/model_management_types.py @@ -30,6 +30,68 @@ class HooksSupport(Protocol, metaclass=ABCMeta): def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options): return + def model_patches_models(self) -> list[ModelManageableT]: + """ + Used to implement Qwen DiffSynth Controlnets (?) + :return: + """ + return [] + + @property + def hook_mode(self): + from .hooks import EnumHookMode + if not hasattr(self, "_hook_mode"): + setattr(self, "_hook_mode", EnumHookMode.MaxSpeed) + return getattr(self, "_hook_mode") + + + @hook_mode.setter + def hook_mode(self, value): + setattr(self, "_hook_mode", value) + + def restore_hook_patches(self): + return + + @property + def wrappers(self): + if not hasattr(self, "_wrappers"): + setattr(self, "_wrappers", {}) + return getattr(self, "_wrappers") + + @wrappers.setter + def wrappers(self, value): + setattr(self, "_wrappers", value) + + @property + def callbacks(self) -> dict: + if not hasattr(self, "_callbacks"): + setattr(self, "_callbacks", {}) + return getattr(self, "_callbacks") + + @callbacks.setter + def callbacks(self, value): + setattr(self, "_callbacks", value) + + def cleanup(self): + pass + + def pre_run(self): + from .model_base import BaseModel + if hasattr(self, "model") and isinstance(self.model, BaseModel): + self.model.current_patcher = self + + def prepare_state(self, *args, **kwargs): + pass + + def register_all_hook_patches(self, a, b, c, d): + pass + + def get_nested_additional_models(self): + return [] + + def apply_hooks(self, *args, **kwargs): + return {} + class TrainingSupport(Protocol, metaclass=ABCMeta): def set_model_compute_dtype(self, dtype: torch.dtype): @@ -146,14 +208,14 @@ class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol, self.unpatch_model(device_to) return self.model_size() - def memory_required(self, input_shape) -> int: + def memory_required(self, input_shape: torch.Size) -> int: from .model_base import BaseModel if isinstance(self.model, BaseModel): return self.model.memory_required(input_shape=input_shape) else: - # todo: why isn't this true? - return self.model_size() + # todo: we need a real implementation of this + return 0 def loaded_size(self) -> int: if self.current_loaded_device() == self.load_device: @@ -198,13 +260,6 @@ class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol, self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) return self.model - def model_patches_models(self) -> list[ModelManageableT]: - """ - Used to implement Qwen DiffSynth Controlnets (?) - :return: - """ - return [] - @dataclasses.dataclass class MemoryMeasurements: diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 2dec6b669..518911791 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -812,10 +812,11 @@ class VAELoader: def load_vae(self, vae_name): if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: sd_ = self.load_taesd(vae_name) + metadata = {} else: vae_path = get_or_download("vae", vae_name, KNOWN_VAES) - sd_ = utils.load_torch_file(vae_path) - vae = sd.VAE(sd=sd_) + sd_, metadata = utils.load_torch_file(vae_path, return_metadata=True) + vae = sd.VAE(sd=sd_, metadata=metadata, ckpt_name=vae_name) vae.throw_exception_if_invalid() return (vae,) @@ -1321,12 +1322,12 @@ class RepeatLatentBatch: s = samples.copy() s_in = samples["samples"] - s["samples"] = s_in.repeat((amount,) + ((1, ) * (s_in.ndim - 1))) + s["samples"] = s_in.repeat((amount,) + ((1,) * (s_in.ndim - 1))) if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1: masks = samples["noise_mask"] if masks.shape[0] < s_in.shape[0]: masks = masks.repeat((math.ceil(s_in.shape[0] / masks.shape[0]),) + ((1,) * (masks.ndim - 1)))[:s_in.shape[0]] - s["noise_mask"] = samples["noise_mask"].repeat((amount,) + ((1, ) * (samples["noise_mask"].ndim - 1))) + s["noise_mask"] = samples["noise_mask"].repeat((amount,) + ((1,) * (samples["noise_mask"].ndim - 1))) if "batch_index" in s: offset = max(s["batch_index"]) - min(s["batch_index"]) + 1 s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]] diff --git a/comfy/sd.py b/comfy/sd.py index 9fae588f7..50decf18d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -11,6 +11,8 @@ import yaml from enum import Enum from typing import Any, Optional +from humanize import naturalsize + from . import clip_vision from . import diffusers_convert from . import gligen @@ -281,7 +283,8 @@ class CLIP: class VAE: - def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None, no_init=False): + def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None, no_init=False, ckpt_name:Optional[str]=""): + self.ckpt_name = ckpt_name if no_init: return if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): # diffusers format @@ -459,8 +462,16 @@ class VAE: ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} self.first_stage_model = wan_vae.WanVAE(**ddconfig) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] - self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) - self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype) + + # todo: not sure how to detect qwen here + wan_21_decode = 7000 + wan_21_encode = wan_21_decode - 1000 + qwen_vae_decode = int(wan_21_decode / 3) + qwen_vae_encode = int(wan_21_encode / 3) + encode_const = qwen_vae_encode if "qwen" in self.ckpt_name.lower() else wan_21_encode + decode_const = qwen_vae_decode if "qwen" in self.ckpt_name.lower() else wan_21_decode + self.memory_used_encode = lambda shape, dtype: encode_const * shape[3] * shape[4] * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: decode_const * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype) elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd: self.latent_dim = 1 ln_post = "geo_decoder.ln_post.weight" in sd @@ -777,6 +788,14 @@ class VAE: except: return None + def __str__(self): + info_str = f"dtype={self.vae_dtype} device={self.device}" + + if self.ckpt_name == "": + return f"" + else: + return f"" + class StyleModel: def __init__(self, model, device="cpu"):