added other types of attention + compatibility

with images
This commit is contained in:
Yousef Rafat 2025-12-26 21:16:36 +02:00
parent 4d7012ecda
commit 9b573da39b
3 changed files with 93 additions and 31 deletions

View File

@ -19,9 +19,15 @@ if model_management.xformers_enabled():
import xformers.ops import xformers.ops
SAGE_ATTENTION_IS_AVAILABLE = False SAGE_ATTENTION_IS_AVAILABLE = False
SAGE_ATTENTION_VAR_LENGTH_AVAILABLE = False
try: try:
from sageattention import sageattn from sageattention import sageattn
SAGE_ATTENTION_IS_AVAILABLE = True SAGE_ATTENTION_IS_AVAILABLE = True
try:
from sageattention import sageattn_varlen
SAGE_ATTENTION_VAR_LENGTH_AVAILABLE = True
except:
pass
except ImportError as e: except ImportError as e:
if model_management.sage_attention_enabled(): if model_management.sage_attention_enabled():
if e.name == "sageattention": if e.name == "sageattention":
@ -80,7 +86,13 @@ def default(val, d):
return val return val
return d return d
def var_attn_arg(kwargs):
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
max_seqlen_q = kwargs.get("max_seqlen_q", None)
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q)
assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True"
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
# feedforward # feedforward
class GEGLU(nn.Module): class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops): def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
@ -404,14 +416,14 @@ except:
pass pass
@wrap_attn @wrap_attn
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
b = q.shape[0] b = q.shape[0]
dim_head = q.shape[-1] dim_head = q.shape[-1]
# check to make sure xformers isn't broken # check to make sure xformers isn't broken
disabled_xformers = False disabled_xformers = False
if BROKEN_XFORMERS: if BROKEN_XFORMERS:
if b * heads > 65535: if b * heads > 65535 and not var_length:
disabled_xformers = True disabled_xformers = True
if not disabled_xformers: if not disabled_xformers:
@ -419,9 +431,24 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
disabled_xformers = True disabled_xformers = True
if disabled_xformers: if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs) return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, var_length=var_length, **kwargs)
if skip_reshape: if var_length:
if not skip_reshape:
total_tokens, hidden_dim = q.shape
dim_head = hidden_dim // heads
q = q.view(1, total_tokens, heads, dim_head)
k = k.view(1, total_tokens, heads, dim_head)
v = v.view(1, total_tokens, heads, dim_head)
else:
if q.ndim == 3: q = q.unsqueeze(0)
if k.ndim == 3: k = k.unsqueeze(0)
if v.ndim == 3: v = v.unsqueeze(0)
dim_head = q.shape[-1]
target_output_shape = (q.shape[1], -1)
b = 1
elif skip_reshape:
# b h k d -> b k h d # b h k d -> b k h d
q, k, v = map( q, k, v = map(
lambda t: t.permute(0, 2, 1, 3), lambda t: t.permute(0, 2, 1, 3),
@ -435,7 +462,11 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
(q, k, v), (q, k, v),
) )
if mask is not None: if var_length:
cu_seqlens_q, _, _, _ = var_attn_arg(kwargs)
seq_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
mask = xformers.ops.BlockDiagonalMask.from_seqlens(seq_lens_q=seq_lens, seq_lens_k=seq_lens)
elif mask is not None:
# add a singleton batch dimension # add a singleton batch dimension
if mask.ndim == 2: if mask.ndim == 2:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
@ -457,6 +488,8 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
if var_length:
return out.reshape(*target_output_shape)
if skip_output_reshape: if skip_output_reshape:
out = out.permute(0, 2, 1, 3) out = out.permute(0, 2, 1, 3)
else: else:
@ -475,9 +508,7 @@ else:
@wrap_attn @wrap_attn
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs): def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
if var_length: if var_length:
cu_seqlens_q = kwargs.get("cu_seqlens_q", None) cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True"
if not skip_reshape: if not skip_reshape:
# assumes 2D q, k,v [total_tokens, embed_dim] # assumes 2D q, k,v [total_tokens, embed_dim]
total_tokens, embed_dim = q.shape total_tokens, embed_dim = q.shape
@ -539,9 +570,19 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
return out return out
@wrap_attn @wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length = False, **kwargs):
exception_fallback = False exception_fallback = False
if skip_reshape: if var_length:
if not skip_reshape:
total_tokens, hidden_dim = q.shape
dim_head = hidden_dim // heads
q, k, v = [t.view(total_tokens, heads, dim_head) for t in (q, k, v)]
b, _, dim_head = q.shape
# skips batched code
mask = None
tensor_layout = "VAR"
target_output_shape = (q.shape[0], -1)
elif skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
tensor_layout = "HND" tensor_layout = "HND"
else: else:
@ -562,7 +603,14 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
mask = mask.unsqueeze(1) mask = mask.unsqueeze(1)
try: try:
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) if var_length and not SAGE_ATTENTION_VAR_LENGTH_AVAILABLE:
raise ValueError("Sage Attention two is required to run variable length attention.")
elif var_length:
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
sm_scale = 1.0 / (q.shape[-1] ** 0.5)
out = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, is_causal=False, sm_scale=sm_scale)
else:
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
except Exception as e: except Exception as e:
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e)) logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
exception_fallback = True exception_fallback = True
@ -572,7 +620,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
lambda t: t.transpose(1, 2), lambda t: t.transpose(1, 2),
(q, k, v), (q, k, v),
) )
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs) return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, var_length=var_length, **kwargs)
if tensor_layout == "HND": if tensor_layout == "HND":
if not skip_output_reshape: if not skip_output_reshape:
@ -583,6 +631,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
if skip_output_reshape: if skip_output_reshape:
out = out.transpose(1, 2) out = out.transpose(1, 2)
else: else:
if var_length:
return out.view(*target_output_shape)
out = out.reshape(b, -1, heads * dim_head) out = out.reshape(b, -1, heads * dim_head)
return out return out
@ -608,12 +658,7 @@ except AttributeError as error:
@wrap_attn @wrap_attn
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs): def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
if var_length: if var_length:
cu_seqlens_q = kwargs.get("cu_seqlens_q", None) cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
max_seqlen_q = kwargs.get("max_seqlen_q", None)
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q)
assert max_seqlen_q != None, "max_seqlen_q shouldn't be None when var_length is True"
assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True"
return flash_attn_varlen_func( return flash_attn_varlen_func(
q=q, k=k, v=v, q=q, k=k, v=v,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,

View File

@ -499,6 +499,8 @@ class InflatedCausalConv3d(ops.Conv3d):
def pad_and_forward(): def pad_and_forward():
padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0) padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0)
if not padded.is_contiguous():
padded = padded.contiguous()
with ignore_padding(self): with ignore_padding(self):
return torch.nn.Conv3d.forward(self, padded) return torch.nn.Conv3d.forward(self, padded)
@ -1726,7 +1728,7 @@ class VideoAutoencoderKL(nn.Module):
return decoded return decoded
def _encode( def _encode(
self, x, memory_state self, x, memory_state = MemoryState.DISABLED
) -> torch.Tensor: ) -> torch.Tensor:
_x = x.to(self.device) _x = x.to(self.device)
h = self.encoder(_x, memory_state=memory_state) h = self.encoder(_x, memory_state=memory_state)
@ -1737,7 +1739,7 @@ class VideoAutoencoderKL(nn.Module):
return output.to(x.device) return output.to(x.device)
def _decode( def _decode(
self, z, memory_state self, z, memory_state = MemoryState.DISABLED
) -> torch.Tensor: ) -> torch.Tensor:
_z = z.to(self.device) _z = z.to(self.device)
@ -1892,9 +1894,16 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
# in case of padded frames # in case of padded frames
t = input.size(0) t = input.size(0)
x = x[:, :, :t] if t != 1:
x = x[:, :, :t]
if t == 1 and x.size(2) == 4:
x = x[:, :, :t]
x = rearrange(x, "b c t h w -> (b t) c h w") if x.size(1) == 1:
exp = "b t c h w -> (b t) c h w"
else:
exp = "b c t h w -> (b t) c h w"
x = rearrange(x, exp)
input = input.to(x.device) input = input.to(x.device)
x = wavelet_reconstruction(x, input) x = wavelet_reconstruction(x, input)

