diff --git a/README.md b/README.md index c3405c363..ff545ccb8 100644 --- a/README.md +++ b/README.md @@ -333,6 +333,42 @@ pip install git+https://github.com/AppMana/appmana-comfyui-nodes-controlnet-aux. Start creating an AnimateDiff workflow. When using these packages, the appropriate models will download automatically. +## SageAttention + +Improve the performance of your Mochi model video generation using **Sage Attention**: + +| Device | PyTorch 2.5.1 | SageAttention | S.A. + TorchCompileModel | +|--------|---------------|---------------|--------------------------| +| A5000 | 7.52s/it | 5.81s/it | 5.00s/it (but corrupted) | + +[Use the default Mochi Workflow.](https://github.com/comfyanonymous/ComfyUI_examples/raw/refs/heads/master/mochi/mochi_text_to_video_example.webp) This does not require any custom nodes or any change to your workflow. + +Install the dependencies for Windows or Linux using the `withtriton` component, or install the specific dependencies you need from [requirements-triton.txt](./requirements-triton.txt): + +```shell +pip install "comfyui[withtriton]@git+https://github.com/hiddenswitch/ComfyUI.git" +``` + +If you have `xformers` installed, disable it, as it will be preferred over Sage Attention: + +```shell +comfyui --disable-xformers +``` + +If you want to use **TorchCompileModel** to further improve performance, do not reserve VRAM: + +```shell +comfyui --disable-xformers --reserve-vram=-1.0 +``` + +Sage Attention is not compatible with Flux. It does not appear to be compatible with Mochi when using `torch.compile` + +![with_sage_attention.webp](./docs/assets/with_sage_attention.webp) +**With SageAttention** + +![with_pytorch_attention](./docs/assets/with_pytorch_attention.webp) +**With PyTorch Attention** + # Custom Nodes Custom Nodes can be added to ComfyUI by copying and pasting Python files into your `./custom_nodes` directory. diff --git a/comfy/ldm/genmo/__init__.py b/comfy/ldm/genmo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/ldm/genmo/joint_model/__init__.py b/comfy/ldm/genmo/joint_model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/ldm/genmo/joint_model/asymm_models_joint.py b/comfy/ldm/genmo/joint_model/asymm_models_joint.py index 45c938966..dae567896 100644 --- a/comfy/ldm/genmo/joint_model/asymm_models_joint.py +++ b/comfy/ldm/genmo/joint_model/asymm_models_joint.py @@ -1,14 +1,11 @@ -#original code from https://github.com/genmoai/models under apache 2.0 license -#adapted to ComfyUI - -from typing import Dict, List, Optional, Tuple +# original code from https://github.com/genmoai/models under apache 2.0 license +# adapted to ComfyUIfrom typing import Dict, List, Optional, Tuple +from typing import Tuple, List, Dict, Optional import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -# from flash_attn import flash_attn_varlen_qkvpacked_func -from comfy.ldm.modules.attention import optimized_attention from .layers import ( FeedForward, @@ -16,7 +13,6 @@ from .layers import ( RMSNorm, TimestepEmbedder, ) - from .rope_mixed import ( compute_mixed_rotation, create_position_matrix, @@ -26,14 +22,15 @@ from .utils import ( AttentionPool, modulate, ) - -import comfy.ldm.common_dit -import comfy.ops +from ...common_dit import rms_norm +# from flash_attn import flash_attn_varlen_qkvpacked_func +from ...modules.attention import optimized_attention +from .... import ops def modulated_rmsnorm(x, scale, eps=1e-6): # Normalize and modulate - x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps) + x_normed = rms_norm(x, eps=eps) x_modulated = x_normed * (1 + scale.unsqueeze(1)) return x_modulated @@ -44,29 +41,30 @@ def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6): tanh_gate = torch.tanh(gate).unsqueeze(1) # Normalize and apply gated scaling - x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate + x_normed = rms_norm(x_res, eps=eps) * tanh_gate # Apply residual connection output = x + x_normed return output + class AsymmetricAttention(nn.Module): def __init__( - self, - dim_x: int, - dim_y: int, - num_heads: int = 8, - qkv_bias: bool = True, - qk_norm: bool = False, - attn_drop: float = 0.0, - update_y: bool = True, - out_bias: bool = True, - attend_to_padding: bool = False, - softmax_scale: Optional[float] = None, - device: Optional[torch.device] = None, - dtype=None, - operations=None, + self, + dim_x: int, + dim_y: int, + num_heads: int = 8, + qkv_bias: bool = True, + qk_norm: bool = False, + attn_drop: float = 0.0, + update_y: bool = True, + out_bias: bool = True, + attend_to_padding: bool = False, + softmax_scale: Optional[float] = None, + device: Optional[torch.device] = None, + dtype=None, + operations=None, ): super().__init__() self.dim_x = dim_x @@ -104,13 +102,13 @@ class AsymmetricAttention(nn.Module): ) def forward( - self, - x: torch.Tensor, # (B, N, dim_x) - y: torch.Tensor, # (B, L, dim_y) - scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. - scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. - crop_y, - **rope_rotation, + self, + x: torch.Tensor, # (B, N, dim_x) + y: torch.Tensor, # (B, L, dim_y) + scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. + scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. + crop_y, + **rope_rotation, ) -> Tuple[torch.Tensor, torch.Tensor]: rope_cos = rope_rotation.get("rope_cos") rope_sin = rope_rotation.get("rope_sin") @@ -159,18 +157,18 @@ class AsymmetricAttention(nn.Module): class AsymmetricJointBlock(nn.Module): def __init__( - self, - hidden_size_x: int, - hidden_size_y: int, - num_heads: int, - *, - mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens. - mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens. - update_y: bool = True, # Whether to update text tokens in this block. - device: Optional[torch.device] = None, - dtype=None, - operations=None, - **block_kwargs, + self, + hidden_size_x: int, + hidden_size_y: int, + num_heads: int, + *, + mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens. + mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens. + update_y: bool = True, # Whether to update text tokens in this block. + device: Optional[torch.device] = None, + dtype=None, + operations=None, + **block_kwargs, ): super().__init__() self.update_y = update_y @@ -221,11 +219,11 @@ class AsymmetricJointBlock(nn.Module): ) def forward( - self, - x: torch.Tensor, - c: torch.Tensor, - y: torch.Tensor, - **attn_kwargs, + self, + x: torch.Tensor, + c: torch.Tensor, + y: torch.Tensor, + **attn_kwargs, ): """Forward pass of a block. @@ -291,13 +289,13 @@ class FinalLayer(nn.Module): """ def __init__( - self, - hidden_size, - patch_size, - out_channels, - device: Optional[torch.device] = None, - dtype=None, - operations=None, + self, + hidden_size, + patch_size, + out_channels, + device: Optional[torch.device] = None, + dtype=None, + operations=None, ): super().__init__() self.norm_final = operations.LayerNorm( @@ -324,32 +322,32 @@ class AsymmDiTJoint(nn.Module): """ def __init__( - self, - *, - patch_size=2, - in_channels=4, - hidden_size_x=1152, - hidden_size_y=1152, - depth=48, - num_heads=16, - mlp_ratio_x=8.0, - mlp_ratio_y=4.0, - use_t5: bool = False, - t5_feat_dim: int = 4096, - t5_token_length: int = 256, - learn_sigma=True, - patch_embed_bias: bool = True, - timestep_mlp_bias: bool = True, - attend_to_padding: bool = False, - timestep_scale: Optional[float] = None, - use_extended_posenc: bool = False, - posenc_preserve_area: bool = False, - rope_theta: float = 10000.0, - image_model=None, - device: Optional[torch.device] = None, - dtype=None, - operations=None, - **block_kwargs, + self, + *, + patch_size=2, + in_channels=4, + hidden_size_x=1152, + hidden_size_y=1152, + depth=48, + num_heads=16, + mlp_ratio_x=8.0, + mlp_ratio_y=4.0, + use_t5: bool = False, + t5_feat_dim: int = 4096, + t5_token_length: int = 256, + learn_sigma=True, + patch_embed_bias: bool = True, + timestep_mlp_bias: bool = True, + attend_to_padding: bool = False, + timestep_scale: Optional[float] = None, + use_extended_posenc: bool = False, + posenc_preserve_area: bool = False, + rope_theta: float = 10000.0, + image_model=None, + device: Optional[torch.device] = None, + dtype=None, + operations=None, + **block_kwargs, ): super().__init__() @@ -362,7 +360,7 @@ class AsymmDiTJoint(nn.Module): self.hidden_size_x = hidden_size_x self.hidden_size_y = hidden_size_y self.head_dim = ( - hidden_size_x // num_heads + hidden_size_x // num_heads ) # Head dimension and count is determined by visual. self.attend_to_padding = attend_to_padding self.use_extended_posenc = use_extended_posenc @@ -449,11 +447,11 @@ class AsymmDiTJoint(nn.Module): return self.x_embedder(x) # Convert BcTHW to BCN def prepare( - self, - x: torch.Tensor, - sigma: torch.Tensor, - t5_feat: torch.Tensor, - t5_mask: torch.Tensor, + self, + x: torch.Tensor, + sigma: torch.Tensor, + t5_feat: torch.Tensor, + t5_mask: torch.Tensor, ): """Prepare input and conditioning embeddings.""" # Visual patch embeddings with positional encoding. @@ -463,7 +461,6 @@ class AsymmDiTJoint(nn.Module): assert x.ndim == 3 B = x.size(0) - pH, pW = H // self.patch_size, W // self.patch_size N = T * pH * pW assert x.size(1) == N @@ -471,7 +468,7 @@ class AsymmDiTJoint(nn.Module): T, pH=pH, pW=pW, device=x.device, dtype=torch.float32 ) # (N, 3) rope_cos, rope_sin = compute_mixed_rotation( - freqs=comfy.ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos + freqs=ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos ) # Each are (N, num_heads, dim // 2) c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D) @@ -485,17 +482,19 @@ class AsymmDiTJoint(nn.Module): return x, c, y_feat, rope_cos, rope_sin def forward( - self, - x: torch.Tensor, - timestep: torch.Tensor, - context: List[torch.Tensor], - attention_mask: List[torch.Tensor], - num_tokens=256, - packed_indices: Dict[str, torch.Tensor] = None, - rope_cos: torch.Tensor = None, - rope_sin: torch.Tensor = None, - control=None, transformer_options={}, **kwargs + self, + x: torch.Tensor, + timestep: torch.Tensor, + context: List[torch.Tensor], + attention_mask: List[torch.Tensor], + num_tokens=256, + packed_indices: Dict[str, torch.Tensor] = None, + rope_cos: torch.Tensor = None, + rope_sin: torch.Tensor = None, + control=None, transformer_options=None, **kwargs ): + if transformer_options is None: + transformer_options = {} patches_replace = transformer_options.get("patches_replace", {}) y_feat = context y_mask = attention_mask @@ -522,14 +521,15 @@ class AsymmDiTJoint(nn.Module): def block_wrap(args): out = {} out["img"], out["txt"] = block( - args["img"], - args["vec"], - args["txt"], - rope_cos=args["rope_cos"], - rope_sin=args["rope_sin"], - crop_y=args["num_tokens"] - ) + args["img"], + args["vec"], + args["txt"], + rope_cos=args["rope_cos"], + rope_sin=args["rope_sin"], + crop_y=args["num_tokens"] + ) return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap}) y_feat = out["txt"] x = out["img"] diff --git a/comfy/ldm/genmo/joint_model/layers.py b/comfy/ldm/genmo/joint_model/layers.py index 51d979559..3232a0df7 100644 --- a/comfy/ldm/genmo/joint_model/layers.py +++ b/comfy/ldm/genmo/joint_model/layers.py @@ -1,5 +1,5 @@ -#original code from https://github.com/genmoai/models under apache 2.0 license -#adapted to ComfyUI +# original code from https://github.com/genmoai/models under apache 2.0 license +# adapted to ComfyUI import collections.abc import math @@ -10,7 +10,8 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -import comfy.ldm.common_dit + +from ...common_dit import pad_to_patch_size, rms_norm # From PyTorch internals @@ -28,15 +29,15 @@ to_2tuple = _ntuple(2) class TimestepEmbedder(nn.Module): def __init__( - self, - hidden_size: int, - frequency_embedding_size: int = 256, - *, - bias: bool = True, - timestep_scale: Optional[float] = None, - dtype=None, - device=None, - operations=None, + self, + hidden_size: int, + frequency_embedding_size: int = 256, + *, + bias: bool = True, + timestep_scale: Optional[float] = None, + dtype=None, + device=None, + operations=None, ): super().__init__() self.mlp = nn.Sequential( @@ -70,14 +71,14 @@ class TimestepEmbedder(nn.Module): class FeedForward(nn.Module): def __init__( - self, - in_features: int, - hidden_size: int, - multiple_of: int, - ffn_dim_multiplier: Optional[float], - device: Optional[torch.device] = None, - dtype=None, - operations=None, + self, + in_features: int, + hidden_size: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + device: Optional[torch.device] = None, + dtype=None, + operations=None, ): super().__init__() # keep parameter count and computation constant compared to standard FFN @@ -99,17 +100,17 @@ class FeedForward(nn.Module): class PatchEmbed(nn.Module): def __init__( - self, - patch_size: int = 16, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer: Optional[Callable] = None, - flatten: bool = True, - bias: bool = True, - dynamic_img_pad: bool = False, - dtype=None, - device=None, - operations=None, + self, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten: bool = True, + bias: bool = True, + dynamic_img_pad: bool = False, + dtype=None, + device=None, + operations=None, ): super().__init__() self.patch_size = to_2tuple(patch_size) @@ -141,7 +142,7 @@ class PatchEmbed(nn.Module): x = F.pad(x, (0, pad_w, 0, pad_h)) x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T) - x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode='circular') + x = pad_to_patch_size(x, self.patch_size, padding_mode='circular') x = self.proj(x) # Flatten temporal and spatial dimensions. @@ -161,4 +162,4 @@ class RMSNorm(torch.nn.Module): self.register_parameter("bias", None) def forward(self, x): - return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps) + return rms_norm(x, self.weight, self.eps) diff --git a/comfy/ldm/genmo/joint_model/rope_mixed.py b/comfy/ldm/genmo/joint_model/rope_mixed.py index dee3fa21f..48d01c20d 100644 --- a/comfy/ldm/genmo/joint_model/rope_mixed.py +++ b/comfy/ldm/genmo/joint_model/rope_mixed.py @@ -59,8 +59,8 @@ def create_position_matrix( # Stack and reshape the grids. pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3] - pos = pos.view(-1, 3) # [T * pH * pW, 3] pos = pos.to(dtype=dtype, device=device) + pos = pos.view(-1, 3) # [T * pH * pW, 3] return pos diff --git a/comfy/ldm/genmo/vae/__init__.py b/comfy/ldm/genmo/vae/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/ldm/genmo/vae/model.py b/comfy/ldm/genmo/vae/model.py index b68d48ae5..c3bccdd5a 100644 --- a/comfy/ldm/genmo/vae/model.py +++ b/comfy/ldm/genmo/vae/model.py @@ -1,19 +1,18 @@ -#original code from https://github.com/genmoai/models under apache 2.0 license -#adapted to ComfyUI +# original code from https://github.com/genmoai/models under apache 2.0 license +# adapted to ComfyUI -from typing import Callable, List, Optional, Tuple, Union -from functools import partial import math +from functools import partial +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from comfy.ldm.modules.attention import optimized_attention +from ...modules.attention import optimized_attention +from ....ops import disable_weight_init as ops -import comfy.ops -ops = comfy.ops.disable_weight_init # import mochi_preview.dit.joint_model.context_parallel as cp # from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames @@ -34,19 +33,20 @@ class GroupNormSpatial(ops.GroupNorm): # Run group norm in chunks. output = torch.empty_like(x) for b in range(0, B * T, chunk_size): - output[b : b + chunk_size] = super().forward(x[b : b + chunk_size]) + output[b: b + chunk_size] = super().forward(x[b: b + chunk_size]) return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T) + class PConv3d(ops.Conv3d): def __init__( - self, - in_channels, - out_channels, - kernel_size: Union[int, Tuple[int, int, int]], - stride: Union[int, Tuple[int, int, int]], - causal: bool = True, - context_parallel: bool = True, - **kwargs, + self, + in_channels, + out_channels, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]], + causal: bool = True, + context_parallel: bool = True, + **kwargs, ): self.causal = causal self.context_parallel = context_parallel @@ -105,9 +105,9 @@ class Conv1x1(ops.Linear): class DepthToSpaceTime(nn.Module): def __init__( - self, - temporal_expansion: int, - spatial_expansion: int, + self, + temporal_expansion: int, + spatial_expansion: int, ): super().__init__() self.temporal_expansion = temporal_expansion @@ -135,20 +135,20 @@ class DepthToSpaceTime(nn.Module): ) # cp_rank, _ = cp.get_cp_rank_size() - if self.temporal_expansion > 1: # and cp_rank == 0: + if self.temporal_expansion > 1: # and cp_rank == 0: # Drop the first self.temporal_expansion - 1 frames. # This is because we always want the 3x3x3 conv filter to only apply # to the first frame, and the first frame doesn't need to be repeated. assert all(x.shape) - x = x[:, :, self.temporal_expansion - 1 :] + x = x[:, :, self.temporal_expansion - 1:] assert all(x.shape) return x def norm_fn( - in_channels: int, - affine: bool = True, + in_channels: int, + affine: bool = True, ): return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels) @@ -157,15 +157,15 @@ class ResBlock(nn.Module): """Residual block that preserves the spatial dimensions.""" def __init__( - self, - channels: int, - *, - affine: bool = True, - attn_block: Optional[nn.Module] = None, - causal: bool = True, - prune_bottleneck: bool = False, - padding_mode: str, - bias: bool = True, + self, + channels: int, + *, + affine: bool = True, + attn_block: Optional[nn.Module] = None, + causal: bool = True, + prune_bottleneck: bool = False, + padding_mode: str, + bias: bool = True, ): super().__init__() self.channels = channels @@ -214,12 +214,12 @@ class ResBlock(nn.Module): class Attention(nn.Module): def __init__( - self, - dim: int, - head_dim: int = 32, - qkv_bias: bool = False, - out_bias: bool = True, - qk_norm: bool = True, + self, + dim: int, + head_dim: int = 32, + qkv_bias: bool = False, + out_bias: bool = True, + qk_norm: bool = True, ) -> None: super().__init__() self.head_dim = head_dim @@ -230,8 +230,8 @@ class Attention(nn.Module): self.out = nn.Linear(dim, dim, bias=out_bias) def forward( - self, - x: torch.Tensor, + self, + x: torch.Tensor, ) -> torch.Tensor: """Compute temporal self-attention. @@ -275,9 +275,9 @@ class Attention(nn.Module): class AttentionBlock(nn.Module): def __init__( - self, - dim: int, - **attn_kwargs, + self, + dim: int, + **attn_kwargs, ) -> None: super().__init__() self.norm = norm_fn(dim) @@ -289,14 +289,14 @@ class AttentionBlock(nn.Module): class CausalUpsampleBlock(nn.Module): def __init__( - self, - in_channels: int, - out_channels: int, - num_res_blocks: int, - *, - temporal_expansion: int = 2, - spatial_expansion: int = 2, - **block_kwargs, + self, + in_channels: int, + out_channels: int, + num_res_blocks: int, + *, + temporal_expansion: int = 2, + spatial_expansion: int = 2, + **block_kwargs, ): super().__init__() @@ -311,7 +311,7 @@ class CausalUpsampleBlock(nn.Module): # Change channels in the final convolution layer. self.proj = Conv1x1( in_channels, - out_channels * temporal_expansion * (spatial_expansion**2), + out_channels * temporal_expansion * (spatial_expansion ** 2), ) self.d2st = DepthToSpaceTime( @@ -332,14 +332,14 @@ def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **bl class DownsampleBlock(nn.Module): def __init__( - self, - in_channels: int, - out_channels: int, - num_res_blocks, - *, - temporal_reduction=2, - spatial_reduction=2, - **block_kwargs, + self, + in_channels: int, + out_channels: int, + num_res_blocks, + *, + temporal_reduction=2, + spatial_reduction=2, + **block_kwargs, ): """ Downsample block for the VAE encoder. @@ -427,21 +427,21 @@ class FourierFeatures(nn.Module): class Decoder(nn.Module): def __init__( - self, - *, - out_channels: int = 3, - latent_dim: int, - base_channels: int, - channel_multipliers: List[int], - num_res_blocks: List[int], - temporal_expansions: Optional[List[int]] = None, - spatial_expansions: Optional[List[int]] = None, - has_attention: List[bool], - output_norm: bool = True, - nonlinearity: str = "silu", - output_nonlinearity: str = "silu", - causal: bool = True, - **block_kwargs, + self, + *, + out_channels: int = 3, + latent_dim: int, + base_channels: int, + channel_multipliers: List[int], + num_res_blocks: List[int], + temporal_expansions: Optional[List[int]] = None, + spatial_expansions: Optional[List[int]] = None, + has_attention: List[bool], + output_norm: bool = True, + nonlinearity: str = "silu", + output_nonlinearity: str = "silu", + causal: bool = True, + **block_kwargs, ): super().__init__() self.input_channels = latent_dim @@ -529,6 +529,7 @@ class Decoder(nn.Module): return self.output_proj(x).contiguous() + class LatentDistribution: def __init__(self, mean: torch.Tensor, logvar: torch.Tensor): """Initialize latent distribution. @@ -560,23 +561,24 @@ class LatentDistribution: def mode(self): return self.mean + class Encoder(nn.Module): def __init__( - self, - *, - in_channels: int, - base_channels: int, - channel_multipliers: List[int], - num_res_blocks: List[int], - latent_dim: int, - temporal_reductions: List[int], - spatial_reductions: List[int], - prune_bottlenecks: List[bool], - has_attentions: List[bool], - affine: bool = True, - bias: bool = True, - input_is_conv_1x1: bool = False, - padding_mode: str, + self, + *, + in_channels: int, + base_channels: int, + channel_multipliers: List[int], + num_res_blocks: List[int], + latent_dim: int, + temporal_reductions: List[int], + spatial_reductions: List[int], + prune_bottlenecks: List[bool], + has_attentions: List[bool], + affine: bool = True, + bias: bool = True, + input_is_conv_1x1: bool = False, + padding_mode: str, ): super().__init__() self.temporal_reductions = temporal_reductions diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 3985e6bf4..6e4ae493f 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -103,8 +103,10 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def __init__(self, device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel, special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False, - return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32 + return_projected_pooled=True, return_attention_masks=False, model_options=None): # clip-vit-base-patch32 super().__init__() + if model_options is None: + model_options = {} if special_tokens is None: special_tokens = {"start": 49406, "end": 49407, "pad": 49407} assert layer in self.LAYERS @@ -663,7 +665,9 @@ SD1TokenizerT = TypeVar("SD1TokenizerT", bound="SD1Tokenizer") class SD1Tokenizer: - def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data=None, clip_name="l", tokenizer=SDTokenizer): + if tokenizer_data is None: + tokenizer_data = {} self.clip_name = clip_name self.clip = "clip_{}".format(self.clip_name) tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer) @@ -694,13 +698,18 @@ class SD1Tokenizer: return {} class SD1CheckpointClipModel(SDClipModel): - def __init__(self, device="cpu", dtype=None, model_options={}, textmodel_json_config=None): + def __init__(self, device="cpu", dtype=None, model_options=None, textmodel_json_config=None): super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config) + if model_options is None: + model_options = {} + class SD1ClipModel(torch.nn.Module): - def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, textmodel_json_config=None, name=None, **kwargs): + def __init__(self, device="cpu", dtype=None, model_options=None, clip_name="l", clip_model=SD1CheckpointClipModel, textmodel_json_config=None, name=None, **kwargs): super().__init__() + if model_options is None: + model_options = {} if name is not None: self.clip_name = name self.clip = "{}".format(self.clip_name) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2d0bc2a70..597ef5e86 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -697,7 +697,9 @@ class GenmoMochi(supported_models_base.BASE): out = model_base.GenmoMochi(self, device=device) return out - def clip_target(self, state_dict={}): + def clip_target(self, state_dict=None): + if state_dict is None: + state_dict = {} pref = self.text_encoder_key_prefix[0] t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(genmo.MochiT5Tokenizer, genmo.mochi_te(**t5_detect)) diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py index 45987a480..bd992093e 100644 --- a/comfy/text_encoders/genmo.py +++ b/comfy/text_encoders/genmo.py @@ -3,6 +3,8 @@ import comfy.text_encoders.sd3_clip import os from transformers import T5TokenizerFast +from comfy.component_model import files + class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel): def __init__(self, **kwargs): @@ -11,24 +13,32 @@ class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel): class MochiT5XXL(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", dtype=None, model_options={}): + def __init__(self, device="cpu", dtype=None, model_options=None): + if model_options is None: + model_options = {} super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options) class T5XXLTokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): - tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") + def __init__(self, embedding_directory=None, tokenizer_data=None): + if tokenizer_data is None: + tokenizer_data = {} + tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer") super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) class MochiT5Tokenizer(sd1_clip.SD1Tokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): + def __init__(self, embedding_directory=None, tokenizer_data=None): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) + if tokenizer_data is None: + tokenizer_data = {} def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None): class MochiTEModel_(MochiT5XXL): - def __init__(self, device="cpu", dtype=None, model_options={}): + def __init__(self, device="cpu", dtype=None, model_options=None): + if model_options is None: + model_options = {} if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: model_options = model_options.copy() model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index a2d048212..4080744a1 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -20,7 +20,7 @@ class T5XXLModel(sd1_clip.SDClipModel): model_options = model_options.copy() model_options["scaled_fp8"] = t5xxl_scaled_fp8 - super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, model_options=model_options) + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) def t5_xxl_detect(state_dict, prefix=""): @@ -35,10 +35,11 @@ def t5_xxl_detect(state_dict, prefix=""): return out + class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data=None): if tokenizer_data is None: - tokenizer_data = dict() + tokenizer_data = {} tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer") super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77) @@ -46,7 +47,7 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer): class SD3Tokenizer: def __init__(self, embedding_directory=None, tokenizer_data=None): if tokenizer_data is None: - tokenizer_data = dict() + tokenizer_data = {} clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory) @@ -63,15 +64,17 @@ class SD3Tokenizer: return self.clip_g.untokenize(token_weight_pair) def state_dict(self): - return dict() + return {} def clone(self): return copy.copy(self) class SD3ClipModel(torch.nn.Module): - def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options={}): + def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options=None): super().__init__() + if model_options is None: + model_options = {} self.dtypes = set() if clip_l: clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) @@ -180,4 +183,5 @@ def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8= model_options = model_options.copy() model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options) + return SD3ClipModel_ diff --git a/docs/assets/with_pytorch_attention.webp b/docs/assets/with_pytorch_attention.webp new file mode 100644 index 000000000..acfa691a1 Binary files /dev/null and b/docs/assets/with_pytorch_attention.webp differ diff --git a/docs/assets/with_sage_attention.webp b/docs/assets/with_sage_attention.webp new file mode 100644 index 000000000..c6563beb5 Binary files /dev/null and b/docs/assets/with_sage_attention.webp differ