mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-06 20:20:54 +08:00
added other types of attention + compatibility
with images
This commit is contained in:
parent
4d7012ecda
commit
9b573da39b
@ -19,9 +19,15 @@ if model_management.xformers_enabled():
|
||||
import xformers.ops
|
||||
|
||||
SAGE_ATTENTION_IS_AVAILABLE = False
|
||||
SAGE_ATTENTION_VAR_LENGTH_AVAILABLE = False
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
SAGE_ATTENTION_IS_AVAILABLE = True
|
||||
try:
|
||||
from sageattention import sageattn_varlen
|
||||
SAGE_ATTENTION_VAR_LENGTH_AVAILABLE = True
|
||||
except:
|
||||
pass
|
||||
except ImportError as e:
|
||||
if model_management.sage_attention_enabled():
|
||||
if e.name == "sageattention":
|
||||
@ -80,7 +86,13 @@ def default(val, d):
|
||||
return val
|
||||
return d
|
||||
|
||||
|
||||
def var_attn_arg(kwargs):
|
||||
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
|
||||
max_seqlen_q = kwargs.get("max_seqlen_q", None)
|
||||
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q)
|
||||
assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
|
||||
@ -404,14 +416,14 @@ except:
|
||||
pass
|
||||
|
||||
@wrap_attn
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
||||
b = q.shape[0]
|
||||
dim_head = q.shape[-1]
|
||||
# check to make sure xformers isn't broken
|
||||
disabled_xformers = False
|
||||
|
||||
if BROKEN_XFORMERS:
|
||||
if b * heads > 65535:
|
||||
if b * heads > 65535 and not var_length:
|
||||
disabled_xformers = True
|
||||
|
||||
if not disabled_xformers:
|
||||
@ -419,9 +431,24 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
disabled_xformers = True
|
||||
|
||||
if disabled_xformers:
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, var_length=var_length, **kwargs)
|
||||
|
||||
if skip_reshape:
|
||||
if var_length:
|
||||
if not skip_reshape:
|
||||
total_tokens, hidden_dim = q.shape
|
||||
dim_head = hidden_dim // heads
|
||||
q = q.view(1, total_tokens, heads, dim_head)
|
||||
k = k.view(1, total_tokens, heads, dim_head)
|
||||
v = v.view(1, total_tokens, heads, dim_head)
|
||||
else:
|
||||
if q.ndim == 3: q = q.unsqueeze(0)
|
||||
if k.ndim == 3: k = k.unsqueeze(0)
|
||||
if v.ndim == 3: v = v.unsqueeze(0)
|
||||
dim_head = q.shape[-1]
|
||||
|
||||
target_output_shape = (q.shape[1], -1)
|
||||
b = 1
|
||||
elif skip_reshape:
|
||||
# b h k d -> b k h d
|
||||
q, k, v = map(
|
||||
lambda t: t.permute(0, 2, 1, 3),
|
||||
@ -435,7 +462,11 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
if var_length:
|
||||
cu_seqlens_q, _, _, _ = var_attn_arg(kwargs)
|
||||
seq_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
|
||||
mask = xformers.ops.BlockDiagonalMask.from_seqlens(seq_lens_q=seq_lens, seq_lens_k=seq_lens)
|
||||
elif mask is not None:
|
||||
# add a singleton batch dimension
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
@ -457,6 +488,8 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
|
||||
if var_length:
|
||||
return out.reshape(*target_output_shape)
|
||||
if skip_output_reshape:
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
else:
|
||||
@ -475,9 +508,7 @@ else:
|
||||
@wrap_attn
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
||||
if var_length:
|
||||
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
|
||||
assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||
cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
|
||||
if not skip_reshape:
|
||||
# assumes 2D q, k,v [total_tokens, embed_dim]
|
||||
total_tokens, embed_dim = q.shape
|
||||
@ -539,9 +570,19 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
return out
|
||||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length = False, **kwargs):
|
||||
exception_fallback = False
|
||||
if skip_reshape:
|
||||
if var_length:
|
||||
if not skip_reshape:
|
||||
total_tokens, hidden_dim = q.shape
|
||||
dim_head = hidden_dim // heads
|
||||
q, k, v = [t.view(total_tokens, heads, dim_head) for t in (q, k, v)]
|
||||
b, _, dim_head = q.shape
|
||||
# skips batched code
|
||||
mask = None
|
||||
tensor_layout = "VAR"
|
||||
target_output_shape = (q.shape[0], -1)
|
||||
elif skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout = "HND"
|
||||
else:
|
||||
@ -562,7 +603,14 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
try:
|
||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
if var_length and not SAGE_ATTENTION_VAR_LENGTH_AVAILABLE:
|
||||
raise ValueError("Sage Attention two is required to run variable length attention.")
|
||||
elif var_length:
|
||||
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
|
||||
sm_scale = 1.0 / (q.shape[-1] ** 0.5)
|
||||
out = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, is_causal=False, sm_scale=sm_scale)
|
||||
else:
|
||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
except Exception as e:
|
||||
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
||||
exception_fallback = True
|
||||
@ -572,7 +620,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
lambda t: t.transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, var_length=var_length, **kwargs)
|
||||
|
||||
if tensor_layout == "HND":
|
||||
if not skip_output_reshape:
|
||||
@ -583,6 +631,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
if skip_output_reshape:
|
||||
out = out.transpose(1, 2)
|
||||
else:
|
||||
if var_length:
|
||||
return out.view(*target_output_shape)
|
||||
out = out.reshape(b, -1, heads * dim_head)
|
||||
return out
|
||||
|
||||
@ -608,12 +658,7 @@ except AttributeError as error:
|
||||
@wrap_attn
|
||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
||||
if var_length:
|
||||
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
|
||||
max_seqlen_q = kwargs.get("max_seqlen_q", None)
|
||||
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q)
|
||||
assert max_seqlen_q != None, "max_seqlen_q shouldn't be None when var_length is True"
|
||||
assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
|
||||
return flash_attn_varlen_func(
|
||||
q=q, k=k, v=v,
|
||||
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
||||
|
||||
@ -499,6 +499,8 @@ class InflatedCausalConv3d(ops.Conv3d):
|
||||
|
||||
def pad_and_forward():
|
||||
padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0)
|
||||
if not padded.is_contiguous():
|
||||
padded = padded.contiguous()
|
||||
with ignore_padding(self):
|
||||
return torch.nn.Conv3d.forward(self, padded)
|
||||
|
||||
@ -1726,7 +1728,7 @@ class VideoAutoencoderKL(nn.Module):
|
||||
return decoded
|
||||
|
||||
def _encode(
|
||||
self, x, memory_state
|
||||
self, x, memory_state = MemoryState.DISABLED
|
||||
) -> torch.Tensor:
|
||||
_x = x.to(self.device)
|
||||
h = self.encoder(_x, memory_state=memory_state)
|
||||
@ -1737,7 +1739,7 @@ class VideoAutoencoderKL(nn.Module):
|
||||
return output.to(x.device)
|
||||
|
||||
def _decode(
|
||||
self, z, memory_state
|
||||
self, z, memory_state = MemoryState.DISABLED
|
||||
) -> torch.Tensor:
|
||||
_z = z.to(self.device)
|
||||
|
||||
@ -1892,9 +1894,16 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
|
||||
# in case of padded frames
|
||||
t = input.size(0)
|
||||
x = x[:, :, :t]
|
||||
if t != 1:
|
||||
x = x[:, :, :t]
|
||||
if t == 1 and x.size(2) == 4:
|
||||
x = x[:, :, :t]
|
||||
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
if x.size(1) == 1:
|
||||
exp = "b t c h w -> (b t) c h w"
|
||||
else:
|
||||
exp = "b c t h w -> (b t) c h w"
|
||||
x = rearrange(x, exp)
|
||||
|
||||
input = input.to(x.device)
|
||||
x = wavelet_reconstruction(x, input)
|
||||
|
||||
@ -270,12 +270,11 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
||||
inputs = [
|
||||
io.Image.Input("images"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("resolution_height", default = 1280, min = 120), # //
|
||||
io.Int.Input("resolution_width", default = 720, min = 120), # just non-zero value
|
||||
io.Int.Input("resolution", default = 1280, min = 120), # just non-zero value
|
||||
io.Int.Input("spatial_tile_size", default = 512, min = 1),
|
||||
io.Int.Input("temporal_tile_size", default = 8, min = 1),
|
||||
io.Int.Input("spatial_overlap", default = 64, min = 1),
|
||||
io.Boolean.Input("enable_tiling", default=False)
|
||||
io.Int.Input("temporal_tile_size", default = 8, min = 1),
|
||||
io.Boolean.Input("enable_tiling", default=False),
|
||||
],
|
||||
outputs = [
|
||||
io.Latent.Output("vae_conditioning")
|
||||
@ -283,7 +282,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling):
|
||||
def execute(cls, images, vae, resolution, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling):
|
||||
device = vae.patcher.load_device
|
||||
|
||||
offload_device = comfy.model_management.intermediate_device()
|
||||
@ -298,11 +297,9 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
||||
b, t, c, h, w = images.shape
|
||||
images = images.reshape(b * t, c, h, w)
|
||||
|
||||
#max_area = ((resolution_height * resolution_width)** 0.5) ** 2
|
||||
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
|
||||
normalize = Normalize(0.5, 0.5)
|
||||
#images = area_resize(images, max_area)
|
||||
images = side_resize(images, resolution_height)
|
||||
images = side_resize(images, resolution)
|
||||
|
||||
images = clip(images)
|
||||
o_h, o_w = images.shape[-2:]
|
||||
@ -317,6 +314,17 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
||||
images = images.to(device)
|
||||
vae_model = vae_model.to(device)
|
||||
|
||||
# in case users a non-compatiable number for tiling
|
||||
def make_divisible(val, divisor):
|
||||
return max(divisor, round(val / divisor) * divisor)
|
||||
|
||||
temporal_tile_size = make_divisible(temporal_tile_size, 4)
|
||||
spatial_tile_size = make_divisible(spatial_tile_size, 32)
|
||||
spatial_overlap = make_divisible(spatial_overlap, 32)
|
||||
|
||||
if spatial_overlap >= spatial_tile_size:
|
||||
spatial_overlap = max(0, spatial_tile_size - 8)
|
||||
|
||||
args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap),
|
||||
"temporal_size":temporal_tile_size}
|
||||
if enable_tiling:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user