Mochi and SageAttention improvements

This commit is contained in:
doctorpangloss 2024-11-18 15:40:15 -08:00
parent 264d84db39
commit 8ba412897e
14 changed files with 311 additions and 247 deletions

View File

@ -333,6 +333,42 @@ pip install git+https://github.com/AppMana/appmana-comfyui-nodes-controlnet-aux.
Start creating an AnimateDiff workflow. When using these packages, the appropriate models will download automatically. Start creating an AnimateDiff workflow. When using these packages, the appropriate models will download automatically.
## SageAttention
Improve the performance of your Mochi model video generation using **Sage Attention**:
| Device | PyTorch 2.5.1 | SageAttention | S.A. + TorchCompileModel |
|--------|---------------|---------------|--------------------------|
| A5000 | 7.52s/it | 5.81s/it | 5.00s/it (but corrupted) |
[Use the default Mochi Workflow.](https://github.com/comfyanonymous/ComfyUI_examples/raw/refs/heads/master/mochi/mochi_text_to_video_example.webp) This does not require any custom nodes or any change to your workflow.
Install the dependencies for Windows or Linux using the `withtriton` component, or install the specific dependencies you need from [requirements-triton.txt](./requirements-triton.txt):
```shell
pip install "comfyui[withtriton]@git+https://github.com/hiddenswitch/ComfyUI.git"
```
If you have `xformers` installed, disable it, as it will be preferred over Sage Attention:
```shell
comfyui --disable-xformers
```
If you want to use **TorchCompileModel** to further improve performance, do not reserve VRAM:
```shell
comfyui --disable-xformers --reserve-vram=-1.0
```
Sage Attention is not compatible with Flux. It does not appear to be compatible with Mochi when using `torch.compile`
![with_sage_attention.webp](./docs/assets/with_sage_attention.webp)
**With SageAttention**
![with_pytorch_attention](./docs/assets/with_pytorch_attention.webp)
**With PyTorch Attention**
# Custom Nodes # Custom Nodes
Custom Nodes can be added to ComfyUI by copying and pasting Python files into your `./custom_nodes` directory. Custom Nodes can be added to ComfyUI by copying and pasting Python files into your `./custom_nodes` directory.

View File

View File

View File

@ -1,14 +1,11 @@
#original code from https://github.com/genmoai/models under apache 2.0 license # original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI # adapted to ComfyUIfrom typing import Dict, List, Optional, Tuple
from typing import Tuple, List, Dict, Optional
from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
# from flash_attn import flash_attn_varlen_qkvpacked_func
from comfy.ldm.modules.attention import optimized_attention
from .layers import ( from .layers import (
FeedForward, FeedForward,
@ -16,7 +13,6 @@ from .layers import (
RMSNorm, RMSNorm,
TimestepEmbedder, TimestepEmbedder,
) )
from .rope_mixed import ( from .rope_mixed import (
compute_mixed_rotation, compute_mixed_rotation,
create_position_matrix, create_position_matrix,
@ -26,14 +22,15 @@ from .utils import (
AttentionPool, AttentionPool,
modulate, modulate,
) )
from ...common_dit import rms_norm
import comfy.ldm.common_dit # from flash_attn import flash_attn_varlen_qkvpacked_func
import comfy.ops from ...modules.attention import optimized_attention
from .... import ops
def modulated_rmsnorm(x, scale, eps=1e-6): def modulated_rmsnorm(x, scale, eps=1e-6):
# Normalize and modulate # Normalize and modulate
x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps) x_normed = rms_norm(x, eps=eps)
x_modulated = x_normed * (1 + scale.unsqueeze(1)) x_modulated = x_normed * (1 + scale.unsqueeze(1))
return x_modulated return x_modulated
@ -44,29 +41,30 @@ def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
tanh_gate = torch.tanh(gate).unsqueeze(1) tanh_gate = torch.tanh(gate).unsqueeze(1)
# Normalize and apply gated scaling # Normalize and apply gated scaling
x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate x_normed = rms_norm(x_res, eps=eps) * tanh_gate
# Apply residual connection # Apply residual connection
output = x + x_normed output = x + x_normed
return output return output
class AsymmetricAttention(nn.Module): class AsymmetricAttention(nn.Module):
def __init__( def __init__(
self, self,
dim_x: int, dim_x: int,
dim_y: int, dim_y: int,
num_heads: int = 8, num_heads: int = 8,
qkv_bias: bool = True, qkv_bias: bool = True,
qk_norm: bool = False, qk_norm: bool = False,
attn_drop: float = 0.0, attn_drop: float = 0.0,
update_y: bool = True, update_y: bool = True,
out_bias: bool = True, out_bias: bool = True,
attend_to_padding: bool = False, attend_to_padding: bool = False,
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype=None, dtype=None,
operations=None, operations=None,
): ):
super().__init__() super().__init__()
self.dim_x = dim_x self.dim_x = dim_x
@ -104,13 +102,13 @@ class AsymmetricAttention(nn.Module):
) )
def forward( def forward(
self, self,
x: torch.Tensor, # (B, N, dim_x) x: torch.Tensor, # (B, N, dim_x)
y: torch.Tensor, # (B, L, dim_y) y: torch.Tensor, # (B, L, dim_y)
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
crop_y, crop_y,
**rope_rotation, **rope_rotation,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
rope_cos = rope_rotation.get("rope_cos") rope_cos = rope_rotation.get("rope_cos")
rope_sin = rope_rotation.get("rope_sin") rope_sin = rope_rotation.get("rope_sin")
@ -159,18 +157,18 @@ class AsymmetricAttention(nn.Module):
class AsymmetricJointBlock(nn.Module): class AsymmetricJointBlock(nn.Module):
def __init__( def __init__(
self, self,
hidden_size_x: int, hidden_size_x: int,
hidden_size_y: int, hidden_size_y: int,
num_heads: int, num_heads: int,
*, *,
mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens. mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens. mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
update_y: bool = True, # Whether to update text tokens in this block. update_y: bool = True, # Whether to update text tokens in this block.
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype=None, dtype=None,
operations=None, operations=None,
**block_kwargs, **block_kwargs,
): ):
super().__init__() super().__init__()
self.update_y = update_y self.update_y = update_y
@ -221,11 +219,11 @@ class AsymmetricJointBlock(nn.Module):
) )
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
c: torch.Tensor, c: torch.Tensor,
y: torch.Tensor, y: torch.Tensor,
**attn_kwargs, **attn_kwargs,
): ):
"""Forward pass of a block. """Forward pass of a block.
@ -291,13 +289,13 @@ class FinalLayer(nn.Module):
""" """
def __init__( def __init__(
self, self,
hidden_size, hidden_size,
patch_size, patch_size,
out_channels, out_channels,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype=None, dtype=None,
operations=None, operations=None,
): ):
super().__init__() super().__init__()
self.norm_final = operations.LayerNorm( self.norm_final = operations.LayerNorm(
@ -324,32 +322,32 @@ class AsymmDiTJoint(nn.Module):
""" """
def __init__( def __init__(
self, self,
*, *,
patch_size=2, patch_size=2,
in_channels=4, in_channels=4,
hidden_size_x=1152, hidden_size_x=1152,
hidden_size_y=1152, hidden_size_y=1152,
depth=48, depth=48,
num_heads=16, num_heads=16,
mlp_ratio_x=8.0, mlp_ratio_x=8.0,
mlp_ratio_y=4.0, mlp_ratio_y=4.0,
use_t5: bool = False, use_t5: bool = False,
t5_feat_dim: int = 4096, t5_feat_dim: int = 4096,
t5_token_length: int = 256, t5_token_length: int = 256,
learn_sigma=True, learn_sigma=True,
patch_embed_bias: bool = True, patch_embed_bias: bool = True,
timestep_mlp_bias: bool = True, timestep_mlp_bias: bool = True,
attend_to_padding: bool = False, attend_to_padding: bool = False,
timestep_scale: Optional[float] = None, timestep_scale: Optional[float] = None,
use_extended_posenc: bool = False, use_extended_posenc: bool = False,
posenc_preserve_area: bool = False, posenc_preserve_area: bool = False,
rope_theta: float = 10000.0, rope_theta: float = 10000.0,
image_model=None, image_model=None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype=None, dtype=None,
operations=None, operations=None,
**block_kwargs, **block_kwargs,
): ):
super().__init__() super().__init__()
@ -362,7 +360,7 @@ class AsymmDiTJoint(nn.Module):
self.hidden_size_x = hidden_size_x self.hidden_size_x = hidden_size_x
self.hidden_size_y = hidden_size_y self.hidden_size_y = hidden_size_y
self.head_dim = ( self.head_dim = (
hidden_size_x // num_heads hidden_size_x // num_heads
) # Head dimension and count is determined by visual. ) # Head dimension and count is determined by visual.
self.attend_to_padding = attend_to_padding self.attend_to_padding = attend_to_padding
self.use_extended_posenc = use_extended_posenc self.use_extended_posenc = use_extended_posenc
@ -449,11 +447,11 @@ class AsymmDiTJoint(nn.Module):
return self.x_embedder(x) # Convert BcTHW to BCN return self.x_embedder(x) # Convert BcTHW to BCN
def prepare( def prepare(
self, self,
x: torch.Tensor, x: torch.Tensor,
sigma: torch.Tensor, sigma: torch.Tensor,
t5_feat: torch.Tensor, t5_feat: torch.Tensor,
t5_mask: torch.Tensor, t5_mask: torch.Tensor,
): ):
"""Prepare input and conditioning embeddings.""" """Prepare input and conditioning embeddings."""
# Visual patch embeddings with positional encoding. # Visual patch embeddings with positional encoding.
@ -463,7 +461,6 @@ class AsymmDiTJoint(nn.Module):
assert x.ndim == 3 assert x.ndim == 3
B = x.size(0) B = x.size(0)
pH, pW = H // self.patch_size, W // self.patch_size pH, pW = H // self.patch_size, W // self.patch_size
N = T * pH * pW N = T * pH * pW
assert x.size(1) == N assert x.size(1) == N
@ -471,7 +468,7 @@ class AsymmDiTJoint(nn.Module):
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32 T, pH=pH, pW=pW, device=x.device, dtype=torch.float32
) # (N, 3) ) # (N, 3)
rope_cos, rope_sin = compute_mixed_rotation( rope_cos, rope_sin = compute_mixed_rotation(
freqs=comfy.ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos freqs=ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos
) # Each are (N, num_heads, dim // 2) ) # Each are (N, num_heads, dim // 2)
c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D) c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D)
@ -485,17 +482,19 @@ class AsymmDiTJoint(nn.Module):
return x, c, y_feat, rope_cos, rope_sin return x, c, y_feat, rope_cos, rope_sin
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
timestep: torch.Tensor, timestep: torch.Tensor,
context: List[torch.Tensor], context: List[torch.Tensor],
attention_mask: List[torch.Tensor], attention_mask: List[torch.Tensor],
num_tokens=256, num_tokens=256,
packed_indices: Dict[str, torch.Tensor] = None, packed_indices: Dict[str, torch.Tensor] = None,
rope_cos: torch.Tensor = None, rope_cos: torch.Tensor = None,
rope_sin: torch.Tensor = None, rope_sin: torch.Tensor = None,
control=None, transformer_options={}, **kwargs control=None, transformer_options=None, **kwargs
): ):
if transformer_options is None:
transformer_options = {}
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
y_feat = context y_feat = context
y_mask = attention_mask y_mask = attention_mask
@ -522,14 +521,15 @@ class AsymmDiTJoint(nn.Module):
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"], out["txt"] = block( out["img"], out["txt"] = block(
args["img"], args["img"],
args["vec"], args["vec"],
args["txt"], args["txt"],
rope_cos=args["rope_cos"], rope_cos=args["rope_cos"],
rope_sin=args["rope_sin"], rope_sin=args["rope_sin"],
crop_y=args["num_tokens"] crop_y=args["num_tokens"]
) )
return out return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
y_feat = out["txt"] y_feat = out["txt"]
x = out["img"] x = out["img"]

