mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-21 20:00:17 +08:00
Mochi and SageAttention improvements
This commit is contained in:
parent
264d84db39
commit
8ba412897e
36
README.md
36
README.md
@ -333,6 +333,42 @@ pip install git+https://github.com/AppMana/appmana-comfyui-nodes-controlnet-aux.
|
|||||||
|
|
||||||
Start creating an AnimateDiff workflow. When using these packages, the appropriate models will download automatically.
|
Start creating an AnimateDiff workflow. When using these packages, the appropriate models will download automatically.
|
||||||
|
|
||||||
|
## SageAttention
|
||||||
|
|
||||||
|
Improve the performance of your Mochi model video generation using **Sage Attention**:
|
||||||
|
|
||||||
|
| Device | PyTorch 2.5.1 | SageAttention | S.A. + TorchCompileModel |
|
||||||
|
|--------|---------------|---------------|--------------------------|
|
||||||
|
| A5000 | 7.52s/it | 5.81s/it | 5.00s/it (but corrupted) |
|
||||||
|
|
||||||
|
[Use the default Mochi Workflow.](https://github.com/comfyanonymous/ComfyUI_examples/raw/refs/heads/master/mochi/mochi_text_to_video_example.webp) This does not require any custom nodes or any change to your workflow.
|
||||||
|
|
||||||
|
Install the dependencies for Windows or Linux using the `withtriton` component, or install the specific dependencies you need from [requirements-triton.txt](./requirements-triton.txt):
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install "comfyui[withtriton]@git+https://github.com/hiddenswitch/ComfyUI.git"
|
||||||
|
```
|
||||||
|
|
||||||
|
If you have `xformers` installed, disable it, as it will be preferred over Sage Attention:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
comfyui --disable-xformers
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to use **TorchCompileModel** to further improve performance, do not reserve VRAM:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
comfyui --disable-xformers --reserve-vram=-1.0
|
||||||
|
```
|
||||||
|
|
||||||
|
Sage Attention is not compatible with Flux. It does not appear to be compatible with Mochi when using `torch.compile`
|
||||||
|
|
||||||
|

|
||||||
|
**With SageAttention**
|
||||||
|
|
||||||
|

|
||||||
|
**With PyTorch Attention**
|
||||||
|
|
||||||
# Custom Nodes
|
# Custom Nodes
|
||||||
|
|
||||||
Custom Nodes can be added to ComfyUI by copying and pasting Python files into your `./custom_nodes` directory.
|
Custom Nodes can be added to ComfyUI by copying and pasting Python files into your `./custom_nodes` directory.
|
||||||
|
|||||||
0
comfy/ldm/genmo/__init__.py
Normal file
0
comfy/ldm/genmo/__init__.py
Normal file
0
comfy/ldm/genmo/joint_model/__init__.py
Normal file
0
comfy/ldm/genmo/joint_model/__init__.py
Normal file
@ -1,14 +1,11 @@
|
|||||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
# original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
#adapted to ComfyUI
|
# adapted to ComfyUIfrom typing import Dict, List, Optional, Tuple
|
||||||
|
from typing import Tuple, List, Dict, Optional
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
# from flash_attn import flash_attn_varlen_qkvpacked_func
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
|
||||||
|
|
||||||
from .layers import (
|
from .layers import (
|
||||||
FeedForward,
|
FeedForward,
|
||||||
@ -16,7 +13,6 @@ from .layers import (
|
|||||||
RMSNorm,
|
RMSNorm,
|
||||||
TimestepEmbedder,
|
TimestepEmbedder,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .rope_mixed import (
|
from .rope_mixed import (
|
||||||
compute_mixed_rotation,
|
compute_mixed_rotation,
|
||||||
create_position_matrix,
|
create_position_matrix,
|
||||||
@ -26,14 +22,15 @@ from .utils import (
|
|||||||
AttentionPool,
|
AttentionPool,
|
||||||
modulate,
|
modulate,
|
||||||
)
|
)
|
||||||
|
from ...common_dit import rms_norm
|
||||||
import comfy.ldm.common_dit
|
# from flash_attn import flash_attn_varlen_qkvpacked_func
|
||||||
import comfy.ops
|
from ...modules.attention import optimized_attention
|
||||||
|
from .... import ops
|
||||||
|
|
||||||
|
|
||||||
def modulated_rmsnorm(x, scale, eps=1e-6):
|
def modulated_rmsnorm(x, scale, eps=1e-6):
|
||||||
# Normalize and modulate
|
# Normalize and modulate
|
||||||
x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps)
|
x_normed = rms_norm(x, eps=eps)
|
||||||
x_modulated = x_normed * (1 + scale.unsqueeze(1))
|
x_modulated = x_normed * (1 + scale.unsqueeze(1))
|
||||||
|
|
||||||
return x_modulated
|
return x_modulated
|
||||||
@ -44,29 +41,30 @@ def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
|
|||||||
tanh_gate = torch.tanh(gate).unsqueeze(1)
|
tanh_gate = torch.tanh(gate).unsqueeze(1)
|
||||||
|
|
||||||
# Normalize and apply gated scaling
|
# Normalize and apply gated scaling
|
||||||
x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate
|
x_normed = rms_norm(x_res, eps=eps) * tanh_gate
|
||||||
|
|
||||||
# Apply residual connection
|
# Apply residual connection
|
||||||
output = x + x_normed
|
output = x + x_normed
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class AsymmetricAttention(nn.Module):
|
class AsymmetricAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim_x: int,
|
dim_x: int,
|
||||||
dim_y: int,
|
dim_y: int,
|
||||||
num_heads: int = 8,
|
num_heads: int = 8,
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
qk_norm: bool = False,
|
qk_norm: bool = False,
|
||||||
attn_drop: float = 0.0,
|
attn_drop: float = 0.0,
|
||||||
update_y: bool = True,
|
update_y: bool = True,
|
||||||
out_bias: bool = True,
|
out_bias: bool = True,
|
||||||
attend_to_padding: bool = False,
|
attend_to_padding: bool = False,
|
||||||
softmax_scale: Optional[float] = None,
|
softmax_scale: Optional[float] = None,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim_x = dim_x
|
self.dim_x = dim_x
|
||||||
@ -104,13 +102,13 @@ class AsymmetricAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor, # (B, N, dim_x)
|
x: torch.Tensor, # (B, N, dim_x)
|
||||||
y: torch.Tensor, # (B, L, dim_y)
|
y: torch.Tensor, # (B, L, dim_y)
|
||||||
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
||||||
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
||||||
crop_y,
|
crop_y,
|
||||||
**rope_rotation,
|
**rope_rotation,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
rope_cos = rope_rotation.get("rope_cos")
|
rope_cos = rope_rotation.get("rope_cos")
|
||||||
rope_sin = rope_rotation.get("rope_sin")
|
rope_sin = rope_rotation.get("rope_sin")
|
||||||
@ -159,18 +157,18 @@ class AsymmetricAttention(nn.Module):
|
|||||||
|
|
||||||
class AsymmetricJointBlock(nn.Module):
|
class AsymmetricJointBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size_x: int,
|
hidden_size_x: int,
|
||||||
hidden_size_y: int,
|
hidden_size_y: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
*,
|
*,
|
||||||
mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
|
mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
|
||||||
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
|
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
|
||||||
update_y: bool = True, # Whether to update text tokens in this block.
|
update_y: bool = True, # Whether to update text tokens in this block.
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
**block_kwargs,
|
**block_kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.update_y = update_y
|
self.update_y = update_y
|
||||||
@ -221,11 +219,11 @@ class AsymmetricJointBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
c: torch.Tensor,
|
c: torch.Tensor,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
**attn_kwargs,
|
**attn_kwargs,
|
||||||
):
|
):
|
||||||
"""Forward pass of a block.
|
"""Forward pass of a block.
|
||||||
|
|
||||||
@ -291,13 +289,13 @@ class FinalLayer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
patch_size,
|
patch_size,
|
||||||
out_channels,
|
out_channels,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm_final = operations.LayerNorm(
|
self.norm_final = operations.LayerNorm(
|
||||||
@ -324,32 +322,32 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
patch_size=2,
|
patch_size=2,
|
||||||
in_channels=4,
|
in_channels=4,
|
||||||
hidden_size_x=1152,
|
hidden_size_x=1152,
|
||||||
hidden_size_y=1152,
|
hidden_size_y=1152,
|
||||||
depth=48,
|
depth=48,
|
||||||
num_heads=16,
|
num_heads=16,
|
||||||
mlp_ratio_x=8.0,
|
mlp_ratio_x=8.0,
|
||||||
mlp_ratio_y=4.0,
|
mlp_ratio_y=4.0,
|
||||||
use_t5: bool = False,
|
use_t5: bool = False,
|
||||||
t5_feat_dim: int = 4096,
|
t5_feat_dim: int = 4096,
|
||||||
t5_token_length: int = 256,
|
t5_token_length: int = 256,
|
||||||
learn_sigma=True,
|
learn_sigma=True,
|
||||||
patch_embed_bias: bool = True,
|
patch_embed_bias: bool = True,
|
||||||
timestep_mlp_bias: bool = True,
|
timestep_mlp_bias: bool = True,
|
||||||
attend_to_padding: bool = False,
|
attend_to_padding: bool = False,
|
||||||
timestep_scale: Optional[float] = None,
|
timestep_scale: Optional[float] = None,
|
||||||
use_extended_posenc: bool = False,
|
use_extended_posenc: bool = False,
|
||||||
posenc_preserve_area: bool = False,
|
posenc_preserve_area: bool = False,
|
||||||
rope_theta: float = 10000.0,
|
rope_theta: float = 10000.0,
|
||||||
image_model=None,
|
image_model=None,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
**block_kwargs,
|
**block_kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -362,7 +360,7 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
self.hidden_size_x = hidden_size_x
|
self.hidden_size_x = hidden_size_x
|
||||||
self.hidden_size_y = hidden_size_y
|
self.hidden_size_y = hidden_size_y
|
||||||
self.head_dim = (
|
self.head_dim = (
|
||||||
hidden_size_x // num_heads
|
hidden_size_x // num_heads
|
||||||
) # Head dimension and count is determined by visual.
|
) # Head dimension and count is determined by visual.
|
||||||
self.attend_to_padding = attend_to_padding
|
self.attend_to_padding = attend_to_padding
|
||||||
self.use_extended_posenc = use_extended_posenc
|
self.use_extended_posenc = use_extended_posenc
|
||||||
@ -449,11 +447,11 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
return self.x_embedder(x) # Convert BcTHW to BCN
|
return self.x_embedder(x) # Convert BcTHW to BCN
|
||||||
|
|
||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
sigma: torch.Tensor,
|
sigma: torch.Tensor,
|
||||||
t5_feat: torch.Tensor,
|
t5_feat: torch.Tensor,
|
||||||
t5_mask: torch.Tensor,
|
t5_mask: torch.Tensor,
|
||||||
):
|
):
|
||||||
"""Prepare input and conditioning embeddings."""
|
"""Prepare input and conditioning embeddings."""
|
||||||
# Visual patch embeddings with positional encoding.
|
# Visual patch embeddings with positional encoding.
|
||||||
@ -463,7 +461,6 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
assert x.ndim == 3
|
assert x.ndim == 3
|
||||||
B = x.size(0)
|
B = x.size(0)
|
||||||
|
|
||||||
|
|
||||||
pH, pW = H // self.patch_size, W // self.patch_size
|
pH, pW = H // self.patch_size, W // self.patch_size
|
||||||
N = T * pH * pW
|
N = T * pH * pW
|
||||||
assert x.size(1) == N
|
assert x.size(1) == N
|
||||||
@ -471,7 +468,7 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32
|
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32
|
||||||
) # (N, 3)
|
) # (N, 3)
|
||||||
rope_cos, rope_sin = compute_mixed_rotation(
|
rope_cos, rope_sin = compute_mixed_rotation(
|
||||||
freqs=comfy.ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos
|
freqs=ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos
|
||||||
) # Each are (N, num_heads, dim // 2)
|
) # Each are (N, num_heads, dim // 2)
|
||||||
|
|
||||||
c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D)
|
c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D)
|
||||||
@ -485,17 +482,19 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
return x, c, y_feat, rope_cos, rope_sin
|
return x, c, y_feat, rope_cos, rope_sin
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
timestep: torch.Tensor,
|
timestep: torch.Tensor,
|
||||||
context: List[torch.Tensor],
|
context: List[torch.Tensor],
|
||||||
attention_mask: List[torch.Tensor],
|
attention_mask: List[torch.Tensor],
|
||||||
num_tokens=256,
|
num_tokens=256,
|
||||||
packed_indices: Dict[str, torch.Tensor] = None,
|
packed_indices: Dict[str, torch.Tensor] = None,
|
||||||
rope_cos: torch.Tensor = None,
|
rope_cos: torch.Tensor = None,
|
||||||
rope_sin: torch.Tensor = None,
|
rope_sin: torch.Tensor = None,
|
||||||
control=None, transformer_options={}, **kwargs
|
control=None, transformer_options=None, **kwargs
|
||||||
):
|
):
|
||||||
|
if transformer_options is None:
|
||||||
|
transformer_options = {}
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
y_feat = context
|
y_feat = context
|
||||||
y_mask = attention_mask
|
y_mask = attention_mask
|
||||||
@ -522,14 +521,15 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"], out["txt"] = block(
|
out["img"], out["txt"] = block(
|
||||||
args["img"],
|
args["img"],
|
||||||
args["vec"],
|
args["vec"],
|
||||||
args["txt"],
|
args["txt"],
|
||||||
rope_cos=args["rope_cos"],
|
rope_cos=args["rope_cos"],
|
||||||
rope_sin=args["rope_sin"],
|
rope_sin=args["rope_sin"],
|
||||||
crop_y=args["num_tokens"]
|
crop_y=args["num_tokens"]
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
|
||||||
y_feat = out["txt"]
|
y_feat = out["txt"]
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
# original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
#adapted to ComfyUI
|
# adapted to ComfyUI
|
||||||
|
|
||||||
import collections.abc
|
import collections.abc
|
||||||
import math
|
import math
|
||||||
@ -10,7 +10,8 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
import comfy.ldm.common_dit
|
|
||||||
|
from ...common_dit import pad_to_patch_size, rms_norm
|
||||||
|
|
||||||
|
|
||||||
# From PyTorch internals
|
# From PyTorch internals
|
||||||
@ -28,15 +29,15 @@ to_2tuple = _ntuple(2)
|
|||||||
|
|
||||||
class TimestepEmbedder(nn.Module):
|
class TimestepEmbedder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
frequency_embedding_size: int = 256,
|
frequency_embedding_size: int = 256,
|
||||||
*,
|
*,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
timestep_scale: Optional[float] = None,
|
timestep_scale: Optional[float] = None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mlp = nn.Sequential(
|
self.mlp = nn.Sequential(
|
||||||
@ -70,14 +71,14 @@ class TimestepEmbedder(nn.Module):
|
|||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
multiple_of: int,
|
multiple_of: int,
|
||||||
ffn_dim_multiplier: Optional[float],
|
ffn_dim_multiplier: Optional[float],
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# keep parameter count and computation constant compared to standard FFN
|
# keep parameter count and computation constant compared to standard FFN
|
||||||
@ -99,17 +100,17 @@ class FeedForward(nn.Module):
|
|||||||
|
|
||||||
class PatchEmbed(nn.Module):
|
class PatchEmbed(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
patch_size: int = 16,
|
patch_size: int = 16,
|
||||||
in_chans: int = 3,
|
in_chans: int = 3,
|
||||||
embed_dim: int = 768,
|
embed_dim: int = 768,
|
||||||
norm_layer: Optional[Callable] = None,
|
norm_layer: Optional[Callable] = None,
|
||||||
flatten: bool = True,
|
flatten: bool = True,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
dynamic_img_pad: bool = False,
|
dynamic_img_pad: bool = False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.patch_size = to_2tuple(patch_size)
|
self.patch_size = to_2tuple(patch_size)
|
||||||
@ -141,7 +142,7 @@ class PatchEmbed(nn.Module):
|
|||||||
x = F.pad(x, (0, pad_w, 0, pad_h))
|
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||||
|
|
||||||
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
|
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode='circular')
|
x = pad_to_patch_size(x, self.patch_size, padding_mode='circular')
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
|
|
||||||
# Flatten temporal and spatial dimensions.
|
# Flatten temporal and spatial dimensions.
|
||||||
@ -161,4 +162,4 @@ class RMSNorm(torch.nn.Module):
|
|||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
return rms_norm(x, self.weight, self.eps)
|
||||||
|
|||||||
@ -59,8 +59,8 @@ def create_position_matrix(
|
|||||||
|
|
||||||
# Stack and reshape the grids.
|
# Stack and reshape the grids.
|
||||||
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
|
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
|
||||||
pos = pos.view(-1, 3) # [T * pH * pW, 3]
|
|
||||||
pos = pos.to(dtype=dtype, device=device)
|
pos = pos.to(dtype=dtype, device=device)
|
||||||
|
pos = pos.view(-1, 3) # [T * pH * pW, 3]
|
||||||
|
|
||||||
return pos
|
return pos
|
||||||
|
|
||||||
|
|||||||
0
comfy/ldm/genmo/vae/__init__.py
Normal file
0
comfy/ldm/genmo/vae/__init__.py
Normal file
@ -1,19 +1,18 @@
|
|||||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
# original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
#adapted to ComfyUI
|
# adapted to ComfyUI
|
||||||
|
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
|
||||||
from functools import partial
|
|
||||||
import math
|
import math
|
||||||
|
from functools import partial
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from ...modules.attention import optimized_attention
|
||||||
|
from ....ops import disable_weight_init as ops
|
||||||
|
|
||||||
import comfy.ops
|
|
||||||
ops = comfy.ops.disable_weight_init
|
|
||||||
|
|
||||||
# import mochi_preview.dit.joint_model.context_parallel as cp
|
# import mochi_preview.dit.joint_model.context_parallel as cp
|
||||||
# from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
|
# from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
|
||||||
@ -34,19 +33,20 @@ class GroupNormSpatial(ops.GroupNorm):
|
|||||||
# Run group norm in chunks.
|
# Run group norm in chunks.
|
||||||
output = torch.empty_like(x)
|
output = torch.empty_like(x)
|
||||||
for b in range(0, B * T, chunk_size):
|
for b in range(0, B * T, chunk_size):
|
||||||
output[b : b + chunk_size] = super().forward(x[b : b + chunk_size])
|
output[b: b + chunk_size] = super().forward(x[b: b + chunk_size])
|
||||||
return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
|
return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
|
||||||
|
|
||||||
|
|
||||||
class PConv3d(ops.Conv3d):
|
class PConv3d(ops.Conv3d):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size: Union[int, Tuple[int, int, int]],
|
kernel_size: Union[int, Tuple[int, int, int]],
|
||||||
stride: Union[int, Tuple[int, int, int]],
|
stride: Union[int, Tuple[int, int, int]],
|
||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
context_parallel: bool = True,
|
context_parallel: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
self.context_parallel = context_parallel
|
self.context_parallel = context_parallel
|
||||||
@ -105,9 +105,9 @@ class Conv1x1(ops.Linear):
|
|||||||
|
|
||||||
class DepthToSpaceTime(nn.Module):
|
class DepthToSpaceTime(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
temporal_expansion: int,
|
temporal_expansion: int,
|
||||||
spatial_expansion: int,
|
spatial_expansion: int,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.temporal_expansion = temporal_expansion
|
self.temporal_expansion = temporal_expansion
|
||||||
@ -135,20 +135,20 @@ class DepthToSpaceTime(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# cp_rank, _ = cp.get_cp_rank_size()
|
# cp_rank, _ = cp.get_cp_rank_size()
|
||||||
if self.temporal_expansion > 1: # and cp_rank == 0:
|
if self.temporal_expansion > 1: # and cp_rank == 0:
|
||||||
# Drop the first self.temporal_expansion - 1 frames.
|
# Drop the first self.temporal_expansion - 1 frames.
|
||||||
# This is because we always want the 3x3x3 conv filter to only apply
|
# This is because we always want the 3x3x3 conv filter to only apply
|
||||||
# to the first frame, and the first frame doesn't need to be repeated.
|
# to the first frame, and the first frame doesn't need to be repeated.
|
||||||
assert all(x.shape)
|
assert all(x.shape)
|
||||||
x = x[:, :, self.temporal_expansion - 1 :]
|
x = x[:, :, self.temporal_expansion - 1:]
|
||||||
assert all(x.shape)
|
assert all(x.shape)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def norm_fn(
|
def norm_fn(
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
affine: bool = True,
|
affine: bool = True,
|
||||||
):
|
):
|
||||||
return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels)
|
return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels)
|
||||||
|
|
||||||
@ -157,15 +157,15 @@ class ResBlock(nn.Module):
|
|||||||
"""Residual block that preserves the spatial dimensions."""
|
"""Residual block that preserves the spatial dimensions."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
channels: int,
|
channels: int,
|
||||||
*,
|
*,
|
||||||
affine: bool = True,
|
affine: bool = True,
|
||||||
attn_block: Optional[nn.Module] = None,
|
attn_block: Optional[nn.Module] = None,
|
||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
prune_bottleneck: bool = False,
|
prune_bottleneck: bool = False,
|
||||||
padding_mode: str,
|
padding_mode: str,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
@ -214,12 +214,12 @@ class ResBlock(nn.Module):
|
|||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim: int,
|
dim: int,
|
||||||
head_dim: int = 32,
|
head_dim: int = 32,
|
||||||
qkv_bias: bool = False,
|
qkv_bias: bool = False,
|
||||||
out_bias: bool = True,
|
out_bias: bool = True,
|
||||||
qk_norm: bool = True,
|
qk_norm: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
@ -230,8 +230,8 @@ class Attention(nn.Module):
|
|||||||
self.out = nn.Linear(dim, dim, bias=out_bias)
|
self.out = nn.Linear(dim, dim, bias=out_bias)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Compute temporal self-attention.
|
"""Compute temporal self-attention.
|
||||||
|
|
||||||
@ -275,9 +275,9 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
class AttentionBlock(nn.Module):
|
class AttentionBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim: int,
|
dim: int,
|
||||||
**attn_kwargs,
|
**attn_kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = norm_fn(dim)
|
self.norm = norm_fn(dim)
|
||||||
@ -289,14 +289,14 @@ class AttentionBlock(nn.Module):
|
|||||||
|
|
||||||
class CausalUpsampleBlock(nn.Module):
|
class CausalUpsampleBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
out_channels: int,
|
out_channels: int,
|
||||||
num_res_blocks: int,
|
num_res_blocks: int,
|
||||||
*,
|
*,
|
||||||
temporal_expansion: int = 2,
|
temporal_expansion: int = 2,
|
||||||
spatial_expansion: int = 2,
|
spatial_expansion: int = 2,
|
||||||
**block_kwargs,
|
**block_kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -311,7 +311,7 @@ class CausalUpsampleBlock(nn.Module):
|
|||||||
# Change channels in the final convolution layer.
|
# Change channels in the final convolution layer.
|
||||||
self.proj = Conv1x1(
|
self.proj = Conv1x1(
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels * temporal_expansion * (spatial_expansion**2),
|
out_channels * temporal_expansion * (spatial_expansion ** 2),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.d2st = DepthToSpaceTime(
|
self.d2st = DepthToSpaceTime(
|
||||||
@ -332,14 +332,14 @@ def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **bl
|
|||||||
|
|
||||||
class DownsampleBlock(nn.Module):
|
class DownsampleBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
out_channels: int,
|
out_channels: int,
|
||||||
num_res_blocks,
|
num_res_blocks,
|
||||||
*,
|
*,
|
||||||
temporal_reduction=2,
|
temporal_reduction=2,
|
||||||
spatial_reduction=2,
|
spatial_reduction=2,
|
||||||
**block_kwargs,
|
**block_kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Downsample block for the VAE encoder.
|
Downsample block for the VAE encoder.
|
||||||
@ -427,21 +427,21 @@ class FourierFeatures(nn.Module):
|
|||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
out_channels: int = 3,
|
out_channels: int = 3,
|
||||||
latent_dim: int,
|
latent_dim: int,
|
||||||
base_channels: int,
|
base_channels: int,
|
||||||
channel_multipliers: List[int],
|
channel_multipliers: List[int],
|
||||||
num_res_blocks: List[int],
|
num_res_blocks: List[int],
|
||||||
temporal_expansions: Optional[List[int]] = None,
|
temporal_expansions: Optional[List[int]] = None,
|
||||||
spatial_expansions: Optional[List[int]] = None,
|
spatial_expansions: Optional[List[int]] = None,
|
||||||
has_attention: List[bool],
|
has_attention: List[bool],
|
||||||
output_norm: bool = True,
|
output_norm: bool = True,
|
||||||
nonlinearity: str = "silu",
|
nonlinearity: str = "silu",
|
||||||
output_nonlinearity: str = "silu",
|
output_nonlinearity: str = "silu",
|
||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
**block_kwargs,
|
**block_kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_channels = latent_dim
|
self.input_channels = latent_dim
|
||||||
@ -529,6 +529,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
return self.output_proj(x).contiguous()
|
return self.output_proj(x).contiguous()
|
||||||
|
|
||||||
|
|
||||||
class LatentDistribution:
|
class LatentDistribution:
|
||||||
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
|
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
|
||||||
"""Initialize latent distribution.
|
"""Initialize latent distribution.
|
||||||
@ -560,23 +561,24 @@ class LatentDistribution:
|
|||||||
def mode(self):
|
def mode(self):
|
||||||
return self.mean
|
return self.mean
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
base_channels: int,
|
base_channels: int,
|
||||||
channel_multipliers: List[int],
|
channel_multipliers: List[int],
|
||||||
num_res_blocks: List[int],
|
num_res_blocks: List[int],
|
||||||
latent_dim: int,
|
latent_dim: int,
|
||||||
temporal_reductions: List[int],
|
temporal_reductions: List[int],
|
||||||
spatial_reductions: List[int],
|
spatial_reductions: List[int],
|
||||||
prune_bottlenecks: List[bool],
|
prune_bottlenecks: List[bool],
|
||||||
has_attentions: List[bool],
|
has_attentions: List[bool],
|
||||||
affine: bool = True,
|
affine: bool = True,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
input_is_conv_1x1: bool = False,
|
input_is_conv_1x1: bool = False,
|
||||||
padding_mode: str,
|
padding_mode: str,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.temporal_reductions = temporal_reductions
|
self.temporal_reductions = temporal_reductions
|
||||||
|
|||||||
@ -103,8 +103,10 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
def __init__(self, device="cpu", max_length=77,
|
def __init__(self, device="cpu", max_length=77,
|
||||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel,
|
freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel,
|
||||||
special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||||
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
return_projected_pooled=True, return_attention_masks=False, model_options=None): # clip-vit-base-patch32
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if model_options is None:
|
||||||
|
model_options = {}
|
||||||
if special_tokens is None:
|
if special_tokens is None:
|
||||||
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
|
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
|
||||||
assert layer in self.LAYERS
|
assert layer in self.LAYERS
|
||||||
@ -663,7 +665,9 @@ SD1TokenizerT = TypeVar("SD1TokenizerT", bound="SD1Tokenizer")
|
|||||||
|
|
||||||
|
|
||||||
class SD1Tokenizer:
|
class SD1Tokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
def __init__(self, embedding_directory=None, tokenizer_data=None, clip_name="l", tokenizer=SDTokenizer):
|
||||||
|
if tokenizer_data is None:
|
||||||
|
tokenizer_data = {}
|
||||||
self.clip_name = clip_name
|
self.clip_name = clip_name
|
||||||
self.clip = "clip_{}".format(self.clip_name)
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
|
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
|
||||||
@ -694,13 +698,18 @@ class SD1Tokenizer:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
class SD1CheckpointClipModel(SDClipModel):
|
class SD1CheckpointClipModel(SDClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}, textmodel_json_config=None):
|
def __init__(self, device="cpu", dtype=None, model_options=None, textmodel_json_config=None):
|
||||||
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config)
|
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config)
|
||||||
|
if model_options is None:
|
||||||
|
model_options = {}
|
||||||
|
|
||||||
|
|
||||||
class SD1ClipModel(torch.nn.Module):
|
class SD1ClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, textmodel_json_config=None, name=None, **kwargs):
|
def __init__(self, device="cpu", dtype=None, model_options=None, clip_name="l", clip_model=SD1CheckpointClipModel, textmodel_json_config=None, name=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if model_options is None:
|
||||||
|
model_options = {}
|
||||||
if name is not None:
|
if name is not None:
|
||||||
self.clip_name = name
|
self.clip_name = name
|
||||||
self.clip = "{}".format(self.clip_name)
|
self.clip = "{}".format(self.clip_name)
|
||||||
|
|||||||
@ -697,7 +697,9 @@ class GenmoMochi(supported_models_base.BASE):
|
|||||||
out = model_base.GenmoMochi(self, device=device)
|
out = model_base.GenmoMochi(self, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict=None):
|
||||||
|
if state_dict is None:
|
||||||
|
state_dict = {}
|
||||||
pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(genmo.MochiT5Tokenizer, genmo.mochi_te(**t5_detect))
|
return supported_models_base.ClipTarget(genmo.MochiT5Tokenizer, genmo.mochi_te(**t5_detect))
|
||||||
|
|||||||
@ -3,6 +3,8 @@ import comfy.text_encoders.sd3_clip
|
|||||||
import os
|
import os
|
||||||
from transformers import T5TokenizerFast
|
from transformers import T5TokenizerFast
|
||||||
|
|
||||||
|
from comfy.component_model import files
|
||||||
|
|
||||||
|
|
||||||
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@ -11,24 +13,32 @@ class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
|||||||
|
|
||||||
|
|
||||||
class MochiT5XXL(sd1_clip.SD1ClipModel):
|
class MochiT5XXL(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||||
|
if model_options is None:
|
||||||
|
model_options = {}
|
||||||
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
if tokenizer_data is None:
|
||||||
|
tokenizer_data = {}
|
||||||
|
tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||||
|
|
||||||
|
|
||||||
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
if tokenizer_data is None:
|
||||||
|
tokenizer_data = {}
|
||||||
|
|
||||||
|
|
||||||
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||||
class MochiTEModel_(MochiT5XXL):
|
class MochiTEModel_(MochiT5XXL):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||||
|
if model_options is None:
|
||||||
|
model_options = {}
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
|||||||
@ -20,7 +20,7 @@ class T5XXLModel(sd1_clip.SDClipModel):
|
|||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
def t5_xxl_detect(state_dict, prefix=""):
|
def t5_xxl_detect(state_dict, prefix=""):
|
||||||
@ -35,10 +35,11 @@ def t5_xxl_detect(state_dict, prefix=""):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||||
if tokenizer_data is None:
|
if tokenizer_data is None:
|
||||||
tokenizer_data = dict()
|
tokenizer_data = {}
|
||||||
tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer")
|
tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
||||||
|
|
||||||
@ -46,7 +47,7 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
|||||||
class SD3Tokenizer:
|
class SD3Tokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||||
if tokenizer_data is None:
|
if tokenizer_data is None:
|
||||||
tokenizer_data = dict()
|
tokenizer_data = {}
|
||||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||||
@ -63,15 +64,17 @@ class SD3Tokenizer:
|
|||||||
return self.clip_g.untokenize(token_weight_pair)
|
return self.clip_g.untokenize(token_weight_pair)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return dict()
|
return {}
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
return copy.copy(self)
|
return copy.copy(self)
|
||||||
|
|
||||||
|
|
||||||
class SD3ClipModel(torch.nn.Module):
|
class SD3ClipModel(torch.nn.Module):
|
||||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options={}):
|
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if model_options is None:
|
||||||
|
model_options = {}
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
if clip_l:
|
if clip_l:
|
||||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||||
@ -180,4 +183,5 @@ def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=
|
|||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|
||||||
return SD3ClipModel_
|
return SD3ClipModel_
|
||||||
|
|||||||
BIN
docs/assets/with_pytorch_attention.webp
Normal file
BIN
docs/assets/with_pytorch_attention.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.9 MiB |
BIN
docs/assets/with_sage_attention.webp
Normal file
BIN
docs/assets/with_sage_attention.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.9 MiB |
Loading…
Reference in New Issue
Block a user