View File

@ -270,12 +270,11 @@ class SeedVR2InputProcessing(io.ComfyNode):
inputs = [ inputs = [
io.Image.Input("images"), io.Image.Input("images"),
io.Vae.Input("vae"), io.Vae.Input("vae"),
io.Int.Input("resolution_height", default = 1280, min = 120), # // io.Int.Input("resolution", default = 1280, min = 120), # just non-zero value
io.Int.Input("resolution_width", default = 720, min = 120), # just non-zero value
io.Int.Input("spatial_tile_size", default = 512, min = 1), io.Int.Input("spatial_tile_size", default = 512, min = 1),
io.Int.Input("temporal_tile_size", default = 8, min = 1),
io.Int.Input("spatial_overlap", default = 64, min = 1), io.Int.Input("spatial_overlap", default = 64, min = 1),
io.Boolean.Input("enable_tiling", default=False) io.Int.Input("temporal_tile_size", default = 8, min = 1),
io.Boolean.Input("enable_tiling", default=False),
], ],
outputs = [ outputs = [
io.Latent.Output("vae_conditioning") io.Latent.Output("vae_conditioning")
@ -283,7 +282,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling): def execute(cls, images, vae, resolution, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling):
device = vae.patcher.load_device device = vae.patcher.load_device
offload_device = comfy.model_management.intermediate_device() offload_device = comfy.model_management.intermediate_device()
@ -298,11 +297,9 @@ class SeedVR2InputProcessing(io.ComfyNode):
b, t, c, h, w = images.shape b, t, c, h, w = images.shape
images = images.reshape(b * t, c, h, w) images = images.reshape(b * t, c, h, w)
#max_area = ((resolution_height * resolution_width)** 0.5) ** 2
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
normalize = Normalize(0.5, 0.5) normalize = Normalize(0.5, 0.5)
#images = area_resize(images, max_area) images = side_resize(images, resolution)
images = side_resize(images, resolution_height)
images = clip(images) images = clip(images)
o_h, o_w = images.shape[-2:] o_h, o_w = images.shape[-2:]
@ -317,6 +314,17 @@ class SeedVR2InputProcessing(io.ComfyNode):
images = images.to(device) images = images.to(device)
vae_model = vae_model.to(device) vae_model = vae_model.to(device)
# in case users a non-compatiable number for tiling
def make_divisible(val, divisor):
return max(divisor, round(val / divisor) * divisor)
temporal_tile_size = make_divisible(temporal_tile_size, 4)
spatial_tile_size = make_divisible(spatial_tile_size, 32)
spatial_overlap = make_divisible(spatial_overlap, 32)
if spatial_overlap >= spatial_tile_size:
spatial_overlap = max(0, spatial_tile_size - 8)
args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap), args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap),
"temporal_size":temporal_tile_size} "temporal_size":temporal_tile_size}
if enable_tiling: if enable_tiling: