mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-16 14:47:33 +08:00
Merge 4f23ca8785 into 0904cc3fe5
This commit is contained in:
commit
c0b0601378
811
comfy/ldm/helios/model.py
Normal file
811
comfy/ldm/helios/model.py
Normal file
@ -0,0 +1,811 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.wan.model import sinusoidal_embedding_1d
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
|
||||
|
||||
|
||||
def pad_for_3d_conv(x, kernel_size):
|
||||
b, c, t, h, w = x.shape
|
||||
pt, ph, pw = kernel_size
|
||||
pad_t = (pt - (t % pt)) % pt
|
||||
pad_h = (ph - (h % ph)) % ph
|
||||
pad_w = (pw - (w % pw)) % pw
|
||||
return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
|
||||
|
||||
|
||||
def center_down_sample_3d(x, kernel_size):
|
||||
return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
|
||||
|
||||
|
||||
class OutputNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-6, operation_settings={}):
|
||||
super().__init__()
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(
|
||||
1,
|
||||
2,
|
||||
dim,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
) / dim**0.5)
|
||||
self.norm = operation_settings.get("operations").LayerNorm(
|
||||
dim,
|
||||
eps,
|
||||
elementwise_affine=False,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
original_context_length: int,
|
||||
):
|
||||
temb = temb[:, -original_context_length:, :]
|
||||
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
|
||||
shift = shift.squeeze(2).to(hidden_states.device)
|
||||
scale = scale.squeeze(2).to(hidden_states.device)
|
||||
hidden_states = hidden_states[:, -original_context_length:, :]
|
||||
# Use float32 for numerical stability like diffusers
|
||||
hidden_states = (self.norm(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HeliosSelfAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
qk_norm=True,
|
||||
eps=1e-6,
|
||||
is_cross_attention=False,
|
||||
is_amplify_history=False,
|
||||
history_scale_mode="per_head",
|
||||
operation_settings={},
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.is_cross_attention = is_cross_attention
|
||||
self.is_amplify_history = is_amplify_history
|
||||
|
||||
self.to_q = operation_settings.get("operations").Linear(
|
||||
dim,
|
||||
dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
self.to_k = operation_settings.get("operations").Linear(
|
||||
dim,
|
||||
dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
self.to_v = operation_settings.get("operations").Linear(
|
||||
dim,
|
||||
dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
self.to_out = nn.ModuleList([
|
||||
operation_settings.get("operations").Linear(
|
||||
dim,
|
||||
dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
nn.Dropout(0.0),
|
||||
])
|
||||
|
||||
if qk_norm:
|
||||
self.norm_q = operation_settings.get("operations").RMSNorm(
|
||||
dim,
|
||||
eps=eps,
|
||||
elementwise_affine=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
self.norm_k = operation_settings.get("operations").RMSNorm(
|
||||
dim,
|
||||
eps=eps,
|
||||
elementwise_affine=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
else:
|
||||
self.norm_q = nn.Identity()
|
||||
self.norm_k = nn.Identity()
|
||||
|
||||
if is_amplify_history:
|
||||
if history_scale_mode == "scalar":
|
||||
self.history_key_scale = nn.Parameter(torch.ones(
|
||||
1,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
))
|
||||
else:
|
||||
self.history_key_scale = nn.Parameter(torch.ones(
|
||||
num_heads,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
))
|
||||
self.history_scale_mode = history_scale_mode
|
||||
self.max_scale = 10.0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
freqs=None,
|
||||
original_context_length=None,
|
||||
transformer_options={},
|
||||
):
|
||||
if context is None:
|
||||
context = x
|
||||
|
||||
b, sq, _ = x.shape
|
||||
sk = context.shape[1]
|
||||
|
||||
q = self.norm_q(self.to_q(x)).view(b, sq, self.num_heads, self.head_dim)
|
||||
k = self.norm_k(self.to_k(context)).view(b, sk, self.num_heads, self.head_dim)
|
||||
v = self.to_v(context).view(b, sk, self.num_heads, self.head_dim)
|
||||
|
||||
if freqs is not None:
|
||||
q = apply_rope1(q, freqs)
|
||||
k = apply_rope1(k, freqs)
|
||||
|
||||
if q.dtype != v.dtype:
|
||||
q = q.to(v.dtype)
|
||||
if k.dtype != v.dtype:
|
||||
k = k.to(v.dtype)
|
||||
|
||||
if (not self.is_cross_attention and self.is_amplify_history and original_context_length is not None):
|
||||
history_seq_len = sq - original_context_length
|
||||
if history_seq_len > 0:
|
||||
scale_key = 1.0 + torch.sigmoid(self.history_key_scale) * (self.max_scale - 1.0)
|
||||
if self.history_scale_mode == "per_head":
|
||||
scale_key = scale_key.view(1, 1, -1, 1)
|
||||
k = torch.cat([k[:, :history_seq_len] * scale_key, k[:, history_seq_len:]], dim=1)
|
||||
|
||||
y = optimized_attention(
|
||||
q.view(b, sq, -1),
|
||||
k.view(b, sk, -1),
|
||||
v.view(b, sk, -1),
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
y = self.to_out[0](y)
|
||||
y = self.to_out[1](y)
|
||||
return y
|
||||
|
||||
|
||||
class HeliosAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
ffn_dim,
|
||||
num_heads,
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
guidance_cross_attn=False,
|
||||
is_amplify_history=False,
|
||||
history_scale_mode="per_head",
|
||||
operation_settings={},
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = operation_settings.get("operations").LayerNorm(
|
||||
dim,
|
||||
eps,
|
||||
elementwise_affine=False,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
self.attn1 = HeliosSelfAttention(
|
||||
dim,
|
||||
num_heads,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
is_cross_attention=False,
|
||||
is_amplify_history=is_amplify_history,
|
||||
history_scale_mode=history_scale_mode,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
|
||||
self.cross_attn_norm = bool(cross_attn_norm)
|
||||
self.norm2 = (operation_settings.get("operations").LayerNorm(
|
||||
dim,
|
||||
eps,
|
||||
elementwise_affine=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
) if self.cross_attn_norm else nn.Identity())
|
||||
self.attn2 = HeliosSelfAttention(
|
||||
dim,
|
||||
num_heads,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
is_cross_attention=True,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
|
||||
self.norm3 = operation_settings.get("operations").LayerNorm(
|
||||
dim,
|
||||
eps,
|
||||
elementwise_affine=False,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
self.ffn = nn.Sequential(
|
||||
operation_settings.get("operations").Linear(
|
||||
dim,
|
||||
ffn_dim,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operation_settings.get("operations").Linear(
|
||||
ffn_dim,
|
||||
dim,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(
|
||||
1,
|
||||
6,
|
||||
dim,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
) / dim**0.5)
|
||||
self.guidance_cross_attn = guidance_cross_attn
|
||||
|
||||
def forward(self, x, context, e, freqs, original_context_length=None, transformer_options={}):
|
||||
if e.ndim == 4:
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table.unsqueeze(0).to(e.device) + e.float()
|
||||
).chunk(6, dim=2)
|
||||
shift_msa = shift_msa.squeeze(2)
|
||||
scale_msa = scale_msa.squeeze(2)
|
||||
gate_msa = gate_msa.squeeze(2)
|
||||
c_shift_msa = c_shift_msa.squeeze(2)
|
||||
c_scale_msa = c_scale_msa.squeeze(2)
|
||||
c_gate_msa = c_gate_msa.squeeze(2)
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table.to(e.device) + e.float()
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# self-attn
|
||||
# Use float32 for numerical stability like diffusers
|
||||
# norm1 has elementwise_affine=False, so we can safely convert to float32
|
||||
norm_x = self.norm1(x.float())
|
||||
norm_x = (norm_x * (1 + scale_msa) + shift_msa).type_as(x)
|
||||
y = self.attn1(
|
||||
norm_x,
|
||||
freqs=freqs,
|
||||
original_context_length=original_context_length,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x = (x.float() + y.float() * gate_msa).type_as(x)
|
||||
|
||||
# cross-attn
|
||||
if self.guidance_cross_attn and original_context_length is not None:
|
||||
history_seq_len = x.shape[1] - original_context_length
|
||||
history_x, x_main = torch.split(x, [history_seq_len, original_context_length], dim=1)
|
||||
if self.cross_attn_norm:
|
||||
# norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior
|
||||
norm_x_main = torch.nn.functional.layer_norm(
|
||||
x_main.float(),
|
||||
self.norm2.normalized_shape,
|
||||
self.norm2.weight.to(x_main.device).float() if self.norm2.weight is not None else None,
|
||||
self.norm2.bias.to(x_main.device).float() if self.norm2.bias is not None else None,
|
||||
self.norm2.eps,
|
||||
).type_as(x_main)
|
||||
else:
|
||||
norm_x_main = x_main
|
||||
x_main = x_main + self.attn2(
|
||||
norm_x_main,
|
||||
context=context,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x = torch.cat([history_x, x_main], dim=1)
|
||||
else:
|
||||
if self.cross_attn_norm:
|
||||
# norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior
|
||||
norm_x = torch.nn.functional.layer_norm(
|
||||
x.float(),
|
||||
self.norm2.normalized_shape,
|
||||
self.norm2.weight.to(x.device).float() if self.norm2.weight is not None else None,
|
||||
self.norm2.bias.to(x.device).float() if self.norm2.bias is not None else None,
|
||||
self.norm2.eps,
|
||||
).type_as(x)
|
||||
else:
|
||||
norm_x = x
|
||||
x = x + self.attn2(norm_x, context=context, transformer_options=transformer_options)
|
||||
|
||||
# ffn
|
||||
# Use float32 for numerical stability like diffusers
|
||||
# norm3 has elementwise_affine=False, so we can safely convert to float32
|
||||
norm_x = self.norm3(x.float())
|
||||
norm_x = (norm_x * (1 + c_scale_msa) + c_shift_msa).type_as(x)
|
||||
y = self.ffn(norm_x)
|
||||
x = (x.float() + y.float() * c_gate_msa).type_as(x)
|
||||
return x
|
||||
|
||||
|
||||
class HeliosModel(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type="t2v",
|
||||
patch_size=(1, 2, 2),
|
||||
num_attention_heads=40,
|
||||
attention_head_dim=128,
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
text_dim=4096,
|
||||
freq_dim=256,
|
||||
ffn_dim=13824,
|
||||
num_layers=40,
|
||||
cross_attn_norm=True,
|
||||
qk_norm=True,
|
||||
eps=1e-6,
|
||||
rope_dim=(44, 42, 42),
|
||||
rope_theta=10000.0,
|
||||
guidance_cross_attn=True,
|
||||
zero_history_timestep=True,
|
||||
has_multi_term_memory_patch=True,
|
||||
is_amplify_history=False,
|
||||
history_scale_mode="per_head",
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
del model_type, image_model, kwargs
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
operation_settings = {
|
||||
"operations": operations,
|
||||
"device": device,
|
||||
"dtype": dtype,
|
||||
}
|
||||
|
||||
dim = num_attention_heads * attention_head_dim
|
||||
self.patch_size = patch_size
|
||||
self.out_dim = out_channels or in_channels
|
||||
self.dim = dim
|
||||
self.freq_dim = freq_dim
|
||||
self.zero_history_timestep = zero_history_timestep
|
||||
|
||||
# embeddings
|
||||
self.patch_embedding = operations.Conv3d(
|
||||
in_channels,
|
||||
dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
self.text_embedding = nn.Sequential(
|
||||
operations.Linear(
|
||||
text_dim,
|
||||
dim,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(
|
||||
dim,
|
||||
dim,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
self.time_embedding = nn.Sequential(
|
||||
operations.Linear(
|
||||
freq_dim,
|
||||
dim,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
nn.SiLU(),
|
||||
operations.Linear(
|
||||
dim,
|
||||
dim,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
self.time_projection = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(
|
||||
dim,
|
||||
dim * 6,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
|
||||
d = dim // num_attention_heads
|
||||
self.rope_embedder = EmbedND(dim=d, theta=rope_theta, axes_dim=list(rope_dim))
|
||||
|
||||
# pyramidal embedding
|
||||
if has_multi_term_memory_patch:
|
||||
self.patch_short = operations.Conv3d(
|
||||
in_channels,
|
||||
dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
self.patch_mid = operations.Conv3d(
|
||||
in_channels,
|
||||
dim,
|
||||
kernel_size=tuple(2 * p for p in patch_size),
|
||||
stride=tuple(2 * p for p in patch_size),
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
self.patch_long = operations.Conv3d(
|
||||
in_channels,
|
||||
dim,
|
||||
kernel_size=tuple(4 * p for p in patch_size),
|
||||
stride=tuple(4 * p for p in patch_size),
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
|
||||
# blocks
|
||||
self.blocks = nn.ModuleList([HeliosAttentionBlock(
|
||||
dim,
|
||||
ffn_dim,
|
||||
num_attention_heads,
|
||||
qk_norm=qk_norm,
|
||||
cross_attn_norm=cross_attn_norm,
|
||||
eps=eps,
|
||||
guidance_cross_attn=guidance_cross_attn,
|
||||
is_amplify_history=is_amplify_history,
|
||||
history_scale_mode=history_scale_mode,
|
||||
operation_settings=operation_settings,
|
||||
) for _ in range(num_layers)])
|
||||
|
||||
# head
|
||||
self.norm_out = OutputNorm(dim, eps=eps, operation_settings=operation_settings)
|
||||
self.proj_out = operations.Linear(
|
||||
dim,
|
||||
self.out_dim * math.prod(patch_size),
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
|
||||
def rope_encode(
|
||||
self,
|
||||
t,
|
||||
h,
|
||||
w,
|
||||
t_start=0,
|
||||
steps_t=None,
|
||||
steps_h=None,
|
||||
steps_w=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
transformer_options={},
|
||||
frame_indices=None,
|
||||
):
|
||||
patch_size = self.patch_size
|
||||
t_len = (t + (patch_size[0] // 2)) // patch_size[0]
|
||||
h_len = (h + (patch_size[1] // 2)) // patch_size[1]
|
||||
w_len = (w + (patch_size[2] // 2)) // patch_size[2]
|
||||
|
||||
if steps_t is None:
|
||||
steps_t = t_len
|
||||
if steps_h is None:
|
||||
steps_h = h_len
|
||||
if steps_w is None:
|
||||
steps_w = w_len
|
||||
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
if rope_options is not None:
|
||||
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
||||
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||
|
||||
t_start += rope_options.get("shift_t", 0.0)
|
||||
h_start += rope_options.get("shift_y", 0.0)
|
||||
w_start += rope_options.get("shift_x", 0.0)
|
||||
|
||||
if frame_indices is None:
|
||||
t_coords = torch.linspace(
|
||||
t_start,
|
||||
t_start + (t_len - 1),
|
||||
steps=steps_t,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
).reshape(1, -1, 1, 1)
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = frame_indices.shape[0]
|
||||
t_coords = frame_indices.to(device=device, dtype=dtype)
|
||||
if t_coords.shape[1] != steps_t:
|
||||
t_coords = torch.nn.functional.interpolate(
|
||||
t_coords.unsqueeze(1),
|
||||
size=steps_t,
|
||||
mode="linear",
|
||||
align_corners=False,
|
||||
).squeeze(1)
|
||||
t_coords = (t_coords + t_start)[:, :, None, None]
|
||||
|
||||
img_ids = torch.zeros((batch_size, steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||
img_ids[:, :, :, :, 0] = img_ids[:, :, :, :, 0] + t_coords.expand(batch_size, steps_t, steps_h, steps_w)
|
||||
img_ids[:, :, :, :, 1] = img_ids[:, :, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, 1, -1, 1)
|
||||
img_ids[:, :, :, :, 2] = img_ids[:, :, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, 1, -1)
|
||||
img_ids = img_ids.reshape(batch_size, -1, img_ids.shape[-1])
|
||||
return self.rope_embedder(img_ids).movedim(1, 2)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
context,
|
||||
clip_fea=None,
|
||||
time_dim_concat=None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
|
||||
).execute(
|
||||
x,
|
||||
timestep,
|
||||
context,
|
||||
clip_fea,
|
||||
time_dim_concat,
|
||||
transformer_options,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
context,
|
||||
clip_fea=None,
|
||||
time_dim_concat=None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
):
|
||||
del clip_fea, time_dim_concat
|
||||
|
||||
_, _, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
out = self.forward_orig(
|
||||
hidden_states=x,
|
||||
timestep=timestep,
|
||||
context=context,
|
||||
indices_hidden_states=kwargs.get("indices_hidden_states", None),
|
||||
indices_latents_history_short=kwargs.get("indices_latents_history_short", None),
|
||||
indices_latents_history_mid=kwargs.get("indices_latents_history_mid", None),
|
||||
indices_latents_history_long=kwargs.get("indices_latents_history_long", None),
|
||||
latents_history_short=kwargs.get("latents_history_short", None),
|
||||
latents_history_mid=kwargs.get("latents_history_mid", None),
|
||||
latents_history_long=kwargs.get("latents_history_long", None),
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
return out[:, :, :t, :h, :w]
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep,
|
||||
context,
|
||||
indices_hidden_states=None,
|
||||
indices_latents_history_short=None,
|
||||
indices_latents_history_mid=None,
|
||||
indices_latents_history_long=None,
|
||||
latents_history_short=None,
|
||||
latents_history_mid=None,
|
||||
latents_history_long=None,
|
||||
transformer_options={},
|
||||
):
|
||||
batch_size = hidden_states.shape[0]
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
|
||||
# embeddings
|
||||
hidden_states = self.patch_embedding(hidden_states)
|
||||
_, _, post_t, post_h, post_w = hidden_states.shape
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
|
||||
if indices_hidden_states is None:
|
||||
indices_hidden_states = (torch.arange(0, post_t, device=hidden_states.device).unsqueeze(0).expand(batch_size, -1))
|
||||
|
||||
freqs = self.rope_encode(
|
||||
t=post_t * self.patch_size[0],
|
||||
h=post_h * self.patch_size[1],
|
||||
w=post_w * self.patch_size[2],
|
||||
steps_t=post_t,
|
||||
steps_h=post_h,
|
||||
steps_w=post_w,
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
transformer_options=transformer_options,
|
||||
frame_indices=indices_hidden_states,
|
||||
)
|
||||
original_context_length = hidden_states.shape[1]
|
||||
|
||||
if latents_history_short is not None and indices_latents_history_short is not None:
|
||||
x_short = self.patch_short(latents_history_short)
|
||||
_, _, ts, hs, ws = x_short.shape
|
||||
x_short = x_short.flatten(2).transpose(1, 2)
|
||||
f_short = self.rope_encode(
|
||||
t=ts * self.patch_size[0],
|
||||
h=hs * self.patch_size[1],
|
||||
w=ws * self.patch_size[2],
|
||||
steps_t=ts,
|
||||
steps_h=hs,
|
||||
steps_w=ws,
|
||||
device=x_short.device,
|
||||
dtype=x_short.dtype,
|
||||
transformer_options=transformer_options,
|
||||
frame_indices=indices_latents_history_short,
|
||||
)
|
||||
hidden_states = torch.cat([x_short, hidden_states], dim=1)
|
||||
freqs = torch.cat([f_short, freqs], dim=1)
|
||||
|
||||
if latents_history_mid is not None and indices_latents_history_mid is not None:
|
||||
x_mid = self.patch_mid(pad_for_3d_conv(latents_history_mid, (2, 4, 4)))
|
||||
_, _, tm, hm, wm = x_mid.shape
|
||||
x_mid = x_mid.flatten(2).transpose(1, 2)
|
||||
mid_t = indices_latents_history_mid.shape[1]
|
||||
# patch_mid downsamples by 2 in (t, h, w); build RoPE on the pre-downsample grid.
|
||||
mid_h = hm * 2
|
||||
mid_w = wm * 2
|
||||
f_mid = self.rope_encode(
|
||||
t=mid_t * self.patch_size[0],
|
||||
h=mid_h * self.patch_size[1],
|
||||
w=mid_w * self.patch_size[2],
|
||||
steps_t=mid_t,
|
||||
steps_h=mid_h,
|
||||
steps_w=mid_w,
|
||||
device=x_mid.device,
|
||||
dtype=x_mid.dtype,
|
||||
transformer_options=transformer_options,
|
||||
frame_indices=indices_latents_history_mid,
|
||||
)
|
||||
f_mid = self._rope_downsample_3d(f_mid, (mid_t, mid_h, mid_w), (2, 2, 2))
|
||||
hidden_states = torch.cat([x_mid, hidden_states], dim=1)
|
||||
freqs = torch.cat([f_mid, freqs], dim=1)
|
||||
|
||||
if latents_history_long is not None and indices_latents_history_long is not None:
|
||||
x_long = self.patch_long(pad_for_3d_conv(latents_history_long, (4, 8, 8)))
|
||||
_, _, tl, hl, wl = x_long.shape
|
||||
x_long = x_long.flatten(2).transpose(1, 2)
|
||||
long_t = indices_latents_history_long.shape[1]
|
||||
# patch_long downsamples by 4 in (t, h, w); build RoPE on the pre-downsample grid.
|
||||
long_h = hl * 4
|
||||
long_w = wl * 4
|
||||
f_long = self.rope_encode(
|
||||
t=long_t * self.patch_size[0],
|
||||
h=long_h * self.patch_size[1],
|
||||
w=long_w * self.patch_size[2],
|
||||
steps_t=long_t,
|
||||
steps_h=long_h,
|
||||
steps_w=long_w,
|
||||
device=x_long.device,
|
||||
dtype=x_long.dtype,
|
||||
transformer_options=transformer_options,
|
||||
frame_indices=indices_latents_history_long,
|
||||
)
|
||||
f_long = self._rope_downsample_3d(f_long, (long_t, long_h, long_w), (4, 4, 4))
|
||||
hidden_states = torch.cat([x_long, hidden_states], dim=1)
|
||||
freqs = torch.cat([f_long, freqs], dim=1)
|
||||
|
||||
history_context_length = hidden_states.shape[1] - original_context_length
|
||||
|
||||
if timestep.ndim == 0:
|
||||
timestep = timestep.unsqueeze(0)
|
||||
timestep = timestep.to(hidden_states.device)
|
||||
if timestep.shape[0] != batch_size:
|
||||
timestep = timestep[:1].expand(batch_size)
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=hidden_states.dtype))
|
||||
e = e.reshape(batch_size, -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
context = self.text_embedding(context.to(dtype=hidden_states.dtype))
|
||||
|
||||
if self.zero_history_timestep and history_context_length > 0:
|
||||
timestep_t0 = torch.zeros((1, ), dtype=timestep.dtype, device=timestep.device)
|
||||
e_t0 = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep_t0.flatten()).to(dtype=hidden_states.dtype))
|
||||
e_t0 = e_t0.reshape(1, -1, e_t0.shape[-1]).expand(batch_size, history_context_length, -1)
|
||||
e0_t0 = self.time_projection(e_t0[:, :1]).unflatten(2, (6, self.dim))
|
||||
e0_t0 = (e0_t0.view(batch_size, 1, 6, self.dim).permute(0, 2, 1, 3).expand(batch_size, 6, history_context_length, self.dim))
|
||||
|
||||
e = e.expand(batch_size, original_context_length, -1)
|
||||
e0 = (e0.view(batch_size, 1, 6, self.dim).permute(0, 2, 1, 3).expand(batch_size, 6, original_context_length, self.dim))
|
||||
e = torch.cat([e_t0, e], dim=1)
|
||||
e0 = torch.cat([e0_t0, e0], dim=2)
|
||||
else:
|
||||
e = e.expand(batch_size, hidden_states.shape[1], -1)
|
||||
e0 = (e0.view(batch_size, 1, 6, self.dim).permute(0, 2, 1, 3).expand(batch_size, 6, hidden_states.shape[1], self.dim))
|
||||
|
||||
e0 = e0.permute(0, 2, 1, 3)
|
||||
|
||||
for i_b, block in enumerate(self.blocks):
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
context,
|
||||
e0,
|
||||
freqs,
|
||||
original_context_length=original_context_length,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
hidden_states = self.norm_out(hidden_states, e, original_context_length)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
return self.unpatchify(hidden_states, (post_t, post_h, post_w))
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
"""
|
||||
Unpatchify the output from proj_out back to video format.
|
||||
|
||||
Args:
|
||||
x: [batch, num_patches, out_dim * prod(patch_size)]
|
||||
grid_sizes: (num_frames, height, width) in patch space
|
||||
|
||||
Returns:
|
||||
[batch, out_dim, num_frames, height, width] in pixel space
|
||||
"""
|
||||
b = x.shape[0]
|
||||
post_t, post_h, post_w = grid_sizes
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
|
||||
# Reshape: [B, T*H*W, out_dim*p_t*p_h*p_w] -> [B, T, H, W, p_t, p_h, p_w, out_dim]
|
||||
# Use -1 to let PyTorch infer the channel dimension (out_dim)
|
||||
hidden_states = x.reshape(b, post_t, post_h, post_w, p_t, p_h, p_w, -1)
|
||||
|
||||
# Permute: [B, T, H, W, p_t, p_h, p_w, C] -> [B, C, T, p_t, H, p_h, W, p_w]
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
|
||||
# Flatten patches: [B, C, T, p_t, H, p_h, W, p_w] -> [B, C, T*p_t, H*p_h, W*p_w]
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
return output
|
||||
def _rope_downsample_3d(self, freqs, grid_sizes, kernel_size):
|
||||
b, _, one, d, i2, j2 = freqs.shape
|
||||
gt, gh, gw = grid_sizes
|
||||
c = one * d * i2 * j2
|
||||
freqs_3d = freqs.reshape(b, gt, gh, gw, c).permute(0, 4, 1, 2, 3)
|
||||
freqs_3d = pad_for_3d_conv(freqs_3d, kernel_size)
|
||||
freqs_3d = center_down_sample_3d(freqs_3d, kernel_size)
|
||||
dt, dh, dw = freqs_3d.shape[2:]
|
||||
freqs_3d = freqs_3d.permute(0, 2, 3, 4, 1).reshape(b, dt * dh * dw, one, d, i2, j2)
|
||||
return freqs_3d
|
||||
|
||||
# Backward-compatible alias for existing integration points.
|
||||
HeliosTransformer3DModel = HeliosModel
|
||||
@ -41,6 +41,7 @@ import comfy.ldm.cosmos.predict2
|
||||
import comfy.ldm.lumina.model
|
||||
import comfy.ldm.wan.model
|
||||
import comfy.ldm.wan.model_animate
|
||||
import comfy.ldm.helios.model
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
@ -1268,6 +1269,70 @@ class ZImagePixelSpace(Lumina2):
|
||||
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
class Helios(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.helios.model.HeliosTransformer3DModel)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out["c_crossattn"] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
cond_keys = (
|
||||
"indices_hidden_states",
|
||||
"indices_latents_history_short",
|
||||
"indices_latents_history_mid",
|
||||
"indices_latents_history_long",
|
||||
"latents_history_short",
|
||||
"latents_history_mid",
|
||||
"latents_history_long",
|
||||
"helios_stage_sigmas",
|
||||
"helios_stage_timesteps",
|
||||
)
|
||||
|
||||
for key in cond_keys:
|
||||
value = kwargs.get(key, None)
|
||||
if value is None:
|
||||
continue
|
||||
# Diffusers forwards Helios history latents without latent-format re-normalization.
|
||||
# Keep raw history tensors to match transformer inputs across frameworks.
|
||||
if key in ("helios_stage_sigmas", "helios_stage_timesteps"):
|
||||
out[key] = comfy.conds.CONDConstant(value)
|
||||
else:
|
||||
out[key] = comfy.conds.CONDRegular(value)
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, **kwargs):
|
||||
stage_sigmas = kwargs.get("helios_stage_sigmas", None)
|
||||
stage_timesteps = kwargs.get("helios_stage_timesteps", None)
|
||||
if stage_sigmas is None or stage_timesteps is None:
|
||||
return timestep
|
||||
|
||||
if stage_sigmas.ndim > 1:
|
||||
stage_sigmas = stage_sigmas[0]
|
||||
if stage_timesteps.ndim > 1:
|
||||
stage_timesteps = stage_timesteps[0]
|
||||
|
||||
if stage_timesteps.numel() == 0 or stage_sigmas.numel() == 0:
|
||||
return timestep
|
||||
|
||||
if stage_sigmas.numel() == stage_timesteps.numel() + 1:
|
||||
sigma_candidates = stage_sigmas[:-1]
|
||||
else:
|
||||
sigma_candidates = stage_sigmas[: stage_timesteps.numel()]
|
||||
|
||||
if sigma_candidates.numel() == 0:
|
||||
return timestep
|
||||
|
||||
multiplier = float(getattr(self.model_sampling, "multiplier", 1000.0))
|
||||
sigma_in = timestep / multiplier
|
||||
idx = torch.argmin(torch.abs(sigma_in.unsqueeze(-1) - sigma_candidates.unsqueeze(0)), dim=-1)
|
||||
mapped = stage_timesteps[idx].to(dtype=timestep.dtype)
|
||||
if mapped.dtype.is_floating_point:
|
||||
mapped = torch.floor(mapped)
|
||||
return mapped
|
||||
|
||||
class WAN21(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
|
||||
@ -490,6 +490,54 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
return dit_config
|
||||
|
||||
helios_required_keys = (
|
||||
'{}patch_mid.weight'.format(key_prefix),
|
||||
'{}patch_long.weight'.format(key_prefix),
|
||||
)
|
||||
if all(k in state_dict_keys for k in helios_required_keys): # Helios
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "helios"
|
||||
|
||||
patch_weight = state_dict['{}patch_embedding.weight'.format(key_prefix)]
|
||||
inner_dim = patch_weight.shape[0]
|
||||
patch_size = tuple(patch_weight.shape[2:])
|
||||
out_proj = state_dict['{}proj_out.weight'.format(key_prefix)]
|
||||
|
||||
dit_config["patch_size"] = patch_size
|
||||
dit_config["in_channels"] = patch_weight.shape[1]
|
||||
dit_config["out_channels"] = out_proj.shape[0] // math.prod(patch_size)
|
||||
text_w = state_dict['{}text_embedding.0.weight'.format(key_prefix)]
|
||||
time_w = state_dict['{}time_embedding.0.weight'.format(key_prefix)]
|
||||
dit_config["text_dim"] = text_w.shape[1]
|
||||
dit_config["freq_dim"] = time_w.shape[1]
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["num_attention_heads"] = inner_dim // 128
|
||||
dit_config["attention_head_dim"] = 128
|
||||
|
||||
ffn_in = state_dict.get('{}blocks.0.ffn.net.0.proj.weight'.format(key_prefix), None)
|
||||
if ffn_in is None:
|
||||
ffn_in = state_dict.get('{}blocks.0.ffn.0.weight'.format(key_prefix), None)
|
||||
if ffn_in is not None:
|
||||
dit_config["ffn_dim"] = ffn_in.shape[0]
|
||||
|
||||
if '{}blocks.0.attn2.add_k_proj.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["added_kv_proj_dim"] = state_dict['{}blocks.0.attn2.add_k_proj.weight'.format(key_prefix)].shape[1]
|
||||
|
||||
if '{}patch_short.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["has_multi_term_memory_patch"] = True
|
||||
else:
|
||||
dit_config["has_multi_term_memory_patch"] = False
|
||||
|
||||
if '{}blocks.0.attn1.history_key_scale'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["is_amplify_history"] = True
|
||||
hk = state_dict['{}blocks.0.attn1.history_key_scale'.format(key_prefix)]
|
||||
dit_config["history_scale_mode"] = "per_head" if len(hk.shape) > 0 and hk.numel() > 1 else "scalar"
|
||||
|
||||
if metadata is not None and "config" in metadata:
|
||||
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "wan2.1"
|
||||
|
||||
@ -1168,6 +1168,7 @@ class CLIPType(Enum):
|
||||
NEWBIE = 24
|
||||
FLUX2 = 25
|
||||
LONGCAT_IMAGE = 26
|
||||
HELIOS = 27
|
||||
|
||||
|
||||
|
||||
@ -1339,6 +1340,11 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif clip_type == CLIPType.HELIOS:
|
||||
# Helios reuses the WAN UMT5-XXL text encoder stack.
|
||||
clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif clip_type == CLIPType.HIDREAM:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
|
||||
|
||||
@ -1132,6 +1132,42 @@ class ZImagePixelSpace(ZImage):
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.ZImagePixelSpace(self, device=device)
|
||||
|
||||
class Helios(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "helios",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Wan21
|
||||
memory_usage_factor = 1.8
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = (self.unet_config.get("num_layers", 40) * self.unet_config.get("num_attention_heads", 40)) / (40 * 40) * 1.8
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.Helios(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(
|
||||
state_dict,
|
||||
"{}umt5xxl.transformer.".format(pref),
|
||||
)
|
||||
# Directly reuse WAN text encoder stack; no Helios-specific TE.
|
||||
return supported_models_base.ClipTarget(
|
||||
comfy.text_encoders.wan.WanT5Tokenizer,
|
||||
comfy.text_encoders.wan.te(**t5_detect),
|
||||
)
|
||||
|
||||
class WAN21_T2V(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@ -1734,6 +1770,6 @@ class LongCatImage(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, Helios, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
1291
comfy_extras/nodes_helios.py
Normal file
1291
comfy_extras/nodes_helios.py
Normal file
File diff suppressed because it is too large
Load Diff
5
nodes.py
5
nodes.py
@ -976,7 +976,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "helios", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@ -986,7 +986,7 @@ class CLIPLoader:
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\nhelios: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
|
||||
|
||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||
@ -2414,6 +2414,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_cosmos.py",
|
||||
"nodes_video.py",
|
||||
"nodes_lumina2.py",
|
||||
"nodes_helios.py",
|
||||
"nodes_wan.py",
|
||||
"nodes_lotus.py",
|
||||
"nodes_hunyuan3d.py",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user