mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 05:40:15 +08:00
Mochi and SageAttention improvements
This commit is contained in:
parent
264d84db39
commit
8ba412897e
36
README.md
36
README.md
@ -333,6 +333,42 @@ pip install git+https://github.com/AppMana/appmana-comfyui-nodes-controlnet-aux.
|
|||||||
|
|
||||||
Start creating an AnimateDiff workflow. When using these packages, the appropriate models will download automatically.
|
Start creating an AnimateDiff workflow. When using these packages, the appropriate models will download automatically.
|
||||||
|
|
||||||
|
## SageAttention
|
||||||
|
|
||||||
|
Improve the performance of your Mochi model video generation using **Sage Attention**:
|
||||||
|
|
||||||
|
| Device | PyTorch 2.5.1 | SageAttention | S.A. + TorchCompileModel |
|
||||||
|
|--------|---------------|---------------|--------------------------|
|
||||||
|
| A5000 | 7.52s/it | 5.81s/it | 5.00s/it (but corrupted) |
|
||||||
|
|
||||||
|
[Use the default Mochi Workflow.](https://github.com/comfyanonymous/ComfyUI_examples/raw/refs/heads/master/mochi/mochi_text_to_video_example.webp) This does not require any custom nodes or any change to your workflow.
|
||||||
|
|
||||||
|
Install the dependencies for Windows or Linux using the `withtriton` component, or install the specific dependencies you need from [requirements-triton.txt](./requirements-triton.txt):
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install "comfyui[withtriton]@git+https://github.com/hiddenswitch/ComfyUI.git"
|
||||||
|
```
|
||||||
|
|
||||||
|
If you have `xformers` installed, disable it, as it will be preferred over Sage Attention:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
comfyui --disable-xformers
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to use **TorchCompileModel** to further improve performance, do not reserve VRAM:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
comfyui --disable-xformers --reserve-vram=-1.0
|
||||||
|
```
|
||||||
|
|
||||||
|
Sage Attention is not compatible with Flux. It does not appear to be compatible with Mochi when using `torch.compile`
|
||||||
|
|
||||||
|

|
||||||
|
**With SageAttention**
|
||||||
|
|
||||||
|

|
||||||
|
**With PyTorch Attention**
|
||||||
|
|
||||||
# Custom Nodes
|
# Custom Nodes
|
||||||
|
|
||||||
Custom Nodes can be added to ComfyUI by copying and pasting Python files into your `./custom_nodes` directory.
|
Custom Nodes can be added to ComfyUI by copying and pasting Python files into your `./custom_nodes` directory.
|
||||||
|
|||||||
0
comfy/ldm/genmo/__init__.py
Normal file
0
comfy/ldm/genmo/__init__.py
Normal file
0
comfy/ldm/genmo/joint_model/__init__.py
Normal file
0
comfy/ldm/genmo/joint_model/__init__.py
Normal file
@ -1,14 +1,11 @@
|
|||||||
# original code from https://github.com/genmoai/models under apache 2.0 license
|
# original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
#adapted to ComfyUI
|
# adapted to ComfyUIfrom typing import Dict, List, Optional, Tuple
|
||||||
|
from typing import Tuple, List, Dict, Optional
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
# from flash_attn import flash_attn_varlen_qkvpacked_func
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
|
||||||
|
|
||||||
from .layers import (
|
from .layers import (
|
||||||
FeedForward,
|
FeedForward,
|
||||||
@ -16,7 +13,6 @@ from .layers import (
|
|||||||
RMSNorm,
|
RMSNorm,
|
||||||
TimestepEmbedder,
|
TimestepEmbedder,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .rope_mixed import (
|
from .rope_mixed import (
|
||||||
compute_mixed_rotation,
|
compute_mixed_rotation,
|
||||||
create_position_matrix,
|
create_position_matrix,
|
||||||
@ -26,14 +22,15 @@ from .utils import (
|
|||||||
AttentionPool,
|
AttentionPool,
|
||||||
modulate,
|
modulate,
|
||||||
)
|
)
|
||||||
|
from ...common_dit import rms_norm
|
||||||
import comfy.ldm.common_dit
|
# from flash_attn import flash_attn_varlen_qkvpacked_func
|
||||||
import comfy.ops
|
from ...modules.attention import optimized_attention
|
||||||
|
from .... import ops
|
||||||
|
|
||||||
|
|
||||||
def modulated_rmsnorm(x, scale, eps=1e-6):
|
def modulated_rmsnorm(x, scale, eps=1e-6):
|
||||||
# Normalize and modulate
|
# Normalize and modulate
|
||||||
x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps)
|
x_normed = rms_norm(x, eps=eps)
|
||||||
x_modulated = x_normed * (1 + scale.unsqueeze(1))
|
x_modulated = x_normed * (1 + scale.unsqueeze(1))
|
||||||
|
|
||||||
return x_modulated
|
return x_modulated
|
||||||
@ -44,13 +41,14 @@ def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
|
|||||||
tanh_gate = torch.tanh(gate).unsqueeze(1)
|
tanh_gate = torch.tanh(gate).unsqueeze(1)
|
||||||
|
|
||||||
# Normalize and apply gated scaling
|
# Normalize and apply gated scaling
|
||||||
x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate
|
x_normed = rms_norm(x_res, eps=eps) * tanh_gate
|
||||||
|
|
||||||
# Apply residual connection
|
# Apply residual connection
|
||||||
output = x + x_normed
|
output = x + x_normed
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class AsymmetricAttention(nn.Module):
|
class AsymmetricAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -463,7 +461,6 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
assert x.ndim == 3
|
assert x.ndim == 3
|
||||||
B = x.size(0)
|
B = x.size(0)
|
||||||
|
|
||||||
|
|
||||||
pH, pW = H // self.patch_size, W // self.patch_size
|
pH, pW = H // self.patch_size, W // self.patch_size
|
||||||
N = T * pH * pW
|
N = T * pH * pW
|
||||||
assert x.size(1) == N
|
assert x.size(1) == N
|
||||||
@ -471,7 +468,7 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32
|
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32
|
||||||
) # (N, 3)
|
) # (N, 3)
|
||||||
rope_cos, rope_sin = compute_mixed_rotation(
|
rope_cos, rope_sin = compute_mixed_rotation(
|
||||||
freqs=comfy.ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos
|
freqs=ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos
|
||||||
) # Each are (N, num_heads, dim // 2)
|
) # Each are (N, num_heads, dim // 2)
|
||||||
|
|
||||||
c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D)
|
c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D)
|
||||||
@ -494,8 +491,10 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
packed_indices: Dict[str, torch.Tensor] = None,
|
packed_indices: Dict[str, torch.Tensor] = None,
|
||||||
rope_cos: torch.Tensor = None,
|
rope_cos: torch.Tensor = None,
|
||||||
rope_sin: torch.Tensor = None,
|
rope_sin: torch.Tensor = None,
|
||||||
control=None, transformer_options={}, **kwargs
|
control=None, transformer_options=None, **kwargs
|
||||||
):
|
):
|
||||||
|
if transformer_options is None:
|
||||||
|
transformer_options = {}
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
y_feat = context
|
y_feat = context
|
||||||
y_mask = attention_mask
|
y_mask = attention_mask
|
||||||
@ -530,6 +529,7 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
crop_y=args["num_tokens"]
|
crop_y=args["num_tokens"]
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
|
||||||
y_feat = out["txt"]
|
y_feat = out["txt"]
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
|
|||||||
@ -10,7 +10,8 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
import comfy.ldm.common_dit
|
|
||||||
|
from ...common_dit import pad_to_patch_size, rms_norm
|
||||||
|
|
||||||
|
|
||||||
# From PyTorch internals
|
# From PyTorch internals
|
||||||
@ -141,7 +142,7 @@ class PatchEmbed(nn.Module):
|
|||||||
x = F.pad(x, (0, pad_w, 0, pad_h))
|
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||||
|
|
||||||
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
|
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode='circular')
|
x = pad_to_patch_size(x, self.patch_size, padding_mode='circular')
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
|
|
||||||
# Flatten temporal and spatial dimensions.
|
# Flatten temporal and spatial dimensions.
|
||||||
@ -161,4 +162,4 @@ class RMSNorm(torch.nn.Module):
|
|||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
return rms_norm(x, self.weight, self.eps)
|
||||||
|
|||||||
@ -59,8 +59,8 @@ def create_position_matrix(
|
|||||||
|
|
||||||
# Stack and reshape the grids.
|
# Stack and reshape the grids.
|
||||||
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
|
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
|
||||||
pos = pos.view(-1, 3) # [T * pH * pW, 3]
|
|
||||||
pos = pos.to(dtype=dtype, device=device)
|
pos = pos.to(dtype=dtype, device=device)
|
||||||
|
pos = pos.view(-1, 3) # [T * pH * pW, 3]
|
||||||
|
|
||||||
return pos
|
return pos
|
||||||
|
|
||||||
|
|||||||
0
comfy/ldm/genmo/vae/__init__.py
Normal file
0
comfy/ldm/genmo/vae/__init__.py
Normal file
@ -1,19 +1,18 @@
|
|||||||
# original code from https://github.com/genmoai/models under apache 2.0 license
|
# original code from https://github.com/genmoai/models under apache 2.0 license
|
||||||
# adapted to ComfyUI
|
# adapted to ComfyUI
|
||||||
|
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
|
||||||
from functools import partial
|
|
||||||
import math
|
import math
|
||||||
|
from functools import partial
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from ...modules.attention import optimized_attention
|
||||||
|
from ....ops import disable_weight_init as ops
|
||||||
|
|
||||||
import comfy.ops
|
|
||||||
ops = comfy.ops.disable_weight_init
|
|
||||||
|
|
||||||
# import mochi_preview.dit.joint_model.context_parallel as cp
|
# import mochi_preview.dit.joint_model.context_parallel as cp
|
||||||
# from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
|
# from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
|
||||||
@ -37,6 +36,7 @@ class GroupNormSpatial(ops.GroupNorm):
|
|||||||
output[b: b + chunk_size] = super().forward(x[b: b + chunk_size])
|
output[b: b + chunk_size] = super().forward(x[b: b + chunk_size])
|
||||||
return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
|
return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
|
||||||
|
|
||||||
|
|
||||||
class PConv3d(ops.Conv3d):
|
class PConv3d(ops.Conv3d):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -529,6 +529,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
return self.output_proj(x).contiguous()
|
return self.output_proj(x).contiguous()
|
||||||
|
|
||||||
|
|
||||||
class LatentDistribution:
|
class LatentDistribution:
|
||||||
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
|
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
|
||||||
"""Initialize latent distribution.
|
"""Initialize latent distribution.
|
||||||
@ -560,6 +561,7 @@ class LatentDistribution:
|
|||||||
def mode(self):
|
def mode(self):
|
||||||
return self.mean
|
return self.mean
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -103,8 +103,10 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
def __init__(self, device="cpu", max_length=77,
|
def __init__(self, device="cpu", max_length=77,
|
||||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel,
|
freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel,
|
||||||
special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||||
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
return_projected_pooled=True, return_attention_masks=False, model_options=None): # clip-vit-base-patch32
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if model_options is None:
|
||||||
|
model_options = {}
|
||||||
if special_tokens is None:
|
if special_tokens is None:
|
||||||
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
|
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
|
||||||
assert layer in self.LAYERS
|
assert layer in self.LAYERS
|
||||||
@ -663,7 +665,9 @@ SD1TokenizerT = TypeVar("SD1TokenizerT", bound="SD1Tokenizer")
|
|||||||
|
|
||||||
|
|
||||||
class SD1Tokenizer:
|
class SD1Tokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
def __init__(self, embedding_directory=None, tokenizer_data=None, clip_name="l", tokenizer=SDTokenizer):
|
||||||
|
if tokenizer_data is None:
|
||||||
|
tokenizer_data = {}
|
||||||
self.clip_name = clip_name
|
self.clip_name = clip_name
|
||||||
self.clip = "clip_{}".format(self.clip_name)
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
|
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
|
||||||
@ -694,13 +698,18 @@ class SD1Tokenizer:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
class SD1CheckpointClipModel(SDClipModel):
|
class SD1CheckpointClipModel(SDClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}, textmodel_json_config=None):
|
def __init__(self, device="cpu", dtype=None, model_options=None, textmodel_json_config=None):
|
||||||
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config)
|
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config)
|
||||||
|
if model_options is None:
|
||||||
|
model_options = {}
|
||||||
|
|
||||||
|
|
||||||
class SD1ClipModel(torch.nn.Module):
|
class SD1ClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, textmodel_json_config=None, name=None, **kwargs):
|
def __init__(self, device="cpu", dtype=None, model_options=None, clip_name="l", clip_model=SD1CheckpointClipModel, textmodel_json_config=None, name=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if model_options is None:
|
||||||
|
model_options = {}
|
||||||
if name is not None:
|
if name is not None:
|
||||||
self.clip_name = name
|
self.clip_name = name
|
||||||
self.clip = "{}".format(self.clip_name)
|
self.clip = "{}".format(self.clip_name)
|
||||||
|
|||||||
@ -697,7 +697,9 @@ class GenmoMochi(supported_models_base.BASE):
|
|||||||
out = model_base.GenmoMochi(self, device=device)
|
out = model_base.GenmoMochi(self, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict=None):
|
||||||
|
if state_dict is None:
|
||||||
|
state_dict = {}
|
||||||
pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(genmo.MochiT5Tokenizer, genmo.mochi_te(**t5_detect))
|
return supported_models_base.ClipTarget(genmo.MochiT5Tokenizer, genmo.mochi_te(**t5_detect))
|
||||||
|
|||||||
@ -3,6 +3,8 @@ import comfy.text_encoders.sd3_clip
|
|||||||
import os
|
import os
|
||||||
from transformers import T5TokenizerFast
|
from transformers import T5TokenizerFast
|
||||||
|
|
||||||
|
from comfy.component_model import files
|
||||||
|
|
||||||
|
|
||||||
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@ -11,24 +13,32 @@ class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
|||||||
|
|
||||||
|
|
||||||
class MochiT5XXL(sd1_clip.SD1ClipModel):
|
class MochiT5XXL(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||||
|
if model_options is None:
|
||||||
|
model_options = {}
|
||||||
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
if tokenizer_data is None:
|
||||||
|
tokenizer_data = {}
|
||||||
|
tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||||
|
|
||||||
|
|
||||||
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
if tokenizer_data is None:
|
||||||
|
tokenizer_data = {}
|
||||||
|
|
||||||
|
|
||||||
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||||
class MochiTEModel_(MochiT5XXL):
|
class MochiTEModel_(MochiT5XXL):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||||
|
if model_options is None:
|
||||||
|
model_options = {}
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
|||||||
@ -20,7 +20,7 @@ class T5XXLModel(sd1_clip.SDClipModel):
|
|||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
def t5_xxl_detect(state_dict, prefix=""):
|
def t5_xxl_detect(state_dict, prefix=""):
|
||||||
@ -35,10 +35,11 @@ def t5_xxl_detect(state_dict, prefix=""):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||||
if tokenizer_data is None:
|
if tokenizer_data is None:
|
||||||
tokenizer_data = dict()
|
tokenizer_data = {}
|
||||||
tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer")
|
tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
||||||
|
|
||||||
@ -46,7 +47,7 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
|||||||
class SD3Tokenizer:
|
class SD3Tokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||||
if tokenizer_data is None:
|
if tokenizer_data is None:
|
||||||
tokenizer_data = dict()
|
tokenizer_data = {}
|
||||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||||
@ -63,15 +64,17 @@ class SD3Tokenizer:
|
|||||||
return self.clip_g.untokenize(token_weight_pair)
|
return self.clip_g.untokenize(token_weight_pair)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return dict()
|
return {}
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
return copy.copy(self)
|
return copy.copy(self)
|
||||||
|
|
||||||
|
|
||||||
class SD3ClipModel(torch.nn.Module):
|
class SD3ClipModel(torch.nn.Module):
|
||||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options={}):
|
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if model_options is None:
|
||||||
|
model_options = {}
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
if clip_l:
|
if clip_l:
|
||||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||||
@ -180,4 +183,5 @@ def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=
|
|||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|
||||||
return SD3ClipModel_
|
return SD3ClipModel_
|
||||||
|
|||||||
BIN
docs/assets/with_pytorch_attention.webp
Normal file
BIN
docs/assets/with_pytorch_attention.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.9 MiB |
BIN
docs/assets/with_sage_attention.webp
Normal file
BIN
docs/assets/with_sage_attention.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.9 MiB |
Loading…
Reference in New Issue
Block a user