From 31d358c78c4414090e191b3aaceea42378de8c58 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 4 Jan 2026 19:15:53 +0200 Subject: [PATCH] rope, attetntion update | vae on cpu warning --- comfy/ldm/modules/attention.py | 12 ++++++++---- comfy/ldm/seedvr/model.py | 19 +++++++++++-------- comfy/ldm/seedvr/vae.py | 14 ++++++++++++-- comfy/sd.py | 4 ++-- 4 files changed, 33 insertions(+), 16 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 6163aec22..be253a010 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -416,7 +416,8 @@ except: pass @wrap_attn -def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs): +def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + var_length = kwargs.get("var_length", False) b = q.shape[0] dim_head = q.shape[-1] # check to make sure xformers isn't broken @@ -506,7 +507,8 @@ else: SDP_BATCH_LIMIT = 2**31 @wrap_attn -def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs): +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + var_length = kwargs.get("var_length", False) if var_length: cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs) if not skip_reshape: @@ -570,7 +572,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha return out @wrap_attn -def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length = False, **kwargs): +def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + var_length = kwargs.get("var_length", False) exception_fallback = False if var_length: if not skip_reshape: @@ -656,7 +659,8 @@ except AttributeError as error: assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" @wrap_attn -def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs): +def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + var_length = kwargs.get("var_length", False) if var_length: cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs) return flash_attn_varlen_func( diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index bd0057332..6c3e9c526 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -13,6 +13,7 @@ from comfy.rmsnorm import RMSNorm from torch.nn.modules.utils import _triple from torch import nn import math +from comfy.ldm.flux.math import apply_rope1 class Cache: def __init__(self, disable=False, prefix="", cache=None): @@ -443,7 +444,6 @@ def apply_rotary_emb( freqs_seq_dim = None ): dtype = t.dtype - if not exists(freqs_seq_dim): if freqs.ndim == 2 or t.ndim == 3: freqs_seq_dim = 0 @@ -452,20 +452,23 @@ def apply_rotary_emb( seq_len = t.shape[seq_dim] freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim) - rot_dim = freqs.shape[-1] - end_index = start_index + rot_dim - - assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' + rot_feats = freqs.shape[-1] + end_index = start_index + rot_feats t_left = t[..., :start_index] t_middle = t[..., start_index:end_index] t_right = t[..., end_index:] - freqs = freqs.to(t_middle.device) - t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale) + angles = freqs.to(t_middle.device)[..., ::2] + cos = torch.cos(angles) * scale + sin = torch.sin(angles) * scale - out = torch.cat((t_left, t_transformed, t_right), dim=-1) + col0 = torch.stack([cos, sin], dim=-1) + col1 = torch.stack([-sin, cos], dim=-1) + freqs_mat = torch.stack([col0, col1], dim=-1) + t_middle_out = apply_rope1(t_middle, freqs_mat) + out = torch.cat((t_left, t_middle_out, t_right), dim=-1) return out.type(dtype) class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 292958a88..d218b90e9 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -16,6 +16,7 @@ import math from enum import Enum from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND +import logging import comfy.ops ops = comfy.ops.disable_weight_init @@ -446,6 +447,7 @@ class InflatedCausalConv3d(ops.Conv3d): self.memory_device = memory_device self.padding = (0, *self.padding[1:]) self.memory_limit = float("inf") + self.logged_once = False def set_memory_limit(self, value: float): self.memory_limit = value @@ -469,8 +471,16 @@ class InflatedCausalConv3d(ops.Conv3d): return out except RuntimeError: pass - - return super()._conv_forward(input, weight, bias, *args, **kwargs) + except NotImplementedError: + pass + try: + return super()._conv_forward(input, weight, bias, *args, **kwargs) + except NotImplementedError: + # for: Could not run 'aten::cudnn_convolution' with arguments from the 'CPU' backend + if not self.logged_once: + logging.warning("VAE is on CPU for decoding. This is most likely due to being not enough memory") + self.logged_once = True + return F.conv3d(input, weight, bias, *args, **kwargs) def memory_limit_conv( self, diff --git a/comfy/sd.py b/comfy/sd.py index 69ec40756..102d1a026 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -382,8 +382,8 @@ class VAE: self.latent_channels = 16 elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2 self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper() - self.memory_used_decode = lambda shape, dtype: (10 * shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (10 * max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.working_dtypes = [torch.bfloat16, torch.float32] self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) self.downscale_index_formula = (4, 8, 8)