From 9b573da39b5ca9d08104840206ec76d4b6601c8e Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 26 Dec 2025 21:16:36 +0200 Subject: [PATCH] added other types of attention + compatibility with images --- comfy/ldm/modules/attention.py | 83 ++++++++++++++++++++++++++-------- comfy/ldm/seedvr/vae.py | 17 +++++-- comfy_extras/nodes_seedvr.py | 24 ++++++---- 3 files changed, 93 insertions(+), 31 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 332c65ffb..c7a15a5c8 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -19,9 +19,15 @@ if model_management.xformers_enabled(): import xformers.ops SAGE_ATTENTION_IS_AVAILABLE = False +SAGE_ATTENTION_VAR_LENGTH_AVAILABLE = False try: from sageattention import sageattn SAGE_ATTENTION_IS_AVAILABLE = True + try: + from sageattention import sageattn_varlen + SAGE_ATTENTION_VAR_LENGTH_AVAILABLE = True + except: + pass except ImportError as e: if model_management.sage_attention_enabled(): if e.name == "sageattention": @@ -80,7 +86,13 @@ def default(val, d): return val return d - +def var_attn_arg(kwargs): + cu_seqlens_q = kwargs.get("cu_seqlens_q", None) + cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) + max_seqlen_q = kwargs.get("max_seqlen_q", None) + max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q) + assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True" + return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k # feedforward class GEGLU(nn.Module): def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops): @@ -404,14 +416,14 @@ except: pass @wrap_attn -def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): +def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs): b = q.shape[0] dim_head = q.shape[-1] # check to make sure xformers isn't broken disabled_xformers = False if BROKEN_XFORMERS: - if b * heads > 65535: + if b * heads > 65535 and not var_length: disabled_xformers = True if not disabled_xformers: @@ -419,9 +431,24 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh disabled_xformers = True if disabled_xformers: - return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs) + return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, var_length=var_length, **kwargs) - if skip_reshape: + if var_length: + if not skip_reshape: + total_tokens, hidden_dim = q.shape + dim_head = hidden_dim // heads + q = q.view(1, total_tokens, heads, dim_head) + k = k.view(1, total_tokens, heads, dim_head) + v = v.view(1, total_tokens, heads, dim_head) + else: + if q.ndim == 3: q = q.unsqueeze(0) + if k.ndim == 3: k = k.unsqueeze(0) + if v.ndim == 3: v = v.unsqueeze(0) + dim_head = q.shape[-1] + + target_output_shape = (q.shape[1], -1) + b = 1 + elif skip_reshape: # b h k d -> b k h d q, k, v = map( lambda t: t.permute(0, 2, 1, 3), @@ -435,7 +462,11 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh (q, k, v), ) - if mask is not None: + if var_length: + cu_seqlens_q, _, _, _ = var_attn_arg(kwargs) + seq_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + mask = xformers.ops.BlockDiagonalMask.from_seqlens(seq_lens_q=seq_lens, seq_lens_k=seq_lens) + elif mask is not None: # add a singleton batch dimension if mask.ndim == 2: mask = mask.unsqueeze(0) @@ -457,6 +488,8 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) + if var_length: + return out.reshape(*target_output_shape) if skip_output_reshape: out = out.permute(0, 2, 1, 3) else: @@ -475,9 +508,7 @@ else: @wrap_attn def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs): if var_length: - cu_seqlens_q = kwargs.get("cu_seqlens_q", None) - cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) - assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True" + cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs) if not skip_reshape: # assumes 2D q, k,v [total_tokens, embed_dim] total_tokens, embed_dim = q.shape @@ -539,9 +570,19 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha return out @wrap_attn -def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): +def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length = False, **kwargs): exception_fallback = False - if skip_reshape: + if var_length: + if not skip_reshape: + total_tokens, hidden_dim = q.shape + dim_head = hidden_dim // heads + q, k, v = [t.view(total_tokens, heads, dim_head) for t in (q, k, v)] + b, _, dim_head = q.shape + # skips batched code + mask = None + tensor_layout = "VAR" + target_output_shape = (q.shape[0], -1) + elif skip_reshape: b, _, _, dim_head = q.shape tensor_layout = "HND" else: @@ -562,7 +603,14 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= mask = mask.unsqueeze(1) try: - out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) + if var_length and not SAGE_ATTENTION_VAR_LENGTH_AVAILABLE: + raise ValueError("Sage Attention two is required to run variable length attention.") + elif var_length: + cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs) + sm_scale = 1.0 / (q.shape[-1] ** 0.5) + out = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, is_causal=False, sm_scale=sm_scale) + else: + out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) except Exception as e: logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e)) exception_fallback = True @@ -572,7 +620,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= lambda t: t.transpose(1, 2), (q, k, v), ) - return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs) + return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, var_length=var_length, **kwargs) if tensor_layout == "HND": if not skip_output_reshape: @@ -583,6 +631,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= if skip_output_reshape: out = out.transpose(1, 2) else: + if var_length: + return out.view(*target_output_shape) out = out.reshape(b, -1, heads * dim_head) return out @@ -608,12 +658,7 @@ except AttributeError as error: @wrap_attn def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs): if var_length: - cu_seqlens_q = kwargs.get("cu_seqlens_q", None) - cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) - max_seqlen_q = kwargs.get("max_seqlen_q", None) - max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q) - assert max_seqlen_q != None, "max_seqlen_q shouldn't be None when var_length is True" - assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True" + cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs) return flash_attn_varlen_func( q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 9fcea60ad..c9fef0677 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -499,6 +499,8 @@ class InflatedCausalConv3d(ops.Conv3d): def pad_and_forward(): padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0) + if not padded.is_contiguous(): + padded = padded.contiguous() with ignore_padding(self): return torch.nn.Conv3d.forward(self, padded) @@ -1726,7 +1728,7 @@ class VideoAutoencoderKL(nn.Module): return decoded def _encode( - self, x, memory_state + self, x, memory_state = MemoryState.DISABLED ) -> torch.Tensor: _x = x.to(self.device) h = self.encoder(_x, memory_state=memory_state) @@ -1737,7 +1739,7 @@ class VideoAutoencoderKL(nn.Module): return output.to(x.device) def _decode( - self, z, memory_state + self, z, memory_state = MemoryState.DISABLED ) -> torch.Tensor: _z = z.to(self.device) @@ -1892,9 +1894,16 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): # in case of padded frames t = input.size(0) - x = x[:, :, :t] + if t != 1: + x = x[:, :, :t] + if t == 1 and x.size(2) == 4: + x = x[:, :, :t] - x = rearrange(x, "b c t h w -> (b t) c h w") + if x.size(1) == 1: + exp = "b t c h w -> (b t) c h w" + else: + exp = "b c t h w -> (b t) c h w" + x = rearrange(x, exp) input = input.to(x.device) x = wavelet_reconstruction(x, input) diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 22b117872..4ec089dde 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -270,12 +270,11 @@ class SeedVR2InputProcessing(io.ComfyNode): inputs = [ io.Image.Input("images"), io.Vae.Input("vae"), - io.Int.Input("resolution_height", default = 1280, min = 120), # // - io.Int.Input("resolution_width", default = 720, min = 120), # just non-zero value + io.Int.Input("resolution", default = 1280, min = 120), # just non-zero value io.Int.Input("spatial_tile_size", default = 512, min = 1), - io.Int.Input("temporal_tile_size", default = 8, min = 1), io.Int.Input("spatial_overlap", default = 64, min = 1), - io.Boolean.Input("enable_tiling", default=False) + io.Int.Input("temporal_tile_size", default = 8, min = 1), + io.Boolean.Input("enable_tiling", default=False), ], outputs = [ io.Latent.Output("vae_conditioning") @@ -283,7 +282,7 @@ class SeedVR2InputProcessing(io.ComfyNode): ) @classmethod - def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling): + def execute(cls, images, vae, resolution, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling): device = vae.patcher.load_device offload_device = comfy.model_management.intermediate_device() @@ -298,11 +297,9 @@ class SeedVR2InputProcessing(io.ComfyNode): b, t, c, h, w = images.shape images = images.reshape(b * t, c, h, w) - #max_area = ((resolution_height * resolution_width)** 0.5) ** 2 clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) normalize = Normalize(0.5, 0.5) - #images = area_resize(images, max_area) - images = side_resize(images, resolution_height) + images = side_resize(images, resolution) images = clip(images) o_h, o_w = images.shape[-2:] @@ -317,6 +314,17 @@ class SeedVR2InputProcessing(io.ComfyNode): images = images.to(device) vae_model = vae_model.to(device) + # in case users a non-compatiable number for tiling + def make_divisible(val, divisor): + return max(divisor, round(val / divisor) * divisor) + + temporal_tile_size = make_divisible(temporal_tile_size, 4) + spatial_tile_size = make_divisible(spatial_tile_size, 32) + spatial_overlap = make_divisible(spatial_overlap, 32) + + if spatial_overlap >= spatial_tile_size: + spatial_overlap = max(0, spatial_tile_size - 8) + args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap), "temporal_size":temporal_tile_size} if enable_tiling: