mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-22 09:33:29 +08:00
Basic support helios
This commit is contained in:
parent
9b85cf9558
commit
ae36a9d4fd
744
comfy/ldm/helios/model.py
Normal file
744
comfy/ldm/helios/model.py
Normal file
@ -0,0 +1,744 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
|
from comfy.ldm.flux.math import apply_rope1
|
||||||
|
from comfy.ldm.wan.model import sinusoidal_embedding_1d, repeat_e
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
import comfy.patcher_extension
|
||||||
|
|
||||||
|
|
||||||
|
def pad_for_3d_conv(x, kernel_size):
|
||||||
|
b, c, t, h, w = x.shape
|
||||||
|
pt, ph, pw = kernel_size
|
||||||
|
pad_t = (pt - (t % pt)) % pt
|
||||||
|
pad_h = (ph - (h % ph)) % ph
|
||||||
|
pad_w = (pw - (w % pw)) % pw
|
||||||
|
return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
|
||||||
|
|
||||||
|
|
||||||
|
class OutputNorm(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, eps=1e-6, operation_settings={}):
|
||||||
|
super().__init__()
|
||||||
|
self.scale_shift_table = nn.Parameter(torch.randn(
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
dim,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
) / dim**0.5)
|
||||||
|
self.norm = operation_settings.get("operations").LayerNorm(
|
||||||
|
dim,
|
||||||
|
eps,
|
||||||
|
elementwise_affine=False,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
temb: torch.Tensor,
|
||||||
|
original_context_length: int,
|
||||||
|
):
|
||||||
|
temb = temb[:, -original_context_length:, :]
|
||||||
|
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
|
||||||
|
shift = shift.squeeze(2).to(hidden_states.device)
|
||||||
|
scale = scale.squeeze(2).to(hidden_states.device)
|
||||||
|
hidden_states = hidden_states[:, -original_context_length:, :]
|
||||||
|
hidden_states = self.norm(hidden_states) * (1 + scale) + shift
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class HeliosSelfAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads,
|
||||||
|
qk_norm=True,
|
||||||
|
eps=1e-6,
|
||||||
|
is_cross_attention=False,
|
||||||
|
is_amplify_history=False,
|
||||||
|
history_scale_mode="per_head",
|
||||||
|
operation_settings={},
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.is_cross_attention = is_cross_attention
|
||||||
|
self.is_amplify_history = is_amplify_history
|
||||||
|
|
||||||
|
self.to_q = operation_settings.get("operations").Linear(
|
||||||
|
dim,
|
||||||
|
dim,
|
||||||
|
bias=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
self.to_k = operation_settings.get("operations").Linear(
|
||||||
|
dim,
|
||||||
|
dim,
|
||||||
|
bias=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
self.to_v = operation_settings.get("operations").Linear(
|
||||||
|
dim,
|
||||||
|
dim,
|
||||||
|
bias=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
self.to_out = nn.ModuleList([
|
||||||
|
operation_settings.get("operations").Linear(
|
||||||
|
dim,
|
||||||
|
dim,
|
||||||
|
bias=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
nn.Dropout(0.0),
|
||||||
|
])
|
||||||
|
|
||||||
|
if qk_norm:
|
||||||
|
self.norm_q = operation_settings.get("operations").RMSNorm(
|
||||||
|
dim,
|
||||||
|
eps=eps,
|
||||||
|
elementwise_affine=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
self.norm_k = operation_settings.get("operations").RMSNorm(
|
||||||
|
dim,
|
||||||
|
eps=eps,
|
||||||
|
elementwise_affine=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.norm_q = nn.Identity()
|
||||||
|
self.norm_k = nn.Identity()
|
||||||
|
|
||||||
|
if is_amplify_history:
|
||||||
|
if history_scale_mode == "scalar":
|
||||||
|
self.history_key_scale = nn.Parameter(torch.ones(
|
||||||
|
1,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
self.history_key_scale = nn.Parameter(torch.ones(
|
||||||
|
num_heads,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
))
|
||||||
|
self.history_scale_mode = history_scale_mode
|
||||||
|
self.max_scale = 10.0
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
context=None,
|
||||||
|
freqs=None,
|
||||||
|
original_context_length=None,
|
||||||
|
transformer_options={},
|
||||||
|
):
|
||||||
|
if context is None:
|
||||||
|
context = x
|
||||||
|
|
||||||
|
b, sq, _ = x.shape
|
||||||
|
sk = context.shape[1]
|
||||||
|
|
||||||
|
q = self.norm_q(self.to_q(x)).view(b, sq, self.num_heads, self.head_dim)
|
||||||
|
k = self.norm_k(self.to_k(context)).view(b, sk, self.num_heads, self.head_dim)
|
||||||
|
v = self.to_v(context).view(b, sk, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
if freqs is not None:
|
||||||
|
q = apply_rope1(q, freqs)
|
||||||
|
k = apply_rope1(k, freqs)
|
||||||
|
|
||||||
|
if q.dtype != v.dtype:
|
||||||
|
q = q.to(v.dtype)
|
||||||
|
if k.dtype != v.dtype:
|
||||||
|
k = k.to(v.dtype)
|
||||||
|
|
||||||
|
if (not self.is_cross_attention and self.is_amplify_history and original_context_length is not None):
|
||||||
|
history_seq_len = sq - original_context_length
|
||||||
|
if history_seq_len > 0:
|
||||||
|
scale_key = 1.0 + torch.sigmoid(self.history_key_scale) * (self.max_scale - 1.0)
|
||||||
|
if self.history_scale_mode == "per_head":
|
||||||
|
scale_key = scale_key.view(1, 1, -1, 1)
|
||||||
|
k = torch.cat([k[:, :history_seq_len] * scale_key, k[:, history_seq_len:]], dim=1)
|
||||||
|
|
||||||
|
y = optimized_attention(
|
||||||
|
q.view(b, sq, -1),
|
||||||
|
k.view(b, sk, -1),
|
||||||
|
v.view(b, sk, -1),
|
||||||
|
heads=self.num_heads,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
y = self.to_out[0](y)
|
||||||
|
y = self.to_out[1](y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class HeliosAttentionBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
ffn_dim,
|
||||||
|
num_heads,
|
||||||
|
qk_norm=True,
|
||||||
|
cross_attn_norm=True,
|
||||||
|
eps=1e-6,
|
||||||
|
guidance_cross_attn=False,
|
||||||
|
is_amplify_history=False,
|
||||||
|
history_scale_mode="per_head",
|
||||||
|
operation_settings={},
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.norm1 = operation_settings.get("operations").LayerNorm(
|
||||||
|
dim,
|
||||||
|
eps,
|
||||||
|
elementwise_affine=False,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
self.attn1 = HeliosSelfAttention(
|
||||||
|
dim,
|
||||||
|
num_heads,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
eps=eps,
|
||||||
|
is_cross_attention=False,
|
||||||
|
is_amplify_history=is_amplify_history,
|
||||||
|
history_scale_mode=history_scale_mode,
|
||||||
|
operation_settings=operation_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm2 = (operation_settings.get("operations").LayerNorm(
|
||||||
|
dim,
|
||||||
|
eps,
|
||||||
|
elementwise_affine=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
) if cross_attn_norm else nn.Identity())
|
||||||
|
self.attn2 = HeliosSelfAttention(
|
||||||
|
dim,
|
||||||
|
num_heads,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
eps=eps,
|
||||||
|
is_cross_attention=True,
|
||||||
|
operation_settings=operation_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm3 = operation_settings.get("operations").LayerNorm(
|
||||||
|
dim,
|
||||||
|
eps,
|
||||||
|
elementwise_affine=False,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
self.ffn = nn.Sequential(
|
||||||
|
operation_settings.get("operations").Linear(
|
||||||
|
dim,
|
||||||
|
ffn_dim,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
operation_settings.get("operations").Linear(
|
||||||
|
ffn_dim,
|
||||||
|
dim,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scale_shift_table = nn.Parameter(torch.randn(
|
||||||
|
1,
|
||||||
|
6,
|
||||||
|
dim,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
) / dim**0.5)
|
||||||
|
self.guidance_cross_attn = guidance_cross_attn
|
||||||
|
|
||||||
|
def forward(self, x, context, e, freqs, original_context_length=None, transformer_options={}):
|
||||||
|
if e.ndim == 4:
|
||||||
|
e = (self.scale_shift_table.unsqueeze(0) + e.float()).chunk(6, dim=2)
|
||||||
|
e = [v.squeeze(2) for v in e]
|
||||||
|
else:
|
||||||
|
e = (self.scale_shift_table + e.float()).chunk(6, dim=1)
|
||||||
|
|
||||||
|
# self-attn
|
||||||
|
y = self.attn1(
|
||||||
|
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||||
|
freqs=freqs,
|
||||||
|
original_context_length=original_context_length,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||||
|
|
||||||
|
# cross-attn
|
||||||
|
if self.guidance_cross_attn and original_context_length is not None:
|
||||||
|
history_seq_len = x.shape[1] - original_context_length
|
||||||
|
history_x, x_main = torch.split(x, [history_seq_len, original_context_length], dim=1)
|
||||||
|
x_main = x_main + self.attn2(
|
||||||
|
self.norm2(x_main),
|
||||||
|
context=context,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
x = torch.cat([history_x, x_main], dim=1)
|
||||||
|
else:
|
||||||
|
x = x + self.attn2(self.norm2(x), context=context, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
# ffn
|
||||||
|
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm3(x), 1 + repeat_e(e[4], x)))
|
||||||
|
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class HeliosModel(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_type="t2v",
|
||||||
|
patch_size=(1, 2, 2),
|
||||||
|
num_attention_heads=40,
|
||||||
|
attention_head_dim=128,
|
||||||
|
in_channels=16,
|
||||||
|
out_channels=16,
|
||||||
|
text_dim=4096,
|
||||||
|
freq_dim=256,
|
||||||
|
ffn_dim=13824,
|
||||||
|
num_layers=40,
|
||||||
|
cross_attn_norm=True,
|
||||||
|
qk_norm=True,
|
||||||
|
eps=1e-6,
|
||||||
|
rope_dim=(44, 42, 42),
|
||||||
|
rope_theta=10000.0,
|
||||||
|
guidance_cross_attn=True,
|
||||||
|
zero_history_timestep=True,
|
||||||
|
has_multi_term_memory_patch=True,
|
||||||
|
is_amplify_history=False,
|
||||||
|
history_scale_mode="per_head",
|
||||||
|
image_model=None,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
del model_type, image_model, kwargs
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
operation_settings = {
|
||||||
|
"operations": operations,
|
||||||
|
"device": device,
|
||||||
|
"dtype": dtype,
|
||||||
|
}
|
||||||
|
|
||||||
|
dim = num_attention_heads * attention_head_dim
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.out_dim = out_channels or in_channels
|
||||||
|
self.dim = dim
|
||||||
|
self.freq_dim = freq_dim
|
||||||
|
self.zero_history_timestep = zero_history_timestep
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
self.patch_embedding = operations.Conv3d(
|
||||||
|
in_channels,
|
||||||
|
dim,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
self.text_embedding = nn.Sequential(
|
||||||
|
operations.Linear(
|
||||||
|
text_dim,
|
||||||
|
dim,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
operations.Linear(
|
||||||
|
dim,
|
||||||
|
dim,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.time_embedding = nn.Sequential(
|
||||||
|
operations.Linear(
|
||||||
|
freq_dim,
|
||||||
|
dim,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(
|
||||||
|
dim,
|
||||||
|
dim,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.time_projection = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(
|
||||||
|
dim,
|
||||||
|
dim * 6,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
d = dim // num_attention_heads
|
||||||
|
self.rope_embedder = EmbedND(dim=d, theta=rope_theta, axes_dim=list(rope_dim))
|
||||||
|
|
||||||
|
# pyramidal embedding
|
||||||
|
if has_multi_term_memory_patch:
|
||||||
|
self.patch_short = operations.Conv3d(
|
||||||
|
in_channels,
|
||||||
|
dim,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
self.patch_mid = operations.Conv3d(
|
||||||
|
in_channels,
|
||||||
|
dim,
|
||||||
|
kernel_size=tuple(2 * p for p in patch_size),
|
||||||
|
stride=tuple(2 * p for p in patch_size),
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
self.patch_long = operations.Conv3d(
|
||||||
|
in_channels,
|
||||||
|
dim,
|
||||||
|
kernel_size=tuple(4 * p for p in patch_size),
|
||||||
|
stride=tuple(4 * p for p in patch_size),
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
# blocks
|
||||||
|
self.blocks = nn.ModuleList([HeliosAttentionBlock(
|
||||||
|
dim,
|
||||||
|
ffn_dim,
|
||||||
|
num_attention_heads,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
cross_attn_norm=cross_attn_norm,
|
||||||
|
eps=eps,
|
||||||
|
guidance_cross_attn=guidance_cross_attn,
|
||||||
|
is_amplify_history=is_amplify_history,
|
||||||
|
history_scale_mode=history_scale_mode,
|
||||||
|
operation_settings=operation_settings,
|
||||||
|
) for _ in range(num_layers)])
|
||||||
|
|
||||||
|
# head
|
||||||
|
self.norm_out = OutputNorm(dim, eps=eps, operation_settings=operation_settings)
|
||||||
|
self.proj_out = operations.Linear(
|
||||||
|
dim,
|
||||||
|
self.out_dim * math.prod(patch_size),
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def rope_encode(
|
||||||
|
self,
|
||||||
|
t,
|
||||||
|
h,
|
||||||
|
w,
|
||||||
|
t_start=0,
|
||||||
|
steps_t=None,
|
||||||
|
steps_h=None,
|
||||||
|
steps_w=None,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
transformer_options={},
|
||||||
|
frame_indices=None,
|
||||||
|
):
|
||||||
|
patch_size = self.patch_size
|
||||||
|
t_len = (t + (patch_size[0] // 2)) // patch_size[0]
|
||||||
|
h_len = (h + (patch_size[1] // 2)) // patch_size[1]
|
||||||
|
w_len = (w + (patch_size[2] // 2)) // patch_size[2]
|
||||||
|
|
||||||
|
if steps_t is None:
|
||||||
|
steps_t = t_len
|
||||||
|
if steps_h is None:
|
||||||
|
steps_h = h_len
|
||||||
|
if steps_w is None:
|
||||||
|
steps_w = w_len
|
||||||
|
|
||||||
|
h_start = 0
|
||||||
|
w_start = 0
|
||||||
|
rope_options = transformer_options.get("rope_options", None)
|
||||||
|
if rope_options is not None:
|
||||||
|
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
||||||
|
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||||
|
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||||
|
|
||||||
|
t_start += rope_options.get("shift_t", 0.0)
|
||||||
|
h_start += rope_options.get("shift_y", 0.0)
|
||||||
|
w_start += rope_options.get("shift_x", 0.0)
|
||||||
|
|
||||||
|
if frame_indices is None:
|
||||||
|
t_coords = torch.linspace(
|
||||||
|
t_start,
|
||||||
|
t_start + (t_len - 1),
|
||||||
|
steps=steps_t,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
).reshape(1, -1, 1, 1)
|
||||||
|
batch_size = 1
|
||||||
|
else:
|
||||||
|
batch_size = frame_indices.shape[0]
|
||||||
|
t_coords = frame_indices.to(device=device, dtype=dtype)
|
||||||
|
if t_coords.shape[1] != steps_t:
|
||||||
|
t_coords = torch.nn.functional.interpolate(
|
||||||
|
t_coords.unsqueeze(1),
|
||||||
|
size=steps_t,
|
||||||
|
mode="linear",
|
||||||
|
align_corners=False,
|
||||||
|
).squeeze(1)
|
||||||
|
t_coords = (t_coords + t_start)[:, :, None, None]
|
||||||
|
|
||||||
|
img_ids = torch.zeros((batch_size, steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||||
|
img_ids[:, :, :, :, 0] = img_ids[:, :, :, :, 0] + t_coords.expand(batch_size, steps_t, steps_h, steps_w)
|
||||||
|
img_ids[:, :, :, :, 1] = img_ids[:, :, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, 1, -1, 1)
|
||||||
|
img_ids[:, :, :, :, 2] = img_ids[:, :, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, 1, -1)
|
||||||
|
img_ids = img_ids.reshape(batch_size, -1, img_ids.shape[-1])
|
||||||
|
return self.rope_embedder(img_ids).movedim(1, 2)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timestep,
|
||||||
|
context,
|
||||||
|
clip_fea=None,
|
||||||
|
time_dim_concat=None,
|
||||||
|
transformer_options={},
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
|
||||||
|
).execute(
|
||||||
|
x,
|
||||||
|
timestep,
|
||||||
|
context,
|
||||||
|
clip_fea,
|
||||||
|
time_dim_concat,
|
||||||
|
transformer_options,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timestep,
|
||||||
|
context,
|
||||||
|
clip_fea=None,
|
||||||
|
time_dim_concat=None,
|
||||||
|
transformer_options={},
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
del clip_fea, time_dim_concat
|
||||||
|
|
||||||
|
_, _, t, h, w = x.shape
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
|
|
||||||
|
out = self.forward_orig(
|
||||||
|
hidden_states=x,
|
||||||
|
timestep=timestep,
|
||||||
|
context=context,
|
||||||
|
indices_hidden_states=kwargs.get("indices_hidden_states", None),
|
||||||
|
indices_latents_history_short=kwargs.get("indices_latents_history_short", None),
|
||||||
|
indices_latents_history_mid=kwargs.get("indices_latents_history_mid", None),
|
||||||
|
indices_latents_history_long=kwargs.get("indices_latents_history_long", None),
|
||||||
|
latents_history_short=kwargs.get("latents_history_short", None),
|
||||||
|
latents_history_mid=kwargs.get("latents_history_mid", None),
|
||||||
|
latents_history_long=kwargs.get("latents_history_long", None),
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
return out[:, :, :t, :h, :w]
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
timestep,
|
||||||
|
context,
|
||||||
|
indices_hidden_states=None,
|
||||||
|
indices_latents_history_short=None,
|
||||||
|
indices_latents_history_mid=None,
|
||||||
|
indices_latents_history_long=None,
|
||||||
|
latents_history_short=None,
|
||||||
|
latents_history_mid=None,
|
||||||
|
latents_history_long=None,
|
||||||
|
transformer_options={},
|
||||||
|
):
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
p_t, p_h, p_w = self.patch_size
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
hidden_states = self.patch_embedding(hidden_states.float()).to(hidden_states.dtype)
|
||||||
|
_, _, post_t, post_h, post_w = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
if indices_hidden_states is None:
|
||||||
|
indices_hidden_states = (torch.arange(0, post_t, device=hidden_states.device).unsqueeze(0).expand(batch_size, -1))
|
||||||
|
|
||||||
|
freqs = self.rope_encode(
|
||||||
|
t=post_t * self.patch_size[0],
|
||||||
|
h=post_h * self.patch_size[1],
|
||||||
|
w=post_w * self.patch_size[2],
|
||||||
|
steps_t=post_t,
|
||||||
|
steps_h=post_h,
|
||||||
|
steps_w=post_w,
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
frame_indices=indices_hidden_states,
|
||||||
|
)
|
||||||
|
original_context_length = hidden_states.shape[1]
|
||||||
|
|
||||||
|
if (latents_history_short is not None and indices_latents_history_short is not None and hasattr(self, "patch_short")):
|
||||||
|
x_short = self.patch_short(latents_history_short.float()).to(hidden_states.dtype)
|
||||||
|
_, _, ts, hs, ws = x_short.shape
|
||||||
|
x_short = x_short.flatten(2).transpose(1, 2)
|
||||||
|
f_short = self.rope_encode(
|
||||||
|
t=ts * self.patch_size[0],
|
||||||
|
h=hs * self.patch_size[1],
|
||||||
|
w=ws * self.patch_size[2],
|
||||||
|
steps_t=ts,
|
||||||
|
steps_h=hs,
|
||||||
|
steps_w=ws,
|
||||||
|
device=x_short.device,
|
||||||
|
dtype=x_short.dtype,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
frame_indices=indices_latents_history_short,
|
||||||
|
)
|
||||||
|
hidden_states = torch.cat([x_short, hidden_states], dim=1)
|
||||||
|
freqs = torch.cat([f_short, freqs], dim=1)
|
||||||
|
|
||||||
|
if (latents_history_mid is not None and indices_latents_history_mid is not None and hasattr(self, "patch_mid")):
|
||||||
|
x_mid = self.patch_mid(pad_for_3d_conv(latents_history_mid, (2, 4, 4)).float()).to(hidden_states.dtype)
|
||||||
|
_, _, tm, hm, wm = x_mid.shape
|
||||||
|
x_mid = x_mid.flatten(2).transpose(1, 2)
|
||||||
|
f_mid = self.rope_encode(
|
||||||
|
t=tm * self.patch_size[0],
|
||||||
|
h=hm * self.patch_size[1],
|
||||||
|
w=wm * self.patch_size[2],
|
||||||
|
steps_t=tm,
|
||||||
|
steps_h=hm,
|
||||||
|
steps_w=wm,
|
||||||
|
device=x_mid.device,
|
||||||
|
dtype=x_mid.dtype,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
frame_indices=indices_latents_history_mid,
|
||||||
|
)
|
||||||
|
hidden_states = torch.cat([x_mid, hidden_states], dim=1)
|
||||||
|
freqs = torch.cat([f_mid, freqs], dim=1)
|
||||||
|
|
||||||
|
if (latents_history_long is not None and indices_latents_history_long is not None and hasattr(self, "patch_long")):
|
||||||
|
x_long = self.patch_long(pad_for_3d_conv(latents_history_long, (4, 8, 8)).float()).to(hidden_states.dtype)
|
||||||
|
_, _, tl, hl, wl = x_long.shape
|
||||||
|
x_long = x_long.flatten(2).transpose(1, 2)
|
||||||
|
f_long = self.rope_encode(
|
||||||
|
t=tl * self.patch_size[0],
|
||||||
|
h=hl * self.patch_size[1],
|
||||||
|
w=wl * self.patch_size[2],
|
||||||
|
steps_t=tl,
|
||||||
|
steps_h=hl,
|
||||||
|
steps_w=wl,
|
||||||
|
device=x_long.device,
|
||||||
|
dtype=x_long.dtype,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
frame_indices=indices_latents_history_long,
|
||||||
|
)
|
||||||
|
hidden_states = torch.cat([x_long, hidden_states], dim=1)
|
||||||
|
freqs = torch.cat([f_long, freqs], dim=1)
|
||||||
|
|
||||||
|
history_context_length = hidden_states.shape[1] - original_context_length
|
||||||
|
|
||||||
|
if timestep.ndim == 0:
|
||||||
|
timestep = timestep.unsqueeze(0)
|
||||||
|
timestep = timestep.to(hidden_states.device)
|
||||||
|
if timestep.shape[0] != batch_size:
|
||||||
|
timestep = timestep[:1].expand(batch_size)
|
||||||
|
|
||||||
|
# time embeddings
|
||||||
|
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=hidden_states.dtype))
|
||||||
|
e = e.reshape(batch_size, -1, e.shape[-1])
|
||||||
|
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
if self.zero_history_timestep and history_context_length > 0:
|
||||||
|
timestep_t0 = torch.zeros((1, ), dtype=timestep.dtype, device=timestep.device)
|
||||||
|
e_t0 = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep_t0.flatten()).to(dtype=hidden_states.dtype))
|
||||||
|
e_t0 = e_t0.reshape(1, -1, e_t0.shape[-1]).expand(batch_size, history_context_length, -1)
|
||||||
|
e0_t0 = self.time_projection(e_t0[:, :1]).unflatten(2, (6, self.dim))
|
||||||
|
e0_t0 = (e0_t0.view(batch_size, 1, 6, self.dim).permute(0, 2, 1, 3).expand(batch_size, 6, history_context_length, self.dim))
|
||||||
|
|
||||||
|
e = e.expand(batch_size, original_context_length, -1)
|
||||||
|
e0 = (e0.view(batch_size, 1, 6, self.dim).permute(0, 2, 1, 3).expand(batch_size, 6, original_context_length, self.dim))
|
||||||
|
e = torch.cat([e_t0, e], dim=1)
|
||||||
|
e0 = torch.cat([e0_t0, e0], dim=2)
|
||||||
|
else:
|
||||||
|
e = e.expand(batch_size, hidden_states.shape[1], -1)
|
||||||
|
e0 = (e0.view(batch_size, 1, 6, self.dim).permute(0, 2, 1, 3).expand(batch_size, 6, hidden_states.shape[1], self.dim))
|
||||||
|
|
||||||
|
e0 = e0.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
hidden_states = block(
|
||||||
|
hidden_states,
|
||||||
|
context,
|
||||||
|
e0,
|
||||||
|
freqs,
|
||||||
|
original_context_length=original_context_length,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.norm_out(hidden_states, e, original_context_length)
|
||||||
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|
||||||
|
return self.unpatchify(hidden_states, (post_t, post_h, post_w))
|
||||||
|
|
||||||
|
def unpatchify(self, x, grid_sizes):
|
||||||
|
c = self.out_dim
|
||||||
|
b = x.shape[0]
|
||||||
|
u = x[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c)
|
||||||
|
u = torch.einsum("bfhwpqrc->bcfphqwr", u)
|
||||||
|
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
||||||
|
return u
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict, strict=True, assign=False):
|
||||||
|
# Keep compatibility with reference diffusers key names.
|
||||||
|
remapped = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
nk = k
|
||||||
|
nk = nk.replace("condition_embedder.time_embedder.linear_1.", "time_embedding.0.")
|
||||||
|
nk = nk.replace("condition_embedder.time_embedder.linear_2.", "time_embedding.2.")
|
||||||
|
nk = nk.replace("condition_embedder.time_proj.", "time_projection.1.")
|
||||||
|
nk = nk.replace("condition_embedder.text_embedder.linear_1.", "text_embedding.0.")
|
||||||
|
nk = nk.replace("condition_embedder.text_embedder.linear_2.", "text_embedding.2.")
|
||||||
|
nk = nk.replace("blocks.", "blocks.")
|
||||||
|
remapped[nk] = v
|
||||||
|
|
||||||
|
return super().load_state_dict(remapped, strict=strict, assign=assign)
|
||||||
|
|
||||||
|
|
||||||
|
# Backward-compatible alias for existing integration points.
|
||||||
|
HeliosTransformer3DModel = HeliosModel
|
||||||
@ -41,6 +41,7 @@ import comfy.ldm.cosmos.predict2
|
|||||||
import comfy.ldm.lumina.model
|
import comfy.ldm.lumina.model
|
||||||
import comfy.ldm.wan.model
|
import comfy.ldm.wan.model
|
||||||
import comfy.ldm.wan.model_animate
|
import comfy.ldm.wan.model_animate
|
||||||
|
import comfy.ldm.helios.model
|
||||||
import comfy.ldm.hunyuan3d.model
|
import comfy.ldm.hunyuan3d.model
|
||||||
import comfy.ldm.hidream.model
|
import comfy.ldm.hidream.model
|
||||||
import comfy.ldm.chroma.model
|
import comfy.ldm.chroma.model
|
||||||
@ -1268,6 +1269,35 @@ class ZImagePixelSpace(Lumina2):
|
|||||||
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
|
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
|
||||||
self.memory_usage_factor_conds = ("ref_latents",)
|
self.memory_usage_factor_conds = ("ref_latents",)
|
||||||
|
|
||||||
|
class Helios(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.helios.model.HeliosTransformer3DModel)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out["c_crossattn"] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
cond_keys = (
|
||||||
|
"indices_hidden_states",
|
||||||
|
"indices_latents_history_short",
|
||||||
|
"indices_latents_history_mid",
|
||||||
|
"indices_latents_history_long",
|
||||||
|
"latents_history_short",
|
||||||
|
"latents_history_mid",
|
||||||
|
"latents_history_long",
|
||||||
|
)
|
||||||
|
|
||||||
|
for key in cond_keys:
|
||||||
|
value = kwargs.get(key, None)
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
if key.startswith("latents_"):
|
||||||
|
value = self.process_latent_in(value)
|
||||||
|
out[key] = comfy.conds.CONDRegular(value)
|
||||||
|
return out
|
||||||
|
|
||||||
class WAN21(BaseModel):
|
class WAN21(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||||
|
|||||||
@ -489,6 +489,48 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}condition_embedder.time_proj.weight'.format(key_prefix) in state_dict_keys and '{}patch_embedding.weight'.format(key_prefix) in state_dict_keys: # Helios
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "helios"
|
||||||
|
|
||||||
|
patch_weight = state_dict['{}patch_embedding.weight'.format(key_prefix)]
|
||||||
|
inner_dim = patch_weight.shape[0]
|
||||||
|
patch_size = tuple(patch_weight.shape[2:])
|
||||||
|
out_proj = state_dict['{}proj_out.weight'.format(key_prefix)]
|
||||||
|
|
||||||
|
dit_config["patch_size"] = patch_size
|
||||||
|
dit_config["in_channels"] = patch_weight.shape[1]
|
||||||
|
dit_config["out_channels"] = out_proj.shape[0] // math.prod(patch_size)
|
||||||
|
dit_config["text_dim"] = state_dict['{}condition_embedder.text_embedder.linear_1.weight'.format(key_prefix)].shape[1]
|
||||||
|
dit_config["freq_dim"] = state_dict['{}condition_embedder.time_embedder.linear_1.weight'.format(key_prefix)].shape[1]
|
||||||
|
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||||
|
dit_config["num_attention_heads"] = inner_dim // 128
|
||||||
|
dit_config["attention_head_dim"] = 128
|
||||||
|
|
||||||
|
ffn_in = state_dict.get('{}blocks.0.ffn.net.0.proj.weight'.format(key_prefix), None)
|
||||||
|
if ffn_in is None:
|
||||||
|
ffn_in = state_dict.get('{}blocks.0.ffn.0.weight'.format(key_prefix), None)
|
||||||
|
if ffn_in is not None:
|
||||||
|
dit_config["ffn_dim"] = ffn_in.shape[0]
|
||||||
|
|
||||||
|
if '{}blocks.0.attn2.add_k_proj.weight'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config["added_kv_proj_dim"] = state_dict['{}blocks.0.attn2.add_k_proj.weight'.format(key_prefix)].shape[1]
|
||||||
|
|
||||||
|
if '{}patch_short.weight'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config["has_multi_term_memory_patch"] = True
|
||||||
|
else:
|
||||||
|
dit_config["has_multi_term_memory_patch"] = False
|
||||||
|
|
||||||
|
if '{}blocks.0.attn1.history_key_scale'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config["is_amplify_history"] = True
|
||||||
|
hk = state_dict['{}blocks.0.attn1.history_key_scale'.format(key_prefix)]
|
||||||
|
dit_config["history_scale_mode"] = "per_head" if len(hk.shape) > 0 and hk.numel() > 1 else "scalar"
|
||||||
|
|
||||||
|
if metadata is not None and "config" in metadata:
|
||||||
|
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
|
||||||
|
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "wan2.1"
|
dit_config["image_model"] = "wan2.1"
|
||||||
|
|||||||
@ -48,6 +48,7 @@ import comfy.text_encoders.hunyuan_video
|
|||||||
import comfy.text_encoders.cosmos
|
import comfy.text_encoders.cosmos
|
||||||
import comfy.text_encoders.lumina2
|
import comfy.text_encoders.lumina2
|
||||||
import comfy.text_encoders.wan
|
import comfy.text_encoders.wan
|
||||||
|
import comfy.text_encoders.helios
|
||||||
import comfy.text_encoders.hidream
|
import comfy.text_encoders.hidream
|
||||||
import comfy.text_encoders.ace
|
import comfy.text_encoders.ace
|
||||||
import comfy.text_encoders.omnigen2
|
import comfy.text_encoders.omnigen2
|
||||||
@ -1163,6 +1164,7 @@ class CLIPType(Enum):
|
|||||||
NEWBIE = 24
|
NEWBIE = 24
|
||||||
FLUX2 = 25
|
FLUX2 = 25
|
||||||
LONGCAT_IMAGE = 26
|
LONGCAT_IMAGE = 26
|
||||||
|
HELIOS = 27
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -1334,6 +1336,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
|
||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
|
elif clip_type == CLIPType.HELIOS:
|
||||||
|
clip_target.clip = comfy.text_encoders.helios.te(**t5xxl_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.helios.HeliosT5Tokenizer
|
||||||
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
elif clip_type == CLIPType.HIDREAM:
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
||||||
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
|
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
|
||||||
|
|||||||
@ -17,6 +17,7 @@ import comfy.text_encoders.hunyuan_video
|
|||||||
import comfy.text_encoders.cosmos
|
import comfy.text_encoders.cosmos
|
||||||
import comfy.text_encoders.lumina2
|
import comfy.text_encoders.lumina2
|
||||||
import comfy.text_encoders.wan
|
import comfy.text_encoders.wan
|
||||||
|
import comfy.text_encoders.helios
|
||||||
import comfy.text_encoders.ace
|
import comfy.text_encoders.ace
|
||||||
import comfy.text_encoders.omnigen2
|
import comfy.text_encoders.omnigen2
|
||||||
import comfy.text_encoders.qwen_image
|
import comfy.text_encoders.qwen_image
|
||||||
@ -1132,6 +1133,35 @@ class ZImagePixelSpace(ZImage):
|
|||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
return model_base.ZImagePixelSpace(self, device=device)
|
return model_base.ZImagePixelSpace(self, device=device)
|
||||||
|
|
||||||
|
class Helios(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "helios",
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.Wan21
|
||||||
|
memory_usage_factor = 1.8
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def __init__(self, unet_config):
|
||||||
|
super().__init__(unet_config)
|
||||||
|
self.memory_usage_factor = (self.unet_config.get("num_layers", 40) * self.unet_config.get("num_attention_heads", 40)) / (40 * 40) * 1.8
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
return model_base.Helios(self, device=device)
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.helios.HeliosT5Tokenizer, comfy.text_encoders.helios.te(**t5_detect))
|
||||||
|
|
||||||
class WAN21_T2V(supported_models_base.BASE):
|
class WAN21_T2V(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "wan2.1",
|
"image_model": "wan2.1",
|
||||||
@ -1734,6 +1764,6 @@ class LongCatImage(supported_models_base.BASE):
|
|||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, Helios, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
41
comfy/text_encoders/helios.py
Normal file
41
comfy/text_encoders/helios.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from comfy import sd1_clip
|
||||||
|
from .spiece_tokenizer import SPieceTokenizer
|
||||||
|
import comfy.text_encoders.t5
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class UMT5XXlModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||||
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_xxl.json")
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key="umt5xxl", tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
|
|
||||||
|
|
||||||
|
class HeliosT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="umt5xxl", tokenizer=UMT5XXlTokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
class HeliosT5Model(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def te(dtype_t5=None, t5_quantization_metadata=None):
|
||||||
|
class HeliosTEModel(HeliosT5Model):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if t5_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["quantization_metadata"] = t5_quantization_metadata
|
||||||
|
if dtype_t5 is not None:
|
||||||
|
dtype = dtype_t5
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return HeliosTEModel
|
||||||
928
comfy_extras/nodes_helios.py
Normal file
928
comfy_extras/nodes_helios.py
Normal file
@ -0,0 +1,928 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import nodes
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.model_patcher
|
||||||
|
import comfy.sample
|
||||||
|
import comfy.samplers
|
||||||
|
import comfy.utils
|
||||||
|
import latent_preview
|
||||||
|
import node_helpers
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_int_list(values, default):
|
||||||
|
if values is None:
|
||||||
|
return default
|
||||||
|
if isinstance(values, (list, tuple)):
|
||||||
|
out = []
|
||||||
|
for v in values:
|
||||||
|
try:
|
||||||
|
out.append(int(v))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return out if len(out) > 0 else default
|
||||||
|
|
||||||
|
parts = [x.strip() for x in str(values).replace(";", ",").split(",")]
|
||||||
|
out = []
|
||||||
|
for p in parts:
|
||||||
|
if len(p) == 0:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
out.append(int(p))
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return out if len(out) > 0 else default
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_float_list(values, default):
|
||||||
|
if values is None:
|
||||||
|
return default
|
||||||
|
if isinstance(values, (list, tuple)):
|
||||||
|
out = []
|
||||||
|
for v in values:
|
||||||
|
try:
|
||||||
|
out.append(float(v))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return out if len(out) > 0 else default
|
||||||
|
|
||||||
|
parts = [x.strip() for x in str(values).replace(";", ",").split(",")]
|
||||||
|
out = []
|
||||||
|
for p in parts:
|
||||||
|
if len(p) == 0:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
out.append(float(p))
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return out if len(out) > 0 else default
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_condition_value(conditioning, key):
|
||||||
|
for c in conditioning:
|
||||||
|
if len(c) < 2:
|
||||||
|
continue
|
||||||
|
value = c[1].get(key, None)
|
||||||
|
if value is not None:
|
||||||
|
return value
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _upsample_latent_5d(latent, scale=2):
|
||||||
|
b, c, t, h, w = latent.shape
|
||||||
|
x = latent.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||||
|
x = comfy.utils.common_upscale(x, w * scale, h * scale, "nearest-exact", "disabled")
|
||||||
|
x = x.reshape(b, t, c, h * scale, w * scale).permute(0, 2, 1, 3, 4)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_block_noise_like(latent, gamma, patch_size=(1, 2, 2)):
|
||||||
|
b, c, t, h, w = latent.shape
|
||||||
|
_, ph, pw = patch_size
|
||||||
|
block_size = ph * pw
|
||||||
|
|
||||||
|
cov = torch.eye(block_size, device=latent.device) * (1.0 + gamma) - torch.ones(block_size, block_size, device=latent.device) * gamma
|
||||||
|
cov += torch.eye(block_size, device=latent.device) * 1e-6
|
||||||
|
|
||||||
|
dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=latent.device), covariance_matrix=cov)
|
||||||
|
block_number = b * c * t * max(1, h // ph) * max(1, w // pw)
|
||||||
|
|
||||||
|
noise = dist.sample((block_number,))
|
||||||
|
noise = noise.view(b, c, t, max(1, h // ph), max(1, w // pw), ph, pw)
|
||||||
|
noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(b, c, t, max(1, h // ph) * ph, max(1, w // pw) * pw)
|
||||||
|
noise = noise[:, :, :, :h, :w]
|
||||||
|
return noise
|
||||||
|
|
||||||
|
|
||||||
|
def _helios_global_sigmas(num_train_timesteps=1000, shift=1.0):
|
||||||
|
alphas = torch.linspace(1.0, 1.0 / float(num_train_timesteps), num_train_timesteps + 1)
|
||||||
|
sigmas = 1.0 - alphas
|
||||||
|
if abs(shift - 1.0) > 1e-8:
|
||||||
|
sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas)
|
||||||
|
return torch.flip(sigmas, dims=[0])[:-1]
|
||||||
|
|
||||||
|
|
||||||
|
def _helios_stage_tables(stage_count, stage_range, gamma, num_train_timesteps=1000, shift=1.0):
|
||||||
|
sigmas = _helios_global_sigmas(num_train_timesteps=num_train_timesteps, shift=shift)
|
||||||
|
|
||||||
|
ori_start_sigmas = {}
|
||||||
|
start_sigmas = {}
|
||||||
|
end_sigmas = {}
|
||||||
|
timestep_ratios = {}
|
||||||
|
timesteps_per_stage = {}
|
||||||
|
sigmas_per_stage = {}
|
||||||
|
|
||||||
|
stage_distance = []
|
||||||
|
for i in range(stage_count):
|
||||||
|
start_indice = int(max(0.0, min(1.0, stage_range[i])) * num_train_timesteps)
|
||||||
|
end_indice = int(max(0.0, min(1.0, stage_range[i + 1])) * num_train_timesteps)
|
||||||
|
start_indice = max(0, min(num_train_timesteps - 1, start_indice))
|
||||||
|
end_indice = max(0, min(num_train_timesteps, end_indice))
|
||||||
|
|
||||||
|
start_sigma = float(sigmas[start_indice].item())
|
||||||
|
end_sigma = float(sigmas[end_indice].item()) if end_indice < num_train_timesteps else 0.0
|
||||||
|
ori_start_sigmas[i] = start_sigma
|
||||||
|
|
||||||
|
if i != 0:
|
||||||
|
ori_sigma = 1.0 - start_sigma
|
||||||
|
corrected_sigma = (1.0 / (math.sqrt(1.0 + (1.0 / gamma)) * (1.0 - ori_sigma) + ori_sigma)) * ori_sigma
|
||||||
|
start_sigma = 1.0 - corrected_sigma
|
||||||
|
|
||||||
|
stage_distance.append(start_sigma - end_sigma)
|
||||||
|
start_sigmas[i] = start_sigma
|
||||||
|
end_sigmas[i] = end_sigma
|
||||||
|
|
||||||
|
tot_distance = sum(stage_distance) if sum(stage_distance) > 1e-12 else 1.0
|
||||||
|
for i in range(stage_count):
|
||||||
|
start_ratio = 0.0 if i == 0 else sum(stage_distance[:i]) / tot_distance
|
||||||
|
end_ratio = 0.9999999999999999 if i == stage_count - 1 else sum(stage_distance[: i + 1]) / tot_distance
|
||||||
|
timestep_ratios[i] = (start_ratio, end_ratio)
|
||||||
|
|
||||||
|
tmax = min(float(sigmas[int(start_ratio * num_train_timesteps)].item() * num_train_timesteps), 999.0)
|
||||||
|
tmin = float(sigmas[min(int(end_ratio * num_train_timesteps), num_train_timesteps - 1)].item() * num_train_timesteps)
|
||||||
|
timesteps_per_stage[i] = torch.linspace(tmax, tmin, num_train_timesteps)
|
||||||
|
sigmas_per_stage[i] = torch.linspace(0.999, 0.0, num_train_timesteps)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"ori_start_sigmas": ori_start_sigmas,
|
||||||
|
"start_sigmas": start_sigmas,
|
||||||
|
"end_sigmas": end_sigmas,
|
||||||
|
"timestep_ratios": timestep_ratios,
|
||||||
|
"timesteps_per_stage": timesteps_per_stage,
|
||||||
|
"sigmas_per_stage": sigmas_per_stage,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _helios_stage_sigmas(stage_idx, stage_steps, stage_tables, is_distilled=False, is_amplify_first_stage=False):
|
||||||
|
stage_steps = max(1, int(stage_steps))
|
||||||
|
if is_distilled:
|
||||||
|
stage_steps = stage_steps * 2 if (is_amplify_first_stage and stage_idx == 0) else stage_steps
|
||||||
|
|
||||||
|
stage_sigma_src = stage_tables["sigmas_per_stage"][stage_idx]
|
||||||
|
sigmas = torch.linspace(float(stage_sigma_src[0].item()), float(stage_sigma_src[-1].item()), stage_steps + 1)
|
||||||
|
return sigmas
|
||||||
|
|
||||||
|
|
||||||
|
def _helios_stage_timesteps(stage_idx, stage_steps, stage_tables, is_distilled=False, is_amplify_first_stage=False):
|
||||||
|
stage_steps = max(1, int(stage_steps))
|
||||||
|
if is_distilled:
|
||||||
|
stage_steps = stage_steps * 2 if (is_amplify_first_stage and stage_idx == 0) else stage_steps
|
||||||
|
|
||||||
|
stage_timestep_src = stage_tables["timesteps_per_stage"][stage_idx]
|
||||||
|
timesteps = torch.linspace(float(stage_timestep_src[0].item()), float(stage_timestep_src[-1].item()), stage_steps)
|
||||||
|
return timesteps
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_shift(image_seq_len, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.15):
|
||||||
|
m = (max_shift - base_shift) / float(max_seq_len - base_seq_len)
|
||||||
|
b = base_shift - m * float(base_seq_len)
|
||||||
|
return float(image_seq_len) * m + b
|
||||||
|
|
||||||
|
|
||||||
|
def _time_shift_linear(mu, sigma, t):
|
||||||
|
return mu / (mu + (1.0 / t - 1.0) ** sigma)
|
||||||
|
|
||||||
|
|
||||||
|
def _time_shift_exponential(mu, sigma, t):
|
||||||
|
return math.exp(mu) / (math.exp(mu) + (1.0 / t - 1.0) ** sigma)
|
||||||
|
|
||||||
|
|
||||||
|
def _time_shift(t, mu, sigma=1.0, mode="exponential"):
|
||||||
|
t = torch.clamp(t, min=1e-6, max=0.999999)
|
||||||
|
if mode == "linear":
|
||||||
|
return _time_shift_linear(mu, sigma, t)
|
||||||
|
return _time_shift_exponential(mu, sigma, t)
|
||||||
|
|
||||||
|
|
||||||
|
def _optimized_scale(positive_flat, negative_flat):
|
||||||
|
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||||
|
squared_norm = torch.sum(negative_flat * negative_flat, dim=1, keepdim=True) + 1e-8
|
||||||
|
return dot_product / squared_norm
|
||||||
|
|
||||||
|
|
||||||
|
def _build_cfg_zero_star_pre_cfg(stage_idx, zero_steps, use_zero_init):
|
||||||
|
state = {"i": 0}
|
||||||
|
|
||||||
|
def pre_cfg_fn(args):
|
||||||
|
conds_out = args["conds_out"]
|
||||||
|
if len(conds_out) < 2 or conds_out[1] is None:
|
||||||
|
state["i"] += 1
|
||||||
|
return conds_out
|
||||||
|
|
||||||
|
noise_pred_text = conds_out[0]
|
||||||
|
noise_uncond = conds_out[1]
|
||||||
|
cfg = float(args.get("cond_scale", 1.0))
|
||||||
|
|
||||||
|
positive_flat = noise_pred_text.view(noise_pred_text.shape[0], -1)
|
||||||
|
negative_flat = noise_uncond.view(noise_uncond.shape[0], -1)
|
||||||
|
alpha = _optimized_scale(positive_flat, negative_flat)
|
||||||
|
alpha = alpha.view(noise_pred_text.shape[0], *([1] * (noise_pred_text.ndim - 1))).to(noise_pred_text.dtype)
|
||||||
|
|
||||||
|
if stage_idx == 0 and state["i"] <= int(zero_steps) and bool(use_zero_init):
|
||||||
|
final = noise_pred_text * 0.0
|
||||||
|
else:
|
||||||
|
final = noise_uncond * alpha + cfg * (noise_pred_text - noise_uncond * alpha)
|
||||||
|
|
||||||
|
state["i"] += 1
|
||||||
|
# Return identical cond/uncond so downstream cfg_function keeps `final` unchanged.
|
||||||
|
return [final, final]
|
||||||
|
|
||||||
|
return pre_cfg_fn
|
||||||
|
|
||||||
|
|
||||||
|
def _helios_euler_sample(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
for i in range(len(sigmas) - 1):
|
||||||
|
sigma = sigmas[i]
|
||||||
|
sigma_next = sigmas[i + 1]
|
||||||
|
denoised = model(x, sigma * s_in, **extra_args)
|
||||||
|
|
||||||
|
sigma_safe = sigma if float(sigma) > 1e-8 else sigma.new_tensor(1e-8)
|
||||||
|
flow_pred = (x - denoised) / sigma_safe
|
||||||
|
|
||||||
|
if callback is not None:
|
||||||
|
callback({"x": x, "i": i, "sigma": sigma, "sigma_hat": sigma, "denoised": denoised})
|
||||||
|
|
||||||
|
x = x + (sigma_next - sigma) * flow_pred
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _helios_dmd_sample(
|
||||||
|
model,
|
||||||
|
x,
|
||||||
|
sigmas,
|
||||||
|
extra_args=None,
|
||||||
|
callback=None,
|
||||||
|
disable=None,
|
||||||
|
dmd_noisy_tensor=None,
|
||||||
|
dmd_sigmas=None,
|
||||||
|
dmd_timesteps=None,
|
||||||
|
all_timesteps=None,
|
||||||
|
):
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
if dmd_noisy_tensor is None:
|
||||||
|
dmd_noisy_tensor = x
|
||||||
|
dmd_noisy_tensor = dmd_noisy_tensor.to(device=x.device, dtype=x.dtype)
|
||||||
|
if dmd_sigmas is None:
|
||||||
|
dmd_sigmas = sigmas
|
||||||
|
if dmd_timesteps is None:
|
||||||
|
dmd_timesteps = torch.arange(len(sigmas) - 1, device=sigmas.device, dtype=sigmas.dtype)
|
||||||
|
if all_timesteps is None:
|
||||||
|
all_timesteps = dmd_timesteps
|
||||||
|
|
||||||
|
def timestep_to_sigma(t):
|
||||||
|
dt = dmd_timesteps.to(device=x.device, dtype=x.dtype)
|
||||||
|
ds = dmd_sigmas.to(device=x.device, dtype=x.dtype)
|
||||||
|
tid = torch.argmin(torch.abs(dt - t))
|
||||||
|
tid = torch.clamp(tid, min=0, max=ds.shape[0] - 1)
|
||||||
|
return ds[tid]
|
||||||
|
|
||||||
|
for i in range(len(sigmas) - 1):
|
||||||
|
sigma = sigmas[i]
|
||||||
|
sigma_next = sigmas[i + 1]
|
||||||
|
timestep = all_timesteps[i] if i < len(all_timesteps) else i
|
||||||
|
denoised = model(x, sigma * s_in, **extra_args)
|
||||||
|
|
||||||
|
if callback is not None:
|
||||||
|
callback({"x": x, "i": i, "sigma": sigma, "sigma_hat": sigma, "denoised": denoised})
|
||||||
|
|
||||||
|
if i < (len(sigmas) - 2):
|
||||||
|
timestep_next = all_timesteps[i + 1] if i + 1 < len(all_timesteps) else (i + 1)
|
||||||
|
sigma_t = timestep_to_sigma(torch.as_tensor(timestep, device=x.device, dtype=x.dtype))
|
||||||
|
sigma_next_t = timestep_to_sigma(torch.as_tensor(timestep_next, device=x.device, dtype=x.dtype))
|
||||||
|
x0_pred = x - sigma_t * ((x - denoised) / torch.clamp(sigma_t, min=1e-8))
|
||||||
|
x = (1.0 - sigma_next_t) * x0_pred + sigma_next_t * dmd_noisy_tensor
|
||||||
|
else:
|
||||||
|
x = denoised
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _set_helios_history_values(positive, negative, history_latent, history_sizes, keep_first_frame, prefix_latent=None):
|
||||||
|
latent = history_latent
|
||||||
|
if latent is None or len(latent.shape) != 5:
|
||||||
|
return positive, negative
|
||||||
|
|
||||||
|
sizes = list(history_sizes)
|
||||||
|
if len(sizes) != 3:
|
||||||
|
sizes = [16, 2, 1]
|
||||||
|
sizes = [max(0, int(v)) for v in sizes]
|
||||||
|
total = sum(sizes)
|
||||||
|
if total <= 0:
|
||||||
|
return positive, negative
|
||||||
|
|
||||||
|
if latent.shape[2] < total:
|
||||||
|
pad = torch.zeros(
|
||||||
|
latent.shape[0],
|
||||||
|
latent.shape[1],
|
||||||
|
total - latent.shape[2],
|
||||||
|
latent.shape[3],
|
||||||
|
latent.shape[4],
|
||||||
|
device=latent.device,
|
||||||
|
dtype=latent.dtype,
|
||||||
|
)
|
||||||
|
hist = torch.cat([pad, latent], dim=2)
|
||||||
|
else:
|
||||||
|
hist = latent[:, :, -total:]
|
||||||
|
|
||||||
|
latents_history_long, latents_history_mid, latents_history_short_base = hist.split(sizes, dim=2)
|
||||||
|
|
||||||
|
if keep_first_frame:
|
||||||
|
if prefix_latent is not None:
|
||||||
|
prefix = prefix_latent
|
||||||
|
elif latent.shape[2] > 0:
|
||||||
|
prefix = latent[:, :, :1]
|
||||||
|
else:
|
||||||
|
prefix = torch.zeros(latent.shape[0], latent.shape[1], 1, latent.shape[3], latent.shape[4], device=latent.device, dtype=latent.dtype)
|
||||||
|
latents_history_short = torch.cat([prefix, latents_history_short_base], dim=2)
|
||||||
|
else:
|
||||||
|
latents_history_short = latents_history_short_base
|
||||||
|
|
||||||
|
idx_short = torch.arange(latents_history_short.shape[2], device=latent.device, dtype=latent.dtype).unsqueeze(0).expand(latent.shape[0], -1)
|
||||||
|
idx_mid = torch.arange(latents_history_mid.shape[2], device=latent.device, dtype=latent.dtype).unsqueeze(0).expand(latent.shape[0], -1)
|
||||||
|
idx_long = torch.arange(latents_history_long.shape[2], device=latent.device, dtype=latent.dtype).unsqueeze(0).expand(latent.shape[0], -1)
|
||||||
|
|
||||||
|
values = {
|
||||||
|
"latents_history_short": latents_history_short,
|
||||||
|
"latents_history_mid": latents_history_mid,
|
||||||
|
"latents_history_long": latents_history_long,
|
||||||
|
"indices_latents_history_short": idx_short,
|
||||||
|
"indices_latents_history_mid": idx_mid,
|
||||||
|
"indices_latents_history_long": idx_long,
|
||||||
|
}
|
||||||
|
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, values)
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, values)
|
||||||
|
return positive, negative
|
||||||
|
|
||||||
|
|
||||||
|
def _build_helios_indices(batch, history_sizes, keep_first_frame, hidden_frames, device, dtype):
|
||||||
|
sizes = list(history_sizes)
|
||||||
|
if len(sizes) != 3:
|
||||||
|
sizes = [16, 2, 1]
|
||||||
|
sizes = [max(0, int(v)) for v in sizes]
|
||||||
|
long_size, mid_size, short_base_size = sizes
|
||||||
|
|
||||||
|
if keep_first_frame:
|
||||||
|
total = 1 + long_size + mid_size + short_base_size + hidden_frames
|
||||||
|
indices = torch.arange(total, device=device, dtype=dtype)
|
||||||
|
splits = [1, long_size, mid_size, short_base_size, hidden_frames]
|
||||||
|
indices_prefix, idx_long, idx_mid, idx_1x, idx_hidden = torch.split(indices, splits, dim=0)
|
||||||
|
idx_short = torch.cat([indices_prefix, idx_1x], dim=0)
|
||||||
|
else:
|
||||||
|
total = long_size + mid_size + short_base_size + hidden_frames
|
||||||
|
indices = torch.arange(total, device=device, dtype=dtype)
|
||||||
|
splits = [long_size, mid_size, short_base_size, hidden_frames]
|
||||||
|
idx_long, idx_mid, idx_short, idx_hidden = torch.split(indices, splits, dim=0)
|
||||||
|
|
||||||
|
idx_hidden = idx_hidden.unsqueeze(0).expand(batch, -1)
|
||||||
|
idx_short = idx_short.unsqueeze(0).expand(batch, -1)
|
||||||
|
idx_mid = idx_mid.unsqueeze(0).expand(batch, -1)
|
||||||
|
idx_long = idx_long.unsqueeze(0).expand(batch, -1)
|
||||||
|
return idx_hidden, idx_short, idx_mid, idx_long
|
||||||
|
|
||||||
|
|
||||||
|
class HeliosImageToVideo(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="HeliosImageToVideo",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Conditioning.Input("negative"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("length", default=132, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
io.Image.Input("start_image", optional=True),
|
||||||
|
io.String.Input("history_sizes", default="16,2,1", advanced=True),
|
||||||
|
io.Boolean.Input("keep_first_frame", default=True, advanced=True),
|
||||||
|
io.Int.Input("num_latent_frames_per_chunk", default=9, min=1, max=256, advanced=True),
|
||||||
|
io.Boolean.Input("add_noise_to_image_latents", default=True, advanced=True),
|
||||||
|
io.Float.Input("image_noise_sigma_min", default=0.111, min=0.0, max=1.0, step=0.0001, round=False, advanced=True),
|
||||||
|
io.Float.Input("image_noise_sigma_max", default=0.135, min=0.0, max=1.0, step=0.0001, round=False, advanced=True),
|
||||||
|
io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, advanced=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(
|
||||||
|
cls,
|
||||||
|
positive,
|
||||||
|
negative,
|
||||||
|
vae,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
length,
|
||||||
|
batch_size,
|
||||||
|
start_image=None,
|
||||||
|
history_sizes="16,2,1",
|
||||||
|
keep_first_frame=True,
|
||||||
|
num_latent_frames_per_chunk=9,
|
||||||
|
add_noise_to_image_latents=True,
|
||||||
|
image_noise_sigma_min=0.111,
|
||||||
|
image_noise_sigma_max=0.135,
|
||||||
|
noise_seed=0,
|
||||||
|
) -> io.NodeOutput:
|
||||||
|
spacial_scale = vae.spacial_compression_encode()
|
||||||
|
latent_channels = vae.latent_channels
|
||||||
|
latent_t = ((length - 1) // 4) + 1
|
||||||
|
latent = torch.zeros([batch_size, latent_channels, latent_t, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
sizes = _parse_int_list(history_sizes, [16, 2, 1])
|
||||||
|
if len(sizes) != 3:
|
||||||
|
sizes = [16, 2, 1]
|
||||||
|
sizes = sorted([max(0, int(v)) for v in sizes], reverse=True)
|
||||||
|
hist_len = max(1, sum(sizes))
|
||||||
|
history_latent = torch.zeros([batch_size, latent_channels, hist_len, latent.shape[-2], latent.shape[-1]], device=latent.device, dtype=latent.dtype)
|
||||||
|
image_latent_prefix = None
|
||||||
|
|
||||||
|
if start_image is not None:
|
||||||
|
image = comfy.utils.common_upscale(start_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
img_latent = vae.encode(image[:, :, :, :3])
|
||||||
|
img_latent = comfy.utils.repeat_to_batch_size(img_latent, batch_size)
|
||||||
|
image_latent_prefix = img_latent[:, :, :1]
|
||||||
|
|
||||||
|
if add_noise_to_image_latents:
|
||||||
|
g = torch.Generator(device=img_latent.device)
|
||||||
|
g.manual_seed(int(noise_seed))
|
||||||
|
sigma = (
|
||||||
|
torch.rand((img_latent.shape[0], 1, 1, 1, 1), device=img_latent.device, generator=g, dtype=img_latent.dtype)
|
||||||
|
* (float(image_noise_sigma_max) - float(image_noise_sigma_min))
|
||||||
|
+ float(image_noise_sigma_min)
|
||||||
|
)
|
||||||
|
image_latent_prefix = sigma * torch.randn_like(image_latent_prefix, generator=g) + (1.0 - sigma) * image_latent_prefix
|
||||||
|
|
||||||
|
min_frames = max(1, (int(num_latent_frames_per_chunk) - 1) * 4 + 1)
|
||||||
|
fake_video = image.repeat(min_frames, 1, 1, 1)
|
||||||
|
fake_latents_full = vae.encode(fake_video)
|
||||||
|
fake_latent = comfy.utils.repeat_to_batch_size(fake_latents_full[:, :, -1:], batch_size)
|
||||||
|
history_latent[:, :, -1:] = fake_latent
|
||||||
|
|
||||||
|
positive, negative = _set_helios_history_values(positive, negative, history_latent, sizes, keep_first_frame, prefix_latent=image_latent_prefix)
|
||||||
|
return io.NodeOutput(
|
||||||
|
positive,
|
||||||
|
negative,
|
||||||
|
{
|
||||||
|
"samples": latent,
|
||||||
|
"helios_history_latent": history_latent,
|
||||||
|
"helios_image_latent_prefix": image_latent_prefix,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HeliosVideoToVideo(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="HeliosVideoToVideo",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Conditioning.Input("negative"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("length", default=132, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
io.Image.Input("video", optional=True),
|
||||||
|
io.String.Input("history_sizes", default="16,2,1", advanced=True),
|
||||||
|
io.Boolean.Input("keep_first_frame", default=True, advanced=True),
|
||||||
|
io.Boolean.Input("add_noise_to_video_latents", default=True, advanced=True),
|
||||||
|
io.Float.Input("video_noise_sigma_min", default=0.111, min=0.0, max=1.0, step=0.0001, round=False, advanced=True),
|
||||||
|
io.Float.Input("video_noise_sigma_max", default=0.135, min=0.0, max=1.0, step=0.0001, round=False, advanced=True),
|
||||||
|
io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, advanced=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(
|
||||||
|
cls,
|
||||||
|
positive,
|
||||||
|
negative,
|
||||||
|
vae,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
length,
|
||||||
|
batch_size,
|
||||||
|
video=None,
|
||||||
|
history_sizes="16,2,1",
|
||||||
|
keep_first_frame=True,
|
||||||
|
add_noise_to_video_latents=True,
|
||||||
|
video_noise_sigma_min=0.111,
|
||||||
|
video_noise_sigma_max=0.135,
|
||||||
|
noise_seed=0,
|
||||||
|
) -> io.NodeOutput:
|
||||||
|
spacial_scale = vae.spacial_compression_encode()
|
||||||
|
latent_channels = vae.latent_channels
|
||||||
|
latent_t = ((length - 1) // 4) + 1
|
||||||
|
latent = torch.zeros([batch_size, latent_channels, latent_t, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
sizes = _parse_int_list(history_sizes, [16, 2, 1])
|
||||||
|
if len(sizes) != 3:
|
||||||
|
sizes = [16, 2, 1]
|
||||||
|
sizes = sorted([max(0, int(v)) for v in sizes], reverse=True)
|
||||||
|
hist_len = max(1, sum(sizes))
|
||||||
|
history_latent = torch.zeros([batch_size, latent_channels, hist_len, latent.shape[-2], latent.shape[-1]], device=latent.device, dtype=latent.dtype)
|
||||||
|
image_latent_prefix = None
|
||||||
|
|
||||||
|
if video is not None:
|
||||||
|
video = comfy.utils.common_upscale(video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
vid_latent = vae.encode(video[:, :, :, :3])
|
||||||
|
if add_noise_to_video_latents:
|
||||||
|
g = torch.Generator(device=vid_latent.device)
|
||||||
|
g.manual_seed(int(noise_seed))
|
||||||
|
frame_sigmas = (
|
||||||
|
torch.rand((1, 1, vid_latent.shape[2], 1, 1), device=vid_latent.device, generator=g, dtype=vid_latent.dtype)
|
||||||
|
* (float(video_noise_sigma_max) - float(video_noise_sigma_min))
|
||||||
|
+ float(video_noise_sigma_min)
|
||||||
|
)
|
||||||
|
vid_latent = frame_sigmas * torch.randn_like(vid_latent, generator=g) + (1.0 - frame_sigmas) * vid_latent
|
||||||
|
vid_latent = vid_latent[:, :, :hist_len]
|
||||||
|
if vid_latent.shape[2] < hist_len:
|
||||||
|
pad = vid_latent[:, :, -1:].repeat(1, 1, hist_len - vid_latent.shape[2], 1, 1)
|
||||||
|
vid_latent = torch.cat([vid_latent, pad], dim=2)
|
||||||
|
vid_latent = comfy.utils.repeat_to_batch_size(vid_latent, batch_size)
|
||||||
|
history_latent = vid_latent
|
||||||
|
image_latent_prefix = history_latent[:, :, :1]
|
||||||
|
|
||||||
|
positive, negative = _set_helios_history_values(positive, negative, history_latent, sizes, keep_first_frame, prefix_latent=image_latent_prefix)
|
||||||
|
return io.NodeOutput(
|
||||||
|
positive,
|
||||||
|
negative,
|
||||||
|
{
|
||||||
|
"samples": latent,
|
||||||
|
"helios_history_latent": history_latent,
|
||||||
|
"helios_image_latent_prefix": image_latent_prefix,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HeliosHistoryConditioning(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="HeliosHistoryConditioning",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Conditioning.Input("negative"),
|
||||||
|
io.Latent.Input("history_latent"),
|
||||||
|
io.String.Input("history_sizes", default="16,2,1"),
|
||||||
|
io.Boolean.Input("keep_first_frame", default=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, positive, negative, history_latent, history_sizes, keep_first_frame) -> io.NodeOutput:
|
||||||
|
latent = history_latent["samples"]
|
||||||
|
if latent is None or len(latent.shape) != 5:
|
||||||
|
return io.NodeOutput(positive, negative)
|
||||||
|
sizes = _parse_int_list(history_sizes, [16, 2, 1])
|
||||||
|
sizes = sorted([max(0, int(v)) for v in sizes], reverse=True)
|
||||||
|
prefix = history_latent.get("helios_image_latent_prefix", None)
|
||||||
|
positive, negative = _set_helios_history_values(positive, negative, latent, sizes, keep_first_frame, prefix_latent=prefix)
|
||||||
|
return io.NodeOutput(positive, negative)
|
||||||
|
|
||||||
|
|
||||||
|
class HeliosPyramidSampler(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="HeliosPyramidSampler",
|
||||||
|
category="sampling/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Boolean.Input("add_noise", default=True, advanced=True),
|
||||||
|
io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, control_after_generate=True),
|
||||||
|
io.Float.Input("cfg", default=5.0, min=0.0, max=100.0, step=0.1, round=0.01),
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Conditioning.Input("negative"),
|
||||||
|
io.Latent.Input("latent_image"),
|
||||||
|
io.String.Input("pyramid_steps", default="10,10,10"),
|
||||||
|
io.String.Input("stage_range", default="0,0.333333,0.666667,1"),
|
||||||
|
io.Boolean.Input("is_distilled", default=False),
|
||||||
|
io.Boolean.Input("is_amplify_first_stage", default=False),
|
||||||
|
io.Combo.Input("scheduler_mode", options=["euler", "unipc_bh2"]),
|
||||||
|
io.Float.Input("gamma", default=1.0 / 3.0, min=0.0001, max=10.0, step=0.0001, round=False),
|
||||||
|
io.Float.Input("shift", default=1.0, min=0.001, max=100.0, step=0.001, round=False, advanced=True),
|
||||||
|
io.Boolean.Input("use_dynamic_shifting", default=False, advanced=True),
|
||||||
|
io.Combo.Input("time_shift_type", options=["exponential", "linear"], advanced=True),
|
||||||
|
io.Int.Input("base_image_seq_len", default=256, min=1, max=65536, advanced=True),
|
||||||
|
io.Int.Input("max_image_seq_len", default=4096, min=1, max=65536, advanced=True),
|
||||||
|
io.Float.Input("base_shift", default=0.5, min=0.0, max=10.0, step=0.0001, round=False, advanced=True),
|
||||||
|
io.Float.Input("max_shift", default=1.15, min=0.0, max=10.0, step=0.0001, round=False, advanced=True),
|
||||||
|
io.Int.Input("num_train_timesteps", default=1000, min=10, max=100000, advanced=True),
|
||||||
|
io.String.Input("history_sizes", default="16,2,1", advanced=True),
|
||||||
|
io.Boolean.Input("keep_first_frame", default=True, advanced=True),
|
||||||
|
io.Int.Input("num_latent_frames_per_chunk", default=9, min=1, max=256, advanced=True),
|
||||||
|
io.Boolean.Input("is_cfg_zero_star", default=False, advanced=True),
|
||||||
|
io.Boolean.Input("use_zero_init", default=True, advanced=True),
|
||||||
|
io.Int.Input("zero_steps", default=1, min=0, max=10000, advanced=True),
|
||||||
|
io.Boolean.Input("is_skip_first_chunk", default=False, advanced=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(display_name="output"),
|
||||||
|
io.Latent.Output(display_name="denoised_output"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(
|
||||||
|
cls,
|
||||||
|
model,
|
||||||
|
add_noise,
|
||||||
|
noise_seed,
|
||||||
|
cfg,
|
||||||
|
positive,
|
||||||
|
negative,
|
||||||
|
latent_image,
|
||||||
|
pyramid_steps,
|
||||||
|
stage_range,
|
||||||
|
is_distilled,
|
||||||
|
is_amplify_first_stage,
|
||||||
|
scheduler_mode,
|
||||||
|
gamma,
|
||||||
|
shift,
|
||||||
|
use_dynamic_shifting,
|
||||||
|
time_shift_type,
|
||||||
|
base_image_seq_len,
|
||||||
|
max_image_seq_len,
|
||||||
|
base_shift,
|
||||||
|
max_shift,
|
||||||
|
num_train_timesteps,
|
||||||
|
history_sizes,
|
||||||
|
keep_first_frame,
|
||||||
|
num_latent_frames_per_chunk,
|
||||||
|
is_cfg_zero_star,
|
||||||
|
use_zero_init,
|
||||||
|
zero_steps,
|
||||||
|
is_skip_first_chunk,
|
||||||
|
) -> io.NodeOutput:
|
||||||
|
latent = latent_image.copy()
|
||||||
|
latent_samples = comfy.sample.fix_empty_latent_channels(model, latent["samples"], latent.get("downscale_ratio_spacial", None))
|
||||||
|
|
||||||
|
stage_steps = _parse_int_list(pyramid_steps, [10, 10, 10])
|
||||||
|
stage_steps = [max(1, int(s)) for s in stage_steps]
|
||||||
|
stage_count = len(stage_steps)
|
||||||
|
history_sizes_list = sorted([max(0, int(v)) for v in _parse_int_list(history_sizes, [16, 2, 1])], reverse=True)
|
||||||
|
|
||||||
|
stage_range_values = _parse_float_list(stage_range, [0.0, 1.0 / 3.0, 2.0 / 3.0, 1.0])
|
||||||
|
if len(stage_range_values) != stage_count + 1:
|
||||||
|
stage_range_values = [float(i) / float(stage_count) for i in range(stage_count + 1)]
|
||||||
|
|
||||||
|
stage_tables = _helios_stage_tables(
|
||||||
|
stage_count=stage_count,
|
||||||
|
stage_range=stage_range_values,
|
||||||
|
gamma=float(gamma),
|
||||||
|
num_train_timesteps=int(num_train_timesteps),
|
||||||
|
shift=float(shift),
|
||||||
|
)
|
||||||
|
|
||||||
|
b, c, t, h, w = latent_samples.shape
|
||||||
|
chunk_t = max(1, int(num_latent_frames_per_chunk))
|
||||||
|
chunk_count = max(1, (t + chunk_t - 1) // chunk_t)
|
||||||
|
low_scale = 2 ** max(0, stage_count - 1)
|
||||||
|
low_h = max(1, h // low_scale)
|
||||||
|
low_w = max(1, w // low_scale)
|
||||||
|
|
||||||
|
base_latent = torch.zeros((b, c, chunk_t, low_h, low_w), dtype=latent_samples.dtype, layout=latent_samples.layout, device=latent_samples.device)
|
||||||
|
|
||||||
|
if add_noise:
|
||||||
|
stage_latent = comfy.sample.prepare_noise(base_latent, noise_seed)
|
||||||
|
else:
|
||||||
|
stage_latent = torch.zeros_like(base_latent, device="cpu")
|
||||||
|
|
||||||
|
stage_latent = stage_latent.to(base_latent.dtype).to(comfy.model_management.intermediate_device())
|
||||||
|
euler_sampler = comfy.samplers.KSAMPLER(_helios_euler_sample)
|
||||||
|
|
||||||
|
latents_history_short = _extract_condition_value(positive, "latents_history_short")
|
||||||
|
latents_history_mid = _extract_condition_value(positive, "latents_history_mid")
|
||||||
|
latents_history_long = _extract_condition_value(positive, "latents_history_long")
|
||||||
|
image_latent_prefix = latent.get("helios_image_latent_prefix", None)
|
||||||
|
if latents_history_short is None and "helios_history_latent" in latent:
|
||||||
|
positive, negative = _set_helios_history_values(
|
||||||
|
positive,
|
||||||
|
negative,
|
||||||
|
latent["helios_history_latent"],
|
||||||
|
history_sizes_list,
|
||||||
|
keep_first_frame,
|
||||||
|
prefix_latent=image_latent_prefix,
|
||||||
|
)
|
||||||
|
latents_history_short = _extract_condition_value(positive, "latents_history_short")
|
||||||
|
latents_history_mid = _extract_condition_value(positive, "latents_history_mid")
|
||||||
|
latents_history_long = _extract_condition_value(positive, "latents_history_long")
|
||||||
|
|
||||||
|
x0_output = {}
|
||||||
|
generated_chunks = []
|
||||||
|
if latents_history_short is not None and latents_history_mid is not None and latents_history_long is not None:
|
||||||
|
rolling_history = torch.cat([latents_history_long, latents_history_mid, latents_history_short], dim=2)
|
||||||
|
elif "helios_history_latent" in latent:
|
||||||
|
rolling_history = latent["helios_history_latent"]
|
||||||
|
else:
|
||||||
|
hist_len = max(1, sum(history_sizes_list))
|
||||||
|
rolling_history = torch.zeros((b, c, hist_len, h, w), device=latent_samples.device, dtype=latent_samples.dtype)
|
||||||
|
|
||||||
|
for chunk_idx in range(chunk_count):
|
||||||
|
if add_noise:
|
||||||
|
stage_latent = comfy.sample.prepare_noise(base_latent, noise_seed + chunk_idx).to(base_latent.dtype).to(comfy.model_management.intermediate_device())
|
||||||
|
else:
|
||||||
|
stage_latent = torch.zeros_like(base_latent, device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
positive_chunk, negative_chunk = _set_helios_history_values(
|
||||||
|
positive,
|
||||||
|
negative,
|
||||||
|
rolling_history,
|
||||||
|
history_sizes_list,
|
||||||
|
keep_first_frame,
|
||||||
|
prefix_latent=image_latent_prefix,
|
||||||
|
)
|
||||||
|
latents_history_short = _extract_condition_value(positive_chunk, "latents_history_short")
|
||||||
|
latents_history_mid = _extract_condition_value(positive_chunk, "latents_history_mid")
|
||||||
|
latents_history_long = _extract_condition_value(positive_chunk, "latents_history_long")
|
||||||
|
|
||||||
|
for stage_idx in range(stage_count):
|
||||||
|
if stage_idx > 0:
|
||||||
|
stage_latent = _upsample_latent_5d(stage_latent, scale=2)
|
||||||
|
|
||||||
|
ori_sigma = 1.0 - float(stage_tables["ori_start_sigmas"][stage_idx])
|
||||||
|
alpha = 1.0 / (math.sqrt(1.0 + (1.0 / gamma)) * (1.0 - ori_sigma) + ori_sigma)
|
||||||
|
beta = alpha * (1.0 - ori_sigma) / math.sqrt(gamma)
|
||||||
|
|
||||||
|
noise = _sample_block_noise_like(stage_latent, gamma, patch_size=(1, 2, 2)).to(stage_latent)
|
||||||
|
stage_latent = alpha * stage_latent + beta * noise
|
||||||
|
|
||||||
|
sigmas = _helios_stage_sigmas(
|
||||||
|
stage_idx=stage_idx,
|
||||||
|
stage_steps=stage_steps[stage_idx],
|
||||||
|
stage_tables=stage_tables,
|
||||||
|
is_distilled=is_distilled,
|
||||||
|
is_amplify_first_stage=is_amplify_first_stage and chunk_idx == 0,
|
||||||
|
).to(stage_latent.dtype)
|
||||||
|
timesteps = _helios_stage_timesteps(
|
||||||
|
stage_idx=stage_idx,
|
||||||
|
stage_steps=stage_steps[stage_idx],
|
||||||
|
stage_tables=stage_tables,
|
||||||
|
is_distilled=is_distilled,
|
||||||
|
is_amplify_first_stage=is_amplify_first_stage and chunk_idx == 0,
|
||||||
|
).to(stage_latent.dtype)
|
||||||
|
if use_dynamic_shifting:
|
||||||
|
patch_size = (1, 2, 2)
|
||||||
|
image_seq_len = (stage_latent.shape[-1] * stage_latent.shape[-2] * stage_latent.shape[-3]) // (patch_size[0] * patch_size[1] * patch_size[2])
|
||||||
|
mu = _calculate_shift(
|
||||||
|
image_seq_len=image_seq_len,
|
||||||
|
base_seq_len=base_image_seq_len,
|
||||||
|
max_seq_len=max_image_seq_len,
|
||||||
|
base_shift=base_shift,
|
||||||
|
max_shift=max_shift,
|
||||||
|
)
|
||||||
|
sigmas = _time_shift(sigmas, mu=mu, sigma=1.0, mode=time_shift_type).to(stage_latent.dtype)
|
||||||
|
tmin = torch.min(timesteps)
|
||||||
|
tmax = torch.max(timesteps)
|
||||||
|
timesteps = tmin + sigmas[:-1] * (tmax - tmin)
|
||||||
|
|
||||||
|
indices_hidden_states, idx_short, idx_mid, idx_long = _build_helios_indices(
|
||||||
|
batch=stage_latent.shape[0],
|
||||||
|
history_sizes=history_sizes_list,
|
||||||
|
keep_first_frame=keep_first_frame,
|
||||||
|
hidden_frames=stage_latent.shape[2],
|
||||||
|
device=stage_latent.device,
|
||||||
|
dtype=stage_latent.dtype,
|
||||||
|
)
|
||||||
|
positive_stage = node_helpers.conditioning_set_values(positive_chunk, {"indices_hidden_states": indices_hidden_states})
|
||||||
|
negative_stage = node_helpers.conditioning_set_values(negative_chunk, {"indices_hidden_states": indices_hidden_states})
|
||||||
|
|
||||||
|
if latents_history_short is not None:
|
||||||
|
values = {"latents_history_short": latents_history_short, "indices_latents_history_short": idx_short}
|
||||||
|
positive_stage = node_helpers.conditioning_set_values(positive_stage, values)
|
||||||
|
negative_stage = node_helpers.conditioning_set_values(negative_stage, values)
|
||||||
|
|
||||||
|
if latents_history_mid is not None:
|
||||||
|
values = {"latents_history_mid": latents_history_mid, "indices_latents_history_mid": idx_mid}
|
||||||
|
positive_stage = node_helpers.conditioning_set_values(positive_stage, values)
|
||||||
|
negative_stage = node_helpers.conditioning_set_values(negative_stage, values)
|
||||||
|
|
||||||
|
if latents_history_long is not None:
|
||||||
|
values = {"latents_history_long": latents_history_long, "indices_latents_history_long": idx_long}
|
||||||
|
positive_stage = node_helpers.conditioning_set_values(positive_stage, values)
|
||||||
|
negative_stage = node_helpers.conditioning_set_values(negative_stage, values)
|
||||||
|
|
||||||
|
cfg_use = 1.0 if is_distilled else cfg
|
||||||
|
|
||||||
|
if stage_idx == 0 and add_noise:
|
||||||
|
noise = comfy.sample.prepare_noise(stage_latent, noise_seed + chunk_idx * 100 + stage_idx)
|
||||||
|
latent_start = torch.zeros_like(stage_latent)
|
||||||
|
else:
|
||||||
|
sigma0 = max(float(sigmas[0].item()), 1e-6)
|
||||||
|
noise = (stage_latent / sigma0).to("cpu")
|
||||||
|
latent_start = torch.zeros_like(stage_latent)
|
||||||
|
|
||||||
|
stage_start_for_dmd = stage_latent.clone()
|
||||||
|
|
||||||
|
if is_distilled:
|
||||||
|
sampler = comfy.samplers.KSAMPLER(
|
||||||
|
_helios_dmd_sample,
|
||||||
|
extra_options={
|
||||||
|
"dmd_noisy_tensor": stage_start_for_dmd,
|
||||||
|
"dmd_sigmas": sigmas,
|
||||||
|
"dmd_timesteps": timesteps,
|
||||||
|
"all_timesteps": timesteps,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if scheduler_mode == "unipc_bh2":
|
||||||
|
sampler = comfy.samplers.ksampler("uni_pc_bh2")
|
||||||
|
else:
|
||||||
|
sampler = euler_sampler
|
||||||
|
|
||||||
|
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
|
||||||
|
stage_model = model
|
||||||
|
if is_cfg_zero_star and not is_distilled:
|
||||||
|
stage_model = model.clone()
|
||||||
|
stage_model.model_options = comfy.model_patcher.set_model_options_pre_cfg_function(
|
||||||
|
stage_model.model_options,
|
||||||
|
_build_cfg_zero_star_pre_cfg(stage_idx=stage_idx, zero_steps=zero_steps, use_zero_init=use_zero_init),
|
||||||
|
disable_cfg1_optimization=True,
|
||||||
|
)
|
||||||
|
stage_latent = comfy.sample.sample_custom(
|
||||||
|
stage_model,
|
||||||
|
noise,
|
||||||
|
cfg_use,
|
||||||
|
sampler,
|
||||||
|
sigmas,
|
||||||
|
positive_stage,
|
||||||
|
negative_stage,
|
||||||
|
latent_start,
|
||||||
|
noise_mask=None,
|
||||||
|
callback=callback,
|
||||||
|
disable_pbar=not comfy.utils.PROGRESS_BAR_ENABLED,
|
||||||
|
seed=noise_seed + chunk_idx * 100 + stage_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stage_latent.shape[-2] != h or stage_latent.shape[-1] != w:
|
||||||
|
b2, c2, t2, h2, w2 = stage_latent.shape
|
||||||
|
x = stage_latent.permute(0, 2, 1, 3, 4).reshape(b2 * t2, c2, h2, w2)
|
||||||
|
x = comfy.utils.common_upscale(x, w, h, "nearest-exact", "disabled")
|
||||||
|
stage_latent = x.reshape(b2, t2, c2, h, w).permute(0, 2, 1, 3, 4)
|
||||||
|
stage_latent = stage_latent[:, :, :, :h, :w]
|
||||||
|
|
||||||
|
generated_chunks.append(stage_latent)
|
||||||
|
if keep_first_frame and ((chunk_idx == 0 and image_latent_prefix is None) or (is_skip_first_chunk and chunk_idx == 1)):
|
||||||
|
image_latent_prefix = stage_latent[:, :, :1]
|
||||||
|
rolling_history = torch.cat([rolling_history, stage_latent.to(rolling_history.device, rolling_history.dtype)], dim=2)
|
||||||
|
keep_hist = max(1, sum(history_sizes_list))
|
||||||
|
rolling_history = rolling_history[:, :, -keep_hist:]
|
||||||
|
|
||||||
|
stage_latent = torch.cat(generated_chunks, dim=2)[:, :, :t]
|
||||||
|
|
||||||
|
out = latent.copy()
|
||||||
|
out.pop("downscale_ratio_spacial", None)
|
||||||
|
out["samples"] = stage_latent
|
||||||
|
|
||||||
|
if "x0" in x0_output:
|
||||||
|
x0_out = model.model.process_latent_out(x0_output["x0"].cpu())
|
||||||
|
out_denoised = latent.copy()
|
||||||
|
out_denoised["samples"] = x0_out
|
||||||
|
else:
|
||||||
|
out_denoised = out
|
||||||
|
|
||||||
|
return io.NodeOutput(out, out_denoised)
|
||||||
|
|
||||||
|
|
||||||
|
class HeliosExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
HeliosImageToVideo,
|
||||||
|
HeliosVideoToVideo,
|
||||||
|
HeliosHistoryConditioning,
|
||||||
|
HeliosPyramidSampler,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> HeliosExtension:
|
||||||
|
return HeliosExtension()
|
||||||
5
nodes.py
5
nodes.py
@ -976,7 +976,7 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "helios", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@ -986,7 +986,7 @@ class CLIPLoader:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
|
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\nhelios: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
|
||||||
|
|
||||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||||
@ -2412,6 +2412,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_cosmos.py",
|
"nodes_cosmos.py",
|
||||||
"nodes_video.py",
|
"nodes_video.py",
|
||||||
"nodes_lumina2.py",
|
"nodes_lumina2.py",
|
||||||
|
"nodes_helios.py",
|
||||||
"nodes_wan.py",
|
"nodes_wan.py",
|
||||||
"nodes_lotus.py",
|
"nodes_lotus.py",
|
||||||
"nodes_hunyuan3d.py",
|
"nodes_hunyuan3d.py",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user