mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Fix sage attention for qwen-image, fix Qwen Image VAE memory usage, improve compatibility with hooks when using KSampler based workflows and non-ModelPatcher model manageable-stuff
This commit is contained in:
parent
b62d4f05e1
commit
f5e29f0e61
@ -22,11 +22,9 @@ if model_management.xformers_enabled():
|
||||
sageattn = None
|
||||
if model_management.sage_attention_enabled():
|
||||
try:
|
||||
from sageattention import sageattn # pylint: disable=import-error
|
||||
from .sage_attention_dispatcher import sageattn
|
||||
except ModuleNotFoundError as e:
|
||||
if e.name == "sageattention":
|
||||
import sys
|
||||
|
||||
logger.error(f"To use the `--use-sage-attention` feature, the `sageattention` package must be installed first.")
|
||||
else:
|
||||
raise e
|
||||
|
||||
93
comfy/ldm/modules/sage_attention_dispatcher.py
Normal file
93
comfy/ldm/modules/sage_attention_dispatcher.py
Normal file
@ -0,0 +1,93 @@
|
||||
from typing import Optional, Any
|
||||
|
||||
import torch
|
||||
# only imported when sage attention is enabled
|
||||
from sageattention import * # pylint: disable=import-error
|
||||
|
||||
|
||||
def get_cuda_arch_versions():
|
||||
cuda_archs = []
|
||||
for i in range(torch.cuda.device_count()):
|
||||
major, minor = torch.cuda.get_device_capability(i)
|
||||
cuda_archs.append(f"sm{major}{minor}")
|
||||
return cuda_archs
|
||||
|
||||
|
||||
def sageattn(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
tensor_layout: str = "HND",
|
||||
is_causal: bool = False,
|
||||
sm_scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
q : torch.Tensor
|
||||
The query tensor. Shape:
|
||||
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
|
||||
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
|
||||
|
||||
k : torch.Tensor
|
||||
The key tensor. Shape:
|
||||
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
|
||||
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
|
||||
|
||||
v : torch.Tensor
|
||||
The value tensor. Shape:
|
||||
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
|
||||
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
|
||||
|
||||
tensor_layout : str
|
||||
The tensor layout, either "HND" or "NHD".
|
||||
Default: "HND".
|
||||
|
||||
is_causal : bool
|
||||
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
|
||||
Default: False.
|
||||
|
||||
sm_scale : Optional[float]
|
||||
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
|
||||
|
||||
return_lse : bool
|
||||
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
|
||||
Default: False.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The output tensor. Shape:
|
||||
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
|
||||
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
|
||||
|
||||
torch.Tensor
|
||||
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
|
||||
Shape: ``[batch_size, num_qo_heads, qo_len]``.
|
||||
Only returned if `return_lse` is True.
|
||||
|
||||
Note
|
||||
----
|
||||
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
|
||||
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
|
||||
- All tensors must be on the same cuda device.
|
||||
"""
|
||||
|
||||
arch = get_cuda_arch_versions()[q.device.index]
|
||||
if arch in ("sm80", "sm86"):
|
||||
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32")
|
||||
# todo: the triton kernel is broken on ampere, so disable it
|
||||
# elif arch == "sm86":
|
||||
# return sageattn_qk_int8_pv_fp16_triton(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse)
|
||||
elif arch == "sm89":
|
||||
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16")
|
||||
elif arch == "sm90":
|
||||
return sageattn_qk_int8_pv_fp8_cuda_sm90(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
|
||||
elif arch == "sm120":
|
||||
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16") # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
|
||||
else:
|
||||
raise ValueError(f"Unsupported CUDA architecture: {arch}")
|
||||
@ -8,7 +8,7 @@ from typing import Optional, Tuple
|
||||
from ..common_dit import pad_to_patch_size
|
||||
from ..flux.layers import EmbedND
|
||||
from ..lightricks.model import TimestepEmbedding, Timesteps
|
||||
from ..modules.attention import optimized_attention_no_sage_masked as optimized_attention_masked
|
||||
from ..modules.attention import optimized_attention_masked
|
||||
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
|
||||
|
||||
class GELU(nn.Module):
|
||||
|
||||
@ -30,6 +30,68 @@ class HooksSupport(Protocol, metaclass=ABCMeta):
|
||||
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options):
|
||||
return
|
||||
|
||||
def model_patches_models(self) -> list[ModelManageableT]:
|
||||
"""
|
||||
Used to implement Qwen DiffSynth Controlnets (?)
|
||||
:return:
|
||||
"""
|
||||
return []
|
||||
|
||||
@property
|
||||
def hook_mode(self):
|
||||
from .hooks import EnumHookMode
|
||||
if not hasattr(self, "_hook_mode"):
|
||||
setattr(self, "_hook_mode", EnumHookMode.MaxSpeed)
|
||||
return getattr(self, "_hook_mode")
|
||||
|
||||
|
||||
@hook_mode.setter
|
||||
def hook_mode(self, value):
|
||||
setattr(self, "_hook_mode", value)
|
||||
|
||||
def restore_hook_patches(self):
|
||||
return
|
||||
|
||||
@property
|
||||
def wrappers(self):
|
||||
if not hasattr(self, "_wrappers"):
|
||||
setattr(self, "_wrappers", {})
|
||||
return getattr(self, "_wrappers")
|
||||
|
||||
@wrappers.setter
|
||||
def wrappers(self, value):
|
||||
setattr(self, "_wrappers", value)
|
||||
|
||||
@property
|
||||
def callbacks(self) -> dict:
|
||||
if not hasattr(self, "_callbacks"):
|
||||
setattr(self, "_callbacks", {})
|
||||
return getattr(self, "_callbacks")
|
||||
|
||||
@callbacks.setter
|
||||
def callbacks(self, value):
|
||||
setattr(self, "_callbacks", value)
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
|
||||
def pre_run(self):
|
||||
from .model_base import BaseModel
|
||||
if hasattr(self, "model") and isinstance(self.model, BaseModel):
|
||||
self.model.current_patcher = self
|
||||
|
||||
def prepare_state(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def register_all_hook_patches(self, a, b, c, d):
|
||||
pass
|
||||
|
||||
def get_nested_additional_models(self):
|
||||
return []
|
||||
|
||||
def apply_hooks(self, *args, **kwargs):
|
||||
return {}
|
||||
|
||||
|
||||
class TrainingSupport(Protocol, metaclass=ABCMeta):
|
||||
def set_model_compute_dtype(self, dtype: torch.dtype):
|
||||
@ -146,14 +208,14 @@ class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol,
|
||||
self.unpatch_model(device_to)
|
||||
return self.model_size()
|
||||
|
||||
def memory_required(self, input_shape) -> int:
|
||||
def memory_required(self, input_shape: torch.Size) -> int:
|
||||
from .model_base import BaseModel
|
||||
|
||||
if isinstance(self.model, BaseModel):
|
||||
return self.model.memory_required(input_shape=input_shape)
|
||||
else:
|
||||
# todo: why isn't this true?
|
||||
return self.model_size()
|
||||
# todo: we need a real implementation of this
|
||||
return 0
|
||||
|
||||
def loaded_size(self) -> int:
|
||||
if self.current_loaded_device() == self.load_device:
|
||||
@ -198,13 +260,6 @@ class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol,
|
||||
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
||||
return self.model
|
||||
|
||||
def model_patches_models(self) -> list[ModelManageableT]:
|
||||
"""
|
||||
Used to implement Qwen DiffSynth Controlnets (?)
|
||||
:return:
|
||||
"""
|
||||
return []
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MemoryMeasurements:
|
||||
|
||||
@ -812,10 +812,11 @@ class VAELoader:
|
||||
def load_vae(self, vae_name):
|
||||
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
||||
sd_ = self.load_taesd(vae_name)
|
||||
metadata = {}
|
||||
else:
|
||||
vae_path = get_or_download("vae", vae_name, KNOWN_VAES)
|
||||
sd_ = utils.load_torch_file(vae_path)
|
||||
vae = sd.VAE(sd=sd_)
|
||||
sd_, metadata = utils.load_torch_file(vae_path, return_metadata=True)
|
||||
vae = sd.VAE(sd=sd_, metadata=metadata, ckpt_name=vae_name)
|
||||
vae.throw_exception_if_invalid()
|
||||
return (vae,)
|
||||
|
||||
@ -1321,12 +1322,12 @@ class RepeatLatentBatch:
|
||||
s = samples.copy()
|
||||
s_in = samples["samples"]
|
||||
|
||||
s["samples"] = s_in.repeat((amount,) + ((1, ) * (s_in.ndim - 1)))
|
||||
s["samples"] = s_in.repeat((amount,) + ((1,) * (s_in.ndim - 1)))
|
||||
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
|
||||
masks = samples["noise_mask"]
|
||||
if masks.shape[0] < s_in.shape[0]:
|
||||
masks = masks.repeat((math.ceil(s_in.shape[0] / masks.shape[0]),) + ((1,) * (masks.ndim - 1)))[:s_in.shape[0]]
|
||||
s["noise_mask"] = samples["noise_mask"].repeat((amount,) + ((1, ) * (samples["noise_mask"].ndim - 1)))
|
||||
s["noise_mask"] = samples["noise_mask"].repeat((amount,) + ((1,) * (samples["noise_mask"].ndim - 1)))
|
||||
if "batch_index" in s:
|
||||
offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
|
||||
s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
|
||||
|
||||
25
comfy/sd.py
25
comfy/sd.py
@ -11,6 +11,8 @@ import yaml
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from humanize import naturalsize
|
||||
|
||||
from . import clip_vision
|
||||
from . import diffusers_convert
|
||||
from . import gligen
|
||||
@ -281,7 +283,8 @@ class CLIP:
|
||||
|
||||
|
||||
class VAE:
|
||||
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None, no_init=False):
|
||||
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None, no_init=False, ckpt_name:Optional[str]=""):
|
||||
self.ckpt_name = ckpt_name
|
||||
if no_init:
|
||||
return
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): # diffusers format
|
||||
@ -459,8 +462,16 @@ class VAE:
|
||||
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||
self.first_stage_model = wan_vae.WanVAE(**ddconfig)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
||||
|
||||
# todo: not sure how to detect qwen here
|
||||
wan_21_decode = 7000
|
||||
wan_21_encode = wan_21_decode - 1000
|
||||
qwen_vae_decode = int(wan_21_decode / 3)
|
||||
qwen_vae_encode = int(wan_21_encode / 3)
|
||||
encode_const = qwen_vae_encode if "qwen" in self.ckpt_name.lower() else wan_21_encode
|
||||
decode_const = qwen_vae_decode if "qwen" in self.ckpt_name.lower() else wan_21_decode
|
||||
self.memory_used_encode = lambda shape, dtype: encode_const * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: decode_const * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
||||
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
||||
self.latent_dim = 1
|
||||
ln_post = "geo_decoder.ln_post.weight" in sd
|
||||
@ -777,6 +788,14 @@ class VAE:
|
||||
except:
|
||||
return None
|
||||
|
||||
def __str__(self):
|
||||
info_str = f"dtype={self.vae_dtype} device={self.device}"
|
||||
|
||||
if self.ckpt_name == "":
|
||||
return f"<VAE for {self.first_stage_model.__class__.__name__} {info_str}>"
|
||||
else:
|
||||
return f"<VAE for {self.ckpt_name} ({self.first_stage_model.__class__.__name__} {info_str})>"
|
||||
|
||||
|
||||
class StyleModel:
|
||||
def __init__(self, model, device="cpu"):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user