mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-07 21:00:49 +08:00
added other types of attention + compatibility
with images
This commit is contained in:
parent
4d7012ecda
commit
9b573da39b
@ -19,9 +19,15 @@ if model_management.xformers_enabled():
|
|||||||
import xformers.ops
|
import xformers.ops
|
||||||
|
|
||||||
SAGE_ATTENTION_IS_AVAILABLE = False
|
SAGE_ATTENTION_IS_AVAILABLE = False
|
||||||
|
SAGE_ATTENTION_VAR_LENGTH_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
from sageattention import sageattn
|
from sageattention import sageattn
|
||||||
SAGE_ATTENTION_IS_AVAILABLE = True
|
SAGE_ATTENTION_IS_AVAILABLE = True
|
||||||
|
try:
|
||||||
|
from sageattention import sageattn_varlen
|
||||||
|
SAGE_ATTENTION_VAR_LENGTH_AVAILABLE = True
|
||||||
|
except:
|
||||||
|
pass
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
if model_management.sage_attention_enabled():
|
if model_management.sage_attention_enabled():
|
||||||
if e.name == "sageattention":
|
if e.name == "sageattention":
|
||||||
@ -80,7 +86,13 @@ def default(val, d):
|
|||||||
return val
|
return val
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
def var_attn_arg(kwargs):
|
||||||
|
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||||
|
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
|
||||||
|
max_seqlen_q = kwargs.get("max_seqlen_q", None)
|
||||||
|
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q)
|
||||||
|
assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||||
|
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
|
||||||
# feedforward
|
# feedforward
|
||||||
class GEGLU(nn.Module):
|
class GEGLU(nn.Module):
|
||||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
|
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
|
||||||
@ -404,14 +416,14 @@ except:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
||||||
b = q.shape[0]
|
b = q.shape[0]
|
||||||
dim_head = q.shape[-1]
|
dim_head = q.shape[-1]
|
||||||
# check to make sure xformers isn't broken
|
# check to make sure xformers isn't broken
|
||||||
disabled_xformers = False
|
disabled_xformers = False
|
||||||
|
|
||||||
if BROKEN_XFORMERS:
|
if BROKEN_XFORMERS:
|
||||||
if b * heads > 65535:
|
if b * heads > 65535 and not var_length:
|
||||||
disabled_xformers = True
|
disabled_xformers = True
|
||||||
|
|
||||||
if not disabled_xformers:
|
if not disabled_xformers:
|
||||||
@ -419,9 +431,24 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
|||||||
disabled_xformers = True
|
disabled_xformers = True
|
||||||
|
|
||||||
if disabled_xformers:
|
if disabled_xformers:
|
||||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
|
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, var_length=var_length, **kwargs)
|
||||||
|
|
||||||
if skip_reshape:
|
if var_length:
|
||||||
|
if not skip_reshape:
|
||||||
|
total_tokens, hidden_dim = q.shape
|
||||||
|
dim_head = hidden_dim // heads
|
||||||
|
q = q.view(1, total_tokens, heads, dim_head)
|
||||||
|
k = k.view(1, total_tokens, heads, dim_head)
|
||||||
|
v = v.view(1, total_tokens, heads, dim_head)
|
||||||
|
else:
|
||||||
|
if q.ndim == 3: q = q.unsqueeze(0)
|
||||||
|
if k.ndim == 3: k = k.unsqueeze(0)
|
||||||
|
if v.ndim == 3: v = v.unsqueeze(0)
|
||||||
|
dim_head = q.shape[-1]
|
||||||
|
|
||||||
|
target_output_shape = (q.shape[1], -1)
|
||||||
|
b = 1
|
||||||
|
elif skip_reshape:
|
||||||
# b h k d -> b k h d
|
# b h k d -> b k h d
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.permute(0, 2, 1, 3),
|
lambda t: t.permute(0, 2, 1, 3),
|
||||||
@ -435,7 +462,11 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
|||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
|
|
||||||
if mask is not None:
|
if var_length:
|
||||||
|
cu_seqlens_q, _, _, _ = var_attn_arg(kwargs)
|
||||||
|
seq_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
|
||||||
|
mask = xformers.ops.BlockDiagonalMask.from_seqlens(seq_lens_q=seq_lens, seq_lens_k=seq_lens)
|
||||||
|
elif mask is not None:
|
||||||
# add a singleton batch dimension
|
# add a singleton batch dimension
|
||||||
if mask.ndim == 2:
|
if mask.ndim == 2:
|
||||||
mask = mask.unsqueeze(0)
|
mask = mask.unsqueeze(0)
|
||||||
@ -457,6 +488,8 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
|||||||
|
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||||
|
|
||||||
|
if var_length:
|
||||||
|
return out.reshape(*target_output_shape)
|
||||||
if skip_output_reshape:
|
if skip_output_reshape:
|
||||||
out = out.permute(0, 2, 1, 3)
|
out = out.permute(0, 2, 1, 3)
|
||||||
else:
|
else:
|
||||||
@ -475,9 +508,7 @@ else:
|
|||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
||||||
if var_length:
|
if var_length:
|
||||||
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
|
||||||
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
|
|
||||||
assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True"
|
|
||||||
if not skip_reshape:
|
if not skip_reshape:
|
||||||
# assumes 2D q, k,v [total_tokens, embed_dim]
|
# assumes 2D q, k,v [total_tokens, embed_dim]
|
||||||
total_tokens, embed_dim = q.shape
|
total_tokens, embed_dim = q.shape
|
||||||
@ -539,9 +570,19 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length = False, **kwargs):
|
||||||
exception_fallback = False
|
exception_fallback = False
|
||||||
if skip_reshape:
|
if var_length:
|
||||||
|
if not skip_reshape:
|
||||||
|
total_tokens, hidden_dim = q.shape
|
||||||
|
dim_head = hidden_dim // heads
|
||||||
|
q, k, v = [t.view(total_tokens, heads, dim_head) for t in (q, k, v)]
|
||||||
|
b, _, dim_head = q.shape
|
||||||
|
# skips batched code
|
||||||
|
mask = None
|
||||||
|
tensor_layout = "VAR"
|
||||||
|
target_output_shape = (q.shape[0], -1)
|
||||||
|
elif skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
tensor_layout = "HND"
|
tensor_layout = "HND"
|
||||||
else:
|
else:
|
||||||
@ -562,7 +603,14 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
mask = mask.unsqueeze(1)
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
if var_length and not SAGE_ATTENTION_VAR_LENGTH_AVAILABLE:
|
||||||
|
raise ValueError("Sage Attention two is required to run variable length attention.")
|
||||||
|
elif var_length:
|
||||||
|
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
|
||||||
|
sm_scale = 1.0 / (q.shape[-1] ** 0.5)
|
||||||
|
out = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, is_causal=False, sm_scale=sm_scale)
|
||||||
|
else:
|
||||||
|
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
||||||
exception_fallback = True
|
exception_fallback = True
|
||||||
@ -572,7 +620,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
lambda t: t.transpose(1, 2),
|
lambda t: t.transpose(1, 2),
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs)
|
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, var_length=var_length, **kwargs)
|
||||||
|
|
||||||
if tensor_layout == "HND":
|
if tensor_layout == "HND":
|
||||||
if not skip_output_reshape:
|
if not skip_output_reshape:
|
||||||
@ -583,6 +631,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
if skip_output_reshape:
|
if skip_output_reshape:
|
||||||
out = out.transpose(1, 2)
|
out = out.transpose(1, 2)
|
||||||
else:
|
else:
|
||||||
|
if var_length:
|
||||||
|
return out.view(*target_output_shape)
|
||||||
out = out.reshape(b, -1, heads * dim_head)
|
out = out.reshape(b, -1, heads * dim_head)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -608,12 +658,7 @@ except AttributeError as error:
|
|||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
||||||
if var_length:
|
if var_length:
|
||||||
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
|
||||||
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
|
|
||||||
max_seqlen_q = kwargs.get("max_seqlen_q", None)
|
|
||||||
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q)
|
|
||||||
assert max_seqlen_q != None, "max_seqlen_q shouldn't be None when var_length is True"
|
|
||||||
assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True"
|
|
||||||
return flash_attn_varlen_func(
|
return flash_attn_varlen_func(
|
||||||
q=q, k=k, v=v,
|
q=q, k=k, v=v,
|
||||||
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
||||||
|
|||||||
@ -499,6 +499,8 @@ class InflatedCausalConv3d(ops.Conv3d):
|
|||||||
|
|
||||||
def pad_and_forward():
|
def pad_and_forward():
|
||||||
padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0)
|
padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0)
|
||||||
|
if not padded.is_contiguous():
|
||||||
|
padded = padded.contiguous()
|
||||||
with ignore_padding(self):
|
with ignore_padding(self):
|
||||||
return torch.nn.Conv3d.forward(self, padded)
|
return torch.nn.Conv3d.forward(self, padded)
|
||||||
|
|
||||||
@ -1726,7 +1728,7 @@ class VideoAutoencoderKL(nn.Module):
|
|||||||
return decoded
|
return decoded
|
||||||
|
|
||||||
def _encode(
|
def _encode(
|
||||||
self, x, memory_state
|
self, x, memory_state = MemoryState.DISABLED
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
_x = x.to(self.device)
|
_x = x.to(self.device)
|
||||||
h = self.encoder(_x, memory_state=memory_state)
|
h = self.encoder(_x, memory_state=memory_state)
|
||||||
@ -1737,7 +1739,7 @@ class VideoAutoencoderKL(nn.Module):
|
|||||||
return output.to(x.device)
|
return output.to(x.device)
|
||||||
|
|
||||||
def _decode(
|
def _decode(
|
||||||
self, z, memory_state
|
self, z, memory_state = MemoryState.DISABLED
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
_z = z.to(self.device)
|
_z = z.to(self.device)
|
||||||
|
|
||||||
@ -1892,9 +1894,16 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|||||||
|
|
||||||
# in case of padded frames
|
# in case of padded frames
|
||||||
t = input.size(0)
|
t = input.size(0)
|
||||||
x = x[:, :, :t]
|
if t != 1:
|
||||||
|
x = x[:, :, :t]
|
||||||
|
if t == 1 and x.size(2) == 4:
|
||||||
|
x = x[:, :, :t]
|
||||||
|
|
||||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
if x.size(1) == 1:
|
||||||
|
exp = "b t c h w -> (b t) c h w"
|
||||||
|
else:
|
||||||
|
exp = "b c t h w -> (b t) c h w"
|
||||||
|
x = rearrange(x, exp)
|
||||||
|
|
||||||
input = input.to(x.device)
|
input = input.to(x.device)
|
||||||
x = wavelet_reconstruction(x, input)
|
x = wavelet_reconstruction(x, input)
|
||||||
|
|||||||
@ -270,12 +270,11 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
inputs = [
|
inputs = [
|
||||||
io.Image.Input("images"),
|
io.Image.Input("images"),
|
||||||
io.Vae.Input("vae"),
|
io.Vae.Input("vae"),
|
||||||
io.Int.Input("resolution_height", default = 1280, min = 120), # //
|
io.Int.Input("resolution", default = 1280, min = 120), # just non-zero value
|
||||||
io.Int.Input("resolution_width", default = 720, min = 120), # just non-zero value
|
|
||||||
io.Int.Input("spatial_tile_size", default = 512, min = 1),
|
io.Int.Input("spatial_tile_size", default = 512, min = 1),
|
||||||
io.Int.Input("temporal_tile_size", default = 8, min = 1),
|
|
||||||
io.Int.Input("spatial_overlap", default = 64, min = 1),
|
io.Int.Input("spatial_overlap", default = 64, min = 1),
|
||||||
io.Boolean.Input("enable_tiling", default=False)
|
io.Int.Input("temporal_tile_size", default = 8, min = 1),
|
||||||
|
io.Boolean.Input("enable_tiling", default=False),
|
||||||
],
|
],
|
||||||
outputs = [
|
outputs = [
|
||||||
io.Latent.Output("vae_conditioning")
|
io.Latent.Output("vae_conditioning")
|
||||||
@ -283,7 +282,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling):
|
def execute(cls, images, vae, resolution, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling):
|
||||||
device = vae.patcher.load_device
|
device = vae.patcher.load_device
|
||||||
|
|
||||||
offload_device = comfy.model_management.intermediate_device()
|
offload_device = comfy.model_management.intermediate_device()
|
||||||
@ -298,11 +297,9 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
b, t, c, h, w = images.shape
|
b, t, c, h, w = images.shape
|
||||||
images = images.reshape(b * t, c, h, w)
|
images = images.reshape(b * t, c, h, w)
|
||||||
|
|
||||||
#max_area = ((resolution_height * resolution_width)** 0.5) ** 2
|
|
||||||
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
|
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
|
||||||
normalize = Normalize(0.5, 0.5)
|
normalize = Normalize(0.5, 0.5)
|
||||||
#images = area_resize(images, max_area)
|
images = side_resize(images, resolution)
|
||||||
images = side_resize(images, resolution_height)
|
|
||||||
|
|
||||||
images = clip(images)
|
images = clip(images)
|
||||||
o_h, o_w = images.shape[-2:]
|
o_h, o_w = images.shape[-2:]
|
||||||
@ -317,6 +314,17 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
images = images.to(device)
|
images = images.to(device)
|
||||||
vae_model = vae_model.to(device)
|
vae_model = vae_model.to(device)
|
||||||
|
|
||||||
|
# in case users a non-compatiable number for tiling
|
||||||
|
def make_divisible(val, divisor):
|
||||||
|
return max(divisor, round(val / divisor) * divisor)
|
||||||
|
|
||||||
|
temporal_tile_size = make_divisible(temporal_tile_size, 4)
|
||||||
|
spatial_tile_size = make_divisible(spatial_tile_size, 32)
|
||||||
|
spatial_overlap = make_divisible(spatial_overlap, 32)
|
||||||
|
|
||||||
|
if spatial_overlap >= spatial_tile_size:
|
||||||
|
spatial_overlap = max(0, spatial_tile_size - 8)
|
||||||
|
|
||||||
args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap),
|
args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap),
|
||||||
"temporal_size":temporal_tile_size}
|
"temporal_size":temporal_tile_size}
|
||||||
if enable_tiling:
|
if enable_tiling:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user