Mochi and SageAttention improvements

This commit is contained in:
doctorpangloss 2024-11-18 15:40:15 -08:00
parent 264d84db39
commit 8ba412897e
14 changed files with 311 additions and 247 deletions

View File

@ -333,6 +333,42 @@ pip install git+https://github.com/AppMana/appmana-comfyui-nodes-controlnet-aux.
Start creating an AnimateDiff workflow. When using these packages, the appropriate models will download automatically. Start creating an AnimateDiff workflow. When using these packages, the appropriate models will download automatically.
## SageAttention
Improve the performance of your Mochi model video generation using **Sage Attention**:
| Device | PyTorch 2.5.1 | SageAttention | S.A. + TorchCompileModel |
|--------|---------------|---------------|--------------------------|
| A5000 | 7.52s/it | 5.81s/it | 5.00s/it (but corrupted) |
[Use the default Mochi Workflow.](https://github.com/comfyanonymous/ComfyUI_examples/raw/refs/heads/master/mochi/mochi_text_to_video_example.webp) This does not require any custom nodes or any change to your workflow.
Install the dependencies for Windows or Linux using the `withtriton` component, or install the specific dependencies you need from [requirements-triton.txt](./requirements-triton.txt):
```shell
pip install "comfyui[withtriton]@git+https://github.com/hiddenswitch/ComfyUI.git"
```
If you have `xformers` installed, disable it, as it will be preferred over Sage Attention:
```shell
comfyui --disable-xformers
```
If you want to use **TorchCompileModel** to further improve performance, do not reserve VRAM:
```shell
comfyui --disable-xformers --reserve-vram=-1.0
```
Sage Attention is not compatible with Flux. It does not appear to be compatible with Mochi when using `torch.compile`
![with_sage_attention.webp](./docs/assets/with_sage_attention.webp)
**With SageAttention**
![with_pytorch_attention](./docs/assets/with_pytorch_attention.webp)
**With PyTorch Attention**
# Custom Nodes # Custom Nodes
Custom Nodes can be added to ComfyUI by copying and pasting Python files into your `./custom_nodes` directory. Custom Nodes can be added to ComfyUI by copying and pasting Python files into your `./custom_nodes` directory.

View File

View File

View File

@ -1,14 +1,11 @@
# original code from https://github.com/genmoai/models under apache 2.0 license # original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI # adapted to ComfyUIfrom typing import Dict, List, Optional, Tuple
from typing import Tuple, List, Dict, Optional
from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
# from flash_attn import flash_attn_varlen_qkvpacked_func
from comfy.ldm.modules.attention import optimized_attention
from .layers import ( from .layers import (
FeedForward, FeedForward,
@ -16,7 +13,6 @@ from .layers import (
RMSNorm, RMSNorm,
TimestepEmbedder, TimestepEmbedder,
) )
from .rope_mixed import ( from .rope_mixed import (
compute_mixed_rotation, compute_mixed_rotation,
create_position_matrix, create_position_matrix,
@ -26,14 +22,15 @@ from .utils import (
AttentionPool, AttentionPool,
modulate, modulate,
) )
from ...common_dit import rms_norm
import comfy.ldm.common_dit # from flash_attn import flash_attn_varlen_qkvpacked_func
import comfy.ops from ...modules.attention import optimized_attention
from .... import ops
def modulated_rmsnorm(x, scale, eps=1e-6): def modulated_rmsnorm(x, scale, eps=1e-6):
# Normalize and modulate # Normalize and modulate
x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps) x_normed = rms_norm(x, eps=eps)
x_modulated = x_normed * (1 + scale.unsqueeze(1)) x_modulated = x_normed * (1 + scale.unsqueeze(1))
return x_modulated return x_modulated
@ -44,13 +41,14 @@ def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
tanh_gate = torch.tanh(gate).unsqueeze(1) tanh_gate = torch.tanh(gate).unsqueeze(1)
# Normalize and apply gated scaling # Normalize and apply gated scaling
x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate x_normed = rms_norm(x_res, eps=eps) * tanh_gate
# Apply residual connection # Apply residual connection
output = x + x_normed output = x + x_normed
return output return output
class AsymmetricAttention(nn.Module): class AsymmetricAttention(nn.Module):
def __init__( def __init__(
self, self,
@ -463,7 +461,6 @@ class AsymmDiTJoint(nn.Module):
assert x.ndim == 3 assert x.ndim == 3
B = x.size(0) B = x.size(0)
pH, pW = H // self.patch_size, W // self.patch_size pH, pW = H // self.patch_size, W // self.patch_size
N = T * pH * pW N = T * pH * pW
assert x.size(1) == N assert x.size(1) == N
@ -471,7 +468,7 @@ class AsymmDiTJoint(nn.Module):
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32 T, pH=pH, pW=pW, device=x.device, dtype=torch.float32
) # (N, 3) ) # (N, 3)
rope_cos, rope_sin = compute_mixed_rotation( rope_cos, rope_sin = compute_mixed_rotation(
freqs=comfy.ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos freqs=ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos
) # Each are (N, num_heads, dim // 2) ) # Each are (N, num_heads, dim // 2)
c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D) c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D)
@ -494,8 +491,10 @@ class AsymmDiTJoint(nn.Module):
packed_indices: Dict[str, torch.Tensor] = None, packed_indices: Dict[str, torch.Tensor] = None,
rope_cos: torch.Tensor = None, rope_cos: torch.Tensor = None,
rope_sin: torch.Tensor = None, rope_sin: torch.Tensor = None,
control=None, transformer_options={}, **kwargs control=None, transformer_options=None, **kwargs
): ):
if transformer_options is None:
transformer_options = {}
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
y_feat = context y_feat = context
y_mask = attention_mask y_mask = attention_mask
@ -530,6 +529,7 @@ class AsymmDiTJoint(nn.Module):
crop_y=args["num_tokens"] crop_y=args["num_tokens"]
) )
return out return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
y_feat = out["txt"] y_feat = out["txt"]
x = out["img"] x = out["img"]