View File

@ -1,5 +1,5 @@
#original code from https://github.com/genmoai/models under apache 2.0 license # original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI # adapted to ComfyUI
import collections.abc import collections.abc
import math import math
@ -10,7 +10,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
import comfy.ldm.common_dit
from ...common_dit import pad_to_patch_size, rms_norm
# From PyTorch internals # From PyTorch internals
@ -28,15 +29,15 @@ to_2tuple = _ntuple(2)
class TimestepEmbedder(nn.Module): class TimestepEmbedder(nn.Module):
def __init__( def __init__(
self, self,
hidden_size: int, hidden_size: int,
frequency_embedding_size: int = 256, frequency_embedding_size: int = 256,
*, *,
bias: bool = True, bias: bool = True,
timestep_scale: Optional[float] = None, timestep_scale: Optional[float] = None,
dtype=None, dtype=None,
device=None, device=None,
operations=None, operations=None,
): ):
super().__init__() super().__init__()
self.mlp = nn.Sequential( self.mlp = nn.Sequential(
@ -70,14 +71,14 @@ class TimestepEmbedder(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__( def __init__(
self, self,
in_features: int, in_features: int,
hidden_size: int, hidden_size: int,
multiple_of: int, multiple_of: int,
ffn_dim_multiplier: Optional[float], ffn_dim_multiplier: Optional[float],
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype=None, dtype=None,
operations=None, operations=None,
): ):
super().__init__() super().__init__()
# keep parameter count and computation constant compared to standard FFN # keep parameter count and computation constant compared to standard FFN
@ -99,17 +100,17 @@ class FeedForward(nn.Module):
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
def __init__( def __init__(
self, self,
patch_size: int = 16, patch_size: int = 16,
in_chans: int = 3, in_chans: int = 3,
embed_dim: int = 768, embed_dim: int = 768,
norm_layer: Optional[Callable] = None, norm_layer: Optional[Callable] = None,
flatten: bool = True, flatten: bool = True,
bias: bool = True, bias: bool = True,
dynamic_img_pad: bool = False, dynamic_img_pad: bool = False,
dtype=None, dtype=None,
device=None, device=None,
operations=None, operations=None,
): ):
super().__init__() super().__init__()
self.patch_size = to_2tuple(patch_size) self.patch_size = to_2tuple(patch_size)
@ -141,7 +142,7 @@ class PatchEmbed(nn.Module):
x = F.pad(x, (0, pad_w, 0, pad_h)) x = F.pad(x, (0, pad_w, 0, pad_h))
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T) x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode='circular') x = pad_to_patch_size(x, self.patch_size, padding_mode='circular')
x = self.proj(x) x = self.proj(x)
# Flatten temporal and spatial dimensions. # Flatten temporal and spatial dimensions.
@ -161,4 +162,4 @@ class RMSNorm(torch.nn.Module):
self.register_parameter("bias", None) self.register_parameter("bias", None)
def forward(self, x): def forward(self, x):
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps) return rms_norm(x, self.weight, self.eps)

