Fix sage attention for qwen-image, fix Qwen Image VAE memory usage, improve compatibility with hooks when using KSampler based workflows and non-ModelPatcher model manageable-stuff

This commit is contained in:
doctorpangloss 2025-09-09 12:58:23 -07:00
parent b62d4f05e1
commit f5e29f0e61
6 changed files with 187 additions and 21 deletions

View File

@ -22,11 +22,9 @@ if model_management.xformers_enabled():
sageattn = None
if model_management.sage_attention_enabled():
try:
from sageattention import sageattn # pylint: disable=import-error
from .sage_attention_dispatcher import sageattn
except ModuleNotFoundError as e:
if e.name == "sageattention":
import sys
logger.error(f"To use the `--use-sage-attention` feature, the `sageattention` package must be installed first.")
else:
raise e

View File

@ -0,0 +1,93 @@
from typing import Optional, Any
import torch
# only imported when sage attention is enabled
from sageattention import * # pylint: disable=import-error
def get_cuda_arch_versions():
cuda_archs = []
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
cuda_archs.append(f"sm{major}{minor}")
return cuda_archs
def sageattn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
tensor_layout: str = "HND",
is_causal: bool = False,
sm_scale: Optional[float] = None,
return_lse: bool = False,
**kwargs: Any,
):
"""
Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability.
Parameters
----------
q : torch.Tensor
The query tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
k : torch.Tensor
The key tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
v : torch.Tensor
The value tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
tensor_layout : str
The tensor layout, either "HND" or "NHD".
Default: "HND".
is_causal : bool
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
Default: False.
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
return_lse : bool
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
Default: False.
Returns
-------
torch.Tensor
The output tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
torch.Tensor
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
Shape: ``[batch_size, num_qo_heads, qo_len]``.
Only returned if `return_lse` is True.
Note
----
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
- All tensors must be on the same cuda device.
"""
arch = get_cuda_arch_versions()[q.device.index]
if arch in ("sm80", "sm86"):
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32")
# todo: the triton kernel is broken on ampere, so disable it
# elif arch == "sm86":
# return sageattn_qk_int8_pv_fp16_triton(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse)
elif arch == "sm89":
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16")
elif arch == "sm90":
return sageattn_qk_int8_pv_fp8_cuda_sm90(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
elif arch == "sm120":
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16") # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
else:
raise ValueError(f"Unsupported CUDA architecture: {arch}")

View File

@ -8,7 +8,7 @@ from typing import Optional, Tuple
from ..common_dit import pad_to_patch_size
from ..flux.layers import EmbedND
from ..lightricks.model import TimestepEmbedding, Timesteps
from ..modules.attention import optimized_attention_no_sage_masked as optimized_attention_masked
from ..modules.attention import optimized_attention_masked
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
class GELU(nn.Module):

View File

@ -30,6 +30,68 @@ class HooksSupport(Protocol, metaclass=ABCMeta):
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options):
return
def model_patches_models(self) -> list[ModelManageableT]:
"""
Used to implement Qwen DiffSynth Controlnets (?)
:return:
"""
return []
@property
def hook_mode(self):
from .hooks import EnumHookMode
if not hasattr(self, "_hook_mode"):
setattr(self, "_hook_mode", EnumHookMode.MaxSpeed)
return getattr(self, "_hook_mode")
@hook_mode.setter
def hook_mode(self, value):
setattr(self, "_hook_mode", value)
def restore_hook_patches(self):
return
@property
def wrappers(self):
if not hasattr(self, "_wrappers"):
setattr(self, "_wrappers", {})
return getattr(self, "_wrappers")
@wrappers.setter
def wrappers(self, value):
setattr(self, "_wrappers", value)
@property
def callbacks(self) -> dict:
if not hasattr(self, "_callbacks"):
setattr(self, "_callbacks", {})
return getattr(self, "_callbacks")
@callbacks.setter
def callbacks(self, value):
setattr(self, "_callbacks", value)
def cleanup(self):
pass
def pre_run(self):
from .model_base import BaseModel
if hasattr(self, "model") and isinstance(self.model, BaseModel):
self.model.current_patcher = self
def prepare_state(self, *args, **kwargs):
pass
def register_all_hook_patches(self, a, b, c, d):
pass
def get_nested_additional_models(self):
return []
def apply_hooks(self, *args, **kwargs):
return {}
class TrainingSupport(Protocol, metaclass=ABCMeta):
def set_model_compute_dtype(self, dtype: torch.dtype):
@ -146,14 +208,14 @@ class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol,
self.unpatch_model(device_to)
return self.model_size()
def memory_required(self, input_shape) -> int:
def memory_required(self, input_shape: torch.Size) -> int:
from .model_base import BaseModel
if isinstance(self.model, BaseModel):
return self.model.memory_required(input_shape=input_shape)
else:
# todo: why isn't this true?
return self.model_size()
# todo: we need a real implementation of this
return 0
def loaded_size(self) -> int:
if self.current_loaded_device() == self.load_device:
@ -198,13 +260,6 @@ class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol,
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
return self.model
def model_patches_models(self) -> list[ModelManageableT]:
"""
Used to implement Qwen DiffSynth Controlnets (?)
:return:
"""
return []
@dataclasses.dataclass
class MemoryMeasurements:

