mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
rope, attetntion update | vae on cpu warning
This commit is contained in:
parent
49febe15c3
commit
31d358c78c
@ -416,7 +416,8 @@ except:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
var_length = kwargs.get("var_length", False)
|
||||||
b = q.shape[0]
|
b = q.shape[0]
|
||||||
dim_head = q.shape[-1]
|
dim_head = q.shape[-1]
|
||||||
# check to make sure xformers isn't broken
|
# check to make sure xformers isn't broken
|
||||||
@ -506,7 +507,8 @@ else:
|
|||||||
SDP_BATCH_LIMIT = 2**31
|
SDP_BATCH_LIMIT = 2**31
|
||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
var_length = kwargs.get("var_length", False)
|
||||||
if var_length:
|
if var_length:
|
||||||
cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
|
cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
|
||||||
if not skip_reshape:
|
if not skip_reshape:
|
||||||
@ -570,7 +572,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length = False, **kwargs):
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
var_length = kwargs.get("var_length", False)
|
||||||
exception_fallback = False
|
exception_fallback = False
|
||||||
if var_length:
|
if var_length:
|
||||||
if not skip_reshape:
|
if not skip_reshape:
|
||||||
@ -656,7 +659,8 @@ except AttributeError as error:
|
|||||||
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
var_length = kwargs.get("var_length", False)
|
||||||
if var_length:
|
if var_length:
|
||||||
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
|
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
|
||||||
return flash_attn_varlen_func(
|
return flash_attn_varlen_func(
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from comfy.rmsnorm import RMSNorm
|
|||||||
from torch.nn.modules.utils import _triple
|
from torch.nn.modules.utils import _triple
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import math
|
import math
|
||||||
|
from comfy.ldm.flux.math import apply_rope1
|
||||||
|
|
||||||
class Cache:
|
class Cache:
|
||||||
def __init__(self, disable=False, prefix="", cache=None):
|
def __init__(self, disable=False, prefix="", cache=None):
|
||||||
@ -443,7 +444,6 @@ def apply_rotary_emb(
|
|||||||
freqs_seq_dim = None
|
freqs_seq_dim = None
|
||||||
):
|
):
|
||||||
dtype = t.dtype
|
dtype = t.dtype
|
||||||
|
|
||||||
if not exists(freqs_seq_dim):
|
if not exists(freqs_seq_dim):
|
||||||
if freqs.ndim == 2 or t.ndim == 3:
|
if freqs.ndim == 2 or t.ndim == 3:
|
||||||
freqs_seq_dim = 0
|
freqs_seq_dim = 0
|
||||||
@ -452,20 +452,23 @@ def apply_rotary_emb(
|
|||||||
seq_len = t.shape[seq_dim]
|
seq_len = t.shape[seq_dim]
|
||||||
freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim)
|
freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim)
|
||||||
|
|
||||||
rot_dim = freqs.shape[-1]
|
rot_feats = freqs.shape[-1]
|
||||||
end_index = start_index + rot_dim
|
end_index = start_index + rot_feats
|
||||||
|
|
||||||
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
|
|
||||||
|
|
||||||
t_left = t[..., :start_index]
|
t_left = t[..., :start_index]
|
||||||
t_middle = t[..., start_index:end_index]
|
t_middle = t[..., start_index:end_index]
|
||||||
t_right = t[..., end_index:]
|
t_right = t[..., end_index:]
|
||||||
|
|
||||||
freqs = freqs.to(t_middle.device)
|
angles = freqs.to(t_middle.device)[..., ::2]
|
||||||
t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
|
cos = torch.cos(angles) * scale
|
||||||
|
sin = torch.sin(angles) * scale
|
||||||
|
|
||||||
out = torch.cat((t_left, t_transformed, t_right), dim=-1)
|
col0 = torch.stack([cos, sin], dim=-1)
|
||||||
|
col1 = torch.stack([-sin, cos], dim=-1)
|
||||||
|
freqs_mat = torch.stack([col0, col1], dim=-1)
|
||||||
|
|
||||||
|
t_middle_out = apply_rope1(t_middle, freqs_mat)
|
||||||
|
out = torch.cat((t_left, t_middle_out, t_right), dim=-1)
|
||||||
return out.type(dtype)
|
return out.type(dtype)
|
||||||
|
|
||||||
class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
|
class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
|
||||||
|
|||||||
@ -16,6 +16,7 @@ import math
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND
|
from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND
|
||||||
|
|
||||||
|
import logging
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
@ -446,6 +447,7 @@ class InflatedCausalConv3d(ops.Conv3d):
|
|||||||
self.memory_device = memory_device
|
self.memory_device = memory_device
|
||||||
self.padding = (0, *self.padding[1:])
|
self.padding = (0, *self.padding[1:])
|
||||||
self.memory_limit = float("inf")
|
self.memory_limit = float("inf")
|
||||||
|
self.logged_once = False
|
||||||
|
|
||||||
def set_memory_limit(self, value: float):
|
def set_memory_limit(self, value: float):
|
||||||
self.memory_limit = value
|
self.memory_limit = value
|
||||||
@ -469,8 +471,16 @@ class InflatedCausalConv3d(ops.Conv3d):
|
|||||||
return out
|
return out
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
|
except NotImplementedError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
||||||
|
except NotImplementedError:
|
||||||
|
# for: Could not run 'aten::cudnn_convolution' with arguments from the 'CPU' backend
|
||||||
|
if not self.logged_once:
|
||||||
|
logging.warning("VAE is on CPU for decoding. This is most likely due to being not enough memory")
|
||||||
|
self.logged_once = True
|
||||||
|
return F.conv3d(input, weight, bias, *args, **kwargs)
|
||||||
|
|
||||||
def memory_limit_conv(
|
def memory_limit_conv(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -382,8 +382,8 @@ class VAE:
|
|||||||
self.latent_channels = 16
|
self.latent_channels = 16
|
||||||
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
|
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
|
||||||
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
|
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
|
||||||
self.memory_used_decode = lambda shape, dtype: (10 * shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
self.memory_used_encode = lambda shape, dtype: (10 * max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||||
self.downscale_index_formula = (4, 8, 8)
|
self.downscale_index_formula = (4, 8, 8)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user