View File

@ -59,8 +59,8 @@ def create_position_matrix(
# Stack and reshape the grids. # Stack and reshape the grids.
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3] pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
pos = pos.view(-1, 3) # [T * pH * pW, 3]
pos = pos.to(dtype=dtype, device=device) pos = pos.to(dtype=dtype, device=device)
pos = pos.view(-1, 3) # [T * pH * pW, 3]
return pos return pos

View File

View File

@ -1,19 +1,18 @@
#original code from https://github.com/genmoai/models under apache 2.0 license # original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI # adapted to ComfyUI
from typing import Callable, List, Optional, Tuple, Union
from functools import partial
import math import math
from functools import partial
from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention from ...modules.attention import optimized_attention
from ....ops import disable_weight_init as ops
import comfy.ops
ops = comfy.ops.disable_weight_init
# import mochi_preview.dit.joint_model.context_parallel as cp # import mochi_preview.dit.joint_model.context_parallel as cp
# from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames # from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
@ -34,19 +33,20 @@ class GroupNormSpatial(ops.GroupNorm):
# Run group norm in chunks. # Run group norm in chunks.
output = torch.empty_like(x) output = torch.empty_like(x)
for b in range(0, B * T, chunk_size): for b in range(0, B * T, chunk_size):
output[b : b + chunk_size] = super().forward(x[b : b + chunk_size]) output[b: b + chunk_size] = super().forward(x[b: b + chunk_size])
return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T) return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
class PConv3d(ops.Conv3d): class PConv3d(ops.Conv3d):
def __init__( def __init__(
self, self,
in_channels, in_channels,
out_channels, out_channels,
kernel_size: Union[int, Tuple[int, int, int]], kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]], stride: Union[int, Tuple[int, int, int]],
causal: bool = True, causal: bool = True,
context_parallel: bool = True, context_parallel: bool = True,
**kwargs, **kwargs,
): ):
self.causal = causal self.causal = causal
self.context_parallel = context_parallel self.context_parallel = context_parallel
@ -105,9 +105,9 @@ class Conv1x1(ops.Linear):
class DepthToSpaceTime(nn.Module): class DepthToSpaceTime(nn.Module):
def __init__( def __init__(
self, self,
temporal_expansion: int, temporal_expansion: int,
spatial_expansion: int, spatial_expansion: int,
): ):
super().__init__() super().__init__()
self.temporal_expansion = temporal_expansion self.temporal_expansion = temporal_expansion
@ -135,20 +135,20 @@ class DepthToSpaceTime(nn.Module):
) )
# cp_rank, _ = cp.get_cp_rank_size() # cp_rank, _ = cp.get_cp_rank_size()
if self.temporal_expansion > 1: # and cp_rank == 0: if self.temporal_expansion > 1: # and cp_rank == 0:
# Drop the first self.temporal_expansion - 1 frames. # Drop the first self.temporal_expansion - 1 frames.
# This is because we always want the 3x3x3 conv filter to only apply # This is because we always want the 3x3x3 conv filter to only apply
# to the first frame, and the first frame doesn't need to be repeated. # to the first frame, and the first frame doesn't need to be repeated.
assert all(x.shape) assert all(x.shape)
x = x[:, :, self.temporal_expansion - 1 :] x = x[:, :, self.temporal_expansion - 1:]
assert all(x.shape) assert all(x.shape)
return x return x
def norm_fn( def norm_fn(
in_channels: int, in_channels: int,
affine: bool = True, affine: bool = True,
): ):
return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels) return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels)
@ -157,15 +157,15 @@ class ResBlock(nn.Module):
"""Residual block that preserves the spatial dimensions.""" """Residual block that preserves the spatial dimensions."""
def __init__( def __init__(
self, self,
channels: int, channels: int,
*, *,
affine: bool = True, affine: bool = True,
attn_block: Optional[nn.Module] = None, attn_block: Optional[nn.Module] = None,
causal: bool = True, causal: bool = True,
prune_bottleneck: bool = False, prune_bottleneck: bool = False,
padding_mode: str, padding_mode: str,
bias: bool = True, bias: bool = True,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
@ -214,12 +214,12 @@ class ResBlock(nn.Module):
class Attention(nn.Module): class Attention(nn.Module):
def __init__( def __init__(
self, self,
dim: int, dim: int,
head_dim: int = 32, head_dim: int = 32,
qkv_bias: bool = False, qkv_bias: bool = False,
out_bias: bool = True, out_bias: bool = True,
qk_norm: bool = True, qk_norm: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
self.head_dim = head_dim self.head_dim = head_dim
@ -230,8 +230,8 @@ class Attention(nn.Module):
self.out = nn.Linear(dim, dim, bias=out_bias) self.out = nn.Linear(dim, dim, bias=out_bias)
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
"""Compute temporal self-attention. """Compute temporal self-attention.
@ -275,9 +275,9 @@ class Attention(nn.Module):
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
def __init__( def __init__(
self, self,
dim: int, dim: int,
**attn_kwargs, **attn_kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
self.norm = norm_fn(dim) self.norm = norm_fn(dim)
@ -289,14 +289,14 @@ class AttentionBlock(nn.Module):
class CausalUpsampleBlock(nn.Module): class CausalUpsampleBlock(nn.Module):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
num_res_blocks: int, num_res_blocks: int,
*, *,
temporal_expansion: int = 2, temporal_expansion: int = 2,
spatial_expansion: int = 2, spatial_expansion: int = 2,
**block_kwargs, **block_kwargs,
): ):
super().__init__() super().__init__()
@ -311,7 +311,7 @@ class CausalUpsampleBlock(nn.Module):
# Change channels in the final convolution layer. # Change channels in the final convolution layer.
self.proj = Conv1x1( self.proj = Conv1x1(
in_channels, in_channels,
out_channels * temporal_expansion * (spatial_expansion**2), out_channels * temporal_expansion * (spatial_expansion ** 2),
) )
self.d2st = DepthToSpaceTime( self.d2st = DepthToSpaceTime(
@ -332,14 +332,14 @@ def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **bl
class DownsampleBlock(nn.Module): class DownsampleBlock(nn.Module):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
num_res_blocks, num_res_blocks,
*, *,
temporal_reduction=2, temporal_reduction=2,
spatial_reduction=2, spatial_reduction=2,
**block_kwargs, **block_kwargs,
): ):
""" """
Downsample block for the VAE encoder. Downsample block for the VAE encoder.
@ -427,21 +427,21 @@ class FourierFeatures(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__( def __init__(
self, self,
*, *,
out_channels: int = 3, out_channels: int = 3,
latent_dim: int, latent_dim: int,
base_channels: int, base_channels: int,
channel_multipliers: List[int], channel_multipliers: List[int],
num_res_blocks: List[int], num_res_blocks: List[int],
temporal_expansions: Optional[List[int]] = None, temporal_expansions: Optional[List[int]] = None,
spatial_expansions: Optional[List[int]] = None, spatial_expansions: Optional[List[int]] = None,
has_attention: List[bool], has_attention: List[bool],
output_norm: bool = True, output_norm: bool = True,
nonlinearity: str = "silu", nonlinearity: str = "silu",
output_nonlinearity: str = "silu", output_nonlinearity: str = "silu",
causal: bool = True, causal: bool = True,
**block_kwargs, **block_kwargs,
): ):
super().__init__() super().__init__()
self.input_channels = latent_dim self.input_channels = latent_dim
@ -529,6 +529,7 @@ class Decoder(nn.Module):
return self.output_proj(x).contiguous() return self.output_proj(x).contiguous()
class LatentDistribution: class LatentDistribution:
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor): def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
"""Initialize latent distribution. """Initialize latent distribution.
@ -560,23 +561,24 @@ class LatentDistribution:
def mode(self): def mode(self):
return self.mean return self.mean
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__( def __init__(
self, self,
*, *,
in_channels: int, in_channels: int,
base_channels: int, base_channels: int,
channel_multipliers: List[int], channel_multipliers: List[int],
num_res_blocks: List[int], num_res_blocks: List[int],
latent_dim: int, latent_dim: int,
temporal_reductions: List[int], temporal_reductions: List[int],
spatial_reductions: List[int], spatial_reductions: List[int],
prune_bottlenecks: List[bool], prune_bottlenecks: List[bool],
has_attentions: List[bool], has_attentions: List[bool],
affine: bool = True, affine: bool = True,
bias: bool = True, bias: bool = True,
input_is_conv_1x1: bool = False, input_is_conv_1x1: bool = False,
padding_mode: str, padding_mode: str,
): ):
super().__init__() super().__init__()
self.temporal_reductions = temporal_reductions self.temporal_reductions = temporal_reductions

View File

@ -103,8 +103,10 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def __init__(self, device="cpu", max_length=77, def __init__(self, device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel, freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel,
special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False, special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32 return_projected_pooled=True, return_attention_masks=False, model_options=None): # clip-vit-base-patch32
super().__init__() super().__init__()
if model_options is None:
model_options = {}
if special_tokens is None: if special_tokens is None:
special_tokens = {"start": 49406, "end": 49407, "pad": 49407} special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
assert layer in self.LAYERS assert layer in self.LAYERS
@ -663,7 +665,9 @@ SD1TokenizerT = TypeVar("SD1TokenizerT", bound="SD1Tokenizer")
class SD1Tokenizer: class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data=None, clip_name="l", tokenizer=SDTokenizer):
if tokenizer_data is None:
tokenizer_data = {}
self.clip_name = clip_name self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name) self.clip = "clip_{}".format(self.clip_name)
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer) tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
@ -694,13 +698,18 @@ class SD1Tokenizer:
return {} return {}
class SD1CheckpointClipModel(SDClipModel): class SD1CheckpointClipModel(SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, textmodel_json_config=None): def __init__(self, device="cpu", dtype=None, model_options=None, textmodel_json_config=None):
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config) super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config)
if model_options is None:
model_options = {}
class SD1ClipModel(torch.nn.Module): class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, textmodel_json_config=None, name=None, **kwargs): def __init__(self, device="cpu", dtype=None, model_options=None, clip_name="l", clip_model=SD1CheckpointClipModel, textmodel_json_config=None, name=None, **kwargs):
super().__init__() super().__init__()
if model_options is None:
model_options = {}
if name is not None: if name is not None:
self.clip_name = name self.clip_name = name
self.clip = "{}".format(self.clip_name) self.clip = "{}".format(self.clip_name)

