mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-21 20:00:17 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
a779e34c5b
4
.github/workflows/test-build.yml
vendored
4
.github/workflows/test-build.yml
vendored
@ -18,7 +18,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
@ -28,4 +28,4 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
|||||||
@ -168,14 +168,18 @@ class Attention(nn.Module):
|
|||||||
k = self.to_k[1](k)
|
k = self.to_k[1](k)
|
||||||
v = self.to_v[1](v)
|
v = self.to_v[1](v)
|
||||||
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
||||||
q = apply_rotary_pos_emb(q, rope_emb)
|
# apply_rotary_pos_emb inlined
|
||||||
k = apply_rotary_pos_emb(k, rope_emb)
|
q_shape = q.shape
|
||||||
return q, k, v
|
q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
|
||||||
|
q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
|
||||||
|
q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)
|
||||||
|
|
||||||
def cal_attn(self, q, k, v, mask=None):
|
# apply_rotary_pos_emb inlined
|
||||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
k_shape = k.shape
|
||||||
out = rearrange(out, " b n s c -> s b (n c)")
|
k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
|
||||||
return self.to_out(out)
|
k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
|
||||||
|
k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -191,7 +195,10 @@ class Attention(nn.Module):
|
|||||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||||
"""
|
"""
|
||||||
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
||||||
return self.cal_attn(q, k, v, mask)
|
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
||||||
|
del q, k, v
|
||||||
|
out = rearrange(out, " b n s c -> s b (n c)")
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
@ -788,10 +795,7 @@ class GeneralDITTransformerBlock(nn.Module):
|
|||||||
crossattn_mask: Optional[torch.Tensor] = None,
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if extra_per_block_pos_emb is not None:
|
|
||||||
x = x + extra_per_block_pos_emb
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(
|
x = block(
|
||||||
x,
|
x,
|
||||||
|
|||||||
@ -30,6 +30,8 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||||||
|
|
||||||
from .patching import (
|
from .patching import (
|
||||||
Patcher,
|
Patcher,
|
||||||
Patcher3D,
|
Patcher3D,
|
||||||
@ -400,6 +402,8 @@ class CausalAttnBlock(nn.Module):
|
|||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.optimized_attention = vae_attention()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
@ -413,18 +417,7 @@ class CausalAttnBlock(nn.Module):
|
|||||||
v, batch_size = time2batch(v)
|
v, batch_size = time2batch(v)
|
||||||
|
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q = q.reshape(b, c, h * w)
|
h_ = self.optimized_attention(q, k, v)
|
||||||
q = q.permute(0, 2, 1)
|
|
||||||
k = k.reshape(b, c, h * w)
|
|
||||||
w_ = torch.bmm(q, k)
|
|
||||||
w_ = w_ * (int(c) ** (-0.5))
|
|
||||||
w_ = F.softmax(w_, dim=2)
|
|
||||||
|
|
||||||
# attend to values
|
|
||||||
v = v.reshape(b, c, h * w)
|
|
||||||
w_ = w_.permute(0, 2, 1)
|
|
||||||
h_ = torch.bmm(v, w_)
|
|
||||||
h_ = h_.reshape(b, c, h, w)
|
|
||||||
|
|
||||||
h_ = batch2time(h_, batch_size)
|
h_ = batch2time(h_, batch_size)
|
||||||
h_ = self.proj_out(h_)
|
h_ = self.proj_out(h_)
|
||||||
@ -871,18 +864,16 @@ class EncoderFactorized(nn.Module):
|
|||||||
x = self.patcher3d(x)
|
x = self.patcher3d(x)
|
||||||
|
|
||||||
# downsampling
|
# downsampling
|
||||||
hs = [self.conv_in(x)]
|
h = self.conv_in(x)
|
||||||
for i_level in range(self.num_resolutions):
|
for i_level in range(self.num_resolutions):
|
||||||
for i_block in range(self.num_res_blocks):
|
for i_block in range(self.num_res_blocks):
|
||||||
h = self.down[i_level].block[i_block](hs[-1])
|
h = self.down[i_level].block[i_block](h)
|
||||||
if len(self.down[i_level].attn) > 0:
|
if len(self.down[i_level].attn) > 0:
|
||||||
h = self.down[i_level].attn[i_block](h)
|
h = self.down[i_level].attn[i_block](h)
|
||||||
hs.append(h)
|
|
||||||
if i_level != self.num_resolutions - 1:
|
if i_level != self.num_resolutions - 1:
|
||||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
h = self.down[i_level].downsample(h)
|
||||||
|
|
||||||
# middle
|
# middle
|
||||||
h = hs[-1]
|
|
||||||
h = self.mid.block_1(h)
|
h = self.mid.block_1(h)
|
||||||
h = self.mid.attn_1(h)
|
h = self.mid.attn_1(h)
|
||||||
h = self.mid.block_2(h)
|
h = self.mid.block_2(h)
|
||||||
|
|||||||
@ -281,54 +281,76 @@ class UnPatcher3D(UnPatcher):
|
|||||||
hh = hh.to(dtype=dtype)
|
hh = hh.to(dtype=dtype)
|
||||||
|
|
||||||
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
|
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
|
||||||
|
del x
|
||||||
|
|
||||||
# Height height transposed convolutions.
|
# Height height transposed convolutions.
|
||||||
xll = F.conv_transpose3d(
|
xll = F.conv_transpose3d(
|
||||||
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
|
del xlll
|
||||||
|
|
||||||
xll += F.conv_transpose3d(
|
xll += F.conv_transpose3d(
|
||||||
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
|
del xllh
|
||||||
|
|
||||||
xlh = F.conv_transpose3d(
|
xlh = F.conv_transpose3d(
|
||||||
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
|
del xlhl
|
||||||
|
|
||||||
xlh += F.conv_transpose3d(
|
xlh += F.conv_transpose3d(
|
||||||
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
|
del xlhh
|
||||||
|
|
||||||
xhl = F.conv_transpose3d(
|
xhl = F.conv_transpose3d(
|
||||||
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
|
del xhll
|
||||||
|
|
||||||
xhl += F.conv_transpose3d(
|
xhl += F.conv_transpose3d(
|
||||||
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
|
del xhlh
|
||||||
|
|
||||||
xhh = F.conv_transpose3d(
|
xhh = F.conv_transpose3d(
|
||||||
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
|
del xhhl
|
||||||
|
|
||||||
xhh += F.conv_transpose3d(
|
xhh += F.conv_transpose3d(
|
||||||
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
|
del xhhh
|
||||||
|
|
||||||
# Handles width transposed convolutions.
|
# Handles width transposed convolutions.
|
||||||
xl = F.conv_transpose3d(
|
xl = F.conv_transpose3d(
|
||||||
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
)
|
)
|
||||||
|
del xll
|
||||||
|
|
||||||
xl += F.conv_transpose3d(
|
xl += F.conv_transpose3d(
|
||||||
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
)
|
)
|
||||||
|
del xlh
|
||||||
|
|
||||||
xh = F.conv_transpose3d(
|
xh = F.conv_transpose3d(
|
||||||
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
)
|
)
|
||||||
|
del xhl
|
||||||
|
|
||||||
xh += F.conv_transpose3d(
|
xh += F.conv_transpose3d(
|
||||||
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
)
|
)
|
||||||
|
del xhh
|
||||||
|
|
||||||
# Handles time axis transposed convolutions.
|
# Handles time axis transposed convolutions.
|
||||||
x = F.conv_transpose3d(
|
x = F.conv_transpose3d(
|
||||||
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
||||||
)
|
)
|
||||||
|
del xl
|
||||||
|
|
||||||
x += F.conv_transpose3d(
|
x += F.conv_transpose3d(
|
||||||
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -168,7 +168,7 @@ class GeneralDIT(nn.Module):
|
|||||||
operations=operations,
|
operations=operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.build_pos_embed(device=device)
|
self.build_pos_embed(device=device, dtype=dtype)
|
||||||
self.block_x_format = block_x_format
|
self.block_x_format = block_x_format
|
||||||
self.use_adaln_lora = use_adaln_lora
|
self.use_adaln_lora = use_adaln_lora
|
||||||
self.adaln_lora_dim = adaln_lora_dim
|
self.adaln_lora_dim = adaln_lora_dim
|
||||||
@ -210,7 +210,7 @@ class GeneralDIT(nn.Module):
|
|||||||
operations=operations,
|
operations=operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_pos_embed(self, device=None):
|
def build_pos_embed(self, device=None, dtype=None):
|
||||||
if self.pos_emb_cls == "rope3d":
|
if self.pos_emb_cls == "rope3d":
|
||||||
cls_type = VideoRopePosition3DEmb
|
cls_type = VideoRopePosition3DEmb
|
||||||
else:
|
else:
|
||||||
@ -242,6 +242,7 @@ class GeneralDIT(nn.Module):
|
|||||||
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
||||||
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
||||||
kwargs["device"] = device
|
kwargs["device"] = device
|
||||||
|
kwargs["dtype"] = dtype
|
||||||
self.extra_pos_embedder = LearnablePosEmbAxis(
|
self.extra_pos_embedder = LearnablePosEmbAxis(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@ -476,6 +477,8 @@ class GeneralDIT(nn.Module):
|
|||||||
inputs["original_shape"],
|
inputs["original_shape"],
|
||||||
)
|
)
|
||||||
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
|
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
|
||||||
|
del inputs
|
||||||
|
|
||||||
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||||
assert (
|
assert (
|
||||||
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||||
@ -486,6 +489,8 @@ class GeneralDIT(nn.Module):
|
|||||||
self.blocks["block0"].x_format == block.x_format
|
self.blocks["block0"].x_format == block.x_format
|
||||||
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
|
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
|
||||||
|
|
||||||
|
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||||
|
x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D
|
||||||
x = block(
|
x = block(
|
||||||
x,
|
x,
|
||||||
affline_emb_B_D,
|
affline_emb_B_D,
|
||||||
@ -493,7 +498,6 @@ class GeneralDIT(nn.Module):
|
|||||||
crossattn_mask,
|
crossattn_mask,
|
||||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||||
extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
||||||
|
|||||||
@ -173,6 +173,7 @@ class LearnablePosEmbAxis(VideoPositionEmb):
|
|||||||
len_w: int,
|
len_w: int,
|
||||||
len_t: int,
|
len_t: int,
|
||||||
device=None,
|
device=None,
|
||||||
|
dtype=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -184,9 +185,9 @@ class LearnablePosEmbAxis(VideoPositionEmb):
|
|||||||
self.interpolation = interpolation
|
self.interpolation = interpolation
|
||||||
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
|
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
|
||||||
|
|
||||||
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device))
|
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
|
||||||
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device))
|
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
|
||||||
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device))
|
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
|
||||||
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
|
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
|
||||||
|
|||||||
@ -5,8 +5,15 @@ from torch import Tensor
|
|||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||||
q, k = apply_rope(q, k, pe)
|
q_shape = q.shape
|
||||||
|
k_shape = k.shape
|
||||||
|
|
||||||
|
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
|
||||||
|
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
|
||||||
|
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||||
|
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||||
|
|
||||||
heads = q.shape[1]
|
heads = q.shape[1]
|
||||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||||
|
|||||||
@ -293,6 +293,17 @@ def pytorch_attention(q, k, v):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def vae_attention():
|
||||||
|
if model_management.xformers_enabled_vae():
|
||||||
|
logging.info("Using xformers attention in VAE")
|
||||||
|
return xformers_attention
|
||||||
|
elif model_management.pytorch_attention_enabled():
|
||||||
|
logging.info("Using pytorch attention in VAE")
|
||||||
|
return pytorch_attention
|
||||||
|
else:
|
||||||
|
logging.info("Using split attention in VAE")
|
||||||
|
return normal_attention
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
class AttnBlock(nn.Module):
|
||||||
def __init__(self, in_channels, conv_op=ops.Conv2d):
|
def __init__(self, in_channels, conv_op=ops.Conv2d):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -320,15 +331,7 @@ class AttnBlock(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
|
|
||||||
if model_management.xformers_enabled_vae():
|
self.optimized_attention = vae_attention()
|
||||||
logging.info("Using xformers attention in VAE")
|
|
||||||
self.optimized_attention = xformers_attention
|
|
||||||
elif model_management.pytorch_attention_enabled():
|
|
||||||
logging.info("Using pytorch attention in VAE")
|
|
||||||
self.optimized_attention = pytorch_attention
|
|
||||||
else:
|
|
||||||
logging.info("Using split attention in VAE")
|
|
||||||
self.optimized_attention = normal_attention
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
|
|||||||
@ -2,6 +2,7 @@ torch
|
|||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
torchaudio
|
torchaudio
|
||||||
|
numpy>=1.25.0
|
||||||
einops
|
einops
|
||||||
transformers>=4.28.1
|
transformers>=4.28.1
|
||||||
tokenizers>=0.13.3
|
tokenizers>=0.13.3
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user