rope, attetntion update | vae on cpu warning

This commit is contained in:
Yousef Rafat 2026-01-04 19:15:53 +02:00
parent 49febe15c3
commit 31d358c78c
4 changed files with 33 additions and 16 deletions

View File

@ -416,7 +416,8 @@ except:
pass
@wrap_attn
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
var_length = kwargs.get("var_length", False)
b = q.shape[0]
dim_head = q.shape[-1]
# check to make sure xformers isn't broken
@ -506,7 +507,8 @@ else:
SDP_BATCH_LIMIT = 2**31
@wrap_attn
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
var_length = kwargs.get("var_length", False)
if var_length:
cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
if not skip_reshape:
@ -570,7 +572,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
return out
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length = False, **kwargs):
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
var_length = kwargs.get("var_length", False)
exception_fallback = False
if var_length:
if not skip_reshape:
@ -656,7 +659,8 @@ except AttributeError as error:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
@wrap_attn
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
var_length = kwargs.get("var_length", False)
if var_length:
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
return flash_attn_varlen_func(

View File

@ -13,6 +13,7 @@ from comfy.rmsnorm import RMSNorm
from torch.nn.modules.utils import _triple
from torch import nn
import math
from comfy.ldm.flux.math import apply_rope1
class Cache:
def __init__(self, disable=False, prefix="", cache=None):
@ -443,7 +444,6 @@ def apply_rotary_emb(
freqs_seq_dim = None
):
dtype = t.dtype
if not exists(freqs_seq_dim):
if freqs.ndim == 2 or t.ndim == 3:
freqs_seq_dim = 0
@ -452,20 +452,23 @@ def apply_rotary_emb(
seq_len = t.shape[seq_dim]
freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim)
rot_dim = freqs.shape[-1]
end_index = start_index + rot_dim
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
rot_feats = freqs.shape[-1]
end_index = start_index + rot_feats
t_left = t[..., :start_index]
t_middle = t[..., start_index:end_index]
t_right = t[..., end_index:]
freqs = freqs.to(t_middle.device)
t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
angles = freqs.to(t_middle.device)[..., ::2]
cos = torch.cos(angles) * scale
sin = torch.sin(angles) * scale
out = torch.cat((t_left, t_transformed, t_right), dim=-1)
col0 = torch.stack([cos, sin], dim=-1)
col1 = torch.stack([-sin, cos], dim=-1)
freqs_mat = torch.stack([col0, col1], dim=-1)
t_middle_out = apply_rope1(t_middle, freqs_mat)
out = torch.cat((t_left, t_middle_out, t_right), dim=-1)
return out.type(dtype)
class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):

View File

@ -16,6 +16,7 @@ import math
from enum import Enum
from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND
import logging
import comfy.ops
ops = comfy.ops.disable_weight_init
@ -446,6 +447,7 @@ class InflatedCausalConv3d(ops.Conv3d):
self.memory_device = memory_device
self.padding = (0, *self.padding[1:])
self.memory_limit = float("inf")
self.logged_once = False
def set_memory_limit(self, value: float):
self.memory_limit = value
@ -469,8 +471,16 @@ class InflatedCausalConv3d(ops.Conv3d):
return out
except RuntimeError:
pass
return super()._conv_forward(input, weight, bias, *args, **kwargs)
except NotImplementedError:
pass
try:
return super()._conv_forward(input, weight, bias, *args, **kwargs)
except NotImplementedError:
# for: Could not run 'aten::cudnn_convolution' with arguments from the 'CPU' backend
if not self.logged_once:
logging.warning("VAE is on CPU for decoding. This is most likely due to being not enough memory")
self.logged_once = True
return F.conv3d(input, weight, bias, *args, **kwargs)
def memory_limit_conv(
self,

View File

@ -382,8 +382,8 @@ class VAE:
self.latent_channels = 16
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
self.memory_used_decode = lambda shape, dtype: (10 * shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (10 * max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)