View File

@ -697,7 +697,9 @@ class GenmoMochi(supported_models_base.BASE):
out = model_base.GenmoMochi(self, device=device) out = model_base.GenmoMochi(self, device=device)
return out return out
def clip_target(self, state_dict={}): def clip_target(self, state_dict=None):
if state_dict is None:
state_dict = {}
pref = self.text_encoder_key_prefix[0] pref = self.text_encoder_key_prefix[0]
t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(genmo.MochiT5Tokenizer, genmo.mochi_te(**t5_detect)) return supported_models_base.ClipTarget(genmo.MochiT5Tokenizer, genmo.mochi_te(**t5_detect))

View File

@ -3,6 +3,8 @@ import comfy.text_encoders.sd3_clip
import os import os
from transformers import T5TokenizerFast from transformers import T5TokenizerFast
from comfy.component_model import files
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel): class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -11,24 +13,32 @@ class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
class MochiT5XXL(sd1_clip.SD1ClipModel): class MochiT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options) super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data=None):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") if tokenizer_data is None:
tokenizer_data = {}
tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer): class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data=None):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
if tokenizer_data is None:
tokenizer_data = {}
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None): def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
class MochiTEModel_(MochiT5XXL): class MochiTEModel_(MochiT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy() model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8

View File

@ -20,7 +20,7 @@ class T5XXLModel(sd1_clip.SDClipModel):
model_options = model_options.copy() model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8 model_options["scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, model_options=model_options) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def t5_xxl_detect(state_dict, prefix=""): def t5_xxl_detect(state_dict, prefix=""):
@ -35,10 +35,11 @@ def t5_xxl_detect(state_dict, prefix=""):
return out return out
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data=None): def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None: if tokenizer_data is None:
tokenizer_data = dict() tokenizer_data = {}
tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer") tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77) super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
@ -46,7 +47,7 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
class SD3Tokenizer: class SD3Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data=None): def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None: if tokenizer_data is None:
tokenizer_data = dict() tokenizer_data = {}
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory) self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
@ -63,15 +64,17 @@ class SD3Tokenizer:
return self.clip_g.untokenize(token_weight_pair) return self.clip_g.untokenize(token_weight_pair)
def state_dict(self): def state_dict(self):
return dict() return {}
def clone(self): def clone(self):
return copy.copy(self) return copy.copy(self)
class SD3ClipModel(torch.nn.Module): class SD3ClipModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options={}): def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options=None):
super().__init__() super().__init__()
if model_options is None:
model_options = {}
self.dtypes = set() self.dtypes = set()
if clip_l: if clip_l:
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
@ -180,4 +183,5 @@ def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=
model_options = model_options.copy() model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options) super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
return SD3ClipModel_ return SD3ClipModel_

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB