Mochi and SageAttention improvements

This commit is contained in:
doctorpangloss 2024-11-18 15:40:15 -08:00
parent 264d84db39
commit 8ba412897e
14 changed files with 311 additions and 247 deletions

View File

@ -333,6 +333,42 @@ pip install git+https://github.com/AppMana/appmana-comfyui-nodes-controlnet-aux.
Start creating an AnimateDiff workflow. When using these packages, the appropriate models will download automatically.
## SageAttention
Improve the performance of your Mochi model video generation using **Sage Attention**:
| Device | PyTorch 2.5.1 | SageAttention | S.A. + TorchCompileModel |
|--------|---------------|---------------|--------------------------|
| A5000 | 7.52s/it | 5.81s/it | 5.00s/it (but corrupted) |
[Use the default Mochi Workflow.](https://github.com/comfyanonymous/ComfyUI_examples/raw/refs/heads/master/mochi/mochi_text_to_video_example.webp) This does not require any custom nodes or any change to your workflow.
Install the dependencies for Windows or Linux using the `withtriton` component, or install the specific dependencies you need from [requirements-triton.txt](./requirements-triton.txt):
```shell
pip install "comfyui[withtriton]@git+https://github.com/hiddenswitch/ComfyUI.git"
```
If you have `xformers` installed, disable it, as it will be preferred over Sage Attention:
```shell
comfyui --disable-xformers
```
If you want to use **TorchCompileModel** to further improve performance, do not reserve VRAM:
```shell
comfyui --disable-xformers --reserve-vram=-1.0
```
Sage Attention is not compatible with Flux. It does not appear to be compatible with Mochi when using `torch.compile`
![with_sage_attention.webp](./docs/assets/with_sage_attention.webp)
**With SageAttention**
![with_pytorch_attention](./docs/assets/with_pytorch_attention.webp)
**With PyTorch Attention**
# Custom Nodes
Custom Nodes can be added to ComfyUI by copying and pasting Python files into your `./custom_nodes` directory.

View File

View File

View File

@ -1,14 +1,11 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI
from typing import Dict, List, Optional, Tuple
# original code from https://github.com/genmoai/models under apache 2.0 license
# adapted to ComfyUIfrom typing import Dict, List, Optional, Tuple
from typing import Tuple, List, Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
# from flash_attn import flash_attn_varlen_qkvpacked_func
from comfy.ldm.modules.attention import optimized_attention
from .layers import (
FeedForward,
@ -16,7 +13,6 @@ from .layers import (
RMSNorm,
TimestepEmbedder,
)
from .rope_mixed import (
compute_mixed_rotation,
create_position_matrix,
@ -26,14 +22,15 @@ from .utils import (
AttentionPool,
modulate,
)
import comfy.ldm.common_dit
import comfy.ops
from ...common_dit import rms_norm
# from flash_attn import flash_attn_varlen_qkvpacked_func
from ...modules.attention import optimized_attention
from .... import ops
def modulated_rmsnorm(x, scale, eps=1e-6):
# Normalize and modulate
x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps)
x_normed = rms_norm(x, eps=eps)
x_modulated = x_normed * (1 + scale.unsqueeze(1))
return x_modulated
@ -44,29 +41,30 @@ def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
tanh_gate = torch.tanh(gate).unsqueeze(1)
# Normalize and apply gated scaling
x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate
x_normed = rms_norm(x_res, eps=eps) * tanh_gate
# Apply residual connection
output = x + x_normed
return output
class AsymmetricAttention(nn.Module):
def __init__(
self,
dim_x: int,
dim_y: int,
num_heads: int = 8,
qkv_bias: bool = True,
qk_norm: bool = False,
attn_drop: float = 0.0,
update_y: bool = True,
out_bias: bool = True,
attend_to_padding: bool = False,
softmax_scale: Optional[float] = None,
device: Optional[torch.device] = None,
dtype=None,
operations=None,
self,
dim_x: int,
dim_y: int,
num_heads: int = 8,
qkv_bias: bool = True,
qk_norm: bool = False,
attn_drop: float = 0.0,
update_y: bool = True,
out_bias: bool = True,
attend_to_padding: bool = False,
softmax_scale: Optional[float] = None,
device: Optional[torch.device] = None,
dtype=None,
operations=None,
):
super().__init__()
self.dim_x = dim_x
@ -104,13 +102,13 @@ class AsymmetricAttention(nn.Module):
)
def forward(
self,
x: torch.Tensor, # (B, N, dim_x)
y: torch.Tensor, # (B, L, dim_y)
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
crop_y,
**rope_rotation,
self,
x: torch.Tensor, # (B, N, dim_x)
y: torch.Tensor, # (B, L, dim_y)
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
crop_y,
**rope_rotation,
) -> Tuple[torch.Tensor, torch.Tensor]:
rope_cos = rope_rotation.get("rope_cos")
rope_sin = rope_rotation.get("rope_sin")
@ -159,18 +157,18 @@ class AsymmetricAttention(nn.Module):
class AsymmetricJointBlock(nn.Module):
def __init__(
self,
hidden_size_x: int,
hidden_size_y: int,
num_heads: int,
*,
mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
update_y: bool = True, # Whether to update text tokens in this block.
device: Optional[torch.device] = None,
dtype=None,
operations=None,
**block_kwargs,
self,
hidden_size_x: int,
hidden_size_y: int,
num_heads: int,
*,
mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
update_y: bool = True, # Whether to update text tokens in this block.
device: Optional[torch.device] = None,
dtype=None,
operations=None,
**block_kwargs,
):
super().__init__()
self.update_y = update_y
@ -221,11 +219,11 @@ class AsymmetricJointBlock(nn.Module):
)
def forward(
self,
x: torch.Tensor,
c: torch.Tensor,
y: torch.Tensor,
**attn_kwargs,
self,
x: torch.Tensor,
c: torch.Tensor,
y: torch.Tensor,
**attn_kwargs,
):
"""Forward pass of a block.
@ -291,13 +289,13 @@ class FinalLayer(nn.Module):
"""
def __init__(
self,
hidden_size,
patch_size,
out_channels,
device: Optional[torch.device] = None,
dtype=None,
operations=None,
self,
hidden_size,
patch_size,
out_channels,
device: Optional[torch.device] = None,
dtype=None,
operations=None,
):
super().__init__()
self.norm_final = operations.LayerNorm(
@ -324,32 +322,32 @@ class AsymmDiTJoint(nn.Module):
"""
def __init__(
self,
*,
patch_size=2,
in_channels=4,
hidden_size_x=1152,
hidden_size_y=1152,
depth=48,
num_heads=16,
mlp_ratio_x=8.0,
mlp_ratio_y=4.0,
use_t5: bool = False,
t5_feat_dim: int = 4096,
t5_token_length: int = 256,
learn_sigma=True,
patch_embed_bias: bool = True,
timestep_mlp_bias: bool = True,
attend_to_padding: bool = False,
timestep_scale: Optional[float] = None,
use_extended_posenc: bool = False,
posenc_preserve_area: bool = False,
rope_theta: float = 10000.0,
image_model=None,
device: Optional[torch.device] = None,
dtype=None,
operations=None,
**block_kwargs,
self,
*,
patch_size=2,
in_channels=4,
hidden_size_x=1152,
hidden_size_y=1152,
depth=48,
num_heads=16,
mlp_ratio_x=8.0,
mlp_ratio_y=4.0,
use_t5: bool = False,
t5_feat_dim: int = 4096,
t5_token_length: int = 256,
learn_sigma=True,
patch_embed_bias: bool = True,
timestep_mlp_bias: bool = True,
attend_to_padding: bool = False,
timestep_scale: Optional[float] = None,
use_extended_posenc: bool = False,
posenc_preserve_area: bool = False,
rope_theta: float = 10000.0,
image_model=None,
device: Optional[torch.device] = None,
dtype=None,
operations=None,
**block_kwargs,
):
super().__init__()
@ -362,7 +360,7 @@ class AsymmDiTJoint(nn.Module):
self.hidden_size_x = hidden_size_x
self.hidden_size_y = hidden_size_y
self.head_dim = (
hidden_size_x // num_heads
hidden_size_x // num_heads
) # Head dimension and count is determined by visual.
self.attend_to_padding = attend_to_padding
self.use_extended_posenc = use_extended_posenc
@ -449,11 +447,11 @@ class AsymmDiTJoint(nn.Module):
return self.x_embedder(x) # Convert BcTHW to BCN
def prepare(
self,
x: torch.Tensor,
sigma: torch.Tensor,
t5_feat: torch.Tensor,
t5_mask: torch.Tensor,
self,
x: torch.Tensor,
sigma: torch.Tensor,
t5_feat: torch.Tensor,
t5_mask: torch.Tensor,
):
"""Prepare input and conditioning embeddings."""
# Visual patch embeddings with positional encoding.
@ -463,7 +461,6 @@ class AsymmDiTJoint(nn.Module):
assert x.ndim == 3
B = x.size(0)
pH, pW = H // self.patch_size, W // self.patch_size
N = T * pH * pW
assert x.size(1) == N
@ -471,7 +468,7 @@ class AsymmDiTJoint(nn.Module):
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32
) # (N, 3)
rope_cos, rope_sin = compute_mixed_rotation(
freqs=comfy.ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos
freqs=ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos
) # Each are (N, num_heads, dim // 2)
c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D)
@ -485,17 +482,19 @@ class AsymmDiTJoint(nn.Module):
return x, c, y_feat, rope_cos, rope_sin
def forward(
self,
x: torch.Tensor,
timestep: torch.Tensor,
context: List[torch.Tensor],
attention_mask: List[torch.Tensor],
num_tokens=256,
packed_indices: Dict[str, torch.Tensor] = None,
rope_cos: torch.Tensor = None,
rope_sin: torch.Tensor = None,
control=None, transformer_options={}, **kwargs
self,
x: torch.Tensor,
timestep: torch.Tensor,
context: List[torch.Tensor],
attention_mask: List[torch.Tensor],
num_tokens=256,
packed_indices: Dict[str, torch.Tensor] = None,
rope_cos: torch.Tensor = None,
rope_sin: torch.Tensor = None,
control=None, transformer_options=None, **kwargs
):
if transformer_options is None:
transformer_options = {}
patches_replace = transformer_options.get("patches_replace", {})
y_feat = context
y_mask = attention_mask
@ -522,14 +521,15 @@ class AsymmDiTJoint(nn.Module):
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(
args["img"],
args["vec"],
args["txt"],
rope_cos=args["rope_cos"],
rope_sin=args["rope_sin"],
crop_y=args["num_tokens"]
)
args["img"],
args["vec"],
args["txt"],
rope_cos=args["rope_cos"],
rope_sin=args["rope_sin"],
crop_y=args["num_tokens"]
)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
y_feat = out["txt"]
x = out["img"]

View File

@ -1,5 +1,5 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI
# original code from https://github.com/genmoai/models under apache 2.0 license
# adapted to ComfyUI
import collections.abc
import math
@ -10,7 +10,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import comfy.ldm.common_dit
from ...common_dit import pad_to_patch_size, rms_norm
# From PyTorch internals
@ -28,15 +29,15 @@ to_2tuple = _ntuple(2)
class TimestepEmbedder(nn.Module):
def __init__(
self,
hidden_size: int,
frequency_embedding_size: int = 256,
*,
bias: bool = True,
timestep_scale: Optional[float] = None,
dtype=None,
device=None,
operations=None,
self,
hidden_size: int,
frequency_embedding_size: int = 256,
*,
bias: bool = True,
timestep_scale: Optional[float] = None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.mlp = nn.Sequential(
@ -70,14 +71,14 @@ class TimestepEmbedder(nn.Module):
class FeedForward(nn.Module):
def __init__(
self,
in_features: int,
hidden_size: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
device: Optional[torch.device] = None,
dtype=None,
operations=None,
self,
in_features: int,
hidden_size: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
device: Optional[torch.device] = None,
dtype=None,
operations=None,
):
super().__init__()
# keep parameter count and computation constant compared to standard FFN
@ -99,17 +100,17 @@ class FeedForward(nn.Module):
class PatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten: bool = True,
bias: bool = True,
dynamic_img_pad: bool = False,
dtype=None,
device=None,
operations=None,
self,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten: bool = True,
bias: bool = True,
dynamic_img_pad: bool = False,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.patch_size = to_2tuple(patch_size)
@ -141,7 +142,7 @@ class PatchEmbed(nn.Module):
x = F.pad(x, (0, pad_w, 0, pad_h))
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode='circular')
x = pad_to_patch_size(x, self.patch_size, padding_mode='circular')
x = self.proj(x)
# Flatten temporal and spatial dimensions.
@ -161,4 +162,4 @@ class RMSNorm(torch.nn.Module):
self.register_parameter("bias", None)
def forward(self, x):
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
return rms_norm(x, self.weight, self.eps)

View File

@ -59,8 +59,8 @@ def create_position_matrix(
# Stack and reshape the grids.
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
pos = pos.view(-1, 3) # [T * pH * pW, 3]
pos = pos.to(dtype=dtype, device=device)
pos = pos.view(-1, 3) # [T * pH * pW, 3]
return pos

View File

View File

@ -1,19 +1,18 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI
# original code from https://github.com/genmoai/models under apache 2.0 license
# adapted to ComfyUI
from typing import Callable, List, Optional, Tuple, Union
from functools import partial
import math
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
from ...modules.attention import optimized_attention
from ....ops import disable_weight_init as ops
import comfy.ops
ops = comfy.ops.disable_weight_init
# import mochi_preview.dit.joint_model.context_parallel as cp
# from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
@ -34,19 +33,20 @@ class GroupNormSpatial(ops.GroupNorm):
# Run group norm in chunks.
output = torch.empty_like(x)
for b in range(0, B * T, chunk_size):
output[b : b + chunk_size] = super().forward(x[b : b + chunk_size])
output[b: b + chunk_size] = super().forward(x[b: b + chunk_size])
return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
class PConv3d(ops.Conv3d):
def __init__(
self,
in_channels,
out_channels,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]],
causal: bool = True,
context_parallel: bool = True,
**kwargs,
self,
in_channels,
out_channels,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]],
causal: bool = True,
context_parallel: bool = True,
**kwargs,
):
self.causal = causal
self.context_parallel = context_parallel
@ -105,9 +105,9 @@ class Conv1x1(ops.Linear):
class DepthToSpaceTime(nn.Module):
def __init__(
self,
temporal_expansion: int,
spatial_expansion: int,
self,
temporal_expansion: int,
spatial_expansion: int,
):
super().__init__()
self.temporal_expansion = temporal_expansion
@ -135,20 +135,20 @@ class DepthToSpaceTime(nn.Module):
)
# cp_rank, _ = cp.get_cp_rank_size()
if self.temporal_expansion > 1: # and cp_rank == 0:
if self.temporal_expansion > 1: # and cp_rank == 0:
# Drop the first self.temporal_expansion - 1 frames.
# This is because we always want the 3x3x3 conv filter to only apply
# to the first frame, and the first frame doesn't need to be repeated.
assert all(x.shape)
x = x[:, :, self.temporal_expansion - 1 :]
x = x[:, :, self.temporal_expansion - 1:]
assert all(x.shape)
return x
def norm_fn(
in_channels: int,
affine: bool = True,
in_channels: int,
affine: bool = True,
):
return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels)
@ -157,15 +157,15 @@ class ResBlock(nn.Module):
"""Residual block that preserves the spatial dimensions."""
def __init__(
self,
channels: int,
*,
affine: bool = True,
attn_block: Optional[nn.Module] = None,
causal: bool = True,
prune_bottleneck: bool = False,
padding_mode: str,
bias: bool = True,
self,
channels: int,
*,
affine: bool = True,
attn_block: Optional[nn.Module] = None,
causal: bool = True,
prune_bottleneck: bool = False,
padding_mode: str,
bias: bool = True,
):
super().__init__()
self.channels = channels
@ -214,12 +214,12 @@ class ResBlock(nn.Module):
class Attention(nn.Module):
def __init__(
self,
dim: int,
head_dim: int = 32,
qkv_bias: bool = False,
out_bias: bool = True,
qk_norm: bool = True,
self,
dim: int,
head_dim: int = 32,
qkv_bias: bool = False,
out_bias: bool = True,
qk_norm: bool = True,
) -> None:
super().__init__()
self.head_dim = head_dim
@ -230,8 +230,8 @@ class Attention(nn.Module):
self.out = nn.Linear(dim, dim, bias=out_bias)
def forward(
self,
x: torch.Tensor,
self,
x: torch.Tensor,
) -> torch.Tensor:
"""Compute temporal self-attention.
@ -275,9 +275,9 @@ class Attention(nn.Module):
class AttentionBlock(nn.Module):
def __init__(
self,
dim: int,
**attn_kwargs,
self,
dim: int,
**attn_kwargs,
) -> None:
super().__init__()
self.norm = norm_fn(dim)
@ -289,14 +289,14 @@ class AttentionBlock(nn.Module):
class CausalUpsampleBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_res_blocks: int,
*,
temporal_expansion: int = 2,
spatial_expansion: int = 2,
**block_kwargs,
self,
in_channels: int,
out_channels: int,
num_res_blocks: int,
*,
temporal_expansion: int = 2,
spatial_expansion: int = 2,
**block_kwargs,
):
super().__init__()
@ -311,7 +311,7 @@ class CausalUpsampleBlock(nn.Module):
# Change channels in the final convolution layer.
self.proj = Conv1x1(
in_channels,
out_channels * temporal_expansion * (spatial_expansion**2),
out_channels * temporal_expansion * (spatial_expansion ** 2),
)
self.d2st = DepthToSpaceTime(
@ -332,14 +332,14 @@ def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **bl
class DownsampleBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_res_blocks,
*,
temporal_reduction=2,
spatial_reduction=2,
**block_kwargs,
self,
in_channels: int,
out_channels: int,
num_res_blocks,
*,
temporal_reduction=2,
spatial_reduction=2,
**block_kwargs,
):
"""
Downsample block for the VAE encoder.
@ -427,21 +427,21 @@ class FourierFeatures(nn.Module):
class Decoder(nn.Module):
def __init__(
self,
*,
out_channels: int = 3,
latent_dim: int,
base_channels: int,
channel_multipliers: List[int],
num_res_blocks: List[int],
temporal_expansions: Optional[List[int]] = None,
spatial_expansions: Optional[List[int]] = None,
has_attention: List[bool],
output_norm: bool = True,
nonlinearity: str = "silu",
output_nonlinearity: str = "silu",
causal: bool = True,
**block_kwargs,
self,
*,
out_channels: int = 3,
latent_dim: int,
base_channels: int,
channel_multipliers: List[int],
num_res_blocks: List[int],
temporal_expansions: Optional[List[int]] = None,
spatial_expansions: Optional[List[int]] = None,
has_attention: List[bool],
output_norm: bool = True,
nonlinearity: str = "silu",
output_nonlinearity: str = "silu",
causal: bool = True,
**block_kwargs,
):
super().__init__()
self.input_channels = latent_dim
@ -529,6 +529,7 @@ class Decoder(nn.Module):
return self.output_proj(x).contiguous()
class LatentDistribution:
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
"""Initialize latent distribution.
@ -560,23 +561,24 @@ class LatentDistribution:
def mode(self):
return self.mean
class Encoder(nn.Module):
def __init__(
self,
*,
in_channels: int,
base_channels: int,
channel_multipliers: List[int],
num_res_blocks: List[int],
latent_dim: int,
temporal_reductions: List[int],
spatial_reductions: List[int],
prune_bottlenecks: List[bool],
has_attentions: List[bool],
affine: bool = True,
bias: bool = True,
input_is_conv_1x1: bool = False,
padding_mode: str,
self,
*,
in_channels: int,
base_channels: int,
channel_multipliers: List[int],
num_res_blocks: List[int],
latent_dim: int,
temporal_reductions: List[int],
spatial_reductions: List[int],
prune_bottlenecks: List[bool],
has_attentions: List[bool],
affine: bool = True,
bias: bool = True,
input_is_conv_1x1: bool = False,
padding_mode: str,
):
super().__init__()
self.temporal_reductions = temporal_reductions

View File

@ -103,8 +103,10 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def __init__(self, device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel,
special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
return_projected_pooled=True, return_attention_masks=False, model_options=None): # clip-vit-base-patch32
super().__init__()
if model_options is None:
model_options = {}
if special_tokens is None:
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
assert layer in self.LAYERS
@ -663,7 +665,9 @@ SD1TokenizerT = TypeVar("SD1TokenizerT", bound="SD1Tokenizer")
class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data=None, clip_name="l", tokenizer=SDTokenizer):
if tokenizer_data is None:
tokenizer_data = {}
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
@ -694,13 +698,18 @@ class SD1Tokenizer:
return {}
class SD1CheckpointClipModel(SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, textmodel_json_config=None):
def __init__(self, device="cpu", dtype=None, model_options=None, textmodel_json_config=None):
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config)
if model_options is None:
model_options = {}
class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, textmodel_json_config=None, name=None, **kwargs):
def __init__(self, device="cpu", dtype=None, model_options=None, clip_name="l", clip_model=SD1CheckpointClipModel, textmodel_json_config=None, name=None, **kwargs):
super().__init__()
if model_options is None:
model_options = {}
if name is not None:
self.clip_name = name
self.clip = "{}".format(self.clip_name)

View File

@ -697,7 +697,9 @@ class GenmoMochi(supported_models_base.BASE):
out = model_base.GenmoMochi(self, device=device)
return out
def clip_target(self, state_dict={}):
def clip_target(self, state_dict=None):
if state_dict is None:
state_dict = {}
pref = self.text_encoder_key_prefix[0]
t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(genmo.MochiT5Tokenizer, genmo.mochi_te(**t5_detect))

