mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-07 21:00:49 +08:00
rope, attetntion update | vae on cpu warning
This commit is contained in:
parent
49febe15c3
commit
31d358c78c
@ -416,7 +416,8 @@ except:
|
||||
pass
|
||||
|
||||
@wrap_attn
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
var_length = kwargs.get("var_length", False)
|
||||
b = q.shape[0]
|
||||
dim_head = q.shape[-1]
|
||||
# check to make sure xformers isn't broken
|
||||
@ -506,7 +507,8 @@ else:
|
||||
SDP_BATCH_LIMIT = 2**31
|
||||
|
||||
@wrap_attn
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
var_length = kwargs.get("var_length", False)
|
||||
if var_length:
|
||||
cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
|
||||
if not skip_reshape:
|
||||
@ -570,7 +572,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
return out
|
||||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length = False, **kwargs):
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
var_length = kwargs.get("var_length", False)
|
||||
exception_fallback = False
|
||||
if var_length:
|
||||
if not skip_reshape:
|
||||
@ -656,7 +659,8 @@ except AttributeError as error:
|
||||
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||
|
||||
@wrap_attn
|
||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
var_length = kwargs.get("var_length", False)
|
||||
if var_length:
|
||||
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
|
||||
return flash_attn_varlen_func(
|
||||
|
||||
@ -13,6 +13,7 @@ from comfy.rmsnorm import RMSNorm
|
||||
from torch.nn.modules.utils import _triple
|
||||
from torch import nn
|
||||
import math
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
|
||||
class Cache:
|
||||
def __init__(self, disable=False, prefix="", cache=None):
|
||||
@ -443,7 +444,6 @@ def apply_rotary_emb(
|
||||
freqs_seq_dim = None
|
||||
):
|
||||
dtype = t.dtype
|
||||
|
||||
if not exists(freqs_seq_dim):
|
||||
if freqs.ndim == 2 or t.ndim == 3:
|
||||
freqs_seq_dim = 0
|
||||
@ -452,20 +452,23 @@ def apply_rotary_emb(
|
||||
seq_len = t.shape[seq_dim]
|
||||
freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim)
|
||||
|
||||
rot_dim = freqs.shape[-1]
|
||||
end_index = start_index + rot_dim
|
||||
|
||||
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
|
||||
rot_feats = freqs.shape[-1]
|
||||
end_index = start_index + rot_feats
|
||||
|
||||
t_left = t[..., :start_index]
|
||||
t_middle = t[..., start_index:end_index]
|
||||
t_right = t[..., end_index:]
|
||||
|
||||
freqs = freqs.to(t_middle.device)
|
||||
t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
|
||||
angles = freqs.to(t_middle.device)[..., ::2]
|
||||
cos = torch.cos(angles) * scale
|
||||
sin = torch.sin(angles) * scale
|
||||
|
||||
out = torch.cat((t_left, t_transformed, t_right), dim=-1)
|
||||
col0 = torch.stack([cos, sin], dim=-1)
|
||||
col1 = torch.stack([-sin, cos], dim=-1)
|
||||
freqs_mat = torch.stack([col0, col1], dim=-1)
|
||||
|
||||
t_middle_out = apply_rope1(t_middle, freqs_mat)
|
||||
out = torch.cat((t_left, t_middle_out, t_right), dim=-1)
|
||||
return out.type(dtype)
|
||||
|
||||
class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
|
||||
|
||||
@ -16,6 +16,7 @@ import math
|
||||
from enum import Enum
|
||||
from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND
|
||||
|
||||
import logging
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
@ -446,6 +447,7 @@ class InflatedCausalConv3d(ops.Conv3d):
|
||||
self.memory_device = memory_device
|
||||
self.padding = (0, *self.padding[1:])
|
||||
self.memory_limit = float("inf")
|
||||
self.logged_once = False
|
||||
|
||||
def set_memory_limit(self, value: float):
|
||||
self.memory_limit = value
|
||||
@ -469,8 +471,16 @@ class InflatedCausalConv3d(ops.Conv3d):
|
||||
return out
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
try:
|
||||
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
||||
except NotImplementedError:
|
||||
# for: Could not run 'aten::cudnn_convolution' with arguments from the 'CPU' backend
|
||||
if not self.logged_once:
|
||||
logging.warning("VAE is on CPU for decoding. This is most likely due to being not enough memory")
|
||||
self.logged_once = True
|
||||
return F.conv3d(input, weight, bias, *args, **kwargs)
|
||||
|
||||
def memory_limit_conv(
|
||||
self,
|
||||
|
||||
@ -382,8 +382,8 @@ class VAE:
|
||||
self.latent_channels = 16
|
||||
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
|
||||
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
|
||||
self.memory_used_decode = lambda shape, dtype: (10 * shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (10 * max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user