mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 10:02:59 +08:00
Merge branch 'master' into worksplit-multigpu
This commit is contained in:
commit
bb44c2ecb9
@ -66,6 +66,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
||||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||||
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
|
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
|
||||||
|
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
|
||||||
- Image Editing Models
|
- Image Editing Models
|
||||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from .wav2vec2 import Wav2Vec2Model
|
from .wav2vec2 import Wav2Vec2Model
|
||||||
|
from .whisper import WhisperLargeV3
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@ -11,7 +12,18 @@ class AudioEncoderModel():
|
|||||||
self.load_device = comfy.model_management.text_encoder_device()
|
self.load_device = comfy.model_management.text_encoder_device()
|
||||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||||
self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast)
|
model_type = config.pop("model_type")
|
||||||
|
model_config = dict(config)
|
||||||
|
model_config.update({
|
||||||
|
"dtype": self.dtype,
|
||||||
|
"device": offload_device,
|
||||||
|
"operations": comfy.ops.manual_cast
|
||||||
|
})
|
||||||
|
|
||||||
|
if model_type == "wav2vec2":
|
||||||
|
self.model = Wav2Vec2Model(**model_config)
|
||||||
|
elif model_type == "whisper3":
|
||||||
|
self.model = WhisperLargeV3(**model_config)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
self.model_sample_rate = 16000
|
self.model_sample_rate = 16000
|
||||||
@ -29,14 +41,51 @@ class AudioEncoderModel():
|
|||||||
outputs = {}
|
outputs = {}
|
||||||
outputs["encoded_audio"] = out
|
outputs["encoded_audio"] = out
|
||||||
outputs["encoded_audio_all_layers"] = all_layers
|
outputs["encoded_audio_all_layers"] = all_layers
|
||||||
|
outputs["audio_samples"] = audio.shape[2]
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def load_audio_encoder_from_sd(sd, prefix=""):
|
def load_audio_encoder_from_sd(sd, prefix=""):
|
||||||
audio_encoder = AudioEncoderModel(None)
|
|
||||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
|
||||||
|
if "encoder.layer_norm.bias" in sd: #wav2vec2
|
||||||
|
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
|
||||||
|
if embed_dim == 1024:# large
|
||||||
|
config = {
|
||||||
|
"model_type": "wav2vec2",
|
||||||
|
"embed_dim": 1024,
|
||||||
|
"num_heads": 16,
|
||||||
|
"num_layers": 24,
|
||||||
|
"conv_norm": True,
|
||||||
|
"conv_bias": True,
|
||||||
|
"do_normalize": True,
|
||||||
|
"do_stable_layer_norm": True
|
||||||
|
}
|
||||||
|
elif embed_dim == 768: # base
|
||||||
|
config = {
|
||||||
|
"model_type": "wav2vec2",
|
||||||
|
"embed_dim": 768,
|
||||||
|
"num_heads": 12,
|
||||||
|
"num_layers": 12,
|
||||||
|
"conv_norm": False,
|
||||||
|
"conv_bias": False,
|
||||||
|
"do_normalize": False, # chinese-wav2vec2-base has this False
|
||||||
|
"do_stable_layer_norm": False
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
|
||||||
|
elif "model.encoder.embed_positions.weight" in sd:
|
||||||
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""})
|
||||||
|
config = {
|
||||||
|
"model_type": "whisper3",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise RuntimeError("ERROR: audio encoder not supported.")
|
||||||
|
|
||||||
|
audio_encoder = AudioEncoderModel(config)
|
||||||
m, u = audio_encoder.load_sd(sd)
|
m, u = audio_encoder.load_sd(sd)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
logging.warning("missing audio encoder: {}".format(m))
|
logging.warning("missing audio encoder: {}".format(m))
|
||||||
|
if len(u) > 0:
|
||||||
|
logging.warning("unexpected audio encoder: {}".format(u))
|
||||||
|
|
||||||
return audio_encoder
|
return audio_encoder
|
||||||
|
|||||||
@ -13,18 +13,48 @@ class LayerNormConv(nn.Module):
|
|||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
|
return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
|
||||||
|
|
||||||
|
class LayerGroupNormConv(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
||||||
|
self.layer_norm = operations.GroupNorm(num_groups=out_channels, num_channels=out_channels, affine=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
return torch.nn.functional.gelu(self.layer_norm(x))
|
||||||
|
|
||||||
|
class ConvNoNorm(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
return torch.nn.functional.gelu(x)
|
||||||
|
|
||||||
|
|
||||||
class ConvFeatureEncoder(nn.Module):
|
class ConvFeatureEncoder(nn.Module):
|
||||||
def __init__(self, conv_dim, dtype=None, device=None, operations=None):
|
def __init__(self, conv_dim, conv_bias=False, conv_norm=True, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if conv_norm:
|
||||||
self.conv_layers = nn.ModuleList([
|
self.conv_layers = nn.ModuleList([
|
||||||
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
|
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
|
||||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.conv_layers = nn.ModuleList([
|
||||||
|
LayerGroupNormConv(1, conv_dim, kernel_size=10, stride=5, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||||
|
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||||
|
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||||
|
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||||
|
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||||
|
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||||
|
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||||
])
|
])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -76,6 +106,7 @@ class TransformerEncoder(nn.Module):
|
|||||||
num_heads=12,
|
num_heads=12,
|
||||||
num_layers=12,
|
num_layers=12,
|
||||||
mlp_ratio=4.0,
|
mlp_ratio=4.0,
|
||||||
|
do_stable_layer_norm=True,
|
||||||
dtype=None, device=None, operations=None
|
dtype=None, device=None, operations=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -86,19 +117,24 @@ class TransformerEncoder(nn.Module):
|
|||||||
embed_dim=embed_dim,
|
embed_dim=embed_dim,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
|
do_stable_layer_norm=do_stable_layer_norm,
|
||||||
device=device, dtype=dtype, operations=operations
|
device=device, dtype=dtype, operations=operations
|
||||||
)
|
)
|
||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
|
self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
|
||||||
|
self.do_stable_layer_norm = do_stable_layer_norm
|
||||||
|
|
||||||
def forward(self, x, mask=None):
|
def forward(self, x, mask=None):
|
||||||
x = x + self.pos_conv_embed(x)
|
x = x + self.pos_conv_embed(x)
|
||||||
all_x = ()
|
all_x = ()
|
||||||
|
if not self.do_stable_layer_norm:
|
||||||
|
x = self.layer_norm(x)
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
all_x += (x,)
|
all_x += (x,)
|
||||||
x = layer(x, mask)
|
x = layer(x, mask)
|
||||||
|
if self.do_stable_layer_norm:
|
||||||
x = self.layer_norm(x)
|
x = self.layer_norm(x)
|
||||||
all_x += (x,)
|
all_x += (x,)
|
||||||
return x, all_x
|
return x, all_x
|
||||||
@ -145,6 +181,7 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
embed_dim=768,
|
embed_dim=768,
|
||||||
num_heads=12,
|
num_heads=12,
|
||||||
mlp_ratio=4.0,
|
mlp_ratio=4.0,
|
||||||
|
do_stable_layer_norm=True,
|
||||||
dtype=None, device=None, operations=None
|
dtype=None, device=None, operations=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -154,15 +191,19 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||||
self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
|
self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
|
||||||
self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||||
|
self.do_stable_layer_norm = do_stable_layer_norm
|
||||||
|
|
||||||
def forward(self, x, mask=None):
|
def forward(self, x, mask=None):
|
||||||
residual = x
|
residual = x
|
||||||
|
if self.do_stable_layer_norm:
|
||||||
x = self.layer_norm(x)
|
x = self.layer_norm(x)
|
||||||
x = self.attention(x, mask=mask)
|
x = self.attention(x, mask=mask)
|
||||||
x = residual + x
|
x = residual + x
|
||||||
|
if not self.do_stable_layer_norm:
|
||||||
x = x + self.feed_forward(self.final_layer_norm(x))
|
x = self.layer_norm(x)
|
||||||
return x
|
return self.final_layer_norm(x + self.feed_forward(x))
|
||||||
|
else:
|
||||||
|
return x + self.feed_forward(self.final_layer_norm(x))
|
||||||
|
|
||||||
|
|
||||||
class Wav2Vec2Model(nn.Module):
|
class Wav2Vec2Model(nn.Module):
|
||||||
@ -174,34 +215,38 @@ class Wav2Vec2Model(nn.Module):
|
|||||||
final_dim=256,
|
final_dim=256,
|
||||||
num_heads=16,
|
num_heads=16,
|
||||||
num_layers=24,
|
num_layers=24,
|
||||||
|
conv_norm=True,
|
||||||
|
conv_bias=True,
|
||||||
|
do_normalize=True,
|
||||||
|
do_stable_layer_norm=True,
|
||||||
dtype=None, device=None, operations=None
|
dtype=None, device=None, operations=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
conv_dim = 512
|
conv_dim = 512
|
||||||
self.feature_extractor = ConvFeatureEncoder(conv_dim, device=device, dtype=dtype, operations=operations)
|
self.feature_extractor = ConvFeatureEncoder(conv_dim, conv_norm=conv_norm, conv_bias=conv_bias, device=device, dtype=dtype, operations=operations)
|
||||||
self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
|
self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
|
self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
|
||||||
self.encoder = TransformerEncoder(
|
self.encoder = TransformerEncoder(
|
||||||
embed_dim=embed_dim,
|
embed_dim=embed_dim,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
num_layers=num_layers,
|
num_layers=num_layers,
|
||||||
|
do_stable_layer_norm=do_stable_layer_norm,
|
||||||
device=device, dtype=dtype, operations=operations
|
device=device, dtype=dtype, operations=operations
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, mask_time_indices=None, return_dict=False):
|
def forward(self, x, mask_time_indices=None, return_dict=False):
|
||||||
|
|
||||||
x = torch.mean(x, dim=1)
|
x = torch.mean(x, dim=1)
|
||||||
|
|
||||||
|
if self.do_normalize:
|
||||||
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
|
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
|
||||||
|
|
||||||
features = self.feature_extractor(x)
|
features = self.feature_extractor(x)
|
||||||
features = self.feature_projection(features)
|
features = self.feature_projection(features)
|
||||||
|
|
||||||
batch_size, seq_len, _ = features.shape
|
batch_size, seq_len, _ = features.shape
|
||||||
|
|
||||||
x, all_x = self.encoder(features)
|
x, all_x = self.encoder(features)
|
||||||
|
|
||||||
return x, all_x
|
return x, all_x
|
||||||
|
|||||||
186
comfy/audio_encoders/whisper.py
Executable file
186
comfy/audio_encoders/whisper.py
Executable file
@ -0,0 +1,186 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchaudio
|
||||||
|
from typing import Optional
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
class WhisperFeatureExtractor(nn.Module):
|
||||||
|
def __init__(self, n_mels=128, device=None):
|
||||||
|
super().__init__()
|
||||||
|
self.sample_rate = 16000
|
||||||
|
self.n_fft = 400
|
||||||
|
self.hop_length = 160
|
||||||
|
self.n_mels = n_mels
|
||||||
|
self.chunk_length = 30
|
||||||
|
self.n_samples = 480000
|
||||||
|
|
||||||
|
self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
hop_length=self.hop_length,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
f_min=0,
|
||||||
|
f_max=8000,
|
||||||
|
norm="slaney",
|
||||||
|
mel_scale="slaney",
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
def __call__(self, audio):
|
||||||
|
audio = torch.mean(audio, dim=1)
|
||||||
|
batch_size = audio.shape[0]
|
||||||
|
processed_audio = []
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
aud = audio[i]
|
||||||
|
if aud.shape[0] > self.n_samples:
|
||||||
|
aud = aud[:self.n_samples]
|
||||||
|
elif aud.shape[0] < self.n_samples:
|
||||||
|
aud = F.pad(aud, (0, self.n_samples - aud.shape[0]))
|
||||||
|
processed_audio.append(aud)
|
||||||
|
|
||||||
|
audio = torch.stack(processed_audio)
|
||||||
|
|
||||||
|
mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device)
|
||||||
|
|
||||||
|
log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||||
|
log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0)
|
||||||
|
log_mel_spec = (log_mel_spec + 4.0) / 4.0
|
||||||
|
|
||||||
|
return log_mel_spec
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
assert d_model % n_heads == 0
|
||||||
|
|
||||||
|
self.d_model = d_model
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.d_k = d_model // n_heads
|
||||||
|
|
||||||
|
self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||||
|
self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device)
|
||||||
|
self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||||
|
self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, seq_len, _ = query.shape
|
||||||
|
|
||||||
|
q = self.q_proj(query)
|
||||||
|
k = self.k_proj(key)
|
||||||
|
v = self.v_proj(value)
|
||||||
|
|
||||||
|
attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask)
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderLayer(nn.Module):
|
||||||
|
def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device)
|
||||||
|
self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device)
|
||||||
|
self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
residual = x
|
||||||
|
x = self.self_attn_layer_norm(x)
|
||||||
|
x = self.self_attn(x, x, x, attention_mask)
|
||||||
|
x = residual + x
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = F.gelu(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = residual + x
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AudioEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_mels: int = 128,
|
||||||
|
n_ctx: int = 1500,
|
||||||
|
n_state: int = 1280,
|
||||||
|
n_head: int = 20,
|
||||||
|
n_layer: int = 32,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device)
|
||||||
|
self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations)
|
||||||
|
for _ in range(n_layer)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = F.gelu(self.conv1(x))
|
||||||
|
x = F.gelu(self.conv2(x))
|
||||||
|
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
|
||||||
|
x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x)
|
||||||
|
|
||||||
|
all_x = ()
|
||||||
|
for layer in self.layers:
|
||||||
|
all_x += (x,)
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
all_x += (x,)
|
||||||
|
return x, all_x
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperLargeV3(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_mels: int = 128,
|
||||||
|
n_audio_ctx: int = 1500,
|
||||||
|
n_audio_state: int = 1280,
|
||||||
|
n_audio_head: int = 20,
|
||||||
|
n_audio_layer: int = 32,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device)
|
||||||
|
|
||||||
|
self.encoder = AudioEncoder(
|
||||||
|
n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, audio):
|
||||||
|
mel = self.feature_extractor(audio)
|
||||||
|
x, all_x = self.encoder(mel)
|
||||||
|
return x, all_x
|
||||||
@ -86,24 +86,24 @@ class BatchedBrownianTree:
|
|||||||
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
||||||
|
|
||||||
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
||||||
self.cpu_tree = True
|
self.cpu_tree = kwargs.pop("cpu", True)
|
||||||
if "cpu" in kwargs:
|
|
||||||
self.cpu_tree = kwargs.pop("cpu")
|
|
||||||
t0, t1, self.sign = self.sort(t0, t1)
|
t0, t1, self.sign = self.sort(t0, t1)
|
||||||
w0 = kwargs.get('w0', torch.zeros_like(x))
|
w0 = kwargs.pop('w0', None)
|
||||||
if seed is None:
|
if w0 is None:
|
||||||
seed = torch.randint(0, 2 ** 63 - 1, []).item()
|
w0 = torch.zeros_like(x)
|
||||||
self.batched = True
|
|
||||||
try:
|
|
||||||
assert len(seed) == x.shape[0]
|
|
||||||
w0 = w0[0]
|
|
||||||
except TypeError:
|
|
||||||
seed = [seed]
|
|
||||||
self.batched = False
|
self.batched = False
|
||||||
if self.cpu_tree:
|
if seed is None:
|
||||||
self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
|
seed = (torch.randint(0, 2 ** 63 - 1, ()).item(),)
|
||||||
|
elif isinstance(seed, (tuple, list)):
|
||||||
|
if len(seed) != x.shape[0]:
|
||||||
|
raise ValueError("Passing a list or tuple of seeds to BatchedBrownianTree requires a length matching the batch size.")
|
||||||
|
self.batched = True
|
||||||
|
w0 = w0[0]
|
||||||
else:
|
else:
|
||||||
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
seed = (seed,)
|
||||||
|
if self.cpu_tree:
|
||||||
|
t0, w0, t1 = t0.detach().cpu(), w0.detach().cpu(), t1.detach().cpu()
|
||||||
|
self.trees = tuple(torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sort(a, b):
|
def sort(a, b):
|
||||||
@ -111,11 +111,10 @@ class BatchedBrownianTree:
|
|||||||
|
|
||||||
def __call__(self, t0, t1):
|
def __call__(self, t0, t1):
|
||||||
t0, t1, sign = self.sort(t0, t1)
|
t0, t1, sign = self.sort(t0, t1)
|
||||||
|
device, dtype = t0.device, t0.dtype
|
||||||
if self.cpu_tree:
|
if self.cpu_tree:
|
||||||
w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
|
t0, t1 = t0.detach().cpu().float(), t1.detach().cpu().float()
|
||||||
else:
|
w = torch.stack([tree(t0, t1) for tree in self.trees]).to(device=device, dtype=dtype) * (self.sign * sign)
|
||||||
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
|
||||||
|
|
||||||
return w if self.batched else w[0]
|
return w if self.batched else w[0]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -606,6 +606,11 @@ class HunyuanImage21(LatentFormat):
|
|||||||
|
|
||||||
latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
|
latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
|
||||||
|
|
||||||
|
class HunyuanImage21Refiner(LatentFormat):
|
||||||
|
latent_channels = 64
|
||||||
|
latent_dimensions = 3
|
||||||
|
scale_factor = 1.03682
|
||||||
|
|
||||||
class Hunyuan3Dv2(LatentFormat):
|
class Hunyuan3Dv2(LatentFormat):
|
||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
latent_dimensions = 1
|
latent_dimensions = 1
|
||||||
@ -624,3 +629,20 @@ class Hunyuan3Dv2mini(LatentFormat):
|
|||||||
class ACEAudio(LatentFormat):
|
class ACEAudio(LatentFormat):
|
||||||
latent_channels = 8
|
latent_channels = 8
|
||||||
latent_dimensions = 2
|
latent_dimensions = 2
|
||||||
|
|
||||||
|
class ChromaRadiance(LatentFormat):
|
||||||
|
latent_channels = 3
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.latent_rgb_factors = [
|
||||||
|
# R G B
|
||||||
|
[ 1.0, 0.0, 0.0 ],
|
||||||
|
[ 0.0, 1.0, 0.0 ],
|
||||||
|
[ 0.0, 0.0, 1.0 ]
|
||||||
|
]
|
||||||
|
|
||||||
|
def process_in(self, latent):
|
||||||
|
return latent
|
||||||
|
|
||||||
|
def process_out(self, latent):
|
||||||
|
return latent
|
||||||
|
|||||||
@ -133,6 +133,7 @@ class Attention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
transformer_options={},
|
||||||
**cross_attention_kwargs,
|
**cross_attention_kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.processor(
|
return self.processor(
|
||||||
@ -140,6 +141,7 @@ class Attention(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
transformer_options=transformer_options,
|
||||||
**cross_attention_kwargs,
|
**cross_attention_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -366,6 +368,7 @@ class CustomerAttnProcessor2_0:
|
|||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||||
|
transformer_options={},
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -433,7 +436,7 @@ class CustomerAttnProcessor2_0:
|
|||||||
|
|
||||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
hidden_states = optimized_attention(
|
hidden_states = optimized_attention(
|
||||||
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
|
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options,
|
||||||
).to(query.dtype)
|
).to(query.dtype)
|
||||||
|
|
||||||
# linear proj
|
# linear proj
|
||||||
@ -697,6 +700,7 @@ class LinearTransformerBlock(nn.Module):
|
|||||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||||
temb: torch.FloatTensor = None,
|
temb: torch.FloatTensor = None,
|
||||||
|
transformer_options={},
|
||||||
):
|
):
|
||||||
|
|
||||||
N = hidden_states.shape[0]
|
N = hidden_states.shape[0]
|
||||||
@ -720,6 +724,7 @@ class LinearTransformerBlock(nn.Module):
|
|||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
rotary_freqs_cis=rotary_freqs_cis,
|
rotary_freqs_cis=rotary_freqs_cis,
|
||||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_output, _ = self.attn(
|
attn_output, _ = self.attn(
|
||||||
@ -729,6 +734,7 @@ class LinearTransformerBlock(nn.Module):
|
|||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
rotary_freqs_cis=rotary_freqs_cis,
|
rotary_freqs_cis=rotary_freqs_cis,
|
||||||
rotary_freqs_cis_cross=None,
|
rotary_freqs_cis_cross=None,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_adaln_single:
|
if self.use_adaln_single:
|
||||||
@ -743,6 +749,7 @@ class LinearTransformerBlock(nn.Module):
|
|||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
rotary_freqs_cis=rotary_freqs_cis,
|
rotary_freqs_cis=rotary_freqs_cis,
|
||||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
hidden_states = attn_output + hidden_states
|
hidden_states = attn_output + hidden_states
|
||||||
|
|
||||||
|
|||||||
@ -314,6 +314,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
|||||||
output_length: int = 0,
|
output_length: int = 0,
|
||||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||||
|
transformer_options={},
|
||||||
):
|
):
|
||||||
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
|
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
|
||||||
temb = self.t_block(embedded_timestep)
|
temb = self.t_block(embedded_timestep)
|
||||||
@ -339,6 +340,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
|||||||
rotary_freqs_cis=rotary_freqs_cis,
|
rotary_freqs_cis=rotary_freqs_cis,
|
||||||
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
||||||
temb=temb,
|
temb=temb,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
||||||
@ -393,6 +395,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
|||||||
|
|
||||||
output_length = hidden_states.shape[-1]
|
output_length = hidden_states.shape[-1]
|
||||||
|
|
||||||
|
transformer_options = kwargs.get("transformer_options", {})
|
||||||
output = self.decode(
|
output = self.decode(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
@ -402,6 +405,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
|||||||
output_length=output_length,
|
output_length=output_length,
|
||||||
block_controlnet_hidden_states=block_controlnet_hidden_states,
|
block_controlnet_hidden_states=block_controlnet_hidden_states,
|
||||||
controlnet_scale=controlnet_scale,
|
controlnet_scale=controlnet_scale,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@ -298,7 +298,8 @@ class Attention(nn.Module):
|
|||||||
mask = None,
|
mask = None,
|
||||||
context_mask = None,
|
context_mask = None,
|
||||||
rotary_pos_emb = None,
|
rotary_pos_emb = None,
|
||||||
causal = None
|
causal = None,
|
||||||
|
transformer_options={},
|
||||||
):
|
):
|
||||||
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
||||||
|
|
||||||
@ -363,7 +364,7 @@ class Attention(nn.Module):
|
|||||||
heads_per_kv_head = h // kv_h
|
heads_per_kv_head = h // kv_h
|
||||||
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
||||||
|
|
||||||
out = optimized_attention(q, k, v, h, skip_reshape=True)
|
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
|
||||||
out = self.to_out(out)
|
out = self.to_out(out)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
@ -488,7 +489,8 @@ class TransformerBlock(nn.Module):
|
|||||||
global_cond=None,
|
global_cond=None,
|
||||||
mask = None,
|
mask = None,
|
||||||
context_mask = None,
|
context_mask = None,
|
||||||
rotary_pos_emb = None
|
rotary_pos_emb = None,
|
||||||
|
transformer_options={}
|
||||||
):
|
):
|
||||||
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
|
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
|
||||||
|
|
||||||
@ -498,12 +500,12 @@ class TransformerBlock(nn.Module):
|
|||||||
residual = x
|
residual = x
|
||||||
x = self.pre_norm(x)
|
x = self.pre_norm(x)
|
||||||
x = x * (1 + scale_self) + shift_self
|
x = x * (1 + scale_self) + shift_self
|
||||||
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
|
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
|
||||||
x = x * torch.sigmoid(1 - gate_self)
|
x = x * torch.sigmoid(1 - gate_self)
|
||||||
x = x + residual
|
x = x + residual
|
||||||
|
|
||||||
if context is not None:
|
if context is not None:
|
||||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
if self.conformer is not None:
|
if self.conformer is not None:
|
||||||
x = x + self.conformer(x)
|
x = x + self.conformer(x)
|
||||||
@ -517,10 +519,10 @@ class TransformerBlock(nn.Module):
|
|||||||
x = x + residual
|
x = x + residual
|
||||||
|
|
||||||
else:
|
else:
|
||||||
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
|
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
|
||||||
|
|
||||||
if context is not None:
|
if context is not None:
|
||||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
if self.conformer is not None:
|
if self.conformer is not None:
|
||||||
x = x + self.conformer(x)
|
x = x + self.conformer(x)
|
||||||
@ -606,7 +608,8 @@ class ContinuousTransformer(nn.Module):
|
|||||||
return_info = False,
|
return_info = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
|
transformer_options = kwargs.get("transformer_options", {})
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
batch, seq, device = *x.shape[:2], x.device
|
batch, seq, device = *x.shape[:2], x.device
|
||||||
context = kwargs["context"]
|
context = kwargs["context"]
|
||||||
|
|
||||||
@ -645,13 +648,13 @@ class ContinuousTransformer(nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
|
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
|
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options)
|
||||||
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||||
|
|
||||||
if return_info:
|
if return_info:
|
||||||
|
|||||||
@ -85,7 +85,7 @@ class SingleAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
#@torch.compile()
|
#@torch.compile()
|
||||||
def forward(self, c):
|
def forward(self, c, transformer_options={}):
|
||||||
|
|
||||||
bsz, seqlen1, _ = c.shape
|
bsz, seqlen1, _ = c.shape
|
||||||
|
|
||||||
@ -95,7 +95,7 @@ class SingleAttention(nn.Module):
|
|||||||
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||||
q, k = self.q_norm1(q), self.k_norm1(k)
|
q, k = self.q_norm1(q), self.k_norm1(k)
|
||||||
|
|
||||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||||
c = self.w1o(output)
|
c = self.w1o(output)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@ -144,7 +144,7 @@ class DoubleAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
#@torch.compile()
|
#@torch.compile()
|
||||||
def forward(self, c, x):
|
def forward(self, c, x, transformer_options={}):
|
||||||
|
|
||||||
bsz, seqlen1, _ = c.shape
|
bsz, seqlen1, _ = c.shape
|
||||||
bsz, seqlen2, _ = x.shape
|
bsz, seqlen2, _ = x.shape
|
||||||
@ -168,7 +168,7 @@ class DoubleAttention(nn.Module):
|
|||||||
torch.cat([cv, xv], dim=1),
|
torch.cat([cv, xv], dim=1),
|
||||||
)
|
)
|
||||||
|
|
||||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||||
|
|
||||||
c, x = output.split([seqlen1, seqlen2], dim=1)
|
c, x = output.split([seqlen1, seqlen2], dim=1)
|
||||||
c = self.w1o(c)
|
c = self.w1o(c)
|
||||||
@ -207,7 +207,7 @@ class MMDiTBlock(nn.Module):
|
|||||||
self.is_last = is_last
|
self.is_last = is_last
|
||||||
|
|
||||||
#@torch.compile()
|
#@torch.compile()
|
||||||
def forward(self, c, x, global_cond, **kwargs):
|
def forward(self, c, x, global_cond, transformer_options={}, **kwargs):
|
||||||
|
|
||||||
cres, xres = c, x
|
cres, xres = c, x
|
||||||
|
|
||||||
@ -225,7 +225,7 @@ class MMDiTBlock(nn.Module):
|
|||||||
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
|
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
|
||||||
|
|
||||||
# attention
|
# attention
|
||||||
c, x = self.attn(c, x)
|
c, x = self.attn(c, x, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
|
||||||
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
|
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
|
||||||
@ -255,13 +255,13 @@ class DiTBlock(nn.Module):
|
|||||||
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
#@torch.compile()
|
#@torch.compile()
|
||||||
def forward(self, cx, global_cond, **kwargs):
|
def forward(self, cx, global_cond, transformer_options={}, **kwargs):
|
||||||
cxres = cx
|
cxres = cx
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
|
||||||
global_cond
|
global_cond
|
||||||
).chunk(6, dim=1)
|
).chunk(6, dim=1)
|
||||||
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
|
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
|
||||||
cx = self.attn(cx)
|
cx = self.attn(cx, transformer_options=transformer_options)
|
||||||
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
|
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
|
||||||
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
|
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
|
||||||
cx = gate_mlp.unsqueeze(1) * mlpout
|
cx = gate_mlp.unsqueeze(1) * mlpout
|
||||||
@ -473,13 +473,14 @@ class MMDiT(nn.Module):
|
|||||||
out = {}
|
out = {}
|
||||||
out["txt"], out["img"] = layer(args["txt"],
|
out["txt"], out["img"] = layer(args["txt"],
|
||||||
args["img"],
|
args["img"],
|
||||||
args["vec"])
|
args["vec"],
|
||||||
|
transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
c = out["txt"]
|
c = out["txt"]
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
c, x = layer(c, x, global_cond, **kwargs)
|
c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs)
|
||||||
|
|
||||||
if len(self.single_layers) > 0:
|
if len(self.single_layers) > 0:
|
||||||
c_len = c.size(1)
|
c_len = c.size(1)
|
||||||
@ -488,13 +489,13 @@ class MMDiT(nn.Module):
|
|||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = layer(args["img"], args["vec"])
|
out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
|
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
cx = out["img"]
|
cx = out["img"]
|
||||||
else:
|
else:
|
||||||
cx = layer(cx, global_cond, **kwargs)
|
cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs)
|
||||||
|
|
||||||
x = cx[:, c_len:]
|
x = cx[:, c_len:]
|
||||||
|
|
||||||
|
|||||||
@ -32,12 +32,12 @@ class OptimizedAttention(nn.Module):
|
|||||||
|
|
||||||
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, q, k, v):
|
def forward(self, q, k, v, transformer_options={}):
|
||||||
q = self.to_q(q)
|
q = self.to_q(q)
|
||||||
k = self.to_k(k)
|
k = self.to_k(k)
|
||||||
v = self.to_v(v)
|
v = self.to_v(v)
|
||||||
|
|
||||||
out = optimized_attention(q, k, v, self.heads)
|
out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options)
|
||||||
|
|
||||||
return self.out_proj(out)
|
return self.out_proj(out)
|
||||||
|
|
||||||
@ -47,13 +47,13 @@ class Attention2D(nn.Module):
|
|||||||
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
|
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
|
||||||
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
|
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, x, kv, self_attn=False):
|
def forward(self, x, kv, self_attn=False, transformer_options={}):
|
||||||
orig_shape = x.shape
|
orig_shape = x.shape
|
||||||
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
|
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
|
||||||
if self_attn:
|
if self_attn:
|
||||||
kv = torch.cat([x, kv], dim=1)
|
kv = torch.cat([x, kv], dim=1)
|
||||||
# x = self.attn(x, kv, kv, need_weights=False)[0]
|
# x = self.attn(x, kv, kv, need_weights=False)[0]
|
||||||
x = self.attn(x, kv, kv)
|
x = self.attn(x, kv, kv, transformer_options=transformer_options)
|
||||||
x = x.permute(0, 2, 1).view(*orig_shape)
|
x = x.permute(0, 2, 1).view(*orig_shape)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -114,9 +114,9 @@ class AttnBlock(nn.Module):
|
|||||||
operations.Linear(c_cond, c, dtype=dtype, device=device)
|
operations.Linear(c_cond, c, dtype=dtype, device=device)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, kv):
|
def forward(self, x, kv, transformer_options={}):
|
||||||
kv = self.kv_mapper(kv)
|
kv = self.kv_mapper(kv)
|
||||||
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
|
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -173,7 +173,7 @@ class StageB(nn.Module):
|
|||||||
clip = self.clip_norm(clip)
|
clip = self.clip_norm(clip)
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def _down_encode(self, x, r_embed, clip):
|
def _down_encode(self, x, r_embed, clip, transformer_options={}):
|
||||||
level_outputs = []
|
level_outputs = []
|
||||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||||
for down_block, downscaler, repmap in block_group:
|
for down_block, downscaler, repmap in block_group:
|
||||||
@ -187,7 +187,7 @@ class StageB(nn.Module):
|
|||||||
elif isinstance(block, AttnBlock) or (
|
elif isinstance(block, AttnBlock) or (
|
||||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
AttnBlock)):
|
AttnBlock)):
|
||||||
x = block(x, clip)
|
x = block(x, clip, transformer_options=transformer_options)
|
||||||
elif isinstance(block, TimestepBlock) or (
|
elif isinstance(block, TimestepBlock) or (
|
||||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
TimestepBlock)):
|
TimestepBlock)):
|
||||||
@ -199,7 +199,7 @@ class StageB(nn.Module):
|
|||||||
level_outputs.insert(0, x)
|
level_outputs.insert(0, x)
|
||||||
return level_outputs
|
return level_outputs
|
||||||
|
|
||||||
def _up_decode(self, level_outputs, r_embed, clip):
|
def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}):
|
||||||
x = level_outputs[0]
|
x = level_outputs[0]
|
||||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||||
@ -216,7 +216,7 @@ class StageB(nn.Module):
|
|||||||
elif isinstance(block, AttnBlock) or (
|
elif isinstance(block, AttnBlock) or (
|
||||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
AttnBlock)):
|
AttnBlock)):
|
||||||
x = block(x, clip)
|
x = block(x, clip, transformer_options=transformer_options)
|
||||||
elif isinstance(block, TimestepBlock) or (
|
elif isinstance(block, TimestepBlock) or (
|
||||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
TimestepBlock)):
|
TimestepBlock)):
|
||||||
@ -228,7 +228,7 @@ class StageB(nn.Module):
|
|||||||
x = upscaler(x)
|
x = upscaler(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
|
def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs):
|
||||||
if pixels is None:
|
if pixels is None:
|
||||||
pixels = x.new_zeros(x.size(0), 3, 8, 8)
|
pixels = x.new_zeros(x.size(0), 3, 8, 8)
|
||||||
|
|
||||||
@ -245,8 +245,8 @@ class StageB(nn.Module):
|
|||||||
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
|
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
|
||||||
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
|
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
|
||||||
align_corners=True)
|
align_corners=True)
|
||||||
level_outputs = self._down_encode(x, r_embed, clip)
|
level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options)
|
||||||
x = self._up_decode(level_outputs, r_embed, clip)
|
x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options)
|
||||||
return self.clf(x)
|
return self.clf(x)
|
||||||
|
|
||||||
def update_weights_ema(self, src_model, beta=0.999):
|
def update_weights_ema(self, src_model, beta=0.999):
|
||||||
|
|||||||
@ -182,7 +182,7 @@ class StageC(nn.Module):
|
|||||||
clip = self.clip_norm(clip)
|
clip = self.clip_norm(clip)
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def _down_encode(self, x, r_embed, clip, cnet=None):
|
def _down_encode(self, x, r_embed, clip, cnet=None, transformer_options={}):
|
||||||
level_outputs = []
|
level_outputs = []
|
||||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||||
for down_block, downscaler, repmap in block_group:
|
for down_block, downscaler, repmap in block_group:
|
||||||
@ -201,7 +201,7 @@ class StageC(nn.Module):
|
|||||||
elif isinstance(block, AttnBlock) or (
|
elif isinstance(block, AttnBlock) or (
|
||||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
AttnBlock)):
|
AttnBlock)):
|
||||||
x = block(x, clip)
|
x = block(x, clip, transformer_options=transformer_options)
|
||||||
elif isinstance(block, TimestepBlock) or (
|
elif isinstance(block, TimestepBlock) or (
|
||||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
TimestepBlock)):
|
TimestepBlock)):
|
||||||
@ -213,7 +213,7 @@ class StageC(nn.Module):
|
|||||||
level_outputs.insert(0, x)
|
level_outputs.insert(0, x)
|
||||||
return level_outputs
|
return level_outputs
|
||||||
|
|
||||||
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
|
def _up_decode(self, level_outputs, r_embed, clip, cnet=None, transformer_options={}):
|
||||||
x = level_outputs[0]
|
x = level_outputs[0]
|
||||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||||
@ -235,7 +235,7 @@ class StageC(nn.Module):
|
|||||||
elif isinstance(block, AttnBlock) or (
|
elif isinstance(block, AttnBlock) or (
|
||||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
AttnBlock)):
|
AttnBlock)):
|
||||||
x = block(x, clip)
|
x = block(x, clip, transformer_options=transformer_options)
|
||||||
elif isinstance(block, TimestepBlock) or (
|
elif isinstance(block, TimestepBlock) or (
|
||||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
TimestepBlock)):
|
TimestepBlock)):
|
||||||
@ -247,7 +247,7 @@ class StageC(nn.Module):
|
|||||||
x = upscaler(x)
|
x = upscaler(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
|
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, transformer_options={}, **kwargs):
|
||||||
# Process the conditioning embeddings
|
# Process the conditioning embeddings
|
||||||
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
|
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
|
||||||
for c in self.t_conds:
|
for c in self.t_conds:
|
||||||
@ -262,8 +262,8 @@ class StageC(nn.Module):
|
|||||||
|
|
||||||
# Model Blocks
|
# Model Blocks
|
||||||
x = self.embedding(x)
|
x = self.embedding(x)
|
||||||
level_outputs = self._down_encode(x, r_embed, clip, cnet)
|
level_outputs = self._down_encode(x, r_embed, clip, cnet, transformer_options=transformer_options)
|
||||||
x = self._up_decode(level_outputs, r_embed, clip, cnet)
|
x = self._up_decode(level_outputs, r_embed, clip, cnet, transformer_options=transformer_options)
|
||||||
return self.clf(x)
|
return self.clf(x)
|
||||||
|
|
||||||
def update_weights_ema(self, src_model, beta=0.999):
|
def update_weights_ema(self, src_model, beta=0.999):
|
||||||
|
|||||||
@ -76,7 +76,7 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
self.flipped_img_txt = flipped_img_txt
|
self.flipped_img_txt = flipped_img_txt
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
|
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
|
||||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||||
|
|
||||||
# prepare image for attention
|
# prepare image for attention
|
||||||
@ -95,7 +95,7 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||||
torch.cat((txt_k, img_k), dim=2),
|
torch.cat((txt_k, img_k), dim=2),
|
||||||
torch.cat((txt_v, img_v), dim=2),
|
torch.cat((txt_v, img_v), dim=2),
|
||||||
pe=pe, mask=attn_mask)
|
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||||
|
|
||||||
@ -148,7 +148,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
self.mlp_act = nn.GELU(approximate="tanh")
|
self.mlp_act = nn.GELU(approximate="tanh")
|
||||||
|
|
||||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
|
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
|
||||||
mod = vec
|
mod = vec
|
||||||
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
||||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
@ -157,7 +157,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
q, k = self.norm(q, k, v)
|
q, k = self.norm(q, k, v)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
x.addcmul_(mod.gate, output)
|
x.addcmul_(mod.gate, output)
|
||||||
|
|||||||
@ -151,8 +151,6 @@ class Chroma(nn.Module):
|
|||||||
attn_mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
|
||||||
|
|
||||||
# running on sequences img
|
# running on sequences img
|
||||||
img = self.img_in(img)
|
img = self.img_in(img)
|
||||||
@ -193,14 +191,16 @@ class Chroma(nn.Module):
|
|||||||
txt=args["txt"],
|
txt=args["txt"],
|
||||||
vec=args["vec"],
|
vec=args["vec"],
|
||||||
pe=args["pe"],
|
pe=args["pe"],
|
||||||
attn_mask=args.get("attn_mask"))
|
attn_mask=args.get("attn_mask"),
|
||||||
|
transformer_options=args.get("transformer_options"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": img,
|
out = blocks_replace[("double_block", i)]({"img": img,
|
||||||
"txt": txt,
|
"txt": txt,
|
||||||
"vec": double_mod,
|
"vec": double_mod,
|
||||||
"pe": pe,
|
"pe": pe,
|
||||||
"attn_mask": attn_mask},
|
"attn_mask": attn_mask,
|
||||||
|
"transformer_options": transformer_options},
|
||||||
{"original_block": block_wrap})
|
{"original_block": block_wrap})
|
||||||
txt = out["txt"]
|
txt = out["txt"]
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
@ -209,7 +209,8 @@ class Chroma(nn.Module):
|
|||||||
txt=txt,
|
txt=txt,
|
||||||
vec=double_mod,
|
vec=double_mod,
|
||||||
pe=pe,
|
pe=pe,
|
||||||
attn_mask=attn_mask)
|
attn_mask=attn_mask,
|
||||||
|
transformer_options=transformer_options)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_i = control.get("input")
|
control_i = control.get("input")
|
||||||
@ -229,17 +230,19 @@ class Chroma(nn.Module):
|
|||||||
out["img"] = block(args["img"],
|
out["img"] = block(args["img"],
|
||||||
vec=args["vec"],
|
vec=args["vec"],
|
||||||
pe=args["pe"],
|
pe=args["pe"],
|
||||||
attn_mask=args.get("attn_mask"))
|
attn_mask=args.get("attn_mask"),
|
||||||
|
transformer_options=args.get("transformer_options"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("single_block", i)]({"img": img,
|
out = blocks_replace[("single_block", i)]({"img": img,
|
||||||
"vec": single_mod,
|
"vec": single_mod,
|
||||||
"pe": pe,
|
"pe": pe,
|
||||||
"attn_mask": attn_mask},
|
"attn_mask": attn_mask,
|
||||||
|
"transformer_options": transformer_options},
|
||||||
{"original_block": block_wrap})
|
{"original_block": block_wrap})
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
|
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_o = control.get("output")
|
control_o = control.get("output")
|
||||||
@ -249,6 +252,7 @@ class Chroma(nn.Module):
|
|||||||
img[:, txt.shape[1] :, ...] += add
|
img[:, txt.shape[1] :, ...] += add
|
||||||
|
|
||||||
img = img[:, txt.shape[1] :, ...]
|
img = img[:, txt.shape[1] :, ...]
|
||||||
|
if hasattr(self, "final_layer"):
|
||||||
final_mod = self.get_modulations(mod_vectors, "final")
|
final_mod = self.get_modulations(mod_vectors, "final")
|
||||||
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||||
return img
|
return img
|
||||||
@ -266,6 +270,9 @@ class Chroma(nn.Module):
|
|||||||
|
|
||||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
|
||||||
|
|
||||||
|
if img.ndim != 3 or context.ndim != 3:
|
||||||
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
|
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
|
||||||
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
|
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
|
|||||||
206
comfy/ldm/chroma_radiance/layers.py
Normal file
206
comfy/ldm/chroma_radiance/layers.py
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
# Adapted from https://github.com/lodestone-rock/flow
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from comfy.ldm.flux.layers import RMSNorm
|
||||||
|
|
||||||
|
|
||||||
|
class NerfEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
An embedder module that combines input features with a 2D positional
|
||||||
|
encoding that mimics the Discrete Cosine Transform (DCT).
|
||||||
|
|
||||||
|
This module takes an input tensor of shape (B, P^2, C), where P is the
|
||||||
|
patch size, and enriches it with positional information before projecting
|
||||||
|
it to a new hidden size.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
hidden_size_input: int,
|
||||||
|
max_freqs: int,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initializes the NerfEmbedder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The number of channels in the input tensor.
|
||||||
|
hidden_size_input (int): The desired dimension of the output embedding.
|
||||||
|
max_freqs (int): The number of frequency components to use for both
|
||||||
|
the x and y dimensions of the positional encoding.
|
||||||
|
The total number of positional features will be max_freqs^2.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
self.max_freqs = max_freqs
|
||||||
|
self.hidden_size_input = hidden_size_input
|
||||||
|
|
||||||
|
# A linear layer to project the concatenated input features and
|
||||||
|
# positional encodings to the final output dimension.
|
||||||
|
self.embedder = nn.Sequential(
|
||||||
|
operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache(maxsize=4)
|
||||||
|
def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Generates and caches 2D DCT-like positional embeddings for a given patch size.
|
||||||
|
|
||||||
|
The LRU cache is a performance optimization that avoids recomputing the
|
||||||
|
same positional grid on every forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patch_size (int): The side length of the square input patch.
|
||||||
|
device: The torch device to create the tensors on.
|
||||||
|
dtype: The torch dtype for the tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (1, patch_size^2, max_freqs^2) containing the
|
||||||
|
positional embeddings.
|
||||||
|
"""
|
||||||
|
# Create normalized 1D coordinate grids from 0 to 1.
|
||||||
|
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||||
|
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Create a 2D meshgrid of coordinates.
|
||||||
|
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
|
||||||
|
|
||||||
|
# Reshape positions to be broadcastable with frequencies.
|
||||||
|
# Shape becomes (patch_size^2, 1, 1).
|
||||||
|
pos_x = pos_x.reshape(-1, 1, 1)
|
||||||
|
pos_y = pos_y.reshape(-1, 1, 1)
|
||||||
|
|
||||||
|
# Create a 1D tensor of frequency values from 0 to max_freqs-1.
|
||||||
|
freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# Reshape frequencies to be broadcastable for creating 2D basis functions.
|
||||||
|
# freqs_x shape: (1, max_freqs, 1)
|
||||||
|
# freqs_y shape: (1, 1, max_freqs)
|
||||||
|
freqs_x = freqs[None, :, None]
|
||||||
|
freqs_y = freqs[None, None, :]
|
||||||
|
|
||||||
|
# A custom weighting coefficient, not part of standard DCT.
|
||||||
|
# This seems to down-weight the contribution of higher-frequency interactions.
|
||||||
|
coeffs = (1 + freqs_x * freqs_y) ** -1
|
||||||
|
|
||||||
|
# Calculate the 1D cosine basis functions for x and y coordinates.
|
||||||
|
# This is the core of the DCT formulation.
|
||||||
|
dct_x = torch.cos(pos_x * freqs_x * torch.pi)
|
||||||
|
dct_y = torch.cos(pos_y * freqs_y * torch.pi)
|
||||||
|
|
||||||
|
# Combine the 1D basis functions to create 2D basis functions by element-wise
|
||||||
|
# multiplication, and apply the custom coefficients. Broadcasting handles the
|
||||||
|
# combination of all (pos_x, freqs_x) with all (pos_y, freqs_y).
|
||||||
|
# The result is flattened into a feature vector for each position.
|
||||||
|
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2)
|
||||||
|
|
||||||
|
return dct
|
||||||
|
|
||||||
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass for the embedder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (Tensor): The input tensor of shape (B, P^2, C).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The output tensor of shape (B, P^2, hidden_size_input).
|
||||||
|
"""
|
||||||
|
# Get the batch size, number of pixels, and number of channels.
|
||||||
|
B, P2, C = inputs.shape
|
||||||
|
|
||||||
|
# Infer the patch side length from the number of pixels (P^2).
|
||||||
|
patch_size = int(P2 ** 0.5)
|
||||||
|
|
||||||
|
input_dtype = inputs.dtype
|
||||||
|
inputs = inputs.to(dtype=self.dtype)
|
||||||
|
|
||||||
|
# Fetch the pre-computed or cached positional embeddings.
|
||||||
|
dct = self.fetch_pos(patch_size, inputs.device, self.dtype)
|
||||||
|
|
||||||
|
# Repeat the positional embeddings for each item in the batch.
|
||||||
|
dct = dct.repeat(B, 1, 1)
|
||||||
|
|
||||||
|
# Concatenate the original input features with the positional embeddings
|
||||||
|
# along the feature dimension.
|
||||||
|
inputs = torch.cat((inputs, dct), dim=-1)
|
||||||
|
|
||||||
|
# Project the combined tensor to the target hidden size.
|
||||||
|
return self.embedder(inputs).to(dtype=input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class NerfGLUBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A NerfBlock using a Gated Linear Unit (GLU) like MLP.
|
||||||
|
"""
|
||||||
|
def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
# The total number of parameters for the MLP is increased to accommodate
|
||||||
|
# the gate, value, and output projection matrices.
|
||||||
|
# We now need to generate parameters for 3 matrices.
|
||||||
|
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
||||||
|
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
||||||
|
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size, num_x, hidden_size_x = x.shape
|
||||||
|
mlp_params = self.param_generator(s)
|
||||||
|
|
||||||
|
# Split the generated parameters into three parts for the gate, value, and output projection.
|
||||||
|
fc1_gate_params, fc1_value_params, fc2_params = mlp_params.chunk(3, dim=-1)
|
||||||
|
|
||||||
|
# Reshape the parameters into matrices for batch matrix multiplication.
|
||||||
|
fc1_gate = fc1_gate_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
|
||||||
|
fc1_value = fc1_value_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
|
||||||
|
fc2 = fc2_params.view(batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x)
|
||||||
|
|
||||||
|
# Normalize the generated weight matrices as in the original implementation.
|
||||||
|
fc1_gate = torch.nn.functional.normalize(fc1_gate, dim=-2)
|
||||||
|
fc1_value = torch.nn.functional.normalize(fc1_value, dim=-2)
|
||||||
|
fc2 = torch.nn.functional.normalize(fc2, dim=-2)
|
||||||
|
|
||||||
|
res_x = x
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
# Apply the final output projection.
|
||||||
|
x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2)
|
||||||
|
|
||||||
|
return x + res_x
|
||||||
|
|
||||||
|
|
||||||
|
class NerfFinalLayer(nn.Module):
|
||||||
|
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
|
||||||
|
# So we temporarily move the channel dimension to the end for the norm operation.
|
||||||
|
return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class NerfFinalLayerConv(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.conv = operations.Conv2d(
|
||||||
|
in_channels=hidden_size,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
|
||||||
|
# So we temporarily move the channel dimension to the end for the norm operation.
|
||||||
|
return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1))
|
||||||
329
comfy/ldm/chroma_radiance/model.py
Normal file
329
comfy/ldm/chroma_radiance/model.py
Normal file
@ -0,0 +1,329 @@
|
|||||||
|
# Credits:
|
||||||
|
# Original Flux code can be found on: https://github.com/black-forest-labs/flux
|
||||||
|
# Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from einops import repeat
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
|
|
||||||
|
from comfy.ldm.chroma.model import Chroma, ChromaParams
|
||||||
|
from comfy.ldm.chroma.layers import (
|
||||||
|
DoubleStreamBlock,
|
||||||
|
SingleStreamBlock,
|
||||||
|
Approximator,
|
||||||
|
)
|
||||||
|
from .layers import (
|
||||||
|
NerfEmbedder,
|
||||||
|
NerfGLUBlock,
|
||||||
|
NerfFinalLayer,
|
||||||
|
NerfFinalLayerConv,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChromaRadianceParams(ChromaParams):
|
||||||
|
patch_size: int
|
||||||
|
nerf_hidden_size: int
|
||||||
|
nerf_mlp_ratio: int
|
||||||
|
nerf_depth: int
|
||||||
|
nerf_max_freqs: int
|
||||||
|
# Setting nerf_tile_size to 0 disables tiling.
|
||||||
|
nerf_tile_size: int
|
||||||
|
# Currently one of linear (legacy) or conv.
|
||||||
|
nerf_final_head_type: str
|
||||||
|
# None means use the same dtype as the model.
|
||||||
|
nerf_embedder_dtype: Optional[torch.dtype]
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaRadiance(Chroma):
|
||||||
|
"""
|
||||||
|
Transformer model for flow matching on sequences.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||||
|
if operations is None:
|
||||||
|
raise RuntimeError("Attempt to create ChromaRadiance object without setting operations")
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.dtype = dtype
|
||||||
|
params = ChromaRadianceParams(**kwargs)
|
||||||
|
self.params = params
|
||||||
|
self.patch_size = params.patch_size
|
||||||
|
self.in_channels = params.in_channels
|
||||||
|
self.out_channels = params.out_channels
|
||||||
|
if params.hidden_size % params.num_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||||
|
)
|
||||||
|
pe_dim = params.hidden_size // params.num_heads
|
||||||
|
if sum(params.axes_dim) != pe_dim:
|
||||||
|
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||||
|
self.hidden_size = params.hidden_size
|
||||||
|
self.num_heads = params.num_heads
|
||||||
|
self.in_dim = params.in_dim
|
||||||
|
self.out_dim = params.out_dim
|
||||||
|
self.hidden_dim = params.hidden_dim
|
||||||
|
self.n_layers = params.n_layers
|
||||||
|
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||||
|
self.img_in_patch = operations.Conv2d(
|
||||||
|
params.in_channels,
|
||||||
|
params.hidden_size,
|
||||||
|
kernel_size=params.patch_size,
|
||||||
|
stride=params.patch_size,
|
||||||
|
bias=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
# set as nn identity for now, will overwrite it later.
|
||||||
|
self.distilled_guidance_layer = Approximator(
|
||||||
|
in_dim=self.in_dim,
|
||||||
|
hidden_dim=self.hidden_dim,
|
||||||
|
out_dim=self.out_dim,
|
||||||
|
n_layers=self.n_layers,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.double_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DoubleStreamBlock(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_heads,
|
||||||
|
mlp_ratio=params.mlp_ratio,
|
||||||
|
qkv_bias=params.qkv_bias,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
for _ in range(params.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.single_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
SingleStreamBlock(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_heads,
|
||||||
|
mlp_ratio=params.mlp_ratio,
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
)
|
||||||
|
for _ in range(params.depth_single_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# pixel channel concat with DCT
|
||||||
|
self.nerf_image_embedder = NerfEmbedder(
|
||||||
|
in_channels=params.in_channels,
|
||||||
|
hidden_size_input=params.nerf_hidden_size,
|
||||||
|
max_freqs=params.nerf_max_freqs,
|
||||||
|
dtype=params.nerf_embedder_dtype or dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.nerf_blocks = nn.ModuleList([
|
||||||
|
NerfGLUBlock(
|
||||||
|
hidden_size_s=params.hidden_size,
|
||||||
|
hidden_size_x=params.nerf_hidden_size,
|
||||||
|
mlp_ratio=params.nerf_mlp_ratio,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
) for _ in range(params.nerf_depth)
|
||||||
|
])
|
||||||
|
|
||||||
|
if params.nerf_final_head_type == "linear":
|
||||||
|
self.nerf_final_layer = NerfFinalLayer(
|
||||||
|
params.nerf_hidden_size,
|
||||||
|
out_channels=params.in_channels,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
elif params.nerf_final_head_type == "conv":
|
||||||
|
self.nerf_final_layer_conv = NerfFinalLayerConv(
|
||||||
|
params.nerf_hidden_size,
|
||||||
|
out_channels=params.in_channels,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}"
|
||||||
|
raise ValueError(errstr)
|
||||||
|
|
||||||
|
self.skip_mmdit = []
|
||||||
|
self.skip_dit = []
|
||||||
|
self.lite = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _nerf_final_layer(self) -> nn.Module:
|
||||||
|
if self.params.nerf_final_head_type == "linear":
|
||||||
|
return self.nerf_final_layer
|
||||||
|
if self.params.nerf_final_head_type == "conv":
|
||||||
|
return self.nerf_final_layer_conv
|
||||||
|
# Impossible to get here as we raise an error on unexpected types on initialization.
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def img_in(self, img: Tensor) -> Tensor:
|
||||||
|
img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P]
|
||||||
|
# flatten into a sequence for the transformer.
|
||||||
|
return img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden]
|
||||||
|
|
||||||
|
def forward_nerf(
|
||||||
|
self,
|
||||||
|
img_orig: Tensor,
|
||||||
|
img_out: Tensor,
|
||||||
|
params: ChromaRadianceParams,
|
||||||
|
) -> Tensor:
|
||||||
|
B, C, H, W = img_orig.shape
|
||||||
|
num_patches = img_out.shape[1]
|
||||||
|
patch_size = params.patch_size
|
||||||
|
|
||||||
|
# Store the raw pixel values of each patch for the NeRF head later.
|
||||||
|
# unfold creates patches: [B, C * P * P, NumPatches]
|
||||||
|
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
|
||||||
|
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
|
||||||
|
|
||||||
|
if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
|
||||||
|
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
|
||||||
|
# the tile size.
|
||||||
|
img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params)
|
||||||
|
else:
|
||||||
|
# Reshape for per-patch processing
|
||||||
|
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
|
||||||
|
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
||||||
|
|
||||||
|
# Get DCT-encoded pixel embeddings [pixel-dct]
|
||||||
|
img_dct = self.nerf_image_embedder(nerf_pixels)
|
||||||
|
|
||||||
|
# Pass through the dynamic MLP blocks (the NeRF)
|
||||||
|
for block in self.nerf_blocks:
|
||||||
|
img_dct = block(img_dct, nerf_hidden)
|
||||||
|
|
||||||
|
# Reassemble the patches into the final image.
|
||||||
|
img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P]
|
||||||
|
# Reshape to combine with batch dimension for fold
|
||||||
|
img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P]
|
||||||
|
img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches]
|
||||||
|
img_dct = nn.functional.fold(
|
||||||
|
img_dct,
|
||||||
|
output_size=(H, W),
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
)
|
||||||
|
return self._nerf_final_layer(img_dct)
|
||||||
|
|
||||||
|
def forward_tiled_nerf(
|
||||||
|
self,
|
||||||
|
nerf_hidden: Tensor,
|
||||||
|
nerf_pixels: Tensor,
|
||||||
|
batch: int,
|
||||||
|
channels: int,
|
||||||
|
num_patches: int,
|
||||||
|
patch_size: int,
|
||||||
|
params: ChromaRadianceParams,
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Processes the NeRF head in tiles to save memory.
|
||||||
|
nerf_hidden has shape [B, L, D]
|
||||||
|
nerf_pixels has shape [B, L, C * P * P]
|
||||||
|
"""
|
||||||
|
tile_size = params.nerf_tile_size
|
||||||
|
output_tiles = []
|
||||||
|
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
|
||||||
|
for i in range(0, num_patches, tile_size):
|
||||||
|
end = min(i + tile_size, num_patches)
|
||||||
|
|
||||||
|
# Slice the current tile from the input tensors
|
||||||
|
nerf_hidden_tile = nerf_hidden[:, i:end, :]
|
||||||
|
nerf_pixels_tile = nerf_pixels[:, i:end, :]
|
||||||
|
|
||||||
|
# Get the actual number of patches in this tile (can be smaller for the last tile)
|
||||||
|
num_patches_tile = nerf_hidden_tile.shape[1]
|
||||||
|
|
||||||
|
# Reshape the tile for per-patch processing
|
||||||
|
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
|
||||||
|
nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size)
|
||||||
|
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
|
||||||
|
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
|
||||||
|
|
||||||
|
# get DCT-encoded pixel embeddings [pixel-dct]
|
||||||
|
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
|
||||||
|
|
||||||
|
# pass through the dynamic MLP blocks (the NeRF)
|
||||||
|
for block in self.nerf_blocks:
|
||||||
|
img_dct_tile = block(img_dct_tile, nerf_hidden_tile)
|
||||||
|
|
||||||
|
output_tiles.append(img_dct_tile)
|
||||||
|
|
||||||
|
# Concatenate the processed tiles along the patch dimension
|
||||||
|
return torch.cat(output_tiles, dim=0)
|
||||||
|
|
||||||
|
def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams:
|
||||||
|
params = self.params
|
||||||
|
if not overrides:
|
||||||
|
return params
|
||||||
|
params_dict = {k: getattr(params, k) for k in params.__dataclass_fields__}
|
||||||
|
nullable_keys = frozenset(("nerf_embedder_dtype",))
|
||||||
|
bad_keys = tuple(k for k in overrides if k not in params_dict)
|
||||||
|
if bad_keys:
|
||||||
|
e = f"Unknown key(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||||
|
raise ValueError(e)
|
||||||
|
bad_keys = tuple(
|
||||||
|
k
|
||||||
|
for k, v in overrides.items()
|
||||||
|
if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys)
|
||||||
|
)
|
||||||
|
if bad_keys:
|
||||||
|
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||||
|
raise ValueError(e)
|
||||||
|
# At this point it's all valid keys and values so we can merge with the existing params.
|
||||||
|
params_dict |= overrides
|
||||||
|
return params.__class__(**params_dict)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
timestep: Tensor,
|
||||||
|
context: Tensor,
|
||||||
|
guidance: Optional[Tensor],
|
||||||
|
control: Optional[dict]=None,
|
||||||
|
transformer_options: dict={},
|
||||||
|
**kwargs: dict,
|
||||||
|
) -> Tensor:
|
||||||
|
bs, c, h, w = x.shape
|
||||||
|
img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||||
|
|
||||||
|
if img.ndim != 4:
|
||||||
|
raise ValueError("Input img tensor must be in [B, C, H, W] format.")
|
||||||
|
if context.ndim != 3:
|
||||||
|
raise ValueError("Input txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
|
params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {}))
|
||||||
|
|
||||||
|
h_len = (img.shape[-2] // self.patch_size)
|
||||||
|
w_len = (img.shape[-1] // self.patch_size)
|
||||||
|
|
||||||
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
|
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||||
|
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||||
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
img_out = self.forward_orig(
|
||||||
|
img,
|
||||||
|
img_ids,
|
||||||
|
context,
|
||||||
|
txt_ids,
|
||||||
|
timestep,
|
||||||
|
guidance,
|
||||||
|
control,
|
||||||
|
transformer_options,
|
||||||
|
attn_mask=kwargs.get("attention_mask", None),
|
||||||
|
)
|
||||||
|
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
|
||||||
@ -176,6 +176,7 @@ class Attention(nn.Module):
|
|||||||
context=None,
|
context=None,
|
||||||
mask=None,
|
mask=None,
|
||||||
rope_emb=None,
|
rope_emb=None,
|
||||||
|
transformer_options={},
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -184,7 +185,7 @@ class Attention(nn.Module):
|
|||||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||||
"""
|
"""
|
||||||
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
||||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True, transformer_options=transformer_options)
|
||||||
del q, k, v
|
del q, k, v
|
||||||
out = rearrange(out, " b n s c -> s b (n c)")
|
out = rearrange(out, " b n s c -> s b (n c)")
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
@ -546,6 +547,7 @@ class VideoAttn(nn.Module):
|
|||||||
context: Optional[torch.Tensor] = None,
|
context: Optional[torch.Tensor] = None,
|
||||||
crossattn_mask: Optional[torch.Tensor] = None,
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
|
transformer_options: Optional[dict] = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass for video attention.
|
Forward pass for video attention.
|
||||||
@ -571,6 +573,7 @@ class VideoAttn(nn.Module):
|
|||||||
context_M_B_D,
|
context_M_B_D,
|
||||||
crossattn_mask,
|
crossattn_mask,
|
||||||
rope_emb=rope_emb_L_1_1_D,
|
rope_emb=rope_emb_L_1_1_D,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
|
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
|
||||||
return x_T_H_W_B_D
|
return x_T_H_W_B_D
|
||||||
@ -665,6 +668,7 @@ class DITBuildingBlock(nn.Module):
|
|||||||
crossattn_mask: Optional[torch.Tensor] = None,
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
|
transformer_options: Optional[dict] = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass for dynamically configured blocks with adaptive normalization.
|
Forward pass for dynamically configured blocks with adaptive normalization.
|
||||||
@ -702,6 +706,7 @@ class DITBuildingBlock(nn.Module):
|
|||||||
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||||
context=None,
|
context=None,
|
||||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
elif self.block_type in ["cross_attn", "ca"]:
|
elif self.block_type in ["cross_attn", "ca"]:
|
||||||
x = x + gate_1_1_1_B_D * self.block(
|
x = x + gate_1_1_1_B_D * self.block(
|
||||||
@ -709,6 +714,7 @@ class DITBuildingBlock(nn.Module):
|
|||||||
context=crossattn_emb,
|
context=crossattn_emb,
|
||||||
crossattn_mask=crossattn_mask,
|
crossattn_mask=crossattn_mask,
|
||||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown block type: {self.block_type}")
|
raise ValueError(f"Unknown block type: {self.block_type}")
|
||||||
@ -784,6 +790,7 @@ class GeneralDITTransformerBlock(nn.Module):
|
|||||||
crossattn_mask: Optional[torch.Tensor] = None,
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
|
transformer_options: Optional[dict] = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(
|
x = block(
|
||||||
@ -793,5 +800,6 @@ class GeneralDITTransformerBlock(nn.Module):
|
|||||||
crossattn_mask,
|
crossattn_mask,
|
||||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -520,6 +520,7 @@ class GeneralDIT(nn.Module):
|
|||||||
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||||
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
|
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
|
||||||
|
|
||||||
|
transformer_options = kwargs.get("transformer_options", {})
|
||||||
for _, block in self.blocks.items():
|
for _, block in self.blocks.items():
|
||||||
assert (
|
assert (
|
||||||
self.blocks["block0"].x_format == block.x_format
|
self.blocks["block0"].x_format == block.x_format
|
||||||
@ -534,6 +535,7 @@ class GeneralDIT(nn.Module):
|
|||||||
crossattn_mask,
|
crossattn_mask,
|
||||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
||||||
|
|||||||
@ -44,7 +44,7 @@ class GPT2FeedForward(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
|
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
|
||||||
"""Computes multi-head attention using PyTorch's native implementation.
|
"""Computes multi-head attention using PyTorch's native implementation.
|
||||||
|
|
||||||
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
|
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
|
||||||
@ -71,7 +71,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H
|
|||||||
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
|
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
|
||||||
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||||
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||||
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
|
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
@ -180,8 +180,8 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
return q, k, v
|
return q, k, v
|
||||||
|
|
||||||
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
|
||||||
result = self.attn_op(q, k, v) # [B, S, H, D]
|
result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D]
|
||||||
return self.output_dropout(self.output_proj(result))
|
return self.output_dropout(self.output_proj(result))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -189,6 +189,7 @@ class Attention(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
context: Optional[torch.Tensor] = None,
|
context: Optional[torch.Tensor] = None,
|
||||||
rope_emb: Optional[torch.Tensor] = None,
|
rope_emb: Optional[torch.Tensor] = None,
|
||||||
|
transformer_options: Optional[dict] = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -196,7 +197,7 @@ class Attention(nn.Module):
|
|||||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||||
"""
|
"""
|
||||||
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
||||||
return self.compute_attention(q, k, v)
|
return self.compute_attention(q, k, v, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
|
||||||
class Timesteps(nn.Module):
|
class Timesteps(nn.Module):
|
||||||
@ -459,6 +460,7 @@ class Block(nn.Module):
|
|||||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||||
|
transformer_options: Optional[dict] = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if extra_per_block_pos_emb is not None:
|
if extra_per_block_pos_emb is not None:
|
||||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||||
@ -512,6 +514,7 @@ class Block(nn.Module):
|
|||||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||||
None,
|
None,
|
||||||
rope_emb=rope_emb_L_1_1_D,
|
rope_emb=rope_emb_L_1_1_D,
|
||||||
|
transformer_options=transformer_options,
|
||||||
),
|
),
|
||||||
"b (t h w) d -> b t h w d",
|
"b (t h w) d -> b t h w d",
|
||||||
t=T,
|
t=T,
|
||||||
@ -525,6 +528,7 @@ class Block(nn.Module):
|
|||||||
layer_norm_cross_attn: Callable,
|
layer_norm_cross_attn: Callable,
|
||||||
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
|
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||||
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
|
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||||
|
transformer_options: Optional[dict] = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
_normalized_x_B_T_H_W_D = _fn(
|
_normalized_x_B_T_H_W_D = _fn(
|
||||||
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
|
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
|
||||||
@ -534,6 +538,7 @@ class Block(nn.Module):
|
|||||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||||
crossattn_emb,
|
crossattn_emb,
|
||||||
rope_emb=rope_emb_L_1_1_D,
|
rope_emb=rope_emb_L_1_1_D,
|
||||||
|
transformer_options=transformer_options,
|
||||||
),
|
),
|
||||||
"b (t h w) d -> b t h w d",
|
"b (t h w) d -> b t h w d",
|
||||||
t=T,
|
t=T,
|
||||||
@ -547,6 +552,7 @@ class Block(nn.Module):
|
|||||||
self.layer_norm_cross_attn,
|
self.layer_norm_cross_attn,
|
||||||
scale_cross_attn_B_T_1_1_D,
|
scale_cross_attn_B_T_1_1_D,
|
||||||
shift_cross_attn_B_T_1_1_D,
|
shift_cross_attn_B_T_1_1_D,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
||||||
|
|
||||||
@ -865,6 +871,7 @@ class MiniTrainDIT(nn.Module):
|
|||||||
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
|
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
|
||||||
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
|
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
|
||||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||||
|
"transformer_options": kwargs.get("transformer_options", {}),
|
||||||
}
|
}
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x_B_T_H_W_D = block(
|
x_B_T_H_W_D = block(
|
||||||
|
|||||||
@ -159,7 +159,7 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
self.flipped_img_txt = flipped_img_txt
|
self.flipped_img_txt = flipped_img_txt
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||||
img_mod1, img_mod2 = self.img_mod(vec)
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||||
|
|
||||||
@ -182,7 +182,7 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
||||||
torch.cat((img_k, txt_k), dim=2),
|
torch.cat((img_k, txt_k), dim=2),
|
||||||
torch.cat((img_v, txt_v), dim=2),
|
torch.cat((img_v, txt_v), dim=2),
|
||||||
pe=pe, mask=attn_mask)
|
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||||
else:
|
else:
|
||||||
@ -190,7 +190,7 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||||
torch.cat((txt_k, img_k), dim=2),
|
torch.cat((txt_k, img_k), dim=2),
|
||||||
torch.cat((txt_v, img_v), dim=2),
|
torch.cat((txt_v, img_v), dim=2),
|
||||||
pe=pe, mask=attn_mask)
|
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
|
|
||||||
@ -244,7 +244,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
self.mlp_act = nn.GELU(approximate="tanh")
|
self.mlp_act = nn.GELU(approximate="tanh")
|
||||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
|
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
|
||||||
mod, _ = self.modulation(vec)
|
mod, _ = self.modulation(vec)
|
||||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
|
|
||||||
@ -252,7 +252,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
q, k = self.norm(q, k, v)
|
q, k = self.norm(q, k, v)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from comfy.ldm.modules.attention import optimized_attention
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||||
q_shape = q.shape
|
q_shape = q.shape
|
||||||
k_shape = k.shape
|
k_shape = k.shape
|
||||||
|
|
||||||
@ -17,7 +17,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
|||||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||||
|
|
||||||
heads = q.shape[1]
|
heads = q.shape[1]
|
||||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -35,11 +35,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|||||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||||
return out.to(dtype=torch.float32, device=pos.device)
|
return out.to(dtype=torch.float32, device=pos.device)
|
||||||
|
|
||||||
|
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||||
|
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||||
|
x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1]
|
||||||
|
return x_out.reshape(*x.shape).type_as(x)
|
||||||
|
|
||||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||||
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
|
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||||
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
|
|
||||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
|
||||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
|
||||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
|
||||||
|
|
||||||
|
|||||||
@ -144,14 +144,16 @@ class Flux(nn.Module):
|
|||||||
txt=args["txt"],
|
txt=args["txt"],
|
||||||
vec=args["vec"],
|
vec=args["vec"],
|
||||||
pe=args["pe"],
|
pe=args["pe"],
|
||||||
attn_mask=args.get("attn_mask"))
|
attn_mask=args.get("attn_mask"),
|
||||||
|
transformer_options=args.get("transformer_options"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": img,
|
out = blocks_replace[("double_block", i)]({"img": img,
|
||||||
"txt": txt,
|
"txt": txt,
|
||||||
"vec": vec,
|
"vec": vec,
|
||||||
"pe": pe,
|
"pe": pe,
|
||||||
"attn_mask": attn_mask},
|
"attn_mask": attn_mask,
|
||||||
|
"transformer_options": transformer_options},
|
||||||
{"original_block": block_wrap})
|
{"original_block": block_wrap})
|
||||||
txt = out["txt"]
|
txt = out["txt"]
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
@ -160,7 +162,8 @@ class Flux(nn.Module):
|
|||||||
txt=txt,
|
txt=txt,
|
||||||
vec=vec,
|
vec=vec,
|
||||||
pe=pe,
|
pe=pe,
|
||||||
attn_mask=attn_mask)
|
attn_mask=attn_mask,
|
||||||
|
transformer_options=transformer_options)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_i = control.get("input")
|
control_i = control.get("input")
|
||||||
@ -181,17 +184,19 @@ class Flux(nn.Module):
|
|||||||
out["img"] = block(args["img"],
|
out["img"] = block(args["img"],
|
||||||
vec=args["vec"],
|
vec=args["vec"],
|
||||||
pe=args["pe"],
|
pe=args["pe"],
|
||||||
attn_mask=args.get("attn_mask"))
|
attn_mask=args.get("attn_mask"),
|
||||||
|
transformer_options=args.get("transformer_options"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("single_block", i)]({"img": img,
|
out = blocks_replace[("single_block", i)]({"img": img,
|
||||||
"vec": vec,
|
"vec": vec,
|
||||||
"pe": pe,
|
"pe": pe,
|
||||||
"attn_mask": attn_mask},
|
"attn_mask": attn_mask,
|
||||||
|
"transformer_options": transformer_options},
|
||||||
{"original_block": block_wrap})
|
{"original_block": block_wrap})
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_o = control.get("output")
|
control_o = control.get("output")
|
||||||
|
|||||||
@ -109,6 +109,7 @@ class AsymmetricAttention(nn.Module):
|
|||||||
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
||||||
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
||||||
crop_y,
|
crop_y,
|
||||||
|
transformer_options={},
|
||||||
**rope_rotation,
|
**rope_rotation,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
rope_cos = rope_rotation.get("rope_cos")
|
rope_cos = rope_rotation.get("rope_cos")
|
||||||
@ -143,7 +144,7 @@ class AsymmetricAttention(nn.Module):
|
|||||||
|
|
||||||
xy = optimized_attention(q,
|
xy = optimized_attention(q,
|
||||||
k,
|
k,
|
||||||
v, self.num_heads, skip_reshape=True)
|
v, self.num_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||||
|
|
||||||
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
|
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
|
||||||
x = self.proj_x(x)
|
x = self.proj_x(x)
|
||||||
@ -224,6 +225,7 @@ class AsymmetricJointBlock(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
c: torch.Tensor,
|
c: torch.Tensor,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
|
transformer_options={},
|
||||||
**attn_kwargs,
|
**attn_kwargs,
|
||||||
):
|
):
|
||||||
"""Forward pass of a block.
|
"""Forward pass of a block.
|
||||||
@ -256,6 +258,7 @@ class AsymmetricJointBlock(nn.Module):
|
|||||||
y,
|
y,
|
||||||
scale_x=scale_msa_x,
|
scale_x=scale_msa_x,
|
||||||
scale_y=scale_msa_y,
|
scale_y=scale_msa_y,
|
||||||
|
transformer_options=transformer_options,
|
||||||
**attn_kwargs,
|
**attn_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -524,10 +527,11 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
args["txt"],
|
args["txt"],
|
||||||
rope_cos=args["rope_cos"],
|
rope_cos=args["rope_cos"],
|
||||||
rope_sin=args["rope_sin"],
|
rope_sin=args["rope_sin"],
|
||||||
crop_y=args["num_tokens"]
|
crop_y=args["num_tokens"],
|
||||||
|
transformer_options=args["transformer_options"]
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
y_feat = out["txt"]
|
y_feat = out["txt"]
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
@ -538,6 +542,7 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
rope_cos=rope_cos,
|
rope_cos=rope_cos,
|
||||||
rope_sin=rope_sin,
|
rope_sin=rope_sin,
|
||||||
crop_y=num_tokens,
|
crop_y=num_tokens,
|
||||||
|
transformer_options=transformer_options,
|
||||||
) # (B, M, D), (B, L, D)
|
) # (B, M, D), (B, L, D)
|
||||||
del y_feat # Final layers don't use dense text features.
|
del y_feat # Final layers don't use dense text features.
|
||||||
|
|
||||||
|
|||||||
@ -72,8 +72,8 @@ class TimestepEmbed(nn.Module):
|
|||||||
return t_emb
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, transformer_options={}):
|
||||||
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
|
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2], transformer_options=transformer_options)
|
||||||
|
|
||||||
|
|
||||||
class HiDreamAttnProcessor_flashattn:
|
class HiDreamAttnProcessor_flashattn:
|
||||||
@ -86,6 +86,7 @@ class HiDreamAttnProcessor_flashattn:
|
|||||||
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||||
text_tokens: Optional[torch.FloatTensor] = None,
|
text_tokens: Optional[torch.FloatTensor] = None,
|
||||||
rope: torch.FloatTensor = None,
|
rope: torch.FloatTensor = None,
|
||||||
|
transformer_options={},
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
@ -133,7 +134,7 @@ class HiDreamAttnProcessor_flashattn:
|
|||||||
query = torch.cat([query_1, query_2], dim=-1)
|
query = torch.cat([query_1, query_2], dim=-1)
|
||||||
key = torch.cat([key_1, key_2], dim=-1)
|
key = torch.cat([key_1, key_2], dim=-1)
|
||||||
|
|
||||||
hidden_states = attention(query, key, value)
|
hidden_states = attention(query, key, value, transformer_options=transformer_options)
|
||||||
|
|
||||||
if not attn.single:
|
if not attn.single:
|
||||||
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
|
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
|
||||||
@ -199,6 +200,7 @@ class HiDreamAttention(nn.Module):
|
|||||||
image_tokens_masks: torch.FloatTensor = None,
|
image_tokens_masks: torch.FloatTensor = None,
|
||||||
norm_text_tokens: torch.FloatTensor = None,
|
norm_text_tokens: torch.FloatTensor = None,
|
||||||
rope: torch.FloatTensor = None,
|
rope: torch.FloatTensor = None,
|
||||||
|
transformer_options={},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.processor(
|
return self.processor(
|
||||||
self,
|
self,
|
||||||
@ -206,6 +208,7 @@ class HiDreamAttention(nn.Module):
|
|||||||
image_tokens_masks = image_tokens_masks,
|
image_tokens_masks = image_tokens_masks,
|
||||||
text_tokens = norm_text_tokens,
|
text_tokens = norm_text_tokens,
|
||||||
rope = rope,
|
rope = rope,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -406,7 +409,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
|
|||||||
text_tokens: Optional[torch.FloatTensor] = None,
|
text_tokens: Optional[torch.FloatTensor] = None,
|
||||||
adaln_input: Optional[torch.FloatTensor] = None,
|
adaln_input: Optional[torch.FloatTensor] = None,
|
||||||
rope: torch.FloatTensor = None,
|
rope: torch.FloatTensor = None,
|
||||||
|
transformer_options={},
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
wtype = image_tokens.dtype
|
wtype = image_tokens.dtype
|
||||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
|
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
|
||||||
@ -419,6 +422,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
|
|||||||
norm_image_tokens,
|
norm_image_tokens,
|
||||||
image_tokens_masks,
|
image_tokens_masks,
|
||||||
rope = rope,
|
rope = rope,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||||
|
|
||||||
@ -483,6 +487,7 @@ class HiDreamImageTransformerBlock(nn.Module):
|
|||||||
text_tokens: Optional[torch.FloatTensor] = None,
|
text_tokens: Optional[torch.FloatTensor] = None,
|
||||||
adaln_input: Optional[torch.FloatTensor] = None,
|
adaln_input: Optional[torch.FloatTensor] = None,
|
||||||
rope: torch.FloatTensor = None,
|
rope: torch.FloatTensor = None,
|
||||||
|
transformer_options={},
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
wtype = image_tokens.dtype
|
wtype = image_tokens.dtype
|
||||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
|
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
|
||||||
@ -500,6 +505,7 @@ class HiDreamImageTransformerBlock(nn.Module):
|
|||||||
image_tokens_masks,
|
image_tokens_masks,
|
||||||
norm_text_tokens,
|
norm_text_tokens,
|
||||||
rope = rope,
|
rope = rope,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||||
@ -550,6 +556,7 @@ class HiDreamImageBlock(nn.Module):
|
|||||||
text_tokens: Optional[torch.FloatTensor] = None,
|
text_tokens: Optional[torch.FloatTensor] = None,
|
||||||
adaln_input: torch.FloatTensor = None,
|
adaln_input: torch.FloatTensor = None,
|
||||||
rope: torch.FloatTensor = None,
|
rope: torch.FloatTensor = None,
|
||||||
|
transformer_options={},
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
return self.block(
|
return self.block(
|
||||||
image_tokens,
|
image_tokens,
|
||||||
@ -557,6 +564,7 @@ class HiDreamImageBlock(nn.Module):
|
|||||||
text_tokens,
|
text_tokens,
|
||||||
adaln_input,
|
adaln_input,
|
||||||
rope,
|
rope,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -786,6 +794,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
|||||||
text_tokens = cur_encoder_hidden_states,
|
text_tokens = cur_encoder_hidden_states,
|
||||||
adaln_input = adaln_input,
|
adaln_input = adaln_input,
|
||||||
rope = rope,
|
rope = rope,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
|
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
|
||||||
block_id += 1
|
block_id += 1
|
||||||
@ -809,6 +818,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
|||||||
text_tokens=None,
|
text_tokens=None,
|
||||||
adaln_input=adaln_input,
|
adaln_input=adaln_input,
|
||||||
rope=rope,
|
rope=rope,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
||||||
block_id += 1
|
block_id += 1
|
||||||
|
|||||||
@ -99,14 +99,16 @@ class Hunyuan3Dv2(nn.Module):
|
|||||||
txt=args["txt"],
|
txt=args["txt"],
|
||||||
vec=args["vec"],
|
vec=args["vec"],
|
||||||
pe=args["pe"],
|
pe=args["pe"],
|
||||||
attn_mask=args.get("attn_mask"))
|
attn_mask=args.get("attn_mask"),
|
||||||
|
transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": img,
|
out = blocks_replace[("double_block", i)]({"img": img,
|
||||||
"txt": txt,
|
"txt": txt,
|
||||||
"vec": vec,
|
"vec": vec,
|
||||||
"pe": pe,
|
"pe": pe,
|
||||||
"attn_mask": attn_mask},
|
"attn_mask": attn_mask,
|
||||||
|
"transformer_options": transformer_options},
|
||||||
{"original_block": block_wrap})
|
{"original_block": block_wrap})
|
||||||
txt = out["txt"]
|
txt = out["txt"]
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
@ -115,7 +117,8 @@ class Hunyuan3Dv2(nn.Module):
|
|||||||
txt=txt,
|
txt=txt,
|
||||||
vec=vec,
|
vec=vec,
|
||||||
pe=pe,
|
pe=pe,
|
||||||
attn_mask=attn_mask)
|
attn_mask=attn_mask,
|
||||||
|
transformer_options=transformer_options)
|
||||||
|
|
||||||
img = torch.cat((txt, img), 1)
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
@ -126,17 +129,19 @@ class Hunyuan3Dv2(nn.Module):
|
|||||||
out["img"] = block(args["img"],
|
out["img"] = block(args["img"],
|
||||||
vec=args["vec"],
|
vec=args["vec"],
|
||||||
pe=args["pe"],
|
pe=args["pe"],
|
||||||
attn_mask=args.get("attn_mask"))
|
attn_mask=args.get("attn_mask"),
|
||||||
|
transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("single_block", i)]({"img": img,
|
out = blocks_replace[("single_block", i)]({"img": img,
|
||||||
"vec": vec,
|
"vec": vec,
|
||||||
"pe": pe,
|
"pe": pe,
|
||||||
"attn_mask": attn_mask},
|
"attn_mask": attn_mask,
|
||||||
|
"transformer_options": transformer_options},
|
||||||
{"original_block": block_wrap})
|
{"original_block": block_wrap})
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
img = img[:, txt.shape[1]:, ...]
|
img = img[:, txt.shape[1]:, ...]
|
||||||
img = self.final_layer(img, vec)
|
img = self.final_layer(img, vec)
|
||||||
|
|||||||
@ -80,13 +80,13 @@ class TokenRefinerBlock(nn.Module):
|
|||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, c, mask):
|
def forward(self, x, c, mask, transformer_options={}):
|
||||||
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
|
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||||
|
|
||||||
norm_x = self.norm1(x)
|
norm_x = self.norm1(x)
|
||||||
qkv = self.self_attn.qkv(norm_x)
|
qkv = self.self_attn.qkv(norm_x)
|
||||||
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
|
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True, transformer_options=transformer_options)
|
||||||
|
|
||||||
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
|
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
|
||||||
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
|
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
|
||||||
@ -117,14 +117,14 @@ class IndividualTokenRefiner(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, c, mask):
|
def forward(self, x, c, mask, transformer_options={}):
|
||||||
m = None
|
m = None
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
|
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
|
||||||
m = m + m.transpose(2, 3)
|
m = m + m.transpose(2, 3)
|
||||||
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, c, m)
|
x = block(x, c, m, transformer_options=transformer_options)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -152,6 +152,7 @@ class TokenRefiner(nn.Module):
|
|||||||
x,
|
x,
|
||||||
timesteps,
|
timesteps,
|
||||||
mask,
|
mask,
|
||||||
|
transformer_options={},
|
||||||
):
|
):
|
||||||
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
||||||
# m = mask.float().unsqueeze(-1)
|
# m = mask.float().unsqueeze(-1)
|
||||||
@ -160,7 +161,7 @@ class TokenRefiner(nn.Module):
|
|||||||
|
|
||||||
c = t + self.c_embedder(c.to(x.dtype))
|
c = t + self.c_embedder(c.to(x.dtype))
|
||||||
x = self.input_embedder(x)
|
x = self.input_embedder(x)
|
||||||
x = self.individual_token_refiner(x, c, mask)
|
x = self.individual_token_refiner(x, c, mask, transformer_options=transformer_options)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -278,6 +279,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
guidance: Tensor = None,
|
guidance: Tensor = None,
|
||||||
guiding_frame_index=None,
|
guiding_frame_index=None,
|
||||||
ref_latent=None,
|
ref_latent=None,
|
||||||
|
disable_time_r=False,
|
||||||
control=None,
|
control=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
@ -288,7 +290,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
img = self.img_in(img)
|
img = self.img_in(img)
|
||||||
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
||||||
|
|
||||||
if self.time_r_in is not None:
|
if (self.time_r_in is not None) and (not disable_time_r):
|
||||||
w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
|
w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
|
||||||
if len(w) > 0:
|
if len(w) > 0:
|
||||||
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
|
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
|
||||||
@ -327,7 +329,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
if txt_mask is not None and not torch.is_floating_point(txt_mask):
|
if txt_mask is not None and not torch.is_floating_point(txt_mask):
|
||||||
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
|
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
|
||||||
|
|
||||||
txt = self.txt_in(txt, timesteps, txt_mask)
|
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
if self.byt5_in is not None and txt_byt5 is not None:
|
if self.byt5_in is not None and txt_byt5 is not None:
|
||||||
txt_byt5 = self.byt5_in(txt_byt5)
|
txt_byt5 = self.byt5_in(txt_byt5)
|
||||||
@ -351,14 +353,14 @@ class HunyuanVideo(nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
|
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"], transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt, 'transformer_options': transformer_options}, {"original_block": block_wrap})
|
||||||
txt = out["txt"]
|
txt = out["txt"]
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt, transformer_options=transformer_options)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_i = control.get("input")
|
control_i = control.get("input")
|
||||||
@ -373,13 +375,13 @@ class HunyuanVideo(nn.Module):
|
|||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
|
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"], transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
|
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims, 'transformer_options': transformer_options}, {"original_block": block_wrap})
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims, transformer_options=transformer_options)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_o = control.get("output")
|
control_o = control.get("output")
|
||||||
@ -428,14 +430,14 @@ class HunyuanVideo(nn.Module):
|
|||||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||||
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
self._forward,
|
self._forward,
|
||||||
self,
|
self,
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
|
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
|
||||||
|
|
||||||
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||||
bs = x.shape[0]
|
bs = x.shape[0]
|
||||||
if len(self.patch_size) == 3:
|
if len(self.patch_size) == 3:
|
||||||
img_ids = self.img_ids(x)
|
img_ids = self.img_ids(x)
|
||||||
@ -443,5 +445,5 @@ class HunyuanVideo(nn.Module):
|
|||||||
else:
|
else:
|
||||||
img_ids = self.img_ids_2d(x)
|
img_ids = self.img_ids_2d(x)
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
|
||||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
|
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
|
||||||
return out
|
return out
|
||||||
|
|||||||
267
comfy/ldm/hunyuan_video/vae_refiner.py
Normal file
267
comfy/ldm/hunyuan_video/vae_refiner.py
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d
|
||||||
|
import comfy.ops
|
||||||
|
import comfy.ldm.models.autoencoder
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
class RMS_norm(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
shape = (dim, 1, 1, 1)
|
||||||
|
self.scale = dim**0.5
|
||||||
|
self.gamma = nn.Parameter(torch.empty(shape))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return F.normalize(x, dim=1) * self.scale * self.gamma
|
||||||
|
|
||||||
|
class DnSmpl(nn.Module):
|
||||||
|
def __init__(self, ic, oc, tds=True):
|
||||||
|
super().__init__()
|
||||||
|
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
|
||||||
|
assert oc % fct == 0
|
||||||
|
self.conv = VideoConv3d(ic, oc // fct, kernel_size=3)
|
||||||
|
|
||||||
|
self.tds = tds
|
||||||
|
self.gs = fct * ic // oc
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
r1 = 2 if self.tds else 1
|
||||||
|
h = self.conv(x)
|
||||||
|
|
||||||
|
if self.tds:
|
||||||
|
hf = h[:, :, :1, :, :]
|
||||||
|
b, c, f, ht, wd = hf.shape
|
||||||
|
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
|
||||||
|
hf = hf.permute(0, 4, 6, 1, 2, 3, 5)
|
||||||
|
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
|
||||||
|
hf = torch.cat([hf, hf], dim=1)
|
||||||
|
|
||||||
|
hn = h[:, :, 1:, :, :]
|
||||||
|
b, c, frms, ht, wd = hn.shape
|
||||||
|
nf = frms // r1
|
||||||
|
hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||||
|
hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||||
|
hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||||
|
|
||||||
|
h = torch.cat([hf, hn], dim=2)
|
||||||
|
|
||||||
|
xf = x[:, :, :1, :, :]
|
||||||
|
b, ci, f, ht, wd = xf.shape
|
||||||
|
xf = xf.reshape(b, ci, f, ht // 2, 2, wd // 2, 2)
|
||||||
|
xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
|
||||||
|
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
|
||||||
|
B, C, T, H, W = xf.shape
|
||||||
|
xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2)
|
||||||
|
|
||||||
|
xn = x[:, :, 1:, :, :]
|
||||||
|
b, ci, frms, ht, wd = xn.shape
|
||||||
|
nf = frms // r1
|
||||||
|
xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||||
|
xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||||
|
xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||||
|
B, C, T, H, W = xn.shape
|
||||||
|
xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||||
|
sc = torch.cat([xf, xn], dim=2)
|
||||||
|
else:
|
||||||
|
b, c, frms, ht, wd = h.shape
|
||||||
|
nf = frms // r1
|
||||||
|
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||||
|
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||||
|
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||||
|
|
||||||
|
b, ci, frms, ht, wd = x.shape
|
||||||
|
nf = frms // r1
|
||||||
|
sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||||
|
sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||||
|
sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||||
|
B, C, T, H, W = sc.shape
|
||||||
|
sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||||
|
|
||||||
|
return h + sc
|
||||||
|
|
||||||
|
|
||||||
|
class UpSmpl(nn.Module):
|
||||||
|
def __init__(self, ic, oc, tus=True):
|
||||||
|
super().__init__()
|
||||||
|
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
|
||||||
|
self.conv = VideoConv3d(ic, oc * fct, kernel_size=3)
|
||||||
|
|
||||||
|
self.tus = tus
|
||||||
|
self.rp = fct * oc // ic
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
r1 = 2 if self.tus else 1
|
||||||
|
h = self.conv(x)
|
||||||
|
|
||||||
|
if self.tus:
|
||||||
|
hf = h[:, :, :1, :, :]
|
||||||
|
b, c, f, ht, wd = hf.shape
|
||||||
|
nc = c // (2 * 2)
|
||||||
|
hf = hf.reshape(b, 2, 2, nc, f, ht, wd)
|
||||||
|
hf = hf.permute(0, 3, 4, 5, 1, 6, 2)
|
||||||
|
hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||||
|
hf = hf[:, : hf.shape[1] // 2]
|
||||||
|
|
||||||
|
hn = h[:, :, 1:, :, :]
|
||||||
|
b, c, frms, ht, wd = hn.shape
|
||||||
|
nc = c // (r1 * 2 * 2)
|
||||||
|
hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||||
|
hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||||
|
hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||||
|
|
||||||
|
h = torch.cat([hf, hn], dim=2)
|
||||||
|
|
||||||
|
xf = x[:, :, :1, :, :]
|
||||||
|
b, ci, f, ht, wd = xf.shape
|
||||||
|
xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1)
|
||||||
|
b, c, f, ht, wd = xf.shape
|
||||||
|
nc = c // (2 * 2)
|
||||||
|
xf = xf.reshape(b, 2, 2, nc, f, ht, wd)
|
||||||
|
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
|
||||||
|
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||||
|
|
||||||
|
xn = x[:, :, 1:, :, :]
|
||||||
|
xn = xn.repeat_interleave(repeats=self.rp, dim=1)
|
||||||
|
b, c, frms, ht, wd = xn.shape
|
||||||
|
nc = c // (r1 * 2 * 2)
|
||||||
|
xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||||
|
xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||||
|
xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||||
|
sc = torch.cat([xf, xn], dim=2)
|
||||||
|
else:
|
||||||
|
b, c, frms, ht, wd = h.shape
|
||||||
|
nc = c // (r1 * 2 * 2)
|
||||||
|
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||||
|
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||||
|
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||||
|
|
||||||
|
sc = x.repeat_interleave(repeats=self.rp, dim=1)
|
||||||
|
b, c, frms, ht, wd = sc.shape
|
||||||
|
nc = c // (r1 * 2 * 2)
|
||||||
|
sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||||
|
sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||||
|
sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||||
|
|
||||||
|
return h + sc
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||||
|
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_):
|
||||||
|
super().__init__()
|
||||||
|
self.z_channels = z_channels
|
||||||
|
self.block_out_channels = block_out_channels
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1)
|
||||||
|
|
||||||
|
self.down = nn.ModuleList()
|
||||||
|
ch = block_out_channels[0]
|
||||||
|
depth = (ffactor_spatial >> 1).bit_length()
|
||||||
|
depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length()
|
||||||
|
|
||||||
|
for i, tgt in enumerate(block_out_channels):
|
||||||
|
stage = nn.Module()
|
||||||
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||||
|
out_channels=tgt,
|
||||||
|
temb_channels=0,
|
||||||
|
conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||||
|
for j in range(num_res_blocks)])
|
||||||
|
ch = tgt
|
||||||
|
if i < depth:
|
||||||
|
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
||||||
|
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal)
|
||||||
|
ch = nxt
|
||||||
|
self.down.append(stage)
|
||||||
|
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||||
|
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
|
||||||
|
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||||
|
|
||||||
|
self.norm_out = RMS_norm(ch)
|
||||||
|
self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1)
|
||||||
|
|
||||||
|
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv_in(x)
|
||||||
|
|
||||||
|
for stage in self.down:
|
||||||
|
for blk in stage.block:
|
||||||
|
x = blk(x)
|
||||||
|
if hasattr(stage, 'downsample'):
|
||||||
|
x = stage.downsample(x)
|
||||||
|
|
||||||
|
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||||
|
|
||||||
|
b, c, t, h, w = x.shape
|
||||||
|
grp = c // (self.z_channels << 1)
|
||||||
|
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
|
||||||
|
|
||||||
|
out = self.conv_out(F.silu(self.norm_out(x))) + skip
|
||||||
|
out = self.regul(out)[0]
|
||||||
|
|
||||||
|
out = torch.cat((out[:, :, :1], out), dim=2)
|
||||||
|
out = out.permute(0, 2, 1, 3, 4)
|
||||||
|
b, f_times_2, c, h, w = out.shape
|
||||||
|
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
||||||
|
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
return out
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
||||||
|
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_):
|
||||||
|
super().__init__()
|
||||||
|
block_out_channels = block_out_channels[::-1]
|
||||||
|
self.z_channels = z_channels
|
||||||
|
self.block_out_channels = block_out_channels
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
|
||||||
|
ch = block_out_channels[0]
|
||||||
|
self.conv_in = VideoConv3d(z_channels, ch, 3)
|
||||||
|
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||||
|
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
|
||||||
|
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||||
|
|
||||||
|
self.up = nn.ModuleList()
|
||||||
|
depth = (ffactor_spatial >> 1).bit_length()
|
||||||
|
depth_temporal = (ffactor_temporal >> 1).bit_length()
|
||||||
|
|
||||||
|
for i, tgt in enumerate(block_out_channels):
|
||||||
|
stage = nn.Module()
|
||||||
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||||
|
out_channels=tgt,
|
||||||
|
temb_channels=0,
|
||||||
|
conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||||
|
for j in range(num_res_blocks + 1)])
|
||||||
|
ch = tgt
|
||||||
|
if i < depth:
|
||||||
|
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
||||||
|
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal)
|
||||||
|
ch = nxt
|
||||||
|
self.up.append(stage)
|
||||||
|
|
||||||
|
self.norm_out = RMS_norm(ch)
|
||||||
|
self.conv_out = VideoConv3d(ch, out_channels, 3)
|
||||||
|
|
||||||
|
def forward(self, z):
|
||||||
|
z = z.permute(0, 2, 1, 3, 4)
|
||||||
|
b, f, c, h, w = z.shape
|
||||||
|
z = z.reshape(b, f, 2, c // 2, h, w)
|
||||||
|
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
||||||
|
z = z.permute(0, 2, 1, 3, 4)
|
||||||
|
z = z[:, :, 1:]
|
||||||
|
|
||||||
|
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||||
|
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||||
|
|
||||||
|
for stage in self.up:
|
||||||
|
for blk in stage.block:
|
||||||
|
x = blk(x)
|
||||||
|
if hasattr(stage, 'upsample'):
|
||||||
|
x = stage.upsample(x)
|
||||||
|
|
||||||
|
return self.conv_out(F.silu(self.norm_out(x)))
|
||||||
@ -271,7 +271,7 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None, pe=None):
|
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = x if context is None else context
|
context = x if context is None else context
|
||||||
k = self.to_k(context)
|
k = self.to_k(context)
|
||||||
@ -285,9 +285,9 @@ class CrossAttention(nn.Module):
|
|||||||
k = apply_rotary_emb(k, pe)
|
k = apply_rotary_emb(k, pe)
|
||||||
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||||
else:
|
else:
|
||||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
@ -303,12 +303,12 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
|
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||||
|
|
||||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
|
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
|
||||||
|
|
||||||
x += self.attn2(x, context=context, mask=attention_mask)
|
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
||||||
x += self.ff(y) * gate_mlp
|
x += self.ff(y) * gate_mlp
|
||||||
@ -479,10 +479,10 @@ class LTXVModel(torch.nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
|
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = block(
|
x = block(
|
||||||
@ -490,7 +490,8 @@ class LTXVModel(torch.nn.Module):
|
|||||||
context=context,
|
context=context,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
pe=pe
|
pe=pe,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Output
|
# 3. Output
|
||||||
|
|||||||
@ -104,6 +104,7 @@ class JointAttention(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_mask: torch.Tensor,
|
x_mask: torch.Tensor,
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
|
transformer_options={},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -140,7 +141,7 @@ class JointAttention(nn.Module):
|
|||||||
if n_rep >= 1:
|
if n_rep >= 1:
|
||||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||||
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
|
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options)
|
||||||
|
|
||||||
return self.out(output)
|
return self.out(output)
|
||||||
|
|
||||||
@ -268,6 +269,7 @@ class JointTransformerBlock(nn.Module):
|
|||||||
x_mask: torch.Tensor,
|
x_mask: torch.Tensor,
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
adaln_input: Optional[torch.Tensor]=None,
|
adaln_input: Optional[torch.Tensor]=None,
|
||||||
|
transformer_options={},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Perform a forward pass through the TransformerBlock.
|
Perform a forward pass through the TransformerBlock.
|
||||||
@ -290,6 +292,7 @@ class JointTransformerBlock(nn.Module):
|
|||||||
modulate(self.attention_norm1(x), scale_msa),
|
modulate(self.attention_norm1(x), scale_msa),
|
||||||
x_mask,
|
x_mask,
|
||||||
freqs_cis,
|
freqs_cis,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||||
@ -304,6 +307,7 @@ class JointTransformerBlock(nn.Module):
|
|||||||
self.attention_norm1(x),
|
self.attention_norm1(x),
|
||||||
x_mask,
|
x_mask,
|
||||||
freqs_cis,
|
freqs_cis,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
x = x + self.ffn_norm2(
|
x = x + self.ffn_norm2(
|
||||||
@ -494,7 +498,7 @@ class NextDiT(nn.Module):
|
|||||||
return imgs
|
return imgs
|
||||||
|
|
||||||
def patchify_and_embed(
|
def patchify_and_embed(
|
||||||
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
|
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
||||||
bsz = len(x)
|
bsz = len(x)
|
||||||
pH = pW = self.patch_size
|
pH = pW = self.patch_size
|
||||||
@ -554,7 +558,7 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
# refine context
|
# refine context
|
||||||
for layer in self.context_refiner:
|
for layer in self.context_refiner:
|
||||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
|
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
|
||||||
|
|
||||||
# refine image
|
# refine image
|
||||||
flat_x = []
|
flat_x = []
|
||||||
@ -573,7 +577,7 @@ class NextDiT(nn.Module):
|
|||||||
padded_img_embed = self.x_embedder(padded_img_embed)
|
padded_img_embed = self.x_embedder(padded_img_embed)
|
||||||
padded_img_mask = padded_img_mask.unsqueeze(1)
|
padded_img_mask = padded_img_mask.unsqueeze(1)
|
||||||
for layer in self.noise_refiner:
|
for layer in self.noise_refiner:
|
||||||
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
|
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
|
||||||
|
|
||||||
if cap_mask is not None:
|
if cap_mask is not None:
|
||||||
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
|
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
|
||||||
@ -616,12 +620,13 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||||
|
|
||||||
|
transformer_options = kwargs.get("transformer_options", {})
|
||||||
x_is_tensor = isinstance(x, torch.Tensor)
|
x_is_tensor = isinstance(x, torch.Tensor)
|
||||||
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
|
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
||||||
freqs_cis = freqs_cis.to(x.device)
|
freqs_cis = freqs_cis.to(x.device)
|
||||||
|
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x, mask, freqs_cis, adaln_input)
|
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||||
|
|
||||||
x = self.final_layer(x, adaln_input)
|
x = self.final_layer(x, adaln_input)
|
||||||
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
|
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
|
||||||
|
|||||||
@ -26,6 +26,12 @@ class DiagonalGaussianRegularizer(torch.nn.Module):
|
|||||||
z = posterior.mode()
|
z = posterior.mode()
|
||||||
return z, None
|
return z, None
|
||||||
|
|
||||||
|
class EmptyRegularizer(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
||||||
|
return z, None
|
||||||
|
|
||||||
class AbstractAutoencoder(torch.nn.Module):
|
class AbstractAutoencoder(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -5,8 +5,9 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from typing import Optional
|
from typing import Optional, Any, Callable, Union
|
||||||
import logging
|
import logging
|
||||||
|
import functools
|
||||||
|
|
||||||
from .diffusionmodules.util import AlphaBlender, timestep_embedding
|
from .diffusionmodules.util import AlphaBlender, timestep_embedding
|
||||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||||
@ -17,23 +18,45 @@ if model_management.xformers_enabled():
|
|||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
|
|
||||||
if model_management.sage_attention_enabled():
|
SAGE_ATTENTION_IS_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
from sageattention import sageattn
|
from sageattention import sageattn
|
||||||
except ModuleNotFoundError as e:
|
SAGE_ATTENTION_IS_AVAILABLE = True
|
||||||
|
except ImportError as e:
|
||||||
|
if model_management.sage_attention_enabled():
|
||||||
if e.name == "sageattention":
|
if e.name == "sageattention":
|
||||||
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
if model_management.flash_attention_enabled():
|
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_func
|
from flash_attn import flash_attn_func
|
||||||
except ModuleNotFoundError:
|
FLASH_ATTENTION_IS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
if model_management.flash_attention_enabled():
|
||||||
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
|
REGISTERED_ATTENTION_FUNCTIONS = {}
|
||||||
|
def register_attention_function(name: str, func: Callable):
|
||||||
|
# avoid replacing existing functions
|
||||||
|
if name not in REGISTERED_ATTENTION_FUNCTIONS:
|
||||||
|
REGISTERED_ATTENTION_FUNCTIONS[name] = func
|
||||||
|
else:
|
||||||
|
logging.warning(f"Attention function {name} already registered, skipping registration.")
|
||||||
|
|
||||||
|
def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]:
|
||||||
|
if name == "optimized":
|
||||||
|
return optimized_attention
|
||||||
|
elif name not in REGISTERED_ATTENTION_FUNCTIONS:
|
||||||
|
if default is ...:
|
||||||
|
raise KeyError(f"Attention function {name} not found.")
|
||||||
|
else:
|
||||||
|
return default
|
||||||
|
return REGISTERED_ATTENTION_FUNCTIONS[name]
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
@ -91,7 +114,27 @@ class FeedForward(nn.Module):
|
|||||||
def Normalize(in_channels, dtype=None, device=None):
|
def Normalize(in_channels, dtype=None, device=None):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
|
||||||
|
def wrap_attn(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
remove_attn_wrapper_key = False
|
||||||
|
try:
|
||||||
|
if "_inside_attn_wrapper" not in kwargs:
|
||||||
|
transformer_options = kwargs.get("transformer_options", None)
|
||||||
|
remove_attn_wrapper_key = True
|
||||||
|
kwargs["_inside_attn_wrapper"] = True
|
||||||
|
if transformer_options is not None:
|
||||||
|
if "optimized_attention_override" in transformer_options:
|
||||||
|
return transformer_options["optimized_attention_override"](func, *args, **kwargs)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
if remove_attn_wrapper_key:
|
||||||
|
del kwargs["_inside_attn_wrapper"]
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
@wrap_attn
|
||||||
|
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@ -159,8 +202,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@wrap_attn
|
||||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
attn_precision = get_attn_precision(attn_precision, query.dtype)
|
attn_precision = get_attn_precision(attn_precision, query.dtype)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@ -230,7 +273,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
|||||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
@wrap_attn
|
||||||
|
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@ -359,7 +403,8 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
@wrap_attn
|
||||||
|
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
b = q.shape[0]
|
b = q.shape[0]
|
||||||
dim_head = q.shape[-1]
|
dim_head = q.shape[-1]
|
||||||
# check to make sure xformers isn't broken
|
# check to make sure xformers isn't broken
|
||||||
@ -374,7 +419,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
|||||||
disabled_xformers = True
|
disabled_xformers = True
|
||||||
|
|
||||||
if disabled_xformers:
|
if disabled_xformers:
|
||||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
|
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
# b h k d -> b k h d
|
# b h k d -> b k h d
|
||||||
@ -427,8 +472,8 @@ else:
|
|||||||
#TODO: other GPUs ?
|
#TODO: other GPUs ?
|
||||||
SDP_BATCH_LIMIT = 2**31
|
SDP_BATCH_LIMIT = 2**31
|
||||||
|
|
||||||
|
@wrap_attn
|
||||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
else:
|
else:
|
||||||
@ -470,8 +515,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@wrap_attn
|
||||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
tensor_layout = "HND"
|
tensor_layout = "HND"
|
||||||
@ -501,7 +546,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
lambda t: t.transpose(1, 2),
|
lambda t: t.transpose(1, 2),
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
|
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||||
|
|
||||||
if tensor_layout == "HND":
|
if tensor_layout == "HND":
|
||||||
if not skip_output_reshape:
|
if not skip_output_reshape:
|
||||||
@ -534,8 +579,8 @@ except AttributeError as error:
|
|||||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||||
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||||
|
|
||||||
|
@wrap_attn
|
||||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
else:
|
else:
|
||||||
@ -555,7 +600,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
mask = mask.unsqueeze(1)
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert mask is None
|
if mask is not None:
|
||||||
|
raise RuntimeError("Mask must not be set for Flash attention")
|
||||||
out = flash_attn_wrapper(
|
out = flash_attn_wrapper(
|
||||||
q.transpose(1, 2),
|
q.transpose(1, 2),
|
||||||
k.transpose(1, 2),
|
k.transpose(1, 2),
|
||||||
@ -597,6 +643,19 @@ else:
|
|||||||
|
|
||||||
optimized_attention_masked = optimized_attention
|
optimized_attention_masked = optimized_attention
|
||||||
|
|
||||||
|
|
||||||
|
# register core-supported attention functions
|
||||||
|
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||||
|
register_attention_function("sage", attention_sage)
|
||||||
|
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||||
|
register_attention_function("flash", attention_flash)
|
||||||
|
if model_management.xformers_enabled():
|
||||||
|
register_attention_function("xformers", attention_xformers)
|
||||||
|
register_attention_function("pytorch", attention_pytorch)
|
||||||
|
register_attention_function("sub_quad", attention_sub_quad)
|
||||||
|
register_attention_function("split", attention_split)
|
||||||
|
|
||||||
|
|
||||||
def optimized_attention_for_device(device, mask=False, small_input=False):
|
def optimized_attention_for_device(device, mask=False, small_input=False):
|
||||||
if small_input:
|
if small_input:
|
||||||
if model_management.pytorch_attention_enabled():
|
if model_management.pytorch_attention_enabled():
|
||||||
@ -629,7 +688,7 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||||
|
|
||||||
def forward(self, x, context=None, value=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None, transformer_options={}):
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
k = self.to_k(context)
|
k = self.to_k(context)
|
||||||
@ -640,9 +699,9 @@ class CrossAttention(nn.Module):
|
|||||||
v = self.to_v(context)
|
v = self.to_v(context)
|
||||||
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||||
else:
|
else:
|
||||||
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
@ -746,7 +805,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
|
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
|
||||||
n = self.attn1.to_out(n)
|
n = self.attn1.to_out(n)
|
||||||
else:
|
else:
|
||||||
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=transformer_options)
|
||||||
|
|
||||||
if "attn1_output_patch" in transformer_patches:
|
if "attn1_output_patch" in transformer_patches:
|
||||||
patch = transformer_patches["attn1_output_patch"]
|
patch = transformer_patches["attn1_output_patch"]
|
||||||
@ -786,7 +845,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
||||||
n = self.attn2.to_out(n)
|
n = self.attn2.to_out(n)
|
||||||
else:
|
else:
|
||||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=transformer_options)
|
||||||
|
|
||||||
if "attn2_output_patch" in transformer_patches:
|
if "attn2_output_patch" in transformer_patches:
|
||||||
patch = transformer_patches["attn2_output_patch"]
|
patch = transformer_patches["attn2_output_patch"]
|
||||||
@ -1017,7 +1076,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
|||||||
|
|
||||||
B, S, C = x_mix.shape
|
B, S, C = x_mix.shape
|
||||||
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
|
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
|
||||||
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
|
x_mix = mix_block(x_mix, context=time_context, transformer_options=transformer_options)
|
||||||
x_mix = rearrange(
|
x_mix = rearrange(
|
||||||
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
|
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
|
||||||
)
|
)
|
||||||
|
|||||||
@ -606,7 +606,7 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
|
|||||||
return _block_mixing(*args, **kwargs)
|
return _block_mixing(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _block_mixing(context, x, context_block, x_block, c):
|
def _block_mixing(context, x, context_block, x_block, c, transformer_options={}):
|
||||||
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||||
|
|
||||||
if x_block.x_block_self_attn:
|
if x_block.x_block_self_attn:
|
||||||
@ -622,6 +622,7 @@ def _block_mixing(context, x, context_block, x_block, c):
|
|||||||
attn = optimized_attention(
|
attn = optimized_attention(
|
||||||
qkv[0], qkv[1], qkv[2],
|
qkv[0], qkv[1], qkv[2],
|
||||||
heads=x_block.attn.num_heads,
|
heads=x_block.attn.num_heads,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
context_attn, x_attn = (
|
context_attn, x_attn = (
|
||||||
attn[:, : context_qkv[0].shape[1]],
|
attn[:, : context_qkv[0].shape[1]],
|
||||||
@ -637,6 +638,7 @@ def _block_mixing(context, x, context_block, x_block, c):
|
|||||||
attn2 = optimized_attention(
|
attn2 = optimized_attention(
|
||||||
x_qkv2[0], x_qkv2[1], x_qkv2[2],
|
x_qkv2[0], x_qkv2[1], x_qkv2[2],
|
||||||
heads=x_block.attn2.num_heads,
|
heads=x_block.attn2.num_heads,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
|
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
|
||||||
else:
|
else:
|
||||||
@ -958,10 +960,10 @@ class MMDiT(nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
|
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"], transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
context = out["txt"]
|
context = out["txt"]
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
@ -970,6 +972,7 @@ class MMDiT(nn.Module):
|
|||||||
x,
|
x,
|
||||||
c=c_mod,
|
c=c_mod,
|
||||||
use_checkpoint=self.use_checkpoint,
|
use_checkpoint=self.use_checkpoint,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
if control is not None:
|
if control is not None:
|
||||||
control_o = control.get("output")
|
control_o = control.get("output")
|
||||||
|
|||||||
@ -145,7 +145,7 @@ class Downsample(nn.Module):
|
|||||||
|
|
||||||
class ResnetBlock(nn.Module):
|
class ResnetBlock(nn.Module):
|
||||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||||
dropout=0.0, temb_channels=512, conv_op=ops.Conv2d):
|
dropout=0.0, temb_channels=512, conv_op=ops.Conv2d, norm_op=Normalize):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
out_channels = in_channels if out_channels is None else out_channels
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
@ -153,7 +153,7 @@ class ResnetBlock(nn.Module):
|
|||||||
self.use_conv_shortcut = conv_shortcut
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
|
||||||
self.swish = torch.nn.SiLU(inplace=True)
|
self.swish = torch.nn.SiLU(inplace=True)
|
||||||
self.norm1 = Normalize(in_channels)
|
self.norm1 = norm_op(in_channels)
|
||||||
self.conv1 = conv_op(in_channels,
|
self.conv1 = conv_op(in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
@ -162,7 +162,7 @@ class ResnetBlock(nn.Module):
|
|||||||
if temb_channels > 0:
|
if temb_channels > 0:
|
||||||
self.temb_proj = ops.Linear(temb_channels,
|
self.temb_proj = ops.Linear(temb_channels,
|
||||||
out_channels)
|
out_channels)
|
||||||
self.norm2 = Normalize(out_channels)
|
self.norm2 = norm_op(out_channels)
|
||||||
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
||||||
self.conv2 = conv_op(out_channels,
|
self.conv2 = conv_op(out_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
@ -305,11 +305,11 @@ def vae_attention():
|
|||||||
return normal_attention
|
return normal_attention
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
class AttnBlock(nn.Module):
|
||||||
def __init__(self, in_channels, conv_op=ops.Conv2d):
|
def __init__(self, in_channels, conv_op=ops.Conv2d, norm_op=Normalize):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
self.norm = Normalize(in_channels)
|
self.norm = norm_op(in_channels)
|
||||||
self.q = conv_op(in_channels,
|
self.q = conv_op(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
|
|||||||
@ -120,7 +120,7 @@ class Attention(nn.Module):
|
|||||||
nn.Dropout(0.0)
|
nn.Dropout(0.0)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
|
||||||
batch_size, sequence_length, _ = hidden_states.shape
|
batch_size, sequence_length, _ = hidden_states.shape
|
||||||
|
|
||||||
query = self.to_q(hidden_states)
|
query = self.to_q(hidden_states)
|
||||||
@ -146,7 +146,7 @@ class Attention(nn.Module):
|
|||||||
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
||||||
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
||||||
|
|
||||||
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
|
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
|
||||||
hidden_states = self.to_out[0](hidden_states)
|
hidden_states = self.to_out[0](hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@ -182,16 +182,16 @@ class OmniGen2TransformerBlock(nn.Module):
|
|||||||
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||||
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
|
||||||
if self.modulation:
|
if self.modulation:
|
||||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||||
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
|
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
|
||||||
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
||||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
||||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
||||||
else:
|
else:
|
||||||
norm_hidden_states = self.norm1(hidden_states)
|
norm_hidden_states = self.norm1(hidden_states)
|
||||||
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
|
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
|
||||||
hidden_states = hidden_states + self.norm2(attn_output)
|
hidden_states = hidden_states + self.norm2(attn_output)
|
||||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
||||||
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
||||||
@ -390,7 +390,7 @@ class OmniGen2Transformer2DModel(nn.Module):
|
|||||||
ref_img_sizes, img_sizes,
|
ref_img_sizes, img_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
|
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, transformer_options={}):
|
||||||
batch_size = len(hidden_states)
|
batch_size = len(hidden_states)
|
||||||
|
|
||||||
hidden_states = self.x_embedder(hidden_states)
|
hidden_states = self.x_embedder(hidden_states)
|
||||||
@ -405,17 +405,17 @@ class OmniGen2Transformer2DModel(nn.Module):
|
|||||||
shift += ref_img_len
|
shift += ref_img_len
|
||||||
|
|
||||||
for layer in self.noise_refiner:
|
for layer in self.noise_refiner:
|
||||||
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
|
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb, transformer_options=transformer_options)
|
||||||
|
|
||||||
if ref_image_hidden_states is not None:
|
if ref_image_hidden_states is not None:
|
||||||
for layer in self.ref_image_refiner:
|
for layer in self.ref_image_refiner:
|
||||||
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
|
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb, transformer_options=transformer_options)
|
||||||
|
|
||||||
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
|
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
|
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||||
_, _, H_padded, W_padded = hidden_states.shape
|
_, _, H_padded, W_padded = hidden_states.shape
|
||||||
@ -444,7 +444,7 @@ class OmniGen2Transformer2DModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for layer in self.context_refiner:
|
for layer in self.context_refiner:
|
||||||
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
|
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options)
|
||||||
|
|
||||||
img_len = hidden_states.shape[1]
|
img_len = hidden_states.shape[1]
|
||||||
combined_img_hidden_states = self.img_patch_embed_and_refine(
|
combined_img_hidden_states = self.img_patch_embed_and_refine(
|
||||||
@ -453,13 +453,14 @@ class OmniGen2Transformer2DModel(nn.Module):
|
|||||||
noise_rotary_emb, ref_img_rotary_emb,
|
noise_rotary_emb, ref_img_rotary_emb,
|
||||||
l_effective_ref_img_len, l_effective_img_len,
|
l_effective_ref_img_len, l_effective_img_len,
|
||||||
temb,
|
temb,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
|
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
|
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
|
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb, transformer_options=transformer_options)
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states, temb)
|
hidden_states = self.norm_out(hidden_states, temb)
|
||||||
|
|
||||||
|
|||||||
@ -132,6 +132,7 @@ class Attention(nn.Module):
|
|||||||
encoder_hidden_states_mask: torch.FloatTensor = None,
|
encoder_hidden_states_mask: torch.FloatTensor = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||||
|
transformer_options={},
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
seq_txt = encoder_hidden_states.shape[1]
|
seq_txt = encoder_hidden_states.shape[1]
|
||||||
|
|
||||||
@ -159,7 +160,7 @@ class Attention(nn.Module):
|
|||||||
joint_key = joint_key.flatten(start_dim=2)
|
joint_key = joint_key.flatten(start_dim=2)
|
||||||
joint_value = joint_value.flatten(start_dim=2)
|
joint_value = joint_value.flatten(start_dim=2)
|
||||||
|
|
||||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask)
|
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||||
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
||||||
@ -226,6 +227,7 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
encoder_hidden_states_mask: torch.Tensor,
|
encoder_hidden_states_mask: torch.Tensor,
|
||||||
temb: torch.Tensor,
|
temb: torch.Tensor,
|
||||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
transformer_options={},
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
img_mod_params = self.img_mod(temb)
|
img_mod_params = self.img_mod(temb)
|
||||||
txt_mod_params = self.txt_mod(temb)
|
txt_mod_params = self.txt_mod(temb)
|
||||||
@ -242,6 +244,7 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
encoder_hidden_states=txt_modulated,
|
encoder_hidden_states=txt_modulated,
|
||||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states + img_gate1 * img_attn_output
|
hidden_states = hidden_states + img_gate1 * img_attn_output
|
||||||
@ -434,9 +437,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"])
|
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
hidden_states = out["img"]
|
hidden_states = out["img"]
|
||||||
encoder_hidden_states = out["txt"]
|
encoder_hidden_states = out["txt"]
|
||||||
else:
|
else:
|
||||||
@ -446,11 +449,12 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
temb=temb,
|
temb=temb,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "double_block" in patches:
|
if "double_block" in patches:
|
||||||
for p in patches["double_block"]:
|
for p in patches["double_block"]:
|
||||||
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i})
|
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||||
hidden_states = out["img"]
|
hidden_states = out["img"]
|
||||||
encoder_hidden_states = out["txt"]
|
encoder_hidden_states = out["txt"]
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from einops import rearrange
|
|||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
from comfy.ldm.flux.math import apply_rope
|
from comfy.ldm.flux.math import apply_rope1
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@ -34,7 +34,9 @@ class WanSelfAttention(nn.Module):
|
|||||||
num_heads,
|
num_heads,
|
||||||
window_size=(-1, -1),
|
window_size=(-1, -1),
|
||||||
qk_norm=True,
|
qk_norm=True,
|
||||||
eps=1e-6, operation_settings={}):
|
eps=1e-6,
|
||||||
|
kv_dim=None,
|
||||||
|
operation_settings={}):
|
||||||
assert dim % num_heads == 0
|
assert dim % num_heads == 0
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@ -43,16 +45,18 @@ class WanSelfAttention(nn.Module):
|
|||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.qk_norm = qk_norm
|
self.qk_norm = qk_norm
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
if kv_dim is None:
|
||||||
|
kv_dim = dim
|
||||||
|
|
||||||
# layers
|
# layers
|
||||||
self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.k = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.v = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||||
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, freqs):
|
def forward(self, x, freqs, transformer_options={}):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||||
@ -60,21 +64,26 @@ class WanSelfAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||||
|
|
||||||
# query, key, value function
|
def qkv_fn_q(x):
|
||||||
def qkv_fn(x):
|
|
||||||
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
||||||
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
return apply_rope1(q, freqs)
|
||||||
v = self.v(x).view(b, s, n * d)
|
|
||||||
return q, k, v
|
|
||||||
|
|
||||||
q, k, v = qkv_fn(x)
|
def qkv_fn_k(x):
|
||||||
q, k = apply_rope(q, k, freqs)
|
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
||||||
|
return apply_rope1(k, freqs)
|
||||||
|
|
||||||
|
#These two are VRAM hogs, so we want to do all of q computation and
|
||||||
|
#have pytorch garbage collect the intermediates on the sub function
|
||||||
|
#return before we touch k
|
||||||
|
q = qkv_fn_q(x)
|
||||||
|
k = qkv_fn_k(x)
|
||||||
|
|
||||||
x = optimized_attention(
|
x = optimized_attention(
|
||||||
q.view(b, s, n * d),
|
q.view(b, s, n * d),
|
||||||
k.view(b, s, n * d),
|
k.view(b, s, n * d),
|
||||||
v,
|
self.v(x).view(b, s, n * d),
|
||||||
heads=self.num_heads,
|
heads=self.num_heads,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
x = self.o(x)
|
x = self.o(x)
|
||||||
@ -83,7 +92,7 @@ class WanSelfAttention(nn.Module):
|
|||||||
|
|
||||||
class WanT2VCrossAttention(WanSelfAttention):
|
class WanT2VCrossAttention(WanSelfAttention):
|
||||||
|
|
||||||
def forward(self, x, context, **kwargs):
|
def forward(self, x, context, transformer_options={}, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L1, C]
|
x(Tensor): Shape [B, L1, C]
|
||||||
@ -95,7 +104,7 @@ class WanT2VCrossAttention(WanSelfAttention):
|
|||||||
v = self.v(context)
|
v = self.v(context)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
x = optimized_attention(q, k, v, heads=self.num_heads)
|
x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
|
||||||
|
|
||||||
x = self.o(x)
|
x = self.o(x)
|
||||||
return x
|
return x
|
||||||
@ -116,7 +125,7 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|||||||
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
||||||
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, context, context_img_len):
|
def forward(self, x, context, context_img_len, transformer_options={}):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L1, C]
|
x(Tensor): Shape [B, L1, C]
|
||||||
@ -131,9 +140,9 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|||||||
v = self.v(context)
|
v = self.v(context)
|
||||||
k_img = self.norm_k_img(self.k_img(context_img))
|
k_img = self.norm_k_img(self.k_img(context_img))
|
||||||
v_img = self.v_img(context_img)
|
v_img = self.v_img(context_img)
|
||||||
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads)
|
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options)
|
||||||
# compute attention
|
# compute attention
|
||||||
x = optimized_attention(q, k, v, heads=self.num_heads)
|
x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
x = x + img_x
|
x = x + img_x
|
||||||
@ -206,6 +215,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
freqs,
|
freqs,
|
||||||
context,
|
context,
|
||||||
context_img_len=257,
|
context_img_len=257,
|
||||||
|
transformer_options={},
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -224,12 +234,12 @@ class WanAttentionBlock(nn.Module):
|
|||||||
# self-attention
|
# self-attention
|
||||||
y = self.self_attn(
|
y = self.self_attn(
|
||||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||||
freqs)
|
freqs, transformer_options=transformer_options)
|
||||||
|
|
||||||
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||||
|
|
||||||
# cross-attention & ffn
|
# cross-attention & ffn
|
||||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||||
return x
|
return x
|
||||||
@ -396,6 +406,7 @@ class WanModel(torch.nn.Module):
|
|||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
flf_pos_embed_token_number=None,
|
flf_pos_embed_token_number=None,
|
||||||
in_dim_ref_conv=None,
|
in_dim_ref_conv=None,
|
||||||
|
wan_attn_block_class=WanAttentionBlock,
|
||||||
image_model=None,
|
image_model=None,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
@ -473,7 +484,7 @@ class WanModel(torch.nn.Module):
|
|||||||
# blocks
|
# blocks
|
||||||
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
|
||||||
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
])
|
])
|
||||||
@ -559,12 +570,12 @@ class WanModel(torch.nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||||
|
|
||||||
# head
|
# head
|
||||||
x = self.head(x, e)
|
x = self.head(x, e)
|
||||||
@ -742,17 +753,17 @@ class VaceWanModel(WanModel):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||||
|
|
||||||
ii = self.vace_layers_mapping.get(i, None)
|
ii = self.vace_layers_mapping.get(i, None)
|
||||||
if ii is not None:
|
if ii is not None:
|
||||||
for iii in range(len(c)):
|
for iii in range(len(c)):
|
||||||
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||||
x += c_skip * vace_strength[iii]
|
x += c_skip * vace_strength[iii]
|
||||||
del c_skip
|
del c_skip
|
||||||
# head
|
# head
|
||||||
@ -841,12 +852,12 @@ class CameraWanModel(WanModel):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||||
|
|
||||||
# head
|
# head
|
||||||
x = self.head(x, e)
|
x = self.head(x, e)
|
||||||
@ -1319,3 +1330,247 @@ class WanModel_S2V(WanModel):
|
|||||||
# unpatchify
|
# unpatchify
|
||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WanT2VCrossAttentionGather(WanSelfAttention):
|
||||||
|
|
||||||
|
def forward(self, x, context, transformer_options={}, **kwargs):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
x(Tensor): Shape [B, L1, C] - video tokens
|
||||||
|
context(Tensor): Shape [B, L2, C] - audio tokens with shape [B, frames*16, 1536]
|
||||||
|
"""
|
||||||
|
b, n, d = x.size(0), self.num_heads, self.head_dim
|
||||||
|
|
||||||
|
q = self.norm_q(self.q(x))
|
||||||
|
k = self.norm_k(self.k(context))
|
||||||
|
v = self.v(context)
|
||||||
|
|
||||||
|
# Handle audio temporal structure (16 tokens per frame)
|
||||||
|
k = k.reshape(-1, 16, n, d).transpose(1, 2)
|
||||||
|
v = v.reshape(-1, 16, n, d).transpose(1, 2)
|
||||||
|
|
||||||
|
# Handle video spatial structure
|
||||||
|
q = q.reshape(k.shape[0], -1, n, d).transpose(1, 2)
|
||||||
|
|
||||||
|
x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
x = x.transpose(1, 2).view(b, -1, n, d).flatten(2)
|
||||||
|
x = self.o(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AudioCrossAttentionWrapper(nn.Module):
|
||||||
|
def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6, operation_settings={}):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.audio_cross_attn = WanT2VCrossAttentionGather(dim, num_heads, qk_norm=qk_norm, kv_dim=kv_dim, eps=eps, operation_settings=operation_settings)
|
||||||
|
self.norm1_audio = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
|
def forward(self, x, audio, transformer_options={}):
|
||||||
|
x = x + self.audio_cross_attn(self.norm1_audio(x), audio, transformer_options=transformer_options)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WanAttentionBlockAudio(WanAttentionBlock):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
cross_attn_type,
|
||||||
|
dim,
|
||||||
|
ffn_dim,
|
||||||
|
num_heads,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
qk_norm=True,
|
||||||
|
cross_attn_norm=False,
|
||||||
|
eps=1e-6, operation_settings={}):
|
||||||
|
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings)
|
||||||
|
self.audio_cross_attn_wrapper = AudioCrossAttentionWrapper(dim, 1536, num_heads, qk_norm, eps, operation_settings=operation_settings)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
e,
|
||||||
|
freqs,
|
||||||
|
context,
|
||||||
|
context_img_len=257,
|
||||||
|
audio=None,
|
||||||
|
transformer_options={},
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
x(Tensor): Shape [B, L, C]
|
||||||
|
e(Tensor): Shape [B, 6, C]
|
||||||
|
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||||
|
"""
|
||||||
|
# assert e.dtype == torch.float32
|
||||||
|
|
||||||
|
if e.ndim < 4:
|
||||||
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||||
|
else:
|
||||||
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
|
||||||
|
# assert e[0].dtype == torch.float32
|
||||||
|
|
||||||
|
# self-attention
|
||||||
|
y = self.self_attn(
|
||||||
|
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||||
|
freqs, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||||
|
|
||||||
|
# cross-attention & ffn
|
||||||
|
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||||
|
if audio is not None:
|
||||||
|
x = self.audio_cross_attn_wrapper(x, audio, transformer_options=transformer_options)
|
||||||
|
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||||
|
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
class DummyAdapterLayer(nn.Module):
|
||||||
|
def __init__(self, layer):
|
||||||
|
super().__init__()
|
||||||
|
self.layer = layer
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.layer(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioProjModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
seq_len=5,
|
||||||
|
blocks=13, # add a new parameter blocks
|
||||||
|
channels=768, # add a new parameter channels
|
||||||
|
intermediate_dim=512,
|
||||||
|
output_dim=1536,
|
||||||
|
context_tokens=16,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.blocks = blocks
|
||||||
|
self.channels = channels
|
||||||
|
self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels.
|
||||||
|
self.intermediate_dim = intermediate_dim
|
||||||
|
self.context_tokens = context_tokens
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
# define multiple linear layers
|
||||||
|
self.audio_proj_glob_1 = DummyAdapterLayer(operations.Linear(self.input_dim, intermediate_dim, dtype=dtype, device=device))
|
||||||
|
self.audio_proj_glob_2 = DummyAdapterLayer(operations.Linear(intermediate_dim, intermediate_dim, dtype=dtype, device=device))
|
||||||
|
self.audio_proj_glob_3 = DummyAdapterLayer(operations.Linear(intermediate_dim, context_tokens * output_dim, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
self.audio_proj_glob_norm = DummyAdapterLayer(operations.LayerNorm(output_dim, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
def forward(self, audio_embeds):
|
||||||
|
video_length = audio_embeds.shape[1]
|
||||||
|
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
|
||||||
|
batch_size, window_size, blocks, channels = audio_embeds.shape
|
||||||
|
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
|
||||||
|
|
||||||
|
audio_embeds = torch.relu(self.audio_proj_glob_1(audio_embeds))
|
||||||
|
audio_embeds = torch.relu(self.audio_proj_glob_2(audio_embeds))
|
||||||
|
|
||||||
|
context_tokens = self.audio_proj_glob_3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim)
|
||||||
|
|
||||||
|
context_tokens = self.audio_proj_glob_norm(context_tokens)
|
||||||
|
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
|
||||||
|
|
||||||
|
return context_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class HumoWanModel(WanModel):
|
||||||
|
r"""
|
||||||
|
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_type='humo',
|
||||||
|
patch_size=(1, 2, 2),
|
||||||
|
text_len=512,
|
||||||
|
in_dim=16,
|
||||||
|
dim=2048,
|
||||||
|
ffn_dim=8192,
|
||||||
|
freq_dim=256,
|
||||||
|
text_dim=4096,
|
||||||
|
out_dim=16,
|
||||||
|
num_heads=16,
|
||||||
|
num_layers=32,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
qk_norm=True,
|
||||||
|
cross_attn_norm=True,
|
||||||
|
eps=1e-6,
|
||||||
|
flf_pos_embed_token_number=None,
|
||||||
|
image_model=None,
|
||||||
|
audio_token_num=16,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
context,
|
||||||
|
freqs=None,
|
||||||
|
audio_embed=None,
|
||||||
|
reference_latent=None,
|
||||||
|
transformer_options={},
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
bs, _, time, height, width = x.shape
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
|
grid_sizes = x.shape[2:]
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
# time embeddings
|
||||||
|
e = self.time_embedding(
|
||||||
|
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
||||||
|
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||||
|
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||||
|
|
||||||
|
if reference_latent is not None:
|
||||||
|
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
|
||||||
|
ref = ref.flatten(2).transpose(1, 2)
|
||||||
|
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=time, device=x.device, dtype=x.dtype)
|
||||||
|
x = torch.cat([x, ref], dim=1)
|
||||||
|
freqs = torch.cat([freqs, freqs_ref], dim=1)
|
||||||
|
del ref, freqs_ref
|
||||||
|
|
||||||
|
# context
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
context_img_len = None
|
||||||
|
|
||||||
|
if audio_embed is not None:
|
||||||
|
audio = self.audio_proj(audio_embed).permute(0, 3, 1, 2).flatten(2).transpose(1, 2)
|
||||||
|
else:
|
||||||
|
audio = None
|
||||||
|
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, audio=audio, transformer_options=args["transformer_options"])
|
||||||
|
return out
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, audio=audio, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
# head
|
||||||
|
x = self.head(x, e)
|
||||||
|
|
||||||
|
# unpatchify
|
||||||
|
x = self.unpatchify(x, grid_sizes)
|
||||||
|
return x
|
||||||
|
|||||||
@ -297,6 +297,12 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||||
key_map["{}".format(key_lora)] = k
|
key_map["{}".format(key_lora)] = k
|
||||||
|
|
||||||
|
if isinstance(model, comfy.model_base.Omnigen2):
|
||||||
|
for k in sdk:
|
||||||
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||||
|
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||||
|
key_map["{}".format(key_lora)] = k
|
||||||
|
|
||||||
if isinstance(model, comfy.model_base.QwenImage):
|
if isinstance(model, comfy.model_base.QwenImage):
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.startswith("diffusion_model.") and k.endswith(".weight"): #QwenImage lora format
|
if k.startswith("diffusion_model.") and k.endswith(".weight"): #QwenImage lora format
|
||||||
|
|||||||
@ -42,6 +42,7 @@ import comfy.ldm.wan.model
|
|||||||
import comfy.ldm.hunyuan3d.model
|
import comfy.ldm.hunyuan3d.model
|
||||||
import comfy.ldm.hidream.model
|
import comfy.ldm.hidream.model
|
||||||
import comfy.ldm.chroma.model
|
import comfy.ldm.chroma.model
|
||||||
|
import comfy.ldm.chroma_radiance.model
|
||||||
import comfy.ldm.ace.model
|
import comfy.ldm.ace.model
|
||||||
import comfy.ldm.omnigen.omnigen2
|
import comfy.ldm.omnigen.omnigen2
|
||||||
import comfy.ldm.qwen_image.model
|
import comfy.ldm.qwen_image.model
|
||||||
@ -1212,6 +1213,46 @@ class WAN21_Camera(WAN21):
|
|||||||
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
|
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class WAN21_HuMo(WAN21):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.HumoWanModel)
|
||||||
|
self.image_to_video = image_to_video
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
|
||||||
|
audio_embed = kwargs.get("audio_embed", None)
|
||||||
|
if audio_embed is not None:
|
||||||
|
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
|
||||||
|
|
||||||
|
if "c_concat" not in out: # 1.7B model
|
||||||
|
reference_latents = kwargs.get("reference_latents", None)
|
||||||
|
if reference_latents is not None:
|
||||||
|
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
|
||||||
|
else:
|
||||||
|
noise_shape = list(noise.shape)
|
||||||
|
noise_shape[1] += 4
|
||||||
|
concat_latent = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
|
||||||
|
zero_vae_values_first = torch.tensor([0.8660, -0.4326, -0.0017, -0.4884, -0.5283, 0.9207, -0.9896, 0.4433, -0.5543, -0.0113, 0.5753, -0.6000, -0.8346, -0.3497, -0.1926, -0.6938]).view(1, 16, 1, 1, 1)
|
||||||
|
zero_vae_values_second = torch.tensor([1.0869, -1.2370, 0.0206, -0.4357, -0.6411, 2.0307, -1.5972, 1.2659, -0.8595, -0.4654, 0.9638, -1.6330, -1.4310, -0.1098, -0.3856, -1.4583]).view(1, 16, 1, 1, 1)
|
||||||
|
zero_vae_values = torch.tensor([0.8642, -1.8583, 0.1577, 0.1350, -0.3641, 2.5863, -1.9670, 1.6065, -1.0475, -0.8678, 1.1734, -1.8138, -1.5933, -0.7721, -0.3289, -1.3745]).view(1, 16, 1, 1, 1)
|
||||||
|
concat_latent[:, 4:] = zero_vae_values
|
||||||
|
concat_latent[:, 4:, :1] = zero_vae_values_first
|
||||||
|
concat_latent[:, 4:, 1:2] = zero_vae_values_second
|
||||||
|
out['c_concat'] = comfy.conds.CONDNoiseShape(concat_latent)
|
||||||
|
reference_latents = kwargs.get("reference_latents", None)
|
||||||
|
if reference_latents is not None:
|
||||||
|
ref_latent = self.process_latent_in(reference_latents[-1])
|
||||||
|
ref_latent_shape = list(ref_latent.shape)
|
||||||
|
ref_latent_shape[1] += 4 + ref_latent_shape[1]
|
||||||
|
ref_latent_full = torch.zeros(ref_latent_shape, device=ref_latent.device, dtype=ref_latent.dtype)
|
||||||
|
ref_latent_full[:, 20:] = ref_latent
|
||||||
|
ref_latent_full[:, 16:20] = 1.0
|
||||||
|
out['reference_latent'] = comfy.conds.CONDRegular(ref_latent_full)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
class WAN22_S2V(WAN21):
|
class WAN22_S2V(WAN21):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
||||||
@ -1320,8 +1361,8 @@ class HiDream(BaseModel):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
class Chroma(Flux):
|
class Chroma(Flux):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.chroma.model.Chroma):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
|
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
@ -1331,6 +1372,10 @@ class Chroma(Flux):
|
|||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class ChromaRadiance(Chroma):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma_radiance.model.ChromaRadiance)
|
||||||
|
|
||||||
class ACEStep(BaseModel):
|
class ACEStep(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel)
|
||||||
@ -1432,3 +1477,31 @@ class HunyuanImage21(BaseModel):
|
|||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class HunyuanImage21Refiner(HunyuanImage21):
|
||||||
|
def concat_cond(self, **kwargs):
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
image = kwargs.get("concat_latent_image", None)
|
||||||
|
noise_augmentation = kwargs.get("noise_augmentation", 0.0)
|
||||||
|
device = kwargs["device"]
|
||||||
|
|
||||||
|
if image is None:
|
||||||
|
shape_image = list(noise.shape)
|
||||||
|
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
|
||||||
|
else:
|
||||||
|
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
|
image = self.process_latent_in(image)
|
||||||
|
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||||
|
if noise_augmentation > 0:
|
||||||
|
generator = torch.Generator(device="cpu")
|
||||||
|
generator.manual_seed(kwargs.get("seed", 0) - 10)
|
||||||
|
noise = torch.randn(image.shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device)
|
||||||
|
image = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image
|
||||||
|
else:
|
||||||
|
image = 0.75 * image
|
||||||
|
return image
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
out['disable_time_r'] = comfy.conds.CONDConstant(True)
|
||||||
|
return out
|
||||||
|
|||||||
@ -174,7 +174,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux
|
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "flux"
|
dit_config["image_model"] = "flux"
|
||||||
dit_config["in_channels"] = 16
|
dit_config["in_channels"] = 16
|
||||||
@ -204,6 +204,18 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["out_dim"] = 3072
|
dit_config["out_dim"] = 3072
|
||||||
dit_config["hidden_dim"] = 5120
|
dit_config["hidden_dim"] = 5120
|
||||||
dit_config["n_layers"] = 5
|
dit_config["n_layers"] = 5
|
||||||
|
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
|
||||||
|
dit_config["image_model"] = "chroma_radiance"
|
||||||
|
dit_config["in_channels"] = 3
|
||||||
|
dit_config["out_channels"] = 3
|
||||||
|
dit_config["patch_size"] = 16
|
||||||
|
dit_config["nerf_hidden_size"] = 64
|
||||||
|
dit_config["nerf_mlp_ratio"] = 4
|
||||||
|
dit_config["nerf_depth"] = 4
|
||||||
|
dit_config["nerf_max_freqs"] = 8
|
||||||
|
dit_config["nerf_tile_size"] = 32
|
||||||
|
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||||
|
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||||
else:
|
else:
|
||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
return dit_config
|
return dit_config
|
||||||
@ -390,6 +402,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["model_type"] = "camera_2.2"
|
dit_config["model_type"] = "camera_2.2"
|
||||||
elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys:
|
elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "s2v"
|
dit_config["model_type"] = "s2v"
|
||||||
|
elif '{}audio_proj.audio_proj_glob_1.layer.bias'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config["model_type"] = "humo"
|
||||||
else:
|
else:
|
||||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "i2v"
|
dit_config["model_type"] = "i2v"
|
||||||
|
|||||||
16
comfy/pixel_space_convert.py
Normal file
16
comfy/pixel_space_convert.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1
|
||||||
|
# to LATENT B, C, H, W and values on the scale of -1..1.
|
||||||
|
class PixelspaceConversionVAE(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.pixel_space_vae = torch.nn.Parameter(torch.tensor(1.0))
|
||||||
|
|
||||||
|
def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
|
||||||
|
return pixels
|
||||||
|
|
||||||
|
def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
|
||||||
|
return samples
|
||||||
|
|
||||||
35
comfy/sd.py
35
comfy/sd.py
@ -18,6 +18,7 @@ import comfy.ldm.wan.vae2_2
|
|||||||
import comfy.ldm.hunyuan3d.vae
|
import comfy.ldm.hunyuan3d.vae
|
||||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||||
import comfy.ldm.hunyuan_video.vae
|
import comfy.ldm.hunyuan_video.vae
|
||||||
|
import comfy.pixel_space_convert
|
||||||
import yaml
|
import yaml
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -285,6 +286,7 @@ class VAE:
|
|||||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
self.disable_offload = False
|
self.disable_offload = False
|
||||||
|
self.not_video = False
|
||||||
|
|
||||||
self.downscale_index_formula = None
|
self.downscale_index_formula = None
|
||||||
self.upscale_index_formula = None
|
self.upscale_index_formula = None
|
||||||
@ -409,6 +411,23 @@ class VAE:
|
|||||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
||||||
self.downscale_index_formula = (8, 32, 32)
|
self.downscale_index_formula = (8, 32, 32)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
|
||||||
|
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||||
|
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||||
|
self.latent_channels = 64
|
||||||
|
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||||
|
self.upscale_index_formula = (4, 16, 16)
|
||||||
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
|
||||||
|
self.downscale_index_formula = (4, 16, 16)
|
||||||
|
self.latent_dim = 3
|
||||||
|
self.not_video = True
|
||||||
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"},
|
||||||
|
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
|
||||||
|
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
|
||||||
|
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (1400 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (1400 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
||||||
elif "decoder.conv_in.conv.weight" in sd:
|
elif "decoder.conv_in.conv.weight" in sd:
|
||||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
ddconfig["conv3d"] = True
|
ddconfig["conv3d"] = True
|
||||||
@ -498,6 +517,15 @@ class VAE:
|
|||||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
self.disable_offload = True
|
self.disable_offload = True
|
||||||
self.extra_1d_channel = 16
|
self.extra_1d_channel = 16
|
||||||
|
elif "pixel_space_vae" in sd:
|
||||||
|
self.first_stage_model = comfy.pixel_space_convert.PixelspaceConversionVAE()
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||||
|
self.downscale_ratio = 1
|
||||||
|
self.upscale_ratio = 1
|
||||||
|
self.latent_channels = 3
|
||||||
|
self.latent_dim = 2
|
||||||
|
self.output_channels = 3
|
||||||
else:
|
else:
|
||||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||||
self.first_stage_model = None
|
self.first_stage_model = None
|
||||||
@ -670,7 +698,10 @@ class VAE:
|
|||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
||||||
|
if not self.not_video:
|
||||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
pixel_samples = pixel_samples.unsqueeze(2)
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||||
@ -704,7 +735,10 @@ class VAE:
|
|||||||
dims = self.latent_dim
|
dims = self.latent_dim
|
||||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
if dims == 3:
|
if dims == 3:
|
||||||
|
if not self.not_video:
|
||||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
pixel_samples = pixel_samples.unsqueeze(2)
|
||||||
|
|
||||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||||
@ -761,6 +795,7 @@ class VAE:
|
|||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class StyleModel:
|
class StyleModel:
|
||||||
def __init__(self, model, device="cpu"):
|
def __init__(self, model, device="cpu"):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|||||||
@ -1073,6 +1073,16 @@ class WAN21_Vace(WAN21_T2V):
|
|||||||
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class WAN21_HuMo(WAN21_T2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "wan2.1",
|
||||||
|
"model_type": "humo",
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.WAN21_HuMo(self, image_to_video=False, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
class WAN22_S2V(WAN21_T2V):
|
class WAN22_S2V(WAN21_T2V):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "wan2.1",
|
"image_model": "wan2.1",
|
||||||
@ -1205,6 +1215,19 @@ class Chroma(supported_models_base.BASE):
|
|||||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
||||||
|
|
||||||
|
class ChromaRadiance(Chroma):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "chroma_radiance",
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_format = comfy.latent_formats.ChromaRadiance
|
||||||
|
|
||||||
|
# Pixel-space model, no spatial compression for model input.
|
||||||
|
memory_usage_factor = 0.038
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
return model_base.ChromaRadiance(self, device=device)
|
||||||
|
|
||||||
class ACEStep(supported_models_base.BASE):
|
class ACEStep(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"audio_model": "ace",
|
"audio_model": "ace",
|
||||||
@ -1321,6 +1344,23 @@ class HunyuanImage21(HunyuanVideo):
|
|||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
|
class HunyuanImage21Refiner(HunyuanVideo):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "hunyuan_video",
|
||||||
|
"patch_size": [1, 1, 1],
|
||||||
|
"vec_in_dim": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 4.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_format = latent_formats.HunyuanImage21Refiner
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.HunyuanImage21Refiner(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@ -22,17 +22,14 @@ class HunyuanImageTokenizer(QwenImageTokenizer):
|
|||||||
|
|
||||||
# ByT5 processing for HunyuanImage
|
# ByT5 processing for HunyuanImage
|
||||||
text_prompt_texts = []
|
text_prompt_texts = []
|
||||||
pattern_quote_single = r'\'(.*?)\''
|
|
||||||
pattern_quote_double = r'\"(.*?)\"'
|
pattern_quote_double = r'\"(.*?)\"'
|
||||||
pattern_quote_chinese_single = r'‘(.*?)’'
|
pattern_quote_chinese_single = r'‘(.*?)’'
|
||||||
pattern_quote_chinese_double = r'“(.*?)”'
|
pattern_quote_chinese_double = r'“(.*?)”'
|
||||||
|
|
||||||
matches_quote_single = re.findall(pattern_quote_single, text)
|
|
||||||
matches_quote_double = re.findall(pattern_quote_double, text)
|
matches_quote_double = re.findall(pattern_quote_double, text)
|
||||||
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text)
|
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text)
|
||||||
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text)
|
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text)
|
||||||
|
|
||||||
text_prompt_texts.extend(matches_quote_single)
|
|
||||||
text_prompt_texts.extend(matches_quote_double)
|
text_prompt_texts.extend(matches_quote_double)
|
||||||
text_prompt_texts.extend(matches_quote_chinese_single)
|
text_prompt_texts.extend(matches_quote_chinese_single)
|
||||||
text_prompt_texts.extend(matches_quote_chinese_double)
|
text_prompt_texts.extend(matches_quote_chinese_double)
|
||||||
|
|||||||
@ -331,7 +331,7 @@ class String(ComfyTypeIO):
|
|||||||
})
|
})
|
||||||
|
|
||||||
@comfytype(io_type="COMBO")
|
@comfytype(io_type="COMBO")
|
||||||
class Combo(ComfyTypeI):
|
class Combo(ComfyTypeIO):
|
||||||
Type = str
|
Type = str
|
||||||
class Input(WidgetInput):
|
class Input(WidgetInput):
|
||||||
"""Combo input (dropdown)."""
|
"""Combo input (dropdown)."""
|
||||||
@ -360,6 +360,14 @@ class Combo(ComfyTypeI):
|
|||||||
"remote": self.remote.as_dict() if self.remote else None,
|
"remote": self.remote.as_dict() if self.remote else None,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
class Output(Output):
|
||||||
|
def __init__(self, id: str=None, display_name: str=None, options: list[str]=None, tooltip: str=None, is_output_list=False):
|
||||||
|
super().__init__(id, display_name, tooltip, is_output_list)
|
||||||
|
self.options = options if options is not None else []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def io_type(self):
|
||||||
|
return self.options
|
||||||
|
|
||||||
@comfytype(io_type="COMBO")
|
@comfytype(io_type="COMBO")
|
||||||
class MultiCombo(ComfyTypeI):
|
class MultiCombo(ComfyTypeI):
|
||||||
|
|||||||
@ -846,6 +846,8 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
|
|||||||
"pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"),
|
"pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"),
|
||||||
"pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"),
|
"pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"),
|
||||||
"pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"),
|
"pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"),
|
||||||
|
"pro mode / 5s duration / kling-v2-1": ("pro", "5", "kling-v2-1"),
|
||||||
|
"pro mode / 10s duration / kling-v2-1": ("pro", "10", "kling-v2-1"),
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
from inspect import cleandoc
|
from inspect import cleandoc
|
||||||
from typing import Union
|
from typing import Optional
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from comfy.comfy_types.node_typing import IO
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||||
from comfy_api.input_impl.video_types import VideoFromFile
|
from comfy_api.input_impl.video_types import VideoFromFile
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis import (
|
||||||
MinimaxVideoGenerationRequest,
|
MinimaxVideoGenerationRequest,
|
||||||
@ -11,7 +12,7 @@ from comfy_api_nodes.apis import (
|
|||||||
MinimaxFileRetrieveResponse,
|
MinimaxFileRetrieveResponse,
|
||||||
MinimaxTaskResultResponse,
|
MinimaxTaskResultResponse,
|
||||||
SubjectReferenceItem,
|
SubjectReferenceItem,
|
||||||
MiniMaxModel
|
MiniMaxModel,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.apis.client import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
@ -31,85 +32,29 @@ from server import PromptServer
|
|||||||
I2V_AVERAGE_DURATION = 114
|
I2V_AVERAGE_DURATION = 114
|
||||||
T2V_AVERAGE_DURATION = 234
|
T2V_AVERAGE_DURATION = 234
|
||||||
|
|
||||||
class MinimaxTextToVideoNode:
|
|
||||||
"""
|
|
||||||
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
AVERAGE_DURATION = T2V_AVERAGE_DURATION
|
async def _generate_mm_video(
|
||||||
|
*,
|
||||||
@classmethod
|
auth: dict[str, str],
|
||||||
def INPUT_TYPES(s):
|
node_id: str,
|
||||||
return {
|
prompt_text: str,
|
||||||
"required": {
|
seed: int,
|
||||||
"prompt_text": (
|
model: str,
|
||||||
"STRING",
|
image: Optional[torch.Tensor] = None, # used for ImageToVideo
|
||||||
{
|
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
|
||||||
"multiline": True,
|
average_duration: Optional[int] = None,
|
||||||
"default": "",
|
) -> comfy_io.NodeOutput:
|
||||||
"tooltip": "Text prompt to guide the video generation",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
"model": (
|
|
||||||
[
|
|
||||||
"T2V-01",
|
|
||||||
"T2V-01-Director",
|
|
||||||
],
|
|
||||||
{
|
|
||||||
"default": "T2V-01",
|
|
||||||
"tooltip": "Model to use for video generation",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"seed": (
|
|
||||||
IO.INT,
|
|
||||||
{
|
|
||||||
"default": 0,
|
|
||||||
"min": 0,
|
|
||||||
"max": 0xFFFFFFFFFFFFFFFF,
|
|
||||||
"control_after_generate": True,
|
|
||||||
"tooltip": "The random seed used for creating the noise.",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"hidden": {
|
|
||||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
|
||||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
|
||||||
"unique_id": "UNIQUE_ID",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("VIDEO",)
|
|
||||||
DESCRIPTION = "Generates videos from prompts using MiniMax's API"
|
|
||||||
FUNCTION = "generate_video"
|
|
||||||
CATEGORY = "api node/video/MiniMax"
|
|
||||||
API_NODE = True
|
|
||||||
|
|
||||||
async def generate_video(
|
|
||||||
self,
|
|
||||||
prompt_text,
|
|
||||||
seed=0,
|
|
||||||
model="T2V-01",
|
|
||||||
image: torch.Tensor=None, # used for ImageToVideo
|
|
||||||
subject: torch.Tensor=None, # used for SubjectToVideo
|
|
||||||
unique_id: Union[str, None]=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
'''
|
|
||||||
Function used between MiniMax nodes - supports T2V, I2V, and S2V, based on provided arguments.
|
|
||||||
'''
|
|
||||||
if image is None:
|
if image is None:
|
||||||
validate_string(prompt_text, field_name="prompt_text")
|
validate_string(prompt_text, field_name="prompt_text")
|
||||||
# upload image, if passed in
|
# upload image, if passed in
|
||||||
image_url = None
|
image_url = None
|
||||||
if image is not None:
|
if image is not None:
|
||||||
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs))[0]
|
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth))[0]
|
||||||
|
|
||||||
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
|
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
|
||||||
subject_reference = None
|
subject_reference = None
|
||||||
if subject is not None:
|
if subject is not None:
|
||||||
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs))[0]
|
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=auth))[0]
|
||||||
subject_reference = [SubjectReferenceItem(image=subject_url)]
|
subject_reference = [SubjectReferenceItem(image=subject_url)]
|
||||||
|
|
||||||
|
|
||||||
@ -128,7 +73,7 @@ class MinimaxTextToVideoNode:
|
|||||||
subject_reference=subject_reference,
|
subject_reference=subject_reference,
|
||||||
prompt_optimizer=None,
|
prompt_optimizer=None,
|
||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=auth,
|
||||||
)
|
)
|
||||||
response = await video_generate_operation.execute()
|
response = await video_generate_operation.execute()
|
||||||
|
|
||||||
@ -147,9 +92,9 @@ class MinimaxTextToVideoNode:
|
|||||||
completed_statuses=["Success"],
|
completed_statuses=["Success"],
|
||||||
failed_statuses=["Fail"],
|
failed_statuses=["Fail"],
|
||||||
status_extractor=lambda x: x.status.value,
|
status_extractor=lambda x: x.status.value,
|
||||||
estimated_duration=self.AVERAGE_DURATION,
|
estimated_duration=average_duration,
|
||||||
node_id=unique_id,
|
node_id=node_id,
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=auth,
|
||||||
)
|
)
|
||||||
task_result = await video_generate_operation.execute()
|
task_result = await video_generate_operation.execute()
|
||||||
|
|
||||||
@ -165,7 +110,7 @@ class MinimaxTextToVideoNode:
|
|||||||
query_params={"file_id": int(file_id)},
|
query_params={"file_id": int(file_id)},
|
||||||
),
|
),
|
||||||
request=EmptyRequest(),
|
request=EmptyRequest(),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=auth,
|
||||||
)
|
)
|
||||||
file_result = await file_retrieve_operation.execute()
|
file_result = await file_retrieve_operation.execute()
|
||||||
|
|
||||||
@ -174,229 +119,311 @@ class MinimaxTextToVideoNode:
|
|||||||
raise Exception(
|
raise Exception(
|
||||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
||||||
)
|
)
|
||||||
logging.info(f"Generated video URL: {file_url}")
|
logging.info("Generated video URL: %s", file_url)
|
||||||
if unique_id:
|
if node_id:
|
||||||
if hasattr(file_result.file, "backup_download_url"):
|
if hasattr(file_result.file, "backup_download_url"):
|
||||||
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
||||||
else:
|
else:
|
||||||
message = f"Result URL: {file_url}"
|
message = f"Result URL: {file_url}"
|
||||||
PromptServer.instance.send_progress_text(message, unique_id)
|
PromptServer.instance.send_progress_text(message, node_id)
|
||||||
|
|
||||||
|
# Download and return as VideoFromFile
|
||||||
video_io = await download_url_to_bytesio(file_url)
|
video_io = await download_url_to_bytesio(file_url)
|
||||||
if video_io is None:
|
if video_io is None:
|
||||||
error_msg = f"Failed to download video from {file_url}"
|
error_msg = f"Failed to download video from {file_url}"
|
||||||
logging.error(error_msg)
|
logging.error(error_msg)
|
||||||
raise Exception(error_msg)
|
raise Exception(error_msg)
|
||||||
return (VideoFromFile(video_io),)
|
return comfy_io.NodeOutput(VideoFromFile(video_io))
|
||||||
|
|
||||||
|
|
||||||
class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
|
class MinimaxTextToVideoNode(comfy_io.ComfyNode):
|
||||||
|
"""
|
||||||
|
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> comfy_io.Schema:
|
||||||
|
return comfy_io.Schema(
|
||||||
|
node_id="MinimaxTextToVideoNode",
|
||||||
|
display_name="MiniMax Text to Video",
|
||||||
|
category="api node/video/MiniMax",
|
||||||
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
|
inputs=[
|
||||||
|
comfy_io.String.Input(
|
||||||
|
"prompt_text",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Text prompt to guide the video generation",
|
||||||
|
),
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"model",
|
||||||
|
options=["T2V-01", "T2V-01-Director"],
|
||||||
|
default="T2V-01",
|
||||||
|
tooltip="Model to use for video generation",
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=0xFFFFFFFFFFFFFFFF,
|
||||||
|
step=1,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="The random seed used for creating the noise.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[comfy_io.Video.Output()],
|
||||||
|
hidden=[
|
||||||
|
comfy_io.Hidden.auth_token_comfy_org,
|
||||||
|
comfy_io.Hidden.api_key_comfy_org,
|
||||||
|
comfy_io.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt_text: str,
|
||||||
|
model: str = "T2V-01",
|
||||||
|
seed: int = 0,
|
||||||
|
) -> comfy_io.NodeOutput:
|
||||||
|
return await _generate_mm_video(
|
||||||
|
auth={
|
||||||
|
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||||
|
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||||
|
},
|
||||||
|
node_id=cls.hidden.unique_id,
|
||||||
|
prompt_text=prompt_text,
|
||||||
|
seed=seed,
|
||||||
|
model=model,
|
||||||
|
image=None,
|
||||||
|
subject=None,
|
||||||
|
average_duration=T2V_AVERAGE_DURATION,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MinimaxImageToVideoNode(comfy_io.ComfyNode):
|
||||||
"""
|
"""
|
||||||
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
AVERAGE_DURATION = I2V_AVERAGE_DURATION
|
@classmethod
|
||||||
|
def define_schema(cls) -> comfy_io.Schema:
|
||||||
|
return comfy_io.Schema(
|
||||||
|
node_id="MinimaxImageToVideoNode",
|
||||||
|
display_name="MiniMax Image to Video",
|
||||||
|
category="api node/video/MiniMax",
|
||||||
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
|
inputs=[
|
||||||
|
comfy_io.Image.Input(
|
||||||
|
"image",
|
||||||
|
tooltip="Image to use as first frame of video generation",
|
||||||
|
),
|
||||||
|
comfy_io.String.Input(
|
||||||
|
"prompt_text",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Text prompt to guide the video generation",
|
||||||
|
),
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"model",
|
||||||
|
options=["I2V-01-Director", "I2V-01", "I2V-01-live"],
|
||||||
|
default="I2V-01",
|
||||||
|
tooltip="Model to use for video generation",
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=0xFFFFFFFFFFFFFFFF,
|
||||||
|
step=1,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="The random seed used for creating the noise.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[comfy_io.Video.Output()],
|
||||||
|
hidden=[
|
||||||
|
comfy_io.Hidden.auth_token_comfy_org,
|
||||||
|
comfy_io.Hidden.api_key_comfy_org,
|
||||||
|
comfy_io.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
async def execute(
|
||||||
return {
|
cls,
|
||||||
"required": {
|
image: torch.Tensor,
|
||||||
"image": (
|
prompt_text: str,
|
||||||
IO.IMAGE,
|
model: str = "I2V-01",
|
||||||
{
|
seed: int = 0,
|
||||||
"tooltip": "Image to use as first frame of video generation"
|
) -> comfy_io.NodeOutput:
|
||||||
|
return await _generate_mm_video(
|
||||||
|
auth={
|
||||||
|
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||||
|
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||||
},
|
},
|
||||||
),
|
node_id=cls.hidden.unique_id,
|
||||||
"prompt_text": (
|
prompt_text=prompt_text,
|
||||||
"STRING",
|
seed=seed,
|
||||||
{
|
model=model,
|
||||||
"multiline": True,
|
image=image,
|
||||||
"default": "",
|
subject=None,
|
||||||
"tooltip": "Text prompt to guide the video generation",
|
average_duration=I2V_AVERAGE_DURATION,
|
||||||
},
|
)
|
||||||
),
|
|
||||||
"model": (
|
|
||||||
[
|
|
||||||
"I2V-01-Director",
|
|
||||||
"I2V-01",
|
|
||||||
"I2V-01-live",
|
|
||||||
],
|
|
||||||
{
|
|
||||||
"default": "I2V-01",
|
|
||||||
"tooltip": "Model to use for video generation",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"seed": (
|
|
||||||
IO.INT,
|
|
||||||
{
|
|
||||||
"default": 0,
|
|
||||||
"min": 0,
|
|
||||||
"max": 0xFFFFFFFFFFFFFFFF,
|
|
||||||
"control_after_generate": True,
|
|
||||||
"tooltip": "The random seed used for creating the noise.",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"hidden": {
|
|
||||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
|
||||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
|
||||||
"unique_id": "UNIQUE_ID",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("VIDEO",)
|
|
||||||
DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API"
|
|
||||||
FUNCTION = "generate_video"
|
|
||||||
CATEGORY = "api node/video/MiniMax"
|
|
||||||
API_NODE = True
|
|
||||||
|
|
||||||
|
|
||||||
class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
|
class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
|
||||||
"""
|
"""
|
||||||
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
AVERAGE_DURATION = T2V_AVERAGE_DURATION
|
@classmethod
|
||||||
|
def define_schema(cls) -> comfy_io.Schema:
|
||||||
|
return comfy_io.Schema(
|
||||||
|
node_id="MinimaxSubjectToVideoNode",
|
||||||
|
display_name="MiniMax Subject to Video",
|
||||||
|
category="api node/video/MiniMax",
|
||||||
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
|
inputs=[
|
||||||
|
comfy_io.Image.Input(
|
||||||
|
"subject",
|
||||||
|
tooltip="Image of subject to reference for video generation",
|
||||||
|
),
|
||||||
|
comfy_io.String.Input(
|
||||||
|
"prompt_text",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Text prompt to guide the video generation",
|
||||||
|
),
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"model",
|
||||||
|
options=["S2V-01"],
|
||||||
|
default="S2V-01",
|
||||||
|
tooltip="Model to use for video generation",
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=0xFFFFFFFFFFFFFFFF,
|
||||||
|
step=1,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="The random seed used for creating the noise.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[comfy_io.Video.Output()],
|
||||||
|
hidden=[
|
||||||
|
comfy_io.Hidden.auth_token_comfy_org,
|
||||||
|
comfy_io.Hidden.api_key_comfy_org,
|
||||||
|
comfy_io.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
async def execute(
|
||||||
return {
|
cls,
|
||||||
"required": {
|
subject: torch.Tensor,
|
||||||
"subject": (
|
prompt_text: str,
|
||||||
IO.IMAGE,
|
model: str = "S2V-01",
|
||||||
{
|
seed: int = 0,
|
||||||
"tooltip": "Image of subject to reference video generation"
|
) -> comfy_io.NodeOutput:
|
||||||
|
return await _generate_mm_video(
|
||||||
|
auth={
|
||||||
|
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||||
|
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||||
},
|
},
|
||||||
),
|
node_id=cls.hidden.unique_id,
|
||||||
"prompt_text": (
|
prompt_text=prompt_text,
|
||||||
"STRING",
|
seed=seed,
|
||||||
{
|
model=model,
|
||||||
"multiline": True,
|
image=None,
|
||||||
"default": "",
|
subject=subject,
|
||||||
"tooltip": "Text prompt to guide the video generation",
|
average_duration=T2V_AVERAGE_DURATION,
|
||||||
},
|
)
|
||||||
),
|
|
||||||
"model": (
|
|
||||||
[
|
|
||||||
"S2V-01",
|
|
||||||
],
|
|
||||||
{
|
|
||||||
"default": "S2V-01",
|
|
||||||
"tooltip": "Model to use for video generation",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"seed": (
|
|
||||||
IO.INT,
|
|
||||||
{
|
|
||||||
"default": 0,
|
|
||||||
"min": 0,
|
|
||||||
"max": 0xFFFFFFFFFFFFFFFF,
|
|
||||||
"control_after_generate": True,
|
|
||||||
"tooltip": "The random seed used for creating the noise.",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"hidden": {
|
|
||||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
|
||||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
|
||||||
"unique_id": "UNIQUE_ID",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("VIDEO",)
|
|
||||||
DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API"
|
|
||||||
FUNCTION = "generate_video"
|
|
||||||
CATEGORY = "api node/video/MiniMax"
|
|
||||||
API_NODE = True
|
|
||||||
|
|
||||||
|
|
||||||
class MinimaxHailuoVideoNode:
|
class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
||||||
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
|
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> comfy_io.Schema:
|
||||||
return {
|
return comfy_io.Schema(
|
||||||
"required": {
|
node_id="MinimaxHailuoVideoNode",
|
||||||
"prompt_text": (
|
display_name="MiniMax Hailuo Video",
|
||||||
"STRING",
|
category="api node/video/MiniMax",
|
||||||
{
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
"multiline": True,
|
inputs=[
|
||||||
"default": "",
|
comfy_io.String.Input(
|
||||||
"tooltip": "Text prompt to guide the video generation.",
|
"prompt_text",
|
||||||
},
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Text prompt to guide the video generation.",
|
||||||
),
|
),
|
||||||
},
|
comfy_io.Int.Input(
|
||||||
"optional": {
|
"seed",
|
||||||
"seed": (
|
default=0,
|
||||||
IO.INT,
|
min=0,
|
||||||
{
|
max=0xFFFFFFFFFFFFFFFF,
|
||||||
"default": 0,
|
step=1,
|
||||||
"min": 0,
|
control_after_generate=True,
|
||||||
"max": 0xFFFFFFFFFFFFFFFF,
|
tooltip="The random seed used for creating the noise.",
|
||||||
"control_after_generate": True,
|
optional=True,
|
||||||
"tooltip": "The random seed used for creating the noise.",
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"first_frame_image": (
|
comfy_io.Image.Input(
|
||||||
IO.IMAGE,
|
"first_frame_image",
|
||||||
{
|
tooltip="Optional image to use as the first frame to generate a video.",
|
||||||
"tooltip": "Optional image to use as the first frame to generate a video."
|
optional=True,
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"prompt_optimizer": (
|
comfy_io.Boolean.Input(
|
||||||
IO.BOOLEAN,
|
"prompt_optimizer",
|
||||||
{
|
default=True,
|
||||||
"tooltip": "Optimize prompt to improve generation quality when needed.",
|
tooltip="Optimize prompt to improve generation quality when needed.",
|
||||||
"default": True,
|
optional=True,
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"duration": (
|
comfy_io.Combo.Input(
|
||||||
IO.COMBO,
|
"duration",
|
||||||
{
|
options=[6, 10],
|
||||||
"tooltip": "The length of the output video in seconds.",
|
default=6,
|
||||||
"default": 6,
|
tooltip="The length of the output video in seconds.",
|
||||||
"options": [6, 10],
|
optional=True,
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"resolution": (
|
comfy_io.Combo.Input(
|
||||||
IO.COMBO,
|
"resolution",
|
||||||
{
|
options=["768P", "1080P"],
|
||||||
"tooltip": "The dimensions of the video display. "
|
default="768P",
|
||||||
"1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels.",
|
tooltip="The dimensions of the video display. 1080p is 1920x1080, 768p is 1366x768.",
|
||||||
"default": "768P",
|
optional=True,
|
||||||
"options": ["768P", "1080P"],
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
},
|
],
|
||||||
"hidden": {
|
outputs=[comfy_io.Video.Output()],
|
||||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
hidden=[
|
||||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
comfy_io.Hidden.auth_token_comfy_org,
|
||||||
"unique_id": "UNIQUE_ID",
|
comfy_io.Hidden.api_key_comfy_org,
|
||||||
},
|
comfy_io.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt_text: str,
|
||||||
|
seed: int = 0,
|
||||||
|
first_frame_image: Optional[torch.Tensor] = None, # used for ImageToVideo
|
||||||
|
prompt_optimizer: bool = True,
|
||||||
|
duration: int = 6,
|
||||||
|
resolution: str = "768P",
|
||||||
|
model: str = "MiniMax-Hailuo-02",
|
||||||
|
) -> comfy_io.NodeOutput:
|
||||||
|
auth = {
|
||||||
|
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||||
|
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("VIDEO",)
|
|
||||||
DESCRIPTION = cleandoc(__doc__ or "")
|
|
||||||
FUNCTION = "generate_video"
|
|
||||||
CATEGORY = "api node/video/MiniMax"
|
|
||||||
API_NODE = True
|
|
||||||
|
|
||||||
async def generate_video(
|
|
||||||
self,
|
|
||||||
prompt_text,
|
|
||||||
seed=0,
|
|
||||||
first_frame_image: torch.Tensor=None, # used for ImageToVideo
|
|
||||||
prompt_optimizer=True,
|
|
||||||
duration=6,
|
|
||||||
resolution="768P",
|
|
||||||
model="MiniMax-Hailuo-02",
|
|
||||||
unique_id: Union[str, None]=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
if first_frame_image is None:
|
if first_frame_image is None:
|
||||||
validate_string(prompt_text, field_name="prompt_text")
|
validate_string(prompt_text, field_name="prompt_text")
|
||||||
|
|
||||||
@ -408,7 +435,7 @@ class MinimaxHailuoVideoNode:
|
|||||||
# upload image, if passed in
|
# upload image, if passed in
|
||||||
image_url = None
|
image_url = None
|
||||||
if first_frame_image is not None:
|
if first_frame_image is not None:
|
||||||
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=kwargs))[0]
|
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=auth))[0]
|
||||||
|
|
||||||
video_generate_operation = SynchronousOperation(
|
video_generate_operation = SynchronousOperation(
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
@ -426,7 +453,7 @@ class MinimaxHailuoVideoNode:
|
|||||||
duration=duration,
|
duration=duration,
|
||||||
resolution=resolution,
|
resolution=resolution,
|
||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=auth,
|
||||||
)
|
)
|
||||||
response = await video_generate_operation.execute()
|
response = await video_generate_operation.execute()
|
||||||
|
|
||||||
@ -447,8 +474,8 @@ class MinimaxHailuoVideoNode:
|
|||||||
failed_statuses=["Fail"],
|
failed_statuses=["Fail"],
|
||||||
status_extractor=lambda x: x.status.value,
|
status_extractor=lambda x: x.status.value,
|
||||||
estimated_duration=average_duration,
|
estimated_duration=average_duration,
|
||||||
node_id=unique_id,
|
node_id=cls.hidden.unique_id,
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=auth,
|
||||||
)
|
)
|
||||||
task_result = await video_generate_operation.execute()
|
task_result = await video_generate_operation.execute()
|
||||||
|
|
||||||
@ -464,7 +491,7 @@ class MinimaxHailuoVideoNode:
|
|||||||
query_params={"file_id": int(file_id)},
|
query_params={"file_id": int(file_id)},
|
||||||
),
|
),
|
||||||
request=EmptyRequest(),
|
request=EmptyRequest(),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=auth,
|
||||||
)
|
)
|
||||||
file_result = await file_retrieve_operation.execute()
|
file_result = await file_retrieve_operation.execute()
|
||||||
|
|
||||||
@ -474,34 +501,31 @@ class MinimaxHailuoVideoNode:
|
|||||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
||||||
)
|
)
|
||||||
logging.info(f"Generated video URL: {file_url}")
|
logging.info(f"Generated video URL: {file_url}")
|
||||||
if unique_id:
|
if cls.hidden.unique_id:
|
||||||
if hasattr(file_result.file, "backup_download_url"):
|
if hasattr(file_result.file, "backup_download_url"):
|
||||||
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
||||||
else:
|
else:
|
||||||
message = f"Result URL: {file_url}"
|
message = f"Result URL: {file_url}"
|
||||||
PromptServer.instance.send_progress_text(message, unique_id)
|
PromptServer.instance.send_progress_text(message, cls.hidden.unique_id)
|
||||||
|
|
||||||
video_io = await download_url_to_bytesio(file_url)
|
video_io = await download_url_to_bytesio(file_url)
|
||||||
if video_io is None:
|
if video_io is None:
|
||||||
error_msg = f"Failed to download video from {file_url}"
|
error_msg = f"Failed to download video from {file_url}"
|
||||||
logging.error(error_msg)
|
logging.error(error_msg)
|
||||||
raise Exception(error_msg)
|
raise Exception(error_msg)
|
||||||
return (VideoFromFile(video_io),)
|
return comfy_io.NodeOutput(VideoFromFile(video_io))
|
||||||
|
|
||||||
|
|
||||||
# A dictionary that contains all nodes you want to export with their names
|
class MinimaxExtension(ComfyExtension):
|
||||||
# NOTE: names should be globally unique
|
@override
|
||||||
NODE_CLASS_MAPPINGS = {
|
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||||
"MinimaxTextToVideoNode": MinimaxTextToVideoNode,
|
return [
|
||||||
"MinimaxImageToVideoNode": MinimaxImageToVideoNode,
|
MinimaxTextToVideoNode,
|
||||||
# "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode,
|
MinimaxImageToVideoNode,
|
||||||
"MinimaxHailuoVideoNode": MinimaxHailuoVideoNode,
|
# MinimaxSubjectToVideoNode,
|
||||||
}
|
MinimaxHailuoVideoNode,
|
||||||
|
]
|
||||||
|
|
||||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
async def comfy_entrypoint() -> MinimaxExtension:
|
||||||
"MinimaxTextToVideoNode": "MiniMax Text to Video",
|
return MinimaxExtension()
|
||||||
"MinimaxImageToVideoNode": "MiniMax Image to Video",
|
|
||||||
"MinimaxSubjectToVideoNode": "MiniMax Subject to Video",
|
|
||||||
"MinimaxHailuoVideoNode": "MiniMax Hailuo Video",
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Optional, TypeVar
|
from typing import Any, Callable, Optional, TypeVar
|
||||||
import torch
|
import torch
|
||||||
|
from typing_extensions import override
|
||||||
from comfy_api_nodes.util.validation_utils import (
|
from comfy_api_nodes.util.validation_utils import (
|
||||||
get_image_dimensions,
|
get_image_dimensions,
|
||||||
validate_image_dimensions,
|
validate_image_dimensions,
|
||||||
@ -26,11 +27,9 @@ from comfy_api_nodes.apinode_utils import (
|
|||||||
upload_images_to_comfyapi,
|
upload_images_to_comfyapi,
|
||||||
upload_video_to_comfyapi,
|
upload_video_to_comfyapi,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
|
||||||
|
|
||||||
from comfy_api.input.video_types import VideoInput
|
from comfy_api.input import VideoInput
|
||||||
from comfy.comfy_types.node_typing import IO
|
from comfy_api.latest import ComfyExtension, InputImpl, io as comfy_io
|
||||||
from comfy_api.input_impl import VideoFromFile
|
|
||||||
import av
|
import av
|
||||||
import io
|
import io
|
||||||
|
|
||||||
@ -362,7 +361,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
|||||||
|
|
||||||
# Return as VideoFromFile using the buffer
|
# Return as VideoFromFile using the buffer
|
||||||
output_buffer.seek(0)
|
output_buffer.seek(0)
|
||||||
return VideoFromFile(output_buffer)
|
return InputImpl.VideoFromFile(output_buffer)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Clean up on error
|
# Clean up on error
|
||||||
@ -373,9 +372,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
|||||||
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
|
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
# --- BaseMoonvalleyVideoNode ---
|
def parse_width_height_from_res(resolution: str):
|
||||||
class BaseMoonvalleyVideoNode:
|
|
||||||
def parseWidthHeightFromRes(self, resolution: str):
|
|
||||||
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
|
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
|
||||||
res_map = {
|
res_map = {
|
||||||
"16:9 (1920 x 1080)": {"width": 1920, "height": 1080},
|
"16:9 (1920 x 1080)": {"width": 1920, "height": 1080},
|
||||||
@ -385,27 +382,22 @@ class BaseMoonvalleyVideoNode:
|
|||||||
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
|
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
|
||||||
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
|
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
|
||||||
}
|
}
|
||||||
if resolution in res_map:
|
return res_map.get(resolution, {"width": 1920, "height": 1080})
|
||||||
return res_map[resolution]
|
|
||||||
else:
|
|
||||||
# Default to 1920x1080 if unknown
|
|
||||||
return {"width": 1920, "height": 1080}
|
|
||||||
|
|
||||||
def parseControlParameter(self, value):
|
|
||||||
|
def parse_control_parameter(value):
|
||||||
control_map = {
|
control_map = {
|
||||||
"Motion Transfer": "motion_control",
|
"Motion Transfer": "motion_control",
|
||||||
"Canny": "canny_control",
|
"Canny": "canny_control",
|
||||||
"Pose Transfer": "pose_control",
|
"Pose Transfer": "pose_control",
|
||||||
"Depth": "depth_control",
|
"Depth": "depth_control",
|
||||||
}
|
}
|
||||||
if value in control_map:
|
return control_map.get(value, control_map["Motion Transfer"])
|
||||||
return control_map[value]
|
|
||||||
else:
|
|
||||||
return control_map["Motion Transfer"]
|
|
||||||
|
|
||||||
async def get_response(
|
|
||||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
async def get_response(
|
||||||
) -> MoonvalleyPromptResponse:
|
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||||
|
) -> MoonvalleyPromptResponse:
|
||||||
return await poll_until_finished(
|
return await poll_until_finished(
|
||||||
auth_kwargs,
|
auth_kwargs,
|
||||||
ApiEndpoint(
|
ApiEndpoint(
|
||||||
@ -418,121 +410,112 @@ class BaseMoonvalleyVideoNode:
|
|||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls) -> comfy_io.Schema:
|
||||||
return {
|
return comfy_io.Schema(
|
||||||
"required": {
|
node_id="MoonvalleyImg2VideoNode",
|
||||||
"prompt": model_field_to_node_input(
|
display_name="Moonvalley Marey Image to Video",
|
||||||
IO.STRING,
|
category="api node/video/Moonvalley Marey",
|
||||||
MoonvalleyTextToVideoRequest,
|
description="Moonvalley Marey Image to Video Node",
|
||||||
"prompt_text",
|
inputs=[
|
||||||
|
comfy_io.Image.Input(
|
||||||
|
"image",
|
||||||
|
tooltip="The reference image used to generate the video",
|
||||||
|
),
|
||||||
|
comfy_io.String.Input(
|
||||||
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
),
|
),
|
||||||
"negative_prompt": model_field_to_node_input(
|
comfy_io.String.Input(
|
||||||
IO.STRING,
|
|
||||||
MoonvalleyTextToVideoInferenceParams,
|
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
|
||||||
|
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
|
||||||
|
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
|
||||||
|
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
|
||||||
|
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
|
||||||
|
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||||
|
tooltip="Negative prompt text",
|
||||||
),
|
),
|
||||||
"resolution": (
|
comfy_io.Combo.Input(
|
||||||
IO.COMBO,
|
"resolution",
|
||||||
{
|
options=[
|
||||||
"options": [
|
|
||||||
"16:9 (1920 x 1080)",
|
"16:9 (1920 x 1080)",
|
||||||
"9:16 (1080 x 1920)",
|
"9:16 (1080 x 1920)",
|
||||||
"1:1 (1152 x 1152)",
|
"1:1 (1152 x 1152)",
|
||||||
"4:3 (1440 x 1080)",
|
"4:3 (1536 x 1152)",
|
||||||
"3:4 (1080 x 1440)",
|
"3:4 (1152 x 1536)",
|
||||||
"21:9 (2560 x 1080)",
|
"21:9 (2560 x 1080)",
|
||||||
],
|
],
|
||||||
"default": "16:9 (1920 x 1080)",
|
default="16:9 (1920 x 1080)",
|
||||||
"tooltip": "Resolution of the output video",
|
tooltip="Resolution of the output video",
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"prompt_adherence": model_field_to_node_input(
|
comfy_io.Float.Input(
|
||||||
IO.FLOAT,
|
"prompt_adherence",
|
||||||
MoonvalleyTextToVideoInferenceParams,
|
|
||||||
"guidance_scale",
|
|
||||||
default=10.0,
|
default=10.0,
|
||||||
step=1,
|
min=1.0,
|
||||||
min=1,
|
max=20.0,
|
||||||
max=20,
|
step=1.0,
|
||||||
|
tooltip="Guidance scale for generation control",
|
||||||
),
|
),
|
||||||
"seed": model_field_to_node_input(
|
comfy_io.Int.Input(
|
||||||
IO.INT,
|
|
||||||
MoonvalleyTextToVideoInferenceParams,
|
|
||||||
"seed",
|
"seed",
|
||||||
default=9,
|
default=9,
|
||||||
min=0,
|
min=0,
|
||||||
max=4294967295,
|
max=4294967295,
|
||||||
step=1,
|
step=1,
|
||||||
display="number",
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
tooltip="Random seed value",
|
tooltip="Random seed value",
|
||||||
),
|
),
|
||||||
"steps": model_field_to_node_input(
|
comfy_io.Int.Input(
|
||||||
IO.INT,
|
|
||||||
MoonvalleyTextToVideoInferenceParams,
|
|
||||||
"steps",
|
"steps",
|
||||||
default=100,
|
default=100,
|
||||||
min=1,
|
min=1,
|
||||||
max=100,
|
max=100,
|
||||||
|
step=1,
|
||||||
|
tooltip="Number of denoising steps",
|
||||||
),
|
),
|
||||||
},
|
],
|
||||||
"hidden": {
|
outputs=[comfy_io.Video.Output()],
|
||||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
hidden=[
|
||||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
comfy_io.Hidden.auth_token_comfy_org,
|
||||||
"unique_id": "UNIQUE_ID",
|
comfy_io.Hidden.api_key_comfy_org,
|
||||||
},
|
comfy_io.Hidden.unique_id,
|
||||||
"optional": {
|
],
|
||||||
"image": model_field_to_node_input(
|
is_api_node=True,
|
||||||
IO.IMAGE,
|
)
|
||||||
MoonvalleyTextToVideoRequest,
|
|
||||||
"image_url",
|
|
||||||
tooltip="The reference image used to generate the video",
|
|
||||||
),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("STRING",)
|
|
||||||
FUNCTION = "generate"
|
|
||||||
CATEGORY = "api node/video/Moonvalley Marey"
|
|
||||||
API_NODE = True
|
|
||||||
|
|
||||||
def generate(self, **kwargs):
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# --- MoonvalleyImg2VideoNode ---
|
|
||||||
class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
async def execute(
|
||||||
return super().INPUT_TYPES()
|
cls,
|
||||||
|
image: torch.Tensor,
|
||||||
RETURN_TYPES = ("VIDEO",)
|
prompt: str,
|
||||||
RETURN_NAMES = ("video",)
|
negative_prompt: str,
|
||||||
DESCRIPTION = "Moonvalley Marey Image to Video Node"
|
resolution: str,
|
||||||
|
prompt_adherence: float,
|
||||||
async def generate(
|
seed: int,
|
||||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
steps: int,
|
||||||
):
|
) -> comfy_io.NodeOutput:
|
||||||
image = kwargs.get("image", None)
|
|
||||||
if image is None:
|
|
||||||
raise MoonvalleyApiError("image is required")
|
|
||||||
|
|
||||||
validate_input_image(image, True)
|
validate_input_image(image, True)
|
||||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||||
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
|
width_height = parse_width_height_from_res(resolution)
|
||||||
|
|
||||||
|
auth = {
|
||||||
|
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||||
|
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||||
|
}
|
||||||
|
|
||||||
inference_params = MoonvalleyTextToVideoInferenceParams(
|
inference_params = MoonvalleyTextToVideoInferenceParams(
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
steps=kwargs.get("steps"),
|
steps=steps,
|
||||||
seed=kwargs.get("seed"),
|
seed=seed,
|
||||||
guidance_scale=kwargs.get("prompt_adherence"),
|
guidance_scale=prompt_adherence,
|
||||||
num_frames=128,
|
num_frames=128,
|
||||||
width=width_height.get("width"),
|
width=width_height["width"],
|
||||||
height=width_height.get("height"),
|
height=width_height["height"],
|
||||||
use_negative_prompts=True,
|
use_negative_prompts=True,
|
||||||
)
|
)
|
||||||
"""Upload image to comfy backend to have a URL available for further processing"""
|
"""Upload image to comfy backend to have a URL available for further processing"""
|
||||||
@ -541,7 +524,7 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
|
|
||||||
image_url = (
|
image_url = (
|
||||||
await upload_images_to_comfyapi(
|
await upload_images_to_comfyapi(
|
||||||
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
|
image, max_images=1, auth_kwargs=auth, mime_type=mime_type
|
||||||
)
|
)
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
@ -556,127 +539,102 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
response_model=MoonvalleyPromptResponse,
|
response_model=MoonvalleyPromptResponse,
|
||||||
),
|
),
|
||||||
request=request,
|
request=request,
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=auth,
|
||||||
)
|
)
|
||||||
task_creation_response = await initial_operation.execute()
|
task_creation_response = await initial_operation.execute()
|
||||||
validate_task_creation_response(task_creation_response)
|
validate_task_creation_response(task_creation_response)
|
||||||
task_id = task_creation_response.id
|
task_id = task_creation_response.id
|
||||||
|
|
||||||
final_response = await self.get_response(
|
final_response = await get_response(
|
||||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
|
||||||
)
|
)
|
||||||
video = await download_url_to_video_output(final_response.output_url)
|
video = await download_url_to_video_output(final_response.output_url)
|
||||||
return (video,)
|
return comfy_io.NodeOutput(video)
|
||||||
|
|
||||||
|
|
||||||
# --- MoonvalleyVid2VidNode ---
|
class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
|
||||||
class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls) -> comfy_io.Schema:
|
||||||
return {
|
return comfy_io.Schema(
|
||||||
"required": {
|
node_id="MoonvalleyVideo2VideoNode",
|
||||||
"prompt": model_field_to_node_input(
|
display_name="Moonvalley Marey Video to Video",
|
||||||
IO.STRING,
|
category="api node/video/Moonvalley Marey",
|
||||||
MoonvalleyVideoToVideoRequest,
|
description="",
|
||||||
"prompt_text",
|
inputs=[
|
||||||
|
comfy_io.String.Input(
|
||||||
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
|
tooltip="Describes the video to generate",
|
||||||
),
|
),
|
||||||
"negative_prompt": model_field_to_node_input(
|
comfy_io.String.Input(
|
||||||
IO.STRING,
|
|
||||||
MoonvalleyVideoToVideoInferenceParams,
|
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
|
||||||
|
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
|
||||||
|
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
|
||||||
|
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
|
||||||
|
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
|
||||||
|
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||||
|
tooltip="Negative prompt text",
|
||||||
),
|
),
|
||||||
"seed": model_field_to_node_input(
|
comfy_io.Int.Input(
|
||||||
IO.INT,
|
|
||||||
MoonvalleyVideoToVideoInferenceParams,
|
|
||||||
"seed",
|
"seed",
|
||||||
default=9,
|
default=9,
|
||||||
min=0,
|
min=0,
|
||||||
max=4294967295,
|
max=4294967295,
|
||||||
step=1,
|
step=1,
|
||||||
display="number",
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
tooltip="Random seed value",
|
tooltip="Random seed value",
|
||||||
control_after_generate=False,
|
control_after_generate=False,
|
||||||
),
|
),
|
||||||
"prompt_adherence": model_field_to_node_input(
|
comfy_io.Video.Input(
|
||||||
IO.FLOAT,
|
"video",
|
||||||
MoonvalleyVideoToVideoInferenceParams,
|
tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. "
|
||||||
"guidance_scale",
|
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
|
||||||
default=10.0,
|
),
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"control_type",
|
||||||
|
options=["Motion Transfer", "Pose Transfer"],
|
||||||
|
default="Motion Transfer",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"motion_intensity",
|
||||||
|
default=100,
|
||||||
|
min=0,
|
||||||
|
max=100,
|
||||||
step=1,
|
step=1,
|
||||||
min=1,
|
tooltip="Only used if control_type is 'Motion Transfer'",
|
||||||
max=20,
|
optional=True,
|
||||||
),
|
),
|
||||||
},
|
],
|
||||||
"hidden": {
|
outputs=[comfy_io.Video.Output()],
|
||||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
hidden=[
|
||||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
comfy_io.Hidden.auth_token_comfy_org,
|
||||||
"unique_id": "UNIQUE_ID",
|
comfy_io.Hidden.api_key_comfy_org,
|
||||||
},
|
comfy_io.Hidden.unique_id,
|
||||||
"optional": {
|
],
|
||||||
"video": (
|
is_api_node=True,
|
||||||
IO.VIDEO,
|
)
|
||||||
{
|
|
||||||
"default": "",
|
@classmethod
|
||||||
"multiline": False,
|
async def execute(
|
||||||
"tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
|
cls,
|
||||||
},
|
prompt: str,
|
||||||
),
|
negative_prompt: str,
|
||||||
"control_type": (
|
seed: int,
|
||||||
["Motion Transfer", "Pose Transfer"],
|
video: Optional[VideoInput] = None,
|
||||||
{"default": "Motion Transfer"},
|
control_type: str = "Motion Transfer",
|
||||||
),
|
motion_intensity: Optional[int] = 100,
|
||||||
"motion_intensity": (
|
) -> comfy_io.NodeOutput:
|
||||||
"INT",
|
auth = {
|
||||||
{
|
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||||
"default": 100,
|
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||||
"step": 1,
|
|
||||||
"min": 0,
|
|
||||||
"max": 100,
|
|
||||||
"tooltip": "Only used if control_type is 'Motion Transfer'",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
"image": model_field_to_node_input(
|
|
||||||
IO.IMAGE,
|
|
||||||
MoonvalleyTextToVideoRequest,
|
|
||||||
"image_url",
|
|
||||||
tooltip="The reference image used to generate the video",
|
|
||||||
),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("VIDEO",)
|
|
||||||
RETURN_NAMES = ("video",)
|
|
||||||
|
|
||||||
async def generate(
|
|
||||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
|
||||||
):
|
|
||||||
video = kwargs.get("video")
|
|
||||||
image = kwargs.get("image", None)
|
|
||||||
|
|
||||||
if not video:
|
|
||||||
raise MoonvalleyApiError("video is required")
|
|
||||||
|
|
||||||
video_url = ""
|
|
||||||
if video:
|
|
||||||
validated_video = validate_video_to_video_input(video)
|
validated_video = validate_video_to_video_input(video)
|
||||||
video_url = await upload_video_to_comfyapi(
|
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth)
|
||||||
validated_video, auth_kwargs=kwargs
|
|
||||||
)
|
|
||||||
mime_type = "image/png"
|
|
||||||
|
|
||||||
if not image is None:
|
|
||||||
validate_input_image(image, with_frame_conditioning=True)
|
|
||||||
image_url = await upload_images_to_comfyapi(
|
|
||||||
image=image, auth_kwargs=kwargs, max_images=1, mime_type=mime_type
|
|
||||||
)
|
|
||||||
control_type = kwargs.get("control_type")
|
|
||||||
motion_intensity = kwargs.get("motion_intensity")
|
|
||||||
|
|
||||||
"""Validate prompts and inference input"""
|
"""Validate prompts and inference input"""
|
||||||
validate_prompts(prompt, negative_prompt)
|
validate_prompts(prompt, negative_prompt)
|
||||||
@ -688,11 +646,11 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
|
|
||||||
inference_params = MoonvalleyVideoToVideoInferenceParams(
|
inference_params = MoonvalleyVideoToVideoInferenceParams(
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
seed=kwargs.get("seed"),
|
seed=seed,
|
||||||
control_params=control_params,
|
control_params=control_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
control = self.parseControlParameter(control_type)
|
control = parse_control_parameter(control_type)
|
||||||
|
|
||||||
request = MoonvalleyVideoToVideoRequest(
|
request = MoonvalleyVideoToVideoRequest(
|
||||||
control_type=control,
|
control_type=control,
|
||||||
@ -700,7 +658,6 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
prompt_text=prompt,
|
prompt_text=prompt,
|
||||||
inference_params=inference_params,
|
inference_params=inference_params,
|
||||||
)
|
)
|
||||||
request.image_url = image_url if not image is None else None
|
|
||||||
|
|
||||||
initial_operation = SynchronousOperation(
|
initial_operation = SynchronousOperation(
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
@ -710,58 +667,125 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
response_model=MoonvalleyPromptResponse,
|
response_model=MoonvalleyPromptResponse,
|
||||||
),
|
),
|
||||||
request=request,
|
request=request,
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=auth,
|
||||||
)
|
)
|
||||||
task_creation_response = await initial_operation.execute()
|
task_creation_response = await initial_operation.execute()
|
||||||
validate_task_creation_response(task_creation_response)
|
validate_task_creation_response(task_creation_response)
|
||||||
task_id = task_creation_response.id
|
task_id = task_creation_response.id
|
||||||
|
|
||||||
final_response = await self.get_response(
|
final_response = await get_response(
|
||||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
|
||||||
)
|
)
|
||||||
|
|
||||||
video = await download_url_to_video_output(final_response.output_url)
|
video = await download_url_to_video_output(final_response.output_url)
|
||||||
|
return comfy_io.NodeOutput(video)
|
||||||
return (video,)
|
|
||||||
|
|
||||||
|
|
||||||
# --- MoonvalleyTxt2VideoNode ---
|
class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
|
||||||
class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
RETURN_TYPES = ("VIDEO",)
|
|
||||||
RETURN_NAMES = ("video",)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls) -> comfy_io.Schema:
|
||||||
input_types = super().INPUT_TYPES()
|
return comfy_io.Schema(
|
||||||
# Remove image-specific parameters
|
node_id="MoonvalleyTxt2VideoNode",
|
||||||
for param in ["image"]:
|
display_name="Moonvalley Marey Text to Video",
|
||||||
if param in input_types["optional"]:
|
category="api node/video/Moonvalley Marey",
|
||||||
del input_types["optional"][param]
|
description="",
|
||||||
return input_types
|
inputs=[
|
||||||
|
comfy_io.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
),
|
||||||
|
comfy_io.String.Input(
|
||||||
|
"negative_prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
|
||||||
|
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
|
||||||
|
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
|
||||||
|
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
|
||||||
|
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
|
||||||
|
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||||
|
tooltip="Negative prompt text",
|
||||||
|
),
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=[
|
||||||
|
"16:9 (1920 x 1080)",
|
||||||
|
"9:16 (1080 x 1920)",
|
||||||
|
"1:1 (1152 x 1152)",
|
||||||
|
"4:3 (1536 x 1152)",
|
||||||
|
"3:4 (1152 x 1536)",
|
||||||
|
"21:9 (2560 x 1080)",
|
||||||
|
],
|
||||||
|
default="16:9 (1920 x 1080)",
|
||||||
|
tooltip="Resolution of the output video",
|
||||||
|
),
|
||||||
|
comfy_io.Float.Input(
|
||||||
|
"prompt_adherence",
|
||||||
|
default=10.0,
|
||||||
|
min=1.0,
|
||||||
|
max=20.0,
|
||||||
|
step=1.0,
|
||||||
|
tooltip="Guidance scale for generation control",
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=9,
|
||||||
|
min=0,
|
||||||
|
max=4294967295,
|
||||||
|
step=1,
|
||||||
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
|
tooltip="Random seed value",
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"steps",
|
||||||
|
default=100,
|
||||||
|
min=1,
|
||||||
|
max=100,
|
||||||
|
step=1,
|
||||||
|
tooltip="Inference steps",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[comfy_io.Video.Output()],
|
||||||
|
hidden=[
|
||||||
|
comfy_io.Hidden.auth_token_comfy_org,
|
||||||
|
comfy_io.Hidden.api_key_comfy_org,
|
||||||
|
comfy_io.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
async def generate(
|
@classmethod
|
||||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
async def execute(
|
||||||
):
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
resolution: str,
|
||||||
|
prompt_adherence: float,
|
||||||
|
seed: int,
|
||||||
|
steps: int,
|
||||||
|
) -> comfy_io.NodeOutput:
|
||||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||||
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
|
width_height = parse_width_height_from_res(resolution)
|
||||||
|
|
||||||
|
auth = {
|
||||||
|
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||||
|
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||||
|
}
|
||||||
|
|
||||||
inference_params = MoonvalleyTextToVideoInferenceParams(
|
inference_params = MoonvalleyTextToVideoInferenceParams(
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
steps=kwargs.get("steps"),
|
steps=steps,
|
||||||
seed=kwargs.get("seed"),
|
seed=seed,
|
||||||
guidance_scale=kwargs.get("prompt_adherence"),
|
guidance_scale=prompt_adherence,
|
||||||
num_frames=128,
|
num_frames=128,
|
||||||
width=width_height.get("width"),
|
width=width_height["width"],
|
||||||
height=width_height.get("height"),
|
height=width_height["height"],
|
||||||
)
|
)
|
||||||
request = MoonvalleyTextToVideoRequest(
|
request = MoonvalleyTextToVideoRequest(
|
||||||
prompt_text=prompt, inference_params=inference_params
|
prompt_text=prompt, inference_params=inference_params
|
||||||
)
|
)
|
||||||
|
|
||||||
initial_operation = SynchronousOperation(
|
init_op = SynchronousOperation(
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path=API_TXT2VIDEO_ENDPOINT,
|
path=API_TXT2VIDEO_ENDPOINT,
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
@ -769,29 +793,29 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
response_model=MoonvalleyPromptResponse,
|
response_model=MoonvalleyPromptResponse,
|
||||||
),
|
),
|
||||||
request=request,
|
request=request,
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=auth,
|
||||||
)
|
)
|
||||||
task_creation_response = await initial_operation.execute()
|
task_creation_response = await init_op.execute()
|
||||||
validate_task_creation_response(task_creation_response)
|
validate_task_creation_response(task_creation_response)
|
||||||
task_id = task_creation_response.id
|
task_id = task_creation_response.id
|
||||||
|
|
||||||
final_response = await self.get_response(
|
final_response = await get_response(
|
||||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
|
||||||
)
|
)
|
||||||
|
|
||||||
video = await download_url_to_video_output(final_response.output_url)
|
video = await download_url_to_video_output(final_response.output_url)
|
||||||
return (video,)
|
return comfy_io.NodeOutput(video)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class MoonvalleyExtension(ComfyExtension):
|
||||||
"MoonvalleyImg2VideoNode": MoonvalleyImg2VideoNode,
|
@override
|
||||||
"MoonvalleyTxt2VideoNode": MoonvalleyTxt2VideoNode,
|
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||||
"MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode,
|
return [
|
||||||
}
|
MoonvalleyImg2VideoNode,
|
||||||
|
MoonvalleyTxt2VideoNode,
|
||||||
|
MoonvalleyVideo2VideoNode,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
async def comfy_entrypoint() -> MoonvalleyExtension:
|
||||||
"MoonvalleyImg2VideoNode": "Moonvalley Marey Image to Video",
|
return MoonvalleyExtension()
|
||||||
"MoonvalleyTxt2VideoNode": "Moonvalley Marey Text to Video",
|
|
||||||
"MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video",
|
|
||||||
}
|
|
||||||
|
|||||||
@ -2,12 +2,12 @@ import nodes
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
from typing_extensions import override
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
|
||||||
|
|
||||||
CAMERA_DICT = {
|
CAMERA_DICT = {
|
||||||
"base_T_norm": 1.5,
|
"base_T_norm": 1.5,
|
||||||
"base_angle": np.pi/3,
|
"base_angle": np.pi/3,
|
||||||
@ -148,32 +148,47 @@ def get_camera_motion(angle, T, speed, n=81):
|
|||||||
RT = np.stack(RT)
|
RT = np.stack(RT)
|
||||||
return RT
|
return RT
|
||||||
|
|
||||||
class WanCameraEmbedding:
|
class WanCameraEmbedding(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls):
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="WanCameraEmbedding",
|
||||||
"camera_pose":(["Static","Pan Up","Pan Down","Pan Left","Pan Right","Zoom In","Zoom Out","Anti Clockwise (ACW)", "ClockWise (CW)"],{"default":"Static"}),
|
category="camera",
|
||||||
"width": ("INT", {"default": 832, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
|
inputs=[
|
||||||
"height": ("INT", {"default": 480, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
|
io.Combo.Input(
|
||||||
"length": ("INT", {"default": 81, "min": 1, "max": MAX_RESOLUTION, "step": 4}),
|
"camera_pose",
|
||||||
},
|
options=[
|
||||||
"optional":{
|
"Static",
|
||||||
"speed":("FLOAT",{"default":1.0, "min": 0, "max": 10.0, "step": 0.1}),
|
"Pan Up",
|
||||||
"fx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}),
|
"Pan Down",
|
||||||
"fy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}),
|
"Pan Left",
|
||||||
"cx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}),
|
"Pan Right",
|
||||||
"cy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}),
|
"Zoom In",
|
||||||
}
|
"Zoom Out",
|
||||||
|
"Anti Clockwise (ACW)",
|
||||||
|
"ClockWise (CW)",
|
||||||
|
],
|
||||||
|
default="Static",
|
||||||
|
),
|
||||||
|
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
|
io.Float.Input("speed", default=1.0, min=0, max=10.0, step=0.1, optional=True),
|
||||||
|
io.Float.Input("fx", default=0.5, min=0, max=1, step=0.000000001, optional=True),
|
||||||
|
io.Float.Input("fy", default=0.5, min=0, max=1, step=0.000000001, optional=True),
|
||||||
|
io.Float.Input("cx", default=0.5, min=0, max=1, step=0.01, optional=True),
|
||||||
|
io.Float.Input("cy", default=0.5, min=0, max=1, step=0.01, optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.WanCameraEmbedding.Output(display_name="camera_embedding"),
|
||||||
|
io.Int.Output(display_name="width"),
|
||||||
|
io.Int.Output(display_name="height"),
|
||||||
|
io.Int.Output(display_name="length"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
}
|
@classmethod
|
||||||
|
def execute(cls, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5) -> io.NodeOutput:
|
||||||
RETURN_TYPES = ("WAN_CAMERA_EMBEDDING","INT","INT","INT")
|
|
||||||
RETURN_NAMES = ("camera_embedding","width","height","length")
|
|
||||||
FUNCTION = "run"
|
|
||||||
CATEGORY = "camera"
|
|
||||||
|
|
||||||
def run(self, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5):
|
|
||||||
"""
|
"""
|
||||||
Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021)
|
Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021)
|
||||||
Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py
|
Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py
|
||||||
@ -210,9 +225,15 @@ class WanCameraEmbedding:
|
|||||||
control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
|
control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
|
||||||
control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
|
control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
|
||||||
|
|
||||||
return (control_camera_video, width, height, length)
|
return io.NodeOutput(control_camera_video, width, height, length)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class CameraTrajectoryExtension(ComfyExtension):
|
||||||
"WanCameraEmbedding": WanCameraEmbedding,
|
@override
|
||||||
}
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
WanCameraEmbedding,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> CameraTrajectoryExtension:
|
||||||
|
return CameraTrajectoryExtension()
|
||||||
|
|||||||
@ -1,25 +1,41 @@
|
|||||||
from kornia.filters import canny
|
from kornia.filters import canny
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
class Canny:
|
class Canny(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"image": ("IMAGE",),
|
return io.Schema(
|
||||||
"low_threshold": ("FLOAT", {"default": 0.4, "min": 0.01, "max": 0.99, "step": 0.01}),
|
node_id="Canny",
|
||||||
"high_threshold": ("FLOAT", {"default": 0.8, "min": 0.01, "max": 0.99, "step": 0.01})
|
category="image/preprocessors",
|
||||||
}}
|
inputs=[
|
||||||
|
io.Image.Input("image"),
|
||||||
|
io.Float.Input("low_threshold", default=0.4, min=0.01, max=0.99, step=0.01),
|
||||||
|
io.Float.Input("high_threshold", default=0.8, min=0.01, max=0.99, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[io.Image.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
@classmethod
|
||||||
FUNCTION = "detect_edge"
|
def detect_edge(cls, image, low_threshold, high_threshold):
|
||||||
|
# Deprecated: use the V3 schema's `execute` method instead of this.
|
||||||
|
return cls.execute(image, low_threshold, high_threshold)
|
||||||
|
|
||||||
CATEGORY = "image/preprocessors"
|
@classmethod
|
||||||
|
def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput:
|
||||||
def detect_edge(self, image, low_threshold, high_threshold):
|
|
||||||
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
|
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
|
||||||
img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
|
img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
|
||||||
return (img_out,)
|
return io.NodeOutput(img_out)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"Canny": Canny,
|
class CannyExtension(ComfyExtension):
|
||||||
}
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [Canny]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> CannyExtension:
|
||||||
|
return CannyExtension()
|
||||||
|
|||||||
@ -1,5 +1,10 @@
|
|||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/WeichenFan/CFG-Zero-star
|
# https://github.com/WeichenFan/CFG-Zero-star
|
||||||
def optimized_scale(positive, negative):
|
def optimized_scale(positive, negative):
|
||||||
positive_flat = positive.reshape(positive.shape[0], -1)
|
positive_flat = positive.reshape(positive.shape[0], -1)
|
||||||
@ -16,17 +21,20 @@ def optimized_scale(positive, negative):
|
|||||||
|
|
||||||
return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1))
|
return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1))
|
||||||
|
|
||||||
class CFGZeroStar:
|
class CFGZeroStar(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required": {"model": ("MODEL",),
|
return io.Schema(
|
||||||
}}
|
node_id="CFGZeroStar",
|
||||||
RETURN_TYPES = ("MODEL",)
|
category="advanced/guidance",
|
||||||
RETURN_NAMES = ("patched_model",)
|
inputs=[
|
||||||
FUNCTION = "patch"
|
io.Model.Input("model"),
|
||||||
CATEGORY = "advanced/guidance"
|
],
|
||||||
|
outputs=[io.Model.Output(display_name="patched_model")],
|
||||||
|
)
|
||||||
|
|
||||||
def patch(self, model):
|
@classmethod
|
||||||
|
def execute(cls, model) -> io.NodeOutput:
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
def cfg_zero_star(args):
|
def cfg_zero_star(args):
|
||||||
guidance_scale = args['cond_scale']
|
guidance_scale = args['cond_scale']
|
||||||
@ -38,21 +46,24 @@ class CFGZeroStar:
|
|||||||
|
|
||||||
return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha)
|
return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha)
|
||||||
m.set_model_sampler_post_cfg_function(cfg_zero_star)
|
m.set_model_sampler_post_cfg_function(cfg_zero_star)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
class CFGNorm:
|
class CFGNorm(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required": {"model": ("MODEL",),
|
return io.Schema(
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
|
node_id="CFGNorm",
|
||||||
}}
|
category="advanced/guidance",
|
||||||
RETURN_TYPES = ("MODEL",)
|
inputs=[
|
||||||
RETURN_NAMES = ("patched_model",)
|
io.Model.Input("model"),
|
||||||
FUNCTION = "patch"
|
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||||
CATEGORY = "advanced/guidance"
|
],
|
||||||
EXPERIMENTAL = True
|
outputs=[io.Model.Output(display_name="patched_model")],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
def patch(self, model, strength):
|
@classmethod
|
||||||
|
def execute(cls, model, strength) -> io.NodeOutput:
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
def cfg_norm(args):
|
def cfg_norm(args):
|
||||||
cond_p = args['cond_denoised']
|
cond_p = args['cond_denoised']
|
||||||
@ -64,9 +75,17 @@ class CFGNorm:
|
|||||||
return pred_text_ * scale * strength
|
return pred_text_ * scale * strength
|
||||||
|
|
||||||
m.set_model_sampler_post_cfg_function(cfg_norm)
|
m.set_model_sampler_post_cfg_function(cfg_norm)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"CFGZeroStar": CFGZeroStar,
|
class CfgExtension(ComfyExtension):
|
||||||
"CFGNorm": CFGNorm,
|
@override
|
||||||
}
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
CFGZeroStar,
|
||||||
|
CFGNorm,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> CfgExtension:
|
||||||
|
return CfgExtension()
|
||||||
|
|||||||
114
comfy_extras/nodes_chroma_radiance.py
Normal file
114
comfy_extras/nodes_chroma_radiance.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
from typing_extensions import override
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
import nodes
|
||||||
|
|
||||||
|
class EmptyChromaRadianceLatentImage(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="EmptyChromaRadianceLatentImage",
|
||||||
|
category="latent/chroma_radiance",
|
||||||
|
inputs=[
|
||||||
|
io.Int.Input(id="width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input(id="height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input(id="batch_size", default=1, min=1, max=4096),
|
||||||
|
],
|
||||||
|
outputs=[io.Latent().Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, *, width: int, height: int, batch_size: int=1) -> io.NodeOutput:
|
||||||
|
latent = torch.zeros((batch_size, 3, height, width), device=comfy.model_management.intermediate_device())
|
||||||
|
return io.NodeOutput({"samples":latent})
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaRadianceOptions(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ChromaRadianceOptions",
|
||||||
|
category="model_patches/chroma_radiance",
|
||||||
|
description="Allows setting advanced options for the Chroma Radiance model.",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input(id="model"),
|
||||||
|
io.Boolean.Input(
|
||||||
|
id="preserve_wrapper",
|
||||||
|
default=True,
|
||||||
|
tooltip="When enabled, will delegate to an existing model function wrapper if it exists. Generally should be left enabled.",
|
||||||
|
),
|
||||||
|
io.Float.Input(
|
||||||
|
id="start_sigma",
|
||||||
|
default=1.0,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
tooltip="First sigma that these options will be in effect.",
|
||||||
|
),
|
||||||
|
io.Float.Input(
|
||||||
|
id="end_sigma",
|
||||||
|
default=0.0,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
tooltip="Last sigma that these options will be in effect.",
|
||||||
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
id="nerf_tile_size",
|
||||||
|
default=-1,
|
||||||
|
min=-1,
|
||||||
|
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
model: io.Model.Type,
|
||||||
|
preserve_wrapper: bool,
|
||||||
|
start_sigma: float,
|
||||||
|
end_sigma: float,
|
||||||
|
nerf_tile_size: int,
|
||||||
|
) -> io.NodeOutput:
|
||||||
|
radiance_options = {}
|
||||||
|
if nerf_tile_size >= 0:
|
||||||
|
radiance_options["nerf_tile_size"] = nerf_tile_size
|
||||||
|
|
||||||
|
if not radiance_options:
|
||||||
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
old_wrapper = model.model_options.get("model_function_wrapper")
|
||||||
|
|
||||||
|
def model_function_wrapper(apply_model: Callable, args: dict) -> torch.Tensor:
|
||||||
|
c = args["c"].copy()
|
||||||
|
sigma = args["timestep"].max().detach().cpu().item()
|
||||||
|
if end_sigma <= sigma <= start_sigma:
|
||||||
|
transformer_options = c.get("transformer_options", {}).copy()
|
||||||
|
transformer_options["chroma_radiance_options"] = radiance_options.copy()
|
||||||
|
c["transformer_options"] = transformer_options
|
||||||
|
if not (preserve_wrapper and old_wrapper):
|
||||||
|
return apply_model(args["input"], args["timestep"], **c)
|
||||||
|
return old_wrapper(apply_model, args | {"c": c})
|
||||||
|
|
||||||
|
model = model.clone()
|
||||||
|
model.set_model_unet_function_wrapper(model_function_wrapper)
|
||||||
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaRadianceExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
EmptyChromaRadianceLatentImage,
|
||||||
|
ChromaRadianceOptions,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> ChromaRadianceExtension:
|
||||||
|
return ChromaRadianceExtension()
|
||||||
@ -1,15 +1,25 @@
|
|||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncodeControlnet:
|
class CLIPTextEncodeControlnet(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True, "dynamicPrompts": True})}}
|
return io.Schema(
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
node_id="CLIPTextEncodeControlnet",
|
||||||
FUNCTION = "encode"
|
category="_for_testing/conditioning",
|
||||||
|
inputs=[
|
||||||
|
io.Clip.Input("clip"),
|
||||||
|
io.Conditioning.Input("conditioning"),
|
||||||
|
io.String.Input("text", multiline=True, dynamic_prompts=True),
|
||||||
|
],
|
||||||
|
outputs=[io.Conditioning.Output()],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "_for_testing/conditioning"
|
@classmethod
|
||||||
|
def execute(cls, clip, conditioning, text) -> io.NodeOutput:
|
||||||
def encode(self, clip, conditioning, text):
|
|
||||||
tokens = clip.tokenize(text)
|
tokens = clip.tokenize(text)
|
||||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||||
c = []
|
c = []
|
||||||
@ -18,32 +28,41 @@ class CLIPTextEncodeControlnet:
|
|||||||
n[1]['cross_attn_controlnet'] = cond
|
n[1]['cross_attn_controlnet'] = cond
|
||||||
n[1]['pooled_output_controlnet'] = pooled
|
n[1]['pooled_output_controlnet'] = pooled
|
||||||
c.append(n)
|
c.append(n)
|
||||||
return (c, )
|
return io.NodeOutput(c)
|
||||||
|
|
||||||
class T5TokenizerOptions:
|
class T5TokenizerOptions(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="T5TokenizerOptions",
|
||||||
"clip": ("CLIP", ),
|
category="_for_testing/conditioning",
|
||||||
"min_padding": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
|
inputs=[
|
||||||
"min_length": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
|
io.Clip.Input("clip"),
|
||||||
}
|
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1),
|
||||||
}
|
io.Int.Input("min_length", default=0, min=0, max=10000, step=1),
|
||||||
|
],
|
||||||
|
outputs=[io.Clip.Output()],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "_for_testing/conditioning"
|
@classmethod
|
||||||
RETURN_TYPES = ("CLIP",)
|
def execute(cls, clip, min_padding, min_length) -> io.NodeOutput:
|
||||||
FUNCTION = "set_options"
|
|
||||||
|
|
||||||
def set_options(self, clip, min_padding, min_length):
|
|
||||||
clip = clip.clone()
|
clip = clip.clone()
|
||||||
for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]:
|
for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]:
|
||||||
clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding)
|
clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding)
|
||||||
clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length)
|
clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length)
|
||||||
|
|
||||||
return (clip, )
|
return io.NodeOutput(clip)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"CLIPTextEncodeControlnet": CLIPTextEncodeControlnet,
|
class CondExtension(ComfyExtension):
|
||||||
"T5TokenizerOptions": T5TokenizerOptions,
|
@override
|
||||||
}
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
CLIPTextEncodeControlnet,
|
||||||
|
T5TokenizerOptions,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> CondExtension:
|
||||||
|
return CondExtension()
|
||||||
|
|||||||
@ -1,25 +1,32 @@
|
|||||||
|
from typing_extensions import override
|
||||||
import nodes
|
import nodes
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
class EmptyCosmosLatentVideo:
|
|
||||||
|
class EmptyCosmosLatentVideo(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required": { "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
return io.Schema(
|
||||||
"height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
node_id="EmptyCosmosLatentVideo",
|
||||||
"length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
category="latent/video",
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
inputs=[
|
||||||
RETURN_TYPES = ("LATENT",)
|
io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
FUNCTION = "generate"
|
io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
],
|
||||||
|
outputs=[io.Latent.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "latent/video"
|
@classmethod
|
||||||
|
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
|
||||||
def generate(self, width, height, length, batch_size=1):
|
|
||||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
return ({"samples": latent}, )
|
return io.NodeOutput({"samples": latent})
|
||||||
|
|
||||||
|
|
||||||
def vae_encode_with_padding(vae, image, width, height, length, padding=0):
|
def vae_encode_with_padding(vae, image, width, height, length, padding=0):
|
||||||
@ -33,31 +40,31 @@ def vae_encode_with_padding(vae, image, width, height, length, padding=0):
|
|||||||
return latent_temp[:, :, :latent_len]
|
return latent_temp[:, :, :latent_len]
|
||||||
|
|
||||||
|
|
||||||
class CosmosImageToVideoLatent:
|
class CosmosImageToVideoLatent(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required": {"vae": ("VAE", ),
|
return io.Schema(
|
||||||
"width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
node_id="CosmosImageToVideoLatent",
|
||||||
"height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
category="conditioning/inpaint",
|
||||||
"length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
inputs=[
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
io.Vae.Input("vae"),
|
||||||
},
|
io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
"optional": {"start_image": ("IMAGE", ),
|
io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
"end_image": ("IMAGE", ),
|
io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8),
|
||||||
}}
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
io.Image.Input("start_image", optional=True),
|
||||||
|
io.Image.Input("end_image", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[io.Latent.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
RETURN_TYPES = ("LATENT",)
|
def execute(cls, vae, width, height, length, batch_size, start_image=None, end_image=None) -> io.NodeOutput:
|
||||||
FUNCTION = "encode"
|
|
||||||
|
|
||||||
CATEGORY = "conditioning/inpaint"
|
|
||||||
|
|
||||||
def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None):
|
|
||||||
latent = torch.zeros([1, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([1, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
if start_image is None and end_image is None:
|
if start_image is None and end_image is None:
|
||||||
out_latent = {}
|
out_latent = {}
|
||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
return (out_latent,)
|
return io.NodeOutput(out_latent)
|
||||||
|
|
||||||
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 8) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 8) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
@ -74,33 +81,33 @@ class CosmosImageToVideoLatent:
|
|||||||
out_latent = {}
|
out_latent = {}
|
||||||
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
||||||
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||||
return (out_latent,)
|
return io.NodeOutput(out_latent)
|
||||||
|
|
||||||
class CosmosPredict2ImageToVideoLatent:
|
class CosmosPredict2ImageToVideoLatent(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required": {"vae": ("VAE", ),
|
return io.Schema(
|
||||||
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
node_id="CosmosPredict2ImageToVideoLatent",
|
||||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
category="conditioning/inpaint",
|
||||||
"length": ("INT", {"default": 93, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
inputs=[
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
io.Vae.Input("vae"),
|
||||||
},
|
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
"optional": {"start_image": ("IMAGE", ),
|
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
"end_image": ("IMAGE", ),
|
io.Int.Input("length", default=93, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
}}
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
io.Image.Input("start_image", optional=True),
|
||||||
|
io.Image.Input("end_image", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[io.Latent.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
RETURN_TYPES = ("LATENT",)
|
def execute(cls, vae, width, height, length, batch_size, start_image=None, end_image=None) -> io.NodeOutput:
|
||||||
FUNCTION = "encode"
|
|
||||||
|
|
||||||
CATEGORY = "conditioning/inpaint"
|
|
||||||
|
|
||||||
def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None):
|
|
||||||
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
if start_image is None and end_image is None:
|
if start_image is None and end_image is None:
|
||||||
out_latent = {}
|
out_latent = {}
|
||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
return (out_latent,)
|
return io.NodeOutput(out_latent)
|
||||||
|
|
||||||
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
@ -119,10 +126,18 @@ class CosmosPredict2ImageToVideoLatent:
|
|||||||
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
|
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
|
||||||
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
||||||
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||||
return (out_latent,)
|
return io.NodeOutput(out_latent)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"EmptyCosmosLatentVideo": EmptyCosmosLatentVideo,
|
class CosmosExtension(ComfyExtension):
|
||||||
"CosmosImageToVideoLatent": CosmosImageToVideoLatent,
|
@override
|
||||||
"CosmosPredict2ImageToVideoLatent": CosmosPredict2ImageToVideoLatent,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
}
|
return [
|
||||||
|
EmptyCosmosLatentVideo,
|
||||||
|
CosmosImageToVideoLatent,
|
||||||
|
CosmosPredict2ImageToVideoLatent,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> CosmosExtension:
|
||||||
|
return CosmosExtension()
|
||||||
|
|||||||
@ -128,6 +128,28 @@ class EmptyHunyuanImageLatent:
|
|||||||
latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
||||||
return ({"samples":latent}, )
|
return ({"samples":latent}, )
|
||||||
|
|
||||||
|
class HunyuanRefinerLatent:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"latent": ("LATENT", ),
|
||||||
|
"noise_augmentation": ("FLOAT", {"default": 0.10, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
|
RETURN_NAMES = ("positive", "negative", "latent")
|
||||||
|
|
||||||
|
FUNCTION = "execute"
|
||||||
|
|
||||||
|
def execute(self, positive, negative, latent, noise_augmentation):
|
||||||
|
latent = latent["samples"]
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||||
|
return (positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
||||||
@ -135,4 +157,5 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
|
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
|
||||||
"HunyuanImageToVideo": HunyuanImageToVideo,
|
"HunyuanImageToVideo": HunyuanImageToVideo,
|
||||||
"EmptyHunyuanImageLatent": EmptyHunyuanImageLatent,
|
"EmptyHunyuanImageLatent": EmptyHunyuanImageLatent,
|
||||||
|
"HunyuanRefinerLatent": HunyuanRefinerLatent,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1015,6 +1015,103 @@ class WanSoundImageToVideoExtend(io.ComfyNode):
|
|||||||
return io.NodeOutput(positive, negative, out_latent)
|
return io.NodeOutput(positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_emb_window(audio_emb, frame_num, frame0_idx, audio_shift=2):
|
||||||
|
zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)
|
||||||
|
zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device
|
||||||
|
iter_ = 1 + (frame_num - 1) // 4
|
||||||
|
audio_emb_wind = []
|
||||||
|
for lt_i in range(iter_):
|
||||||
|
if lt_i == 0:
|
||||||
|
st = frame0_idx + lt_i - 2
|
||||||
|
ed = frame0_idx + lt_i + 3
|
||||||
|
wind_feat = torch.stack([
|
||||||
|
audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
|
||||||
|
for i in range(st, ed)
|
||||||
|
], dim=0)
|
||||||
|
wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0)
|
||||||
|
else:
|
||||||
|
st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift
|
||||||
|
ed = frame0_idx + 1 + 4 * lt_i + audio_shift
|
||||||
|
wind_feat = torch.stack([
|
||||||
|
audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
|
||||||
|
for i in range(st, ed)
|
||||||
|
], dim=0)
|
||||||
|
audio_emb_wind.append(wind_feat)
|
||||||
|
audio_emb_wind = torch.stack(audio_emb_wind, dim=0)
|
||||||
|
|
||||||
|
return audio_emb_wind, ed - audio_shift
|
||||||
|
|
||||||
|
|
||||||
|
class WanHuMoImageToVideo(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="WanHuMoImageToVideo",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Conditioning.Input("negative"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
io.AudioEncoderOutput.Input("audio_encoder_output", optional=True),
|
||||||
|
io.Image.Input("ref_image", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None) -> io.NodeOutput:
|
||||||
|
latent_t = ((length - 1) // 4) + 1
|
||||||
|
latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
if ref_image is not None:
|
||||||
|
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
ref_latent = vae.encode(ref_image[:, :, :, :3])
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True)
|
||||||
|
else:
|
||||||
|
zero_latent = torch.zeros([batch_size, 16, 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [zero_latent]}, append=True)
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [zero_latent]}, append=True)
|
||||||
|
|
||||||
|
if audio_encoder_output is not None:
|
||||||
|
audio_emb = torch.stack(audio_encoder_output["encoded_audio_all_layers"], dim=2)
|
||||||
|
audio_len = audio_encoder_output["audio_samples"] // 640
|
||||||
|
audio_emb = audio_emb[:, :audio_len * 2]
|
||||||
|
|
||||||
|
feat0 = linear_interpolation(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25)
|
||||||
|
feat1 = linear_interpolation(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25)
|
||||||
|
feat2 = linear_interpolation(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25)
|
||||||
|
feat3 = linear_interpolation(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25)
|
||||||
|
feat4 = linear_interpolation(audio_emb[:, :, 32], 50, 25)
|
||||||
|
audio_emb = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] # [T, 5, 1280]
|
||||||
|
audio_emb, _ = get_audio_emb_window(audio_emb, length, frame0_idx=0)
|
||||||
|
|
||||||
|
# pad for ref latent
|
||||||
|
zero_audio_pad = torch.zeros(ref_latent.shape[2], *audio_emb.shape[1:], device=audio_emb.device, dtype=audio_emb.dtype)
|
||||||
|
audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0)
|
||||||
|
|
||||||
|
audio_emb = audio_emb.unsqueeze(0)
|
||||||
|
audio_emb_neg = torch.zeros_like(audio_emb)
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_emb})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_emb_neg})
|
||||||
|
else:
|
||||||
|
zero_audio = torch.zeros([batch_size, latent_t + 1, 8, 5, 1280], device=comfy.model_management.intermediate_device())
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": zero_audio})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": zero_audio})
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return io.NodeOutput(positive, negative, out_latent)
|
||||||
|
|
||||||
class Wan22ImageToVideoLatent(io.ComfyNode):
|
class Wan22ImageToVideoLatent(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -1075,6 +1172,7 @@ class WanExtension(ComfyExtension):
|
|||||||
WanPhantomSubjectToVideo,
|
WanPhantomSubjectToVideo,
|
||||||
WanSoundImageToVideo,
|
WanSoundImageToVideo,
|
||||||
WanSoundImageToVideoExtend,
|
WanSoundImageToVideoExtend,
|
||||||
|
WanHuMoImageToVideo,
|
||||||
Wan22ImageToVideoLatent,
|
Wan22ImageToVideoLatent,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
7
nodes.py
7
nodes.py
@ -730,6 +730,7 @@ class VAELoader:
|
|||||||
vaes.append("taesd3")
|
vaes.append("taesd3")
|
||||||
if f1_taesd_dec and f1_taesd_enc:
|
if f1_taesd_dec and f1_taesd_enc:
|
||||||
vaes.append("taef1")
|
vaes.append("taef1")
|
||||||
|
vaes.append("pixel_space")
|
||||||
return vaes
|
return vaes
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -772,7 +773,10 @@ class VAELoader:
|
|||||||
|
|
||||||
#TODO: scale factor?
|
#TODO: scale factor?
|
||||||
def load_vae(self, vae_name):
|
def load_vae(self, vae_name):
|
||||||
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
if vae_name == "pixel_space":
|
||||||
|
sd = {}
|
||||||
|
sd["pixel_space_vae"] = torch.tensor(1.0)
|
||||||
|
elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
||||||
sd = self.load_taesd(vae_name)
|
sd = self.load_taesd(vae_name)
|
||||||
else:
|
else:
|
||||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||||
@ -2323,6 +2327,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_tcfg.py",
|
"nodes_tcfg.py",
|
||||||
"nodes_context_windows.py",
|
"nodes_context_windows.py",
|
||||||
"nodes_qwen.py",
|
"nodes_qwen.py",
|
||||||
|
"nodes_chroma_radiance.py",
|
||||||
"nodes_model_patch.py",
|
"nodes_model_patch.py",
|
||||||
"nodes_easycache.py",
|
"nodes_easycache.py",
|
||||||
"nodes_audio_encoder.py",
|
"nodes_audio_encoder.py",
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
comfyui-frontend-package==1.25.11
|
comfyui-frontend-package==1.26.11
|
||||||
comfyui-workflow-templates==0.1.81
|
comfyui-workflow-templates==0.1.81
|
||||||
comfyui-embedded-docs==0.2.6
|
comfyui-embedded-docs==0.2.6
|
||||||
torch
|
torch
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user