View File

@ -812,10 +812,11 @@ class VAELoader:
def load_vae(self, vae_name):
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
sd_ = self.load_taesd(vae_name)
metadata = {}
else:
vae_path = get_or_download("vae", vae_name, KNOWN_VAES)
sd_ = utils.load_torch_file(vae_path)
vae = sd.VAE(sd=sd_)
sd_, metadata = utils.load_torch_file(vae_path, return_metadata=True)
vae = sd.VAE(sd=sd_, metadata=metadata, ckpt_name=vae_name)
vae.throw_exception_if_invalid()
return (vae,)
@ -1321,12 +1322,12 @@ class RepeatLatentBatch:
s = samples.copy()
s_in = samples["samples"]
s["samples"] = s_in.repeat((amount,) + ((1, ) * (s_in.ndim - 1)))
s["samples"] = s_in.repeat((amount,) + ((1,) * (s_in.ndim - 1)))
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
masks = samples["noise_mask"]
if masks.shape[0] < s_in.shape[0]:
masks = masks.repeat((math.ceil(s_in.shape[0] / masks.shape[0]),) + ((1,) * (masks.ndim - 1)))[:s_in.shape[0]]
s["noise_mask"] = samples["noise_mask"].repeat((amount,) + ((1, ) * (samples["noise_mask"].ndim - 1)))
s["noise_mask"] = samples["noise_mask"].repeat((amount,) + ((1,) * (samples["noise_mask"].ndim - 1)))
if "batch_index" in s:
offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]

View File

@ -11,6 +11,8 @@ import yaml
from enum import Enum
from typing import Any, Optional
from humanize import naturalsize
from . import clip_vision
from . import diffusers_convert
from . import gligen
@ -281,7 +283,8 @@ class CLIP:
class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None, no_init=False):
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None, no_init=False, ckpt_name:Optional[str]=""):
self.ckpt_name = ckpt_name
if no_init:
return
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): # diffusers format
@ -459,8 +462,16 @@ class VAE:
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = wan_vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
# todo: not sure how to detect qwen here
wan_21_decode = 7000
wan_21_encode = wan_21_decode - 1000
qwen_vae_decode = int(wan_21_decode / 3)
qwen_vae_encode = int(wan_21_encode / 3)
encode_const = qwen_vae_encode if "qwen" in self.ckpt_name.lower() else wan_21_encode
decode_const = qwen_vae_decode if "qwen" in self.ckpt_name.lower() else wan_21_decode
self.memory_used_encode = lambda shape, dtype: encode_const * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: decode_const * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
self.latent_dim = 1
ln_post = "geo_decoder.ln_post.weight" in sd
@ -777,6 +788,14 @@ class VAE:
except:
return None
def __str__(self):
info_str = f"dtype={self.vae_dtype} device={self.device}"
if self.ckpt_name == "":
return f"<VAE for {self.first_stage_model.__class__.__name__} {info_str}>"
else:
return f"<VAE for {self.ckpt_name} ({self.first_stage_model.__class__.__name__} {info_str})>"
class StyleModel:
def __init__(self, model, device="cpu"):