mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-18 07:35:05 +08:00
Refactor Helios integration and latent processing with new T2V support.
This commit is contained in:
parent
ae36a9d4fd
commit
d93133ee53
@ -783,3 +783,11 @@ class ZImagePixelSpace(ChromaRadiance):
|
||||
No VAE encoding/decoding — the model operates directly on RGB pixels.
|
||||
"""
|
||||
pass
|
||||
|
||||
class Helios(Wan21):
|
||||
"""Helios video model latent format
|
||||
|
||||
Helios uses the same latent format as Wan21 (same VAE architecture).
|
||||
Inherits latents_mean, latents_std, and processing methods from Wan21.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -6,11 +6,12 @@ import torch.nn as nn
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.wan.model import sinusoidal_embedding_1d, repeat_e
|
||||
from comfy.ldm.wan.model import sinusoidal_embedding_1d
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
|
||||
|
||||
|
||||
def pad_for_3d_conv(x, kernel_size):
|
||||
b, c, t, h, w = x.shape
|
||||
pt, ph, pw = kernel_size
|
||||
@ -20,6 +21,10 @@ def pad_for_3d_conv(x, kernel_size):
|
||||
return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
|
||||
|
||||
|
||||
def center_down_sample_3d(x, kernel_size):
|
||||
return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
|
||||
|
||||
|
||||
class OutputNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-6, operation_settings={}):
|
||||
@ -50,7 +55,8 @@ class OutputNorm(nn.Module):
|
||||
shift = shift.squeeze(2).to(hidden_states.device)
|
||||
scale = scale.squeeze(2).to(hidden_states.device)
|
||||
hidden_states = hidden_states[:, -original_context_length:, :]
|
||||
hidden_states = self.norm(hidden_states) * (1 + scale) + shift
|
||||
# Use float32 for numerical stability like diffusers
|
||||
hidden_states = (self.norm(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -272,36 +278,69 @@ class HeliosAttentionBlock(nn.Module):
|
||||
|
||||
def forward(self, x, context, e, freqs, original_context_length=None, transformer_options={}):
|
||||
if e.ndim == 4:
|
||||
e = (self.scale_shift_table.unsqueeze(0) + e.float()).chunk(6, dim=2)
|
||||
e = [v.squeeze(2) for v in e]
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table.unsqueeze(0).to(e.device) + e.float()
|
||||
).chunk(6, dim=2)
|
||||
shift_msa = shift_msa.squeeze(2)
|
||||
scale_msa = scale_msa.squeeze(2)
|
||||
gate_msa = gate_msa.squeeze(2)
|
||||
c_shift_msa = c_shift_msa.squeeze(2)
|
||||
c_scale_msa = c_scale_msa.squeeze(2)
|
||||
c_gate_msa = c_gate_msa.squeeze(2)
|
||||
else:
|
||||
e = (self.scale_shift_table + e.float()).chunk(6, dim=1)
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table.to(e.device) + e.float()
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# self-attn
|
||||
# Use float32 for numerical stability like diffusers
|
||||
# norm1 has elementwise_affine=False, so we can safely convert to float32
|
||||
norm_x = self.norm1(x.float())
|
||||
norm_x = (norm_x * (1 + scale_msa) + shift_msa).type_as(x)
|
||||
y = self.attn1(
|
||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||
norm_x,
|
||||
freqs=freqs,
|
||||
original_context_length=original_context_length,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||
x = (x.float() + y.float() * gate_msa).type_as(x)
|
||||
|
||||
# cross-attn
|
||||
if self.guidance_cross_attn and original_context_length is not None:
|
||||
history_seq_len = x.shape[1] - original_context_length
|
||||
history_x, x_main = torch.split(x, [history_seq_len, original_context_length], dim=1)
|
||||
# norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior
|
||||
norm_x_main = torch.nn.functional.layer_norm(
|
||||
x_main.float(),
|
||||
self.norm2.normalized_shape,
|
||||
self.norm2.weight.to(x_main.device).float() if self.norm2.weight is not None else None,
|
||||
self.norm2.bias.to(x_main.device).float() if self.norm2.bias is not None else None,
|
||||
self.norm2.eps,
|
||||
).type_as(x_main)
|
||||
x_main = x_main + self.attn2(
|
||||
self.norm2(x_main),
|
||||
norm_x_main,
|
||||
context=context,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x = torch.cat([history_x, x_main], dim=1)
|
||||
else:
|
||||
x = x + self.attn2(self.norm2(x), context=context, transformer_options=transformer_options)
|
||||
# norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior
|
||||
norm_x = torch.nn.functional.layer_norm(
|
||||
x.float(),
|
||||
self.norm2.normalized_shape,
|
||||
self.norm2.weight.to(x.device).float() if self.norm2.weight is not None else None,
|
||||
self.norm2.bias.to(x.device).float() if self.norm2.bias is not None else None,
|
||||
self.norm2.eps,
|
||||
).type_as(x)
|
||||
x = x + self.attn2(norm_x, context=context, transformer_options=transformer_options)
|
||||
|
||||
# ffn
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm3(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
# Use float32 for numerical stability like diffusers
|
||||
# norm3 has elementwise_affine=False, so we can safely convert to float32
|
||||
norm_x = self.norm3(x.float())
|
||||
norm_x = (norm_x * (1 + c_scale_msa) + c_shift_msa).type_as(x)
|
||||
y = self.ffn(norm_x)
|
||||
x = (x.float() + y.float() * c_gate_msa).type_as(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -358,7 +397,7 @@ class HeliosModel(torch.nn.Module):
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=torch.float32,
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
self.text_embedding = nn.Sequential(
|
||||
operations.Linear(
|
||||
@ -411,7 +450,7 @@ class HeliosModel(torch.nn.Module):
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=torch.float32,
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
self.patch_mid = operations.Conv3d(
|
||||
in_channels,
|
||||
@ -419,7 +458,7 @@ class HeliosModel(torch.nn.Module):
|
||||
kernel_size=tuple(2 * p for p in patch_size),
|
||||
stride=tuple(2 * p for p in patch_size),
|
||||
device=operation_settings.get("device"),
|
||||
dtype=torch.float32,
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
self.patch_long = operations.Conv3d(
|
||||
in_channels,
|
||||
@ -427,7 +466,7 @@ class HeliosModel(torch.nn.Module):
|
||||
kernel_size=tuple(4 * p for p in patch_size),
|
||||
stride=tuple(4 * p for p in patch_size),
|
||||
device=operation_settings.get("device"),
|
||||
dtype=torch.float32,
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
|
||||
# blocks
|
||||
@ -592,7 +631,7 @@ class HeliosModel(torch.nn.Module):
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
|
||||
# embeddings
|
||||
hidden_states = self.patch_embedding(hidden_states.float()).to(hidden_states.dtype)
|
||||
hidden_states = self.patch_embedding(hidden_states)
|
||||
_, _, post_t, post_h, post_w = hidden_states.shape
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
|
||||
@ -614,7 +653,7 @@ class HeliosModel(torch.nn.Module):
|
||||
original_context_length = hidden_states.shape[1]
|
||||
|
||||
if (latents_history_short is not None and indices_latents_history_short is not None and hasattr(self, "patch_short")):
|
||||
x_short = self.patch_short(latents_history_short.float()).to(hidden_states.dtype)
|
||||
x_short = self.patch_short(latents_history_short).to(hidden_states.dtype)
|
||||
_, _, ts, hs, ws = x_short.shape
|
||||
x_short = x_short.flatten(2).transpose(1, 2)
|
||||
f_short = self.rope_encode(
|
||||
@ -633,44 +672,70 @@ class HeliosModel(torch.nn.Module):
|
||||
freqs = torch.cat([f_short, freqs], dim=1)
|
||||
|
||||
if (latents_history_mid is not None and indices_latents_history_mid is not None and hasattr(self, "patch_mid")):
|
||||
x_mid = self.patch_mid(pad_for_3d_conv(latents_history_mid, (2, 4, 4)).float()).to(hidden_states.dtype)
|
||||
x_mid = self.patch_mid(pad_for_3d_conv(latents_history_mid, (2, 4, 4))).to(hidden_states.dtype)
|
||||
_, _, tm, hm, wm = x_mid.shape
|
||||
x_mid = x_mid.flatten(2).transpose(1, 2)
|
||||
mid_t = indices_latents_history_mid.shape[1]
|
||||
if ("hs" in locals()) and ("ws" in locals()):
|
||||
mid_h, mid_w = hs, ws
|
||||
else:
|
||||
mid_h, mid_w = hm * 2, wm * 2
|
||||
f_mid = self.rope_encode(
|
||||
t=tm * self.patch_size[0],
|
||||
h=hm * self.patch_size[1],
|
||||
w=wm * self.patch_size[2],
|
||||
steps_t=tm,
|
||||
steps_h=hm,
|
||||
steps_w=wm,
|
||||
t=mid_t * self.patch_size[0],
|
||||
h=mid_h * self.patch_size[1],
|
||||
w=mid_w * self.patch_size[2],
|
||||
steps_t=mid_t,
|
||||
steps_h=mid_h,
|
||||
steps_w=mid_w,
|
||||
device=x_mid.device,
|
||||
dtype=x_mid.dtype,
|
||||
transformer_options=transformer_options,
|
||||
frame_indices=indices_latents_history_mid,
|
||||
)
|
||||
f_mid = self._rope_downsample_3d(f_mid, (mid_t, mid_h, mid_w), (2, 2, 2))
|
||||
if f_mid.shape[1] != x_mid.shape[1]:
|
||||
f_mid = f_mid[:, :x_mid.shape[1]]
|
||||
hidden_states = torch.cat([x_mid, hidden_states], dim=1)
|
||||
freqs = torch.cat([f_mid, freqs], dim=1)
|
||||
|
||||
if (latents_history_long is not None and indices_latents_history_long is not None and hasattr(self, "patch_long")):
|
||||
x_long = self.patch_long(pad_for_3d_conv(latents_history_long, (4, 8, 8)).float()).to(hidden_states.dtype)
|
||||
x_long = self.patch_long(pad_for_3d_conv(latents_history_long, (4, 8, 8))).to(hidden_states.dtype)
|
||||
_, _, tl, hl, wl = x_long.shape
|
||||
x_long = x_long.flatten(2).transpose(1, 2)
|
||||
long_t = indices_latents_history_long.shape[1]
|
||||
if ("hs" in locals()) and ("ws" in locals()):
|
||||
long_h, long_w = hs, ws
|
||||
else:
|
||||
long_h, long_w = hl * 4, wl * 4
|
||||
f_long = self.rope_encode(
|
||||
t=tl * self.patch_size[0],
|
||||
h=hl * self.patch_size[1],
|
||||
w=wl * self.patch_size[2],
|
||||
steps_t=tl,
|
||||
steps_h=hl,
|
||||
steps_w=wl,
|
||||
t=long_t * self.patch_size[0],
|
||||
h=long_h * self.patch_size[1],
|
||||
w=long_w * self.patch_size[2],
|
||||
steps_t=long_t,
|
||||
steps_h=long_h,
|
||||
steps_w=long_w,
|
||||
device=x_long.device,
|
||||
dtype=x_long.dtype,
|
||||
transformer_options=transformer_options,
|
||||
frame_indices=indices_latents_history_long,
|
||||
)
|
||||
f_long = self._rope_downsample_3d(f_long, (long_t, long_h, long_w), (4, 4, 4))
|
||||
if f_long.shape[1] != x_long.shape[1]:
|
||||
f_long = f_long[:, :x_long.shape[1]]
|
||||
hidden_states = torch.cat([x_long, hidden_states], dim=1)
|
||||
freqs = torch.cat([f_long, freqs], dim=1)
|
||||
|
||||
history_context_length = hidden_states.shape[1] - original_context_length
|
||||
mismatch = hidden_states.shape[1] != freqs.shape[1]
|
||||
summary_key = (
|
||||
int(post_t),
|
||||
int(post_h),
|
||||
int(post_w),
|
||||
int(original_context_length),
|
||||
int(hidden_states.shape[1]),
|
||||
int(freqs.shape[1]),
|
||||
int(history_context_length),
|
||||
)
|
||||
|
||||
if timestep.ndim == 0:
|
||||
timestep = timestep.unsqueeze(0)
|
||||
@ -682,7 +747,7 @@ class HeliosModel(torch.nn.Module):
|
||||
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=hidden_states.dtype))
|
||||
e = e.reshape(batch_size, -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
context = self.text_embedding(context)
|
||||
context = self.text_embedding(context.to(dtype=hidden_states.dtype))
|
||||
|
||||
if self.zero_history_timestep and history_context_length > 0:
|
||||
timestep_t0 = torch.zeros((1, ), dtype=timestep.dtype, device=timestep.device)
|
||||
@ -701,7 +766,7 @@ class HeliosModel(torch.nn.Module):
|
||||
|
||||
e0 = e0.permute(0, 2, 1, 3)
|
||||
|
||||
for block in self.blocks:
|
||||
for i_b, block in enumerate(self.blocks):
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
context,
|
||||
@ -710,35 +775,46 @@ class HeliosModel(torch.nn.Module):
|
||||
original_context_length=original_context_length,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, e, original_context_length)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
return self.unpatchify(hidden_states, (post_t, post_h, post_w))
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
c = self.out_dim
|
||||
"""
|
||||
Unpatchify the output from proj_out back to video format.
|
||||
|
||||
Args:
|
||||
x: [batch, num_patches, out_dim * prod(patch_size)]
|
||||
grid_sizes: (num_frames, height, width) in patch space
|
||||
|
||||
Returns:
|
||||
[batch, out_dim, num_frames, height, width] in pixel space
|
||||
"""
|
||||
b = x.shape[0]
|
||||
u = x[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c)
|
||||
u = torch.einsum("bfhwpqrc->bcfphqwr", u)
|
||||
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
||||
return u
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True, assign=False):
|
||||
# Keep compatibility with reference diffusers key names.
|
||||
remapped = {}
|
||||
for k, v in state_dict.items():
|
||||
nk = k
|
||||
nk = nk.replace("condition_embedder.time_embedder.linear_1.", "time_embedding.0.")
|
||||
nk = nk.replace("condition_embedder.time_embedder.linear_2.", "time_embedding.2.")
|
||||
nk = nk.replace("condition_embedder.time_proj.", "time_projection.1.")
|
||||
nk = nk.replace("condition_embedder.text_embedder.linear_1.", "text_embedding.0.")
|
||||
nk = nk.replace("condition_embedder.text_embedder.linear_2.", "text_embedding.2.")
|
||||
nk = nk.replace("blocks.", "blocks.")
|
||||
remapped[nk] = v
|
||||
|
||||
return super().load_state_dict(remapped, strict=strict, assign=assign)
|
||||
|
||||
post_t, post_h, post_w = grid_sizes
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
|
||||
# Reshape: [B, T*H*W, out_dim*p_t*p_h*p_w] -> [B, T, H, W, p_t, p_h, p_w, out_dim]
|
||||
# Use -1 to let PyTorch infer the channel dimension (out_dim)
|
||||
hidden_states = x.reshape(b, post_t, post_h, post_w, p_t, p_h, p_w, -1)
|
||||
|
||||
# Permute: [B, T, H, W, p_t, p_h, p_w, C] -> [B, C, T, p_t, H, p_h, W, p_w]
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
|
||||
# Flatten patches: [B, C, T, p_t, H, p_h, W, p_w] -> [B, C, T*p_t, H*p_h, W*p_w]
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
return output
|
||||
def _rope_downsample_3d(self, freqs, grid_sizes, kernel_size):
|
||||
b, _, one, d, i2, j2 = freqs.shape
|
||||
gt, gh, gw = grid_sizes
|
||||
c = one * d * i2 * j2
|
||||
freqs_3d = freqs.reshape(b, gt, gh, gw, c).permute(0, 4, 1, 2, 3)
|
||||
freqs_3d = pad_for_3d_conv(freqs_3d, kernel_size)
|
||||
freqs_3d = center_down_sample_3d(freqs_3d, kernel_size)
|
||||
dt, dh, dw = freqs_3d.shape[2:]
|
||||
freqs_3d = freqs_3d.permute(0, 2, 3, 4, 1).reshape(b, dt * dh * dw, one, d, i2, j2)
|
||||
return freqs_3d
|
||||
|
||||
# Backward-compatible alias for existing integration points.
|
||||
HeliosTransformer3DModel = HeliosModel
|
||||
|
||||
@ -1287,17 +1287,52 @@ class Helios(BaseModel):
|
||||
"latents_history_short",
|
||||
"latents_history_mid",
|
||||
"latents_history_long",
|
||||
"helios_stage_sigmas",
|
||||
"helios_stage_timesteps",
|
||||
)
|
||||
|
||||
for key in cond_keys:
|
||||
value = kwargs.get(key, None)
|
||||
if value is None:
|
||||
continue
|
||||
if key.startswith("latents_"):
|
||||
value = self.process_latent_in(value)
|
||||
out[key] = comfy.conds.CONDRegular(value)
|
||||
# Diffusers forwards Helios history latents without latent-format re-normalization.
|
||||
# Keep raw history tensors to match transformer inputs across frameworks.
|
||||
if key in ("helios_stage_sigmas", "helios_stage_timesteps"):
|
||||
out[key] = comfy.conds.CONDConstant(value)
|
||||
else:
|
||||
out[key] = comfy.conds.CONDRegular(value)
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, **kwargs):
|
||||
stage_sigmas = kwargs.get("helios_stage_sigmas", None)
|
||||
stage_timesteps = kwargs.get("helios_stage_timesteps", None)
|
||||
if stage_sigmas is None or stage_timesteps is None:
|
||||
return timestep
|
||||
|
||||
if stage_sigmas.ndim > 1:
|
||||
stage_sigmas = stage_sigmas[0]
|
||||
if stage_timesteps.ndim > 1:
|
||||
stage_timesteps = stage_timesteps[0]
|
||||
|
||||
if stage_timesteps.numel() == 0 or stage_sigmas.numel() == 0:
|
||||
return timestep
|
||||
|
||||
if stage_sigmas.numel() == stage_timesteps.numel() + 1:
|
||||
sigma_candidates = stage_sigmas[:-1]
|
||||
else:
|
||||
sigma_candidates = stage_sigmas[: stage_timesteps.numel()]
|
||||
|
||||
if sigma_candidates.numel() == 0:
|
||||
return timestep
|
||||
|
||||
multiplier = float(getattr(self.model_sampling, "multiplier", 1000.0))
|
||||
sigma_in = timestep / multiplier
|
||||
idx = torch.argmin(torch.abs(sigma_in.unsqueeze(-1) - sigma_candidates.unsqueeze(0)), dim=-1)
|
||||
mapped = stage_timesteps[idx].to(dtype=timestep.dtype)
|
||||
if mapped.dtype.is_floating_point:
|
||||
mapped = torch.floor(mapped)
|
||||
return mapped
|
||||
|
||||
class WAN21(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
|
||||
@ -489,7 +489,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}condition_embedder.time_proj.weight'.format(key_prefix) in state_dict_keys and '{}patch_embedding.weight'.format(key_prefix) in state_dict_keys: # Helios
|
||||
helios_required_keys = (
|
||||
'{}patch_mid.weight'.format(key_prefix),
|
||||
'{}patch_long.weight'.format(key_prefix),
|
||||
)
|
||||
if all(k in state_dict_keys for k in helios_required_keys): # Helios
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "helios"
|
||||
|
||||
@ -501,8 +505,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["patch_size"] = patch_size
|
||||
dit_config["in_channels"] = patch_weight.shape[1]
|
||||
dit_config["out_channels"] = out_proj.shape[0] // math.prod(patch_size)
|
||||
dit_config["text_dim"] = state_dict['{}condition_embedder.text_embedder.linear_1.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["freq_dim"] = state_dict['{}condition_embedder.time_embedder.linear_1.weight'.format(key_prefix)].shape[1]
|
||||
text_w = state_dict['{}text_embedding.0.weight'.format(key_prefix)]
|
||||
time_w = state_dict['{}time_embedding.0.weight'.format(key_prefix)]
|
||||
dit_config["text_dim"] = text_w.shape[1]
|
||||
dit_config["freq_dim"] = time_w.shape[1]
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["num_attention_heads"] = inner_dim // 128
|
||||
dit_config["attention_head_dim"] = 128
|
||||
|
||||
@ -1143,7 +1143,7 @@ class Helios(supported_models_base.BASE):
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Wan21
|
||||
latent_format = latent_formats.Helios
|
||||
memory_usage_factor = 1.8
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
|
||||
@ -14,6 +14,9 @@ from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _parse_int_list(values, default):
|
||||
if values is None:
|
||||
return default
|
||||
@ -72,15 +75,73 @@ def _extract_condition_value(conditioning, key):
|
||||
return None
|
||||
|
||||
|
||||
def _process_latent_in_preserve_zero_frames(model, latent, valid_mask=None):
|
||||
if latent is None or len(latent.shape) != 5:
|
||||
return latent
|
||||
if valid_mask is None:
|
||||
raise ValueError("Helios requires `helios_history_valid_mask` for history latent conversion.")
|
||||
vm = valid_mask
|
||||
if not torch.is_tensor(vm):
|
||||
vm = torch.tensor(vm, device=latent.device)
|
||||
vm = vm.to(device=latent.device)
|
||||
if vm.ndim == 2:
|
||||
nonzero = vm.any(dim=0)
|
||||
else:
|
||||
nonzero = vm.reshape(-1)
|
||||
nonzero = nonzero.bool()
|
||||
|
||||
if nonzero.numel() == 0 or (not torch.any(nonzero)):
|
||||
return latent
|
||||
|
||||
if nonzero.shape[0] != latent.shape[2]:
|
||||
# Keep behavior safe when mask length does not match temporal length.
|
||||
nonzero = torch.zeros((latent.shape[2],), device=latent.device, dtype=torch.bool)
|
||||
|
||||
converted = model.model.process_latent_in(latent)
|
||||
out = latent.clone()
|
||||
out[:, :, nonzero, :, :] = converted[:, :, nonzero, :, :]
|
||||
return out
|
||||
|
||||
|
||||
def _upsample_latent_5d(latent, scale=2):
|
||||
b, c, t, h, w = latent.shape
|
||||
x = latent.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = comfy.utils.common_upscale(x, w * scale, h * scale, "nearest-exact", "disabled")
|
||||
x = comfy.utils.common_upscale(x, w * scale, h * scale, "nearest", "disabled")
|
||||
x = x.reshape(b, t, c, h * scale, w * scale).permute(0, 2, 1, 3, 4)
|
||||
return x
|
||||
|
||||
|
||||
def _sample_block_noise_like(latent, gamma, patch_size=(1, 2, 2)):
|
||||
def _downsample_latent_5d_bilinear_x2(latent):
|
||||
b, c, t, h, w = latent.shape
|
||||
x = latent.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = comfy.utils.common_upscale(x, max(1, w // 2), max(1, h // 2), "bilinear", "disabled") * 2.0
|
||||
x = x.reshape(b, t, c, max(1, h // 2), max(1, w // 2)).permute(0, 2, 1, 3, 4)
|
||||
return x
|
||||
|
||||
|
||||
def _prepare_stage0_latent(batch, channels, frames, height, width, stage_count, add_noise, seed, dtype, layout, device):
|
||||
"""Prepare initial latent for stage 0 with optional noise"""
|
||||
full_latent = torch.zeros((batch, channels, frames, height, width), dtype=dtype, layout=layout, device=device)
|
||||
if add_noise:
|
||||
full_latent = comfy.sample.prepare_noise(full_latent, seed).to(dtype)
|
||||
|
||||
# Downsample to stage 0 resolution
|
||||
stage_latent = full_latent
|
||||
for _ in range(max(0, int(stage_count) - 1)):
|
||||
stage_latent = _downsample_latent_5d_bilinear_x2(stage_latent)
|
||||
return stage_latent
|
||||
|
||||
|
||||
def _downsample_latent_for_stage0(latent, stage_count):
|
||||
"""Downsample latent to stage 0 resolution (like Diffusers does)"""
|
||||
stage_latent = latent
|
||||
for _ in range(max(0, int(stage_count) - 1)):
|
||||
stage_latent = _downsample_latent_5d_bilinear_x2(stage_latent)
|
||||
return stage_latent
|
||||
|
||||
|
||||
|
||||
def _sample_block_noise_like(latent, gamma, patch_size=(1, 2, 2), generator=None, seed=None):
|
||||
b, c, t, h, w = latent.shape
|
||||
_, ph, pw = patch_size
|
||||
block_size = ph * pw
|
||||
@ -88,13 +149,38 @@ def _sample_block_noise_like(latent, gamma, patch_size=(1, 2, 2)):
|
||||
cov = torch.eye(block_size, device=latent.device) * (1.0 + gamma) - torch.ones(block_size, block_size, device=latent.device) * gamma
|
||||
cov += torch.eye(block_size, device=latent.device) * 1e-6
|
||||
|
||||
dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=latent.device), covariance_matrix=cov)
|
||||
block_number = b * c * t * max(1, h // ph) * max(1, w // pw)
|
||||
h_blocks = h // ph
|
||||
w_blocks = w // pw
|
||||
block_number = b * c * t * h_blocks * w_blocks
|
||||
|
||||
noise = dist.sample((block_number,))
|
||||
noise = noise.view(b, c, t, max(1, h // ph), max(1, w // pw), ph, pw)
|
||||
noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(b, c, t, max(1, h // ph) * ph, max(1, w // pw) * pw)
|
||||
noise = noise[:, :, :, :h, :w]
|
||||
if generator is not None:
|
||||
# Exact Diffusers sampling path (MultivariateNormal.sample), while consuming
|
||||
# from an explicit generator by temporarily swapping default RNG state.
|
||||
with torch.random.fork_rng(devices=[latent.device] if latent.device.type == "cuda" else []):
|
||||
if latent.device.type == "cuda":
|
||||
torch.cuda.set_rng_state(generator.get_state(), device=latent.device)
|
||||
else:
|
||||
torch.random.set_rng_state(generator.get_state())
|
||||
dist = torch.distributions.MultivariateNormal(
|
||||
torch.zeros(block_size, device=latent.device),
|
||||
covariance_matrix=cov,
|
||||
)
|
||||
noise = dist.sample((block_number,))
|
||||
if latent.device.type == "cuda":
|
||||
generator.set_state(torch.cuda.get_rng_state(device=latent.device))
|
||||
else:
|
||||
generator.set_state(torch.random.get_rng_state())
|
||||
elif seed is None:
|
||||
dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=latent.device), covariance_matrix=cov)
|
||||
noise = dist.sample((block_number,))
|
||||
else:
|
||||
# Use deterministic RNG when seed is provided (for cross-framework alignment).
|
||||
with torch.random.fork_rng(devices=[latent.device] if latent.device.type == "cuda" else []):
|
||||
torch.manual_seed(int(seed))
|
||||
dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=latent.device), covariance_matrix=cov)
|
||||
noise = dist.sample((block_number,))
|
||||
noise = noise.view(b, c, t, h_blocks, w_blocks, ph, pw)
|
||||
noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(b, c, t, h, w)
|
||||
return noise
|
||||
|
||||
|
||||
@ -144,8 +230,10 @@ def _helios_stage_tables(stage_count, stage_range, gamma, num_train_timesteps=10
|
||||
|
||||
tmax = min(float(sigmas[int(start_ratio * num_train_timesteps)].item() * num_train_timesteps), 999.0)
|
||||
tmin = float(sigmas[min(int(end_ratio * num_train_timesteps), num_train_timesteps - 1)].item() * num_train_timesteps)
|
||||
timesteps_per_stage[i] = torch.linspace(tmax, tmin, num_train_timesteps)
|
||||
sigmas_per_stage[i] = torch.linspace(0.999, 0.0, num_train_timesteps)
|
||||
timesteps_per_stage[i] = torch.linspace(tmax, tmin, num_train_timesteps + 1)[:-1]
|
||||
# Fixed: Use same sigma range [0.999, 0] for all stages like Diffusers
|
||||
sigmas_per_stage[i] = torch.linspace(0.999, 0.0, num_train_timesteps + 1)[:-1]
|
||||
|
||||
|
||||
return {
|
||||
"ori_start_sigmas": ori_start_sigmas,
|
||||
@ -163,7 +251,8 @@ def _helios_stage_sigmas(stage_idx, stage_steps, stage_tables, is_distilled=Fals
|
||||
stage_steps = stage_steps * 2 if (is_amplify_first_stage and stage_idx == 0) else stage_steps
|
||||
|
||||
stage_sigma_src = stage_tables["sigmas_per_stage"][stage_idx]
|
||||
sigmas = torch.linspace(float(stage_sigma_src[0].item()), float(stage_sigma_src[-1].item()), stage_steps + 1)
|
||||
sigmas = torch.linspace(float(stage_sigma_src[0].item()), float(stage_sigma_src[-1].item()), stage_steps)
|
||||
sigmas = torch.cat([sigmas, torch.zeros(1, dtype=sigmas.dtype, device=sigmas.device)], dim=0)
|
||||
return sigmas
|
||||
|
||||
|
||||
@ -213,23 +302,37 @@ def _build_cfg_zero_star_pre_cfg(stage_idx, zero_steps, use_zero_init):
|
||||
state["i"] += 1
|
||||
return conds_out
|
||||
|
||||
noise_pred_text = conds_out[0]
|
||||
noise_uncond = conds_out[1]
|
||||
denoised_text = conds_out[0] # apply_model 返回的 denoised
|
||||
denoised_uncond = conds_out[1]
|
||||
cfg = float(args.get("cond_scale", 1.0))
|
||||
x = args["input"] # 当前的 noisy latent
|
||||
sigma = args["sigma"] # 当前的 sigma
|
||||
|
||||
positive_flat = noise_pred_text.view(noise_pred_text.shape[0], -1)
|
||||
negative_flat = noise_uncond.view(noise_uncond.shape[0], -1)
|
||||
# 关键修复:将 denoised 转换为 flow
|
||||
# denoised = x - flow * sigma => flow = (x - denoised) / sigma
|
||||
sigma_reshaped = sigma.reshape(sigma.shape[0], *([1] * (denoised_text.ndim - 1)))
|
||||
sigma_safe = torch.clamp(sigma_reshaped, min=1e-8)
|
||||
|
||||
flow_text = (x - denoised_text) / sigma_safe
|
||||
flow_uncond = (x - denoised_uncond) / sigma_safe
|
||||
|
||||
# 在 flow 空间做 CFG Zero Star
|
||||
positive_flat = flow_text.reshape(flow_text.shape[0], -1)
|
||||
negative_flat = flow_uncond.reshape(flow_uncond.shape[0], -1)
|
||||
alpha = _optimized_scale(positive_flat, negative_flat)
|
||||
alpha = alpha.view(noise_pred_text.shape[0], *([1] * (noise_pred_text.ndim - 1))).to(noise_pred_text.dtype)
|
||||
alpha = alpha.reshape(flow_text.shape[0], *([1] * (flow_text.ndim - 1))).to(flow_text.dtype)
|
||||
|
||||
if stage_idx == 0 and state["i"] <= int(zero_steps) and bool(use_zero_init):
|
||||
final = noise_pred_text * 0.0
|
||||
flow_final = flow_text * 0.0
|
||||
else:
|
||||
final = noise_uncond * alpha + cfg * (noise_pred_text - noise_uncond * alpha)
|
||||
flow_final = flow_uncond * alpha + cfg * (flow_text - flow_uncond * alpha)
|
||||
|
||||
# 将 flow 转回 denoised
|
||||
denoised_final = x - flow_final * sigma_safe
|
||||
|
||||
state["i"] += 1
|
||||
# Return identical cond/uncond so downstream cfg_function keeps `final` unchanged.
|
||||
return [final, final]
|
||||
return [denoised_final, denoised_final]
|
||||
|
||||
return pre_cfg_fn
|
||||
|
||||
@ -310,6 +413,8 @@ def _set_helios_history_values(positive, negative, history_latent, history_sizes
|
||||
latent = history_latent
|
||||
if latent is None or len(latent.shape) != 5:
|
||||
return positive, negative
|
||||
if prefix_latent is not None and (latent.device != prefix_latent.device or latent.dtype != prefix_latent.dtype):
|
||||
latent = latent.to(device=prefix_latent.device, dtype=prefix_latent.dtype)
|
||||
|
||||
sizes = list(history_sizes)
|
||||
if len(sizes) != 3:
|
||||
@ -342,13 +447,15 @@ def _set_helios_history_values(positive, negative, history_latent, history_sizes
|
||||
prefix = latent[:, :, :1]
|
||||
else:
|
||||
prefix = torch.zeros(latent.shape[0], latent.shape[1], 1, latent.shape[3], latent.shape[4], device=latent.device, dtype=latent.dtype)
|
||||
if prefix.device != latents_history_short_base.device or prefix.dtype != latents_history_short_base.dtype:
|
||||
prefix = prefix.to(device=latents_history_short_base.device, dtype=latents_history_short_base.dtype)
|
||||
latents_history_short = torch.cat([prefix, latents_history_short_base], dim=2)
|
||||
else:
|
||||
latents_history_short = latents_history_short_base
|
||||
|
||||
idx_short = torch.arange(latents_history_short.shape[2], device=latent.device, dtype=latent.dtype).unsqueeze(0).expand(latent.shape[0], -1)
|
||||
idx_mid = torch.arange(latents_history_mid.shape[2], device=latent.device, dtype=latent.dtype).unsqueeze(0).expand(latent.shape[0], -1)
|
||||
idx_long = torch.arange(latents_history_long.shape[2], device=latent.device, dtype=latent.dtype).unsqueeze(0).expand(latent.shape[0], -1)
|
||||
idx_short = torch.arange(latents_history_short.shape[2], device=latent.device, dtype=torch.int64).unsqueeze(0).expand(latent.shape[0], -1)
|
||||
idx_mid = torch.arange(latents_history_mid.shape[2], device=latent.device, dtype=torch.int64).unsqueeze(0).expand(latent.shape[0], -1)
|
||||
idx_long = torch.arange(latents_history_long.shape[2], device=latent.device, dtype=torch.int64).unsqueeze(0).expand(latent.shape[0], -1)
|
||||
|
||||
values = {
|
||||
"latents_history_short": latents_history_short,
|
||||
@ -364,7 +471,7 @@ def _set_helios_history_values(positive, negative, history_latent, history_sizes
|
||||
return positive, negative
|
||||
|
||||
|
||||
def _build_helios_indices(batch, history_sizes, keep_first_frame, hidden_frames, device, dtype):
|
||||
def _build_helios_indices(batch, history_sizes, keep_first_frame, hidden_frames, device):
|
||||
sizes = list(history_sizes)
|
||||
if len(sizes) != 3:
|
||||
sizes = [16, 2, 1]
|
||||
@ -373,13 +480,13 @@ def _build_helios_indices(batch, history_sizes, keep_first_frame, hidden_frames,
|
||||
|
||||
if keep_first_frame:
|
||||
total = 1 + long_size + mid_size + short_base_size + hidden_frames
|
||||
indices = torch.arange(total, device=device, dtype=dtype)
|
||||
indices = torch.arange(total, device=device, dtype=torch.int64)
|
||||
splits = [1, long_size, mid_size, short_base_size, hidden_frames]
|
||||
indices_prefix, idx_long, idx_mid, idx_1x, idx_hidden = torch.split(indices, splits, dim=0)
|
||||
idx_short = torch.cat([indices_prefix, idx_1x], dim=0)
|
||||
else:
|
||||
total = long_size + mid_size + short_base_size + hidden_frames
|
||||
indices = torch.arange(total, device=device, dtype=dtype)
|
||||
indices = torch.arange(total, device=device, dtype=torch.int64)
|
||||
splits = [long_size, mid_size, short_base_size, hidden_frames]
|
||||
idx_long, idx_mid, idx_short, idx_hidden = torch.split(indices, splits, dim=0)
|
||||
|
||||
@ -450,7 +557,9 @@ class HeliosImageToVideo(io.ComfyNode):
|
||||
sizes = sorted([max(0, int(v)) for v in sizes], reverse=True)
|
||||
hist_len = max(1, sum(sizes))
|
||||
history_latent = torch.zeros([batch_size, latent_channels, hist_len, latent.shape[-2], latent.shape[-1]], device=latent.device, dtype=latent.dtype)
|
||||
history_valid_mask = torch.zeros((batch_size, hist_len), device=latent.device, dtype=torch.bool)
|
||||
image_latent_prefix = None
|
||||
i2v_noise_gen = None
|
||||
|
||||
if start_image is not None:
|
||||
image = comfy.utils.common_upscale(start_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
@ -459,20 +568,36 @@ class HeliosImageToVideo(io.ComfyNode):
|
||||
image_latent_prefix = img_latent[:, :, :1]
|
||||
|
||||
if add_noise_to_image_latents:
|
||||
g = torch.Generator(device=img_latent.device)
|
||||
g.manual_seed(int(noise_seed))
|
||||
i2v_noise_gen = torch.Generator(device=img_latent.device)
|
||||
i2v_noise_gen.manual_seed(int(noise_seed))
|
||||
sigma = (
|
||||
torch.rand((img_latent.shape[0], 1, 1, 1, 1), device=img_latent.device, generator=g, dtype=img_latent.dtype)
|
||||
torch.rand((img_latent.shape[0], 1, 1, 1, 1), device=img_latent.device, generator=i2v_noise_gen, dtype=img_latent.dtype)
|
||||
* (float(image_noise_sigma_max) - float(image_noise_sigma_min))
|
||||
+ float(image_noise_sigma_min)
|
||||
)
|
||||
image_latent_prefix = sigma * torch.randn_like(image_latent_prefix, generator=g) + (1.0 - sigma) * image_latent_prefix
|
||||
image_latent_prefix = sigma * torch.randn_like(image_latent_prefix, generator=i2v_noise_gen) + (1.0 - sigma) * image_latent_prefix
|
||||
|
||||
min_frames = max(1, (int(num_latent_frames_per_chunk) - 1) * 4 + 1)
|
||||
fake_video = image.repeat(min_frames, 1, 1, 1)
|
||||
fake_latents_full = vae.encode(fake_video)
|
||||
fake_latent = comfy.utils.repeat_to_batch_size(fake_latents_full[:, :, -1:], batch_size)
|
||||
# Diffusers parity for I2V:
|
||||
# when adding noise to image latents, fake_image_latents used for history are also noised.
|
||||
if add_noise_to_image_latents:
|
||||
if i2v_noise_gen is None:
|
||||
i2v_noise_gen = torch.Generator(device=fake_latent.device)
|
||||
i2v_noise_gen.manual_seed(int(noise_seed))
|
||||
# Keep backward compatibility with existing I2V node inputs:
|
||||
# this node exposes only image sigma controls, while fake history
|
||||
# latents follow the video-noise path in Diffusers.
|
||||
fake_sigma = (
|
||||
torch.rand((fake_latent.shape[0], 1, 1, 1, 1), device=fake_latent.device, generator=i2v_noise_gen, dtype=fake_latent.dtype)
|
||||
* (float(image_noise_sigma_max) - float(image_noise_sigma_min))
|
||||
+ float(image_noise_sigma_min)
|
||||
)
|
||||
fake_latent = fake_sigma * torch.randn_like(fake_latent, generator=i2v_noise_gen) + (1.0 - fake_sigma) * fake_latent
|
||||
history_latent[:, :, -1:] = fake_latent
|
||||
history_valid_mask[:, -1] = True
|
||||
|
||||
positive, negative = _set_helios_history_values(positive, negative, history_latent, sizes, keep_first_frame, prefix_latent=image_latent_prefix)
|
||||
return io.NodeOutput(
|
||||
@ -482,6 +607,85 @@ class HeliosImageToVideo(io.ComfyNode):
|
||||
"samples": latent,
|
||||
"helios_history_latent": history_latent,
|
||||
"helios_image_latent_prefix": image_latent_prefix,
|
||||
"helios_history_valid_mask": history_valid_mask,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class HeliosTextToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="HeliosTextToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=132, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.String.Input("history_sizes", default="16,2,1", advanced=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
positive,
|
||||
negative,
|
||||
vae,
|
||||
width,
|
||||
height,
|
||||
length,
|
||||
batch_size,
|
||||
history_sizes="16,2,1",
|
||||
) -> io.NodeOutput:
|
||||
spacial_scale = vae.spacial_compression_encode()
|
||||
latent_channels = vae.latent_channels
|
||||
latent_t = ((length - 1) // 4) + 1
|
||||
|
||||
# Create zero latent as shape placeholder (noise will be generated in sampler)
|
||||
latent = torch.zeros(
|
||||
[batch_size, latent_channels, latent_t, height // spacial_scale, width // spacial_scale],
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
|
||||
sizes = _parse_int_list(history_sizes, [16, 2, 1])
|
||||
if len(sizes) != 3:
|
||||
sizes = [16, 2, 1]
|
||||
sizes = sorted([max(0, int(v)) for v in sizes], reverse=True)
|
||||
hist_len = max(1, sum(sizes))
|
||||
# History latent starts as zeros (no history yet)
|
||||
history_latent = torch.zeros(
|
||||
[batch_size, latent_channels, hist_len, latent.shape[-2], latent.shape[-1]],
|
||||
device=latent.device,
|
||||
dtype=latent.dtype,
|
||||
)
|
||||
history_valid_mask = torch.zeros((batch_size, hist_len), device=latent.device, dtype=torch.bool)
|
||||
|
||||
positive, negative = _set_helios_history_values(
|
||||
positive,
|
||||
negative,
|
||||
history_latent,
|
||||
sizes,
|
||||
False,
|
||||
prefix_latent=None,
|
||||
)
|
||||
return io.NodeOutput(
|
||||
positive,
|
||||
negative,
|
||||
{
|
||||
"samples": latent,
|
||||
"helios_history_latent": history_latent,
|
||||
"helios_image_latent_prefix": None,
|
||||
"helios_history_valid_mask": history_valid_mask,
|
||||
},
|
||||
)
|
||||
|
||||
@ -544,6 +748,7 @@ class HeliosVideoToVideo(io.ComfyNode):
|
||||
sizes = sorted([max(0, int(v)) for v in sizes], reverse=True)
|
||||
hist_len = max(1, sum(sizes))
|
||||
history_latent = torch.zeros([batch_size, latent_channels, hist_len, latent.shape[-2], latent.shape[-1]], device=latent.device, dtype=latent.dtype)
|
||||
history_valid_mask = torch.zeros((batch_size, hist_len), device=latent.device, dtype=torch.bool)
|
||||
image_latent_prefix = None
|
||||
|
||||
if video is not None:
|
||||
@ -559,11 +764,14 @@ class HeliosVideoToVideo(io.ComfyNode):
|
||||
)
|
||||
vid_latent = frame_sigmas * torch.randn_like(vid_latent, generator=g) + (1.0 - frame_sigmas) * vid_latent
|
||||
vid_latent = vid_latent[:, :, :hist_len]
|
||||
if vid_latent.shape[2] < hist_len:
|
||||
pad = vid_latent[:, :, -1:].repeat(1, 1, hist_len - vid_latent.shape[2], 1, 1)
|
||||
vid_latent = torch.cat([vid_latent, pad], dim=2)
|
||||
vid_latent = comfy.utils.repeat_to_batch_size(vid_latent, batch_size)
|
||||
history_latent = vid_latent
|
||||
if vid_latent.shape[2] < hist_len:
|
||||
keep_frames = hist_len - vid_latent.shape[2]
|
||||
history_latent = torch.cat([history_latent[:, :, :keep_frames], vid_latent], dim=2)
|
||||
history_valid_mask[:, keep_frames:] = True
|
||||
else:
|
||||
history_latent = vid_latent[:, :, -hist_len:]
|
||||
history_valid_mask[:] = True
|
||||
image_latent_prefix = history_latent[:, :, :1]
|
||||
|
||||
positive, negative = _set_helios_history_values(positive, negative, history_latent, sizes, keep_first_frame, prefix_latent=image_latent_prefix)
|
||||
@ -574,6 +782,7 @@ class HeliosVideoToVideo(io.ComfyNode):
|
||||
"samples": latent,
|
||||
"helios_history_latent": history_latent,
|
||||
"helios_image_latent_prefix": image_latent_prefix,
|
||||
"helios_history_valid_mask": history_valid_mask,
|
||||
},
|
||||
)
|
||||
|
||||
@ -625,25 +834,16 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
io.Latent.Input("latent_image"),
|
||||
io.String.Input("pyramid_steps", default="10,10,10"),
|
||||
io.String.Input("stage_range", default="0,0.333333,0.666667,1"),
|
||||
io.Boolean.Input("is_distilled", default=False),
|
||||
io.Boolean.Input("is_amplify_first_stage", default=False),
|
||||
io.Combo.Input("scheduler_mode", options=["euler", "unipc_bh2"]),
|
||||
io.Boolean.Input("distilled", default=False),
|
||||
io.Boolean.Input("amplify_first_stage", default=False),
|
||||
io.Float.Input("gamma", default=1.0 / 3.0, min=0.0001, max=10.0, step=0.0001, round=False),
|
||||
io.Float.Input("shift", default=1.0, min=0.001, max=100.0, step=0.001, round=False, advanced=True),
|
||||
io.Boolean.Input("use_dynamic_shifting", default=False, advanced=True),
|
||||
io.Combo.Input("time_shift_type", options=["exponential", "linear"], advanced=True),
|
||||
io.Int.Input("base_image_seq_len", default=256, min=1, max=65536, advanced=True),
|
||||
io.Int.Input("max_image_seq_len", default=4096, min=1, max=65536, advanced=True),
|
||||
io.Float.Input("base_shift", default=0.5, min=0.0, max=10.0, step=0.0001, round=False, advanced=True),
|
||||
io.Float.Input("max_shift", default=1.15, min=0.0, max=10.0, step=0.0001, round=False, advanced=True),
|
||||
io.Int.Input("num_train_timesteps", default=1000, min=10, max=100000, advanced=True),
|
||||
io.String.Input("history_sizes", default="16,2,1", advanced=True),
|
||||
io.Boolean.Input("keep_first_frame", default=True, advanced=True),
|
||||
io.Int.Input("num_latent_frames_per_chunk", default=9, min=1, max=256, advanced=True),
|
||||
io.Boolean.Input("is_cfg_zero_star", default=False, advanced=True),
|
||||
io.Boolean.Input("cfg_zero_star", default=True, advanced=True),
|
||||
io.Boolean.Input("use_zero_init", default=True, advanced=True),
|
||||
io.Int.Input("zero_steps", default=1, min=0, max=10000, advanced=True),
|
||||
io.Boolean.Input("is_skip_first_chunk", default=False, advanced=True),
|
||||
io.Boolean.Input("skip_first_chunk", default=False, advanced=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="output"),
|
||||
@ -663,33 +863,40 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
latent_image,
|
||||
pyramid_steps,
|
||||
stage_range,
|
||||
is_distilled,
|
||||
is_amplify_first_stage,
|
||||
scheduler_mode,
|
||||
distilled,
|
||||
amplify_first_stage,
|
||||
gamma,
|
||||
shift,
|
||||
use_dynamic_shifting,
|
||||
time_shift_type,
|
||||
base_image_seq_len,
|
||||
max_image_seq_len,
|
||||
base_shift,
|
||||
max_shift,
|
||||
num_train_timesteps,
|
||||
history_sizes,
|
||||
keep_first_frame,
|
||||
num_latent_frames_per_chunk,
|
||||
is_cfg_zero_star,
|
||||
cfg_zero_star,
|
||||
use_zero_init,
|
||||
zero_steps,
|
||||
is_skip_first_chunk,
|
||||
skip_first_chunk,
|
||||
) -> io.NodeOutput:
|
||||
# Keep these scheduler knobs internal (not exposed in node UI).
|
||||
shift = 1.0
|
||||
num_train_timesteps = 1000
|
||||
# Keep dynamic shifting always on for Helios parity; not exposed in node UI.
|
||||
use_dynamic_shifting = True
|
||||
time_shift_type = "exponential"
|
||||
base_image_seq_len = 256
|
||||
max_image_seq_len = 4096
|
||||
base_shift = 0.5
|
||||
max_shift = 1.15
|
||||
|
||||
latent = latent_image.copy()
|
||||
latent_samples = comfy.sample.fix_empty_latent_channels(model, latent["samples"], latent.get("downscale_ratio_spacial", None))
|
||||
if not add_noise:
|
||||
latent_samples = _process_latent_in_preserve_zero_frames(model, latent_samples)
|
||||
|
||||
stage_steps = _parse_int_list(pyramid_steps, [10, 10, 10])
|
||||
stage_steps = [max(1, int(s)) for s in stage_steps]
|
||||
stage_count = len(stage_steps)
|
||||
history_sizes_list = sorted([max(0, int(v)) for v in _parse_int_list(history_sizes, [16, 2, 1])], reverse=True)
|
||||
# Diffusers parity: if not keeping first frame, fold prefix slot into short history size.
|
||||
if not keep_first_frame and len(history_sizes_list) > 0:
|
||||
history_sizes_list[-1] += 1
|
||||
|
||||
stage_range_values = _parse_float_list(stage_range, [0.0, 1.0 / 3.0, 2.0 / 3.0, 1.0])
|
||||
if len(stage_range_values) != stage_count + 1:
|
||||
@ -706,29 +913,41 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
b, c, t, h, w = latent_samples.shape
|
||||
chunk_t = max(1, int(num_latent_frames_per_chunk))
|
||||
chunk_count = max(1, (t + chunk_t - 1) // chunk_t)
|
||||
low_scale = 2 ** max(0, stage_count - 1)
|
||||
low_h = max(1, h // low_scale)
|
||||
low_w = max(1, w // low_scale)
|
||||
|
||||
base_latent = torch.zeros((b, c, chunk_t, low_h, low_w), dtype=latent_samples.dtype, layout=latent_samples.layout, device=latent_samples.device)
|
||||
|
||||
if add_noise:
|
||||
stage_latent = comfy.sample.prepare_noise(base_latent, noise_seed)
|
||||
else:
|
||||
stage_latent = torch.zeros_like(base_latent, device="cpu")
|
||||
|
||||
stage_latent = stage_latent.to(base_latent.dtype).to(comfy.model_management.intermediate_device())
|
||||
euler_sampler = comfy.samplers.KSAMPLER(_helios_euler_sample)
|
||||
target_device = comfy.model_management.get_torch_device()
|
||||
noise_gen = torch.Generator(device=target_device)
|
||||
noise_gen.manual_seed(int(noise_seed))
|
||||
|
||||
image_latent_prefix = latent.get("helios_image_latent_prefix", None)
|
||||
history_valid_mask = latent.get("helios_history_valid_mask", None)
|
||||
if history_valid_mask is None:
|
||||
raise ValueError("Helios sampler requires `helios_history_valid_mask` in latent input.")
|
||||
history_from_latent_applied = False
|
||||
if image_latent_prefix is not None:
|
||||
image_latent_prefix = model.model.process_latent_in(image_latent_prefix)
|
||||
if "helios_history_latent" in latent:
|
||||
history_in = _process_latent_in_preserve_zero_frames(model, latent["helios_history_latent"], valid_mask=history_valid_mask)
|
||||
positive, negative = _set_helios_history_values(
|
||||
positive,
|
||||
negative,
|
||||
history_in,
|
||||
history_sizes_list,
|
||||
keep_first_frame,
|
||||
prefix_latent=image_latent_prefix,
|
||||
)
|
||||
history_from_latent_applied = True
|
||||
|
||||
latents_history_short = _extract_condition_value(positive, "latents_history_short")
|
||||
latents_history_mid = _extract_condition_value(positive, "latents_history_mid")
|
||||
latents_history_long = _extract_condition_value(positive, "latents_history_long")
|
||||
image_latent_prefix = latent.get("helios_image_latent_prefix", None)
|
||||
if (not history_from_latent_applied) and latents_history_short is not None and latents_history_mid is not None and latents_history_long is not None:
|
||||
raise ValueError("Helios requires `helios_history_latent` + `helios_history_valid_mask`; direct history conditioning is not supported.")
|
||||
if latents_history_short is None and "helios_history_latent" in latent:
|
||||
history_in = _process_latent_in_preserve_zero_frames(model, latent["helios_history_latent"], valid_mask=history_valid_mask)
|
||||
positive, negative = _set_helios_history_values(
|
||||
positive,
|
||||
negative,
|
||||
latent["helios_history_latent"],
|
||||
history_in,
|
||||
history_sizes_list,
|
||||
keep_first_frame,
|
||||
prefix_latent=image_latent_prefix,
|
||||
@ -740,18 +959,100 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
x0_output = {}
|
||||
generated_chunks = []
|
||||
if latents_history_short is not None and latents_history_mid is not None and latents_history_long is not None:
|
||||
rolling_history = torch.cat([latents_history_long, latents_history_mid, latents_history_short], dim=2)
|
||||
# Diffusers parity: `history_latents` storage does NOT include the keep_first_frame prefix slot.
|
||||
# `latents_history_short` in conditioning may include [prefix + short_base], so strip prefix here.
|
||||
short_base_size = history_sizes_list[-1] if len(history_sizes_list) > 0 else latents_history_short.shape[2]
|
||||
if keep_first_frame and latents_history_short.shape[2] > short_base_size:
|
||||
short_for_history = latents_history_short[:, :, -short_base_size:]
|
||||
else:
|
||||
short_for_history = latents_history_short
|
||||
rolling_history = torch.cat([latents_history_long, latents_history_mid, short_for_history], dim=2)
|
||||
elif "helios_history_latent" in latent:
|
||||
rolling_history = latent["helios_history_latent"]
|
||||
rolling_history = _process_latent_in_preserve_zero_frames(model, rolling_history, valid_mask=history_valid_mask)
|
||||
else:
|
||||
hist_len = max(1, sum(history_sizes_list))
|
||||
rolling_history = torch.zeros((b, c, hist_len, h, w), device=latent_samples.device, dtype=latent_samples.dtype)
|
||||
|
||||
for chunk_idx in range(chunk_count):
|
||||
if add_noise:
|
||||
stage_latent = comfy.sample.prepare_noise(base_latent, noise_seed + chunk_idx).to(base_latent.dtype).to(comfy.model_management.intermediate_device())
|
||||
# Align with Diffusers behavior: when initial video latents are provided, seed history buffer
|
||||
# with those latents before the first denoising chunk.
|
||||
if not add_noise:
|
||||
hist_len = max(1, sum(history_sizes_list))
|
||||
rolling_history = rolling_history.to(device=latent_samples.device, dtype=latent_samples.dtype)
|
||||
video_latents = latent_samples
|
||||
video_frames = video_latents.shape[2]
|
||||
if video_frames < hist_len:
|
||||
keep_frames = hist_len - video_frames
|
||||
rolling_history = torch.cat([rolling_history[:, :, :keep_frames], video_latents], dim=2)
|
||||
else:
|
||||
stage_latent = torch.zeros_like(base_latent, device=comfy.model_management.intermediate_device())
|
||||
rolling_history = video_latents[:, :, -hist_len:]
|
||||
|
||||
# Keep history/prefix on the same device/dtype as denoising latents.
|
||||
rolling_history = rolling_history.to(device=target_device, dtype=latent_samples.dtype)
|
||||
if image_latent_prefix is not None:
|
||||
image_latent_prefix = image_latent_prefix.to(device=target_device, dtype=latent_samples.dtype)
|
||||
|
||||
for chunk_idx in range(chunk_count):
|
||||
# Extract chunk from input latents
|
||||
chunk_start = chunk_idx * chunk_t
|
||||
chunk_end = min(chunk_start + chunk_t, t)
|
||||
latent_chunk = latent_samples[:, :, chunk_start:chunk_end, :, :]
|
||||
|
||||
# Prepare initial latent for this chunk
|
||||
if add_noise:
|
||||
# Diffusers parity: each chunk denoises a fixed latent window size.
|
||||
# Keep chunk temporal length constant and crop only after all chunks.
|
||||
noise_shape = (
|
||||
latent_samples.shape[0],
|
||||
latent_samples.shape[1],
|
||||
chunk_t,
|
||||
latent_samples.shape[3],
|
||||
latent_samples.shape[4],
|
||||
)
|
||||
stage_latent = torch.randn(noise_shape, device=target_device, dtype=latent_samples.dtype, generator=noise_gen)
|
||||
else:
|
||||
# Use actual input latents; pad final short chunk to fixed size like Diffusers windowing.
|
||||
stage_latent = latent_chunk.clone()
|
||||
if stage_latent.shape[2] < chunk_t:
|
||||
if stage_latent.shape[2] == 0:
|
||||
stage_latent = torch.zeros(
|
||||
(
|
||||
latent_samples.shape[0],
|
||||
latent_samples.shape[1],
|
||||
chunk_t,
|
||||
latent_samples.shape[3],
|
||||
latent_samples.shape[4],
|
||||
),
|
||||
device=latent_samples.device,
|
||||
dtype=latent_samples.dtype,
|
||||
)
|
||||
else:
|
||||
pad = stage_latent[:, :, -1:].repeat(1, 1, chunk_t - stage_latent.shape[2], 1, 1)
|
||||
stage_latent = torch.cat([stage_latent, pad], dim=2)
|
||||
|
||||
# Downsample to stage 0 resolution
|
||||
for _ in range(max(0, int(stage_count) - 1)):
|
||||
stage_latent = _downsample_latent_5d_bilinear_x2(stage_latent)
|
||||
|
||||
# Keep stage latents on model device for parity with Diffusers scheduler/noise path.
|
||||
stage_latent = stage_latent.to(target_device)
|
||||
|
||||
# Diffusers parity:
|
||||
# keep_first_frame=True and no image_latent_prefix on the first chunk
|
||||
# should use an all-zero prefix frame, not history[:, :, :1].
|
||||
chunk_prefix = image_latent_prefix
|
||||
if keep_first_frame and image_latent_prefix is None and chunk_idx == 0:
|
||||
chunk_prefix = torch.zeros(
|
||||
(
|
||||
rolling_history.shape[0],
|
||||
rolling_history.shape[1],
|
||||
1,
|
||||
rolling_history.shape[3],
|
||||
rolling_history.shape[4],
|
||||
),
|
||||
device=rolling_history.device,
|
||||
dtype=rolling_history.dtype,
|
||||
)
|
||||
|
||||
positive_chunk, negative_chunk = _set_helios_history_values(
|
||||
positive,
|
||||
@ -759,37 +1060,28 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
rolling_history,
|
||||
history_sizes_list,
|
||||
keep_first_frame,
|
||||
prefix_latent=image_latent_prefix,
|
||||
prefix_latent=chunk_prefix,
|
||||
)
|
||||
latents_history_short = _extract_condition_value(positive_chunk, "latents_history_short")
|
||||
latents_history_mid = _extract_condition_value(positive_chunk, "latents_history_mid")
|
||||
latents_history_long = _extract_condition_value(positive_chunk, "latents_history_long")
|
||||
|
||||
for stage_idx in range(stage_count):
|
||||
if stage_idx > 0:
|
||||
stage_latent = _upsample_latent_5d(stage_latent, scale=2)
|
||||
|
||||
ori_sigma = 1.0 - float(stage_tables["ori_start_sigmas"][stage_idx])
|
||||
alpha = 1.0 / (math.sqrt(1.0 + (1.0 / gamma)) * (1.0 - ori_sigma) + ori_sigma)
|
||||
beta = alpha * (1.0 - ori_sigma) / math.sqrt(gamma)
|
||||
|
||||
noise = _sample_block_noise_like(stage_latent, gamma, patch_size=(1, 2, 2)).to(stage_latent)
|
||||
stage_latent = alpha * stage_latent + beta * noise
|
||||
|
||||
stage_latent = stage_latent.to(comfy.model_management.get_torch_device())
|
||||
sigmas = _helios_stage_sigmas(
|
||||
stage_idx=stage_idx,
|
||||
stage_steps=stage_steps[stage_idx],
|
||||
stage_tables=stage_tables,
|
||||
is_distilled=is_distilled,
|
||||
is_amplify_first_stage=is_amplify_first_stage and chunk_idx == 0,
|
||||
).to(stage_latent.dtype)
|
||||
is_distilled=distilled,
|
||||
is_amplify_first_stage=amplify_first_stage and chunk_idx == 0,
|
||||
).to(device=stage_latent.device, dtype=torch.float32)
|
||||
timesteps = _helios_stage_timesteps(
|
||||
stage_idx=stage_idx,
|
||||
stage_steps=stage_steps[stage_idx],
|
||||
stage_tables=stage_tables,
|
||||
is_distilled=is_distilled,
|
||||
is_amplify_first_stage=is_amplify_first_stage and chunk_idx == 0,
|
||||
).to(stage_latent.dtype)
|
||||
is_distilled=distilled,
|
||||
is_amplify_first_stage=amplify_first_stage and chunk_idx == 0,
|
||||
).to(device=stage_latent.device, dtype=torch.float32)
|
||||
if use_dynamic_shifting:
|
||||
patch_size = (1, 2, 2)
|
||||
image_seq_len = (stage_latent.shape[-1] * stage_latent.shape[-2] * stage_latent.shape[-3]) // (patch_size[0] * patch_size[1] * patch_size[2])
|
||||
@ -800,10 +1092,24 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
base_shift=base_shift,
|
||||
max_shift=max_shift,
|
||||
)
|
||||
sigmas = _time_shift(sigmas, mu=mu, sigma=1.0, mode=time_shift_type).to(stage_latent.dtype)
|
||||
sigmas = _time_shift(sigmas, mu=mu, sigma=1.0, mode=time_shift_type).to(torch.float32)
|
||||
tmin = torch.min(timesteps)
|
||||
tmax = torch.max(timesteps)
|
||||
timesteps = tmin + sigmas[:-1] * (tmax - tmin)
|
||||
else:
|
||||
pass
|
||||
|
||||
# Keep parity with Diffusers pipeline order:
|
||||
# stage timesteps are computed before upsampling/renoise for stage > 0.
|
||||
if stage_idx > 0:
|
||||
stage_latent = _upsample_latent_5d(stage_latent, scale=2)
|
||||
|
||||
ori_sigma = 1.0 - float(stage_tables["ori_start_sigmas"][stage_idx])
|
||||
alpha = 1.0 / (math.sqrt(1.0 + (1.0 / gamma)) * (1.0 - ori_sigma) + ori_sigma)
|
||||
beta = alpha * (1.0 - ori_sigma) / math.sqrt(gamma)
|
||||
|
||||
noise = _sample_block_noise_like(stage_latent, gamma, patch_size=(1, 2, 2), generator=noise_gen).to(stage_latent)
|
||||
stage_latent = alpha * stage_latent + beta * noise
|
||||
|
||||
indices_hidden_states, idx_short, idx_mid, idx_long = _build_helios_indices(
|
||||
batch=stage_latent.shape[0],
|
||||
@ -811,7 +1117,6 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
keep_first_frame=keep_first_frame,
|
||||
hidden_frames=stage_latent.shape[2],
|
||||
device=stage_latent.device,
|
||||
dtype=stage_latent.dtype,
|
||||
)
|
||||
positive_stage = node_helpers.conditioning_set_values(positive_chunk, {"indices_hidden_states": indices_hidden_states})
|
||||
negative_stage = node_helpers.conditioning_set_values(negative_chunk, {"indices_hidden_states": indices_hidden_states})
|
||||
@ -831,19 +1136,22 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
positive_stage = node_helpers.conditioning_set_values(positive_stage, values)
|
||||
negative_stage = node_helpers.conditioning_set_values(negative_stage, values)
|
||||
|
||||
cfg_use = 1.0 if is_distilled else cfg
|
||||
stage_time_values = {
|
||||
"helios_stage_sigmas": sigmas,
|
||||
"helios_stage_timesteps": timesteps,
|
||||
}
|
||||
positive_stage = node_helpers.conditioning_set_values(positive_stage, stage_time_values)
|
||||
negative_stage = node_helpers.conditioning_set_values(negative_stage, stage_time_values)
|
||||
|
||||
if stage_idx == 0 and add_noise:
|
||||
noise = comfy.sample.prepare_noise(stage_latent, noise_seed + chunk_idx * 100 + stage_idx)
|
||||
latent_start = torch.zeros_like(stage_latent)
|
||||
else:
|
||||
sigma0 = max(float(sigmas[0].item()), 1e-6)
|
||||
noise = (stage_latent / sigma0).to("cpu")
|
||||
latent_start = torch.zeros_like(stage_latent)
|
||||
cfg_use = 1.0 if distilled else cfg
|
||||
|
||||
sigma0 = max(float(sigmas[0].item()), 1e-6)
|
||||
noise = stage_latent / sigma0
|
||||
latent_start = torch.zeros_like(stage_latent)
|
||||
|
||||
stage_start_for_dmd = stage_latent.clone()
|
||||
|
||||
if is_distilled:
|
||||
if distilled:
|
||||
sampler = comfy.samplers.KSAMPLER(
|
||||
_helios_dmd_sample,
|
||||
extra_options={
|
||||
@ -854,14 +1162,11 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
},
|
||||
)
|
||||
else:
|
||||
if scheduler_mode == "unipc_bh2":
|
||||
sampler = comfy.samplers.ksampler("uni_pc_bh2")
|
||||
else:
|
||||
sampler = euler_sampler
|
||||
sampler = euler_sampler
|
||||
|
||||
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
|
||||
stage_model = model
|
||||
if is_cfg_zero_star and not is_distilled:
|
||||
if cfg_zero_star and not distilled:
|
||||
stage_model = model.clone()
|
||||
stage_model.model_options = comfy.model_patcher.set_model_options_pre_cfg_function(
|
||||
stage_model.model_options,
|
||||
@ -882,6 +1187,10 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
disable_pbar=not comfy.utils.PROGRESS_BAR_ENABLED,
|
||||
seed=noise_seed + chunk_idx * 100 + stage_idx,
|
||||
)
|
||||
# sample_custom returns latent_format.process_out(samples); convert back to model-space
|
||||
# so subsequent pyramid stages and history conditioning stay in the same latent space
|
||||
# as Diffusers' internal denoising latents.
|
||||
stage_latent = model.model.process_latent_in(stage_latent)
|
||||
|
||||
if stage_latent.shape[-2] != h or stage_latent.shape[-1] != w:
|
||||
b2, c2, t2, h2, w2 = stage_latent.shape
|
||||
@ -891,7 +1200,7 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
stage_latent = stage_latent[:, :, :, :h, :w]
|
||||
|
||||
generated_chunks.append(stage_latent)
|
||||
if keep_first_frame and ((chunk_idx == 0 and image_latent_prefix is None) or (is_skip_first_chunk and chunk_idx == 1)):
|
||||
if keep_first_frame and ((chunk_idx == 0 and image_latent_prefix is None) or (skip_first_chunk and chunk_idx == 1)):
|
||||
image_latent_prefix = stage_latent[:, :, :1]
|
||||
rolling_history = torch.cat([rolling_history, stage_latent.to(rolling_history.device, rolling_history.dtype)], dim=2)
|
||||
keep_hist = max(1, sum(history_sizes_list))
|
||||
@ -901,7 +1210,7 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
|
||||
out = latent.copy()
|
||||
out.pop("downscale_ratio_spacial", None)
|
||||
out["samples"] = stage_latent
|
||||
out["samples"] = model.model.process_latent_out(stage_latent)
|
||||
|
||||
if "x0" in x0_output:
|
||||
x0_out = model.model.process_latent_out(x0_output["x0"].cpu())
|
||||
@ -917,6 +1226,7 @@ class HeliosExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
HeliosTextToVideo,
|
||||
HeliosImageToVideo,
|
||||
HeliosVideoToVideo,
|
||||
HeliosHistoryConditioning,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user