View File

@ -3,6 +3,8 @@ import comfy.text_encoders.sd3_clip
import os
from transformers import T5TokenizerFast
from comfy.component_model import files
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
def __init__(self, **kwargs):
@ -11,24 +13,32 @@ class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
class MochiT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
def __init__(self, embedding_directory=None, tokenizer_data=None):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
if tokenizer_data is None:
tokenizer_data = {}
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
class MochiTEModel_(MochiT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8

View File

@ -20,7 +20,7 @@ class T5XXLModel(sd1_clip.SDClipModel):
model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, model_options=model_options)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def t5_xxl_detect(state_dict, prefix=""):
@ -35,10 +35,11 @@ def t5_xxl_detect(state_dict, prefix=""):
return out
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = dict()
tokenizer_data = {}
tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
@ -46,7 +47,7 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
class SD3Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = dict()
tokenizer_data = {}
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
@ -63,15 +64,17 @@ class SD3Tokenizer:
return self.clip_g.untokenize(token_weight_pair)
def state_dict(self):
return dict()
return {}
def clone(self):
return copy.copy(self)
class SD3ClipModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options={}):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options=None):
super().__init__()
if model_options is None:
model_options = {}
self.dtypes = set()
if clip_l:
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
@ -180,4 +183,5 @@ def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
return SD3ClipModel_

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB