diff --git a/.github/workflows/test-execution.yml b/.github/workflows/test-execution.yml
new file mode 100644
index 000000000..00ef07ebf
--- /dev/null
+++ b/.github/workflows/test-execution.yml
@@ -0,0 +1,30 @@
+name: Execution Tests
+
+on:
+ push:
+ branches: [ main, master ]
+ pull_request:
+ branches: [ main, master ]
+
+jobs:
+ test:
+ strategy:
+ matrix:
+ os: [ubuntu-latest, windows-latest, macos-latest]
+ runs-on: ${{ matrix.os }}
+ continue-on-error: true
+ steps:
+ - uses: actions/checkout@v4
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: '3.12'
+ - name: Install requirements
+ run: |
+ python -m pip install --upgrade pip
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
+ pip install -r requirements.txt
+ pip install -r tests-unit/requirements.txt
+ - name: Run Execution Tests
+ run: |
+ python -m pytest tests/execution -v --skip-timing-checks
diff --git a/README.md b/README.md
index fa99a8cbe..3f6cfc2ed 100644
--- a/README.md
+++ b/README.md
@@ -65,18 +65,18 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- - [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
+ - [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
+ - [Qwen Image Edit](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/#edit-model)
- Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
- Audio Models
@@ -191,7 +191,7 @@ comfy install
## Manual Install (Windows, Linux)
-python 3.13 is supported but using 3.12 is recommended because some custom nodes and their dependencies might not support it yet.
+Python 3.13 is very well supported. If you have trouble with some custom node dependencies you can try 3.12
Git clone this repo.
diff --git a/app/user_manager.py b/app/user_manager.py
index 0ec3e46ea..a2d376c0c 100644
--- a/app/user_manager.py
+++ b/app/user_manager.py
@@ -363,10 +363,17 @@ class UserManager():
if not overwrite and os.path.exists(path):
return web.Response(status=409, text="File already exists")
- body = await request.read()
+ try:
+ body = await request.read()
- with open(path, "wb") as f:
- f.write(body)
+ with open(path, "wb") as f:
+ f.write(body)
+ except OSError as e:
+ logging.warning(f"Error saving file '{path}': {e}")
+ return web.Response(
+ status=400,
+ reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|"
+ )
user_path = self.get_request_user_filepath(request, None)
if full_info:
diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py
new file mode 100644
index 000000000..46ef21c95
--- /dev/null
+++ b/comfy/audio_encoders/audio_encoders.py
@@ -0,0 +1,91 @@
+from .wav2vec2 import Wav2Vec2Model
+from .whisper import WhisperLargeV3
+import comfy.model_management
+import comfy.ops
+import comfy.utils
+import logging
+import torchaudio
+
+
+class AudioEncoderModel():
+ def __init__(self, config):
+ self.load_device = comfy.model_management.text_encoder_device()
+ offload_device = comfy.model_management.text_encoder_offload_device()
+ self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
+ model_type = config.pop("model_type")
+ model_config = dict(config)
+ model_config.update({
+ "dtype": self.dtype,
+ "device": offload_device,
+ "operations": comfy.ops.manual_cast
+ })
+
+ if model_type == "wav2vec2":
+ self.model = Wav2Vec2Model(**model_config)
+ elif model_type == "whisper3":
+ self.model = WhisperLargeV3(**model_config)
+ self.model.eval()
+ self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
+ self.model_sample_rate = 16000
+
+ def load_sd(self, sd):
+ return self.model.load_state_dict(sd, strict=False)
+
+ def get_sd(self):
+ return self.model.state_dict()
+
+ def encode_audio(self, audio, sample_rate):
+ comfy.model_management.load_model_gpu(self.patcher)
+ audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
+ out, all_layers = self.model(audio.to(self.load_device))
+ outputs = {}
+ outputs["encoded_audio"] = out
+ outputs["encoded_audio_all_layers"] = all_layers
+ outputs["audio_samples"] = audio.shape[2]
+ return outputs
+
+
+def load_audio_encoder_from_sd(sd, prefix=""):
+ sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
+ if "encoder.layer_norm.bias" in sd: #wav2vec2
+ embed_dim = sd["encoder.layer_norm.bias"].shape[0]
+ if embed_dim == 1024:# large
+ config = {
+ "model_type": "wav2vec2",
+ "embed_dim": 1024,
+ "num_heads": 16,
+ "num_layers": 24,
+ "conv_norm": True,
+ "conv_bias": True,
+ "do_normalize": True,
+ "do_stable_layer_norm": True
+ }
+ elif embed_dim == 768: # base
+ config = {
+ "model_type": "wav2vec2",
+ "embed_dim": 768,
+ "num_heads": 12,
+ "num_layers": 12,
+ "conv_norm": False,
+ "conv_bias": False,
+ "do_normalize": False, # chinese-wav2vec2-base has this False
+ "do_stable_layer_norm": False
+ }
+ else:
+ raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
+ elif "model.encoder.embed_positions.weight" in sd:
+ sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""})
+ config = {
+ "model_type": "whisper3",
+ }
+ else:
+ raise RuntimeError("ERROR: audio encoder not supported.")
+
+ audio_encoder = AudioEncoderModel(config)
+ m, u = audio_encoder.load_sd(sd)
+ if len(m) > 0:
+ logging.warning("missing audio encoder: {}".format(m))
+ if len(u) > 0:
+ logging.warning("unexpected audio encoder: {}".format(u))
+
+ return audio_encoder
diff --git a/comfy/audio_encoders/wav2vec2.py b/comfy/audio_encoders/wav2vec2.py
new file mode 100644
index 000000000..4e34a40a7
--- /dev/null
+++ b/comfy/audio_encoders/wav2vec2.py
@@ -0,0 +1,252 @@
+import torch
+import torch.nn as nn
+from comfy.ldm.modules.attention import optimized_attention_masked
+
+
+class LayerNormConv(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
+ self.layer_norm = operations.LayerNorm(out_channels, elementwise_affine=True, device=device, dtype=dtype)
+
+ def forward(self, x):
+ x = self.conv(x)
+ return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
+
+class LayerGroupNormConv(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
+ self.layer_norm = operations.GroupNorm(num_groups=out_channels, num_channels=out_channels, affine=True, device=device, dtype=dtype)
+
+ def forward(self, x):
+ x = self.conv(x)
+ return torch.nn.functional.gelu(self.layer_norm(x))
+
+class ConvNoNorm(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
+
+ def forward(self, x):
+ x = self.conv(x)
+ return torch.nn.functional.gelu(x)
+
+
+class ConvFeatureEncoder(nn.Module):
+ def __init__(self, conv_dim, conv_bias=False, conv_norm=True, dtype=None, device=None, operations=None):
+ super().__init__()
+ if conv_norm:
+ self.conv_layers = nn.ModuleList([
+ LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
+ LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ ])
+ else:
+ self.conv_layers = nn.ModuleList([
+ LayerGroupNormConv(1, conv_dim, kernel_size=10, stride=5, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ ])
+
+ def forward(self, x):
+ x = x.unsqueeze(1)
+
+ for conv in self.conv_layers:
+ x = conv(x)
+
+ return x.transpose(1, 2)
+
+
+class FeatureProjection(nn.Module):
+ def __init__(self, conv_dim, embed_dim, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.layer_norm = operations.LayerNorm(conv_dim, eps=1e-05, device=device, dtype=dtype)
+ self.projection = operations.Linear(conv_dim, embed_dim, device=device, dtype=dtype)
+
+ def forward(self, x):
+ x = self.layer_norm(x)
+ x = self.projection(x)
+ return x
+
+
+class PositionalConvEmbedding(nn.Module):
+ def __init__(self, embed_dim=768, kernel_size=128, groups=16):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ embed_dim,
+ embed_dim,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ groups=groups,
+ )
+ self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
+ self.activation = nn.GELU()
+
+ def forward(self, x):
+ x = x.transpose(1, 2)
+ x = self.conv(x)[:, :, :-1]
+ x = self.activation(x)
+ x = x.transpose(1, 2)
+ return x
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(
+ self,
+ embed_dim=768,
+ num_heads=12,
+ num_layers=12,
+ mlp_ratio=4.0,
+ do_stable_layer_norm=True,
+ dtype=None, device=None, operations=None
+ ):
+ super().__init__()
+
+ self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim)
+ self.layers = nn.ModuleList([
+ TransformerEncoderLayer(
+ embed_dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ do_stable_layer_norm=do_stable_layer_norm,
+ device=device, dtype=dtype, operations=operations
+ )
+ for _ in range(num_layers)
+ ])
+
+ self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
+ self.do_stable_layer_norm = do_stable_layer_norm
+
+ def forward(self, x, mask=None):
+ x = x + self.pos_conv_embed(x)
+ all_x = ()
+ if not self.do_stable_layer_norm:
+ x = self.layer_norm(x)
+ for layer in self.layers:
+ all_x += (x,)
+ x = layer(x, mask)
+ if self.do_stable_layer_norm:
+ x = self.layer_norm(x)
+ all_x += (x,)
+ return x, all_x
+
+
+class Attention(nn.Module):
+ def __init__(self, embed_dim, num_heads, bias=True, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.head_dim = embed_dim // num_heads
+
+ self.k_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
+ self.v_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
+ self.q_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
+ self.out_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
+
+ def forward(self, x, mask=None):
+ assert (mask is None) # TODO?
+ q = self.q_proj(x)
+ k = self.k_proj(x)
+ v = self.v_proj(x)
+
+ out = optimized_attention_masked(q, k, v, self.num_heads)
+ return self.out_proj(out)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, embed_dim, mlp_ratio, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.intermediate_dense = operations.Linear(embed_dim, int(embed_dim * mlp_ratio), device=device, dtype=dtype)
+ self.output_dense = operations.Linear(int(embed_dim * mlp_ratio), embed_dim, device=device, dtype=dtype)
+
+ def forward(self, x):
+ x = self.intermediate_dense(x)
+ x = torch.nn.functional.gelu(x)
+ x = self.output_dense(x)
+ return x
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ embed_dim=768,
+ num_heads=12,
+ mlp_ratio=4.0,
+ do_stable_layer_norm=True,
+ dtype=None, device=None, operations=None
+ ):
+ super().__init__()
+
+ self.attention = Attention(embed_dim, num_heads, device=device, dtype=dtype, operations=operations)
+
+ self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
+ self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
+ self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
+ self.do_stable_layer_norm = do_stable_layer_norm
+
+ def forward(self, x, mask=None):
+ residual = x
+ if self.do_stable_layer_norm:
+ x = self.layer_norm(x)
+ x = self.attention(x, mask=mask)
+ x = residual + x
+ if not self.do_stable_layer_norm:
+ x = self.layer_norm(x)
+ return self.final_layer_norm(x + self.feed_forward(x))
+ else:
+ return x + self.feed_forward(self.final_layer_norm(x))
+
+
+class Wav2Vec2Model(nn.Module):
+ """Complete Wav2Vec 2.0 model."""
+
+ def __init__(
+ self,
+ embed_dim=1024,
+ final_dim=256,
+ num_heads=16,
+ num_layers=24,
+ conv_norm=True,
+ conv_bias=True,
+ do_normalize=True,
+ do_stable_layer_norm=True,
+ dtype=None, device=None, operations=None
+ ):
+ super().__init__()
+
+ conv_dim = 512
+ self.feature_extractor = ConvFeatureEncoder(conv_dim, conv_norm=conv_norm, conv_bias=conv_bias, device=device, dtype=dtype, operations=operations)
+ self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
+
+ self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
+ self.do_normalize = do_normalize
+
+ self.encoder = TransformerEncoder(
+ embed_dim=embed_dim,
+ num_heads=num_heads,
+ num_layers=num_layers,
+ do_stable_layer_norm=do_stable_layer_norm,
+ device=device, dtype=dtype, operations=operations
+ )
+
+ def forward(self, x, mask_time_indices=None, return_dict=False):
+ x = torch.mean(x, dim=1)
+
+ if self.do_normalize:
+ x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
+
+ features = self.feature_extractor(x)
+ features = self.feature_projection(features)
+ batch_size, seq_len, _ = features.shape
+
+ x, all_x = self.encoder(features)
+ return x, all_x
diff --git a/comfy/audio_encoders/whisper.py b/comfy/audio_encoders/whisper.py
new file mode 100755
index 000000000..93d3782f1
--- /dev/null
+++ b/comfy/audio_encoders/whisper.py
@@ -0,0 +1,186 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchaudio
+from typing import Optional
+from comfy.ldm.modules.attention import optimized_attention_masked
+import comfy.ops
+
+class WhisperFeatureExtractor(nn.Module):
+ def __init__(self, n_mels=128, device=None):
+ super().__init__()
+ self.sample_rate = 16000
+ self.n_fft = 400
+ self.hop_length = 160
+ self.n_mels = n_mels
+ self.chunk_length = 30
+ self.n_samples = 480000
+
+ self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
+ sample_rate=self.sample_rate,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ n_mels=self.n_mels,
+ f_min=0,
+ f_max=8000,
+ norm="slaney",
+ mel_scale="slaney",
+ ).to(device)
+
+ def __call__(self, audio):
+ audio = torch.mean(audio, dim=1)
+ batch_size = audio.shape[0]
+ processed_audio = []
+
+ for i in range(batch_size):
+ aud = audio[i]
+ if aud.shape[0] > self.n_samples:
+ aud = aud[:self.n_samples]
+ elif aud.shape[0] < self.n_samples:
+ aud = F.pad(aud, (0, self.n_samples - aud.shape[0]))
+ processed_audio.append(aud)
+
+ audio = torch.stack(processed_audio)
+
+ mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device)
+
+ log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10()
+ log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0)
+ log_mel_spec = (log_mel_spec + 4.0) / 4.0
+
+ return log_mel_spec
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None):
+ super().__init__()
+ assert d_model % n_heads == 0
+
+ self.d_model = d_model
+ self.n_heads = n_heads
+ self.d_k = d_model // n_heads
+
+ self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
+ self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device)
+ self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
+ self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, seq_len, _ = query.shape
+
+ q = self.q_proj(query)
+ k = self.k_proj(key)
+ v = self.v_proj(value)
+
+ attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask)
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output
+
+
+class EncoderLayer(nn.Module):
+ def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None):
+ super().__init__()
+
+ self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations)
+ self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
+
+ self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device)
+ self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device)
+ self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ residual = x
+ x = self.self_attn_layer_norm(x)
+ x = self.self_attn(x, x, x, attention_mask)
+ x = residual + x
+
+ residual = x
+ x = self.final_layer_norm(x)
+ x = self.fc1(x)
+ x = F.gelu(x)
+ x = self.fc2(x)
+ x = residual + x
+
+ return x
+
+
+class AudioEncoder(nn.Module):
+ def __init__(
+ self,
+ n_mels: int = 128,
+ n_ctx: int = 1500,
+ n_state: int = 1280,
+ n_head: int = 20,
+ n_layer: int = 32,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+
+ self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device)
+ self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device)
+
+ self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device)
+
+ self.layers = nn.ModuleList([
+ EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations)
+ for _ in range(n_layer)
+ ])
+
+ self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = F.gelu(self.conv1(x))
+ x = F.gelu(self.conv2(x))
+
+ x = x.transpose(1, 2)
+
+ x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x)
+
+ all_x = ()
+ for layer in self.layers:
+ all_x += (x,)
+ x = layer(x)
+
+ x = self.layer_norm(x)
+ all_x += (x,)
+ return x, all_x
+
+
+class WhisperLargeV3(nn.Module):
+ def __init__(
+ self,
+ n_mels: int = 128,
+ n_audio_ctx: int = 1500,
+ n_audio_state: int = 1280,
+ n_audio_head: int = 20,
+ n_audio_layer: int = 32,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+
+ self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device)
+
+ self.encoder = AudioEncoder(
+ n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer,
+ dtype=dtype, device=device, operations=operations
+ )
+
+ def forward(self, audio):
+ mel = self.feature_extractor(audio)
+ x, all_x = self.encoder(mel)
+ return x, all_x
diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index d814e453a..7955cc763 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -143,8 +143,9 @@ class PerformanceFeature(enum.Enum):
Fp16Accumulation = "fp16_accumulation"
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
+ AutoTune = "autotune"
-parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
+parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
diff --git a/comfy/clip_model.py b/comfy/clip_model.py
index c8294d483..7c0cadab5 100644
--- a/comfy/clip_model.py
+++ b/comfy/clip_model.py
@@ -61,8 +61,12 @@ class CLIPEncoder(torch.nn.Module):
def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
+ all_intermediate = None
if intermediate_output is not None:
- if intermediate_output < 0:
+ if intermediate_output == "all":
+ all_intermediate = []
+ intermediate_output = None
+ elif intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
@@ -70,6 +74,12 @@ class CLIPEncoder(torch.nn.Module):
x = l(x, mask, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
+ if all_intermediate is not None:
+ all_intermediate.append(x.unsqueeze(1).clone())
+
+ if all_intermediate is not None:
+ intermediate = torch.cat(all_intermediate, dim=1)
+
return x, intermediate
class CLIPEmbeddings(torch.nn.Module):
@@ -97,7 +107,7 @@ class CLIPTextModel_(torch.nn.Module):
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
- def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
+ def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]):
if embeds is not None:
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
else:
diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py
index 00aab9164..447b1ce4a 100644
--- a/comfy/clip_vision.py
+++ b/comfy/clip_vision.py
@@ -50,7 +50,13 @@ class ClipVisionModel():
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
- model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
+ model_type = config.get("model_type", "clip_vision_model")
+ model_class = IMAGE_ENCODERS.get(model_type)
+ if model_type == "siglip_vision_model":
+ self.return_all_hidden_states = True
+ else:
+ self.return_all_hidden_states = False
+
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
@@ -68,12 +74,18 @@ class ClipVisionModel():
def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
- out = self.model(pixel_values=pixel_values, intermediate_output=-2)
+ out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
outputs = Output()
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
- outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
+ if self.return_all_hidden_states:
+ all_hs = out[1].to(comfy.model_management.intermediate_device())
+ outputs["penultimate_hidden_states"] = all_hs[:, -2]
+ outputs["all_hidden_states"] = all_hs
+ else:
+ outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
+
outputs["mm_projected"] = out[3]
return outputs
@@ -124,8 +136,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
- elif "embeddings.patch_embeddings.projection.weight" in sd:
+
+ # Dinov2
+ elif 'encoder.layer.39.layer_scale2.lambda1' in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
+ elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
+ json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
else:
return None
diff --git a/comfy/controlnet.py b/comfy/controlnet.py
index 988acdb57..f08ff4b36 100644
--- a/comfy/controlnet.py
+++ b/comfy/controlnet.py
@@ -36,6 +36,7 @@ import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet
+import comfy.ldm.qwen_image.controlnet
import comfy.cldm.dit_embedder
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@@ -236,11 +237,11 @@ class ControlNet(ControlBase):
self.cond_hint = None
compression_ratio = self.compression_ratio
if self.vae is not None:
- compression_ratio *= self.vae.downscale_ratio
+ compression_ratio *= self.vae.spacial_compression_encode()
else:
if self.latent_format is not None:
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
- self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
+ self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[-1] * compression_ratio, x_noisy.shape[-2] * compression_ratio, self.upscale_algorithm, "center")
self.cond_hint = self.preprocess_image(self.cond_hint)
if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
@@ -252,7 +253,10 @@ class ControlNet(ControlBase):
to_concat = []
for c in self.extra_concat_orig:
c = c.to(self.cond_hint.device)
- c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
+ c = comfy.utils.common_upscale(c, self.cond_hint.shape[-1], self.cond_hint.shape[-2], self.upscale_algorithm, "center")
+ if c.ndim < self.cond_hint.ndim:
+ c = c.unsqueeze(2)
+ c = comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[2], dim=2)
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
@@ -582,6 +586,22 @@ def load_controlnet_flux_instantx(sd, model_options={}):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
+def load_controlnet_qwen_instantx(sd, model_options={}):
+ model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
+ control_latent_channels = sd.get("controlnet_x_embedder.weight").shape[1]
+
+ extra_condition_channels = 0
+ concat_mask = False
+ if control_latent_channels == 68: #inpaint controlnet
+ extra_condition_channels = control_latent_channels - 64
+ concat_mask = True
+ control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(extra_condition_channels=extra_condition_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
+ control_model = controlnet_load_state_dict(control_model, sd)
+ latent_format = comfy.latent_formats.Wan21()
+ extra_conds = []
+ control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
+ return control
+
def convert_mistoline(sd):
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
@@ -655,8 +675,11 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
else:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
+ elif "transformer_blocks.0.img_mlp.net.0.proj.weight" in controlnet_data:
+ return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
+
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
diff --git a/comfy/image_encoders/dino2.py b/comfy/image_encoders/dino2.py
index 976f98c65..9b6dace9d 100644
--- a/comfy/image_encoders/dino2.py
+++ b/comfy/image_encoders/dino2.py
@@ -31,6 +31,20 @@ class LayerScale(torch.nn.Module):
def forward(self, x):
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
+class Dinov2MLP(torch.nn.Module):
+ def __init__(self, hidden_size: int, dtype, device, operations):
+ super().__init__()
+
+ mlp_ratio = 4
+ hidden_features = int(hidden_size * mlp_ratio)
+ self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype)
+ self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.fc1(hidden_state)
+ hidden_state = torch.nn.functional.gelu(hidden_state)
+ hidden_state = self.fc2(hidden_state)
+ return hidden_state
class SwiGLUFFN(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
@@ -50,12 +64,15 @@ class SwiGLUFFN(torch.nn.Module):
class Dino2Block(torch.nn.Module):
- def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
+ def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
super().__init__()
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
- self.mlp = SwiGLUFFN(dim, dtype, device, operations)
+ if use_swiglu_ffn:
+ self.mlp = SwiGLUFFN(dim, dtype, device, operations)
+ else:
+ self.mlp = Dinov2MLP(dim, dtype, device, operations)
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
@@ -66,9 +83,10 @@ class Dino2Block(torch.nn.Module):
class Dino2Encoder(torch.nn.Module):
- def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
+ def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
super().__init__()
- self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
+ self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
+ for _ in range(num_layers)])
def forward(self, x, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
@@ -78,8 +96,8 @@ class Dino2Encoder(torch.nn.Module):
intermediate_output = len(self.layer) + intermediate_output
intermediate = None
- for i, l in enumerate(self.layer):
- x = l(x, optimized_attention)
+ for i, layer in enumerate(self.layer):
+ x = layer(x, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
@@ -128,9 +146,10 @@ class Dinov2Model(torch.nn.Module):
dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
layer_norm_eps = config_dict["layer_norm_eps"]
+ use_swiglu_ffn = config_dict["use_swiglu_ffn"]
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
- self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
+ self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
diff --git a/comfy/image_encoders/dino2_large.json b/comfy/image_encoders/dino2_large.json
new file mode 100644
index 000000000..43fbb58ff
--- /dev/null
+++ b/comfy/image_encoders/dino2_large.json
@@ -0,0 +1,22 @@
+{
+ "hidden_size": 1024,
+ "use_mask_token": true,
+ "patch_size": 14,
+ "image_size": 518,
+ "num_channels": 3,
+ "num_attention_heads": 16,
+ "initializer_range": 0.02,
+ "attention_probs_dropout_prob": 0.0,
+ "hidden_dropout_prob": 0.0,
+ "hidden_act": "gelu",
+ "mlp_ratio": 4,
+ "model_type": "dinov2",
+ "num_hidden_layers": 24,
+ "layer_norm_eps": 1e-6,
+ "qkv_bias": true,
+ "use_swiglu_ffn": false,
+ "layerscale_value": 1.0,
+ "drop_path_rate": 0.0,
+ "image_mean": [0.485, 0.456, 0.406],
+ "image_std": [0.229, 0.224, 0.225]
+}
diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py
index a2bc492fd..0e2cda291 100644
--- a/comfy/k_diffusion/sampling.py
+++ b/comfy/k_diffusion/sampling.py
@@ -86,24 +86,24 @@ class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
def __init__(self, x, t0, t1, seed=None, **kwargs):
- self.cpu_tree = True
- if "cpu" in kwargs:
- self.cpu_tree = kwargs.pop("cpu")
+ self.cpu_tree = kwargs.pop("cpu", True)
t0, t1, self.sign = self.sort(t0, t1)
- w0 = kwargs.get('w0', torch.zeros_like(x))
+ w0 = kwargs.pop('w0', None)
+ if w0 is None:
+ w0 = torch.zeros_like(x)
+ self.batched = False
if seed is None:
- seed = torch.randint(0, 2 ** 63 - 1, []).item()
- self.batched = True
- try:
- assert len(seed) == x.shape[0]
+ seed = (torch.randint(0, 2 ** 63 - 1, ()).item(),)
+ elif isinstance(seed, (tuple, list)):
+ if len(seed) != x.shape[0]:
+ raise ValueError("Passing a list or tuple of seeds to BatchedBrownianTree requires a length matching the batch size.")
+ self.batched = True
w0 = w0[0]
- except TypeError:
- seed = [seed]
- self.batched = False
- if self.cpu_tree:
- self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
else:
- self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
+ seed = (seed,)
+ if self.cpu_tree:
+ t0, w0, t1 = t0.detach().cpu(), w0.detach().cpu(), t1.detach().cpu()
+ self.trees = tuple(torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed)
@staticmethod
def sort(a, b):
@@ -111,11 +111,10 @@ class BatchedBrownianTree:
def __call__(self, t0, t1):
t0, t1, sign = self.sort(t0, t1)
+ device, dtype = t0.device, t0.dtype
if self.cpu_tree:
- w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
- else:
- w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
-
+ t0, t1 = t0.detach().cpu().float(), t1.detach().cpu().float()
+ w = torch.stack([tree(t0, t1) for tree in self.trees]).to(device=device, dtype=dtype) * (self.sign * sign)
return w if self.batched else w[0]
@@ -171,6 +170,16 @@ def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
return sigmas
+def ei_h_phi_1(h: torch.Tensor) -> torch.Tensor:
+ """Compute the result of h*phi_1(h) in exponential integrator methods."""
+ return torch.expm1(h)
+
+
+def ei_h_phi_2(h: torch.Tensor) -> torch.Tensor:
+ """Compute the result of h*phi_2(h) in exponential integrator methods."""
+ return (torch.expm1(h) - h) / h
+
+
@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
@@ -853,6 +862,11 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
return x
+@torch.no_grad()
+def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
+ return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
+
+
@torch.no_grad()
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""DPM-Solver++(3M) SDE."""
@@ -925,6 +939,16 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
+@torch.no_grad()
+def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
+ if len(sigmas) <= 1:
+ return x
+ extra_args = {} if extra_args is None else extra_args
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
+ return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
+
+
@torch.no_grad()
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1:
@@ -1535,13 +1559,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
- arXiv: https://arxiv.org/abs/2305.14267
+ arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
-
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
@@ -1549,55 +1572,53 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
+ fac = 1 / (2 * r)
+
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+
if sigmas[i + 1] == 0:
x = denoised
- else:
- lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
- h = lambda_t - lambda_s
- h_eta = h * (eta + 1)
- lambda_s_1 = lambda_s + r * h
- fac = 1 / (2 * r)
- sigma_s_1 = sigma_fn(lambda_s_1)
+ continue
- # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
- alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
- alpha_t = sigmas[i + 1] * lambda_t.exp()
+ lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
+ h = lambda_t - lambda_s
+ h_eta = h * (eta + 1)
+ lambda_s_1 = torch.lerp(lambda_s, lambda_t, r)
+ sigma_s_1 = sigma_fn(lambda_s_1)
- coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
- if inject_noise:
- # 0 < r < 1
- noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
- noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
- noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
+ alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
+ alpha_t = sigmas[i + 1] * lambda_t.exp()
- # Step 1
- x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
- if inject_noise:
- x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
- denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
+ # Step 1
+ x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised
+ if inject_noise:
+ sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
+ x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
+ denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
- # Step 2
- denoised_d = (1 - fac) * denoised + fac * denoised_2
- x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
- if inject_noise:
- x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
+ # Step 2
+ denoised_d = torch.lerp(denoised, denoised_2, fac)
+ x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
+ if inject_noise:
+ segment_factor = (r - 1) * h * eta
+ sde_noise = sde_noise * segment_factor.exp()
+ sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1])
+ x = x + sde_noise * sigmas[i + 1] * s_noise
return x
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
- arXiv: https://arxiv.org/abs/2305.14267
+ arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
-
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
@@ -1609,45 +1630,49 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+
if sigmas[i + 1] == 0:
x = denoised
- else:
- lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
- h = lambda_t - lambda_s
- h_eta = h * (eta + 1)
- lambda_s_1 = lambda_s + r_1 * h
- lambda_s_2 = lambda_s + r_2 * h
- sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
+ continue
- # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
- alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
- alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
- alpha_t = sigmas[i + 1] * lambda_t.exp()
+ lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
+ h = lambda_t - lambda_s
+ h_eta = h * (eta + 1)
+ lambda_s_1 = torch.lerp(lambda_s, lambda_t, r_1)
+ lambda_s_2 = torch.lerp(lambda_s, lambda_t, r_2)
+ sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
- coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
- if inject_noise:
- # 0 < r_1 < r_2 < 1
- noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
- noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
- noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
- noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
+ alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
+ alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
+ alpha_t = sigmas[i + 1] * lambda_t.exp()
- # Step 1
- x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
- if inject_noise:
- x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
- denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
+ # Step 1
+ x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised
+ if inject_noise:
+ sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
+ x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
+ denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
- # Step 2
- x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
- if inject_noise:
- x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
- denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
+ # Step 2
+ a3_2 = r_2 / r_1 * ei_h_phi_2(-r_2 * h_eta)
+ a3_1 = ei_h_phi_1(-r_2 * h_eta) - a3_2
+ x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2)
+ if inject_noise:
+ segment_factor = (r_1 - r_2) * h * eta
+ sde_noise = sde_noise * segment_factor.exp()
+ sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2)
+ x_3 = x_3 + sde_noise * sigma_s_2 * s_noise
+ denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
- # Step 3
- x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
- if inject_noise:
- x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
+ # Step 3
+ b3 = ei_h_phi_2(-h_eta) / r_2
+ b1 = ei_h_phi_1(-h_eta) - b3
+ x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b3 * denoised_3)
+ if inject_noise:
+ segment_factor = (r_2 - 1) * h * eta
+ sde_noise = sde_noise * segment_factor.exp()
+ sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1])
+ x = x + sde_noise * sigmas[i + 1] * s_noise
return x
diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py
index caf4991fc..77e642a94 100644
--- a/comfy/latent_formats.py
+++ b/comfy/latent_formats.py
@@ -533,11 +533,94 @@ class Wan22(Wan21):
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
]).view(1, self.latent_channels, 1, 1, 1)
+class HunyuanImage21(LatentFormat):
+ latent_channels = 64
+ latent_dimensions = 2
+ scale_factor = 0.75289
+
+ latent_rgb_factors = [
+ [-0.0154, -0.0397, -0.0521],
+ [ 0.0005, 0.0093, 0.0006],
+ [-0.0805, -0.0773, -0.0586],
+ [-0.0494, -0.0487, -0.0498],
+ [-0.0212, -0.0076, -0.0261],
+ [-0.0179, -0.0417, -0.0505],
+ [ 0.0158, 0.0310, 0.0239],
+ [ 0.0409, 0.0516, 0.0201],
+ [ 0.0350, 0.0553, 0.0036],
+ [-0.0447, -0.0327, -0.0479],
+ [-0.0038, -0.0221, -0.0365],
+ [-0.0423, -0.0718, -0.0654],
+ [ 0.0039, 0.0368, 0.0104],
+ [ 0.0655, 0.0217, 0.0122],
+ [ 0.0490, 0.1638, 0.2053],
+ [ 0.0932, 0.0829, 0.0650],
+ [-0.0186, -0.0209, -0.0135],
+ [-0.0080, -0.0076, -0.0148],
+ [-0.0284, -0.0201, 0.0011],
+ [-0.0642, -0.0294, -0.0777],
+ [-0.0035, 0.0076, -0.0140],
+ [ 0.0519, 0.0731, 0.0887],
+ [-0.0102, 0.0095, 0.0704],
+ [ 0.0068, 0.0218, -0.0023],
+ [-0.0726, -0.0486, -0.0519],
+ [ 0.0260, 0.0295, 0.0263],
+ [ 0.0250, 0.0333, 0.0341],
+ [ 0.0168, -0.0120, -0.0174],
+ [ 0.0226, 0.1037, 0.0114],
+ [ 0.2577, 0.1906, 0.1604],
+ [-0.0646, -0.0137, -0.0018],
+ [-0.0112, 0.0309, 0.0358],
+ [-0.0347, 0.0146, -0.0481],
+ [ 0.0234, 0.0179, 0.0201],
+ [ 0.0157, 0.0313, 0.0225],
+ [ 0.0423, 0.0675, 0.0524],
+ [-0.0031, 0.0027, -0.0255],
+ [ 0.0447, 0.0555, 0.0330],
+ [-0.0152, 0.0103, 0.0299],
+ [-0.0755, -0.0489, -0.0635],
+ [ 0.0853, 0.0788, 0.1017],
+ [-0.0272, -0.0294, -0.0471],
+ [ 0.0440, 0.0400, -0.0137],
+ [ 0.0335, 0.0317, -0.0036],
+ [-0.0344, -0.0621, -0.0984],
+ [-0.0127, -0.0630, -0.0620],
+ [-0.0648, 0.0360, 0.0924],
+ [-0.0781, -0.0801, -0.0409],
+ [ 0.0363, 0.0613, 0.0499],
+ [ 0.0238, 0.0034, 0.0041],
+ [-0.0135, 0.0258, 0.0310],
+ [ 0.0614, 0.1086, 0.0589],
+ [ 0.0428, 0.0350, 0.0205],
+ [ 0.0153, 0.0173, -0.0018],
+ [-0.0288, -0.0455, -0.0091],
+ [ 0.0344, 0.0109, -0.0157],
+ [-0.0205, -0.0247, -0.0187],
+ [ 0.0487, 0.0126, 0.0064],
+ [-0.0220, -0.0013, 0.0074],
+ [-0.0203, -0.0094, -0.0048],
+ [-0.0719, 0.0429, -0.0442],
+ [ 0.1042, 0.0497, 0.0356],
+ [-0.0659, -0.0578, -0.0280],
+ [-0.0060, -0.0322, -0.0234]]
+
+ latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
+
+class HunyuanImage21Refiner(LatentFormat):
+ latent_channels = 64
+ latent_dimensions = 3
+ scale_factor = 1.03682
+
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1
scale_factor = 0.9990943042622529
+class Hunyuan3Dv2_1(LatentFormat):
+ scale_factor = 1.0039506158752403
+ latent_channels = 64
+ latent_dimensions = 1
+
class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64
latent_dimensions = 1
@@ -546,3 +629,20 @@ class Hunyuan3Dv2mini(LatentFormat):
class ACEAudio(LatentFormat):
latent_channels = 8
latent_dimensions = 2
+
+class ChromaRadiance(LatentFormat):
+ latent_channels = 3
+
+ def __init__(self):
+ self.latent_rgb_factors = [
+ # R G B
+ [ 1.0, 0.0, 0.0 ],
+ [ 0.0, 1.0, 0.0 ],
+ [ 0.0, 0.0, 1.0 ]
+ ]
+
+ def process_in(self, latent):
+ return latent
+
+ def process_out(self, latent):
+ return latent
diff --git a/comfy/ldm/ace/attention.py b/comfy/ldm/ace/attention.py
index f20a01669..670eb9783 100644
--- a/comfy/ldm/ace/attention.py
+++ b/comfy/ldm/ace/attention.py
@@ -133,6 +133,7 @@ class Attention(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
+ transformer_options={},
**cross_attention_kwargs,
) -> torch.Tensor:
return self.processor(
@@ -140,6 +141,7 @@ class Attention(nn.Module):
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
+ transformer_options=transformer_options,
**cross_attention_kwargs,
)
@@ -366,6 +368,7 @@ class CustomerAttnProcessor2_0:
encoder_attention_mask: Optional[torch.FloatTensor] = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
+ transformer_options={},
*args,
**kwargs,
) -> torch.Tensor:
@@ -433,7 +436,7 @@ class CustomerAttnProcessor2_0:
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = optimized_attention(
- query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
+ query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options,
).to(query.dtype)
# linear proj
@@ -697,6 +700,7 @@ class LinearTransformerBlock(nn.Module):
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
temb: torch.FloatTensor = None,
+ transformer_options={},
):
N = hidden_states.shape[0]
@@ -720,6 +724,7 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
+ transformer_options=transformer_options,
)
else:
attn_output, _ = self.attn(
@@ -729,6 +734,7 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=None,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=None,
+ transformer_options=transformer_options,
)
if self.use_adaln_single:
@@ -743,6 +749,7 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
+ transformer_options=transformer_options,
)
hidden_states = attn_output + hidden_states
diff --git a/comfy/ldm/ace/model.py b/comfy/ldm/ace/model.py
index 12c524701..399329853 100644
--- a/comfy/ldm/ace/model.py
+++ b/comfy/ldm/ace/model.py
@@ -19,6 +19,7 @@ import torch
from torch import nn
import comfy.model_management
+import comfy.patcher_extension
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from .attention import LinearTransformerBlock, t2i_modulate
@@ -313,6 +314,7 @@ class ACEStepTransformer2DModel(nn.Module):
output_length: int = 0,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
+ transformer_options={},
):
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
temb = self.t_block(embedded_timestep)
@@ -338,12 +340,34 @@ class ACEStepTransformer2DModel(nn.Module):
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
temb=temb,
+ transformer_options=transformer_options,
)
output = self.final_layer(hidden_states, embedded_timestep, output_length)
return output
- def forward(
+ def forward(self,
+ x,
+ timestep,
+ attention_mask=None,
+ context: Optional[torch.Tensor] = None,
+ text_attention_mask: Optional[torch.LongTensor] = None,
+ speaker_embeds: Optional[torch.FloatTensor] = None,
+ lyric_token_idx: Optional[torch.LongTensor] = None,
+ lyric_mask: Optional[torch.LongTensor] = None,
+ block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
+ controlnet_scale: Union[float, torch.Tensor] = 1.0,
+ lyrics_strength=1.0,
+ **kwargs
+ ):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
+ ).execute(x, timestep, attention_mask, context, text_attention_mask, speaker_embeds, lyric_token_idx, lyric_mask, block_controlnet_hidden_states,
+ controlnet_scale, lyrics_strength, **kwargs)
+
+ def _forward(
self,
x,
timestep,
@@ -371,6 +395,7 @@ class ACEStepTransformer2DModel(nn.Module):
output_length = hidden_states.shape[-1]
+ transformer_options = kwargs.get("transformer_options", {})
output = self.decode(
hidden_states=hidden_states,
attention_mask=attention_mask,
@@ -380,6 +405,7 @@ class ACEStepTransformer2DModel(nn.Module):
output_length=output_length,
block_controlnet_hidden_states=block_controlnet_hidden_states,
controlnet_scale=controlnet_scale,
+ transformer_options=transformer_options,
)
return output
diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py
index 179c5b67e..ca865189e 100644
--- a/comfy/ldm/audio/dit.py
+++ b/comfy/ldm/audio/dit.py
@@ -298,7 +298,8 @@ class Attention(nn.Module):
mask = None,
context_mask = None,
rotary_pos_emb = None,
- causal = None
+ causal = None,
+ transformer_options={},
):
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
@@ -363,7 +364,7 @@ class Attention(nn.Module):
heads_per_kv_head = h // kv_h
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
- out = optimized_attention(q, k, v, h, skip_reshape=True)
+ out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
out = self.to_out(out)
if mask is not None:
@@ -488,7 +489,8 @@ class TransformerBlock(nn.Module):
global_cond=None,
mask = None,
context_mask = None,
- rotary_pos_emb = None
+ rotary_pos_emb = None,
+ transformer_options={}
):
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
@@ -498,12 +500,12 @@ class TransformerBlock(nn.Module):
residual = x
x = self.pre_norm(x)
x = x * (1 + scale_self) + shift_self
- x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
+ x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
x = x * torch.sigmoid(1 - gate_self)
x = x + residual
if context is not None:
- x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
if self.conformer is not None:
x = x + self.conformer(x)
@@ -517,10 +519,10 @@ class TransformerBlock(nn.Module):
x = x + residual
else:
- x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
+ x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
if context is not None:
- x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
if self.conformer is not None:
x = x + self.conformer(x)
@@ -606,7 +608,8 @@ class ContinuousTransformer(nn.Module):
return_info = False,
**kwargs
):
- patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
+ transformer_options = kwargs.get("transformer_options", {})
+ patches_replace = transformer_options.get("patches_replace", {})
batch, seq, device = *x.shape[:2], x.device
context = kwargs["context"]
@@ -632,7 +635,7 @@ class ContinuousTransformer(nn.Module):
# Attention layers
if self.rotary_pos_emb is not None:
- rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
+ rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=torch.float, device=x.device)
else:
rotary_pos_emb = None
@@ -645,13 +648,13 @@ class ContinuousTransformer(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
+ out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
- x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
+ x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options)
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
if return_info:
diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py
index 1258ae11f..66d9613b6 100644
--- a/comfy/ldm/aura/mmdit.py
+++ b/comfy/ldm/aura/mmdit.py
@@ -9,6 +9,7 @@ import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
+import comfy.patcher_extension
import comfy.ldm.common_dit
def modulate(x, shift, scale):
@@ -84,7 +85,7 @@ class SingleAttention(nn.Module):
)
#@torch.compile()
- def forward(self, c):
+ def forward(self, c, transformer_options={}):
bsz, seqlen1, _ = c.shape
@@ -94,7 +95,7 @@ class SingleAttention(nn.Module):
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
q, k = self.q_norm1(q), self.k_norm1(k)
- output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
+ output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
c = self.w1o(output)
return c
@@ -143,7 +144,7 @@ class DoubleAttention(nn.Module):
#@torch.compile()
- def forward(self, c, x):
+ def forward(self, c, x, transformer_options={}):
bsz, seqlen1, _ = c.shape
bsz, seqlen2, _ = x.shape
@@ -167,7 +168,7 @@ class DoubleAttention(nn.Module):
torch.cat([cv, xv], dim=1),
)
- output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
+ output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
c, x = output.split([seqlen1, seqlen2], dim=1)
c = self.w1o(c)
@@ -206,7 +207,7 @@ class MMDiTBlock(nn.Module):
self.is_last = is_last
#@torch.compile()
- def forward(self, c, x, global_cond, **kwargs):
+ def forward(self, c, x, global_cond, transformer_options={}, **kwargs):
cres, xres = c, x
@@ -224,7 +225,7 @@ class MMDiTBlock(nn.Module):
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
# attention
- c, x = self.attn(c, x)
+ c, x = self.attn(c, x, transformer_options=transformer_options)
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
@@ -254,13 +255,13 @@ class DiTBlock(nn.Module):
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
#@torch.compile()
- def forward(self, cx, global_cond, **kwargs):
+ def forward(self, cx, global_cond, transformer_options={}, **kwargs):
cxres = cx
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
global_cond
).chunk(6, dim=1)
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
- cx = self.attn(cx)
+ cx = self.attn(cx, transformer_options=transformer_options)
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
cx = gate_mlp.unsqueeze(1) * mlpout
@@ -436,6 +437,13 @@ class MMDiT(nn.Module):
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, timestep, context, transformer_options, **kwargs)
+
+ def _forward(self, x, timestep, context, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
# patchify x, add PE
b, c, h, w = x.shape
@@ -465,13 +473,14 @@ class MMDiT(nn.Module):
out = {}
out["txt"], out["img"] = layer(args["txt"],
args["img"],
- args["vec"])
+ args["vec"],
+ transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
c = out["txt"]
x = out["img"]
else:
- c, x = layer(c, x, global_cond, **kwargs)
+ c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs)
if len(self.single_layers) > 0:
c_len = c.size(1)
@@ -480,13 +489,13 @@ class MMDiT(nn.Module):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["img"] = layer(args["img"], args["vec"])
+ out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
+ out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
cx = out["img"]
else:
- cx = layer(cx, global_cond, **kwargs)
+ cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs)
x = cx[:, c_len:]
diff --git a/comfy/ldm/cascade/common.py b/comfy/ldm/cascade/common.py
index 3eaa0c821..42ef98c7a 100644
--- a/comfy/ldm/cascade/common.py
+++ b/comfy/ldm/cascade/common.py
@@ -32,12 +32,12 @@ class OptimizedAttention(nn.Module):
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
- def forward(self, q, k, v):
+ def forward(self, q, k, v, transformer_options={}):
q = self.to_q(q)
k = self.to_k(k)
v = self.to_v(v)
- out = optimized_attention(q, k, v, self.heads)
+ out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options)
return self.out_proj(out)
@@ -47,13 +47,13 @@ class Attention2D(nn.Module):
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
- def forward(self, x, kv, self_attn=False):
+ def forward(self, x, kv, self_attn=False, transformer_options={}):
orig_shape = x.shape
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
if self_attn:
kv = torch.cat([x, kv], dim=1)
# x = self.attn(x, kv, kv, need_weights=False)[0]
- x = self.attn(x, kv, kv)
+ x = self.attn(x, kv, kv, transformer_options=transformer_options)
x = x.permute(0, 2, 1).view(*orig_shape)
return x
@@ -114,9 +114,9 @@ class AttnBlock(nn.Module):
operations.Linear(c_cond, c, dtype=dtype, device=device)
)
- def forward(self, x, kv):
+ def forward(self, x, kv, transformer_options={}):
kv = self.kv_mapper(kv)
- x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
+ x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options)
return x
diff --git a/comfy/ldm/cascade/stage_b.py b/comfy/ldm/cascade/stage_b.py
index 773830956..428c67fdf 100644
--- a/comfy/ldm/cascade/stage_b.py
+++ b/comfy/ldm/cascade/stage_b.py
@@ -173,7 +173,7 @@ class StageB(nn.Module):
clip = self.clip_norm(clip)
return clip
- def _down_encode(self, x, r_embed, clip):
+ def _down_encode(self, x, r_embed, clip, transformer_options={}):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
@@ -187,7 +187,7 @@ class StageB(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
- x = block(x, clip)
+ x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@@ -199,7 +199,7 @@ class StageB(nn.Module):
level_outputs.insert(0, x)
return level_outputs
- def _up_decode(self, level_outputs, r_embed, clip):
+ def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
@@ -216,7 +216,7 @@ class StageB(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
- x = block(x, clip)
+ x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@@ -228,7 +228,7 @@ class StageB(nn.Module):
x = upscaler(x)
return x
- def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
+ def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs):
if pixels is None:
pixels = x.new_zeros(x.size(0), 3, 8, 8)
@@ -245,8 +245,8 @@ class StageB(nn.Module):
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
align_corners=True)
- level_outputs = self._down_encode(x, r_embed, clip)
- x = self._up_decode(level_outputs, r_embed, clip)
+ level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options)
+ x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options)
return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999):
diff --git a/comfy/ldm/cascade/stage_c.py b/comfy/ldm/cascade/stage_c.py
index b952d0349..ebc4434e2 100644
--- a/comfy/ldm/cascade/stage_c.py
+++ b/comfy/ldm/cascade/stage_c.py
@@ -182,7 +182,7 @@ class StageC(nn.Module):
clip = self.clip_norm(clip)
return clip
- def _down_encode(self, x, r_embed, clip, cnet=None):
+ def _down_encode(self, x, r_embed, clip, cnet=None, transformer_options={}):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
@@ -201,7 +201,7 @@ class StageC(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
- x = block(x, clip)
+ x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@@ -213,7 +213,7 @@ class StageC(nn.Module):
level_outputs.insert(0, x)
return level_outputs
- def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
+ def _up_decode(self, level_outputs, r_embed, clip, cnet=None, transformer_options={}):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
@@ -235,7 +235,7 @@ class StageC(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
- x = block(x, clip)
+ x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@@ -247,7 +247,7 @@ class StageC(nn.Module):
x = upscaler(x)
return x
- def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
+ def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, transformer_options={}, **kwargs):
# Process the conditioning embeddings
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
for c in self.t_conds:
@@ -262,8 +262,8 @@ class StageC(nn.Module):
# Model Blocks
x = self.embedding(x)
- level_outputs = self._down_encode(x, r_embed, clip, cnet)
- x = self._up_decode(level_outputs, r_embed, clip, cnet)
+ level_outputs = self._down_encode(x, r_embed, clip, cnet, transformer_options=transformer_options)
+ x = self._up_decode(level_outputs, r_embed, clip, cnet, transformer_options=transformer_options)
return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999):
diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py
index 2a0dec606..fc7110cce 100644
--- a/comfy/ldm/chroma/layers.py
+++ b/comfy/ldm/chroma/layers.py
@@ -76,7 +76,7 @@ class DoubleStreamBlock(nn.Module):
)
self.flipped_img_txt = flipped_img_txt
- def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
+ def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
@@ -95,7 +95,7 @@ class DoubleStreamBlock(nn.Module):
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
- pe=pe, mask=attn_mask)
+ pe=pe, mask=attn_mask, transformer_options=transformer_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
@@ -148,7 +148,7 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh")
- def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
+ def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
mod = vec
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@@ -157,7 +157,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v)
# compute attention
- attn = attention(q, k, v, pe=pe, mask=attn_mask)
+ attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x.addcmul_(mod.gate, output)
diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py
index 06021d4f2..ad1c523fe 100644
--- a/comfy/ldm/chroma/model.py
+++ b/comfy/ldm/chroma/model.py
@@ -5,6 +5,7 @@ from dataclasses import dataclass
import torch
from torch import Tensor, nn
from einops import rearrange, repeat
+import comfy.patcher_extension
import comfy.ldm.common_dit
from comfy.ldm.flux.layers import (
@@ -150,8 +151,6 @@ class Chroma(nn.Module):
attn_mask: Tensor = None,
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
- if img.ndim != 3 or txt.ndim != 3:
- raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
@@ -192,14 +191,16 @@ class Chroma(nn.Module):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
- attn_mask=args.get("attn_mask"))
+ attn_mask=args.get("attn_mask"),
+ transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": double_mod,
"pe": pe,
- "attn_mask": attn_mask},
+ "attn_mask": attn_mask,
+ "transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
@@ -208,7 +209,8 @@ class Chroma(nn.Module):
txt=txt,
vec=double_mod,
pe=pe,
- attn_mask=attn_mask)
+ attn_mask=attn_mask,
+ transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
@@ -228,17 +230,19 @@ class Chroma(nn.Module):
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
- attn_mask=args.get("attn_mask"))
+ attn_mask=args.get("attn_mask"),
+ transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": single_mod,
"pe": pe,
- "attn_mask": attn_mask},
+ "attn_mask": attn_mask,
+ "transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
- img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
+ img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
@@ -248,16 +252,27 @@ class Chroma(nn.Module):
img[:, txt.shape[1] :, ...] += add
img = img[:, txt.shape[1] :, ...]
- final_mod = self.get_modulations(mod_vectors, "final")
- img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
+ if hasattr(self, "final_layer"):
+ final_mod = self.get_modulations(mod_vectors, "final")
+ img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, timestep, context, guidance, control, transformer_options, **kwargs)
+
+ def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
+ if img.ndim != 3 or context.ndim != 3:
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
+
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
diff --git a/comfy/ldm/chroma_radiance/layers.py b/comfy/ldm/chroma_radiance/layers.py
new file mode 100644
index 000000000..3c7bc9b6b
--- /dev/null
+++ b/comfy/ldm/chroma_radiance/layers.py
@@ -0,0 +1,206 @@
+# Adapted from https://github.com/lodestone-rock/flow
+from functools import lru_cache
+
+import torch
+from torch import nn
+
+from comfy.ldm.flux.layers import RMSNorm
+
+
+class NerfEmbedder(nn.Module):
+ """
+ An embedder module that combines input features with a 2D positional
+ encoding that mimics the Discrete Cosine Transform (DCT).
+
+ This module takes an input tensor of shape (B, P^2, C), where P is the
+ patch size, and enriches it with positional information before projecting
+ it to a new hidden size.
+ """
+ def __init__(
+ self,
+ in_channels: int,
+ hidden_size_input: int,
+ max_freqs: int,
+ dtype=None,
+ device=None,
+ operations=None,
+ ):
+ """
+ Initializes the NerfEmbedder.
+
+ Args:
+ in_channels (int): The number of channels in the input tensor.
+ hidden_size_input (int): The desired dimension of the output embedding.
+ max_freqs (int): The number of frequency components to use for both
+ the x and y dimensions of the positional encoding.
+ The total number of positional features will be max_freqs^2.
+ """
+ super().__init__()
+ self.dtype = dtype
+ self.max_freqs = max_freqs
+ self.hidden_size_input = hidden_size_input
+
+ # A linear layer to project the concatenated input features and
+ # positional encodings to the final output dimension.
+ self.embedder = nn.Sequential(
+ operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
+ )
+
+ @lru_cache(maxsize=4)
+ def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
+ """
+ Generates and caches 2D DCT-like positional embeddings for a given patch size.
+
+ The LRU cache is a performance optimization that avoids recomputing the
+ same positional grid on every forward pass.
+
+ Args:
+ patch_size (int): The side length of the square input patch.
+ device: The torch device to create the tensors on.
+ dtype: The torch dtype for the tensors.
+
+ Returns:
+ A tensor of shape (1, patch_size^2, max_freqs^2) containing the
+ positional embeddings.
+ """
+ # Create normalized 1D coordinate grids from 0 to 1.
+ pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
+ pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
+
+ # Create a 2D meshgrid of coordinates.
+ pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
+
+ # Reshape positions to be broadcastable with frequencies.
+ # Shape becomes (patch_size^2, 1, 1).
+ pos_x = pos_x.reshape(-1, 1, 1)
+ pos_y = pos_y.reshape(-1, 1, 1)
+
+ # Create a 1D tensor of frequency values from 0 to max_freqs-1.
+ freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device)
+
+ # Reshape frequencies to be broadcastable for creating 2D basis functions.
+ # freqs_x shape: (1, max_freqs, 1)
+ # freqs_y shape: (1, 1, max_freqs)
+ freqs_x = freqs[None, :, None]
+ freqs_y = freqs[None, None, :]
+
+ # A custom weighting coefficient, not part of standard DCT.
+ # This seems to down-weight the contribution of higher-frequency interactions.
+ coeffs = (1 + freqs_x * freqs_y) ** -1
+
+ # Calculate the 1D cosine basis functions for x and y coordinates.
+ # This is the core of the DCT formulation.
+ dct_x = torch.cos(pos_x * freqs_x * torch.pi)
+ dct_y = torch.cos(pos_y * freqs_y * torch.pi)
+
+ # Combine the 1D basis functions to create 2D basis functions by element-wise
+ # multiplication, and apply the custom coefficients. Broadcasting handles the
+ # combination of all (pos_x, freqs_x) with all (pos_y, freqs_y).
+ # The result is flattened into a feature vector for each position.
+ dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2)
+
+ return dct
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass for the embedder.
+
+ Args:
+ inputs (Tensor): The input tensor of shape (B, P^2, C).
+
+ Returns:
+ Tensor: The output tensor of shape (B, P^2, hidden_size_input).
+ """
+ # Get the batch size, number of pixels, and number of channels.
+ B, P2, C = inputs.shape
+
+ # Infer the patch side length from the number of pixels (P^2).
+ patch_size = int(P2 ** 0.5)
+
+ input_dtype = inputs.dtype
+ inputs = inputs.to(dtype=self.dtype)
+
+ # Fetch the pre-computed or cached positional embeddings.
+ dct = self.fetch_pos(patch_size, inputs.device, self.dtype)
+
+ # Repeat the positional embeddings for each item in the batch.
+ dct = dct.repeat(B, 1, 1)
+
+ # Concatenate the original input features with the positional embeddings
+ # along the feature dimension.
+ inputs = torch.cat((inputs, dct), dim=-1)
+
+ # Project the combined tensor to the target hidden size.
+ return self.embedder(inputs).to(dtype=input_dtype)
+
+
+class NerfGLUBlock(nn.Module):
+ """
+ A NerfBlock using a Gated Linear Unit (GLU) like MLP.
+ """
+ def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio, dtype=None, device=None, operations=None):
+ super().__init__()
+ # The total number of parameters for the MLP is increased to accommodate
+ # the gate, value, and output projection matrices.
+ # We now need to generate parameters for 3 matrices.
+ total_params = 3 * hidden_size_x**2 * mlp_ratio
+ self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
+ self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
+ self.mlp_ratio = mlp_ratio
+
+
+ def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
+ batch_size, num_x, hidden_size_x = x.shape
+ mlp_params = self.param_generator(s)
+
+ # Split the generated parameters into three parts for the gate, value, and output projection.
+ fc1_gate_params, fc1_value_params, fc2_params = mlp_params.chunk(3, dim=-1)
+
+ # Reshape the parameters into matrices for batch matrix multiplication.
+ fc1_gate = fc1_gate_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
+ fc1_value = fc1_value_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
+ fc2 = fc2_params.view(batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x)
+
+ # Normalize the generated weight matrices as in the original implementation.
+ fc1_gate = torch.nn.functional.normalize(fc1_gate, dim=-2)
+ fc1_value = torch.nn.functional.normalize(fc1_value, dim=-2)
+ fc2 = torch.nn.functional.normalize(fc2, dim=-2)
+
+ res_x = x
+ x = self.norm(x)
+
+ # Apply the final output projection.
+ x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2)
+
+ return x + res_x
+
+
+class NerfFinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
+ self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
+ # So we temporarily move the channel dimension to the end for the norm operation.
+ return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1)
+
+
+class NerfFinalLayerConv(nn.Module):
+ def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
+ self.conv = operations.Conv2d(
+ in_channels=hidden_size,
+ out_channels=out_channels,
+ kernel_size=3,
+ padding=1,
+ dtype=dtype,
+ device=device,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
+ # So we temporarily move the channel dimension to the end for the norm operation.
+ return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1))
diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py
new file mode 100644
index 000000000..47aa11b04
--- /dev/null
+++ b/comfy/ldm/chroma_radiance/model.py
@@ -0,0 +1,329 @@
+# Credits:
+# Original Flux code can be found on: https://github.com/black-forest-labs/flux
+# Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow
+
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+from torch import Tensor, nn
+from einops import repeat
+import comfy.ldm.common_dit
+
+from comfy.ldm.flux.layers import EmbedND
+
+from comfy.ldm.chroma.model import Chroma, ChromaParams
+from comfy.ldm.chroma.layers import (
+ DoubleStreamBlock,
+ SingleStreamBlock,
+ Approximator,
+)
+from .layers import (
+ NerfEmbedder,
+ NerfGLUBlock,
+ NerfFinalLayer,
+ NerfFinalLayerConv,
+)
+
+
+@dataclass
+class ChromaRadianceParams(ChromaParams):
+ patch_size: int
+ nerf_hidden_size: int
+ nerf_mlp_ratio: int
+ nerf_depth: int
+ nerf_max_freqs: int
+ # Setting nerf_tile_size to 0 disables tiling.
+ nerf_tile_size: int
+ # Currently one of linear (legacy) or conv.
+ nerf_final_head_type: str
+ # None means use the same dtype as the model.
+ nerf_embedder_dtype: Optional[torch.dtype]
+
+
+class ChromaRadiance(Chroma):
+ """
+ Transformer model for flow matching on sequences.
+ """
+
+ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
+ if operations is None:
+ raise RuntimeError("Attempt to create ChromaRadiance object without setting operations")
+ nn.Module.__init__(self)
+ self.dtype = dtype
+ params = ChromaRadianceParams(**kwargs)
+ self.params = params
+ self.patch_size = params.patch_size
+ self.in_channels = params.in_channels
+ self.out_channels = params.out_channels
+ if params.hidden_size % params.num_heads != 0:
+ raise ValueError(
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
+ )
+ pe_dim = params.hidden_size // params.num_heads
+ if sum(params.axes_dim) != pe_dim:
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
+ self.hidden_size = params.hidden_size
+ self.num_heads = params.num_heads
+ self.in_dim = params.in_dim
+ self.out_dim = params.out_dim
+ self.hidden_dim = params.hidden_dim
+ self.n_layers = params.n_layers
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
+ self.img_in_patch = operations.Conv2d(
+ params.in_channels,
+ params.hidden_size,
+ kernel_size=params.patch_size,
+ stride=params.patch_size,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ )
+ self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
+ # set as nn identity for now, will overwrite it later.
+ self.distilled_guidance_layer = Approximator(
+ in_dim=self.in_dim,
+ hidden_dim=self.hidden_dim,
+ out_dim=self.out_dim,
+ n_layers=self.n_layers,
+ dtype=dtype, device=device, operations=operations
+ )
+
+
+ self.double_blocks = nn.ModuleList(
+ [
+ DoubleStreamBlock(
+ self.hidden_size,
+ self.num_heads,
+ mlp_ratio=params.mlp_ratio,
+ qkv_bias=params.qkv_bias,
+ dtype=dtype, device=device, operations=operations
+ )
+ for _ in range(params.depth)
+ ]
+ )
+
+ self.single_blocks = nn.ModuleList(
+ [
+ SingleStreamBlock(
+ self.hidden_size,
+ self.num_heads,
+ mlp_ratio=params.mlp_ratio,
+ dtype=dtype, device=device, operations=operations,
+ )
+ for _ in range(params.depth_single_blocks)
+ ]
+ )
+
+ # pixel channel concat with DCT
+ self.nerf_image_embedder = NerfEmbedder(
+ in_channels=params.in_channels,
+ hidden_size_input=params.nerf_hidden_size,
+ max_freqs=params.nerf_max_freqs,
+ dtype=params.nerf_embedder_dtype or dtype,
+ device=device,
+ operations=operations,
+ )
+
+ self.nerf_blocks = nn.ModuleList([
+ NerfGLUBlock(
+ hidden_size_s=params.hidden_size,
+ hidden_size_x=params.nerf_hidden_size,
+ mlp_ratio=params.nerf_mlp_ratio,
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ ) for _ in range(params.nerf_depth)
+ ])
+
+ if params.nerf_final_head_type == "linear":
+ self.nerf_final_layer = NerfFinalLayer(
+ params.nerf_hidden_size,
+ out_channels=params.in_channels,
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+ elif params.nerf_final_head_type == "conv":
+ self.nerf_final_layer_conv = NerfFinalLayerConv(
+ params.nerf_hidden_size,
+ out_channels=params.in_channels,
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+ else:
+ errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}"
+ raise ValueError(errstr)
+
+ self.skip_mmdit = []
+ self.skip_dit = []
+ self.lite = False
+
+ @property
+ def _nerf_final_layer(self) -> nn.Module:
+ if self.params.nerf_final_head_type == "linear":
+ return self.nerf_final_layer
+ if self.params.nerf_final_head_type == "conv":
+ return self.nerf_final_layer_conv
+ # Impossible to get here as we raise an error on unexpected types on initialization.
+ raise NotImplementedError
+
+ def img_in(self, img: Tensor) -> Tensor:
+ img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P]
+ # flatten into a sequence for the transformer.
+ return img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden]
+
+ def forward_nerf(
+ self,
+ img_orig: Tensor,
+ img_out: Tensor,
+ params: ChromaRadianceParams,
+ ) -> Tensor:
+ B, C, H, W = img_orig.shape
+ num_patches = img_out.shape[1]
+ patch_size = params.patch_size
+
+ # Store the raw pixel values of each patch for the NeRF head later.
+ # unfold creates patches: [B, C * P * P, NumPatches]
+ nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
+ nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
+
+ if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
+ # Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
+ # the tile size.
+ img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params)
+ else:
+ # Reshape for per-patch processing
+ nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
+ nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
+
+ # Get DCT-encoded pixel embeddings [pixel-dct]
+ img_dct = self.nerf_image_embedder(nerf_pixels)
+
+ # Pass through the dynamic MLP blocks (the NeRF)
+ for block in self.nerf_blocks:
+ img_dct = block(img_dct, nerf_hidden)
+
+ # Reassemble the patches into the final image.
+ img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P]
+ # Reshape to combine with batch dimension for fold
+ img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P]
+ img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches]
+ img_dct = nn.functional.fold(
+ img_dct,
+ output_size=(H, W),
+ kernel_size=patch_size,
+ stride=patch_size,
+ )
+ return self._nerf_final_layer(img_dct)
+
+ def forward_tiled_nerf(
+ self,
+ nerf_hidden: Tensor,
+ nerf_pixels: Tensor,
+ batch: int,
+ channels: int,
+ num_patches: int,
+ patch_size: int,
+ params: ChromaRadianceParams,
+ ) -> Tensor:
+ """
+ Processes the NeRF head in tiles to save memory.
+ nerf_hidden has shape [B, L, D]
+ nerf_pixels has shape [B, L, C * P * P]
+ """
+ tile_size = params.nerf_tile_size
+ output_tiles = []
+ # Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
+ for i in range(0, num_patches, tile_size):
+ end = min(i + tile_size, num_patches)
+
+ # Slice the current tile from the input tensors
+ nerf_hidden_tile = nerf_hidden[:, i:end, :]
+ nerf_pixels_tile = nerf_pixels[:, i:end, :]
+
+ # Get the actual number of patches in this tile (can be smaller for the last tile)
+ num_patches_tile = nerf_hidden_tile.shape[1]
+
+ # Reshape the tile for per-patch processing
+ # [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
+ nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size)
+ # [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
+ nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
+
+ # get DCT-encoded pixel embeddings [pixel-dct]
+ img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
+
+ # pass through the dynamic MLP blocks (the NeRF)
+ for block in self.nerf_blocks:
+ img_dct_tile = block(img_dct_tile, nerf_hidden_tile)
+
+ output_tiles.append(img_dct_tile)
+
+ # Concatenate the processed tiles along the patch dimension
+ return torch.cat(output_tiles, dim=0)
+
+ def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams:
+ params = self.params
+ if not overrides:
+ return params
+ params_dict = {k: getattr(params, k) for k in params.__dataclass_fields__}
+ nullable_keys = frozenset(("nerf_embedder_dtype",))
+ bad_keys = tuple(k for k in overrides if k not in params_dict)
+ if bad_keys:
+ e = f"Unknown key(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
+ raise ValueError(e)
+ bad_keys = tuple(
+ k
+ for k, v in overrides.items()
+ if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys)
+ )
+ if bad_keys:
+ e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
+ raise ValueError(e)
+ # At this point it's all valid keys and values so we can merge with the existing params.
+ params_dict |= overrides
+ return params.__class__(**params_dict)
+
+ def _forward(
+ self,
+ x: Tensor,
+ timestep: Tensor,
+ context: Tensor,
+ guidance: Optional[Tensor],
+ control: Optional[dict]=None,
+ transformer_options: dict={},
+ **kwargs: dict,
+ ) -> Tensor:
+ bs, c, h, w = x.shape
+ img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
+
+ if img.ndim != 4:
+ raise ValueError("Input img tensor must be in [B, C, H, W] format.")
+ if context.ndim != 3:
+ raise ValueError("Input txt tensors must have 3 dimensions.")
+
+ params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {}))
+
+ h_len = (img.shape[-2] // self.patch_size)
+ w_len = (img.shape[-1] // self.patch_size)
+
+ img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
+ img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
+ img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
+ txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
+
+ img_out = self.forward_orig(
+ img,
+ img_ids,
+ context,
+ txt_ids,
+ timestep,
+ guidance,
+ control,
+ transformer_options,
+ attn_mask=kwargs.get("attention_mask", None),
+ )
+ return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py
index 5c4356a3f..afb43d469 100644
--- a/comfy/ldm/cosmos/blocks.py
+++ b/comfy/ldm/cosmos/blocks.py
@@ -176,6 +176,7 @@ class Attention(nn.Module):
context=None,
mask=None,
rope_emb=None,
+ transformer_options={},
**kwargs,
):
"""
@@ -184,7 +185,7 @@ class Attention(nn.Module):
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
- out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
+ out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True, transformer_options=transformer_options)
del q, k, v
out = rearrange(out, " b n s c -> s b (n c)")
return self.to_out(out)
@@ -546,6 +547,7 @@ class VideoAttn(nn.Module):
context: Optional[torch.Tensor] = None,
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
+ transformer_options: Optional[dict] = {},
) -> torch.Tensor:
"""
Forward pass for video attention.
@@ -571,6 +573,7 @@ class VideoAttn(nn.Module):
context_M_B_D,
crossattn_mask,
rope_emb=rope_emb_L_1_1_D,
+ transformer_options=transformer_options,
)
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
return x_T_H_W_B_D
@@ -665,6 +668,7 @@ class DITBuildingBlock(nn.Module):
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
+ transformer_options: Optional[dict] = {},
) -> torch.Tensor:
"""
Forward pass for dynamically configured blocks with adaptive normalization.
@@ -702,6 +706,7 @@ class DITBuildingBlock(nn.Module):
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
context=None,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
+ transformer_options=transformer_options,
)
elif self.block_type in ["cross_attn", "ca"]:
x = x + gate_1_1_1_B_D * self.block(
@@ -709,6 +714,7 @@ class DITBuildingBlock(nn.Module):
context=crossattn_emb,
crossattn_mask=crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
+ transformer_options=transformer_options,
)
else:
raise ValueError(f"Unknown block type: {self.block_type}")
@@ -784,6 +790,7 @@ class GeneralDITTransformerBlock(nn.Module):
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
+ transformer_options: Optional[dict] = {},
) -> torch.Tensor:
for block in self.blocks:
x = block(
@@ -793,5 +800,6 @@ class GeneralDITTransformerBlock(nn.Module):
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
+ transformer_options=transformer_options,
)
return x
diff --git a/comfy/ldm/cosmos/model.py b/comfy/ldm/cosmos/model.py
index 4836e0b69..52ef7ef43 100644
--- a/comfy/ldm/cosmos/model.py
+++ b/comfy/ldm/cosmos/model.py
@@ -27,6 +27,8 @@ from torchvision import transforms
from enum import Enum
import logging
+import comfy.patcher_extension
+
from .blocks import (
FinalLayer,
GeneralDITTransformerBlock,
@@ -435,6 +437,42 @@ class GeneralDIT(nn.Module):
latent_condition_sigma: Optional[torch.Tensor] = None,
condition_video_augment_sigma: Optional[torch.Tensor] = None,
**kwargs,
+ ):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
+ ).execute(x,
+ timesteps,
+ context,
+ attention_mask,
+ fps,
+ image_size,
+ padding_mask,
+ scalar_feature,
+ data_type,
+ latent_condition,
+ latent_condition_sigma,
+ condition_video_augment_sigma,
+ **kwargs)
+
+ def _forward(
+ self,
+ x: torch.Tensor,
+ timesteps: torch.Tensor,
+ context: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ # crossattn_emb: torch.Tensor,
+ # crossattn_mask: Optional[torch.Tensor] = None,
+ fps: Optional[torch.Tensor] = None,
+ image_size: Optional[torch.Tensor] = None,
+ padding_mask: Optional[torch.Tensor] = None,
+ scalar_feature: Optional[torch.Tensor] = None,
+ data_type: Optional[DataType] = DataType.VIDEO,
+ latent_condition: Optional[torch.Tensor] = None,
+ latent_condition_sigma: Optional[torch.Tensor] = None,
+ condition_video_augment_sigma: Optional[torch.Tensor] = None,
+ **kwargs,
):
"""
Args:
@@ -482,6 +520,7 @@ class GeneralDIT(nn.Module):
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
+ transformer_options = kwargs.get("transformer_options", {})
for _, block in self.blocks.items():
assert (
self.blocks["block0"].x_format == block.x_format
@@ -496,6 +535,7 @@ class GeneralDIT(nn.Module):
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
+ transformer_options=transformer_options,
)
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py
index 316117f77..07a4fc79f 100644
--- a/comfy/ldm/cosmos/predict2.py
+++ b/comfy/ldm/cosmos/predict2.py
@@ -11,6 +11,7 @@ import math
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
from torchvision import transforms
+import comfy.patcher_extension
from comfy.ldm.modules.attention import optimized_attention
def apply_rotary_pos_emb(
@@ -43,7 +44,7 @@ class GPT2FeedForward(nn.Module):
return x
-def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
+def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
"""Computes multi-head attention using PyTorch's native implementation.
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
@@ -70,7 +71,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
- return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
+ return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, transformer_options=transformer_options)
class Attention(nn.Module):
@@ -179,8 +180,8 @@ class Attention(nn.Module):
return q, k, v
- def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
- result = self.attn_op(q, k, v) # [B, S, H, D]
+ def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
+ result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D]
return self.output_dropout(self.output_proj(result))
def forward(
@@ -188,6 +189,7 @@ class Attention(nn.Module):
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
+ transformer_options: Optional[dict] = {},
) -> torch.Tensor:
"""
Args:
@@ -195,7 +197,7 @@ class Attention(nn.Module):
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
- return self.compute_attention(q, k, v)
+ return self.compute_attention(q, k, v, transformer_options=transformer_options)
class Timesteps(nn.Module):
@@ -458,6 +460,7 @@ class Block(nn.Module):
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
+ transformer_options: Optional[dict] = {},
) -> torch.Tensor:
if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
@@ -511,6 +514,7 @@ class Block(nn.Module):
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
None,
rope_emb=rope_emb_L_1_1_D,
+ transformer_options=transformer_options,
),
"b (t h w) d -> b t h w d",
t=T,
@@ -524,6 +528,7 @@ class Block(nn.Module):
layer_norm_cross_attn: Callable,
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
+ transformer_options: Optional[dict] = {},
) -> torch.Tensor:
_normalized_x_B_T_H_W_D = _fn(
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
@@ -533,6 +538,7 @@ class Block(nn.Module):
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
+ transformer_options=transformer_options,
),
"b (t h w) d -> b t h w d",
t=T,
@@ -546,6 +552,7 @@ class Block(nn.Module):
self.layer_norm_cross_attn,
scale_cross_attn_B_T_1_1_D,
shift_cross_attn_B_T_1_1_D,
+ transformer_options=transformer_options,
)
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
@@ -805,7 +812,21 @@ class MiniTrainDIT(nn.Module):
)
return x_B_C_Tt_Hp_Wp
- def forward(
+ def forward(self,
+ x: torch.Tensor,
+ timesteps: torch.Tensor,
+ context: torch.Tensor,
+ fps: Optional[torch.Tensor] = None,
+ padding_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
+ ).execute(x, timesteps, context, fps, padding_mask, **kwargs)
+
+ def _forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
@@ -850,6 +871,7 @@ class MiniTrainDIT(nn.Module):
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
+ "transformer_options": kwargs.get("transformer_options", {}),
}
for block in self.blocks:
x_B_T_H_W_D = block(
diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py
index 113eb2096..ef21b416b 100644
--- a/comfy/ldm/flux/layers.py
+++ b/comfy/ldm/flux/layers.py
@@ -159,7 +159,7 @@ class DoubleStreamBlock(nn.Module):
)
self.flipped_img_txt = flipped_img_txt
- def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
@@ -182,7 +182,7 @@ class DoubleStreamBlock(nn.Module):
attn = attention(torch.cat((img_q, txt_q), dim=2),
torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2),
- pe=pe, mask=attn_mask)
+ pe=pe, mask=attn_mask, transformer_options=transformer_options)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
@@ -190,7 +190,7 @@ class DoubleStreamBlock(nn.Module):
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
- pe=pe, mask=attn_mask)
+ pe=pe, mask=attn_mask, transformer_options=transformer_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
@@ -244,7 +244,7 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
- def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
mod, _ = self.modulation(vec)
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@@ -252,7 +252,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v)
# compute attention
- attn = attention(q, k, v, pe=pe, mask=attn_mask)
+ attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += apply_mod(output, mod.gate, None, modulation_dims)
diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py
index 3e0978176..fb7cd7586 100644
--- a/comfy/ldm/flux/math.py
+++ b/comfy/ldm/flux/math.py
@@ -6,7 +6,7 @@ from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
-def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
+def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
q_shape = q.shape
k_shape = k.shape
@@ -17,7 +17,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
heads = q.shape[1]
- x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
+ x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
return x
@@ -35,11 +35,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.to(dtype=torch.float32, device=pos.device)
+def apply_rope1(x: Tensor, freqs_cis: Tensor):
+ x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
+ x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1]
+ return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
- xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
- xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
- xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
- xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
- return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
-
+ return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py
index c4de82795..14f90cea5 100644
--- a/comfy/ldm/flux/model.py
+++ b/comfy/ldm/flux/model.py
@@ -6,6 +6,7 @@ import torch
from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.ldm.common_dit
+import comfy.patcher_extension
from .layers import (
DoubleStreamBlock,
@@ -105,6 +106,7 @@ class Flux(nn.Module):
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
+ patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -116,9 +118,17 @@ class Flux(nn.Module):
if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
- vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
+ vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
txt = self.txt_in(txt)
+ if "post_input" in patches:
+ for p in patches["post_input"]:
+ out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
+ img = out["img"]
+ txt = out["txt"]
+ img_ids = out["img_ids"]
+ txt_ids = out["txt_ids"]
+
if img_ids is not None:
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
@@ -134,14 +144,16 @@ class Flux(nn.Module):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
- attn_mask=args.get("attn_mask"))
+ attn_mask=args.get("attn_mask"),
+ transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
- "attn_mask": attn_mask},
+ "attn_mask": attn_mask,
+ "transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
@@ -150,14 +162,15 @@ class Flux(nn.Module):
txt=txt,
vec=vec,
pe=pe,
- attn_mask=attn_mask)
+ attn_mask=attn_mask,
+ transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
- img += add
+ img[:, :add.shape[1]] += add
if img.dtype == torch.float16:
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
@@ -171,24 +184,26 @@ class Flux(nn.Module):
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
- attn_mask=args.get("attn_mask"))
+ attn_mask=args.get("attn_mask"),
+ transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
- "attn_mask": attn_mask},
+ "attn_mask": attn_mask,
+ "transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
- img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
+ img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
- img[:, txt.shape[1] :, ...] += add
+ img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
img = img[:, txt.shape[1] :, ...]
@@ -214,6 +229,13 @@ class Flux(nn.Module):
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, timestep, context, y, guidance, ref_latents, control, transformer_options, **kwargs)
+
+ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
bs, c, h_orig, w_orig = x.shape
patch_size = self.patch_size
@@ -225,12 +247,18 @@ class Flux(nn.Module):
h = 0
w = 0
index = 0
- index_ref_method = kwargs.get("ref_latents_method", "offset") == "index"
+ ref_latents_method = kwargs.get("ref_latents_method", "offset")
for ref in ref_latents:
- if index_ref_method:
+ if ref_latents_method == "index":
index += 1
h_offset = 0
w_offset = 0
+ elif ref_latents_method == "uxo":
+ index = 0
+ h_offset = h_len * patch_size + h
+ w_offset = w_len * patch_size + w
+ h += ref.shape[-2]
+ w += ref.shape[-1]
else:
index = 1
h_offset = 0
diff --git a/comfy/ldm/genmo/joint_model/asymm_models_joint.py b/comfy/ldm/genmo/joint_model/asymm_models_joint.py
index 366a8b713..5c1bb4d42 100644
--- a/comfy/ldm/genmo/joint_model/asymm_models_joint.py
+++ b/comfy/ldm/genmo/joint_model/asymm_models_joint.py
@@ -109,6 +109,7 @@ class AsymmetricAttention(nn.Module):
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
crop_y,
+ transformer_options={},
**rope_rotation,
) -> Tuple[torch.Tensor, torch.Tensor]:
rope_cos = rope_rotation.get("rope_cos")
@@ -143,7 +144,7 @@ class AsymmetricAttention(nn.Module):
xy = optimized_attention(q,
k,
- v, self.num_heads, skip_reshape=True)
+ v, self.num_heads, skip_reshape=True, transformer_options=transformer_options)
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
x = self.proj_x(x)
@@ -224,6 +225,7 @@ class AsymmetricJointBlock(nn.Module):
x: torch.Tensor,
c: torch.Tensor,
y: torch.Tensor,
+ transformer_options={},
**attn_kwargs,
):
"""Forward pass of a block.
@@ -256,6 +258,7 @@ class AsymmetricJointBlock(nn.Module):
y,
scale_x=scale_msa_x,
scale_y=scale_msa_y,
+ transformer_options=transformer_options,
**attn_kwargs,
)
@@ -524,10 +527,11 @@ class AsymmDiTJoint(nn.Module):
args["txt"],
rope_cos=args["rope_cos"],
rope_sin=args["rope_sin"],
- crop_y=args["num_tokens"]
+ crop_y=args["num_tokens"],
+ transformer_options=args["transformer_options"]
)
return out
- out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens, "transformer_options": transformer_options}, {"original_block": block_wrap})
y_feat = out["txt"]
x = out["img"]
else:
@@ -538,6 +542,7 @@ class AsymmDiTJoint(nn.Module):
rope_cos=rope_cos,
rope_sin=rope_sin,
crop_y=num_tokens,
+ transformer_options=transformer_options,
) # (B, M, D), (B, L, D)
del y_feat # Final layers don't use dense text features.
diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py
index 0305747bf..28d81c79e 100644
--- a/comfy/ldm/hidream/model.py
+++ b/comfy/ldm/hidream/model.py
@@ -13,6 +13,7 @@ from comfy.ldm.flux.layers import LastLayer
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
+import comfy.patcher_extension
import comfy.ldm.common_dit
@@ -71,8 +72,8 @@ class TimestepEmbed(nn.Module):
return t_emb
-def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
- return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
+def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, transformer_options={}):
+ return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2], transformer_options=transformer_options)
class HiDreamAttnProcessor_flashattn:
@@ -85,6 +86,7 @@ class HiDreamAttnProcessor_flashattn:
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
+ transformer_options={},
*args,
**kwargs,
) -> torch.FloatTensor:
@@ -132,7 +134,7 @@ class HiDreamAttnProcessor_flashattn:
query = torch.cat([query_1, query_2], dim=-1)
key = torch.cat([key_1, key_2], dim=-1)
- hidden_states = attention(query, key, value)
+ hidden_states = attention(query, key, value, transformer_options=transformer_options)
if not attn.single:
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
@@ -198,6 +200,7 @@ class HiDreamAttention(nn.Module):
image_tokens_masks: torch.FloatTensor = None,
norm_text_tokens: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
+ transformer_options={},
) -> torch.Tensor:
return self.processor(
self,
@@ -205,6 +208,7 @@ class HiDreamAttention(nn.Module):
image_tokens_masks = image_tokens_masks,
text_tokens = norm_text_tokens,
rope = rope,
+ transformer_options=transformer_options,
)
@@ -405,7 +409,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
-
+ transformer_options={},
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
@@ -418,6 +422,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
norm_image_tokens,
image_tokens_masks,
rope = rope,
+ transformer_options=transformer_options,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
@@ -482,6 +487,7 @@ class HiDreamImageTransformerBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
+ transformer_options={},
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
@@ -499,6 +505,7 @@ class HiDreamImageTransformerBlock(nn.Module):
image_tokens_masks,
norm_text_tokens,
rope = rope,
+ transformer_options=transformer_options,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
@@ -549,6 +556,7 @@ class HiDreamImageBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
+ transformer_options={},
) -> torch.FloatTensor:
return self.block(
image_tokens,
@@ -556,6 +564,7 @@ class HiDreamImageBlock(nn.Module):
text_tokens,
adaln_input,
rope,
+ transformer_options=transformer_options,
)
@@ -692,7 +701,23 @@ class HiDreamImageTransformer2DModel(nn.Module):
raise NotImplementedError
return x, x_masks, img_sizes
- def forward(
+ def forward(self,
+ x: torch.Tensor,
+ t: torch.Tensor,
+ y: Optional[torch.Tensor] = None,
+ context: Optional[torch.Tensor] = None,
+ encoder_hidden_states_llama3=None,
+ image_cond=None,
+ control = None,
+ transformer_options = {},
+ ):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, t, y, context, encoder_hidden_states_llama3, image_cond, control, transformer_options)
+
+ def _forward(
self,
x: torch.Tensor,
t: torch.Tensor,
@@ -769,6 +794,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
text_tokens = cur_encoder_hidden_states,
adaln_input = adaln_input,
rope = rope,
+ transformer_options=transformer_options,
)
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
block_id += 1
@@ -792,6 +818,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
text_tokens=None,
adaln_input=adaln_input,
rope=rope,
+ transformer_options=transformer_options,
)
hidden_states = hidden_states[:, :hidden_states_seq_len]
block_id += 1
diff --git a/comfy/ldm/hunyuan3d/model.py b/comfy/ldm/hunyuan3d/model.py
index 4e18358f0..4991b1645 100644
--- a/comfy/ldm/hunyuan3d/model.py
+++ b/comfy/ldm/hunyuan3d/model.py
@@ -7,6 +7,7 @@ from comfy.ldm.flux.layers import (
SingleStreamBlock,
timestep_embedding,
)
+import comfy.patcher_extension
class Hunyuan3Dv2(nn.Module):
@@ -67,6 +68,13 @@ class Hunyuan3Dv2(nn.Module):
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, timestep, context, guidance, transformer_options, **kwargs)
+
+ def _forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
x = x.movedim(-1, -2)
timestep = 1.0 - timestep
txt = context
@@ -91,14 +99,16 @@ class Hunyuan3Dv2(nn.Module):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
- attn_mask=args.get("attn_mask"))
+ attn_mask=args.get("attn_mask"),
+ transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
- "attn_mask": attn_mask},
+ "attn_mask": attn_mask,
+ "transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
@@ -107,7 +117,8 @@ class Hunyuan3Dv2(nn.Module):
txt=txt,
vec=vec,
pe=pe,
- attn_mask=attn_mask)
+ attn_mask=attn_mask,
+ transformer_options=transformer_options)
img = torch.cat((txt, img), 1)
@@ -118,17 +129,19 @@ class Hunyuan3Dv2(nn.Module):
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
- attn_mask=args.get("attn_mask"))
+ attn_mask=args.get("attn_mask"),
+ transformer_options=args["transformer_options"])
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
- "attn_mask": attn_mask},
+ "attn_mask": attn_mask,
+ "transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
- img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
+ img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
img = img[:, txt.shape[1]:, ...]
img = self.final_layer(img, vec)
diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py
index 6e8cbf1d9..760944827 100644
--- a/comfy/ldm/hunyuan3d/vae.py
+++ b/comfy/ldm/hunyuan3d/vae.py
@@ -4,81 +4,458 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-
-
-from typing import Union, Tuple, List, Callable, Optional
-
import numpy as np
-from einops import repeat, rearrange
+import math
from tqdm import tqdm
+
+from typing import Optional
+
import logging
import comfy.ops
ops = comfy.ops.disable_weight_init
-def generate_dense_grid_points(
- bbox_min: np.ndarray,
- bbox_max: np.ndarray,
- octree_resolution: int,
- indexing: str = "ij",
-):
- length = bbox_max - bbox_min
- num_cells = octree_resolution
+def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True):
- x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
- y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
- z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
- [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
- xyz = np.stack((xs, ys, zs), axis=-1)
- grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
+ # manually create the pointer vector
+ assert src.size(0) == batch.numel()
- return xyz, grid_size, length
+ batch_size = int(batch.max()) + 1
+ deg = src.new_zeros(batch_size, dtype = torch.long)
+
+ deg.scatter_add_(0, batch, torch.ones_like(batch))
+
+ ptr_vec = deg.new_zeros(batch_size + 1)
+ torch.cumsum(deg, 0, out=ptr_vec[1:])
+
+ #return fps_sampling(src, ptr_vec, ratio)
+ sampled_indicies = []
+
+ for b in range(batch_size):
+ # start and the end of each batch
+ start, end = ptr_vec[b].item(), ptr_vec[b + 1].item()
+ # points from the point cloud
+ points = src[start:end]
+
+ num_points = points.size(0)
+ num_samples = max(1, math.ceil(num_points * sampling_ratio))
+
+ selected = torch.zeros(num_samples, device = src.device, dtype = torch.long)
+ distances = torch.full((num_points,), float("inf"), device = src.device)
+
+ # select a random start point
+ if start_random:
+ farthest = torch.randint(0, num_points, (1,), device = src.device)
+ else:
+ farthest = torch.tensor([0], device = src.device, dtype = torch.long)
+
+ for i in range(num_samples):
+ selected[i] = farthest
+ centroid = points[farthest].squeeze(0)
+ dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance
+ distances = torch.minimum(distances, dist)
+ farthest = torch.argmax(distances)
+
+ sampled_indicies.append(torch.arange(start, end)[selected])
+
+ return torch.cat(sampled_indicies, dim = 0)
+class PointCrossAttention(nn.Module):
+ def __init__(self,
+ num_latents: int,
+ downsample_ratio: float,
+ pc_size: int,
+ pc_sharpedge_size: int,
+ point_feats: int,
+ width: int,
+ heads: int,
+ layers: int,
+ fourier_embedder,
+ normal_pe: bool = False,
+ qkv_bias: bool = False,
+ use_ln_post: bool = True,
+ qk_norm: bool = True):
+
+ super().__init__()
+
+ self.fourier_embedder = fourier_embedder
+
+ self.pc_size = pc_size
+ self.normal_pe = normal_pe
+ self.downsample_ratio = downsample_ratio
+ self.pc_sharpedge_size = pc_sharpedge_size
+ self.num_latents = num_latents
+ self.point_feats = point_feats
+
+ self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
+
+ self.cross_attn = ResidualCrossAttentionBlock(
+ width = width,
+ heads = heads,
+ qkv_bias = qkv_bias,
+ qk_norm = qk_norm
+ )
+
+ self.self_attn = None
+ if layers > 0:
+ self.self_attn = Transformer(
+ width = width,
+ heads = heads,
+ qkv_bias = qkv_bias,
+ qk_norm = qk_norm,
+ layers = layers
+ )
+
+ if use_ln_post:
+ self.ln_post = nn.LayerNorm(width)
+ else:
+ self.ln_post = None
+
+ def sample_points_and_latents(self, point_cloud: torch.Tensor, features: torch.Tensor):
+
+ """
+ Subsample points randomly from the point cloud (input_pc)
+ Further sample the subsampled points to get query_pc
+ take the fourier embeddings for both input and query pc
+
+ Mental Note: FPS-sampled points (query_pc) act as latent tokens that attend to and learn from the broader context in input_pc.
+ Goal: get a smaller represenation (query_pc) to represent the entire scence structure by learning from a broader subset (input_pc).
+ More computationally efficient.
+
+ Features are additional information for each point in the cloud
+ """
+
+ B, _, D = point_cloud.shape
+
+ num_latents = int(self.num_latents)
+
+ num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
+ num_sharpedge_query = num_latents - num_random_query
+
+ # Split random and sharpedge surface points
+ random_pc, sharpedge_pc = torch.split(point_cloud, [self.pc_size, self.pc_sharpedge_size], dim=1)
+
+ # assert statements
+ assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size"
+ assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size"
+
+ input_random_pc_size = int(num_random_query * self.downsample_ratio)
+ random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \
+ self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size)
+
+ input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
+
+ if input_sharpedge_pc_size == 0:
+ sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device)
+ sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device)
+
+ else:
+ sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \
+ self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size)
+
+ # concat the random and sharpedges
+ query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1)
+ input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1)
+
+ query = self.fourier_embedder(query_pc)
+ data = self.fourier_embedder(input_pc)
+
+ if self.point_feats > 0:
+ random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1)
+
+ input_random_surface_features, query_random_features = \
+ self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B,
+ input_pc_size = input_random_pc_size, idx_query = random_idx_query)
+
+ if input_sharpedge_pc_size == 0:
+ input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats,
+ dtype = input_random_surface_features.dtype, device = point_cloud.device)
+
+ query_sharpedge_features = torch.zeros(B, 0, self.point_feats,
+ dtype = query_random_features.dtype, device = point_cloud.device)
+ else:
+
+ input_sharpedge_surface_features, query_sharpedge_features = \
+ self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features,
+ batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size)
+
+ query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1)
+ input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1)
+
+ if self.normal_pe:
+ # apply the fourier embeddings on the first 3 dims (xyz)
+ input_features_pe = self.fourier_embedder(input_features[..., :3])
+ query_features_pe = self.fourier_embedder(query_features[..., :3])
+ # replace the first 3 dims with the new PE ones
+ input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1)
+ query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1)
+
+ # concat at the channels dim
+ query = torch.cat([query, query_features], dim = -1)
+ data = torch.cat([data, input_features], dim = -1)
+
+ # don't return pc_info to avoid unnecessary memory usuage
+ return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1])
+
+ def forward(self, point_cloud: torch.Tensor, features: torch.Tensor):
+
+ query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features)
+
+ # apply projections
+ query = self.input_proj(query)
+ data = self.input_proj(data)
+
+ # apply cross attention between query and data
+ latents = self.cross_attn(query, data)
+
+ if self.self_attn is not None:
+ latents = self.self_attn(latents)
+
+ if self.ln_post is not None:
+ latents = self.ln_post(latents)
+
+ return latents
-class VanillaVolumeDecoder:
+ def subsample(self, pc, num_query, input_pc_size: int):
+
+ """
+ num_query: number of points to keep after FPS
+ input_pc_size: number of points to select before FPS
+ """
+
+ B, _, D = pc.shape
+ query_ratio = num_query / input_pc_size
+
+ # random subsampling of points inside the point cloud
+ idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size]
+ input_pc = pc[:, idx_pc, :]
+
+ # flatten to allow applying fps across the whole batch
+ flattent_input_pc = input_pc.view(B * input_pc_size, D)
+
+ # construct a batch_down tensor to tell fps
+ # which points belong to which batch
+ N_down = int(flattent_input_pc.shape[0] / B)
+ batch_down = torch.arange(B).to(pc.device)
+ batch_down = torch.repeat_interleave(batch_down, N_down)
+
+ idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio)
+ query_pc = flattent_input_pc[idx_query].view(B, -1, D)
+
+ return query_pc, input_pc, idx_pc, idx_query
+
+ def handle_features(self, features, idx_pc, input_pc_size, batch_size: int, idx_query):
+
+ B = batch_size
+
+ input_surface_features = features[:, idx_pc, :]
+ flattent_input_features = input_surface_features.view(B * input_pc_size, -1)
+ query_features = flattent_input_features[idx_query].view(B, -1,
+ flattent_input_features.shape[-1])
+
+ return input_surface_features, query_features
+
+def normalize_mesh(mesh, scale = 0.9999):
+ """Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]"""
+
+ bbox = mesh.bounds
+ center = (bbox[1] + bbox[0]) / 2
+
+ max_extent = (bbox[1] - bbox[0]).max()
+ mesh.apply_translation(-center)
+ mesh.apply_scale((2 * scale) / max_extent)
+
+ return mesh
+
+def sample_pointcloud(mesh, num = 200000):
+ """ Uniformly sample points from the surface of the mesh """
+
+ points, face_idx = mesh.sample(num, return_index = True)
+ normals = mesh.face_normals[face_idx]
+ return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32))
+
+def detect_sharp_edges(mesh, threshold=0.985):
+ """Return edge indices (a, b) that lie on sharp boundaries of the mesh."""
+
+ V, F = mesh.vertices, mesh.faces
+ VN, FN = mesh.vertex_normals, mesh.face_normals
+
+ sharp_mask = np.ones(V.shape[0])
+ for i in range(3):
+ indices = F[:, i]
+ alignment = np.einsum('ij,ij->i', VN[indices], FN)
+ dot_stack = np.stack((sharp_mask[indices], alignment), axis=-1)
+ sharp_mask[indices] = np.min(dot_stack, axis=-1)
+
+ edge_a = np.concatenate([F[:, 0], F[:, 1], F[:, 2]])
+ edge_b = np.concatenate([F[:, 1], F[:, 2], F[:, 0]])
+ sharp_edges = (sharp_mask[edge_a] < threshold) & (sharp_mask[edge_b] < threshold)
+
+ return edge_a[sharp_edges], edge_b[sharp_edges]
+
+
+def sharp_sample_pointcloud(mesh, num = 16384):
+ """ Sample points preferentially from sharp edges in the mesh. """
+
+ edge_a, edge_b = detect_sharp_edges(mesh)
+ V, VN = mesh.vertices, mesh.vertex_normals
+
+ va, vb = V[edge_a], V[edge_b]
+ na, nb = VN[edge_a], VN[edge_b]
+
+ edge_lengths = np.linalg.norm(vb - va, axis=-1)
+ weights = edge_lengths / edge_lengths.sum()
+
+ indices = np.searchsorted(np.cumsum(weights), np.random.rand(num))
+ t = np.random.rand(num, 1)
+
+ samples = t * va[indices] + (1 - t) * vb[indices]
+ normals = t * na[indices] + (1 - t) * nb[indices]
+
+ return samples.astype(np.float32), normals.astype(np.float32)
+
+def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"):
+ """Load a surface with optional sharp-edge annotations from a trimesh mesh."""
+
+ import trimesh
+
+ try:
+ mesh_full = trimesh.util.concatenate(mesh.dump())
+ except Exception:
+ mesh_full = trimesh.util.concatenate(mesh)
+
+ mesh_full = normalize_mesh(mesh_full)
+
+ faces = mesh_full.faces
+ vertices = mesh_full.vertices
+ origin_face_count = faces.shape[0]
+
+ mesh_surface = trimesh.Trimesh(vertices=vertices, faces=faces[:origin_face_count])
+ mesh_fill = trimesh.Trimesh(vertices=vertices, faces=faces[origin_face_count:])
+
+ area_surface = mesh_surface.area
+ area_fill = mesh_fill.area
+ total_area = area_surface + area_fill
+
+ sample_num = 499712 // 2
+ fill_ratio = area_fill / total_area if total_area > 0 else 0
+
+ num_fill = int(sample_num * fill_ratio)
+ num_surface = sample_num - num_fill
+
+ surf_pts, surf_normals = sample_pointcloud(mesh_surface, num_surface)
+ fill_pts, fill_normals = (torch.zeros(0, 3), torch.zeros(0, 3)) if num_fill == 0 else sample_pointcloud(mesh_fill, num_fill)
+
+ sharp_pts, sharp_normals = sharp_sample_pointcloud(mesh_surface, sample_num)
+
+ def assemble_tensor(points, normals, label=None):
+
+ data = torch.cat([points, normals], dim=1).half().to(device)
+
+ if label is not None:
+ label_tensor = torch.full((data.shape[0], 1), float(label), dtype=torch.float16).to(device)
+ data = torch.cat([data, label_tensor], dim=1)
+
+ return data
+
+ surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0),
+ torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0),
+ label = 0 if sharpedge_flag else None)
+
+ sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals),
+ label = 1 if sharpedge_flag else None)
+
+ rng = np.random.default_rng()
+
+ surface = surface[rng.choice(surface.shape[0], num_points, replace = False)]
+ sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)]
+
+ full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0)
+
+ return full
+
+class SharpEdgeSurfaceLoader:
+ """ Load mesh surface and sharp edge samples. """
+
+ def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192):
+
+ self.num_uniform_points = num_uniform_points
+ self.num_sharp_points = num_sharp_points
+ self.total_points = num_uniform_points + num_sharp_points
+
+ def __call__(self, mesh_input, device = "cuda"):
+ mesh = self._load_mesh(mesh_input)
+ return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device)
+
+ @staticmethod
+ def _load_mesh(mesh_input):
+ import trimesh
+
+ if isinstance(mesh_input, str):
+ mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True)
+ else:
+ mesh = mesh_input
+
+ if isinstance(mesh, trimesh.Scene):
+ combined = None
+ for obj in mesh.geometry.values():
+ combined = obj if combined is None else combined + obj
+ return combined
+
+ return mesh
+
+class DiagonalGaussianDistribution:
+ def __init__(self, params: torch.Tensor, feature_dim: int = -1):
+
+ # divide quant channels (8) into mean and log variance
+ self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim)
+
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.std = torch.exp(0.5 * self.logvar)
+
+ def sample(self):
+
+ eps = torch.randn_like(self.std)
+ z = self.mean + eps * self.std
+
+ return z
+
+################################################
+# Volume Decoder
+################################################
+
+class VanillaVolumeDecoder():
@torch.no_grad()
- def __call__(
- self,
- latents: torch.FloatTensor,
- geo_decoder: Callable,
- bounds: Union[Tuple[float], List[float], float] = 1.01,
- num_chunks: int = 10000,
- octree_resolution: int = None,
- enable_pbar: bool = True,
- **kwargs,
- ):
- device = latents.device
- dtype = latents.dtype
- batch_size = latents.shape[0]
+ def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01,
+ num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs):
- # 1. generate query points
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
- bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
- xyz_samples, grid_size, length = generate_dense_grid_points(
- bbox_min=bbox_min,
- bbox_max=bbox_max,
- octree_resolution=octree_resolution,
- indexing="ij"
- )
- xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
+ bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:])
+
+ x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32)
+ y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32)
+ z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32)
+
+ [xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij")
+ xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3)
+ grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
- # 2. latents to 3d volume
batch_logits = []
- for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
+ for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding",
disable=not enable_pbar):
- chunk_queries = xyz_samples[start: start + num_chunks, :]
- chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
- logits = geo_decoder(queries=chunk_queries, latents=latents)
+
+ chunk_queries = xyz[start: start + num_chunks, :]
+ chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1)
+ logits = geo_decoder(queries = chunk_queries, latents = latents)
batch_logits.append(logits)
- grid_logits = torch.cat(batch_logits, dim=1)
- grid_logits = grid_logits.view((batch_size, *grid_size)).float()
+ grid_logits = torch.cat(batch_logits, dim = 1)
+ grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float()
return grid_logits
-
class FourierEmbedder(nn.Module):
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
each feature dimension of `x[..., i]` into:
@@ -175,13 +552,11 @@ class FourierEmbedder(nn.Module):
else:
return x
-
class CrossAttentionProcessor:
def __call__(self, attn, q, k, v):
out = comfy.ops.scaled_dot_product_attention(q, k, v)
return out
-
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
@@ -232,38 +607,41 @@ class MLP(nn.Module):
def forward(self, x):
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
-
class QKVMultiheadCrossAttention(nn.Module):
def __init__(
self,
- *,
heads: int,
+ n_data = None,
width=None,
qk_norm=False,
norm_layer=ops.LayerNorm
):
super().__init__()
self.heads = heads
+ self.n_data = n_data
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
- self.attn_processor = CrossAttentionProcessor()
-
def forward(self, q, kv):
+
_, n_ctx, _ = q.shape
bs, n_data, width = kv.shape
+
attn_ch = width // self.heads // 2
q = q.view(bs, n_ctx, self.heads, -1)
+
kv = kv.view(bs, n_data, self.heads, -1)
k, v = torch.split(kv, attn_ch, dim=-1)
q = self.q_norm(q)
k = self.k_norm(k)
- q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
- out = self.attn_processor(self, q, k, v)
- out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
- return out
+ q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
+ out = F.scaled_dot_product_attention(q, k, v)
+
+ out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
+
+ return out
class MultiheadCrossAttention(nn.Module):
def __init__(
@@ -306,7 +684,6 @@ class MultiheadCrossAttention(nn.Module):
x = self.c_proj(x)
return x
-
class ResidualCrossAttentionBlock(nn.Module):
def __init__(
self,
@@ -366,7 +743,7 @@ class QKVMultiheadAttention(nn.Module):
q = self.q_norm(q)
k = self.k_norm(k)
- q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
+ q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
return out
@@ -383,8 +760,7 @@ class MultiheadAttention(nn.Module):
drop_path_rate: float = 0.0
):
super().__init__()
- self.width = width
- self.heads = heads
+
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
self.c_proj = ops.Linear(width, width)
self.attention = QKVMultiheadAttention(
@@ -491,7 +867,7 @@ class CrossAttentionDecoder(nn.Module):
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
if self.downsample_ratio != 1:
self.latents_proj = ops.Linear(width * downsample_ratio, width)
- if self.enable_ln_post == False:
+ if not self.enable_ln_post:
qk_norm = False
self.cross_attn_decoder = ResidualCrossAttentionBlock(
width=width,
@@ -522,28 +898,44 @@ class CrossAttentionDecoder(nn.Module):
class ShapeVAE(nn.Module):
def __init__(
- self,
- *,
- embed_dim: int,
- width: int,
- heads: int,
- num_decoder_layers: int,
- geo_decoder_downsample_ratio: int = 1,
- geo_decoder_mlp_expand_ratio: int = 4,
- geo_decoder_ln_post: bool = True,
- num_freqs: int = 8,
- include_pi: bool = True,
- qkv_bias: bool = True,
- qk_norm: bool = False,
- label_type: str = "binary",
- drop_path_rate: float = 0.0,
- scale_factor: float = 1.0,
+ self,
+ *,
+ num_latents: int = 4096,
+ embed_dim: int = 64,
+ width: int = 1024,
+ heads: int = 16,
+ num_decoder_layers: int = 16,
+ num_encoder_layers: int = 8,
+ pc_size: int = 81920,
+ pc_sharpedge_size: int = 0,
+ point_feats: int = 4,
+ downsample_ratio: int = 20,
+ geo_decoder_downsample_ratio: int = 1,
+ geo_decoder_mlp_expand_ratio: int = 4,
+ geo_decoder_ln_post: bool = True,
+ num_freqs: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = True,
+ drop_path_rate: float = 0.0,
+ include_pi: bool = False,
+ scale_factor: float = 1.0039506158752403,
+ label_type: str = "binary",
):
super().__init__()
self.geo_decoder_ln_post = geo_decoder_ln_post
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
+ self.encoder = PointCrossAttention(layers = num_encoder_layers,
+ num_latents = num_latents,
+ downsample_ratio = downsample_ratio,
+ heads = heads,
+ pc_size = pc_size,
+ width = width,
+ point_feats = point_feats,
+ fourier_embedder = self.fourier_embedder,
+ pc_sharpedge_size = pc_sharpedge_size)
+
self.post_kl = ops.Linear(embed_dim, width)
self.transformer = Transformer(
@@ -583,5 +975,14 @@ class ShapeVAE(nn.Module):
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
return grid_logits.movedim(-2, -1)
- def encode(self, x):
- return None
+ def encode(self, surface):
+
+ pc, feats = surface[:, :, :3], surface[:, :, 3:]
+ latents = self.encoder(pc, feats)
+
+ moments = self.pre_kl(latents)
+ posterior = DiagonalGaussianDistribution(moments, feature_dim = -1)
+
+ latents = posterior.sample()
+
+ return latents
diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py
new file mode 100644
index 000000000..d48d9d642
--- /dev/null
+++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py
@@ -0,0 +1,659 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from comfy.ldm.modules.attention import optimized_attention
+import comfy.model_management
+
+class GELU(nn.Module):
+
+ def __init__(self, dim_in: int, dim_out: int, operations, device, dtype):
+ super().__init__()
+ self.proj = operations.Linear(dim_in, dim_out, device = device, dtype = dtype)
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+
+ if gate.device.type == "mps":
+ return F.gelu(gate.to(dtype = torch.float32)).to(dtype = gate.dtype)
+
+ return F.gelu(gate)
+
+ def forward(self, hidden_states):
+
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.gelu(hidden_states)
+
+ return hidden_states
+
+class FeedForward(nn.Module):
+
+ def __init__(self, dim: int, dim_out = None, mult: int = 4,
+ dropout: float = 0.0, inner_dim = None, operations = None, device = None, dtype = None):
+
+ super().__init__()
+ if inner_dim is None:
+ inner_dim = int(dim * mult)
+
+ dim_out = dim_out if dim_out is not None else dim
+
+ act_fn = GELU(dim, inner_dim, operations = operations, device = device, dtype = dtype)
+
+ self.net = nn.ModuleList([])
+ self.net.append(act_fn)
+
+ self.net.append(nn.Dropout(dropout))
+ self.net.append(operations.Linear(inner_dim, dim_out, device = device, dtype = dtype))
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+class AddAuxLoss(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, loss):
+ # do nothing in forward (no computation)
+ ctx.requires_aux_loss = loss.requires_grad
+ ctx.dtype = loss.dtype
+
+ return x
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ # add the aux loss gradients
+ grad_loss = None
+ # put the aux grad the same as the main grad loss
+ # aux grad contributes equally
+ if ctx.requires_aux_loss:
+ grad_loss = torch.ones(1, dtype = ctx.dtype, device = grad_output.device)
+
+ return grad_output, grad_loss
+
+class MoEGate(nn.Module):
+
+ def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01, device = None, dtype = None):
+
+ super().__init__()
+ self.top_k = num_experts_per_tok
+ self.n_routed_experts = num_experts
+
+ self.alpha = aux_loss_alpha
+
+ self.gating_dim = embed_dim
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), device = device, dtype = dtype))
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+
+ # flatten hidden states
+ hidden_states = hidden_states.view(-1, hidden_states.size(-1))
+
+ # get logits and pass it to softmax
+ logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), bias = None)
+ scores = logits.softmax(dim = -1)
+
+ topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False)
+
+ if self.training and self.alpha > 0.0:
+ scores_for_aux = scores
+
+ # used bincount instead of one hot encoding
+ counts = torch.bincount(topk_idx.view(-1), minlength = self.n_routed_experts).float()
+ ce = counts / topk_idx.numel() # normalized expert usage
+
+ # mean expert score
+ Pi = scores_for_aux.mean(0)
+
+ # expert balance loss
+ aux_loss = (Pi * ce * self.n_routed_experts).sum() * self.alpha
+ else:
+ aux_loss = None
+
+ return topk_idx, topk_weight, aux_loss
+
+class MoEBlock(nn.Module):
+ def __init__(self, dim, num_experts: int = 6, moe_top_k: int = 2, dropout: float = 0.0,
+ ff_inner_dim: int = None, operations = None, device = None, dtype = None):
+ super().__init__()
+
+ self.moe_top_k = moe_top_k
+ self.num_experts = num_experts
+
+ self.experts = nn.ModuleList([
+ FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
+ for _ in range(num_experts)
+ ])
+
+ self.gate = MoEGate(dim, num_experts = num_experts, num_experts_per_tok = moe_top_k, device = device, dtype = dtype)
+ self.shared_experts = FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
+
+ def forward(self, hidden_states) -> torch.Tensor:
+
+ identity = hidden_states
+ orig_shape = hidden_states.shape
+ topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
+
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+ flat_topk_idx = topk_idx.view(-1)
+
+ if self.training:
+
+ hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim = 0)
+ y = torch.empty_like(hidden_states, dtype = hidden_states.dtype)
+
+ for i, expert in enumerate(self.experts):
+ tmp = expert(hidden_states[flat_topk_idx == i])
+ y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)
+
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim = 1)
+ y = y.view(*orig_shape)
+
+ y = AddAuxLoss.apply(y, aux_loss)
+ else:
+ y = self.moe_infer(hidden_states, flat_expert_indices = flat_topk_idx,flat_expert_weights = topk_weight.view(-1, 1)).view(*orig_shape)
+
+ y = y + self.shared_experts(identity)
+
+ return y
+
+ @torch.no_grad()
+ def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
+
+ expert_cache = torch.zeros_like(x)
+ idxs = flat_expert_indices.argsort()
+
+ # no need for .numpy().cpu() here
+ tokens_per_expert = flat_expert_indices.bincount().cumsum(0)
+ token_idxs = idxs // self.moe_top_k
+
+ for i, end_idx in enumerate(tokens_per_expert):
+
+ start_idx = 0 if i == 0 else tokens_per_expert[i-1]
+
+ if start_idx == end_idx:
+ continue
+
+ expert = self.experts[i]
+ exp_token_idx = token_idxs[start_idx:end_idx]
+
+ expert_tokens = x[exp_token_idx]
+ expert_out = expert(expert_tokens)
+
+ expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
+
+ # use index_add_ with a 1-D index tensor directly avoids building a large [N, D] index map and extra memcopy required by scatter_reduce_
+ # + avoid dtype conversion
+ expert_cache.index_add_(0, exp_token_idx, expert_out)
+
+ return expert_cache
+
+class Timesteps(nn.Module):
+ def __init__(self, num_channels: int, downscale_freq_shift: float = 0.0,
+ scale: float = 1.0, max_period: int = 10000):
+ super().__init__()
+
+ self.num_channels = num_channels
+ half_dim = num_channels // 2
+
+ # precompute the “inv_freq” vector once
+ exponent = -math.log(max_period) * torch.arange(
+ half_dim, dtype=torch.float32
+ ) / (half_dim - downscale_freq_shift)
+
+ inv_freq = torch.exp(exponent)
+
+ # pad
+ if num_channels % 2 == 1:
+ # we’ll pad a zero at the end of the cos-half
+ inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(1)])
+
+ # register to buffer so it moves with the device
+ self.register_buffer("inv_freq", inv_freq, persistent = False)
+ self.scale = scale
+
+ def forward(self, timesteps: torch.Tensor):
+
+ x = timesteps.float().unsqueeze(1) * self.inv_freq.to(timesteps.device).unsqueeze(0)
+
+
+ # fused CUDA kernels for sin and cos
+ sin_emb = x.sin()
+ cos_emb = x.cos()
+
+ emb = torch.cat([sin_emb, cos_emb], dim = 1)
+
+ # scale factor
+ if self.scale != 1.0:
+ emb = emb * self.scale
+
+ # If we padded inv_freq for odd, emb is already wide enough; otherwise:
+ if emb.shape[1] > self.num_channels:
+ emb = emb[:, :self.num_channels]
+
+ return emb
+
+class TimestepEmbedder(nn.Module):
+ def __init__(self, hidden_size, frequency_embedding_size = 256, cond_proj_dim = None, operations = None, device = None, dtype = None):
+ super().__init__()
+
+ self.mlp = nn.Sequential(
+ operations.Linear(hidden_size, frequency_embedding_size, bias=True, device = device, dtype = dtype),
+ nn.GELU(),
+ operations.Linear(frequency_embedding_size, hidden_size, bias=True, device = device, dtype = dtype),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ if cond_proj_dim is not None:
+ self.cond_proj = operations.Linear(cond_proj_dim, frequency_embedding_size, bias=False, device = device, dtype = dtype)
+
+ self.time_embed = Timesteps(hidden_size)
+
+ def forward(self, timesteps, condition):
+
+ timestep_embed = self.time_embed(timesteps).type(self.mlp[0].weight.dtype)
+
+ if condition is not None:
+ cond_embed = self.cond_proj(condition)
+ timestep_embed = timestep_embed + cond_embed
+
+ time_conditioned = self.mlp(timestep_embed)
+
+ # for broadcasting with image tokens
+ return time_conditioned.unsqueeze(1)
+
+class MLP(nn.Module):
+ def __init__(self, *, width: int, operations = None, device = None, dtype = None):
+ super().__init__()
+ self.width = width
+ self.fc1 = operations.Linear(width, width * 4, device = device, dtype = dtype)
+ self.fc2 = operations.Linear(width * 4, width, device = device, dtype = dtype)
+ self.gelu = nn.GELU()
+
+ def forward(self, x):
+ return self.fc2(self.gelu(self.fc1(x)))
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ qdim,
+ kdim,
+ num_heads,
+ qkv_bias=True,
+ qk_norm=False,
+ norm_layer=nn.LayerNorm,
+ use_fp16: bool = False,
+ operations = None,
+ dtype = None,
+ device = None,
+ **kwargs,
+ ):
+ super().__init__()
+ self.qdim = qdim
+ self.kdim = kdim
+
+ self.num_heads = num_heads
+ self.head_dim = self.qdim // num_heads
+
+ self.scale = self.head_dim ** -0.5
+
+ self.to_q = operations.Linear(qdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
+ self.to_k = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
+ self.to_v = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
+
+ if use_fp16:
+ eps = 1.0 / 65504
+ else:
+ eps = 1e-6
+
+ if norm_layer == nn.LayerNorm:
+ norm_layer = operations.LayerNorm
+ else:
+ norm_layer = operations.RMSNorm
+
+ self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
+ self.out_proj = operations.Linear(qdim, qdim, bias=True, device = device, dtype = dtype)
+
+ def forward(self, x, y):
+
+ b, s1, _ = x.shape
+ _, s2, _ = y.shape
+
+ y = y.to(next(self.to_k.parameters()).dtype)
+
+ q = self.to_q(x)
+ k = self.to_k(y)
+ v = self.to_v(y)
+
+ kv = torch.cat((k, v), dim=-1)
+ split_size = kv.shape[-1] // self.num_heads // 2
+
+ kv = kv.view(1, -1, self.num_heads, split_size * 2)
+ k, v = torch.split(kv, split_size, dim=-1)
+
+ q = q.view(b, s1, self.num_heads, self.head_dim)
+ k = k.view(b, s2, self.num_heads, self.head_dim)
+ v = v.reshape(b, s2, self.num_heads * self.head_dim)
+
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+
+ x = optimized_attention(
+ q.reshape(b, s1, self.num_heads * self.head_dim),
+ k.reshape(b, s2, self.num_heads * self.head_dim),
+ v,
+ heads=self.num_heads,
+ )
+
+ out = self.out_proj(x)
+
+ return out
+
+class Attention(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ qkv_bias = True,
+ qk_norm = False,
+ norm_layer = nn.LayerNorm,
+ use_fp16: bool = False,
+ operations = None,
+ device = None,
+ dtype = None
+ ):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = self.dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ self.to_q = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
+ self.to_k = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
+ self.to_v = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
+
+ if use_fp16:
+ eps = 1.0 / 65504
+ else:
+ eps = 1e-6
+
+ if norm_layer == nn.LayerNorm:
+ norm_layer = operations.LayerNorm
+ else:
+ norm_layer = operations.RMSNorm
+
+ self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
+ self.out_proj = operations.Linear(dim, dim, device = device, dtype = dtype)
+
+ def forward(self, x):
+ B, N, _ = x.shape
+
+ query = self.to_q(x)
+ key = self.to_k(x)
+ value = self.to_v(x)
+
+ qkv_combined = torch.cat((query, key, value), dim=-1)
+ split_size = qkv_combined.shape[-1] // self.num_heads // 3
+
+ qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3)
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+
+ query = query.reshape(B, N, self.num_heads, self.head_dim)
+ key = key.reshape(B, N, self.num_heads, self.head_dim)
+ value = value.reshape(B, N, self.num_heads * self.head_dim)
+
+ query = self.q_norm(query)
+ key = self.k_norm(key)
+
+ x = optimized_attention(
+ query.reshape(B, N, self.num_heads * self.head_dim),
+ key.reshape(B, N, self.num_heads * self.head_dim),
+ value,
+ heads=self.num_heads,
+ )
+
+ x = self.out_proj(x)
+ return x
+
+class HunYuanDiTBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ c_emb_size,
+ num_heads,
+ text_states_dim=1024,
+ qk_norm=False,
+ norm_layer=nn.LayerNorm,
+ qk_norm_layer=True,
+ qkv_bias=True,
+ skip_connection=True,
+ timested_modulate=False,
+ use_moe: bool = False,
+ num_experts: int = 8,
+ moe_top_k: int = 2,
+ use_fp16: bool = False,
+ operations = None,
+ device = None, dtype = None
+ ):
+ super().__init__()
+
+ # eps can't be 1e-6 in fp16 mode because of numerical stability issues
+ if use_fp16:
+ eps = 1.0 / 65504
+ else:
+ eps = 1e-6
+
+ self.norm1 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
+
+ self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
+ norm_layer=qk_norm_layer, use_fp16 = use_fp16, device = device, dtype = dtype, operations = operations)
+
+ self.norm2 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
+
+ self.timested_modulate = timested_modulate
+ if self.timested_modulate:
+ self.default_modulation = nn.Sequential(
+ nn.SiLU(),
+ operations.Linear(c_emb_size, hidden_size, bias=True, device = device, dtype = dtype)
+ )
+
+ self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias,
+ qk_norm=qk_norm, norm_layer=qk_norm_layer, use_fp16 = use_fp16,
+ device = device, dtype = dtype, operations = operations)
+
+ self.norm3 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
+
+ if skip_connection:
+ self.skip_norm = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
+ self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, device = device, dtype = dtype)
+ else:
+ self.skip_linear = None
+
+ self.use_moe = use_moe
+
+ if self.use_moe:
+ self.moe = MoEBlock(
+ hidden_size,
+ num_experts = num_experts,
+ moe_top_k = moe_top_k,
+ dropout = 0.0,
+ ff_inner_dim = int(hidden_size * 4.0),
+ device = device, dtype = dtype,
+ operations = operations
+ )
+ else:
+ self.mlp = MLP(width=hidden_size, operations=operations, device = device, dtype = dtype)
+
+ def forward(self, hidden_states, conditioning=None, text_states=None, skip_tensor=None):
+
+ if self.skip_linear is not None:
+ combined = torch.cat([skip_tensor, hidden_states], dim=-1)
+ hidden_states = self.skip_linear(combined)
+ hidden_states = self.skip_norm(hidden_states)
+
+ # self attention
+ if self.timested_modulate:
+ modulation_shift = self.default_modulation(conditioning).unsqueeze(dim=1)
+ hidden_states = hidden_states + modulation_shift
+
+ self_attn_out = self.attn1(self.norm1(hidden_states))
+ hidden_states = hidden_states + self_attn_out
+
+ # cross attention
+ hidden_states = hidden_states + self.attn2(self.norm2(hidden_states), text_states)
+
+ # MLP Layer
+ mlp_input = self.norm3(hidden_states)
+
+ if self.use_moe:
+ hidden_states = hidden_states + self.moe(mlp_input)
+ else:
+ hidden_states = hidden_states + self.mlp(mlp_input)
+
+ return hidden_states
+
+class FinalLayer(nn.Module):
+
+ def __init__(self, final_hidden_size, out_channels, operations, use_fp16: bool = False, device = None, dtype = None):
+ super().__init__()
+
+ if use_fp16:
+ eps = 1.0 / 65504
+ else:
+ eps = 1e-6
+
+ self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
+ self.linear = operations.Linear(final_hidden_size, out_channels, bias = True, device = device, dtype = dtype)
+
+ def forward(self, x):
+ x = self.norm_final(x)
+ x = x[:, 1:]
+ x = self.linear(x)
+ return x
+
+class HunYuanDiTPlain(nn.Module):
+
+ # init with the defaults values from https://huggingface.co/tencent/Hunyuan3D-2.1/blob/main/hunyuan3d-dit-v2-1/config.yaml
+ def __init__(
+ self,
+ in_channels: int = 64,
+ hidden_size: int = 2048,
+ context_dim: int = 1024,
+ depth: int = 21,
+ num_heads: int = 16,
+ qk_norm: bool = True,
+ qkv_bias: bool = False,
+ num_moe_layers: int = 6,
+ guidance_cond_proj_dim = 2048,
+ norm_type = 'layer',
+ num_experts: int = 8,
+ moe_top_k: int = 2,
+ use_fp16: bool = False,
+ dtype = None,
+ device = None,
+ operations = None,
+ **kwargs
+ ):
+
+ self.dtype = dtype
+
+ super().__init__()
+
+ self.depth = depth
+
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+
+ norm = operations.LayerNorm if norm_type == 'layer' else operations.RMSNorm
+ qk_norm = operations.RMSNorm
+
+ self.context_dim = context_dim
+ self.guidance_cond_proj_dim = guidance_cond_proj_dim
+
+ self.x_embedder = operations.Linear(in_channels, hidden_size, bias = True, device = device, dtype = dtype)
+ self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim = guidance_cond_proj_dim, device = device, dtype = dtype, operations = operations)
+
+
+ # HUnYuanDiT Blocks
+ self.blocks = nn.ModuleList([
+ HunYuanDiTBlock(hidden_size=hidden_size,
+ c_emb_size=hidden_size,
+ num_heads=num_heads,
+ text_states_dim=context_dim,
+ qk_norm=qk_norm,
+ norm_layer = norm,
+ qk_norm_layer = qk_norm,
+ skip_connection=layer > depth // 2,
+ qkv_bias=qkv_bias,
+ use_moe=True if depth - layer <= num_moe_layers else False,
+ num_experts=num_experts,
+ moe_top_k=moe_top_k,
+ use_fp16 = use_fp16,
+ device = device, dtype = dtype, operations = operations)
+ for layer in range(depth)
+ ])
+
+ self.depth = depth
+
+ self.final_layer = FinalLayer(hidden_size, self.out_channels, use_fp16 = use_fp16, operations = operations, device = device, dtype = dtype)
+
+ def forward(self, x, t, context, transformer_options = {}, **kwargs):
+
+ x = x.movedim(-1, -2)
+ uncond_emb, cond_emb = context.chunk(2, dim = 0)
+
+ context = torch.cat([cond_emb, uncond_emb], dim = 0)
+ main_condition = context
+
+ t = 1.0 - t
+
+ time_embedded = self.t_embedder(t, condition = kwargs.get('guidance_cond'))
+
+ x = x.to(dtype = next(self.x_embedder.parameters()).dtype)
+ x_embedded = self.x_embedder(x)
+
+ combined = torch.cat([time_embedded, x_embedded], dim=1)
+
+ def block_wrap(args):
+ return block(
+ args["x"],
+ args["t"],
+ args["cond"],
+ skip_tensor=args.get("skip"),)
+
+ skip_stack = []
+ patches_replace = transformer_options.get("patches_replace", {})
+ blocks_replace = patches_replace.get("dit", {})
+ for idx, block in enumerate(self.blocks):
+ if idx <= self.depth // 2:
+ skip_input = None
+ else:
+ skip_input = skip_stack.pop()
+
+ if ("block", idx) in blocks_replace:
+
+ combined = blocks_replace[("block", idx)](
+ {
+ "x": combined,
+ "t": time_embedded,
+ "cond": main_condition,
+ "skip": skip_input,
+ },
+ {"original_block": block_wrap},
+ )
+ else:
+ combined = block(combined, time_embedded, main_condition, skip_tensor=skip_input)
+
+ if idx < self.depth // 2:
+ skip_stack.append(combined)
+
+ output = self.final_layer(combined)
+ output = output.movedim(-2, -1) * (-1.0)
+
+ cond_emb, uncond_emb = output.chunk(2, dim = 0)
+ return torch.cat([uncond_emb, cond_emb])
diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py
index fbd8d4196..5132e6c07 100644
--- a/comfy/ldm/hunyuan_video/model.py
+++ b/comfy/ldm/hunyuan_video/model.py
@@ -1,6 +1,7 @@
#Based on Flux code because of weird hunyuan video code license.
import torch
+import comfy.patcher_extension
import comfy.ldm.flux.layers
import comfy.ldm.modules.diffusionmodules.mmdit
from comfy.ldm.modules.attention import optimized_attention
@@ -39,6 +40,8 @@ class HunyuanVideoParams:
patch_size: list
qkv_bias: bool
guidance_embed: bool
+ byt5: bool
+ meanflow: bool
class SelfAttentionRef(nn.Module):
@@ -77,13 +80,13 @@ class TokenRefinerBlock(nn.Module):
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
- def forward(self, x, c, mask):
+ def forward(self, x, c, mask, transformer_options={}):
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn.qkv(norm_x)
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
- attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
+ attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True, transformer_options=transformer_options)
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
@@ -114,14 +117,14 @@ class IndividualTokenRefiner(nn.Module):
]
)
- def forward(self, x, c, mask):
+ def forward(self, x, c, mask, transformer_options={}):
m = None
if mask is not None:
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
m = m + m.transpose(2, 3)
for block in self.blocks:
- x = block(x, c, m)
+ x = block(x, c, m, transformer_options=transformer_options)
return x
@@ -149,6 +152,7 @@ class TokenRefiner(nn.Module):
x,
timesteps,
mask,
+ transformer_options={},
):
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
# m = mask.float().unsqueeze(-1)
@@ -157,9 +161,33 @@ class TokenRefiner(nn.Module):
c = t + self.c_embedder(c.to(x.dtype))
x = self.input_embedder(x)
- x = self.individual_token_refiner(x, c, mask)
+ x = self.individual_token_refiner(x, c, mask, transformer_options=transformer_options)
return x
+
+class ByT5Mapper(nn.Module):
+ def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_res=False, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.layernorm = operations.LayerNorm(in_dim, dtype=dtype, device=device)
+ self.fc1 = operations.Linear(in_dim, hidden_dim, dtype=dtype, device=device)
+ self.fc2 = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
+ self.fc3 = operations.Linear(out_dim, out_dim1, dtype=dtype, device=device)
+ self.use_res = use_res
+ self.act_fn = nn.GELU()
+
+ def forward(self, x):
+ if self.use_res:
+ res = x
+ x = self.layernorm(x)
+ x = self.fc1(x)
+ x = self.act_fn(x)
+ x = self.fc2(x)
+ x2 = self.act_fn(x)
+ x2 = self.fc3(x2)
+ if self.use_res:
+ x2 = x2 + res
+ return x2
+
class HunyuanVideo(nn.Module):
"""
Transformer model for flow matching on sequences.
@@ -184,9 +212,13 @@ class HunyuanVideo(nn.Module):
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
- self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
+ self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=len(self.patch_size) == 3, dtype=dtype, device=device, operations=operations)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
- self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
+ if params.vec_in_dim is not None:
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
+ else:
+ self.vector_in = None
+
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
)
@@ -214,6 +246,23 @@ class HunyuanVideo(nn.Module):
]
)
+ if params.byt5:
+ self.byt5_in = ByT5Mapper(
+ in_dim=1472,
+ out_dim=2048,
+ hidden_dim=2048,
+ out_dim1=self.hidden_size,
+ use_res=False,
+ dtype=dtype, device=device, operations=operations
+ )
+ else:
+ self.byt5_in = None
+
+ if params.meanflow:
+ self.time_r_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
+ else:
+ self.time_r_in = None
+
if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
@@ -225,10 +274,12 @@ class HunyuanVideo(nn.Module):
txt_ids: Tensor,
txt_mask: Tensor,
timesteps: Tensor,
- y: Tensor,
+ y: Tensor = None,
+ txt_byt5=None,
guidance: Tensor = None,
guiding_frame_index=None,
ref_latent=None,
+ disable_time_r=False,
control=None,
transformer_options={},
) -> Tensor:
@@ -239,6 +290,14 @@ class HunyuanVideo(nn.Module):
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
+ if (self.time_r_in is not None) and (not disable_time_r):
+ w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
+ if len(w) > 0:
+ timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
+ timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
+ vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
+ vec = (vec + vec_r) / 2
+
if ref_latent is not None:
ref_latent_ids = self.img_ids(ref_latent)
ref_latent = self.img_in(ref_latent)
@@ -249,13 +308,17 @@ class HunyuanVideo(nn.Module):
if guiding_frame_index is not None:
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
- vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
- vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
+ if self.vector_in is not None:
+ vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
+ vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
+ else:
+ vec = torch.cat([(token_replace_vec).unsqueeze(1), (vec).unsqueeze(1)], dim=1)
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
modulation_dims_txt = [(0, None, 1)]
else:
- vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
+ if self.vector_in is not None:
+ vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
modulation_dims = None
modulation_dims_txt = None
@@ -266,7 +329,13 @@ class HunyuanVideo(nn.Module):
if txt_mask is not None and not torch.is_floating_point(txt_mask):
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
- txt = self.txt_in(txt, timesteps, txt_mask)
+ txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
+
+ if self.byt5_in is not None and txt_byt5 is not None:
+ txt_byt5 = self.byt5_in(txt_byt5)
+ txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
+ txt = torch.cat((txt, txt_byt5), dim=1)
+ txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)
@@ -284,14 +353,14 @@ class HunyuanVideo(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
+ out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"], transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
+ out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt, 'transformer_options': transformer_options}, {"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
- img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt, transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
@@ -306,13 +375,13 @@ class HunyuanVideo(nn.Module):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
+ out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"], transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
+ out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims, 'transformer_options': transformer_options}, {"original_block": block_wrap})
img = out["img"]
else:
- img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
+ img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
@@ -327,12 +396,16 @@ class HunyuanVideo(nn.Module):
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
- shape = initial_shape[-3:]
+ shape = initial_shape[-len(self.patch_size):]
for i in range(len(shape)):
shape[i] = shape[i] // self.patch_size[i]
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
- img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
- img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
+ if img.ndim == 8:
+ img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
+ img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
+ else:
+ img = img.permute(0, 3, 1, 4, 2, 5)
+ img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3])
return img
def img_ids(self, x):
@@ -347,9 +420,30 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
- def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
- bs, c, t, h, w = x.shape
- img_ids = self.img_ids(x)
- txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
- out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
+ def img_ids_2d(self, x):
+ bs, c, h, w = x.shape
+ patch_size = self.patch_size
+ h_len = ((h + (patch_size[0] // 2)) // patch_size[0])
+ w_len = ((w + (patch_size[1] // 2)) // patch_size[1])
+ img_ids = torch.zeros((h_len, w_len, 2), device=x.device, dtype=x.dtype)
+ img_ids[:, :, 0] = img_ids[:, :, 0] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
+ img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
+ return repeat(img_ids, "h w c -> b (h w) c", b=bs)
+
+ def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
+
+ def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
+ bs = x.shape[0]
+ if len(self.patch_size) == 3:
+ img_ids = self.img_ids(x)
+ txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
+ else:
+ img_ids = self.img_ids_2d(x)
+ txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
+ out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
return out
diff --git a/comfy/ldm/hunyuan_video/vae.py b/comfy/ldm/hunyuan_video/vae.py
new file mode 100644
index 000000000..40c12b183
--- /dev/null
+++ b/comfy/ldm/hunyuan_video/vae.py
@@ -0,0 +1,136 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock
+import comfy.ops
+ops = comfy.ops.disable_weight_init
+
+
+class PixelShuffle2D(nn.Module):
+ def __init__(self, in_dim, out_dim, op=ops.Conv2d):
+ super().__init__()
+ self.conv = op(in_dim, out_dim >> 2, 3, 1, 1)
+ self.ratio = (in_dim << 2) // out_dim
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ h2, w2 = h >> 1, w >> 1
+ y = self.conv(x).view(b, -1, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, -1, h2, w2)
+ r = x.view(b, c, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, c << 2, h2, w2)
+ return y + r.view(b, y.shape[1], self.ratio, h2, w2).mean(2)
+
+
+class PixelUnshuffle2D(nn.Module):
+ def __init__(self, in_dim, out_dim, op=ops.Conv2d):
+ super().__init__()
+ self.conv = op(in_dim, out_dim << 2, 3, 1, 1)
+ self.scale = (out_dim << 2) // in_dim
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ h2, w2 = h << 1, w << 1
+ y = self.conv(x).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
+ r = x.repeat_interleave(self.scale, 1).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
+ return y + r
+
+
+class Encoder(nn.Module):
+ def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
+ ffactor_spatial, downsample_match_channel=True, **_):
+ super().__init__()
+ self.z_channels = z_channels
+ self.block_out_channels = block_out_channels
+ self.num_res_blocks = num_res_blocks
+ self.conv_in = ops.Conv2d(in_channels, block_out_channels[0], 3, 1, 1)
+
+ self.down = nn.ModuleList()
+ ch = block_out_channels[0]
+ depth = (ffactor_spatial >> 1).bit_length()
+
+ for i, tgt in enumerate(block_out_channels):
+ stage = nn.Module()
+ stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
+ out_channels=tgt,
+ temb_channels=0,
+ conv_op=ops.Conv2d)
+ for j in range(num_res_blocks)])
+ ch = tgt
+ if i < depth:
+ nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
+ stage.downsample = PixelShuffle2D(ch, nxt, ops.Conv2d)
+ ch = nxt
+ self.down.append(stage)
+
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
+ self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
+ self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
+
+ self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
+ self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1)
+
+ def forward(self, x):
+ x = self.conv_in(x)
+
+ for stage in self.down:
+ for blk in stage.block:
+ x = blk(x)
+ if hasattr(stage, 'downsample'):
+ x = stage.downsample(x)
+
+ x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
+
+ b, c, h, w = x.shape
+ grp = c // (self.z_channels << 1)
+ skip = x.view(b, c // grp, grp, h, w).mean(2)
+
+ return self.conv_out(F.silu(self.norm_out(x))) + skip
+
+
+class Decoder(nn.Module):
+ def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
+ ffactor_spatial, upsample_match_channel=True, **_):
+ super().__init__()
+ block_out_channels = block_out_channels[::-1]
+ self.z_channels = z_channels
+ self.block_out_channels = block_out_channels
+ self.num_res_blocks = num_res_blocks
+
+ ch = block_out_channels[0]
+ self.conv_in = ops.Conv2d(z_channels, ch, 3, 1, 1)
+
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
+ self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
+ self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
+
+ self.up = nn.ModuleList()
+ depth = (ffactor_spatial >> 1).bit_length()
+
+ for i, tgt in enumerate(block_out_channels):
+ stage = nn.Module()
+ stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
+ out_channels=tgt,
+ temb_channels=0,
+ conv_op=ops.Conv2d)
+ for j in range(num_res_blocks + 1)])
+ ch = tgt
+ if i < depth:
+ nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
+ stage.upsample = PixelUnshuffle2D(ch, nxt, ops.Conv2d)
+ ch = nxt
+ self.up.append(stage)
+
+ self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
+ self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1)
+
+ def forward(self, z):
+ x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
+ x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
+
+ for stage in self.up:
+ for blk in stage.block:
+ x = blk(x)
+ if hasattr(stage, 'upsample'):
+ x = stage.upsample(x)
+
+ return self.conv_out(F.silu(self.norm_out(x)))
diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py
new file mode 100644
index 000000000..c6f742710
--- /dev/null
+++ b/comfy/ldm/hunyuan_video/vae_refiner.py
@@ -0,0 +1,267 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d
+import comfy.ops
+import comfy.ldm.models.autoencoder
+ops = comfy.ops.disable_weight_init
+
+class RMS_norm(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ shape = (dim, 1, 1, 1)
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.empty(shape))
+
+ def forward(self, x):
+ return F.normalize(x, dim=1) * self.scale * self.gamma
+
+class DnSmpl(nn.Module):
+ def __init__(self, ic, oc, tds=True):
+ super().__init__()
+ fct = 2 * 2 * 2 if tds else 1 * 2 * 2
+ assert oc % fct == 0
+ self.conv = VideoConv3d(ic, oc // fct, kernel_size=3)
+
+ self.tds = tds
+ self.gs = fct * ic // oc
+
+ def forward(self, x):
+ r1 = 2 if self.tds else 1
+ h = self.conv(x)
+
+ if self.tds:
+ hf = h[:, :, :1, :, :]
+ b, c, f, ht, wd = hf.shape
+ hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
+ hf = hf.permute(0, 4, 6, 1, 2, 3, 5)
+ hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
+ hf = torch.cat([hf, hf], dim=1)
+
+ hn = h[:, :, 1:, :, :]
+ b, c, frms, ht, wd = hn.shape
+ nf = frms // r1
+ hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
+ hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6)
+ hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
+
+ h = torch.cat([hf, hn], dim=2)
+
+ xf = x[:, :, :1, :, :]
+ b, ci, f, ht, wd = xf.shape
+ xf = xf.reshape(b, ci, f, ht // 2, 2, wd // 2, 2)
+ xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
+ xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
+ B, C, T, H, W = xf.shape
+ xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2)
+
+ xn = x[:, :, 1:, :, :]
+ b, ci, frms, ht, wd = xn.shape
+ nf = frms // r1
+ xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
+ xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6)
+ xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
+ B, C, T, H, W = xn.shape
+ xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
+ sc = torch.cat([xf, xn], dim=2)
+ else:
+ b, c, frms, ht, wd = h.shape
+ nf = frms // r1
+ h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
+ h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
+ h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
+
+ b, ci, frms, ht, wd = x.shape
+ nf = frms // r1
+ sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
+ sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6)
+ sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
+ B, C, T, H, W = sc.shape
+ sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
+
+ return h + sc
+
+
+class UpSmpl(nn.Module):
+ def __init__(self, ic, oc, tus=True):
+ super().__init__()
+ fct = 2 * 2 * 2 if tus else 1 * 2 * 2
+ self.conv = VideoConv3d(ic, oc * fct, kernel_size=3)
+
+ self.tus = tus
+ self.rp = fct * oc // ic
+
+ def forward(self, x):
+ r1 = 2 if self.tus else 1
+ h = self.conv(x)
+
+ if self.tus:
+ hf = h[:, :, :1, :, :]
+ b, c, f, ht, wd = hf.shape
+ nc = c // (2 * 2)
+ hf = hf.reshape(b, 2, 2, nc, f, ht, wd)
+ hf = hf.permute(0, 3, 4, 5, 1, 6, 2)
+ hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
+ hf = hf[:, : hf.shape[1] // 2]
+
+ hn = h[:, :, 1:, :, :]
+ b, c, frms, ht, wd = hn.shape
+ nc = c // (r1 * 2 * 2)
+ hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
+ hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3)
+ hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
+
+ h = torch.cat([hf, hn], dim=2)
+
+ xf = x[:, :, :1, :, :]
+ b, ci, f, ht, wd = xf.shape
+ xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1)
+ b, c, f, ht, wd = xf.shape
+ nc = c // (2 * 2)
+ xf = xf.reshape(b, 2, 2, nc, f, ht, wd)
+ xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
+ xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
+
+ xn = x[:, :, 1:, :, :]
+ xn = xn.repeat_interleave(repeats=self.rp, dim=1)
+ b, c, frms, ht, wd = xn.shape
+ nc = c // (r1 * 2 * 2)
+ xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
+ xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3)
+ xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
+ sc = torch.cat([xf, xn], dim=2)
+ else:
+ b, c, frms, ht, wd = h.shape
+ nc = c // (r1 * 2 * 2)
+ h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
+ h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
+ h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
+
+ sc = x.repeat_interleave(repeats=self.rp, dim=1)
+ b, c, frms, ht, wd = sc.shape
+ nc = c // (r1 * 2 * 2)
+ sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd)
+ sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3)
+ sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2)
+
+ return h + sc
+
+class Encoder(nn.Module):
+ def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
+ ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_):
+ super().__init__()
+ self.z_channels = z_channels
+ self.block_out_channels = block_out_channels
+ self.num_res_blocks = num_res_blocks
+ self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1)
+
+ self.down = nn.ModuleList()
+ ch = block_out_channels[0]
+ depth = (ffactor_spatial >> 1).bit_length()
+ depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length()
+
+ for i, tgt in enumerate(block_out_channels):
+ stage = nn.Module()
+ stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
+ out_channels=tgt,
+ temb_channels=0,
+ conv_op=VideoConv3d, norm_op=RMS_norm)
+ for j in range(num_res_blocks)])
+ ch = tgt
+ if i < depth:
+ nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
+ stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal)
+ ch = nxt
+ self.down.append(stage)
+
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
+ self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
+ self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
+
+ self.norm_out = RMS_norm(ch)
+ self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1)
+
+ self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
+
+ def forward(self, x):
+ x = self.conv_in(x)
+
+ for stage in self.down:
+ for blk in stage.block:
+ x = blk(x)
+ if hasattr(stage, 'downsample'):
+ x = stage.downsample(x)
+
+ x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
+
+ b, c, t, h, w = x.shape
+ grp = c // (self.z_channels << 1)
+ skip = x.view(b, c // grp, grp, t, h, w).mean(2)
+
+ out = self.conv_out(F.silu(self.norm_out(x))) + skip
+ out = self.regul(out)[0]
+
+ out = torch.cat((out[:, :, :1], out), dim=2)
+ out = out.permute(0, 2, 1, 3, 4)
+ b, f_times_2, c, h, w = out.shape
+ out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
+ out = out.permute(0, 2, 1, 3, 4).contiguous()
+ return out
+
+class Decoder(nn.Module):
+ def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
+ ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_):
+ super().__init__()
+ block_out_channels = block_out_channels[::-1]
+ self.z_channels = z_channels
+ self.block_out_channels = block_out_channels
+ self.num_res_blocks = num_res_blocks
+
+ ch = block_out_channels[0]
+ self.conv_in = VideoConv3d(z_channels, ch, 3)
+
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
+ self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
+ self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
+
+ self.up = nn.ModuleList()
+ depth = (ffactor_spatial >> 1).bit_length()
+ depth_temporal = (ffactor_temporal >> 1).bit_length()
+
+ for i, tgt in enumerate(block_out_channels):
+ stage = nn.Module()
+ stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
+ out_channels=tgt,
+ temb_channels=0,
+ conv_op=VideoConv3d, norm_op=RMS_norm)
+ for j in range(num_res_blocks + 1)])
+ ch = tgt
+ if i < depth:
+ nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
+ stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal)
+ ch = nxt
+ self.up.append(stage)
+
+ self.norm_out = RMS_norm(ch)
+ self.conv_out = VideoConv3d(ch, out_channels, 3)
+
+ def forward(self, z):
+ z = z.permute(0, 2, 1, 3, 4)
+ b, f, c, h, w = z.shape
+ z = z.reshape(b, f, 2, c // 2, h, w)
+ z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
+ z = z.permute(0, 2, 1, 3, 4)
+ z = z[:, :, 1:]
+
+ x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
+ x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
+
+ for stage in self.up:
+ for blk in stage.block:
+ x = blk(x)
+ if hasattr(stage, 'upsample'):
+ x = stage.upsample(x)
+
+ return self.conv_out(F.silu(self.norm_out(x)))
diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py
index ad9a7daea..def365ba7 100644
--- a/comfy/ldm/lightricks/model.py
+++ b/comfy/ldm/lightricks/model.py
@@ -1,5 +1,6 @@
import torch
from torch import nn
+import comfy.patcher_extension
import comfy.ldm.modules.attention
import comfy.ldm.common_dit
from einops import rearrange
@@ -270,7 +271,7 @@ class CrossAttention(nn.Module):
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
- def forward(self, x, context=None, mask=None, pe=None):
+ def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
q = self.to_q(x)
context = x if context is None else context
k = self.to_k(context)
@@ -284,9 +285,9 @@ class CrossAttention(nn.Module):
k = apply_rotary_emb(k, pe)
if mask is None:
- out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
+ out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
else:
- out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
+ out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
return self.to_out(out)
@@ -302,12 +303,12 @@ class BasicTransformerBlock(nn.Module):
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
- def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
+ def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
- x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
+ x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
- x += self.attn2(x, context=context, mask=attention_mask)
+ x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
x += self.ff(y) * gate_mlp
@@ -420,6 +421,13 @@ class LTXVModel(torch.nn.Module):
self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
+
+ def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
orig_shape = list(x.shape)
@@ -471,10 +479,10 @@ class LTXVModel(torch.nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
+ out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
@@ -482,7 +490,8 @@ class LTXVModel(torch.nn.Module):
context=context,
attention_mask=attention_mask,
timestep=timestep,
- pe=pe
+ pe=pe,
+ transformer_options=transformer_options,
)
# 3. Output
diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py
index f8dc4d7db..f87d98ac0 100644
--- a/comfy/ldm/lumina/model.py
+++ b/comfy/ldm/lumina/model.py
@@ -11,6 +11,7 @@ import comfy.ldm.common_dit
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
+import comfy.patcher_extension
def modulate(x, scale):
@@ -103,6 +104,7 @@ class JointAttention(nn.Module):
x: torch.Tensor,
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
+ transformer_options={},
) -> torch.Tensor:
"""
@@ -139,7 +141,7 @@ class JointAttention(nn.Module):
if n_rep >= 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
- output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
+ output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options)
return self.out(output)
@@ -267,6 +269,7 @@ class JointTransformerBlock(nn.Module):
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor]=None,
+ transformer_options={},
):
"""
Perform a forward pass through the TransformerBlock.
@@ -289,6 +292,7 @@ class JointTransformerBlock(nn.Module):
modulate(self.attention_norm1(x), scale_msa),
x_mask,
freqs_cis,
+ transformer_options=transformer_options,
)
)
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
@@ -303,6 +307,7 @@ class JointTransformerBlock(nn.Module):
self.attention_norm1(x),
x_mask,
freqs_cis,
+ transformer_options=transformer_options,
)
)
x = x + self.ffn_norm2(
@@ -493,7 +498,7 @@ class NextDiT(nn.Module):
return imgs
def patchify_and_embed(
- self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
+ self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
bsz = len(x)
pH = pW = self.patch_size
@@ -553,7 +558,7 @@ class NextDiT(nn.Module):
# refine context
for layer in self.context_refiner:
- cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
+ cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
# refine image
flat_x = []
@@ -572,7 +577,7 @@ class NextDiT(nn.Module):
padded_img_embed = self.x_embedder(padded_img_embed)
padded_img_mask = padded_img_mask.unsqueeze(1)
for layer in self.noise_refiner:
- padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
+ padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
if cap_mask is not None:
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
@@ -590,8 +595,15 @@ class NextDiT(nn.Module):
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
- # def forward(self, x, t, cap_feats, cap_mask):
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
+ ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
+
+ # def forward(self, x, t, cap_feats, cap_mask):
+ def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
@@ -608,12 +620,13 @@ class NextDiT(nn.Module):
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
+ transformer_options = kwargs.get("transformer_options", {})
x_is_tensor = isinstance(x, torch.Tensor)
- x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
+ x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(x.device)
for layer in self.layers:
- x = layer(x, mask, freqs_cis, adaln_input)
+ x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
x = self.final_layer(x, adaln_input)
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py
index 13bd6e16b..611d36a1b 100644
--- a/comfy/ldm/models/autoencoder.py
+++ b/comfy/ldm/models/autoencoder.py
@@ -26,6 +26,12 @@ class DiagonalGaussianRegularizer(torch.nn.Module):
z = posterior.mode()
return z, None
+class EmptyRegularizer(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
+ return z, None
class AbstractAutoencoder(torch.nn.Module):
"""
diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py
index 043df28df..7437e0567 100644
--- a/comfy/ldm/modules/attention.py
+++ b/comfy/ldm/modules/attention.py
@@ -5,8 +5,9 @@ import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
-from typing import Optional
+from typing import Optional, Any, Callable, Union
import logging
+import functools
from .diffusionmodules.util import AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention
@@ -17,23 +18,45 @@ if model_management.xformers_enabled():
import xformers
import xformers.ops
-if model_management.sage_attention_enabled():
- try:
- from sageattention import sageattn
- except ModuleNotFoundError as e:
+SAGE_ATTENTION_IS_AVAILABLE = False
+try:
+ from sageattention import sageattn
+ SAGE_ATTENTION_IS_AVAILABLE = True
+except ImportError as e:
+ if model_management.sage_attention_enabled():
if e.name == "sageattention":
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
else:
raise e
exit(-1)
-if model_management.flash_attention_enabled():
- try:
- from flash_attn import flash_attn_func
- except ModuleNotFoundError:
+FLASH_ATTENTION_IS_AVAILABLE = False
+try:
+ from flash_attn import flash_attn_func
+ FLASH_ATTENTION_IS_AVAILABLE = True
+except ImportError:
+ if model_management.flash_attention_enabled():
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
+REGISTERED_ATTENTION_FUNCTIONS = {}
+def register_attention_function(name: str, func: Callable):
+ # avoid replacing existing functions
+ if name not in REGISTERED_ATTENTION_FUNCTIONS:
+ REGISTERED_ATTENTION_FUNCTIONS[name] = func
+ else:
+ logging.warning(f"Attention function {name} already registered, skipping registration.")
+
+def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]:
+ if name == "optimized":
+ return optimized_attention
+ elif name not in REGISTERED_ATTENTION_FUNCTIONS:
+ if default is ...:
+ raise KeyError(f"Attention function {name} not found.")
+ else:
+ return default
+ return REGISTERED_ATTENTION_FUNCTIONS[name]
+
from comfy.cli_args import args
import comfy.ops
ops = comfy.ops.disable_weight_init
@@ -91,7 +114,27 @@ class FeedForward(nn.Module):
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
-def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
+
+def wrap_attn(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ remove_attn_wrapper_key = False
+ try:
+ if "_inside_attn_wrapper" not in kwargs:
+ transformer_options = kwargs.get("transformer_options", None)
+ remove_attn_wrapper_key = True
+ kwargs["_inside_attn_wrapper"] = True
+ if transformer_options is not None:
+ if "optimized_attention_override" in transformer_options:
+ return transformer_options["optimized_attention_override"](func, *args, **kwargs)
+ return func(*args, **kwargs)
+ finally:
+ if remove_attn_wrapper_key:
+ del kwargs["_inside_attn_wrapper"]
+ return wrapper
+
+@wrap_attn
+def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape:
@@ -159,8 +202,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
)
return out
-
-def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
+@wrap_attn
+def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, query.dtype)
if skip_reshape:
@@ -230,7 +273,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states
-def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
+@wrap_attn
+def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape:
@@ -359,7 +403,8 @@ try:
except:
pass
-def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
+@wrap_attn
+def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
b = q.shape[0]
dim_head = q.shape[-1]
# check to make sure xformers isn't broken
@@ -374,7 +419,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
disabled_xformers = True
if disabled_xformers:
- return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
+ return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
if skip_reshape:
# b h k d -> b k h d
@@ -427,8 +472,8 @@ else:
#TODO: other GPUs ?
SDP_BATCH_LIMIT = 2**31
-
-def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
+@wrap_attn
+def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
@@ -470,8 +515,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out
-
-def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
+@wrap_attn
+def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout = "HND"
@@ -501,7 +546,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
lambda t: t.transpose(1, 2),
(q, k, v),
)
- return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
+ return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs)
if tensor_layout == "HND":
if not skip_output_reshape:
@@ -534,8 +579,8 @@ except AttributeError as error:
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
-
-def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
+@wrap_attn
+def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
@@ -555,7 +600,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
mask = mask.unsqueeze(1)
try:
- assert mask is None
+ if mask is not None:
+ raise RuntimeError("Mask must not be set for Flash attention")
out = flash_attn_wrapper(
q.transpose(1, 2),
k.transpose(1, 2),
@@ -597,6 +643,19 @@ else:
optimized_attention_masked = optimized_attention
+
+# register core-supported attention functions
+if SAGE_ATTENTION_IS_AVAILABLE:
+ register_attention_function("sage", attention_sage)
+if FLASH_ATTENTION_IS_AVAILABLE:
+ register_attention_function("flash", attention_flash)
+if model_management.xformers_enabled():
+ register_attention_function("xformers", attention_xformers)
+register_attention_function("pytorch", attention_pytorch)
+register_attention_function("sub_quad", attention_sub_quad)
+register_attention_function("split", attention_split)
+
+
def optimized_attention_for_device(device, mask=False, small_input=False):
if small_input:
if model_management.pytorch_attention_enabled():
@@ -629,7 +688,7 @@ class CrossAttention(nn.Module):
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
- def forward(self, x, context=None, value=None, mask=None):
+ def forward(self, x, context=None, value=None, mask=None, transformer_options={}):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
@@ -640,9 +699,9 @@ class CrossAttention(nn.Module):
v = self.to_v(context)
if mask is None:
- out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
+ out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
else:
- out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
+ out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
return self.to_out(out)
@@ -746,7 +805,7 @@ class BasicTransformerBlock(nn.Module):
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
n = self.attn1.to_out(n)
else:
- n = self.attn1(n, context=context_attn1, value=value_attn1)
+ n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=transformer_options)
if "attn1_output_patch" in transformer_patches:
patch = transformer_patches["attn1_output_patch"]
@@ -786,7 +845,7 @@ class BasicTransformerBlock(nn.Module):
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
n = self.attn2.to_out(n)
else:
- n = self.attn2(n, context=context_attn2, value=value_attn2)
+ n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=transformer_options)
if "attn2_output_patch" in transformer_patches:
patch = transformer_patches["attn2_output_patch"]
@@ -1017,7 +1076,7 @@ class SpatialVideoTransformer(SpatialTransformer):
B, S, C = x_mix.shape
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
- x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
+ x_mix = mix_block(x_mix, context=time_context, transformer_options=transformer_options)
x_mix = rearrange(
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
)
diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py
index eaf3e73a4..42f406f1a 100644
--- a/comfy/ldm/modules/diffusionmodules/mmdit.py
+++ b/comfy/ldm/modules/diffusionmodules/mmdit.py
@@ -109,7 +109,7 @@ class PatchEmbed(nn.Module):
def modulate(x, shift, scale):
if shift is None:
shift = torch.zeros_like(scale)
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+ return torch.addcmul(shift.unsqueeze(1), x, 1+ scale.unsqueeze(1))
#################################################################################
@@ -564,10 +564,7 @@ class DismantledBlock(nn.Module):
assert not self.pre_only
attn1 = self.attn.post_attention(attn)
attn2 = self.attn2.post_attention(attn2)
- out1 = gate_msa.unsqueeze(1) * attn1
- out2 = gate_msa2.unsqueeze(1) * attn2
- x = x + out1
- x = x + out2
+ x = gate_cat(x, gate_msa, gate_msa2, attn1, attn2)
x = x + gate_mlp.unsqueeze(1) * self.mlp(
modulate(self.norm2(x), shift_mlp, scale_mlp)
)
@@ -594,6 +591,11 @@ class DismantledBlock(nn.Module):
)
return self.post_attention(attn, *intermediates)
+def gate_cat(x, gate_msa, gate_msa2, attn1, attn2):
+ out1 = gate_msa.unsqueeze(1) * attn1
+ out2 = gate_msa2.unsqueeze(1) * attn2
+ x = torch.stack([x, out1, out2], dim=0).sum(dim=0)
+ return x
def block_mixing(*args, use_checkpoint=True, **kwargs):
if use_checkpoint:
@@ -604,7 +606,7 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
return _block_mixing(*args, **kwargs)
-def _block_mixing(context, x, context_block, x_block, c):
+def _block_mixing(context, x, context_block, x_block, c, transformer_options={}):
context_qkv, context_intermediates = context_block.pre_attention(context, c)
if x_block.x_block_self_attn:
@@ -620,6 +622,7 @@ def _block_mixing(context, x, context_block, x_block, c):
attn = optimized_attention(
qkv[0], qkv[1], qkv[2],
heads=x_block.attn.num_heads,
+ transformer_options=transformer_options,
)
context_attn, x_attn = (
attn[:, : context_qkv[0].shape[1]],
@@ -635,6 +638,7 @@ def _block_mixing(context, x, context_block, x_block, c):
attn2 = optimized_attention(
x_qkv2[0], x_qkv2[1], x_qkv2[2],
heads=x_block.attn2.num_heads,
+ transformer_options=transformer_options,
)
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
else:
@@ -956,10 +960,10 @@ class MMDiT(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
+ out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"], transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod, "transformer_options": transformer_options}, {"original_block": block_wrap})
context = out["txt"]
x = out["img"]
else:
@@ -968,6 +972,7 @@ class MMDiT(nn.Module):
x,
c=c_mod,
use_checkpoint=self.use_checkpoint,
+ transformer_options=transformer_options,
)
if control is not None:
control_o = control.get("output")
diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py
index 1fd12b35a..4245eedca 100644
--- a/comfy/ldm/modules/diffusionmodules/model.py
+++ b/comfy/ldm/modules/diffusionmodules/model.py
@@ -145,7 +145,7 @@ class Downsample(nn.Module):
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
- dropout, temb_channels=512, conv_op=ops.Conv2d):
+ dropout=0.0, temb_channels=512, conv_op=ops.Conv2d, norm_op=Normalize):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
@@ -153,7 +153,7 @@ class ResnetBlock(nn.Module):
self.use_conv_shortcut = conv_shortcut
self.swish = torch.nn.SiLU(inplace=True)
- self.norm1 = Normalize(in_channels)
+ self.norm1 = norm_op(in_channels)
self.conv1 = conv_op(in_channels,
out_channels,
kernel_size=3,
@@ -162,7 +162,7 @@ class ResnetBlock(nn.Module):
if temb_channels > 0:
self.temb_proj = ops.Linear(temb_channels,
out_channels)
- self.norm2 = Normalize(out_channels)
+ self.norm2 = norm_op(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = conv_op(out_channels,
out_channels,
@@ -183,7 +183,7 @@ class ResnetBlock(nn.Module):
stride=1,
padding=0)
- def forward(self, x, temb):
+ def forward(self, x, temb=None):
h = x
h = self.norm1(h)
h = self.swish(h)
@@ -305,11 +305,11 @@ def vae_attention():
return normal_attention
class AttnBlock(nn.Module):
- def __init__(self, in_channels, conv_op=ops.Conv2d):
+ def __init__(self, in_channels, conv_op=ops.Conv2d, norm_op=Normalize):
super().__init__()
self.in_channels = in_channels
- self.norm = Normalize(in_channels)
+ self.norm = norm_op(in_channels)
self.q = conv_op(in_channels,
in_channels,
kernel_size=1,
diff --git a/comfy/ldm/omnigen/omnigen2.py b/comfy/ldm/omnigen/omnigen2.py
index 4884449f8..82edc92da 100644
--- a/comfy/ldm/omnigen/omnigen2.py
+++ b/comfy/ldm/omnigen/omnigen2.py
@@ -120,7 +120,7 @@ class Attention(nn.Module):
nn.Dropout(0.0)
)
- def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
query = self.to_q(hidden_states)
@@ -146,7 +146,7 @@ class Attention(nn.Module):
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
- hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
+ hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
hidden_states = self.to_out[0](hidden_states)
return hidden_states
@@ -182,16 +182,16 @@ class OmniGen2TransformerBlock(nn.Module):
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
- def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
if self.modulation:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
- attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
+ attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
else:
norm_hidden_states = self.norm1(hidden_states)
- attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
+ attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
hidden_states = hidden_states + self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
@@ -390,7 +390,7 @@ class OmniGen2Transformer2DModel(nn.Module):
ref_img_sizes, img_sizes,
)
- def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
+ def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, transformer_options={}):
batch_size = len(hidden_states)
hidden_states = self.x_embedder(hidden_states)
@@ -405,17 +405,17 @@ class OmniGen2Transformer2DModel(nn.Module):
shift += ref_img_len
for layer in self.noise_refiner:
- hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
+ hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb, transformer_options=transformer_options)
if ref_image_hidden_states is not None:
for layer in self.ref_image_refiner:
- ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
+ ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb, transformer_options=transformer_options)
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
return hidden_states
- def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
+ def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs):
B, C, H, W = x.shape
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
_, _, H_padded, W_padded = hidden_states.shape
@@ -444,7 +444,7 @@ class OmniGen2Transformer2DModel(nn.Module):
)
for layer in self.context_refiner:
- text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
+ text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options)
img_len = hidden_states.shape[1]
combined_img_hidden_states = self.img_patch_embed_and_refine(
@@ -453,13 +453,14 @@ class OmniGen2Transformer2DModel(nn.Module):
noise_rotary_emb, ref_img_rotary_emb,
l_effective_ref_img_len, l_effective_img_len,
temb,
+ transformer_options=transformer_options,
)
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
attention_mask = None
for layer in self.layers:
- hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
+ hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb, transformer_options=transformer_options)
hidden_states = self.norm_out(hidden_states, temb)
diff --git a/comfy/ldm/qwen_image/controlnet.py b/comfy/ldm/qwen_image/controlnet.py
new file mode 100644
index 000000000..92ac3cf0a
--- /dev/null
+++ b/comfy/ldm/qwen_image/controlnet.py
@@ -0,0 +1,77 @@
+import torch
+import math
+
+from .model import QwenImageTransformer2DModel
+
+
+class QwenImageControlNetModel(QwenImageTransformer2DModel):
+ def __init__(
+ self,
+ extra_condition_channels=0,
+ dtype=None,
+ device=None,
+ operations=None,
+ **kwargs
+ ):
+ super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
+ self.main_model_double = 60
+
+ # controlnet_blocks
+ self.controlnet_blocks = torch.nn.ModuleList([])
+ for _ in range(len(self.transformer_blocks)):
+ self.controlnet_blocks.append(operations.Linear(self.inner_dim, self.inner_dim, device=device, dtype=dtype))
+ self.controlnet_x_embedder = operations.Linear(self.in_channels + extra_condition_channels, self.inner_dim, device=device, dtype=dtype)
+
+ def forward(
+ self,
+ x,
+ timesteps,
+ context,
+ attention_mask=None,
+ guidance: torch.Tensor = None,
+ ref_latents=None,
+ hint=None,
+ transformer_options={},
+ **kwargs
+ ):
+ timestep = timesteps
+ encoder_hidden_states = context
+ encoder_hidden_states_mask = attention_mask
+
+ hidden_states, img_ids, orig_shape = self.process_img(x)
+ hint, _, _ = self.process_img(hint)
+
+ txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
+ txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
+ ids = torch.cat((txt_ids, img_ids), dim=1)
+ image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
+ del ids, txt_ids, img_ids
+
+ hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
+
+ if guidance is not None:
+ guidance = guidance * 1000
+
+ temb = (
+ self.time_text_embed(timestep, hidden_states)
+ if guidance is None
+ else self.time_text_embed(timestep, guidance, hidden_states)
+ )
+
+ repeat = math.ceil(self.main_model_double / len(self.controlnet_blocks))
+
+ controlnet_block_samples = ()
+ for i, block in enumerate(self.transformer_blocks):
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ controlnet_block_samples = controlnet_block_samples + (self.controlnet_blocks[i](hidden_states),) * repeat
+
+ return {"input": controlnet_block_samples[:self.main_model_double]}
diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py
index a3c726299..b9f60c2b7 100644
--- a/comfy/ldm/qwen_image/model.py
+++ b/comfy/ldm/qwen_image/model.py
@@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
import comfy.ldm.common_dit
+import comfy.patcher_extension
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
@@ -131,6 +132,7 @@ class Attention(nn.Module):
encoder_hidden_states_mask: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
+ transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
seq_txt = encoder_hidden_states.shape[1]
@@ -158,7 +160,7 @@ class Attention(nn.Module):
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
- joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask)
+ joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]
@@ -214,9 +216,9 @@ class QwenImageTransformerBlock(nn.Module):
operations=operations,
)
- def _modulate(self, x, mod_params):
- shift, scale, gate = mod_params.chunk(3, dim=-1)
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
+ def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
+ return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
def forward(
self,
@@ -225,6 +227,7 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb)
txt_mod_params = self.txt_mod(temb)
@@ -241,6 +244,7 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states=txt_modulated,
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb,
+ transformer_options=transformer_options,
)
hidden_states = hidden_states + img_gate1 * img_attn_output
@@ -248,11 +252,11 @@ class QwenImageTransformerBlock(nn.Module):
img_normed2 = self.img_norm2(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
- hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
+ hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
txt_normed2 = self.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
- encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
+ encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
return encoder_hidden_states, hidden_states
@@ -275,7 +279,7 @@ class LastLayer(nn.Module):
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
emb = self.linear(self.silu(conditioning_embedding))
scale, shift = torch.chunk(emb, 2, dim=1)
- x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
+ x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :])
return x
@@ -293,6 +297,7 @@ class QwenImageTransformer2DModel(nn.Module):
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
image_model=None,
+ final_layer=True,
dtype=None,
device=None,
operations=None,
@@ -300,6 +305,7 @@ class QwenImageTransformer2DModel(nn.Module):
super().__init__()
self.dtype = dtype
self.patch_size = patch_size
+ self.in_channels = in_channels
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
@@ -329,9 +335,9 @@ class QwenImageTransformer2DModel(nn.Module):
for _ in range(num_layers)
])
- self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
- self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
- self.gradient_checkpointing = False
+ if final_layer:
+ self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
+ self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
def process_img(self, x, index=0, h_offset=0, w_offset=0):
bs, c, t, h, w = x.shape
@@ -347,13 +353,20 @@ class QwenImageTransformer2DModel(nn.Module):
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
- img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
+ img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
- img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
- img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
+ img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
+ img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
- def forward(
+ def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
+
+ def _forward(
self,
x,
timesteps,
@@ -362,6 +375,7 @@ class QwenImageTransformer2DModel(nn.Module):
guidance: torch.Tensor = None,
ref_latents=None,
transformer_options={},
+ control=None,
**kwargs
):
timestep = timesteps
@@ -396,10 +410,11 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
- txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size)))
- txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
+ txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
+ txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
+ del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
@@ -415,15 +430,16 @@ class QwenImageTransformer2DModel(nn.Module):
)
patches_replace = transformer_options.get("patches_replace", {})
+ patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"])
+ out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap})
+ out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
else:
@@ -433,8 +449,22 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
+ transformer_options=transformer_options,
)
+ if "double_block" in patches:
+ for p in patches["double_block"]:
+ out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options})
+ hidden_states = out["img"]
+ encoder_hidden_states = out["txt"]
+
+ if control is not None: # Controlnet
+ control_i = control.get("input")
+ if i < len(control_i):
+ add = control_i[i]
+ if add is not None:
+ hidden_states[:, :add.shape[1]] += add
+
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py
index 9d3741be3..9cf3c171d 100644
--- a/comfy/ldm/wan/model.py
+++ b/comfy/ldm/wan/model.py
@@ -4,13 +4,14 @@ import math
import torch
import torch.nn as nn
-from einops import repeat
+from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.layers import EmbedND
-from comfy.ldm.flux.math import apply_rope
+from comfy.ldm.flux.math import apply_rope1
import comfy.ldm.common_dit
import comfy.model_management
+import comfy.patcher_extension
def sinusoidal_embedding_1d(dim, position):
@@ -33,7 +34,9 @@ class WanSelfAttention(nn.Module):
num_heads,
window_size=(-1, -1),
qk_norm=True,
- eps=1e-6, operation_settings={}):
+ eps=1e-6,
+ kv_dim=None,
+ operation_settings={}):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
@@ -42,16 +45,18 @@ class WanSelfAttention(nn.Module):
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
+ if kv_dim is None:
+ kv_dim = dim
# layers
self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
- self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
- self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.k = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.v = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
- def forward(self, x, freqs):
+ def forward(self, x, freqs, transformer_options={}):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
@@ -59,21 +64,26 @@ class WanSelfAttention(nn.Module):
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
- # query, key, value function
- def qkv_fn(x):
+ def qkv_fn_q(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
- k = self.norm_k(self.k(x)).view(b, s, n, d)
- v = self.v(x).view(b, s, n * d)
- return q, k, v
+ return apply_rope1(q, freqs)
- q, k, v = qkv_fn(x)
- q, k = apply_rope(q, k, freqs)
+ def qkv_fn_k(x):
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
+ return apply_rope1(k, freqs)
+
+ #These two are VRAM hogs, so we want to do all of q computation and
+ #have pytorch garbage collect the intermediates on the sub function
+ #return before we touch k
+ q = qkv_fn_q(x)
+ k = qkv_fn_k(x)
x = optimized_attention(
q.view(b, s, n * d),
k.view(b, s, n * d),
- v,
+ self.v(x).view(b, s, n * d),
heads=self.num_heads,
+ transformer_options=transformer_options,
)
x = self.o(x)
@@ -82,7 +92,7 @@ class WanSelfAttention(nn.Module):
class WanT2VCrossAttention(WanSelfAttention):
- def forward(self, x, context, **kwargs):
+ def forward(self, x, context, transformer_options={}, **kwargs):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@@ -94,7 +104,7 @@ class WanT2VCrossAttention(WanSelfAttention):
v = self.v(context)
# compute attention
- x = optimized_attention(q, k, v, heads=self.num_heads)
+ x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
x = self.o(x)
return x
@@ -115,7 +125,7 @@ class WanI2VCrossAttention(WanSelfAttention):
# self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
- def forward(self, x, context, context_img_len):
+ def forward(self, x, context, context_img_len, transformer_options={}):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@@ -130,9 +140,9 @@ class WanI2VCrossAttention(WanSelfAttention):
v = self.v(context)
k_img = self.norm_k_img(self.k_img(context_img))
v_img = self.v_img(context_img)
- img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads)
+ img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options)
# compute attention
- x = optimized_attention(q, k, v, heads=self.num_heads)
+ x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
# output
x = x + img_x
@@ -148,11 +158,14 @@ WAN_CROSSATTENTION_CLASSES = {
def repeat_e(e, x):
repeats = 1
- if e.shape[1] > 1:
- repeats = x.shape[1] // e.shape[1]
+ if e.size(1) > 1:
+ repeats = x.size(1) // e.size(1)
if repeats == 1:
return e
- return torch.repeat_interleave(e, repeats, dim=1)
+ if repeats * e.size(1) == x.size(1):
+ return torch.repeat_interleave(e, repeats, dim=1)
+ else:
+ return torch.repeat_interleave(e, repeats + 1, dim=1)[:, :x.size(1)]
class WanAttentionBlock(nn.Module):
@@ -202,6 +215,7 @@ class WanAttentionBlock(nn.Module):
freqs,
context,
context_img_len=257,
+ transformer_options={},
):
r"""
Args:
@@ -219,15 +233,15 @@ class WanAttentionBlock(nn.Module):
# self-attention
y = self.self_attn(
- self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x),
- freqs)
+ torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
+ freqs, transformer_options=transformer_options)
- x = x + y * repeat_e(e[2], x)
+ x = torch.addcmul(x, y, repeat_e(e[2], x))
# cross-attention & ffn
- x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
- y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x))
- x = x + y * repeat_e(e[5], x)
+ x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
+ y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
+ x = torch.addcmul(x, y, repeat_e(e[5], x))
return x
@@ -342,7 +356,7 @@ class Head(nn.Module):
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
- x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x)))
+ x = (self.head(torch.addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x))))
return x
@@ -392,6 +406,7 @@ class WanModel(torch.nn.Module):
eps=1e-6,
flf_pos_embed_token_number=None,
in_dim_ref_conv=None,
+ wan_attn_block_class=WanAttentionBlock,
image_model=None,
device=None,
dtype=None,
@@ -469,8 +484,8 @@ class WanModel(torch.nn.Module):
# blocks
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
self.blocks = nn.ModuleList([
- WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
- window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
+ wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
+ window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
for _ in range(num_layers)
])
@@ -555,12 +570,12 @@ class WanModel(torch.nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
+ out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
- x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
+ x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
# head
x = self.head(x, e)
@@ -572,30 +587,49 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes)
return x
- def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
- bs, c, t, h, w = x.shape
- x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
-
+ def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
+ if steps_t is None:
+ steps_t = t_len
+ if steps_h is None:
+ steps_h = h_len
+ if steps_w is None:
+ steps_w = w_len
+
+ img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
+ img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
+ img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
+ img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
+ img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
+
+ freqs = self.rope_embedder(img_ids).movedim(1, 2)
+ return freqs
+
+ def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, timestep, context, clip_fea, time_dim_concat, transformer_options, **kwargs)
+
+ def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
+ bs, c, t, h, w = x.shape
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
+
+ t_len = t
if time_dim_concat is not None:
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
x = torch.cat([x, time_dim_concat], dim=2)
- t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
+ t_len = x.shape[2]
if self.ref_conv is not None and "reference_latent" in kwargs:
t_len += 1
- img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
- img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
- img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
- img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
- img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
-
- freqs = self.rope_embedder(img_ids).movedim(1, 2)
+ freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):
@@ -719,17 +753,17 @@ class VaceWanModel(WanModel):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
+ out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
- x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
+ x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
ii = self.vace_layers_mapping.get(i, None)
if ii is not None:
for iii in range(len(c)):
- c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
+ c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
x += c_skip * vace_strength[iii]
del c_skip
# head
@@ -818,12 +852,721 @@ class CameraWanModel(WanModel):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
+ out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
- x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
+ x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
+
+ # head
+ x = self.head(x, e)
+
+ # unpatchify
+ x = self.unpatchify(x, grid_sizes)
+ return x
+
+
+class CausalConv1d(nn.Module):
+
+ def __init__(self,
+ chan_in,
+ chan_out,
+ kernel_size=3,
+ stride=1,
+ dilation=1,
+ pad_mode='replicate',
+ operations=None,
+ **kwargs):
+ super().__init__()
+
+ self.pad_mode = pad_mode
+ padding = (kernel_size - 1, 0) # T
+ self.time_causal_padding = padding
+
+ self.conv = operations.Conv1d(
+ chan_in,
+ chan_out,
+ kernel_size,
+ stride=stride,
+ dilation=dilation,
+ **kwargs)
+
+ def forward(self, x):
+ x = torch.nn.functional.pad(x, self.time_causal_padding, mode=self.pad_mode)
+ return self.conv(x)
+
+
+class MotionEncoder_tc(nn.Module):
+
+ def __init__(self,
+ in_dim: int,
+ hidden_dim: int,
+ num_heads=int,
+ need_global=True,
+ dtype=None,
+ device=None,
+ operations=None,):
+ factory_kwargs = {"dtype": dtype, "device": device}
+ super().__init__()
+
+ self.num_heads = num_heads
+ self.need_global = need_global
+ self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1, operations=operations, **factory_kwargs)
+ if need_global:
+ self.conv1_global = CausalConv1d(
+ in_dim, hidden_dim // 4, 3, stride=1, operations=operations, **factory_kwargs)
+ self.norm1 = operations.LayerNorm(
+ hidden_dim // 4,
+ elementwise_affine=False,
+ eps=1e-6,
+ **factory_kwargs)
+ self.act = nn.SiLU()
+ self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2, operations=operations, **factory_kwargs)
+ self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2, operations=operations, **factory_kwargs)
+
+ if need_global:
+ self.final_linear = operations.Linear(hidden_dim, hidden_dim, **factory_kwargs)
+
+ self.norm1 = operations.LayerNorm(
+ hidden_dim // 4,
+ elementwise_affine=False,
+ eps=1e-6,
+ **factory_kwargs)
+
+ self.norm2 = operations.LayerNorm(
+ hidden_dim // 2,
+ elementwise_affine=False,
+ eps=1e-6,
+ **factory_kwargs)
+
+ self.norm3 = operations.LayerNorm(
+ hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ self.padding_tokens = nn.Parameter(torch.empty(1, 1, 1, hidden_dim, **factory_kwargs))
+
+ def forward(self, x):
+ x = rearrange(x, 'b t c -> b c t')
+ x_ori = x.clone()
+ b, c, t = x.shape
+ x = self.conv1_local(x)
+ x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)
+ x = self.norm1(x)
+ x = self.act(x)
+ x = rearrange(x, 'b t c -> b c t')
+ x = self.conv2(x)
+ x = rearrange(x, 'b c t -> b t c')
+ x = self.norm2(x)
+ x = self.act(x)
+ x = rearrange(x, 'b t c -> b c t')
+ x = self.conv3(x)
+ x = rearrange(x, 'b c t -> b t c')
+ x = self.norm3(x)
+ x = self.act(x)
+ x = rearrange(x, '(b n) t c -> b t n c', b=b)
+ padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1)
+ x = torch.cat([x, padding], dim=-2)
+ x_local = x.clone()
+
+ if not self.need_global:
+ return x_local
+
+ x = self.conv1_global(x_ori)
+ x = rearrange(x, 'b c t -> b t c')
+ x = self.norm1(x)
+ x = self.act(x)
+ x = rearrange(x, 'b t c -> b c t')
+ x = self.conv2(x)
+ x = rearrange(x, 'b c t -> b t c')
+ x = self.norm2(x)
+ x = self.act(x)
+ x = rearrange(x, 'b t c -> b c t')
+ x = self.conv3(x)
+ x = rearrange(x, 'b c t -> b t c')
+ x = self.norm3(x)
+ x = self.act(x)
+ x = self.final_linear(x)
+ x = rearrange(x, '(b n) t c -> b t n c', b=b)
+
+ return x, x_local
+
+
+class CausalAudioEncoder(nn.Module):
+
+ def __init__(self,
+ dim=5120,
+ num_layers=25,
+ out_dim=2048,
+ video_rate=8,
+ num_token=4,
+ need_global=False,
+ dtype=None,
+ device=None,
+ operations=None):
+ super().__init__()
+ self.encoder = MotionEncoder_tc(
+ in_dim=dim,
+ hidden_dim=out_dim,
+ num_heads=num_token,
+ need_global=need_global, dtype=dtype, device=device, operations=operations)
+ weight = torch.empty((1, num_layers, 1, 1), dtype=dtype, device=device)
+
+ self.weights = torch.nn.Parameter(weight)
+ self.act = torch.nn.SiLU()
+
+ def forward(self, features):
+ # features B * num_layers * dim * video_length
+ weights = self.act(comfy.model_management.cast_to(self.weights, dtype=features.dtype, device=features.device))
+ weights_sum = weights.sum(dim=1, keepdims=True)
+ weighted_feat = ((features * weights) / weights_sum).sum(
+ dim=1) # b dim f
+ weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
+ res = self.encoder(weighted_feat) # b f n dim
+ return res # b f n dim
+
+
+class AdaLayerNorm(nn.Module):
+ def __init__(self, embedding_dim, output_dim=None, norm_elementwise_affine=False, norm_eps=1e-5, dtype=None, device=None, operations=None):
+ super().__init__()
+
+ output_dim = output_dim or embedding_dim * 2
+
+ self.silu = nn.SiLU()
+ self.linear = operations.Linear(embedding_dim, output_dim, dtype=dtype, device=device)
+ self.norm = operations.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine, dtype=dtype, device=device)
+
+ def forward(self, x, temb):
+ temb = self.linear(self.silu(temb))
+ shift, scale = temb.chunk(2, dim=1)
+ shift = shift[:, None, :]
+ scale = scale[:, None, :]
+ x = self.norm(x) * (1 + scale) + shift
+ return x
+
+
+class AudioInjector_WAN(nn.Module):
+
+ def __init__(self,
+ dim=2048,
+ num_heads=32,
+ inject_layer=[0, 27],
+ root_net=None,
+ enable_adain=False,
+ adain_dim=2048,
+ adain_mode=None,
+ dtype=None,
+ device=None,
+ operations=None):
+ super().__init__()
+ self.enable_adain = enable_adain
+ self.adain_mode = adain_mode
+ self.injected_block_id = {}
+ audio_injector_id = 0
+ for inject_id in inject_layer:
+ self.injected_block_id[inject_id] = audio_injector_id
+ audio_injector_id += 1
+
+ self.injector = nn.ModuleList([
+ WanT2VCrossAttention(
+ dim=dim,
+ num_heads=num_heads,
+ qk_norm=True, operation_settings={"operations": operations, "device": device, "dtype": dtype}
+ ) for _ in range(audio_injector_id)
+ ])
+ self.injector_pre_norm_feat = nn.ModuleList([
+ operations.LayerNorm(
+ dim,
+ elementwise_affine=False,
+ eps=1e-6, dtype=dtype, device=device
+ ) for _ in range(audio_injector_id)
+ ])
+ self.injector_pre_norm_vec = nn.ModuleList([
+ operations.LayerNorm(
+ dim,
+ elementwise_affine=False,
+ eps=1e-6, dtype=dtype, device=device
+ ) for _ in range(audio_injector_id)
+ ])
+ if enable_adain:
+ self.injector_adain_layers = nn.ModuleList([
+ AdaLayerNorm(
+ output_dim=dim * 2, embedding_dim=adain_dim, dtype=dtype, device=device, operations=operations)
+ for _ in range(audio_injector_id)
+ ])
+ if adain_mode != "attn_norm":
+ self.injector_adain_output_layers = nn.ModuleList(
+ [operations.Linear(dim, dim, dtype=dtype, device=device) for _ in range(audio_injector_id)])
+
+ def forward(self, x, block_id, audio_emb, audio_emb_global, seq_len):
+ audio_attn_id = self.injected_block_id.get(block_id, None)
+ if audio_attn_id is None:
+ return x
+
+ num_frames = audio_emb.shape[1]
+ input_hidden_states = rearrange(x[:, :seq_len], "b (t n) c -> (b t) n c", t=num_frames)
+ if self.enable_adain and self.adain_mode == "attn_norm":
+ audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c")
+ adain_hidden_states = self.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0])
+ attn_hidden_states = adain_hidden_states
+ else:
+ attn_hidden_states = self.injector_pre_norm_feat[audio_attn_id](input_hidden_states)
+ audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
+ attn_audio_emb = audio_emb
+ residual_out = self.injector[audio_attn_id](x=attn_hidden_states, context=attn_audio_emb)
+ residual_out = rearrange(
+ residual_out, "(b t) n c -> b (t n) c", t=num_frames)
+ x[:, :seq_len] = x[:, :seq_len] + residual_out
+ return x
+
+
+class FramePackMotioner(nn.Module):
+ def __init__(
+ self,
+ inner_dim=1024,
+ num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design
+ zip_frame_buckets=[
+ 1, 2, 16
+ ], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames
+ drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion
+ dtype=None,
+ device=None,
+ operations=None):
+ super().__init__()
+ self.proj = operations.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2), dtype=dtype, device=device)
+ self.proj_2x = operations.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4), dtype=dtype, device=device)
+ self.proj_4x = operations.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8), dtype=dtype, device=device)
+ self.zip_frame_buckets = zip_frame_buckets
+
+ self.inner_dim = inner_dim
+ self.num_heads = num_heads
+
+ self.drop_mode = drop_mode
+
+ def forward(self, motion_latents, rope_embedder, add_last_motion=2):
+ lat_height, lat_width = motion_latents.shape[3], motion_latents.shape[4]
+ padd_lat = torch.zeros(motion_latents.shape[0], 16, sum(self.zip_frame_buckets), lat_height, lat_width).to(device=motion_latents.device, dtype=motion_latents.dtype)
+ overlap_frame = min(padd_lat.shape[2], motion_latents.shape[2])
+ if overlap_frame > 0:
+ padd_lat[:, :, -overlap_frame:] = motion_latents[:, :, -overlap_frame:]
+
+ if add_last_motion < 2 and self.drop_mode != "drop":
+ zero_end_frame = sum(self.zip_frame_buckets[:len(self.zip_frame_buckets) - add_last_motion - 1])
+ padd_lat[:, :, -zero_end_frame:] = 0
+
+ clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -sum(self.zip_frame_buckets):, :, :].split(self.zip_frame_buckets[::-1], dim=2) # 16, 2 ,1
+
+ # patchfy
+ clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2)
+ clean_latents_2x = self.proj_2x(clean_latents_2x)
+ l_2x_shape = clean_latents_2x.shape
+ clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
+ clean_latents_4x = self.proj_4x(clean_latents_4x)
+ l_4x_shape = clean_latents_4x.shape
+ clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
+
+ if add_last_motion < 2 and self.drop_mode == "drop":
+ clean_latents_post = clean_latents_post[:, :
+ 0] if add_last_motion < 2 else clean_latents_post
+ clean_latents_2x = clean_latents_2x[:, :
+ 0] if add_last_motion < 1 else clean_latents_2x
+
+ motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
+
+ rope_post = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-1, device=motion_latents.device, dtype=motion_latents.dtype)
+ rope_2x = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-3, steps_h=l_2x_shape[-2], steps_w=l_2x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype)
+ rope_4x = rope_embedder.rope_encode(4, lat_height, lat_width, t_start=-19, steps_h=l_4x_shape[-2], steps_w=l_4x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype)
+
+ rope = torch.cat([rope_post, rope_2x, rope_4x], dim=1)
+ return motion_lat, rope
+
+
+class WanModel_S2V(WanModel):
+ def __init__(self,
+ model_type='s2v',
+ patch_size=(1, 2, 2),
+ text_len=512,
+ in_dim=16,
+ dim=2048,
+ ffn_dim=8192,
+ freq_dim=256,
+ text_dim=4096,
+ out_dim=16,
+ num_heads=16,
+ num_layers=32,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=True,
+ eps=1e-6,
+ audio_dim=1024,
+ num_audio_token=4,
+ enable_adain=True,
+ cond_dim=16,
+ audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],
+ adain_mode="attn_norm",
+ framepack_drop_mode="padd",
+ image_model=None,
+ device=None,
+ dtype=None,
+ operations=None,
+ ):
+
+ super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations)
+
+ self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype)
+
+ self.casual_audio_encoder = CausalAudioEncoder(
+ dim=audio_dim,
+ out_dim=self.dim,
+ num_token=num_audio_token,
+ need_global=enable_adain, dtype=dtype, device=device, operations=operations)
+
+ if cond_dim > 0:
+ self.cond_encoder = operations.Conv3d(
+ cond_dim,
+ self.dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size, device=device, dtype=dtype)
+
+ self.audio_injector = AudioInjector_WAN(
+ dim=self.dim,
+ num_heads=self.num_heads,
+ inject_layer=audio_inject_layers,
+ root_net=self,
+ enable_adain=enable_adain,
+ adain_dim=self.dim,
+ adain_mode=adain_mode,
+ dtype=dtype, device=device, operations=operations
+ )
+
+ self.frame_packer = FramePackMotioner(
+ inner_dim=self.dim,
+ num_heads=self.num_heads,
+ zip_frame_buckets=[1, 2, 16],
+ drop_mode=framepack_drop_mode,
+ dtype=dtype, device=device, operations=operations)
+
+ def forward_orig(
+ self,
+ x,
+ t,
+ context,
+ audio_embed=None,
+ reference_latent=None,
+ control_video=None,
+ reference_motion=None,
+ clip_fea=None,
+ freqs=None,
+ transformer_options={},
+ **kwargs,
+ ):
+ if audio_embed is not None:
+ num_embeds = x.shape[-3] * 4
+ audio_emb_global, audio_emb = self.casual_audio_encoder(audio_embed[:, :, :, :num_embeds])
+ else:
+ audio_emb = None
+
+ # embeddings
+ bs, _, time, height, width = x.shape
+ x = self.patch_embedding(x.float()).to(x.dtype)
+ if control_video is not None:
+ x = x + self.cond_encoder(control_video)
+
+ if t.ndim == 1:
+ t = t.unsqueeze(1).repeat(1, x.shape[2])
+
+ grid_sizes = x.shape[2:]
+ x = x.flatten(2).transpose(1, 2)
+ seq_len = x.size(1)
+
+ cond_mask_weight = comfy.model_management.cast_to(self.trainable_cond_mask.weight, dtype=x.dtype, device=x.device).unsqueeze(1).unsqueeze(1)
+ x = x + cond_mask_weight[0]
+
+ if reference_latent is not None:
+ ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
+ ref = ref.flatten(2).transpose(1, 2)
+ freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=max(30, time + 9), device=x.device, dtype=x.dtype)
+ ref = ref + cond_mask_weight[1]
+ x = torch.cat([x, ref], dim=1)
+ freqs = torch.cat([freqs, freqs_ref], dim=1)
+ t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1)
+ del ref, freqs_ref
+
+ if reference_motion is not None:
+ motion_encoded, freqs_motion = self.frame_packer(reference_motion, self)
+ motion_encoded = motion_encoded + cond_mask_weight[2]
+ x = torch.cat([x, motion_encoded], dim=1)
+ freqs = torch.cat([freqs, freqs_motion], dim=1)
+
+ t = torch.repeat_interleave(t, 2, dim=1)
+ t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1)
+ del motion_encoded, freqs_motion
+
+ # time embeddings
+ e = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
+ e = e.reshape(t.shape[0], -1, e.shape[-1])
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
+
+ # context
+ context = self.text_embedding(context)
+
+ patches_replace = transformer_options.get("patches_replace", {})
+ blocks_replace = patches_replace.get("dit", {})
+ for i, block in enumerate(self.blocks):
+ if ("double_block", i) in blocks_replace:
+ def block_wrap(args):
+ out = {}
+ out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
+ return out
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
+ x = out["img"]
+ else:
+ x = block(x, e=e0, freqs=freqs, context=context)
+ if audio_emb is not None:
+ x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
+ # head
+ x = self.head(x, e)
+
+ # unpatchify
+ x = self.unpatchify(x, grid_sizes)
+ return x
+
+
+class WanT2VCrossAttentionGather(WanSelfAttention):
+
+ def forward(self, x, context, transformer_options={}, **kwargs):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L1, C] - video tokens
+ context(Tensor): Shape [B, L2, C] - audio tokens with shape [B, frames*16, 1536]
+ """
+ b, n, d = x.size(0), self.num_heads, self.head_dim
+
+ q = self.norm_q(self.q(x))
+ k = self.norm_k(self.k(context))
+ v = self.v(context)
+
+ # Handle audio temporal structure (16 tokens per frame)
+ k = k.reshape(-1, 16, n, d).transpose(1, 2)
+ v = v.reshape(-1, 16, n, d).transpose(1, 2)
+
+ # Handle video spatial structure
+ q = q.reshape(k.shape[0], -1, n, d).transpose(1, 2)
+
+ x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
+
+ x = x.transpose(1, 2).view(b, -1, n, d).flatten(2)
+ x = self.o(x)
+ return x
+
+
+class AudioCrossAttentionWrapper(nn.Module):
+ def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6, operation_settings={}):
+ super().__init__()
+
+ self.audio_cross_attn = WanT2VCrossAttentionGather(dim, num_heads, qk_norm=qk_norm, kv_dim=kv_dim, eps=eps, operation_settings=operation_settings)
+ self.norm1_audio = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, x, audio, transformer_options={}):
+ x = x + self.audio_cross_attn(self.norm1_audio(x), audio, transformer_options=transformer_options)
+ return x
+
+
+class WanAttentionBlockAudio(WanAttentionBlock):
+
+ def __init__(self,
+ cross_attn_type,
+ dim,
+ ffn_dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=False,
+ eps=1e-6, operation_settings={}):
+ super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings)
+ self.audio_cross_attn_wrapper = AudioCrossAttentionWrapper(dim, 1536, num_heads, qk_norm, eps, operation_settings=operation_settings)
+
+ def forward(
+ self,
+ x,
+ e,
+ freqs,
+ context,
+ context_img_len=257,
+ audio=None,
+ transformer_options={},
+ ):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, C]
+ e(Tensor): Shape [B, 6, C]
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
+ """
+ # assert e.dtype == torch.float32
+
+ if e.ndim < 4:
+ e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
+ else:
+ e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
+ # assert e[0].dtype == torch.float32
+
+ # self-attention
+ y = self.self_attn(
+ torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
+ freqs, transformer_options=transformer_options)
+
+ x = torch.addcmul(x, y, repeat_e(e[2], x))
+
+ # cross-attention & ffn
+ x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
+ if audio is not None:
+ x = self.audio_cross_attn_wrapper(x, audio, transformer_options=transformer_options)
+ y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
+ x = torch.addcmul(x, y, repeat_e(e[5], x))
+ return x
+
+class DummyAdapterLayer(nn.Module):
+ def __init__(self, layer):
+ super().__init__()
+ self.layer = layer
+
+ def forward(self, *args, **kwargs):
+ return self.layer(*args, **kwargs)
+
+
+class AudioProjModel(nn.Module):
+ def __init__(
+ self,
+ seq_len=5,
+ blocks=13, # add a new parameter blocks
+ channels=768, # add a new parameter channels
+ intermediate_dim=512,
+ output_dim=1536,
+ context_tokens=16,
+ device=None,
+ dtype=None,
+ operations=None,
+ ):
+ super().__init__()
+
+ self.seq_len = seq_len
+ self.blocks = blocks
+ self.channels = channels
+ self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels.
+ self.intermediate_dim = intermediate_dim
+ self.context_tokens = context_tokens
+ self.output_dim = output_dim
+
+ # define multiple linear layers
+ self.audio_proj_glob_1 = DummyAdapterLayer(operations.Linear(self.input_dim, intermediate_dim, dtype=dtype, device=device))
+ self.audio_proj_glob_2 = DummyAdapterLayer(operations.Linear(intermediate_dim, intermediate_dim, dtype=dtype, device=device))
+ self.audio_proj_glob_3 = DummyAdapterLayer(operations.Linear(intermediate_dim, context_tokens * output_dim, dtype=dtype, device=device))
+
+ self.audio_proj_glob_norm = DummyAdapterLayer(operations.LayerNorm(output_dim, dtype=dtype, device=device))
+
+ def forward(self, audio_embeds):
+ video_length = audio_embeds.shape[1]
+ audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
+ batch_size, window_size, blocks, channels = audio_embeds.shape
+ audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
+
+ audio_embeds = torch.relu(self.audio_proj_glob_1(audio_embeds))
+ audio_embeds = torch.relu(self.audio_proj_glob_2(audio_embeds))
+
+ context_tokens = self.audio_proj_glob_3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim)
+
+ context_tokens = self.audio_proj_glob_norm(context_tokens)
+ context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
+
+ return context_tokens
+
+
+class HumoWanModel(WanModel):
+ r"""
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
+ """
+
+ def __init__(self,
+ model_type='humo',
+ patch_size=(1, 2, 2),
+ text_len=512,
+ in_dim=16,
+ dim=2048,
+ ffn_dim=8192,
+ freq_dim=256,
+ text_dim=4096,
+ out_dim=16,
+ num_heads=16,
+ num_layers=32,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=True,
+ eps=1e-6,
+ flf_pos_embed_token_number=None,
+ image_model=None,
+ audio_token_num=16,
+ device=None,
+ dtype=None,
+ operations=None,
+ ):
+
+ super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
+
+ self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations)
+
+ def forward_orig(
+ self,
+ x,
+ t,
+ context,
+ freqs=None,
+ audio_embed=None,
+ reference_latent=None,
+ transformer_options={},
+ **kwargs,
+ ):
+ bs, _, time, height, width = x.shape
+
+ # embeddings
+ x = self.patch_embedding(x.float()).to(x.dtype)
+ grid_sizes = x.shape[2:]
+ x = x.flatten(2).transpose(1, 2)
+
+ # time embeddings
+ e = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
+ e = e.reshape(t.shape[0], -1, e.shape[-1])
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
+
+ if reference_latent is not None:
+ ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
+ ref = ref.flatten(2).transpose(1, 2)
+ freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=time, device=x.device, dtype=x.dtype)
+ x = torch.cat([x, ref], dim=1)
+ freqs = torch.cat([freqs, freqs_ref], dim=1)
+ del ref, freqs_ref
+
+ # context
+ context = self.text_embedding(context)
+ context_img_len = None
+
+ if audio_embed is not None:
+ audio = self.audio_proj(audio_embed).permute(0, 3, 1, 2).flatten(2).transpose(1, 2)
+ else:
+ audio = None
+
+ patches_replace = transformer_options.get("patches_replace", {})
+ blocks_replace = patches_replace.get("dit", {})
+ for i, block in enumerate(self.blocks):
+ if ("double_block", i) in blocks_replace:
+ def block_wrap(args):
+ out = {}
+ out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, audio=audio, transformer_options=args["transformer_options"])
+ return out
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
+ x = out["img"]
+ else:
+ x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, audio=audio, transformer_options=transformer_options)
# head
x = self.head(x, e)
diff --git a/comfy/lora.py b/comfy/lora.py
index 00358884b..36d26293a 100644
--- a/comfy/lora.py
+++ b/comfy/lora.py
@@ -260,6 +260,10 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
+ for k in sdk:
+ hidden_size = model.model_config.unet_config.get("hidden_size", 0)
+ if k.endswith(".weight") and ".linear1." in k:
+ key_map["{}".format(k.replace(".linear1.weight", ".linear1_qkv"))] = (k, (0, 0, hidden_size * 3))
if isinstance(model, comfy.model_base.GenmoMochi):
for k in sdk:
@@ -293,6 +297,12 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
+ if isinstance(model, comfy.model_base.Omnigen2):
+ for k in sdk:
+ if k.startswith("diffusion_model.") and k.endswith(".weight"):
+ key_lora = k[len("diffusion_model."):-len(".weight")]
+ key_map["{}".format(key_lora)] = k
+
if isinstance(model, comfy.model_base.QwenImage):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"): #QwenImage lora format
diff --git a/comfy/lora_convert.py b/comfy/lora_convert.py
index 3e00b63db..9d8d21efe 100644
--- a/comfy/lora_convert.py
+++ b/comfy/lora_convert.py
@@ -15,10 +15,29 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
def convert_lora_wan_fun(sd): #Wan Fun loras
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
+def convert_uso_lora(sd):
+ sd_out = {}
+ for k in sd:
+ tensor = sd[k]
+ k_to = "diffusion_model.{}".format(k.replace(".down.weight", ".lora_down.weight")
+ .replace(".up.weight", ".lora_up.weight")
+ .replace(".qkv_lora2.", ".txt_attn.qkv.")
+ .replace(".qkv_lora1.", ".img_attn.qkv.")
+ .replace(".proj_lora1.", ".img_attn.proj.")
+ .replace(".proj_lora2.", ".txt_attn.proj.")
+ .replace(".qkv_lora.", ".linear1_qkv.")
+ .replace(".proj_lora.", ".linear2.")
+ .replace(".processor.", ".")
+ )
+ sd_out[k_to] = tensor
+ return sd_out
+
def convert_lora(sd):
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
return convert_lora_bfl_control(sd)
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
return convert_lora_wan_fun(sd)
+ if "single_blocks.37.processor.qkv_lora.up.weight" in sd and "double_blocks.18.processor.qkv_lora2.up.weight" in sd:
+ return convert_uso_lora(sd)
return sd
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 15bd7abef..70b67b7c1 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -16,6 +16,8 @@
along with this program. If not, see .
"""
+import comfy.ldm.hunyuan3dv2_1
+import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch
import logging
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
@@ -40,6 +42,7 @@ import comfy.ldm.wan.model
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
+import comfy.ldm.chroma_radiance.model
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
@@ -150,6 +153,7 @@ class BaseModel(torch.nn.Module):
logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor
self.memory_usage_factor_conds = ()
+ self.memory_usage_shape_process = {}
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@@ -350,8 +354,15 @@ class BaseModel(torch.nn.Module):
input_shapes = [input_shape]
for c in self.memory_usage_factor_conds:
shape = cond_shapes.get(c, None)
- if shape is not None and len(shape) > 0:
- input_shapes += shape
+ if shape is not None:
+ if c in self.memory_usage_shape_process:
+ out = []
+ for s in shape:
+ out.append(self.memory_usage_shape_process[c](s))
+ shape = out
+
+ if len(shape) > 0:
+ input_shapes += shape
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
@@ -1102,9 +1113,10 @@ class WAN21(BaseModel):
shape_image[1] = extra_channels
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
else:
+ latent_dim = self.latent_format.latent_channels
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
- for i in range(0, image.shape[1], 16):
- image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
+ for i in range(0, image.shape[1], latent_dim):
+ image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim])
image = utils.resize_to_batch_size(image, noise.shape[0])
if extra_channels != image.shape[1] + 4:
@@ -1201,18 +1213,90 @@ class WAN21_Camera(WAN21):
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
return out
-class WAN22(BaseModel):
+class WAN21_HuMo(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
- super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
+ super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.HumoWanModel)
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
- cross_attn = kwargs.get("cross_attn", None)
- if cross_attn is not None:
- out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+ noise = kwargs.get("noise", None)
- denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
+ audio_embed = kwargs.get("audio_embed", None)
+ if audio_embed is not None:
+ out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
+
+ if "c_concat" not in out: # 1.7B model
+ reference_latents = kwargs.get("reference_latents", None)
+ if reference_latents is not None:
+ out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
+ else:
+ noise_shape = list(noise.shape)
+ noise_shape[1] += 4
+ concat_latent = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
+ zero_vae_values_first = torch.tensor([0.8660, -0.4326, -0.0017, -0.4884, -0.5283, 0.9207, -0.9896, 0.4433, -0.5543, -0.0113, 0.5753, -0.6000, -0.8346, -0.3497, -0.1926, -0.6938]).view(1, 16, 1, 1, 1)
+ zero_vae_values_second = torch.tensor([1.0869, -1.2370, 0.0206, -0.4357, -0.6411, 2.0307, -1.5972, 1.2659, -0.8595, -0.4654, 0.9638, -1.6330, -1.4310, -0.1098, -0.3856, -1.4583]).view(1, 16, 1, 1, 1)
+ zero_vae_values = torch.tensor([0.8642, -1.8583, 0.1577, 0.1350, -0.3641, 2.5863, -1.9670, 1.6065, -1.0475, -0.8678, 1.1734, -1.8138, -1.5933, -0.7721, -0.3289, -1.3745]).view(1, 16, 1, 1, 1)
+ concat_latent[:, 4:] = zero_vae_values
+ concat_latent[:, 4:, :1] = zero_vae_values_first
+ concat_latent[:, 4:, 1:2] = zero_vae_values_second
+ out['c_concat'] = comfy.conds.CONDNoiseShape(concat_latent)
+ reference_latents = kwargs.get("reference_latents", None)
+ if reference_latents is not None:
+ ref_latent = self.process_latent_in(reference_latents[-1])
+ ref_latent_shape = list(ref_latent.shape)
+ ref_latent_shape[1] += 4 + ref_latent_shape[1]
+ ref_latent_full = torch.zeros(ref_latent_shape, device=ref_latent.device, dtype=ref_latent.dtype)
+ ref_latent_full[:, 20:] = ref_latent
+ ref_latent_full[:, 16:20] = 1.0
+ out['reference_latent'] = comfy.conds.CONDRegular(ref_latent_full)
+
+ return out
+
+class WAN22_S2V(WAN21):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
+ self.memory_usage_factor_conds = ("reference_latent", "reference_motion")
+ self.memory_usage_shape_process = {"reference_motion": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ audio_embed = kwargs.get("audio_embed", None)
+ if audio_embed is not None:
+ out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
+
+ reference_latents = kwargs.get("reference_latents", None)
+ if reference_latents is not None:
+ out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
+
+ reference_motion = kwargs.get("reference_motion", None)
+ if reference_motion is not None:
+ out['reference_motion'] = comfy.conds.CONDRegular(self.process_latent_in(reference_motion))
+
+ control_video = kwargs.get("control_video", None)
+ if control_video is not None:
+ out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video))
+ return out
+
+ def extra_conds_shapes(self, **kwargs):
+ out = {}
+ ref_latents = kwargs.get("reference_latents", None)
+ if ref_latents is not None:
+ out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
+
+ reference_motion = kwargs.get("reference_motion", None)
+ if reference_motion is not None:
+ out['reference_motion'] = reference_motion.shape
+ return out
+
+class WAN22(WAN21):
+ def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
+ super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
+ self.image_to_video = image_to_video
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ denoise_mask = kwargs.get("denoise_mask", None)
if denoise_mask is not None:
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
return out
@@ -1241,6 +1325,21 @@ class Hunyuan3Dv2(BaseModel):
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
+class Hunyuan3Dv2_1(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3dv2_1.hunyuandit.HunYuanDiTPlain)
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+
+ guidance = kwargs.get("guidance", 5.0)
+ if guidance is not None:
+ out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
+ return out
+
class HiDream(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
@@ -1262,8 +1361,8 @@ class HiDream(BaseModel):
return out
class Chroma(Flux):
- def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
- super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
+ def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.chroma.model.Chroma):
+ super().__init__(model_config, model_type, device=device, unet_model=unet_model)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
@@ -1273,6 +1372,10 @@ class Chroma(Flux):
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
+class ChromaRadiance(Chroma):
+ def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma_radiance.model.ChromaRadiance)
+
class ACEStep(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel)
@@ -1325,6 +1428,7 @@ class Omnigen2(BaseModel):
class QwenImage(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel)
+ self.memory_usage_factor_conds = ("ref_latents",)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
@@ -1342,3 +1446,62 @@ class QwenImage(BaseModel):
if ref_latents_method is not None:
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
return out
+
+ def extra_conds_shapes(self, **kwargs):
+ out = {}
+ ref_latents = kwargs.get("reference_latents", None)
+ if ref_latents is not None:
+ out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
+ return out
+
+class HunyuanImage21(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ attention_mask = kwargs.get("attention_mask", None)
+ if attention_mask is not None:
+ if torch.numel(attention_mask) != attention_mask.sum():
+ out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+
+ conditioning_byt5small = kwargs.get("conditioning_byt5small", None)
+ if conditioning_byt5small is not None:
+ out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small)
+
+ guidance = kwargs.get("guidance", 6.0)
+ if guidance is not None:
+ out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
+
+ return out
+
+class HunyuanImage21Refiner(HunyuanImage21):
+ def concat_cond(self, **kwargs):
+ noise = kwargs.get("noise", None)
+ image = kwargs.get("concat_latent_image", None)
+ noise_augmentation = kwargs.get("noise_augmentation", 0.0)
+ device = kwargs["device"]
+
+ if image is None:
+ shape_image = list(noise.shape)
+ image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
+ else:
+ image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
+ image = self.process_latent_in(image)
+ image = utils.resize_to_batch_size(image, noise.shape[0])
+ if noise_augmentation > 0:
+ generator = torch.Generator(device="cpu")
+ generator.manual_seed(kwargs.get("seed", 0) - 10)
+ noise = torch.randn(image.shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device)
+ image = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image
+ else:
+ image = 0.75 * image
+ return image
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ out['disable_time_r'] = comfy.conds.CONDConstant(True)
+ return out
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 2bec0541e..72621bed6 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -136,25 +136,45 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
dit_config = {}
+ in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)]
+ out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)]
dit_config["image_model"] = "hunyuan_video"
- dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels
- dit_config["patch_size"] = [1, 2, 2]
- dit_config["out_channels"] = 16
- dit_config["vec_in_dim"] = 768
- dit_config["context_in_dim"] = 4096
- dit_config["hidden_size"] = 3072
+ dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
+ dit_config["patch_size"] = list(in_w.shape[2:])
+ dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
+ if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys):
+ dit_config["vec_in_dim"] = 768
+ else:
+ dit_config["vec_in_dim"] = None
+
+ if len(dit_config["patch_size"]) == 2:
+ dit_config["axes_dim"] = [64, 64]
+ else:
+ dit_config["axes_dim"] = [16, 56, 56]
+
+ if any(s.startswith('{}time_r_in.'.format(key_prefix)) for s in state_dict_keys):
+ dit_config["meanflow"] = True
+ else:
+ dit_config["meanflow"] = False
+
+ dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
+ dit_config["hidden_size"] = in_w.shape[0]
dit_config["mlp_ratio"] = 4.0
- dit_config["num_heads"] = 24
+ dit_config["num_heads"] = in_w.shape[0] // 128
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
- dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 256
dit_config["qkv_bias"] = True
+ if '{}byt5_in.fc1.weight'.format(key_prefix) in state_dict:
+ dit_config["byt5"] = True
+ else:
+ dit_config["byt5"] = False
+
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
dit_config["guidance_embed"] = len(guidance_keys) > 0
return dit_config
- if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux
+ if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
dit_config = {}
dit_config["image_model"] = "flux"
dit_config["in_channels"] = 16
@@ -184,6 +204,18 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["out_dim"] = 3072
dit_config["hidden_dim"] = 5120
dit_config["n_layers"] = 5
+ if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
+ dit_config["image_model"] = "chroma_radiance"
+ dit_config["in_channels"] = 3
+ dit_config["out_channels"] = 3
+ dit_config["patch_size"] = 16
+ dit_config["nerf_hidden_size"] = 64
+ dit_config["nerf_mlp_ratio"] = 4
+ dit_config["nerf_depth"] = 4
+ dit_config["nerf_max_freqs"] = 8
+ dit_config["nerf_tile_size"] = 32
+ dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
+ dit_config["nerf_embedder_dtype"] = torch.float32
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
@@ -368,6 +400,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "camera"
else:
dit_config["model_type"] = "camera_2.2"
+ elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys:
+ dit_config["model_type"] = "s2v"
+ elif '{}audio_proj.audio_proj_glob_1.layer.bias'.format(key_prefix) in state_dict_keys:
+ dit_config["model_type"] = "humo"
else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"
@@ -398,6 +434,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
+ if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1
+
+ dit_config = {}
+ dit_config["image_model"] = "hunyuan3d2_1"
+ dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
+ dit_config["context_dim"] = 1024
+ dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0]
+ dit_config["mlp_ratio"] = 4.0
+ dit_config["num_heads"] = 16
+ dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}")
+ dit_config["qkv_bias"] = False
+ dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
+ return dit_config
+
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
dit_config = {}
dit_config["image_model"] = "hidream"
@@ -492,6 +542,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
dit_config = {}
dit_config["image_model"] = "qwen_image"
+ dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
+ dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
diff --git a/comfy/model_management.py b/comfy/model_management.py
index 2a9f18068..bbfc3c7a1 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -22,6 +22,7 @@ from enum import Enum
from comfy.cli_args import args, PerformanceFeature
import torch
import sys
+import importlib
import platform
import weakref
import gc
@@ -289,6 +290,24 @@ def is_amd():
return True
return False
+def amd_min_version(device=None, min_rdna_version=0):
+ if not is_amd():
+ return False
+
+ if is_device_cpu(device):
+ return False
+
+ arch = torch.cuda.get_device_properties(device).gcnArchName
+ if arch.startswith('gfx') and len(arch) == 7:
+ try:
+ cmp_rdna_version = int(arch[4]) + 2
+ except:
+ cmp_rdna_version = 0
+ if cmp_rdna_version >= min_rdna_version:
+ return True
+
+ return False
+
MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia():
MIN_WEIGHT_MEMORY_RATIO = 0.0
@@ -321,12 +340,13 @@ try:
logging.info("AMD arch: {}".format(arch))
logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
- if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
- if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
- ENABLE_PYTORCH_ATTENTION = True
-# if torch_version_numeric >= (2, 8):
-# if any((a in arch) for a in ["gfx1201"]):
-# ENABLE_PYTORCH_ATTENTION = True
+ if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not.
+ if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
+ if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
+ ENABLE_PYTORCH_ATTENTION = True
+# if torch_version_numeric >= (2, 8):
+# if any((a in arch) for a in ["gfx1201"]):
+# ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
SUPPORT_FP8_OPS = True
@@ -593,7 +613,13 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
else:
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
- models = set(models)
+ models_temp = set()
+ for m in models:
+ models_temp.add(m)
+ for mm in m.model_patches_models():
+ models_temp.add(mm)
+
+ models = models_temp
models_to_load = []
@@ -899,7 +925,9 @@ def vae_dtype(device=None, allowed_dtypes=[]):
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
- if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
+ # also a problem on RDNA4 except fp32 is also slow there.
+ # This is due to large bf16 convolutions being extremely slow.
+ if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
return d
return torch.float32
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index 52e76b5f3..1fd03d9d1 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -430,6 +430,12 @@ class ModelPatcher:
def set_model_forward_timestep_embed_patch(self, patch):
self.set_model_patch(patch, "forward_timestep_embed_patch")
+ def set_model_double_block_patch(self, patch):
+ self.set_model_patch(patch, "double_block")
+
+ def set_model_post_input_patch(self, patch):
+ self.set_model_patch(patch, "post_input")
+
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
@@ -486,6 +492,30 @@ class ModelPatcher:
if hasattr(wrap_func, "to"):
self.model_options["model_function_wrapper"] = wrap_func.to(device)
+ def model_patches_models(self):
+ to = self.model_options["transformer_options"]
+ models = []
+ if "patches" in to:
+ patches = to["patches"]
+ for name in patches:
+ patch_list = patches[name]
+ for i in range(len(patch_list)):
+ if hasattr(patch_list[i], "models"):
+ models += patch_list[i].models()
+ if "patches_replace" in to:
+ patches = to["patches_replace"]
+ for name in patches:
+ patch_list = patches[name]
+ for k in patch_list:
+ if hasattr(patch_list[k], "models"):
+ models += patch_list[k].models()
+ if "model_function_wrapper" in self.model_options:
+ wrap_func = self.model_options["model_function_wrapper"]
+ if hasattr(wrap_func, "models"):
+ models += wrap_func.models()
+
+ return models
+
def model_dtype(self):
if hasattr(self.model, "get_dtype"):
return self.model.get_dtype()
diff --git a/comfy/ops.py b/comfy/ops.py
index 18e7db705..55e958adb 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -52,6 +52,9 @@ except (ModuleNotFoundError, TypeError):
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
+if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
+ torch.backends.cudnn.benchmark = True
+
def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
diff --git a/comfy/patcher_extension.py b/comfy/patcher_extension.py
index 965958f4c..46cc7b2a8 100644
--- a/comfy/patcher_extension.py
+++ b/comfy/patcher_extension.py
@@ -50,6 +50,7 @@ class WrappersMP:
OUTER_SAMPLE = "outer_sample"
PREPARE_SAMPLING = "prepare_sampling"
SAMPLER_SAMPLE = "sampler_sample"
+ PREDICT_NOISE = "predict_noise"
CALC_COND_BATCH = "calc_cond_batch"
APPLY_MODEL = "apply_model"
DIFFUSION_MODEL = "diffusion_model"
diff --git a/comfy/pixel_space_convert.py b/comfy/pixel_space_convert.py
new file mode 100644
index 000000000..049bbcfb4
--- /dev/null
+++ b/comfy/pixel_space_convert.py
@@ -0,0 +1,16 @@
+import torch
+
+
+# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1
+# to LATENT B, C, H, W and values on the scale of -1..1.
+class PixelspaceConversionVAE(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.pixel_space_vae = torch.nn.Parameter(torch.tensor(1.0))
+
+ def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
+ return pixels
+
+ def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
+ return samples
+
diff --git a/comfy/samplers.py b/comfy/samplers.py
old mode 100644
new mode 100755
index d5390d64e..b3202cec6
--- a/comfy/samplers.py
+++ b/comfy/samplers.py
@@ -17,6 +17,7 @@ import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
import comfy.context_windows
+import comfy.utils
import scipy.stats
import numpy
@@ -61,7 +62,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
if "mask_strength" in conds:
mask_strength = conds["mask_strength"]
mask = conds['mask']
- assert (mask.shape[1:] == x_in.shape[2:])
+ # assert (mask.shape[1:] == x_in.shape[2:])
mask = mask[:input_x.shape[0]]
if area is not None:
@@ -69,7 +70,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
mask = mask.narrow(i + 1, area[len(dims) + i], area[i])
mask = mask * mask_strength
- mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
+ mask = mask.unsqueeze(1).repeat((input_x.shape[0] // mask.shape[0], input_x.shape[1]) + (1, ) * (mask.ndim - 1))
else:
mask = torch.ones_like(input_x)
mult = mask * strength
@@ -553,7 +554,10 @@ def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
if len(mask.shape) == len(dims):
mask = mask.unsqueeze(0)
if mask.shape[1:] != dims:
- mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1)
+ if mask.ndim < 4:
+ mask = comfy.utils.common_upscale(mask.unsqueeze(1), dims[-1], dims[-2], 'bilinear', 'none').squeeze(1)
+ else:
+ mask = comfy.utils.common_upscale(mask, dims[-1], dims[-2], 'bilinear', 'none')
if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
@@ -725,7 +729,7 @@ class Sampler:
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
- "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
+ "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
@@ -953,7 +957,14 @@ class CFGGuider:
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
def __call__(self, *args, **kwargs):
- return self.predict_noise(*args, **kwargs)
+ return self.outer_predict_noise(*args, **kwargs)
+
+ def outer_predict_noise(self, x, timestep, model_options={}, seed=None):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self.predict_noise,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, self.model_options, is_model_options=True)
+ ).execute(x, timestep, model_options, seed)
def predict_noise(self, x, timestep, model_options={}, seed=None):
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
diff --git a/comfy/sd.py b/comfy/sd.py
index bb5d61fb3..2df340739 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -17,6 +17,8 @@ import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
+import comfy.ldm.hunyuan_video.vae
+import comfy.pixel_space_convert
import yaml
import math
import os
@@ -48,6 +50,7 @@ import comfy.text_encoders.hidream
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
+import comfy.text_encoders.hunyuan_image
import comfy.model_patcher
import comfy.lora
@@ -283,6 +286,7 @@ class VAE:
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
+ self.not_video = False
self.downscale_index_formula = None
self.upscale_index_formula = None
@@ -328,6 +332,19 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
+ elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64:
+ ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
+ self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
+ self.downscale_ratio = 32
+ self.upscale_ratio = 32
+ self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+ self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
+ encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
+ decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
+
+ self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
+ self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
+
elif "decoder.conv_in.weight" in sd:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
@@ -394,6 +411,23 @@ class VAE:
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
self.downscale_index_formula = (8, 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32]
+ elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
+ ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
+ ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
+ self.latent_channels = 64
+ self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
+ self.upscale_index_formula = (4, 16, 16)
+ self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
+ self.downscale_index_formula = (4, 16, 16)
+ self.latent_dim = 3
+ self.not_video = True
+ self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+ self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"},
+ encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
+ decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
+
+ self.memory_used_encode = lambda shape, dtype: (1400 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
+ self.memory_used_decode = lambda shape, dtype: (1400 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True
@@ -446,17 +480,29 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
+ # Hunyuan 3d v2 2.0 & 2.1
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
+
self.latent_dim = 1
- ln_post = "geo_decoder.ln_post.weight" in sd
- inner_size = sd["geo_decoder.output_proj.weight"].shape[1]
- downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size
- mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size
- self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO
- self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO
- ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
- self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
+
+ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):
+ batch, num_tokens, hidden_dim = shape
+ dtype_size = model_management.dtype_size(dtype)
+
+ total_mem = batch * num_tokens * hidden_dim * dtype_size * (1 + kv_cache_multiplier * num_layers)
+ return total_mem
+
+ # better memory estimations
+ self.memory_used_encode = lambda shape, dtype, num_layers = 8, kv_cache_multiplier = 0:\
+ estimate_memory(shape, dtype, num_layers, kv_cache_multiplier)
+
+ self.memory_used_decode = lambda shape, dtype, num_layers = 16, kv_cache_multiplier = 2: \
+ estimate_memory(shape, dtype, num_layers, kv_cache_multiplier)
+
+ self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE()
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+
+
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
@@ -471,6 +517,15 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.disable_offload = True
self.extra_1d_channel = 16
+ elif "pixel_space_vae" in sd:
+ self.first_stage_model = comfy.pixel_space_convert.PixelspaceConversionVAE()
+ self.memory_used_encode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
+ self.memory_used_decode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
+ self.downscale_ratio = 1
+ self.upscale_ratio = 1
+ self.latent_channels = 3
+ self.latent_dim = 2
+ self.output_channels = 3
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@@ -643,7 +698,10 @@ class VAE:
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1)
if self.latent_dim == 3 and pixel_samples.ndim < 5:
- pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
+ if not self.not_video:
+ pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
+ else:
+ pixel_samples = pixel_samples.unsqueeze(2)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@@ -677,7 +735,10 @@ class VAE:
dims = self.latent_dim
pixel_samples = pixel_samples.movedim(-1, 1)
if dims == 3:
- pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
+ if not self.not_video:
+ pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
+ else:
+ pixel_samples = pixel_samples.unsqueeze(2)
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@@ -734,6 +795,7 @@ class VAE:
except:
return None
+
class StyleModel:
def __init__(self, model, device="cpu"):
self.model = model
@@ -773,6 +835,7 @@ class CLIPType(Enum):
ACE = 16
OMNIGEN2 = 17
QWEN_IMAGE = 18
+ HUNYUAN_IMAGE = 19
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -794,6 +857,7 @@ class TEModel(Enum):
GEMMA_2_2B = 9
QWEN25_3B = 10
QWEN25_7B = 11
+ BYT5_SMALL_GLYPH = 12
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -811,6 +875,9 @@ def detect_te_model(sd):
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
return TEModel.T5_XXL_OLD
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
+ weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight']
+ if weight.shape[0] == 384:
+ return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
return TEModel.GEMMA_2_2B
@@ -925,8 +992,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
elif te_model == TEModel.QWEN25_7B:
- clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
- clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
+ if clip_type == CLIPType.HUNYUAN_IMAGE:
+ clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
+ else:
+ clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
@@ -970,6 +1041,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
+ elif clip_type == CLIPType.HUNYUAN_IMAGE:
+ clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index ade340fd1..f8a7c2a1b 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -204,17 +204,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
index = 0
pad_extra = 0
+ embeds_info = []
for o in other_embeds:
emb = o[1]
if torch.is_tensor(emb):
emb = {"type": "embedding", "data": emb}
+ extra = None
emb_type = emb.get("type", None)
if emb_type == "embedding":
emb = emb.get("data", None)
else:
if hasattr(self.transformer, "preprocess_embed"):
- emb = self.transformer.preprocess_embed(emb, device=device)
+ emb, extra = self.transformer.preprocess_embed(emb, device=device)
else:
emb = None
@@ -229,6 +231,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
index += emb_shape - 1
+ embeds_info.append({"type": emb_type, "index": ind, "size": emb_shape, "extra": extra})
else:
index += -1
pad_extra += emb_shape
@@ -243,11 +246,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
attention_masks.append(attention_mask)
num_tokens.append(sum(attention_mask))
- return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens
+ return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
def forward(self, tokens):
device = self.transformer.get_input_embeddings().weight.device
- embeds, attention_mask, num_tokens = self.process_tokens(tokens, device)
+ embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
attention_mask_model = None
if self.enable_attention_masks:
@@ -258,7 +261,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else:
intermediate_output = self.layer_idx
- outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
+ outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32, embeds_info=embeds_info)
if self.layer == "last":
z = outputs[0].float()
@@ -531,7 +534,10 @@ class SDTokenizer:
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
text = escape_important(text)
- parsed_weights = token_weights(text, 1.0)
+ if kwargs.get("disable_weights", False):
+ parsed_weights = [(text, 1.0)]
+ else:
+ parsed_weights = token_weights(text, 1.0)
# tokenize words
tokens = []
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index 7ed6dfd69..213b5b92c 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -20,6 +20,7 @@ import comfy.text_encoders.wan
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
+import comfy.text_encoders.hunyuan_image
from . import supported_models_base
from . import latent_formats
@@ -700,7 +701,7 @@ class Flux(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.Flux
- memory_usage_factor = 2.8
+ memory_usage_factor = 3.1 # TODO: debug why flux mem usage is so weird on windows.
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
@@ -1072,6 +1073,29 @@ class WAN21_Vace(WAN21_T2V):
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
return out
+class WAN21_HuMo(WAN21_T2V):
+ unet_config = {
+ "image_model": "wan2.1",
+ "model_type": "humo",
+ }
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.WAN21_HuMo(self, image_to_video=False, device=device)
+ return out
+
+class WAN22_S2V(WAN21_T2V):
+ unet_config = {
+ "image_model": "wan2.1",
+ "model_type": "s2v",
+ }
+
+ def __init__(self, unet_config):
+ super().__init__(unet_config)
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.WAN22_S2V(self, device=device)
+ return out
+
class WAN22_T2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@@ -1115,6 +1139,17 @@ class Hunyuan3Dv2(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return None
+class Hunyuan3Dv2_1(Hunyuan3Dv2):
+ unet_config = {
+ "image_model": "hunyuan3d2_1",
+ }
+
+ latent_format = latent_formats.Hunyuan3Dv2_1
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.Hunyuan3Dv2_1(self, device = device)
+ return out
+
class Hunyuan3Dv2mini(Hunyuan3Dv2):
unet_config = {
"image_model": "hunyuan3d2",
@@ -1180,6 +1215,19 @@ class Chroma(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
+class ChromaRadiance(Chroma):
+ unet_config = {
+ "image_model": "chroma_radiance",
+ }
+
+ latent_format = comfy.latent_formats.ChromaRadiance
+
+ # Pixel-space model, no spatial compression for model input.
+ memory_usage_factor = 0.038
+
+ def get_model(self, state_dict, prefix="", device=None):
+ return model_base.ChromaRadiance(self, device=device)
+
class ACEStep(supported_models_base.BASE):
unet_config = {
"audio_model": "ace",
@@ -1271,7 +1319,48 @@ class QwenImage(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
+class HunyuanImage21(HunyuanVideo):
+ unet_config = {
+ "image_model": "hunyuan_video",
+ "vec_in_dim": None,
+ }
-models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
+ sampling_settings = {
+ "shift": 5.0,
+ }
+
+ latent_format = latent_formats.HunyuanImage21
+
+ memory_usage_factor = 7.7
+
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.HunyuanImage21(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ pref = self.text_encoder_key_prefix[0]
+ hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
+ return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
+
+class HunyuanImage21Refiner(HunyuanVideo):
+ unet_config = {
+ "image_model": "hunyuan_video",
+ "patch_size": [1, 1, 1],
+ "vec_in_dim": None,
+ }
+
+ sampling_settings = {
+ "shift": 4.0,
+ }
+
+ latent_format = latent_formats.HunyuanImage21Refiner
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.HunyuanImage21Refiner(self, device=device)
+ return out
+
+models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
models += [SVD_img2vid]
diff --git a/comfy/text_encoders/bert.py b/comfy/text_encoders/bert.py
index 551b03162..ed4638a9a 100644
--- a/comfy/text_encoders/bert.py
+++ b/comfy/text_encoders/bert.py
@@ -116,7 +116,7 @@ class BertModel_(torch.nn.Module):
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
- def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
+ def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
mask = None
if attention_mask is not None:
diff --git a/comfy/text_encoders/byt5_config_small_glyph.json b/comfy/text_encoders/byt5_config_small_glyph.json
new file mode 100644
index 000000000..0239c7164
--- /dev/null
+++ b/comfy/text_encoders/byt5_config_small_glyph.json
@@ -0,0 +1,22 @@
+{
+ "d_ff": 3584,
+ "d_kv": 64,
+ "d_model": 1472,
+ "decoder_start_token_id": 0,
+ "dropout_rate": 0.1,
+ "eos_token_id": 1,
+ "dense_act_fn": "gelu_pytorch_tanh",
+ "initializer_factor": 1.0,
+ "is_encoder_decoder": true,
+ "is_gated_act": true,
+ "layer_norm_epsilon": 1e-06,
+ "model_type": "t5",
+ "num_decoder_layers": 4,
+ "num_heads": 6,
+ "num_layers": 12,
+ "output_past": true,
+ "pad_token_id": 0,
+ "relative_attention_num_buckets": 32,
+ "tie_word_embeddings": false,
+ "vocab_size": 1510
+}
diff --git a/comfy/text_encoders/byt5_tokenizer/added_tokens.json b/comfy/text_encoders/byt5_tokenizer/added_tokens.json
new file mode 100644
index 000000000..93c190b56
--- /dev/null
+++ b/comfy/text_encoders/byt5_tokenizer/added_tokens.json
@@ -0,0 +1,127 @@
+{
+ "": 259,
+ "": 359,
+ "": 360,
+ "": 361,
+ "": 362,
+ "": 363,
+ "": 364,
+ "": 365,
+ "": 366,
+ "": 367,
+ "": 368,
+ "": 269,
+ "": 369,
+ "": 370,
+ "": 371,
+ "": 372,
+ "": 373,
+ "": 374,
+ "": 375,
+ "": 376,
+ "": 377,
+ "": 378,
+ "": 270,
+ "": 379,
+ "": 380,
+ "": 381,
+ "": 382,
+ "": 383,
+ "": 271,
+ "": 272,
+ "": 273,
+ "": 274,
+ "": 275,
+ "": 276,
+ "": 277,
+ "": 278,
+ "": 260,
+ "": 279,
+ "": 280,
+ "": 281,
+ "": 282,
+ "": 283,
+ "": 284,
+ "": 285,
+ "": 286,
+ "": 287,
+ "": 288,
+ "": 261,
+ "": 289,
+ "": 290,
+ "": 291,
+ "": 292,
+ "": 293,
+ "": 294,
+ "": 295,
+ "": 296,
+ "": 297,
+ "": 298,
+ "": 262,
+ "": 299,
+ "": 300,
+ "": 301,
+ "": 302,
+ "": 303,
+ "": 304,
+ "": 305,
+ "": 306,
+ "": 307,
+ "": 308,
+ "": 263,
+ "": 309,
+ "": 310,
+ "": 311,
+ "": 312,
+ "": 313,
+ "": 314,
+ "": 315,
+ "": 316,
+ "": 317,
+ "": 318,
+ "": 264,
+ "": 319,
+ "": 320,
+ "": 321,
+ "": 322,
+ "": 323,
+ "": 324,
+ "": 325,
+ "": 326,
+ "": 327,
+ "": 328,
+ "": 265,
+ "": 329,
+ "": 330,
+ "": 331,
+ "": 332,
+ "": 333,
+ "": 334,
+ "": 335,
+ "": 336,
+ "": 337,
+ "": 338,
+ "": 266,
+ "": 339,
+ "": 340,
+ "": 341,
+ "": 342,
+ "": 343,
+ "": 344,
+ "": 345,
+ "": 346,
+ "": 347,
+ "": 348,
+ "": 267,
+ "": 349,
+ "": 350,
+ "": 351,
+ "": 352,
+ "": 353,
+ "": 354,
+ "": 355,
+ "": 356,
+ "": 357,
+ "": 358,
+ "": 268
+}
diff --git a/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json b/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json
new file mode 100644
index 000000000..04fd58b5f
--- /dev/null
+++ b/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json
@@ -0,0 +1,150 @@
+{
+ "additional_special_tokens": [
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ""
+ ],
+ "eos_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": true,
+ "rstrip": false,
+ "single_word": false
+ },
+ "pad_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": true,
+ "rstrip": false,
+ "single_word": false
+ },
+ "unk_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": true,
+ "rstrip": false,
+ "single_word": false
+ }
+}
diff --git a/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json b/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json
new file mode 100644
index 000000000..5b1fe24c1
--- /dev/null
+++ b/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json
@@ -0,0 +1,1163 @@
+{
+ "added_tokens_decoder": {
+ "0": {
+ "content": "",
+ "lstrip": false,
+ "normalized": true,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "1": {
+ "content": "",
+ "lstrip": false,
+ "normalized": true,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "2": {
+ "content": "",
+ "lstrip": false,
+ "normalized": true,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "259": {
+ "content": "