View File

@ -10,7 +10,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
import comfy.ldm.common_dit
from ...common_dit import pad_to_patch_size, rms_norm
# From PyTorch internals # From PyTorch internals
@ -141,7 +142,7 @@ class PatchEmbed(nn.Module):
x = F.pad(x, (0, pad_w, 0, pad_h)) x = F.pad(x, (0, pad_w, 0, pad_h))
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T) x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode='circular') x = pad_to_patch_size(x, self.patch_size, padding_mode='circular')
x = self.proj(x) x = self.proj(x)
# Flatten temporal and spatial dimensions. # Flatten temporal and spatial dimensions.
@ -161,4 +162,4 @@ class RMSNorm(torch.nn.Module):
self.register_parameter("bias", None) self.register_parameter("bias", None)
def forward(self, x): def forward(self, x):
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps) return rms_norm(x, self.weight, self.eps)

View File

@ -59,8 +59,8 @@ def create_position_matrix(
# Stack and reshape the grids. # Stack and reshape the grids.
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3] pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
pos = pos.view(-1, 3) # [T * pH * pW, 3]
pos = pos.to(dtype=dtype, device=device) pos = pos.to(dtype=dtype, device=device)
pos = pos.view(-1, 3) # [T * pH * pW, 3]
return pos return pos

View File

View File

@ -1,19 +1,18 @@
# original code from https://github.com/genmoai/models under apache 2.0 license # original code from https://github.com/genmoai/models under apache 2.0 license
# adapted to ComfyUI # adapted to ComfyUI
from typing import Callable, List, Optional, Tuple, Union
from functools import partial
import math import math
from functools import partial
from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention from ...modules.attention import optimized_attention
from ....ops import disable_weight_init as ops
import comfy.ops
ops = comfy.ops.disable_weight_init
# import mochi_preview.dit.joint_model.context_parallel as cp # import mochi_preview.dit.joint_model.context_parallel as cp
# from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames # from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
@ -37,6 +36,7 @@ class GroupNormSpatial(ops.GroupNorm):
output[b: b + chunk_size] = super().forward(x[b: b + chunk_size]) output[b: b + chunk_size] = super().forward(x[b: b + chunk_size])
return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T) return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
class PConv3d(ops.Conv3d): class PConv3d(ops.Conv3d):
def __init__( def __init__(
self, self,
@ -529,6 +529,7 @@ class Decoder(nn.Module):
return self.output_proj(x).contiguous() return self.output_proj(x).contiguous()
class LatentDistribution: class LatentDistribution:
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor): def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
"""Initialize latent distribution. """Initialize latent distribution.
@ -560,6 +561,7 @@ class LatentDistribution:
def mode(self): def mode(self):
return self.mean return self.mean
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__( def __init__(
self, self,

View File

@ -103,8 +103,10 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def __init__(self, device="cpu", max_length=77, def __init__(self, device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel, freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel,
special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False, special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32 return_projected_pooled=True, return_attention_masks=False, model_options=None): # clip-vit-base-patch32
super().__init__() super().__init__()
if model_options is None:
model_options = {}
if special_tokens is None: if special_tokens is None:
special_tokens = {"start": 49406, "end": 49407, "pad": 49407} special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
assert layer in self.LAYERS assert layer in self.LAYERS
@ -663,7 +665,9 @@ SD1TokenizerT = TypeVar("SD1TokenizerT", bound="SD1Tokenizer")
class SD1Tokenizer: class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data=None, clip_name="l", tokenizer=SDTokenizer):
if tokenizer_data is None:
tokenizer_data = {}
self.clip_name = clip_name self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name) self.clip = "clip_{}".format(self.clip_name)
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer) tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
@ -694,13 +698,18 @@ class SD1Tokenizer:
return {} return {}
class SD1CheckpointClipModel(SDClipModel): class SD1CheckpointClipModel(SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, textmodel_json_config=None): def __init__(self, device="cpu", dtype=None, model_options=None, textmodel_json_config=None):
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config) super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config)
if model_options is None:
model_options = {}
class SD1ClipModel(torch.nn.Module): class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, textmodel_json_config=None, name=None, **kwargs): def __init__(self, device="cpu", dtype=None, model_options=None, clip_name="l", clip_model=SD1CheckpointClipModel, textmodel_json_config=None, name=None, **kwargs):
super().__init__() super().__init__()
if model_options is None:
model_options = {}
if name is not None: if name is not None:
self.clip_name = name self.clip_name = name
self.clip = "{}".format(self.clip_name) self.clip = "{}".format(self.clip_name)

View File

@ -697,7 +697,9 @@ class GenmoMochi(supported_models_base.BASE):
out = model_base.GenmoMochi(self, device=device) out = model_base.GenmoMochi(self, device=device)
return out return out
def clip_target(self, state_dict={}): def clip_target(self, state_dict=None):
if state_dict is None:
state_dict = {}
pref = self.text_encoder_key_prefix[0] pref = self.text_encoder_key_prefix[0]
t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(genmo.MochiT5Tokenizer, genmo.mochi_te(**t5_detect)) return supported_models_base.ClipTarget(genmo.MochiT5Tokenizer, genmo.mochi_te(**t5_detect))

View File

@ -3,6 +3,8 @@ import comfy.text_encoders.sd3_clip
import os import os
from transformers import T5TokenizerFast from transformers import T5TokenizerFast
from comfy.component_model import files
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel): class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -11,24 +13,32 @@ class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
class MochiT5XXL(sd1_clip.SD1ClipModel): class MochiT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options) super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data=None):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") if tokenizer_data is None:
tokenizer_data = {}
tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer): class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data=None):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
if tokenizer_data is None:
tokenizer_data = {}
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None): def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
class MochiTEModel_(MochiT5XXL): class MochiTEModel_(MochiT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy() model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8

View File

@ -20,7 +20,7 @@ class T5XXLModel(sd1_clip.SDClipModel):
model_options = model_options.copy() model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8 model_options["scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, model_options=model_options) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def t5_xxl_detect(state_dict, prefix=""): def t5_xxl_detect(state_dict, prefix=""):
@ -35,10 +35,11 @@ def t5_xxl_detect(state_dict, prefix=""):
return out return out
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data=None): def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None: if tokenizer_data is None:
tokenizer_data = dict() tokenizer_data = {}
tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer") tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77) super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
@ -46,7 +47,7 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
class SD3Tokenizer: class SD3Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data=None): def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None: if tokenizer_data is None:
tokenizer_data = dict() tokenizer_data = {}
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory) self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
@ -63,15 +64,17 @@ class SD3Tokenizer:
return self.clip_g.untokenize(token_weight_pair) return self.clip_g.untokenize(token_weight_pair)
def state_dict(self): def state_dict(self):
return dict() return {}
def clone(self): def clone(self):
return copy.copy(self) return copy.copy(self)
class SD3ClipModel(torch.nn.Module): class SD3ClipModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options={}): def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options=None):
super().__init__() super().__init__()
if model_options is None:
model_options = {}
self.dtypes = set() self.dtypes = set()
if clip_l: if clip_l:
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
@ -180,4 +183,5 @@ def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=
model_options = model_options.copy() model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options) super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
return SD3ClipModel_ return SD3ClipModel_

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB