From ae36a9d4fdd7289e82fcc0ee83e7e9b5ddd27a0a Mon Sep 17 00:00:00 2001 From: qqingzheng <2533221180@qq.com> Date: Thu, 5 Mar 2026 03:50:36 +0800 Subject: [PATCH 01/10] Basic support helios --- comfy/ldm/helios/model.py | 744 +++++++++++++++++++++++++++ comfy/model_base.py | 30 ++ comfy/model_detection.py | 42 ++ comfy/sd.py | 6 + comfy/supported_models.py | 32 +- comfy/text_encoders/helios.py | 41 ++ comfy_extras/nodes_helios.py | 928 ++++++++++++++++++++++++++++++++++ nodes.py | 5 +- 8 files changed, 1825 insertions(+), 3 deletions(-) create mode 100644 comfy/ldm/helios/model.py create mode 100644 comfy/text_encoders/helios.py create mode 100644 comfy_extras/nodes_helios.py diff --git a/comfy/ldm/helios/model.py b/comfy/ldm/helios/model.py new file mode 100644 index 000000000..5ffc91129 --- /dev/null +++ b/comfy/ldm/helios/model.py @@ -0,0 +1,744 @@ +import math + +import torch +import torch.nn as nn + +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.wan.model import sinusoidal_embedding_1d, repeat_e +import comfy.ldm.common_dit +import comfy.patcher_extension + + +def pad_for_3d_conv(x, kernel_size): + b, c, t, h, w = x.shape + pt, ph, pw = kernel_size + pad_t = (pt - (t % pt)) % pt + pad_h = (ph - (h % ph)) % ph + pad_w = (pw - (w % pw)) % pw + return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate") + + +class OutputNorm(nn.Module): + + def __init__(self, dim, eps=1e-6, operation_settings={}): + super().__init__() + self.scale_shift_table = nn.Parameter(torch.randn( + 1, + 2, + dim, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) / dim**0.5) + self.norm = operation_settings.get("operations").LayerNorm( + dim, + eps, + elementwise_affine=False, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + original_context_length: int, + ): + temb = temb[:, -original_context_length:, :] + shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2) + shift = shift.squeeze(2).to(hidden_states.device) + scale = scale.squeeze(2).to(hidden_states.device) + hidden_states = hidden_states[:, -original_context_length:, :] + hidden_states = self.norm(hidden_states) * (1 + scale) + shift + return hidden_states + + +class HeliosSelfAttention(nn.Module): + + def __init__( + self, + dim, + num_heads, + qk_norm=True, + eps=1e-6, + is_cross_attention=False, + is_amplify_history=False, + history_scale_mode="per_head", + operation_settings={}, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.is_cross_attention = is_cross_attention + self.is_amplify_history = is_amplify_history + + self.to_q = operation_settings.get("operations").Linear( + dim, + dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + self.to_k = operation_settings.get("operations").Linear( + dim, + dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + self.to_v = operation_settings.get("operations").Linear( + dim, + dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + self.to_out = nn.ModuleList([ + operation_settings.get("operations").Linear( + dim, + dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + nn.Dropout(0.0), + ]) + + if qk_norm: + self.norm_q = operation_settings.get("operations").RMSNorm( + dim, + eps=eps, + elementwise_affine=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + self.norm_k = operation_settings.get("operations").RMSNorm( + dim, + eps=eps, + elementwise_affine=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + else: + self.norm_q = nn.Identity() + self.norm_k = nn.Identity() + + if is_amplify_history: + if history_scale_mode == "scalar": + self.history_key_scale = nn.Parameter(torch.ones( + 1, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + )) + else: + self.history_key_scale = nn.Parameter(torch.ones( + num_heads, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + )) + self.history_scale_mode = history_scale_mode + self.max_scale = 10.0 + + def forward( + self, + x, + context=None, + freqs=None, + original_context_length=None, + transformer_options={}, + ): + if context is None: + context = x + + b, sq, _ = x.shape + sk = context.shape[1] + + q = self.norm_q(self.to_q(x)).view(b, sq, self.num_heads, self.head_dim) + k = self.norm_k(self.to_k(context)).view(b, sk, self.num_heads, self.head_dim) + v = self.to_v(context).view(b, sk, self.num_heads, self.head_dim) + + if freqs is not None: + q = apply_rope1(q, freqs) + k = apply_rope1(k, freqs) + + if q.dtype != v.dtype: + q = q.to(v.dtype) + if k.dtype != v.dtype: + k = k.to(v.dtype) + + if (not self.is_cross_attention and self.is_amplify_history and original_context_length is not None): + history_seq_len = sq - original_context_length + if history_seq_len > 0: + scale_key = 1.0 + torch.sigmoid(self.history_key_scale) * (self.max_scale - 1.0) + if self.history_scale_mode == "per_head": + scale_key = scale_key.view(1, 1, -1, 1) + k = torch.cat([k[:, :history_seq_len] * scale_key, k[:, history_seq_len:]], dim=1) + + y = optimized_attention( + q.view(b, sq, -1), + k.view(b, sk, -1), + v.view(b, sk, -1), + heads=self.num_heads, + transformer_options=transformer_options, + ) + y = self.to_out[0](y) + y = self.to_out[1](y) + return y + + +class HeliosAttentionBlock(nn.Module): + + def __init__( + self, + dim, + ffn_dim, + num_heads, + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + guidance_cross_attn=False, + is_amplify_history=False, + history_scale_mode="per_head", + operation_settings={}, + ): + super().__init__() + + self.norm1 = operation_settings.get("operations").LayerNorm( + dim, + eps, + elementwise_affine=False, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + self.attn1 = HeliosSelfAttention( + dim, + num_heads, + qk_norm=qk_norm, + eps=eps, + is_cross_attention=False, + is_amplify_history=is_amplify_history, + history_scale_mode=history_scale_mode, + operation_settings=operation_settings, + ) + + self.norm2 = (operation_settings.get("operations").LayerNorm( + dim, + eps, + elementwise_affine=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) if cross_attn_norm else nn.Identity()) + self.attn2 = HeliosSelfAttention( + dim, + num_heads, + qk_norm=qk_norm, + eps=eps, + is_cross_attention=True, + operation_settings=operation_settings, + ) + + self.norm3 = operation_settings.get("operations").LayerNorm( + dim, + eps, + elementwise_affine=False, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + self.ffn = nn.Sequential( + operation_settings.get("operations").Linear( + dim, + ffn_dim, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + nn.GELU(approximate="tanh"), + operation_settings.get("operations").Linear( + ffn_dim, + dim, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + + self.scale_shift_table = nn.Parameter(torch.randn( + 1, + 6, + dim, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) / dim**0.5) + self.guidance_cross_attn = guidance_cross_attn + + def forward(self, x, context, e, freqs, original_context_length=None, transformer_options={}): + if e.ndim == 4: + e = (self.scale_shift_table.unsqueeze(0) + e.float()).chunk(6, dim=2) + e = [v.squeeze(2) for v in e] + else: + e = (self.scale_shift_table + e.float()).chunk(6, dim=1) + + # self-attn + y = self.attn1( + torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), + freqs=freqs, + original_context_length=original_context_length, + transformer_options=transformer_options, + ) + x = torch.addcmul(x, y, repeat_e(e[2], x)) + + # cross-attn + if self.guidance_cross_attn and original_context_length is not None: + history_seq_len = x.shape[1] - original_context_length + history_x, x_main = torch.split(x, [history_seq_len, original_context_length], dim=1) + x_main = x_main + self.attn2( + self.norm2(x_main), + context=context, + transformer_options=transformer_options, + ) + x = torch.cat([history_x, x_main], dim=1) + else: + x = x + self.attn2(self.norm2(x), context=context, transformer_options=transformer_options) + + # ffn + y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm3(x), 1 + repeat_e(e[4], x))) + x = torch.addcmul(x, y, repeat_e(e[5], x)) + return x + + +class HeliosModel(torch.nn.Module): + + def __init__( + self, + model_type="t2v", + patch_size=(1, 2, 2), + num_attention_heads=40, + attention_head_dim=128, + in_channels=16, + out_channels=16, + text_dim=4096, + freq_dim=256, + ffn_dim=13824, + num_layers=40, + cross_attn_norm=True, + qk_norm=True, + eps=1e-6, + rope_dim=(44, 42, 42), + rope_theta=10000.0, + guidance_cross_attn=True, + zero_history_timestep=True, + has_multi_term_memory_patch=True, + is_amplify_history=False, + history_scale_mode="per_head", + image_model=None, + device=None, + dtype=None, + operations=None, + **kwargs, + ): + del model_type, image_model, kwargs + super().__init__() + self.dtype = dtype + operation_settings = { + "operations": operations, + "device": device, + "dtype": dtype, + } + + dim = num_attention_heads * attention_head_dim + self.patch_size = patch_size + self.out_dim = out_channels or in_channels + self.dim = dim + self.freq_dim = freq_dim + self.zero_history_timestep = zero_history_timestep + + # embeddings + self.patch_embedding = operations.Conv3d( + in_channels, + dim, + kernel_size=patch_size, + stride=patch_size, + device=operation_settings.get("device"), + dtype=torch.float32, + ) + self.text_embedding = nn.Sequential( + operations.Linear( + text_dim, + dim, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + nn.GELU(approximate="tanh"), + operations.Linear( + dim, + dim, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + self.time_embedding = nn.Sequential( + operations.Linear( + freq_dim, + dim, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + nn.SiLU(), + operations.Linear( + dim, + dim, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + self.time_projection = nn.Sequential( + nn.SiLU(), + operations.Linear( + dim, + dim * 6, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + + d = dim // num_attention_heads + self.rope_embedder = EmbedND(dim=d, theta=rope_theta, axes_dim=list(rope_dim)) + + # pyramidal embedding + if has_multi_term_memory_patch: + self.patch_short = operations.Conv3d( + in_channels, + dim, + kernel_size=patch_size, + stride=patch_size, + device=operation_settings.get("device"), + dtype=torch.float32, + ) + self.patch_mid = operations.Conv3d( + in_channels, + dim, + kernel_size=tuple(2 * p for p in patch_size), + stride=tuple(2 * p for p in patch_size), + device=operation_settings.get("device"), + dtype=torch.float32, + ) + self.patch_long = operations.Conv3d( + in_channels, + dim, + kernel_size=tuple(4 * p for p in patch_size), + stride=tuple(4 * p for p in patch_size), + device=operation_settings.get("device"), + dtype=torch.float32, + ) + + # blocks + self.blocks = nn.ModuleList([HeliosAttentionBlock( + dim, + ffn_dim, + num_attention_heads, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + guidance_cross_attn=guidance_cross_attn, + is_amplify_history=is_amplify_history, + history_scale_mode=history_scale_mode, + operation_settings=operation_settings, + ) for _ in range(num_layers)]) + + # head + self.norm_out = OutputNorm(dim, eps=eps, operation_settings=operation_settings) + self.proj_out = operations.Linear( + dim, + self.out_dim * math.prod(patch_size), + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + + def rope_encode( + self, + t, + h, + w, + t_start=0, + steps_t=None, + steps_h=None, + steps_w=None, + device=None, + dtype=None, + transformer_options={}, + frame_indices=None, + ): + patch_size = self.patch_size + t_len = (t + (patch_size[0] // 2)) // patch_size[0] + h_len = (h + (patch_size[1] // 2)) // patch_size[1] + w_len = (w + (patch_size[2] // 2)) // patch_size[2] + + if steps_t is None: + steps_t = t_len + if steps_h is None: + steps_h = h_len + if steps_w is None: + steps_w = w_len + + h_start = 0 + w_start = 0 + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0 + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + + t_start += rope_options.get("shift_t", 0.0) + h_start += rope_options.get("shift_y", 0.0) + w_start += rope_options.get("shift_x", 0.0) + + if frame_indices is None: + t_coords = torch.linspace( + t_start, + t_start + (t_len - 1), + steps=steps_t, + device=device, + dtype=dtype, + ).reshape(1, -1, 1, 1) + batch_size = 1 + else: + batch_size = frame_indices.shape[0] + t_coords = frame_indices.to(device=device, dtype=dtype) + if t_coords.shape[1] != steps_t: + t_coords = torch.nn.functional.interpolate( + t_coords.unsqueeze(1), + size=steps_t, + mode="linear", + align_corners=False, + ).squeeze(1) + t_coords = (t_coords + t_start)[:, :, None, None] + + img_ids = torch.zeros((batch_size, steps_t, steps_h, steps_w, 3), device=device, dtype=dtype) + img_ids[:, :, :, :, 0] = img_ids[:, :, :, :, 0] + t_coords.expand(batch_size, steps_t, steps_h, steps_w) + img_ids[:, :, :, :, 1] = img_ids[:, :, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, 1, -1, 1) + img_ids[:, :, :, :, 2] = img_ids[:, :, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, 1, -1) + img_ids = img_ids.reshape(batch_size, -1, img_ids.shape[-1]) + return self.rope_embedder(img_ids).movedim(1, 2) + + def forward( + self, + x, + timestep, + context, + clip_fea=None, + time_dim_concat=None, + transformer_options={}, + **kwargs, + ): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options), + ).execute( + x, + timestep, + context, + clip_fea, + time_dim_concat, + transformer_options, + **kwargs, + ) + + def _forward( + self, + x, + timestep, + context, + clip_fea=None, + time_dim_concat=None, + transformer_options={}, + **kwargs, + ): + del clip_fea, time_dim_concat + + _, _, t, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + + out = self.forward_orig( + hidden_states=x, + timestep=timestep, + context=context, + indices_hidden_states=kwargs.get("indices_hidden_states", None), + indices_latents_history_short=kwargs.get("indices_latents_history_short", None), + indices_latents_history_mid=kwargs.get("indices_latents_history_mid", None), + indices_latents_history_long=kwargs.get("indices_latents_history_long", None), + latents_history_short=kwargs.get("latents_history_short", None), + latents_history_mid=kwargs.get("latents_history_mid", None), + latents_history_long=kwargs.get("latents_history_long", None), + transformer_options=transformer_options, + ) + + return out[:, :, :t, :h, :w] + + def forward_orig( + self, + hidden_states, + timestep, + context, + indices_hidden_states=None, + indices_latents_history_short=None, + indices_latents_history_mid=None, + indices_latents_history_long=None, + latents_history_short=None, + latents_history_mid=None, + latents_history_long=None, + transformer_options={}, + ): + batch_size = hidden_states.shape[0] + p_t, p_h, p_w = self.patch_size + + # embeddings + hidden_states = self.patch_embedding(hidden_states.float()).to(hidden_states.dtype) + _, _, post_t, post_h, post_w = hidden_states.shape + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + if indices_hidden_states is None: + indices_hidden_states = (torch.arange(0, post_t, device=hidden_states.device).unsqueeze(0).expand(batch_size, -1)) + + freqs = self.rope_encode( + t=post_t * self.patch_size[0], + h=post_h * self.patch_size[1], + w=post_w * self.patch_size[2], + steps_t=post_t, + steps_h=post_h, + steps_w=post_w, + device=hidden_states.device, + dtype=hidden_states.dtype, + transformer_options=transformer_options, + frame_indices=indices_hidden_states, + ) + original_context_length = hidden_states.shape[1] + + if (latents_history_short is not None and indices_latents_history_short is not None and hasattr(self, "patch_short")): + x_short = self.patch_short(latents_history_short.float()).to(hidden_states.dtype) + _, _, ts, hs, ws = x_short.shape + x_short = x_short.flatten(2).transpose(1, 2) + f_short = self.rope_encode( + t=ts * self.patch_size[0], + h=hs * self.patch_size[1], + w=ws * self.patch_size[2], + steps_t=ts, + steps_h=hs, + steps_w=ws, + device=x_short.device, + dtype=x_short.dtype, + transformer_options=transformer_options, + frame_indices=indices_latents_history_short, + ) + hidden_states = torch.cat([x_short, hidden_states], dim=1) + freqs = torch.cat([f_short, freqs], dim=1) + + if (latents_history_mid is not None and indices_latents_history_mid is not None and hasattr(self, "patch_mid")): + x_mid = self.patch_mid(pad_for_3d_conv(latents_history_mid, (2, 4, 4)).float()).to(hidden_states.dtype) + _, _, tm, hm, wm = x_mid.shape + x_mid = x_mid.flatten(2).transpose(1, 2) + f_mid = self.rope_encode( + t=tm * self.patch_size[0], + h=hm * self.patch_size[1], + w=wm * self.patch_size[2], + steps_t=tm, + steps_h=hm, + steps_w=wm, + device=x_mid.device, + dtype=x_mid.dtype, + transformer_options=transformer_options, + frame_indices=indices_latents_history_mid, + ) + hidden_states = torch.cat([x_mid, hidden_states], dim=1) + freqs = torch.cat([f_mid, freqs], dim=1) + + if (latents_history_long is not None and indices_latents_history_long is not None and hasattr(self, "patch_long")): + x_long = self.patch_long(pad_for_3d_conv(latents_history_long, (4, 8, 8)).float()).to(hidden_states.dtype) + _, _, tl, hl, wl = x_long.shape + x_long = x_long.flatten(2).transpose(1, 2) + f_long = self.rope_encode( + t=tl * self.patch_size[0], + h=hl * self.patch_size[1], + w=wl * self.patch_size[2], + steps_t=tl, + steps_h=hl, + steps_w=wl, + device=x_long.device, + dtype=x_long.dtype, + transformer_options=transformer_options, + frame_indices=indices_latents_history_long, + ) + hidden_states = torch.cat([x_long, hidden_states], dim=1) + freqs = torch.cat([f_long, freqs], dim=1) + + history_context_length = hidden_states.shape[1] - original_context_length + + if timestep.ndim == 0: + timestep = timestep.unsqueeze(0) + timestep = timestep.to(hidden_states.device) + if timestep.shape[0] != batch_size: + timestep = timestep[:1].expand(batch_size) + + # time embeddings + e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=hidden_states.dtype)) + e = e.reshape(batch_size, -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + context = self.text_embedding(context) + + if self.zero_history_timestep and history_context_length > 0: + timestep_t0 = torch.zeros((1, ), dtype=timestep.dtype, device=timestep.device) + e_t0 = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep_t0.flatten()).to(dtype=hidden_states.dtype)) + e_t0 = e_t0.reshape(1, -1, e_t0.shape[-1]).expand(batch_size, history_context_length, -1) + e0_t0 = self.time_projection(e_t0[:, :1]).unflatten(2, (6, self.dim)) + e0_t0 = (e0_t0.view(batch_size, 1, 6, self.dim).permute(0, 2, 1, 3).expand(batch_size, 6, history_context_length, self.dim)) + + e = e.expand(batch_size, original_context_length, -1) + e0 = (e0.view(batch_size, 1, 6, self.dim).permute(0, 2, 1, 3).expand(batch_size, 6, original_context_length, self.dim)) + e = torch.cat([e_t0, e], dim=1) + e0 = torch.cat([e0_t0, e0], dim=2) + else: + e = e.expand(batch_size, hidden_states.shape[1], -1) + e0 = (e0.view(batch_size, 1, 6, self.dim).permute(0, 2, 1, 3).expand(batch_size, 6, hidden_states.shape[1], self.dim)) + + e0 = e0.permute(0, 2, 1, 3) + + for block in self.blocks: + hidden_states = block( + hidden_states, + context, + e0, + freqs, + original_context_length=original_context_length, + transformer_options=transformer_options, + ) + + hidden_states = self.norm_out(hidden_states, e, original_context_length) + hidden_states = self.proj_out(hidden_states) + + return self.unpatchify(hidden_states, (post_t, post_h, post_w)) + + def unpatchify(self, x, grid_sizes): + c = self.out_dim + b = x.shape[0] + u = x[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c) + u = torch.einsum("bfhwpqrc->bcfphqwr", u) + u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)]) + return u + + def load_state_dict(self, state_dict, strict=True, assign=False): + # Keep compatibility with reference diffusers key names. + remapped = {} + for k, v in state_dict.items(): + nk = k + nk = nk.replace("condition_embedder.time_embedder.linear_1.", "time_embedding.0.") + nk = nk.replace("condition_embedder.time_embedder.linear_2.", "time_embedding.2.") + nk = nk.replace("condition_embedder.time_proj.", "time_projection.1.") + nk = nk.replace("condition_embedder.text_embedder.linear_1.", "text_embedding.0.") + nk = nk.replace("condition_embedder.text_embedder.linear_2.", "text_embedding.2.") + nk = nk.replace("blocks.", "blocks.") + remapped[nk] = v + + return super().load_state_dict(remapped, strict=strict, assign=assign) + + +# Backward-compatible alias for existing integration points. +HeliosTransformer3DModel = HeliosModel diff --git a/comfy/model_base.py b/comfy/model_base.py index 1e01e9edc..9bee3049a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -41,6 +41,7 @@ import comfy.ldm.cosmos.predict2 import comfy.ldm.lumina.model import comfy.ldm.wan.model import comfy.ldm.wan.model_animate +import comfy.ldm.helios.model import comfy.ldm.hunyuan3d.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model @@ -1268,6 +1269,35 @@ class ZImagePixelSpace(Lumina2): BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace) self.memory_usage_factor_conds = ("ref_latents",) +class Helios(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.helios.model.HeliosTransformer3DModel) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out["c_crossattn"] = comfy.conds.CONDRegular(cross_attn) + + cond_keys = ( + "indices_hidden_states", + "indices_latents_history_short", + "indices_latents_history_mid", + "indices_latents_history_long", + "latents_history_short", + "latents_history_mid", + "latents_history_long", + ) + + for key in cond_keys: + value = kwargs.get(key, None) + if value is None: + continue + if key.startswith("latents_"): + value = self.process_latent_in(value) + out[key] = comfy.conds.CONDRegular(value) + return out + class WAN21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 6eace4628..7a130c02d 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -489,6 +489,48 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return dit_config + if '{}condition_embedder.time_proj.weight'.format(key_prefix) in state_dict_keys and '{}patch_embedding.weight'.format(key_prefix) in state_dict_keys: # Helios + dit_config = {} + dit_config["image_model"] = "helios" + + patch_weight = state_dict['{}patch_embedding.weight'.format(key_prefix)] + inner_dim = patch_weight.shape[0] + patch_size = tuple(patch_weight.shape[2:]) + out_proj = state_dict['{}proj_out.weight'.format(key_prefix)] + + dit_config["patch_size"] = patch_size + dit_config["in_channels"] = patch_weight.shape[1] + dit_config["out_channels"] = out_proj.shape[0] // math.prod(patch_size) + dit_config["text_dim"] = state_dict['{}condition_embedder.text_embedder.linear_1.weight'.format(key_prefix)].shape[1] + dit_config["freq_dim"] = state_dict['{}condition_embedder.time_embedder.linear_1.weight'.format(key_prefix)].shape[1] + dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.') + dit_config["num_attention_heads"] = inner_dim // 128 + dit_config["attention_head_dim"] = 128 + + ffn_in = state_dict.get('{}blocks.0.ffn.net.0.proj.weight'.format(key_prefix), None) + if ffn_in is None: + ffn_in = state_dict.get('{}blocks.0.ffn.0.weight'.format(key_prefix), None) + if ffn_in is not None: + dit_config["ffn_dim"] = ffn_in.shape[0] + + if '{}blocks.0.attn2.add_k_proj.weight'.format(key_prefix) in state_dict_keys: + dit_config["added_kv_proj_dim"] = state_dict['{}blocks.0.attn2.add_k_proj.weight'.format(key_prefix)].shape[1] + + if '{}patch_short.weight'.format(key_prefix) in state_dict_keys: + dit_config["has_multi_term_memory_patch"] = True + else: + dit_config["has_multi_term_memory_patch"] = False + + if '{}blocks.0.attn1.history_key_scale'.format(key_prefix) in state_dict_keys: + dit_config["is_amplify_history"] = True + hk = state_dict['{}blocks.0.attn1.history_key_scale'.format(key_prefix)] + dit_config["history_scale_mode"] = "per_head" if len(hk.shape) > 0 and hk.numel() > 1 else "scalar" + + if metadata is not None and "config" in metadata: + dit_config.update(json.loads(metadata["config"]).get("transformer", {})) + + return dit_config + if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 dit_config = {} dit_config["image_model"] = "wan2.1" diff --git a/comfy/sd.py b/comfy/sd.py index a9ad7c2d2..b05d55474 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -48,6 +48,7 @@ import comfy.text_encoders.hunyuan_video import comfy.text_encoders.cosmos import comfy.text_encoders.lumina2 import comfy.text_encoders.wan +import comfy.text_encoders.helios import comfy.text_encoders.hidream import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 @@ -1163,6 +1164,7 @@ class CLIPType(Enum): NEWBIE = 24 FLUX2 = 25 LONGCAT_IMAGE = 26 + HELIOS = 27 @@ -1334,6 +1336,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif clip_type == CLIPType.HELIOS: + clip_target.clip = comfy.text_encoders.helios.te(**t5xxl_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.helios.HeliosT5Tokenizer + tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.HIDREAM: clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 07feb31b3..2035f25b8 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -17,6 +17,7 @@ import comfy.text_encoders.hunyuan_video import comfy.text_encoders.cosmos import comfy.text_encoders.lumina2 import comfy.text_encoders.wan +import comfy.text_encoders.helios import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image @@ -1132,6 +1133,35 @@ class ZImagePixelSpace(ZImage): def get_model(self, state_dict, prefix="", device=None): return model_base.ZImagePixelSpace(self, device=device) +class Helios(supported_models_base.BASE): + unet_config = { + "image_model": "helios", + } + + sampling_settings = { + "shift": 1.0, + } + + unet_extra_config = {} + latent_format = latent_formats.Wan21 + memory_usage_factor = 1.8 + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = (self.unet_config.get("num_layers", 40) * self.unet_config.get("num_attention_heads", 40)) / (40 * 40) * 1.8 + + def get_model(self, state_dict, prefix="", device=None): + return model_base.Helios(self, device=device) + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.helios.HeliosT5Tokenizer, comfy.text_encoders.helios.te(**t5_detect)) + class WAN21_T2V(supported_models_base.BASE): unet_config = { "image_model": "wan2.1", @@ -1734,6 +1764,6 @@ class LongCatImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, Helios, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy/text_encoders/helios.py b/comfy/text_encoders/helios.py new file mode 100644 index 000000000..dc4b38b13 --- /dev/null +++ b/comfy/text_encoders/helios.py @@ -0,0 +1,41 @@ +from comfy import sd1_clip +from .spiece_tokenizer import SPieceTokenizer +import comfy.text_encoders.t5 +import os + + +class UMT5XXlModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_xxl.json") + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options) + + +class UMT5XXlTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer = tokenizer_data.get("spiece_model", None) + super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key="umt5xxl", tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0, tokenizer_data=tokenizer_data) + + def state_dict(self): + return {"spiece_model": self.tokenizer.serialize_model()} + + +class HeliosT5Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="umt5xxl", tokenizer=UMT5XXlTokenizer) + + +class HeliosT5Model(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): + super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs) + + +def te(dtype_t5=None, t5_quantization_metadata=None): + class HeliosTEModel(HeliosT5Model): + def __init__(self, device="cpu", dtype=None, model_options={}): + if t5_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = t5_quantization_metadata + if dtype_t5 is not None: + dtype = dtype_t5 + super().__init__(device=device, dtype=dtype, model_options=model_options) + return HeliosTEModel diff --git a/comfy_extras/nodes_helios.py b/comfy_extras/nodes_helios.py new file mode 100644 index 000000000..6c1fd7e20 --- /dev/null +++ b/comfy_extras/nodes_helios.py @@ -0,0 +1,928 @@ +import math +import torch + +import nodes +import comfy.model_management +import comfy.model_patcher +import comfy.sample +import comfy.samplers +import comfy.utils +import latent_preview +import node_helpers + +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +def _parse_int_list(values, default): + if values is None: + return default + if isinstance(values, (list, tuple)): + out = [] + for v in values: + try: + out.append(int(v)) + except Exception: + pass + return out if len(out) > 0 else default + + parts = [x.strip() for x in str(values).replace(";", ",").split(",")] + out = [] + for p in parts: + if len(p) == 0: + continue + try: + out.append(int(p)) + except Exception: + continue + return out if len(out) > 0 else default + + +def _parse_float_list(values, default): + if values is None: + return default + if isinstance(values, (list, tuple)): + out = [] + for v in values: + try: + out.append(float(v)) + except Exception: + pass + return out if len(out) > 0 else default + + parts = [x.strip() for x in str(values).replace(";", ",").split(",")] + out = [] + for p in parts: + if len(p) == 0: + continue + try: + out.append(float(p)) + except Exception: + continue + return out if len(out) > 0 else default + + +def _extract_condition_value(conditioning, key): + for c in conditioning: + if len(c) < 2: + continue + value = c[1].get(key, None) + if value is not None: + return value + return None + + +def _upsample_latent_5d(latent, scale=2): + b, c, t, h, w = latent.shape + x = latent.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = comfy.utils.common_upscale(x, w * scale, h * scale, "nearest-exact", "disabled") + x = x.reshape(b, t, c, h * scale, w * scale).permute(0, 2, 1, 3, 4) + return x + + +def _sample_block_noise_like(latent, gamma, patch_size=(1, 2, 2)): + b, c, t, h, w = latent.shape + _, ph, pw = patch_size + block_size = ph * pw + + cov = torch.eye(block_size, device=latent.device) * (1.0 + gamma) - torch.ones(block_size, block_size, device=latent.device) * gamma + cov += torch.eye(block_size, device=latent.device) * 1e-6 + + dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=latent.device), covariance_matrix=cov) + block_number = b * c * t * max(1, h // ph) * max(1, w // pw) + + noise = dist.sample((block_number,)) + noise = noise.view(b, c, t, max(1, h // ph), max(1, w // pw), ph, pw) + noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(b, c, t, max(1, h // ph) * ph, max(1, w // pw) * pw) + noise = noise[:, :, :, :h, :w] + return noise + + +def _helios_global_sigmas(num_train_timesteps=1000, shift=1.0): + alphas = torch.linspace(1.0, 1.0 / float(num_train_timesteps), num_train_timesteps + 1) + sigmas = 1.0 - alphas + if abs(shift - 1.0) > 1e-8: + sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas) + return torch.flip(sigmas, dims=[0])[:-1] + + +def _helios_stage_tables(stage_count, stage_range, gamma, num_train_timesteps=1000, shift=1.0): + sigmas = _helios_global_sigmas(num_train_timesteps=num_train_timesteps, shift=shift) + + ori_start_sigmas = {} + start_sigmas = {} + end_sigmas = {} + timestep_ratios = {} + timesteps_per_stage = {} + sigmas_per_stage = {} + + stage_distance = [] + for i in range(stage_count): + start_indice = int(max(0.0, min(1.0, stage_range[i])) * num_train_timesteps) + end_indice = int(max(0.0, min(1.0, stage_range[i + 1])) * num_train_timesteps) + start_indice = max(0, min(num_train_timesteps - 1, start_indice)) + end_indice = max(0, min(num_train_timesteps, end_indice)) + + start_sigma = float(sigmas[start_indice].item()) + end_sigma = float(sigmas[end_indice].item()) if end_indice < num_train_timesteps else 0.0 + ori_start_sigmas[i] = start_sigma + + if i != 0: + ori_sigma = 1.0 - start_sigma + corrected_sigma = (1.0 / (math.sqrt(1.0 + (1.0 / gamma)) * (1.0 - ori_sigma) + ori_sigma)) * ori_sigma + start_sigma = 1.0 - corrected_sigma + + stage_distance.append(start_sigma - end_sigma) + start_sigmas[i] = start_sigma + end_sigmas[i] = end_sigma + + tot_distance = sum(stage_distance) if sum(stage_distance) > 1e-12 else 1.0 + for i in range(stage_count): + start_ratio = 0.0 if i == 0 else sum(stage_distance[:i]) / tot_distance + end_ratio = 0.9999999999999999 if i == stage_count - 1 else sum(stage_distance[: i + 1]) / tot_distance + timestep_ratios[i] = (start_ratio, end_ratio) + + tmax = min(float(sigmas[int(start_ratio * num_train_timesteps)].item() * num_train_timesteps), 999.0) + tmin = float(sigmas[min(int(end_ratio * num_train_timesteps), num_train_timesteps - 1)].item() * num_train_timesteps) + timesteps_per_stage[i] = torch.linspace(tmax, tmin, num_train_timesteps) + sigmas_per_stage[i] = torch.linspace(0.999, 0.0, num_train_timesteps) + + return { + "ori_start_sigmas": ori_start_sigmas, + "start_sigmas": start_sigmas, + "end_sigmas": end_sigmas, + "timestep_ratios": timestep_ratios, + "timesteps_per_stage": timesteps_per_stage, + "sigmas_per_stage": sigmas_per_stage, + } + + +def _helios_stage_sigmas(stage_idx, stage_steps, stage_tables, is_distilled=False, is_amplify_first_stage=False): + stage_steps = max(1, int(stage_steps)) + if is_distilled: + stage_steps = stage_steps * 2 if (is_amplify_first_stage and stage_idx == 0) else stage_steps + + stage_sigma_src = stage_tables["sigmas_per_stage"][stage_idx] + sigmas = torch.linspace(float(stage_sigma_src[0].item()), float(stage_sigma_src[-1].item()), stage_steps + 1) + return sigmas + + +def _helios_stage_timesteps(stage_idx, stage_steps, stage_tables, is_distilled=False, is_amplify_first_stage=False): + stage_steps = max(1, int(stage_steps)) + if is_distilled: + stage_steps = stage_steps * 2 if (is_amplify_first_stage and stage_idx == 0) else stage_steps + + stage_timestep_src = stage_tables["timesteps_per_stage"][stage_idx] + timesteps = torch.linspace(float(stage_timestep_src[0].item()), float(stage_timestep_src[-1].item()), stage_steps) + return timesteps + + +def _calculate_shift(image_seq_len, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.15): + m = (max_shift - base_shift) / float(max_seq_len - base_seq_len) + b = base_shift - m * float(base_seq_len) + return float(image_seq_len) * m + b + + +def _time_shift_linear(mu, sigma, t): + return mu / (mu + (1.0 / t - 1.0) ** sigma) + + +def _time_shift_exponential(mu, sigma, t): + return math.exp(mu) / (math.exp(mu) + (1.0 / t - 1.0) ** sigma) + + +def _time_shift(t, mu, sigma=1.0, mode="exponential"): + t = torch.clamp(t, min=1e-6, max=0.999999) + if mode == "linear": + return _time_shift_linear(mu, sigma, t) + return _time_shift_exponential(mu, sigma, t) + + +def _optimized_scale(positive_flat, negative_flat): + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + squared_norm = torch.sum(negative_flat * negative_flat, dim=1, keepdim=True) + 1e-8 + return dot_product / squared_norm + + +def _build_cfg_zero_star_pre_cfg(stage_idx, zero_steps, use_zero_init): + state = {"i": 0} + + def pre_cfg_fn(args): + conds_out = args["conds_out"] + if len(conds_out) < 2 or conds_out[1] is None: + state["i"] += 1 + return conds_out + + noise_pred_text = conds_out[0] + noise_uncond = conds_out[1] + cfg = float(args.get("cond_scale", 1.0)) + + positive_flat = noise_pred_text.view(noise_pred_text.shape[0], -1) + negative_flat = noise_uncond.view(noise_uncond.shape[0], -1) + alpha = _optimized_scale(positive_flat, negative_flat) + alpha = alpha.view(noise_pred_text.shape[0], *([1] * (noise_pred_text.ndim - 1))).to(noise_pred_text.dtype) + + if stage_idx == 0 and state["i"] <= int(zero_steps) and bool(use_zero_init): + final = noise_pred_text * 0.0 + else: + final = noise_uncond * alpha + cfg * (noise_pred_text - noise_uncond * alpha) + + state["i"] += 1 + # Return identical cond/uncond so downstream cfg_function keeps `final` unchanged. + return [final, final] + + return pre_cfg_fn + + +def _helios_euler_sample(model, x, sigmas, extra_args=None, callback=None, disable=None): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + for i in range(len(sigmas) - 1): + sigma = sigmas[i] + sigma_next = sigmas[i + 1] + denoised = model(x, sigma * s_in, **extra_args) + + sigma_safe = sigma if float(sigma) > 1e-8 else sigma.new_tensor(1e-8) + flow_pred = (x - denoised) / sigma_safe + + if callback is not None: + callback({"x": x, "i": i, "sigma": sigma, "sigma_hat": sigma, "denoised": denoised}) + + x = x + (sigma_next - sigma) * flow_pred + + return x + + +def _helios_dmd_sample( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + dmd_noisy_tensor=None, + dmd_sigmas=None, + dmd_timesteps=None, + all_timesteps=None, +): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + if dmd_noisy_tensor is None: + dmd_noisy_tensor = x + dmd_noisy_tensor = dmd_noisy_tensor.to(device=x.device, dtype=x.dtype) + if dmd_sigmas is None: + dmd_sigmas = sigmas + if dmd_timesteps is None: + dmd_timesteps = torch.arange(len(sigmas) - 1, device=sigmas.device, dtype=sigmas.dtype) + if all_timesteps is None: + all_timesteps = dmd_timesteps + + def timestep_to_sigma(t): + dt = dmd_timesteps.to(device=x.device, dtype=x.dtype) + ds = dmd_sigmas.to(device=x.device, dtype=x.dtype) + tid = torch.argmin(torch.abs(dt - t)) + tid = torch.clamp(tid, min=0, max=ds.shape[0] - 1) + return ds[tid] + + for i in range(len(sigmas) - 1): + sigma = sigmas[i] + sigma_next = sigmas[i + 1] + timestep = all_timesteps[i] if i < len(all_timesteps) else i + denoised = model(x, sigma * s_in, **extra_args) + + if callback is not None: + callback({"x": x, "i": i, "sigma": sigma, "sigma_hat": sigma, "denoised": denoised}) + + if i < (len(sigmas) - 2): + timestep_next = all_timesteps[i + 1] if i + 1 < len(all_timesteps) else (i + 1) + sigma_t = timestep_to_sigma(torch.as_tensor(timestep, device=x.device, dtype=x.dtype)) + sigma_next_t = timestep_to_sigma(torch.as_tensor(timestep_next, device=x.device, dtype=x.dtype)) + x0_pred = x - sigma_t * ((x - denoised) / torch.clamp(sigma_t, min=1e-8)) + x = (1.0 - sigma_next_t) * x0_pred + sigma_next_t * dmd_noisy_tensor + else: + x = denoised + + return x + + +def _set_helios_history_values(positive, negative, history_latent, history_sizes, keep_first_frame, prefix_latent=None): + latent = history_latent + if latent is None or len(latent.shape) != 5: + return positive, negative + + sizes = list(history_sizes) + if len(sizes) != 3: + sizes = [16, 2, 1] + sizes = [max(0, int(v)) for v in sizes] + total = sum(sizes) + if total <= 0: + return positive, negative + + if latent.shape[2] < total: + pad = torch.zeros( + latent.shape[0], + latent.shape[1], + total - latent.shape[2], + latent.shape[3], + latent.shape[4], + device=latent.device, + dtype=latent.dtype, + ) + hist = torch.cat([pad, latent], dim=2) + else: + hist = latent[:, :, -total:] + + latents_history_long, latents_history_mid, latents_history_short_base = hist.split(sizes, dim=2) + + if keep_first_frame: + if prefix_latent is not None: + prefix = prefix_latent + elif latent.shape[2] > 0: + prefix = latent[:, :, :1] + else: + prefix = torch.zeros(latent.shape[0], latent.shape[1], 1, latent.shape[3], latent.shape[4], device=latent.device, dtype=latent.dtype) + latents_history_short = torch.cat([prefix, latents_history_short_base], dim=2) + else: + latents_history_short = latents_history_short_base + + idx_short = torch.arange(latents_history_short.shape[2], device=latent.device, dtype=latent.dtype).unsqueeze(0).expand(latent.shape[0], -1) + idx_mid = torch.arange(latents_history_mid.shape[2], device=latent.device, dtype=latent.dtype).unsqueeze(0).expand(latent.shape[0], -1) + idx_long = torch.arange(latents_history_long.shape[2], device=latent.device, dtype=latent.dtype).unsqueeze(0).expand(latent.shape[0], -1) + + values = { + "latents_history_short": latents_history_short, + "latents_history_mid": latents_history_mid, + "latents_history_long": latents_history_long, + "indices_latents_history_short": idx_short, + "indices_latents_history_mid": idx_mid, + "indices_latents_history_long": idx_long, + } + + positive = node_helpers.conditioning_set_values(positive, values) + negative = node_helpers.conditioning_set_values(negative, values) + return positive, negative + + +def _build_helios_indices(batch, history_sizes, keep_first_frame, hidden_frames, device, dtype): + sizes = list(history_sizes) + if len(sizes) != 3: + sizes = [16, 2, 1] + sizes = [max(0, int(v)) for v in sizes] + long_size, mid_size, short_base_size = sizes + + if keep_first_frame: + total = 1 + long_size + mid_size + short_base_size + hidden_frames + indices = torch.arange(total, device=device, dtype=dtype) + splits = [1, long_size, mid_size, short_base_size, hidden_frames] + indices_prefix, idx_long, idx_mid, idx_1x, idx_hidden = torch.split(indices, splits, dim=0) + idx_short = torch.cat([indices_prefix, idx_1x], dim=0) + else: + total = long_size + mid_size + short_base_size + hidden_frames + indices = torch.arange(total, device=device, dtype=dtype) + splits = [long_size, mid_size, short_base_size, hidden_frames] + idx_long, idx_mid, idx_short, idx_hidden = torch.split(indices, splits, dim=0) + + idx_hidden = idx_hidden.unsqueeze(0).expand(batch, -1) + idx_short = idx_short.unsqueeze(0).expand(batch, -1) + idx_mid = idx_mid.unsqueeze(0).expand(batch, -1) + idx_long = idx_long.unsqueeze(0).expand(batch, -1) + return idx_hidden, idx_short, idx_mid, idx_long + + +class HeliosImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HeliosImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=132, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + io.String.Input("history_sizes", default="16,2,1", advanced=True), + io.Boolean.Input("keep_first_frame", default=True, advanced=True), + io.Int.Input("num_latent_frames_per_chunk", default=9, min=1, max=256, advanced=True), + io.Boolean.Input("add_noise_to_image_latents", default=True, advanced=True), + io.Float.Input("image_noise_sigma_min", default=0.111, min=0.0, max=1.0, step=0.0001, round=False, advanced=True), + io.Float.Input("image_noise_sigma_max", default=0.135, min=0.0, max=1.0, step=0.0001, round=False, advanced=True), + io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, advanced=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute( + cls, + positive, + negative, + vae, + width, + height, + length, + batch_size, + start_image=None, + history_sizes="16,2,1", + keep_first_frame=True, + num_latent_frames_per_chunk=9, + add_noise_to_image_latents=True, + image_noise_sigma_min=0.111, + image_noise_sigma_max=0.135, + noise_seed=0, + ) -> io.NodeOutput: + spacial_scale = vae.spacial_compression_encode() + latent_channels = vae.latent_channels + latent_t = ((length - 1) // 4) + 1 + latent = torch.zeros([batch_size, latent_channels, latent_t, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device()) + + sizes = _parse_int_list(history_sizes, [16, 2, 1]) + if len(sizes) != 3: + sizes = [16, 2, 1] + sizes = sorted([max(0, int(v)) for v in sizes], reverse=True) + hist_len = max(1, sum(sizes)) + history_latent = torch.zeros([batch_size, latent_channels, hist_len, latent.shape[-2], latent.shape[-1]], device=latent.device, dtype=latent.dtype) + image_latent_prefix = None + + if start_image is not None: + image = comfy.utils.common_upscale(start_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + img_latent = vae.encode(image[:, :, :, :3]) + img_latent = comfy.utils.repeat_to_batch_size(img_latent, batch_size) + image_latent_prefix = img_latent[:, :, :1] + + if add_noise_to_image_latents: + g = torch.Generator(device=img_latent.device) + g.manual_seed(int(noise_seed)) + sigma = ( + torch.rand((img_latent.shape[0], 1, 1, 1, 1), device=img_latent.device, generator=g, dtype=img_latent.dtype) + * (float(image_noise_sigma_max) - float(image_noise_sigma_min)) + + float(image_noise_sigma_min) + ) + image_latent_prefix = sigma * torch.randn_like(image_latent_prefix, generator=g) + (1.0 - sigma) * image_latent_prefix + + min_frames = max(1, (int(num_latent_frames_per_chunk) - 1) * 4 + 1) + fake_video = image.repeat(min_frames, 1, 1, 1) + fake_latents_full = vae.encode(fake_video) + fake_latent = comfy.utils.repeat_to_batch_size(fake_latents_full[:, :, -1:], batch_size) + history_latent[:, :, -1:] = fake_latent + + positive, negative = _set_helios_history_values(positive, negative, history_latent, sizes, keep_first_frame, prefix_latent=image_latent_prefix) + return io.NodeOutput( + positive, + negative, + { + "samples": latent, + "helios_history_latent": history_latent, + "helios_image_latent_prefix": image_latent_prefix, + }, + ) + + +class HeliosVideoToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HeliosVideoToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=132, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("video", optional=True), + io.String.Input("history_sizes", default="16,2,1", advanced=True), + io.Boolean.Input("keep_first_frame", default=True, advanced=True), + io.Boolean.Input("add_noise_to_video_latents", default=True, advanced=True), + io.Float.Input("video_noise_sigma_min", default=0.111, min=0.0, max=1.0, step=0.0001, round=False, advanced=True), + io.Float.Input("video_noise_sigma_max", default=0.135, min=0.0, max=1.0, step=0.0001, round=False, advanced=True), + io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, advanced=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute( + cls, + positive, + negative, + vae, + width, + height, + length, + batch_size, + video=None, + history_sizes="16,2,1", + keep_first_frame=True, + add_noise_to_video_latents=True, + video_noise_sigma_min=0.111, + video_noise_sigma_max=0.135, + noise_seed=0, + ) -> io.NodeOutput: + spacial_scale = vae.spacial_compression_encode() + latent_channels = vae.latent_channels + latent_t = ((length - 1) // 4) + 1 + latent = torch.zeros([batch_size, latent_channels, latent_t, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device()) + + sizes = _parse_int_list(history_sizes, [16, 2, 1]) + if len(sizes) != 3: + sizes = [16, 2, 1] + sizes = sorted([max(0, int(v)) for v in sizes], reverse=True) + hist_len = max(1, sum(sizes)) + history_latent = torch.zeros([batch_size, latent_channels, hist_len, latent.shape[-2], latent.shape[-1]], device=latent.device, dtype=latent.dtype) + image_latent_prefix = None + + if video is not None: + video = comfy.utils.common_upscale(video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + vid_latent = vae.encode(video[:, :, :, :3]) + if add_noise_to_video_latents: + g = torch.Generator(device=vid_latent.device) + g.manual_seed(int(noise_seed)) + frame_sigmas = ( + torch.rand((1, 1, vid_latent.shape[2], 1, 1), device=vid_latent.device, generator=g, dtype=vid_latent.dtype) + * (float(video_noise_sigma_max) - float(video_noise_sigma_min)) + + float(video_noise_sigma_min) + ) + vid_latent = frame_sigmas * torch.randn_like(vid_latent, generator=g) + (1.0 - frame_sigmas) * vid_latent + vid_latent = vid_latent[:, :, :hist_len] + if vid_latent.shape[2] < hist_len: + pad = vid_latent[:, :, -1:].repeat(1, 1, hist_len - vid_latent.shape[2], 1, 1) + vid_latent = torch.cat([vid_latent, pad], dim=2) + vid_latent = comfy.utils.repeat_to_batch_size(vid_latent, batch_size) + history_latent = vid_latent + image_latent_prefix = history_latent[:, :, :1] + + positive, negative = _set_helios_history_values(positive, negative, history_latent, sizes, keep_first_frame, prefix_latent=image_latent_prefix) + return io.NodeOutput( + positive, + negative, + { + "samples": latent, + "helios_history_latent": history_latent, + "helios_image_latent_prefix": image_latent_prefix, + }, + ) + + +class HeliosHistoryConditioning(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HeliosHistoryConditioning", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Latent.Input("history_latent"), + io.String.Input("history_sizes", default="16,2,1"), + io.Boolean.Input("keep_first_frame", default=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + ], + ) + + @classmethod + def execute(cls, positive, negative, history_latent, history_sizes, keep_first_frame) -> io.NodeOutput: + latent = history_latent["samples"] + if latent is None or len(latent.shape) != 5: + return io.NodeOutput(positive, negative) + sizes = _parse_int_list(history_sizes, [16, 2, 1]) + sizes = sorted([max(0, int(v)) for v in sizes], reverse=True) + prefix = history_latent.get("helios_image_latent_prefix", None) + positive, negative = _set_helios_history_values(positive, negative, latent, sizes, keep_first_frame, prefix_latent=prefix) + return io.NodeOutput(positive, negative) + + +class HeliosPyramidSampler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HeliosPyramidSampler", + category="sampling/video_models", + inputs=[ + io.Model.Input("model"), + io.Boolean.Input("add_noise", default=True, advanced=True), + io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, control_after_generate=True), + io.Float.Input("cfg", default=5.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Latent.Input("latent_image"), + io.String.Input("pyramid_steps", default="10,10,10"), + io.String.Input("stage_range", default="0,0.333333,0.666667,1"), + io.Boolean.Input("is_distilled", default=False), + io.Boolean.Input("is_amplify_first_stage", default=False), + io.Combo.Input("scheduler_mode", options=["euler", "unipc_bh2"]), + io.Float.Input("gamma", default=1.0 / 3.0, min=0.0001, max=10.0, step=0.0001, round=False), + io.Float.Input("shift", default=1.0, min=0.001, max=100.0, step=0.001, round=False, advanced=True), + io.Boolean.Input("use_dynamic_shifting", default=False, advanced=True), + io.Combo.Input("time_shift_type", options=["exponential", "linear"], advanced=True), + io.Int.Input("base_image_seq_len", default=256, min=1, max=65536, advanced=True), + io.Int.Input("max_image_seq_len", default=4096, min=1, max=65536, advanced=True), + io.Float.Input("base_shift", default=0.5, min=0.0, max=10.0, step=0.0001, round=False, advanced=True), + io.Float.Input("max_shift", default=1.15, min=0.0, max=10.0, step=0.0001, round=False, advanced=True), + io.Int.Input("num_train_timesteps", default=1000, min=10, max=100000, advanced=True), + io.String.Input("history_sizes", default="16,2,1", advanced=True), + io.Boolean.Input("keep_first_frame", default=True, advanced=True), + io.Int.Input("num_latent_frames_per_chunk", default=9, min=1, max=256, advanced=True), + io.Boolean.Input("is_cfg_zero_star", default=False, advanced=True), + io.Boolean.Input("use_zero_init", default=True, advanced=True), + io.Int.Input("zero_steps", default=1, min=0, max=10000, advanced=True), + io.Boolean.Input("is_skip_first_chunk", default=False, advanced=True), + ], + outputs=[ + io.Latent.Output(display_name="output"), + io.Latent.Output(display_name="denoised_output"), + ], + ) + + @classmethod + def execute( + cls, + model, + add_noise, + noise_seed, + cfg, + positive, + negative, + latent_image, + pyramid_steps, + stage_range, + is_distilled, + is_amplify_first_stage, + scheduler_mode, + gamma, + shift, + use_dynamic_shifting, + time_shift_type, + base_image_seq_len, + max_image_seq_len, + base_shift, + max_shift, + num_train_timesteps, + history_sizes, + keep_first_frame, + num_latent_frames_per_chunk, + is_cfg_zero_star, + use_zero_init, + zero_steps, + is_skip_first_chunk, + ) -> io.NodeOutput: + latent = latent_image.copy() + latent_samples = comfy.sample.fix_empty_latent_channels(model, latent["samples"], latent.get("downscale_ratio_spacial", None)) + + stage_steps = _parse_int_list(pyramid_steps, [10, 10, 10]) + stage_steps = [max(1, int(s)) for s in stage_steps] + stage_count = len(stage_steps) + history_sizes_list = sorted([max(0, int(v)) for v in _parse_int_list(history_sizes, [16, 2, 1])], reverse=True) + + stage_range_values = _parse_float_list(stage_range, [0.0, 1.0 / 3.0, 2.0 / 3.0, 1.0]) + if len(stage_range_values) != stage_count + 1: + stage_range_values = [float(i) / float(stage_count) for i in range(stage_count + 1)] + + stage_tables = _helios_stage_tables( + stage_count=stage_count, + stage_range=stage_range_values, + gamma=float(gamma), + num_train_timesteps=int(num_train_timesteps), + shift=float(shift), + ) + + b, c, t, h, w = latent_samples.shape + chunk_t = max(1, int(num_latent_frames_per_chunk)) + chunk_count = max(1, (t + chunk_t - 1) // chunk_t) + low_scale = 2 ** max(0, stage_count - 1) + low_h = max(1, h // low_scale) + low_w = max(1, w // low_scale) + + base_latent = torch.zeros((b, c, chunk_t, low_h, low_w), dtype=latent_samples.dtype, layout=latent_samples.layout, device=latent_samples.device) + + if add_noise: + stage_latent = comfy.sample.prepare_noise(base_latent, noise_seed) + else: + stage_latent = torch.zeros_like(base_latent, device="cpu") + + stage_latent = stage_latent.to(base_latent.dtype).to(comfy.model_management.intermediate_device()) + euler_sampler = comfy.samplers.KSAMPLER(_helios_euler_sample) + + latents_history_short = _extract_condition_value(positive, "latents_history_short") + latents_history_mid = _extract_condition_value(positive, "latents_history_mid") + latents_history_long = _extract_condition_value(positive, "latents_history_long") + image_latent_prefix = latent.get("helios_image_latent_prefix", None) + if latents_history_short is None and "helios_history_latent" in latent: + positive, negative = _set_helios_history_values( + positive, + negative, + latent["helios_history_latent"], + history_sizes_list, + keep_first_frame, + prefix_latent=image_latent_prefix, + ) + latents_history_short = _extract_condition_value(positive, "latents_history_short") + latents_history_mid = _extract_condition_value(positive, "latents_history_mid") + latents_history_long = _extract_condition_value(positive, "latents_history_long") + + x0_output = {} + generated_chunks = [] + if latents_history_short is not None and latents_history_mid is not None and latents_history_long is not None: + rolling_history = torch.cat([latents_history_long, latents_history_mid, latents_history_short], dim=2) + elif "helios_history_latent" in latent: + rolling_history = latent["helios_history_latent"] + else: + hist_len = max(1, sum(history_sizes_list)) + rolling_history = torch.zeros((b, c, hist_len, h, w), device=latent_samples.device, dtype=latent_samples.dtype) + + for chunk_idx in range(chunk_count): + if add_noise: + stage_latent = comfy.sample.prepare_noise(base_latent, noise_seed + chunk_idx).to(base_latent.dtype).to(comfy.model_management.intermediate_device()) + else: + stage_latent = torch.zeros_like(base_latent, device=comfy.model_management.intermediate_device()) + + positive_chunk, negative_chunk = _set_helios_history_values( + positive, + negative, + rolling_history, + history_sizes_list, + keep_first_frame, + prefix_latent=image_latent_prefix, + ) + latents_history_short = _extract_condition_value(positive_chunk, "latents_history_short") + latents_history_mid = _extract_condition_value(positive_chunk, "latents_history_mid") + latents_history_long = _extract_condition_value(positive_chunk, "latents_history_long") + + for stage_idx in range(stage_count): + if stage_idx > 0: + stage_latent = _upsample_latent_5d(stage_latent, scale=2) + + ori_sigma = 1.0 - float(stage_tables["ori_start_sigmas"][stage_idx]) + alpha = 1.0 / (math.sqrt(1.0 + (1.0 / gamma)) * (1.0 - ori_sigma) + ori_sigma) + beta = alpha * (1.0 - ori_sigma) / math.sqrt(gamma) + + noise = _sample_block_noise_like(stage_latent, gamma, patch_size=(1, 2, 2)).to(stage_latent) + stage_latent = alpha * stage_latent + beta * noise + + sigmas = _helios_stage_sigmas( + stage_idx=stage_idx, + stage_steps=stage_steps[stage_idx], + stage_tables=stage_tables, + is_distilled=is_distilled, + is_amplify_first_stage=is_amplify_first_stage and chunk_idx == 0, + ).to(stage_latent.dtype) + timesteps = _helios_stage_timesteps( + stage_idx=stage_idx, + stage_steps=stage_steps[stage_idx], + stage_tables=stage_tables, + is_distilled=is_distilled, + is_amplify_first_stage=is_amplify_first_stage and chunk_idx == 0, + ).to(stage_latent.dtype) + if use_dynamic_shifting: + patch_size = (1, 2, 2) + image_seq_len = (stage_latent.shape[-1] * stage_latent.shape[-2] * stage_latent.shape[-3]) // (patch_size[0] * patch_size[1] * patch_size[2]) + mu = _calculate_shift( + image_seq_len=image_seq_len, + base_seq_len=base_image_seq_len, + max_seq_len=max_image_seq_len, + base_shift=base_shift, + max_shift=max_shift, + ) + sigmas = _time_shift(sigmas, mu=mu, sigma=1.0, mode=time_shift_type).to(stage_latent.dtype) + tmin = torch.min(timesteps) + tmax = torch.max(timesteps) + timesteps = tmin + sigmas[:-1] * (tmax - tmin) + + indices_hidden_states, idx_short, idx_mid, idx_long = _build_helios_indices( + batch=stage_latent.shape[0], + history_sizes=history_sizes_list, + keep_first_frame=keep_first_frame, + hidden_frames=stage_latent.shape[2], + device=stage_latent.device, + dtype=stage_latent.dtype, + ) + positive_stage = node_helpers.conditioning_set_values(positive_chunk, {"indices_hidden_states": indices_hidden_states}) + negative_stage = node_helpers.conditioning_set_values(negative_chunk, {"indices_hidden_states": indices_hidden_states}) + + if latents_history_short is not None: + values = {"latents_history_short": latents_history_short, "indices_latents_history_short": idx_short} + positive_stage = node_helpers.conditioning_set_values(positive_stage, values) + negative_stage = node_helpers.conditioning_set_values(negative_stage, values) + + if latents_history_mid is not None: + values = {"latents_history_mid": latents_history_mid, "indices_latents_history_mid": idx_mid} + positive_stage = node_helpers.conditioning_set_values(positive_stage, values) + negative_stage = node_helpers.conditioning_set_values(negative_stage, values) + + if latents_history_long is not None: + values = {"latents_history_long": latents_history_long, "indices_latents_history_long": idx_long} + positive_stage = node_helpers.conditioning_set_values(positive_stage, values) + negative_stage = node_helpers.conditioning_set_values(negative_stage, values) + + cfg_use = 1.0 if is_distilled else cfg + + if stage_idx == 0 and add_noise: + noise = comfy.sample.prepare_noise(stage_latent, noise_seed + chunk_idx * 100 + stage_idx) + latent_start = torch.zeros_like(stage_latent) + else: + sigma0 = max(float(sigmas[0].item()), 1e-6) + noise = (stage_latent / sigma0).to("cpu") + latent_start = torch.zeros_like(stage_latent) + + stage_start_for_dmd = stage_latent.clone() + + if is_distilled: + sampler = comfy.samplers.KSAMPLER( + _helios_dmd_sample, + extra_options={ + "dmd_noisy_tensor": stage_start_for_dmd, + "dmd_sigmas": sigmas, + "dmd_timesteps": timesteps, + "all_timesteps": timesteps, + }, + ) + else: + if scheduler_mode == "unipc_bh2": + sampler = comfy.samplers.ksampler("uni_pc_bh2") + else: + sampler = euler_sampler + + callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output) + stage_model = model + if is_cfg_zero_star and not is_distilled: + stage_model = model.clone() + stage_model.model_options = comfy.model_patcher.set_model_options_pre_cfg_function( + stage_model.model_options, + _build_cfg_zero_star_pre_cfg(stage_idx=stage_idx, zero_steps=zero_steps, use_zero_init=use_zero_init), + disable_cfg1_optimization=True, + ) + stage_latent = comfy.sample.sample_custom( + stage_model, + noise, + cfg_use, + sampler, + sigmas, + positive_stage, + negative_stage, + latent_start, + noise_mask=None, + callback=callback, + disable_pbar=not comfy.utils.PROGRESS_BAR_ENABLED, + seed=noise_seed + chunk_idx * 100 + stage_idx, + ) + + if stage_latent.shape[-2] != h or stage_latent.shape[-1] != w: + b2, c2, t2, h2, w2 = stage_latent.shape + x = stage_latent.permute(0, 2, 1, 3, 4).reshape(b2 * t2, c2, h2, w2) + x = comfy.utils.common_upscale(x, w, h, "nearest-exact", "disabled") + stage_latent = x.reshape(b2, t2, c2, h, w).permute(0, 2, 1, 3, 4) + stage_latent = stage_latent[:, :, :, :h, :w] + + generated_chunks.append(stage_latent) + if keep_first_frame and ((chunk_idx == 0 and image_latent_prefix is None) or (is_skip_first_chunk and chunk_idx == 1)): + image_latent_prefix = stage_latent[:, :, :1] + rolling_history = torch.cat([rolling_history, stage_latent.to(rolling_history.device, rolling_history.dtype)], dim=2) + keep_hist = max(1, sum(history_sizes_list)) + rolling_history = rolling_history[:, :, -keep_hist:] + + stage_latent = torch.cat(generated_chunks, dim=2)[:, :, :t] + + out = latent.copy() + out.pop("downscale_ratio_spacial", None) + out["samples"] = stage_latent + + if "x0" in x0_output: + x0_out = model.model.process_latent_out(x0_output["x0"].cpu()) + out_denoised = latent.copy() + out_denoised["samples"] = x0_out + else: + out_denoised = out + + return io.NodeOutput(out, out_denoised) + + +class HeliosExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + HeliosImageToVideo, + HeliosVideoToVideo, + HeliosHistoryConditioning, + HeliosPyramidSampler, + ] + + +async def comfy_entrypoint() -> HeliosExtension: + return HeliosExtension() diff --git a/nodes.py b/nodes.py index 5be9b16f9..734642915 100644 --- a/nodes.py +++ b/nodes.py @@ -976,7 +976,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "helios", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -986,7 +986,7 @@ class CLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B" + DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\nhelios: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B" def load_clip(self, clip_name, type="stable_diffusion", device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) @@ -2412,6 +2412,7 @@ async def init_builtin_extra_nodes(): "nodes_cosmos.py", "nodes_video.py", "nodes_lumina2.py", + "nodes_helios.py", "nodes_wan.py", "nodes_lotus.py", "nodes_hunyuan3d.py", From d93133ee5350c14e776f0560a975903dab2a60f7 Mon Sep 17 00:00:00 2001 From: qqingzheng <2533221180@qq.com> Date: Sun, 8 Mar 2026 03:44:13 +0800 Subject: [PATCH 02/10] Refactor Helios integration and latent processing with new T2V support. --- comfy/latent_formats.py | 8 + comfy/ldm/helios/model.py | 188 ++++++++---- comfy/model_base.py | 41 ++- comfy/model_detection.py | 12 +- comfy/supported_models.py | 2 +- comfy_extras/nodes_helios.py | 536 +++++++++++++++++++++++++++-------- 6 files changed, 611 insertions(+), 176 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 6a57bca1c..91db60ab5 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -783,3 +783,11 @@ class ZImagePixelSpace(ChromaRadiance): No VAE encoding/decoding — the model operates directly on RGB pixels. """ pass + +class Helios(Wan21): + """Helios video model latent format + + Helios uses the same latent format as Wan21 (same VAE architecture). + Inherits latents_mean, latents_std, and processing methods from Wan21. + """ + pass diff --git a/comfy/ldm/helios/model.py b/comfy/ldm/helios/model.py index 5ffc91129..6fd37b875 100644 --- a/comfy/ldm/helios/model.py +++ b/comfy/ldm/helios/model.py @@ -6,11 +6,12 @@ import torch.nn as nn from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope1 -from comfy.ldm.wan.model import sinusoidal_embedding_1d, repeat_e +from comfy.ldm.wan.model import sinusoidal_embedding_1d import comfy.ldm.common_dit import comfy.patcher_extension + def pad_for_3d_conv(x, kernel_size): b, c, t, h, w = x.shape pt, ph, pw = kernel_size @@ -20,6 +21,10 @@ def pad_for_3d_conv(x, kernel_size): return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate") +def center_down_sample_3d(x, kernel_size): + return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size) + + class OutputNorm(nn.Module): def __init__(self, dim, eps=1e-6, operation_settings={}): @@ -50,7 +55,8 @@ class OutputNorm(nn.Module): shift = shift.squeeze(2).to(hidden_states.device) scale = scale.squeeze(2).to(hidden_states.device) hidden_states = hidden_states[:, -original_context_length:, :] - hidden_states = self.norm(hidden_states) * (1 + scale) + shift + # Use float32 for numerical stability like diffusers + hidden_states = (self.norm(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) return hidden_states @@ -272,36 +278,69 @@ class HeliosAttentionBlock(nn.Module): def forward(self, x, context, e, freqs, original_context_length=None, transformer_options={}): if e.ndim == 4: - e = (self.scale_shift_table.unsqueeze(0) + e.float()).chunk(6, dim=2) - e = [v.squeeze(2) for v in e] + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.unsqueeze(0).to(e.device) + e.float() + ).chunk(6, dim=2) + shift_msa = shift_msa.squeeze(2) + scale_msa = scale_msa.squeeze(2) + gate_msa = gate_msa.squeeze(2) + c_shift_msa = c_shift_msa.squeeze(2) + c_scale_msa = c_scale_msa.squeeze(2) + c_gate_msa = c_gate_msa.squeeze(2) else: - e = (self.scale_shift_table + e.float()).chunk(6, dim=1) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.to(e.device) + e.float() + ).chunk(6, dim=1) # self-attn + # Use float32 for numerical stability like diffusers + # norm1 has elementwise_affine=False, so we can safely convert to float32 + norm_x = self.norm1(x.float()) + norm_x = (norm_x * (1 + scale_msa) + shift_msa).type_as(x) y = self.attn1( - torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), + norm_x, freqs=freqs, original_context_length=original_context_length, transformer_options=transformer_options, ) - x = torch.addcmul(x, y, repeat_e(e[2], x)) + x = (x.float() + y.float() * gate_msa).type_as(x) # cross-attn if self.guidance_cross_attn and original_context_length is not None: history_seq_len = x.shape[1] - original_context_length history_x, x_main = torch.split(x, [history_seq_len, original_context_length], dim=1) + # norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior + norm_x_main = torch.nn.functional.layer_norm( + x_main.float(), + self.norm2.normalized_shape, + self.norm2.weight.to(x_main.device).float() if self.norm2.weight is not None else None, + self.norm2.bias.to(x_main.device).float() if self.norm2.bias is not None else None, + self.norm2.eps, + ).type_as(x_main) x_main = x_main + self.attn2( - self.norm2(x_main), + norm_x_main, context=context, transformer_options=transformer_options, ) x = torch.cat([history_x, x_main], dim=1) else: - x = x + self.attn2(self.norm2(x), context=context, transformer_options=transformer_options) + # norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior + norm_x = torch.nn.functional.layer_norm( + x.float(), + self.norm2.normalized_shape, + self.norm2.weight.to(x.device).float() if self.norm2.weight is not None else None, + self.norm2.bias.to(x.device).float() if self.norm2.bias is not None else None, + self.norm2.eps, + ).type_as(x) + x = x + self.attn2(norm_x, context=context, transformer_options=transformer_options) # ffn - y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm3(x), 1 + repeat_e(e[4], x))) - x = torch.addcmul(x, y, repeat_e(e[5], x)) + # Use float32 for numerical stability like diffusers + # norm3 has elementwise_affine=False, so we can safely convert to float32 + norm_x = self.norm3(x.float()) + norm_x = (norm_x * (1 + c_scale_msa) + c_shift_msa).type_as(x) + y = self.ffn(norm_x) + x = (x.float() + y.float() * c_gate_msa).type_as(x) return x @@ -358,7 +397,7 @@ class HeliosModel(torch.nn.Module): kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), - dtype=torch.float32, + dtype=operation_settings.get("dtype"), ) self.text_embedding = nn.Sequential( operations.Linear( @@ -411,7 +450,7 @@ class HeliosModel(torch.nn.Module): kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), - dtype=torch.float32, + dtype=operation_settings.get("dtype"), ) self.patch_mid = operations.Conv3d( in_channels, @@ -419,7 +458,7 @@ class HeliosModel(torch.nn.Module): kernel_size=tuple(2 * p for p in patch_size), stride=tuple(2 * p for p in patch_size), device=operation_settings.get("device"), - dtype=torch.float32, + dtype=operation_settings.get("dtype"), ) self.patch_long = operations.Conv3d( in_channels, @@ -427,7 +466,7 @@ class HeliosModel(torch.nn.Module): kernel_size=tuple(4 * p for p in patch_size), stride=tuple(4 * p for p in patch_size), device=operation_settings.get("device"), - dtype=torch.float32, + dtype=operation_settings.get("dtype"), ) # blocks @@ -592,7 +631,7 @@ class HeliosModel(torch.nn.Module): p_t, p_h, p_w = self.patch_size # embeddings - hidden_states = self.patch_embedding(hidden_states.float()).to(hidden_states.dtype) + hidden_states = self.patch_embedding(hidden_states) _, _, post_t, post_h, post_w = hidden_states.shape hidden_states = hidden_states.flatten(2).transpose(1, 2) @@ -614,7 +653,7 @@ class HeliosModel(torch.nn.Module): original_context_length = hidden_states.shape[1] if (latents_history_short is not None and indices_latents_history_short is not None and hasattr(self, "patch_short")): - x_short = self.patch_short(latents_history_short.float()).to(hidden_states.dtype) + x_short = self.patch_short(latents_history_short).to(hidden_states.dtype) _, _, ts, hs, ws = x_short.shape x_short = x_short.flatten(2).transpose(1, 2) f_short = self.rope_encode( @@ -633,44 +672,70 @@ class HeliosModel(torch.nn.Module): freqs = torch.cat([f_short, freqs], dim=1) if (latents_history_mid is not None and indices_latents_history_mid is not None and hasattr(self, "patch_mid")): - x_mid = self.patch_mid(pad_for_3d_conv(latents_history_mid, (2, 4, 4)).float()).to(hidden_states.dtype) + x_mid = self.patch_mid(pad_for_3d_conv(latents_history_mid, (2, 4, 4))).to(hidden_states.dtype) _, _, tm, hm, wm = x_mid.shape x_mid = x_mid.flatten(2).transpose(1, 2) + mid_t = indices_latents_history_mid.shape[1] + if ("hs" in locals()) and ("ws" in locals()): + mid_h, mid_w = hs, ws + else: + mid_h, mid_w = hm * 2, wm * 2 f_mid = self.rope_encode( - t=tm * self.patch_size[0], - h=hm * self.patch_size[1], - w=wm * self.patch_size[2], - steps_t=tm, - steps_h=hm, - steps_w=wm, + t=mid_t * self.patch_size[0], + h=mid_h * self.patch_size[1], + w=mid_w * self.patch_size[2], + steps_t=mid_t, + steps_h=mid_h, + steps_w=mid_w, device=x_mid.device, dtype=x_mid.dtype, transformer_options=transformer_options, frame_indices=indices_latents_history_mid, ) + f_mid = self._rope_downsample_3d(f_mid, (mid_t, mid_h, mid_w), (2, 2, 2)) + if f_mid.shape[1] != x_mid.shape[1]: + f_mid = f_mid[:, :x_mid.shape[1]] hidden_states = torch.cat([x_mid, hidden_states], dim=1) freqs = torch.cat([f_mid, freqs], dim=1) if (latents_history_long is not None and indices_latents_history_long is not None and hasattr(self, "patch_long")): - x_long = self.patch_long(pad_for_3d_conv(latents_history_long, (4, 8, 8)).float()).to(hidden_states.dtype) + x_long = self.patch_long(pad_for_3d_conv(latents_history_long, (4, 8, 8))).to(hidden_states.dtype) _, _, tl, hl, wl = x_long.shape x_long = x_long.flatten(2).transpose(1, 2) + long_t = indices_latents_history_long.shape[1] + if ("hs" in locals()) and ("ws" in locals()): + long_h, long_w = hs, ws + else: + long_h, long_w = hl * 4, wl * 4 f_long = self.rope_encode( - t=tl * self.patch_size[0], - h=hl * self.patch_size[1], - w=wl * self.patch_size[2], - steps_t=tl, - steps_h=hl, - steps_w=wl, + t=long_t * self.patch_size[0], + h=long_h * self.patch_size[1], + w=long_w * self.patch_size[2], + steps_t=long_t, + steps_h=long_h, + steps_w=long_w, device=x_long.device, dtype=x_long.dtype, transformer_options=transformer_options, frame_indices=indices_latents_history_long, ) + f_long = self._rope_downsample_3d(f_long, (long_t, long_h, long_w), (4, 4, 4)) + if f_long.shape[1] != x_long.shape[1]: + f_long = f_long[:, :x_long.shape[1]] hidden_states = torch.cat([x_long, hidden_states], dim=1) freqs = torch.cat([f_long, freqs], dim=1) history_context_length = hidden_states.shape[1] - original_context_length + mismatch = hidden_states.shape[1] != freqs.shape[1] + summary_key = ( + int(post_t), + int(post_h), + int(post_w), + int(original_context_length), + int(hidden_states.shape[1]), + int(freqs.shape[1]), + int(history_context_length), + ) if timestep.ndim == 0: timestep = timestep.unsqueeze(0) @@ -682,7 +747,7 @@ class HeliosModel(torch.nn.Module): e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=hidden_states.dtype)) e = e.reshape(batch_size, -1, e.shape[-1]) e0 = self.time_projection(e).unflatten(2, (6, self.dim)) - context = self.text_embedding(context) + context = self.text_embedding(context.to(dtype=hidden_states.dtype)) if self.zero_history_timestep and history_context_length > 0: timestep_t0 = torch.zeros((1, ), dtype=timestep.dtype, device=timestep.device) @@ -701,7 +766,7 @@ class HeliosModel(torch.nn.Module): e0 = e0.permute(0, 2, 1, 3) - for block in self.blocks: + for i_b, block in enumerate(self.blocks): hidden_states = block( hidden_states, context, @@ -710,35 +775,46 @@ class HeliosModel(torch.nn.Module): original_context_length=original_context_length, transformer_options=transformer_options, ) - hidden_states = self.norm_out(hidden_states, e, original_context_length) hidden_states = self.proj_out(hidden_states) - return self.unpatchify(hidden_states, (post_t, post_h, post_w)) def unpatchify(self, x, grid_sizes): - c = self.out_dim + """ + Unpatchify the output from proj_out back to video format. + + Args: + x: [batch, num_patches, out_dim * prod(patch_size)] + grid_sizes: (num_frames, height, width) in patch space + + Returns: + [batch, out_dim, num_frames, height, width] in pixel space + """ b = x.shape[0] - u = x[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c) - u = torch.einsum("bfhwpqrc->bcfphqwr", u) - u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)]) - return u - - def load_state_dict(self, state_dict, strict=True, assign=False): - # Keep compatibility with reference diffusers key names. - remapped = {} - for k, v in state_dict.items(): - nk = k - nk = nk.replace("condition_embedder.time_embedder.linear_1.", "time_embedding.0.") - nk = nk.replace("condition_embedder.time_embedder.linear_2.", "time_embedding.2.") - nk = nk.replace("condition_embedder.time_proj.", "time_projection.1.") - nk = nk.replace("condition_embedder.text_embedder.linear_1.", "text_embedding.0.") - nk = nk.replace("condition_embedder.text_embedder.linear_2.", "text_embedding.2.") - nk = nk.replace("blocks.", "blocks.") - remapped[nk] = v - - return super().load_state_dict(remapped, strict=strict, assign=assign) - + post_t, post_h, post_w = grid_sizes + p_t, p_h, p_w = self.patch_size + + # Reshape: [B, T*H*W, out_dim*p_t*p_h*p_w] -> [B, T, H, W, p_t, p_h, p_w, out_dim] + # Use -1 to let PyTorch infer the channel dimension (out_dim) + hidden_states = x.reshape(b, post_t, post_h, post_w, p_t, p_h, p_w, -1) + + # Permute: [B, T, H, W, p_t, p_h, p_w, C] -> [B, C, T, p_t, H, p_h, W, p_w] + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + + # Flatten patches: [B, C, T, p_t, H, p_h, W, p_w] -> [B, C, T*p_t, H*p_h, W*p_w] + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return output + def _rope_downsample_3d(self, freqs, grid_sizes, kernel_size): + b, _, one, d, i2, j2 = freqs.shape + gt, gh, gw = grid_sizes + c = one * d * i2 * j2 + freqs_3d = freqs.reshape(b, gt, gh, gw, c).permute(0, 4, 1, 2, 3) + freqs_3d = pad_for_3d_conv(freqs_3d, kernel_size) + freqs_3d = center_down_sample_3d(freqs_3d, kernel_size) + dt, dh, dw = freqs_3d.shape[2:] + freqs_3d = freqs_3d.permute(0, 2, 3, 4, 1).reshape(b, dt * dh * dw, one, d, i2, j2) + return freqs_3d # Backward-compatible alias for existing integration points. HeliosTransformer3DModel = HeliosModel diff --git a/comfy/model_base.py b/comfy/model_base.py index 9bee3049a..d2d178b48 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1287,17 +1287,52 @@ class Helios(BaseModel): "latents_history_short", "latents_history_mid", "latents_history_long", + "helios_stage_sigmas", + "helios_stage_timesteps", ) for key in cond_keys: value = kwargs.get(key, None) if value is None: continue - if key.startswith("latents_"): - value = self.process_latent_in(value) - out[key] = comfy.conds.CONDRegular(value) + # Diffusers forwards Helios history latents without latent-format re-normalization. + # Keep raw history tensors to match transformer inputs across frameworks. + if key in ("helios_stage_sigmas", "helios_stage_timesteps"): + out[key] = comfy.conds.CONDConstant(value) + else: + out[key] = comfy.conds.CONDRegular(value) return out + def process_timestep(self, timestep, **kwargs): + stage_sigmas = kwargs.get("helios_stage_sigmas", None) + stage_timesteps = kwargs.get("helios_stage_timesteps", None) + if stage_sigmas is None or stage_timesteps is None: + return timestep + + if stage_sigmas.ndim > 1: + stage_sigmas = stage_sigmas[0] + if stage_timesteps.ndim > 1: + stage_timesteps = stage_timesteps[0] + + if stage_timesteps.numel() == 0 or stage_sigmas.numel() == 0: + return timestep + + if stage_sigmas.numel() == stage_timesteps.numel() + 1: + sigma_candidates = stage_sigmas[:-1] + else: + sigma_candidates = stage_sigmas[: stage_timesteps.numel()] + + if sigma_candidates.numel() == 0: + return timestep + + multiplier = float(getattr(self.model_sampling, "multiplier", 1000.0)) + sigma_in = timestep / multiplier + idx = torch.argmin(torch.abs(sigma_in.unsqueeze(-1) - sigma_candidates.unsqueeze(0)), dim=-1) + mapped = stage_timesteps[idx].to(dtype=timestep.dtype) + if mapped.dtype.is_floating_point: + mapped = torch.floor(mapped) + return mapped + class WAN21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7a130c02d..ae4f254ef 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -489,7 +489,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return dit_config - if '{}condition_embedder.time_proj.weight'.format(key_prefix) in state_dict_keys and '{}patch_embedding.weight'.format(key_prefix) in state_dict_keys: # Helios + helios_required_keys = ( + '{}patch_mid.weight'.format(key_prefix), + '{}patch_long.weight'.format(key_prefix), + ) + if all(k in state_dict_keys for k in helios_required_keys): # Helios dit_config = {} dit_config["image_model"] = "helios" @@ -501,8 +505,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["patch_size"] = patch_size dit_config["in_channels"] = patch_weight.shape[1] dit_config["out_channels"] = out_proj.shape[0] // math.prod(patch_size) - dit_config["text_dim"] = state_dict['{}condition_embedder.text_embedder.linear_1.weight'.format(key_prefix)].shape[1] - dit_config["freq_dim"] = state_dict['{}condition_embedder.time_embedder.linear_1.weight'.format(key_prefix)].shape[1] + text_w = state_dict['{}text_embedding.0.weight'.format(key_prefix)] + time_w = state_dict['{}time_embedding.0.weight'.format(key_prefix)] + dit_config["text_dim"] = text_w.shape[1] + dit_config["freq_dim"] = time_w.shape[1] dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.') dit_config["num_attention_heads"] = inner_dim // 128 dit_config["attention_head_dim"] = 128 diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2035f25b8..b0fb3ce3d 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1143,7 +1143,7 @@ class Helios(supported_models_base.BASE): } unet_extra_config = {} - latent_format = latent_formats.Wan21 + latent_format = latent_formats.Helios memory_usage_factor = 1.8 supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] diff --git a/comfy_extras/nodes_helios.py b/comfy_extras/nodes_helios.py index 6c1fd7e20..13894082a 100644 --- a/comfy_extras/nodes_helios.py +++ b/comfy_extras/nodes_helios.py @@ -14,6 +14,9 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, io + + + def _parse_int_list(values, default): if values is None: return default @@ -72,15 +75,73 @@ def _extract_condition_value(conditioning, key): return None +def _process_latent_in_preserve_zero_frames(model, latent, valid_mask=None): + if latent is None or len(latent.shape) != 5: + return latent + if valid_mask is None: + raise ValueError("Helios requires `helios_history_valid_mask` for history latent conversion.") + vm = valid_mask + if not torch.is_tensor(vm): + vm = torch.tensor(vm, device=latent.device) + vm = vm.to(device=latent.device) + if vm.ndim == 2: + nonzero = vm.any(dim=0) + else: + nonzero = vm.reshape(-1) + nonzero = nonzero.bool() + + if nonzero.numel() == 0 or (not torch.any(nonzero)): + return latent + + if nonzero.shape[0] != latent.shape[2]: + # Keep behavior safe when mask length does not match temporal length. + nonzero = torch.zeros((latent.shape[2],), device=latent.device, dtype=torch.bool) + + converted = model.model.process_latent_in(latent) + out = latent.clone() + out[:, :, nonzero, :, :] = converted[:, :, nonzero, :, :] + return out + + def _upsample_latent_5d(latent, scale=2): b, c, t, h, w = latent.shape x = latent.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = comfy.utils.common_upscale(x, w * scale, h * scale, "nearest-exact", "disabled") + x = comfy.utils.common_upscale(x, w * scale, h * scale, "nearest", "disabled") x = x.reshape(b, t, c, h * scale, w * scale).permute(0, 2, 1, 3, 4) return x -def _sample_block_noise_like(latent, gamma, patch_size=(1, 2, 2)): +def _downsample_latent_5d_bilinear_x2(latent): + b, c, t, h, w = latent.shape + x = latent.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = comfy.utils.common_upscale(x, max(1, w // 2), max(1, h // 2), "bilinear", "disabled") * 2.0 + x = x.reshape(b, t, c, max(1, h // 2), max(1, w // 2)).permute(0, 2, 1, 3, 4) + return x + + +def _prepare_stage0_latent(batch, channels, frames, height, width, stage_count, add_noise, seed, dtype, layout, device): + """Prepare initial latent for stage 0 with optional noise""" + full_latent = torch.zeros((batch, channels, frames, height, width), dtype=dtype, layout=layout, device=device) + if add_noise: + full_latent = comfy.sample.prepare_noise(full_latent, seed).to(dtype) + + # Downsample to stage 0 resolution + stage_latent = full_latent + for _ in range(max(0, int(stage_count) - 1)): + stage_latent = _downsample_latent_5d_bilinear_x2(stage_latent) + return stage_latent + + +def _downsample_latent_for_stage0(latent, stage_count): + """Downsample latent to stage 0 resolution (like Diffusers does)""" + stage_latent = latent + for _ in range(max(0, int(stage_count) - 1)): + stage_latent = _downsample_latent_5d_bilinear_x2(stage_latent) + return stage_latent + + + +def _sample_block_noise_like(latent, gamma, patch_size=(1, 2, 2), generator=None, seed=None): b, c, t, h, w = latent.shape _, ph, pw = patch_size block_size = ph * pw @@ -88,13 +149,38 @@ def _sample_block_noise_like(latent, gamma, patch_size=(1, 2, 2)): cov = torch.eye(block_size, device=latent.device) * (1.0 + gamma) - torch.ones(block_size, block_size, device=latent.device) * gamma cov += torch.eye(block_size, device=latent.device) * 1e-6 - dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=latent.device), covariance_matrix=cov) - block_number = b * c * t * max(1, h // ph) * max(1, w // pw) + h_blocks = h // ph + w_blocks = w // pw + block_number = b * c * t * h_blocks * w_blocks - noise = dist.sample((block_number,)) - noise = noise.view(b, c, t, max(1, h // ph), max(1, w // pw), ph, pw) - noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(b, c, t, max(1, h // ph) * ph, max(1, w // pw) * pw) - noise = noise[:, :, :, :h, :w] + if generator is not None: + # Exact Diffusers sampling path (MultivariateNormal.sample), while consuming + # from an explicit generator by temporarily swapping default RNG state. + with torch.random.fork_rng(devices=[latent.device] if latent.device.type == "cuda" else []): + if latent.device.type == "cuda": + torch.cuda.set_rng_state(generator.get_state(), device=latent.device) + else: + torch.random.set_rng_state(generator.get_state()) + dist = torch.distributions.MultivariateNormal( + torch.zeros(block_size, device=latent.device), + covariance_matrix=cov, + ) + noise = dist.sample((block_number,)) + if latent.device.type == "cuda": + generator.set_state(torch.cuda.get_rng_state(device=latent.device)) + else: + generator.set_state(torch.random.get_rng_state()) + elif seed is None: + dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=latent.device), covariance_matrix=cov) + noise = dist.sample((block_number,)) + else: + # Use deterministic RNG when seed is provided (for cross-framework alignment). + with torch.random.fork_rng(devices=[latent.device] if latent.device.type == "cuda" else []): + torch.manual_seed(int(seed)) + dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=latent.device), covariance_matrix=cov) + noise = dist.sample((block_number,)) + noise = noise.view(b, c, t, h_blocks, w_blocks, ph, pw) + noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(b, c, t, h, w) return noise @@ -144,8 +230,10 @@ def _helios_stage_tables(stage_count, stage_range, gamma, num_train_timesteps=10 tmax = min(float(sigmas[int(start_ratio * num_train_timesteps)].item() * num_train_timesteps), 999.0) tmin = float(sigmas[min(int(end_ratio * num_train_timesteps), num_train_timesteps - 1)].item() * num_train_timesteps) - timesteps_per_stage[i] = torch.linspace(tmax, tmin, num_train_timesteps) - sigmas_per_stage[i] = torch.linspace(0.999, 0.0, num_train_timesteps) + timesteps_per_stage[i] = torch.linspace(tmax, tmin, num_train_timesteps + 1)[:-1] + # Fixed: Use same sigma range [0.999, 0] for all stages like Diffusers + sigmas_per_stage[i] = torch.linspace(0.999, 0.0, num_train_timesteps + 1)[:-1] + return { "ori_start_sigmas": ori_start_sigmas, @@ -163,7 +251,8 @@ def _helios_stage_sigmas(stage_idx, stage_steps, stage_tables, is_distilled=Fals stage_steps = stage_steps * 2 if (is_amplify_first_stage and stage_idx == 0) else stage_steps stage_sigma_src = stage_tables["sigmas_per_stage"][stage_idx] - sigmas = torch.linspace(float(stage_sigma_src[0].item()), float(stage_sigma_src[-1].item()), stage_steps + 1) + sigmas = torch.linspace(float(stage_sigma_src[0].item()), float(stage_sigma_src[-1].item()), stage_steps) + sigmas = torch.cat([sigmas, torch.zeros(1, dtype=sigmas.dtype, device=sigmas.device)], dim=0) return sigmas @@ -213,23 +302,37 @@ def _build_cfg_zero_star_pre_cfg(stage_idx, zero_steps, use_zero_init): state["i"] += 1 return conds_out - noise_pred_text = conds_out[0] - noise_uncond = conds_out[1] + denoised_text = conds_out[0] # apply_model 返回的 denoised + denoised_uncond = conds_out[1] cfg = float(args.get("cond_scale", 1.0)) + x = args["input"] # 当前的 noisy latent + sigma = args["sigma"] # 当前的 sigma - positive_flat = noise_pred_text.view(noise_pred_text.shape[0], -1) - negative_flat = noise_uncond.view(noise_uncond.shape[0], -1) + # 关键修复:将 denoised 转换为 flow + # denoised = x - flow * sigma => flow = (x - denoised) / sigma + sigma_reshaped = sigma.reshape(sigma.shape[0], *([1] * (denoised_text.ndim - 1))) + sigma_safe = torch.clamp(sigma_reshaped, min=1e-8) + + flow_text = (x - denoised_text) / sigma_safe + flow_uncond = (x - denoised_uncond) / sigma_safe + + # 在 flow 空间做 CFG Zero Star + positive_flat = flow_text.reshape(flow_text.shape[0], -1) + negative_flat = flow_uncond.reshape(flow_uncond.shape[0], -1) alpha = _optimized_scale(positive_flat, negative_flat) - alpha = alpha.view(noise_pred_text.shape[0], *([1] * (noise_pred_text.ndim - 1))).to(noise_pred_text.dtype) + alpha = alpha.reshape(flow_text.shape[0], *([1] * (flow_text.ndim - 1))).to(flow_text.dtype) if stage_idx == 0 and state["i"] <= int(zero_steps) and bool(use_zero_init): - final = noise_pred_text * 0.0 + flow_final = flow_text * 0.0 else: - final = noise_uncond * alpha + cfg * (noise_pred_text - noise_uncond * alpha) + flow_final = flow_uncond * alpha + cfg * (flow_text - flow_uncond * alpha) + + # 将 flow 转回 denoised + denoised_final = x - flow_final * sigma_safe state["i"] += 1 # Return identical cond/uncond so downstream cfg_function keeps `final` unchanged. - return [final, final] + return [denoised_final, denoised_final] return pre_cfg_fn @@ -310,6 +413,8 @@ def _set_helios_history_values(positive, negative, history_latent, history_sizes latent = history_latent if latent is None or len(latent.shape) != 5: return positive, negative + if prefix_latent is not None and (latent.device != prefix_latent.device or latent.dtype != prefix_latent.dtype): + latent = latent.to(device=prefix_latent.device, dtype=prefix_latent.dtype) sizes = list(history_sizes) if len(sizes) != 3: @@ -342,13 +447,15 @@ def _set_helios_history_values(positive, negative, history_latent, history_sizes prefix = latent[:, :, :1] else: prefix = torch.zeros(latent.shape[0], latent.shape[1], 1, latent.shape[3], latent.shape[4], device=latent.device, dtype=latent.dtype) + if prefix.device != latents_history_short_base.device or prefix.dtype != latents_history_short_base.dtype: + prefix = prefix.to(device=latents_history_short_base.device, dtype=latents_history_short_base.dtype) latents_history_short = torch.cat([prefix, latents_history_short_base], dim=2) else: latents_history_short = latents_history_short_base - idx_short = torch.arange(latents_history_short.shape[2], device=latent.device, dtype=latent.dtype).unsqueeze(0).expand(latent.shape[0], -1) - idx_mid = torch.arange(latents_history_mid.shape[2], device=latent.device, dtype=latent.dtype).unsqueeze(0).expand(latent.shape[0], -1) - idx_long = torch.arange(latents_history_long.shape[2], device=latent.device, dtype=latent.dtype).unsqueeze(0).expand(latent.shape[0], -1) + idx_short = torch.arange(latents_history_short.shape[2], device=latent.device, dtype=torch.int64).unsqueeze(0).expand(latent.shape[0], -1) + idx_mid = torch.arange(latents_history_mid.shape[2], device=latent.device, dtype=torch.int64).unsqueeze(0).expand(latent.shape[0], -1) + idx_long = torch.arange(latents_history_long.shape[2], device=latent.device, dtype=torch.int64).unsqueeze(0).expand(latent.shape[0], -1) values = { "latents_history_short": latents_history_short, @@ -364,7 +471,7 @@ def _set_helios_history_values(positive, negative, history_latent, history_sizes return positive, negative -def _build_helios_indices(batch, history_sizes, keep_first_frame, hidden_frames, device, dtype): +def _build_helios_indices(batch, history_sizes, keep_first_frame, hidden_frames, device): sizes = list(history_sizes) if len(sizes) != 3: sizes = [16, 2, 1] @@ -373,13 +480,13 @@ def _build_helios_indices(batch, history_sizes, keep_first_frame, hidden_frames, if keep_first_frame: total = 1 + long_size + mid_size + short_base_size + hidden_frames - indices = torch.arange(total, device=device, dtype=dtype) + indices = torch.arange(total, device=device, dtype=torch.int64) splits = [1, long_size, mid_size, short_base_size, hidden_frames] indices_prefix, idx_long, idx_mid, idx_1x, idx_hidden = torch.split(indices, splits, dim=0) idx_short = torch.cat([indices_prefix, idx_1x], dim=0) else: total = long_size + mid_size + short_base_size + hidden_frames - indices = torch.arange(total, device=device, dtype=dtype) + indices = torch.arange(total, device=device, dtype=torch.int64) splits = [long_size, mid_size, short_base_size, hidden_frames] idx_long, idx_mid, idx_short, idx_hidden = torch.split(indices, splits, dim=0) @@ -450,7 +557,9 @@ class HeliosImageToVideo(io.ComfyNode): sizes = sorted([max(0, int(v)) for v in sizes], reverse=True) hist_len = max(1, sum(sizes)) history_latent = torch.zeros([batch_size, latent_channels, hist_len, latent.shape[-2], latent.shape[-1]], device=latent.device, dtype=latent.dtype) + history_valid_mask = torch.zeros((batch_size, hist_len), device=latent.device, dtype=torch.bool) image_latent_prefix = None + i2v_noise_gen = None if start_image is not None: image = comfy.utils.common_upscale(start_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -459,20 +568,36 @@ class HeliosImageToVideo(io.ComfyNode): image_latent_prefix = img_latent[:, :, :1] if add_noise_to_image_latents: - g = torch.Generator(device=img_latent.device) - g.manual_seed(int(noise_seed)) + i2v_noise_gen = torch.Generator(device=img_latent.device) + i2v_noise_gen.manual_seed(int(noise_seed)) sigma = ( - torch.rand((img_latent.shape[0], 1, 1, 1, 1), device=img_latent.device, generator=g, dtype=img_latent.dtype) + torch.rand((img_latent.shape[0], 1, 1, 1, 1), device=img_latent.device, generator=i2v_noise_gen, dtype=img_latent.dtype) * (float(image_noise_sigma_max) - float(image_noise_sigma_min)) + float(image_noise_sigma_min) ) - image_latent_prefix = sigma * torch.randn_like(image_latent_prefix, generator=g) + (1.0 - sigma) * image_latent_prefix + image_latent_prefix = sigma * torch.randn_like(image_latent_prefix, generator=i2v_noise_gen) + (1.0 - sigma) * image_latent_prefix min_frames = max(1, (int(num_latent_frames_per_chunk) - 1) * 4 + 1) fake_video = image.repeat(min_frames, 1, 1, 1) fake_latents_full = vae.encode(fake_video) fake_latent = comfy.utils.repeat_to_batch_size(fake_latents_full[:, :, -1:], batch_size) + # Diffusers parity for I2V: + # when adding noise to image latents, fake_image_latents used for history are also noised. + if add_noise_to_image_latents: + if i2v_noise_gen is None: + i2v_noise_gen = torch.Generator(device=fake_latent.device) + i2v_noise_gen.manual_seed(int(noise_seed)) + # Keep backward compatibility with existing I2V node inputs: + # this node exposes only image sigma controls, while fake history + # latents follow the video-noise path in Diffusers. + fake_sigma = ( + torch.rand((fake_latent.shape[0], 1, 1, 1, 1), device=fake_latent.device, generator=i2v_noise_gen, dtype=fake_latent.dtype) + * (float(image_noise_sigma_max) - float(image_noise_sigma_min)) + + float(image_noise_sigma_min) + ) + fake_latent = fake_sigma * torch.randn_like(fake_latent, generator=i2v_noise_gen) + (1.0 - fake_sigma) * fake_latent history_latent[:, :, -1:] = fake_latent + history_valid_mask[:, -1] = True positive, negative = _set_helios_history_values(positive, negative, history_latent, sizes, keep_first_frame, prefix_latent=image_latent_prefix) return io.NodeOutput( @@ -482,6 +607,85 @@ class HeliosImageToVideo(io.ComfyNode): "samples": latent, "helios_history_latent": history_latent, "helios_image_latent_prefix": image_latent_prefix, + "helios_history_valid_mask": history_valid_mask, + }, + ) + + +class HeliosTextToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HeliosTextToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=132, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.String.Input("history_sizes", default="16,2,1", advanced=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute( + cls, + positive, + negative, + vae, + width, + height, + length, + batch_size, + history_sizes="16,2,1", + ) -> io.NodeOutput: + spacial_scale = vae.spacial_compression_encode() + latent_channels = vae.latent_channels + latent_t = ((length - 1) // 4) + 1 + + # Create zero latent as shape placeholder (noise will be generated in sampler) + latent = torch.zeros( + [batch_size, latent_channels, latent_t, height // spacial_scale, width // spacial_scale], + device=comfy.model_management.intermediate_device(), + ) + + sizes = _parse_int_list(history_sizes, [16, 2, 1]) + if len(sizes) != 3: + sizes = [16, 2, 1] + sizes = sorted([max(0, int(v)) for v in sizes], reverse=True) + hist_len = max(1, sum(sizes)) + # History latent starts as zeros (no history yet) + history_latent = torch.zeros( + [batch_size, latent_channels, hist_len, latent.shape[-2], latent.shape[-1]], + device=latent.device, + dtype=latent.dtype, + ) + history_valid_mask = torch.zeros((batch_size, hist_len), device=latent.device, dtype=torch.bool) + + positive, negative = _set_helios_history_values( + positive, + negative, + history_latent, + sizes, + False, + prefix_latent=None, + ) + return io.NodeOutput( + positive, + negative, + { + "samples": latent, + "helios_history_latent": history_latent, + "helios_image_latent_prefix": None, + "helios_history_valid_mask": history_valid_mask, }, ) @@ -544,6 +748,7 @@ class HeliosVideoToVideo(io.ComfyNode): sizes = sorted([max(0, int(v)) for v in sizes], reverse=True) hist_len = max(1, sum(sizes)) history_latent = torch.zeros([batch_size, latent_channels, hist_len, latent.shape[-2], latent.shape[-1]], device=latent.device, dtype=latent.dtype) + history_valid_mask = torch.zeros((batch_size, hist_len), device=latent.device, dtype=torch.bool) image_latent_prefix = None if video is not None: @@ -559,11 +764,14 @@ class HeliosVideoToVideo(io.ComfyNode): ) vid_latent = frame_sigmas * torch.randn_like(vid_latent, generator=g) + (1.0 - frame_sigmas) * vid_latent vid_latent = vid_latent[:, :, :hist_len] - if vid_latent.shape[2] < hist_len: - pad = vid_latent[:, :, -1:].repeat(1, 1, hist_len - vid_latent.shape[2], 1, 1) - vid_latent = torch.cat([vid_latent, pad], dim=2) vid_latent = comfy.utils.repeat_to_batch_size(vid_latent, batch_size) - history_latent = vid_latent + if vid_latent.shape[2] < hist_len: + keep_frames = hist_len - vid_latent.shape[2] + history_latent = torch.cat([history_latent[:, :, :keep_frames], vid_latent], dim=2) + history_valid_mask[:, keep_frames:] = True + else: + history_latent = vid_latent[:, :, -hist_len:] + history_valid_mask[:] = True image_latent_prefix = history_latent[:, :, :1] positive, negative = _set_helios_history_values(positive, negative, history_latent, sizes, keep_first_frame, prefix_latent=image_latent_prefix) @@ -574,6 +782,7 @@ class HeliosVideoToVideo(io.ComfyNode): "samples": latent, "helios_history_latent": history_latent, "helios_image_latent_prefix": image_latent_prefix, + "helios_history_valid_mask": history_valid_mask, }, ) @@ -625,25 +834,16 @@ class HeliosPyramidSampler(io.ComfyNode): io.Latent.Input("latent_image"), io.String.Input("pyramid_steps", default="10,10,10"), io.String.Input("stage_range", default="0,0.333333,0.666667,1"), - io.Boolean.Input("is_distilled", default=False), - io.Boolean.Input("is_amplify_first_stage", default=False), - io.Combo.Input("scheduler_mode", options=["euler", "unipc_bh2"]), + io.Boolean.Input("distilled", default=False), + io.Boolean.Input("amplify_first_stage", default=False), io.Float.Input("gamma", default=1.0 / 3.0, min=0.0001, max=10.0, step=0.0001, round=False), - io.Float.Input("shift", default=1.0, min=0.001, max=100.0, step=0.001, round=False, advanced=True), - io.Boolean.Input("use_dynamic_shifting", default=False, advanced=True), - io.Combo.Input("time_shift_type", options=["exponential", "linear"], advanced=True), - io.Int.Input("base_image_seq_len", default=256, min=1, max=65536, advanced=True), - io.Int.Input("max_image_seq_len", default=4096, min=1, max=65536, advanced=True), - io.Float.Input("base_shift", default=0.5, min=0.0, max=10.0, step=0.0001, round=False, advanced=True), - io.Float.Input("max_shift", default=1.15, min=0.0, max=10.0, step=0.0001, round=False, advanced=True), - io.Int.Input("num_train_timesteps", default=1000, min=10, max=100000, advanced=True), io.String.Input("history_sizes", default="16,2,1", advanced=True), io.Boolean.Input("keep_first_frame", default=True, advanced=True), io.Int.Input("num_latent_frames_per_chunk", default=9, min=1, max=256, advanced=True), - io.Boolean.Input("is_cfg_zero_star", default=False, advanced=True), + io.Boolean.Input("cfg_zero_star", default=True, advanced=True), io.Boolean.Input("use_zero_init", default=True, advanced=True), io.Int.Input("zero_steps", default=1, min=0, max=10000, advanced=True), - io.Boolean.Input("is_skip_first_chunk", default=False, advanced=True), + io.Boolean.Input("skip_first_chunk", default=False, advanced=True), ], outputs=[ io.Latent.Output(display_name="output"), @@ -663,33 +863,40 @@ class HeliosPyramidSampler(io.ComfyNode): latent_image, pyramid_steps, stage_range, - is_distilled, - is_amplify_first_stage, - scheduler_mode, + distilled, + amplify_first_stage, gamma, - shift, - use_dynamic_shifting, - time_shift_type, - base_image_seq_len, - max_image_seq_len, - base_shift, - max_shift, - num_train_timesteps, history_sizes, keep_first_frame, num_latent_frames_per_chunk, - is_cfg_zero_star, + cfg_zero_star, use_zero_init, zero_steps, - is_skip_first_chunk, + skip_first_chunk, ) -> io.NodeOutput: + # Keep these scheduler knobs internal (not exposed in node UI). + shift = 1.0 + num_train_timesteps = 1000 + # Keep dynamic shifting always on for Helios parity; not exposed in node UI. + use_dynamic_shifting = True + time_shift_type = "exponential" + base_image_seq_len = 256 + max_image_seq_len = 4096 + base_shift = 0.5 + max_shift = 1.15 + latent = latent_image.copy() latent_samples = comfy.sample.fix_empty_latent_channels(model, latent["samples"], latent.get("downscale_ratio_spacial", None)) + if not add_noise: + latent_samples = _process_latent_in_preserve_zero_frames(model, latent_samples) stage_steps = _parse_int_list(pyramid_steps, [10, 10, 10]) stage_steps = [max(1, int(s)) for s in stage_steps] stage_count = len(stage_steps) history_sizes_list = sorted([max(0, int(v)) for v in _parse_int_list(history_sizes, [16, 2, 1])], reverse=True) + # Diffusers parity: if not keeping first frame, fold prefix slot into short history size. + if not keep_first_frame and len(history_sizes_list) > 0: + history_sizes_list[-1] += 1 stage_range_values = _parse_float_list(stage_range, [0.0, 1.0 / 3.0, 2.0 / 3.0, 1.0]) if len(stage_range_values) != stage_count + 1: @@ -706,29 +913,41 @@ class HeliosPyramidSampler(io.ComfyNode): b, c, t, h, w = latent_samples.shape chunk_t = max(1, int(num_latent_frames_per_chunk)) chunk_count = max(1, (t + chunk_t - 1) // chunk_t) - low_scale = 2 ** max(0, stage_count - 1) - low_h = max(1, h // low_scale) - low_w = max(1, w // low_scale) - - base_latent = torch.zeros((b, c, chunk_t, low_h, low_w), dtype=latent_samples.dtype, layout=latent_samples.layout, device=latent_samples.device) - - if add_noise: - stage_latent = comfy.sample.prepare_noise(base_latent, noise_seed) - else: - stage_latent = torch.zeros_like(base_latent, device="cpu") - - stage_latent = stage_latent.to(base_latent.dtype).to(comfy.model_management.intermediate_device()) euler_sampler = comfy.samplers.KSAMPLER(_helios_euler_sample) + target_device = comfy.model_management.get_torch_device() + noise_gen = torch.Generator(device=target_device) + noise_gen.manual_seed(int(noise_seed)) + + image_latent_prefix = latent.get("helios_image_latent_prefix", None) + history_valid_mask = latent.get("helios_history_valid_mask", None) + if history_valid_mask is None: + raise ValueError("Helios sampler requires `helios_history_valid_mask` in latent input.") + history_from_latent_applied = False + if image_latent_prefix is not None: + image_latent_prefix = model.model.process_latent_in(image_latent_prefix) + if "helios_history_latent" in latent: + history_in = _process_latent_in_preserve_zero_frames(model, latent["helios_history_latent"], valid_mask=history_valid_mask) + positive, negative = _set_helios_history_values( + positive, + negative, + history_in, + history_sizes_list, + keep_first_frame, + prefix_latent=image_latent_prefix, + ) + history_from_latent_applied = True latents_history_short = _extract_condition_value(positive, "latents_history_short") latents_history_mid = _extract_condition_value(positive, "latents_history_mid") latents_history_long = _extract_condition_value(positive, "latents_history_long") - image_latent_prefix = latent.get("helios_image_latent_prefix", None) + if (not history_from_latent_applied) and latents_history_short is not None and latents_history_mid is not None and latents_history_long is not None: + raise ValueError("Helios requires `helios_history_latent` + `helios_history_valid_mask`; direct history conditioning is not supported.") if latents_history_short is None and "helios_history_latent" in latent: + history_in = _process_latent_in_preserve_zero_frames(model, latent["helios_history_latent"], valid_mask=history_valid_mask) positive, negative = _set_helios_history_values( positive, negative, - latent["helios_history_latent"], + history_in, history_sizes_list, keep_first_frame, prefix_latent=image_latent_prefix, @@ -740,18 +959,100 @@ class HeliosPyramidSampler(io.ComfyNode): x0_output = {} generated_chunks = [] if latents_history_short is not None and latents_history_mid is not None and latents_history_long is not None: - rolling_history = torch.cat([latents_history_long, latents_history_mid, latents_history_short], dim=2) + # Diffusers parity: `history_latents` storage does NOT include the keep_first_frame prefix slot. + # `latents_history_short` in conditioning may include [prefix + short_base], so strip prefix here. + short_base_size = history_sizes_list[-1] if len(history_sizes_list) > 0 else latents_history_short.shape[2] + if keep_first_frame and latents_history_short.shape[2] > short_base_size: + short_for_history = latents_history_short[:, :, -short_base_size:] + else: + short_for_history = latents_history_short + rolling_history = torch.cat([latents_history_long, latents_history_mid, short_for_history], dim=2) elif "helios_history_latent" in latent: rolling_history = latent["helios_history_latent"] + rolling_history = _process_latent_in_preserve_zero_frames(model, rolling_history, valid_mask=history_valid_mask) else: hist_len = max(1, sum(history_sizes_list)) rolling_history = torch.zeros((b, c, hist_len, h, w), device=latent_samples.device, dtype=latent_samples.dtype) - for chunk_idx in range(chunk_count): - if add_noise: - stage_latent = comfy.sample.prepare_noise(base_latent, noise_seed + chunk_idx).to(base_latent.dtype).to(comfy.model_management.intermediate_device()) + # Align with Diffusers behavior: when initial video latents are provided, seed history buffer + # with those latents before the first denoising chunk. + if not add_noise: + hist_len = max(1, sum(history_sizes_list)) + rolling_history = rolling_history.to(device=latent_samples.device, dtype=latent_samples.dtype) + video_latents = latent_samples + video_frames = video_latents.shape[2] + if video_frames < hist_len: + keep_frames = hist_len - video_frames + rolling_history = torch.cat([rolling_history[:, :, :keep_frames], video_latents], dim=2) else: - stage_latent = torch.zeros_like(base_latent, device=comfy.model_management.intermediate_device()) + rolling_history = video_latents[:, :, -hist_len:] + + # Keep history/prefix on the same device/dtype as denoising latents. + rolling_history = rolling_history.to(device=target_device, dtype=latent_samples.dtype) + if image_latent_prefix is not None: + image_latent_prefix = image_latent_prefix.to(device=target_device, dtype=latent_samples.dtype) + + for chunk_idx in range(chunk_count): + # Extract chunk from input latents + chunk_start = chunk_idx * chunk_t + chunk_end = min(chunk_start + chunk_t, t) + latent_chunk = latent_samples[:, :, chunk_start:chunk_end, :, :] + + # Prepare initial latent for this chunk + if add_noise: + # Diffusers parity: each chunk denoises a fixed latent window size. + # Keep chunk temporal length constant and crop only after all chunks. + noise_shape = ( + latent_samples.shape[0], + latent_samples.shape[1], + chunk_t, + latent_samples.shape[3], + latent_samples.shape[4], + ) + stage_latent = torch.randn(noise_shape, device=target_device, dtype=latent_samples.dtype, generator=noise_gen) + else: + # Use actual input latents; pad final short chunk to fixed size like Diffusers windowing. + stage_latent = latent_chunk.clone() + if stage_latent.shape[2] < chunk_t: + if stage_latent.shape[2] == 0: + stage_latent = torch.zeros( + ( + latent_samples.shape[0], + latent_samples.shape[1], + chunk_t, + latent_samples.shape[3], + latent_samples.shape[4], + ), + device=latent_samples.device, + dtype=latent_samples.dtype, + ) + else: + pad = stage_latent[:, :, -1:].repeat(1, 1, chunk_t - stage_latent.shape[2], 1, 1) + stage_latent = torch.cat([stage_latent, pad], dim=2) + + # Downsample to stage 0 resolution + for _ in range(max(0, int(stage_count) - 1)): + stage_latent = _downsample_latent_5d_bilinear_x2(stage_latent) + + # Keep stage latents on model device for parity with Diffusers scheduler/noise path. + stage_latent = stage_latent.to(target_device) + + # Diffusers parity: + # keep_first_frame=True and no image_latent_prefix on the first chunk + # should use an all-zero prefix frame, not history[:, :, :1]. + chunk_prefix = image_latent_prefix + if keep_first_frame and image_latent_prefix is None and chunk_idx == 0: + chunk_prefix = torch.zeros( + ( + rolling_history.shape[0], + rolling_history.shape[1], + 1, + rolling_history.shape[3], + rolling_history.shape[4], + ), + device=rolling_history.device, + dtype=rolling_history.dtype, + ) positive_chunk, negative_chunk = _set_helios_history_values( positive, @@ -759,37 +1060,28 @@ class HeliosPyramidSampler(io.ComfyNode): rolling_history, history_sizes_list, keep_first_frame, - prefix_latent=image_latent_prefix, + prefix_latent=chunk_prefix, ) latents_history_short = _extract_condition_value(positive_chunk, "latents_history_short") latents_history_mid = _extract_condition_value(positive_chunk, "latents_history_mid") latents_history_long = _extract_condition_value(positive_chunk, "latents_history_long") for stage_idx in range(stage_count): - if stage_idx > 0: - stage_latent = _upsample_latent_5d(stage_latent, scale=2) - - ori_sigma = 1.0 - float(stage_tables["ori_start_sigmas"][stage_idx]) - alpha = 1.0 / (math.sqrt(1.0 + (1.0 / gamma)) * (1.0 - ori_sigma) + ori_sigma) - beta = alpha * (1.0 - ori_sigma) / math.sqrt(gamma) - - noise = _sample_block_noise_like(stage_latent, gamma, patch_size=(1, 2, 2)).to(stage_latent) - stage_latent = alpha * stage_latent + beta * noise - + stage_latent = stage_latent.to(comfy.model_management.get_torch_device()) sigmas = _helios_stage_sigmas( stage_idx=stage_idx, stage_steps=stage_steps[stage_idx], stage_tables=stage_tables, - is_distilled=is_distilled, - is_amplify_first_stage=is_amplify_first_stage and chunk_idx == 0, - ).to(stage_latent.dtype) + is_distilled=distilled, + is_amplify_first_stage=amplify_first_stage and chunk_idx == 0, + ).to(device=stage_latent.device, dtype=torch.float32) timesteps = _helios_stage_timesteps( stage_idx=stage_idx, stage_steps=stage_steps[stage_idx], stage_tables=stage_tables, - is_distilled=is_distilled, - is_amplify_first_stage=is_amplify_first_stage and chunk_idx == 0, - ).to(stage_latent.dtype) + is_distilled=distilled, + is_amplify_first_stage=amplify_first_stage and chunk_idx == 0, + ).to(device=stage_latent.device, dtype=torch.float32) if use_dynamic_shifting: patch_size = (1, 2, 2) image_seq_len = (stage_latent.shape[-1] * stage_latent.shape[-2] * stage_latent.shape[-3]) // (patch_size[0] * patch_size[1] * patch_size[2]) @@ -800,10 +1092,24 @@ class HeliosPyramidSampler(io.ComfyNode): base_shift=base_shift, max_shift=max_shift, ) - sigmas = _time_shift(sigmas, mu=mu, sigma=1.0, mode=time_shift_type).to(stage_latent.dtype) + sigmas = _time_shift(sigmas, mu=mu, sigma=1.0, mode=time_shift_type).to(torch.float32) tmin = torch.min(timesteps) tmax = torch.max(timesteps) timesteps = tmin + sigmas[:-1] * (tmax - tmin) + else: + pass + + # Keep parity with Diffusers pipeline order: + # stage timesteps are computed before upsampling/renoise for stage > 0. + if stage_idx > 0: + stage_latent = _upsample_latent_5d(stage_latent, scale=2) + + ori_sigma = 1.0 - float(stage_tables["ori_start_sigmas"][stage_idx]) + alpha = 1.0 / (math.sqrt(1.0 + (1.0 / gamma)) * (1.0 - ori_sigma) + ori_sigma) + beta = alpha * (1.0 - ori_sigma) / math.sqrt(gamma) + + noise = _sample_block_noise_like(stage_latent, gamma, patch_size=(1, 2, 2), generator=noise_gen).to(stage_latent) + stage_latent = alpha * stage_latent + beta * noise indices_hidden_states, idx_short, idx_mid, idx_long = _build_helios_indices( batch=stage_latent.shape[0], @@ -811,7 +1117,6 @@ class HeliosPyramidSampler(io.ComfyNode): keep_first_frame=keep_first_frame, hidden_frames=stage_latent.shape[2], device=stage_latent.device, - dtype=stage_latent.dtype, ) positive_stage = node_helpers.conditioning_set_values(positive_chunk, {"indices_hidden_states": indices_hidden_states}) negative_stage = node_helpers.conditioning_set_values(negative_chunk, {"indices_hidden_states": indices_hidden_states}) @@ -831,19 +1136,22 @@ class HeliosPyramidSampler(io.ComfyNode): positive_stage = node_helpers.conditioning_set_values(positive_stage, values) negative_stage = node_helpers.conditioning_set_values(negative_stage, values) - cfg_use = 1.0 if is_distilled else cfg + stage_time_values = { + "helios_stage_sigmas": sigmas, + "helios_stage_timesteps": timesteps, + } + positive_stage = node_helpers.conditioning_set_values(positive_stage, stage_time_values) + negative_stage = node_helpers.conditioning_set_values(negative_stage, stage_time_values) - if stage_idx == 0 and add_noise: - noise = comfy.sample.prepare_noise(stage_latent, noise_seed + chunk_idx * 100 + stage_idx) - latent_start = torch.zeros_like(stage_latent) - else: - sigma0 = max(float(sigmas[0].item()), 1e-6) - noise = (stage_latent / sigma0).to("cpu") - latent_start = torch.zeros_like(stage_latent) + cfg_use = 1.0 if distilled else cfg + + sigma0 = max(float(sigmas[0].item()), 1e-6) + noise = stage_latent / sigma0 + latent_start = torch.zeros_like(stage_latent) stage_start_for_dmd = stage_latent.clone() - if is_distilled: + if distilled: sampler = comfy.samplers.KSAMPLER( _helios_dmd_sample, extra_options={ @@ -854,14 +1162,11 @@ class HeliosPyramidSampler(io.ComfyNode): }, ) else: - if scheduler_mode == "unipc_bh2": - sampler = comfy.samplers.ksampler("uni_pc_bh2") - else: - sampler = euler_sampler + sampler = euler_sampler callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output) stage_model = model - if is_cfg_zero_star and not is_distilled: + if cfg_zero_star and not distilled: stage_model = model.clone() stage_model.model_options = comfy.model_patcher.set_model_options_pre_cfg_function( stage_model.model_options, @@ -882,6 +1187,10 @@ class HeliosPyramidSampler(io.ComfyNode): disable_pbar=not comfy.utils.PROGRESS_BAR_ENABLED, seed=noise_seed + chunk_idx * 100 + stage_idx, ) + # sample_custom returns latent_format.process_out(samples); convert back to model-space + # so subsequent pyramid stages and history conditioning stay in the same latent space + # as Diffusers' internal denoising latents. + stage_latent = model.model.process_latent_in(stage_latent) if stage_latent.shape[-2] != h or stage_latent.shape[-1] != w: b2, c2, t2, h2, w2 = stage_latent.shape @@ -891,7 +1200,7 @@ class HeliosPyramidSampler(io.ComfyNode): stage_latent = stage_latent[:, :, :, :h, :w] generated_chunks.append(stage_latent) - if keep_first_frame and ((chunk_idx == 0 and image_latent_prefix is None) or (is_skip_first_chunk and chunk_idx == 1)): + if keep_first_frame and ((chunk_idx == 0 and image_latent_prefix is None) or (skip_first_chunk and chunk_idx == 1)): image_latent_prefix = stage_latent[:, :, :1] rolling_history = torch.cat([rolling_history, stage_latent.to(rolling_history.device, rolling_history.dtype)], dim=2) keep_hist = max(1, sum(history_sizes_list)) @@ -901,7 +1210,7 @@ class HeliosPyramidSampler(io.ComfyNode): out = latent.copy() out.pop("downscale_ratio_spacial", None) - out["samples"] = stage_latent + out["samples"] = model.model.process_latent_out(stage_latent) if "x0" in x0_output: x0_out = model.model.process_latent_out(x0_output["x0"].cpu()) @@ -917,6 +1226,7 @@ class HeliosExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ + HeliosTextToVideo, HeliosImageToVideo, HeliosVideoToVideo, HeliosHistoryConditioning, From 26f44ab7708827321756a158344be80a01676ad9 Mon Sep 17 00:00:00 2001 From: qqingzheng <2533221180@qq.com> Date: Sun, 8 Mar 2026 23:08:57 +0800 Subject: [PATCH 03/10] Enhance Helios model with latent space noise application and debugging options --- comfy/ldm/helios/model.py | 48 ++---- comfy_extras/nodes_helios.py | 324 +++++++++++++++++++++++++++++------ 2 files changed, 287 insertions(+), 85 deletions(-) diff --git a/comfy/ldm/helios/model.py b/comfy/ldm/helios/model.py index 6fd37b875..c1ea5f595 100644 --- a/comfy/ldm/helios/model.py +++ b/comfy/ldm/helios/model.py @@ -652,8 +652,8 @@ class HeliosModel(torch.nn.Module): ) original_context_length = hidden_states.shape[1] - if (latents_history_short is not None and indices_latents_history_short is not None and hasattr(self, "patch_short")): - x_short = self.patch_short(latents_history_short).to(hidden_states.dtype) + if latents_history_short is not None and indices_latents_history_short is not None: + x_short = self.patch_short(latents_history_short) _, _, ts, hs, ws = x_short.shape x_short = x_short.flatten(2).transpose(1, 2) f_short = self.rope_encode( @@ -671,57 +671,45 @@ class HeliosModel(torch.nn.Module): hidden_states = torch.cat([x_short, hidden_states], dim=1) freqs = torch.cat([f_short, freqs], dim=1) - if (latents_history_mid is not None and indices_latents_history_mid is not None and hasattr(self, "patch_mid")): - x_mid = self.patch_mid(pad_for_3d_conv(latents_history_mid, (2, 4, 4))).to(hidden_states.dtype) - _, _, tm, hm, wm = x_mid.shape + if latents_history_mid is not None and indices_latents_history_mid is not None: + x_mid = self.patch_mid(pad_for_3d_conv(latents_history_mid, (2, 4, 4))) + _, _, tm, _, _ = x_mid.shape x_mid = x_mid.flatten(2).transpose(1, 2) mid_t = indices_latents_history_mid.shape[1] - if ("hs" in locals()) and ("ws" in locals()): - mid_h, mid_w = hs, ws - else: - mid_h, mid_w = hm * 2, wm * 2 f_mid = self.rope_encode( t=mid_t * self.patch_size[0], - h=mid_h * self.patch_size[1], - w=mid_w * self.patch_size[2], + h=hs * self.patch_size[1], + w=ws * self.patch_size[2], steps_t=mid_t, - steps_h=mid_h, - steps_w=mid_w, + steps_h=hs, + steps_w=ws, device=x_mid.device, dtype=x_mid.dtype, transformer_options=transformer_options, frame_indices=indices_latents_history_mid, ) - f_mid = self._rope_downsample_3d(f_mid, (mid_t, mid_h, mid_w), (2, 2, 2)) - if f_mid.shape[1] != x_mid.shape[1]: - f_mid = f_mid[:, :x_mid.shape[1]] + f_mid = self._rope_downsample_3d(f_mid, (mid_t, hs, ws), (2, 2, 2)) hidden_states = torch.cat([x_mid, hidden_states], dim=1) freqs = torch.cat([f_mid, freqs], dim=1) - if (latents_history_long is not None and indices_latents_history_long is not None and hasattr(self, "patch_long")): - x_long = self.patch_long(pad_for_3d_conv(latents_history_long, (4, 8, 8))).to(hidden_states.dtype) - _, _, tl, hl, wl = x_long.shape + if latents_history_long is not None and indices_latents_history_long is not None: + x_long = self.patch_long(pad_for_3d_conv(latents_history_long, (4, 8, 8))) + _, _, tl, _, _ = x_long.shape x_long = x_long.flatten(2).transpose(1, 2) long_t = indices_latents_history_long.shape[1] - if ("hs" in locals()) and ("ws" in locals()): - long_h, long_w = hs, ws - else: - long_h, long_w = hl * 4, wl * 4 f_long = self.rope_encode( t=long_t * self.patch_size[0], - h=long_h * self.patch_size[1], - w=long_w * self.patch_size[2], + h=hs * self.patch_size[1], + w=ws * self.patch_size[2], steps_t=long_t, - steps_h=long_h, - steps_w=long_w, + steps_h=hs, + steps_w=ws, device=x_long.device, dtype=x_long.dtype, transformer_options=transformer_options, frame_indices=indices_latents_history_long, ) - f_long = self._rope_downsample_3d(f_long, (long_t, long_h, long_w), (4, 4, 4)) - if f_long.shape[1] != x_long.shape[1]: - f_long = f_long[:, :x_long.shape[1]] + f_long = self._rope_downsample_3d(f_long, (long_t, hs, ws), (4, 4, 4)) hidden_states = torch.cat([x_long, hidden_states], dim=1) freqs = torch.cat([f_long, freqs], dim=1) diff --git a/comfy_extras/nodes_helios.py b/comfy_extras/nodes_helios.py index 13894082a..3d5f80e76 100644 --- a/comfy_extras/nodes_helios.py +++ b/comfy_extras/nodes_helios.py @@ -7,6 +7,7 @@ import comfy.model_patcher import comfy.sample import comfy.samplers import comfy.utils +import comfy.latent_formats import latent_preview import node_helpers @@ -41,6 +42,37 @@ def _parse_int_list(values, default): return out if len(out) > 0 else default +_HELIOS_LATENT_FORMAT = comfy.latent_formats.Helios() + + +def _apply_helios_latent_space_noise(latent, sigma, generator=None): + """Apply noise in Helios model latent space, then map back to VAE latent space.""" + latent_in = _HELIOS_LATENT_FORMAT.process_in(latent) + noise = torch.randn( + latent_in.shape, + device=latent_in.device, + dtype=latent_in.dtype, + generator=generator, + ) + noised_in = sigma * noise + (1.0 - sigma) * latent_in + return _HELIOS_LATENT_FORMAT.process_out(noised_in).to(device=latent.device, dtype=latent.dtype) + + +def _tensor_stats_str(x): + if x is None: + return "None" + if not torch.is_tensor(x): + return f"non-tensor type={type(x)}" + if x.numel() == 0: + return f"shape={tuple(x.shape)} empty" + xf = x.detach().to(torch.float32) + return ( + f"shape={tuple(x.shape)} " + f"mean={xf.mean().item():.6f} std={xf.std(unbiased=False).item():.6f} " + f"min={xf.min().item():.6f} max={xf.max().item():.6f}" + ) + + def _parse_float_list(values, default): if values is None: return default @@ -65,6 +97,15 @@ def _parse_float_list(values, default): return out if len(out) > 0 else default +def _strict_bool(value, default=False): + if isinstance(value, bool): + return value + if isinstance(value, int): + return value != 0 + # Reject non-bool numerics from stale workflows (e.g. 0.135). + return bool(default) + + def _extract_condition_value(conditioning, key): for c in conditioning: if len(c) < 2: @@ -94,8 +135,9 @@ def _process_latent_in_preserve_zero_frames(model, latent, valid_mask=None): return latent if nonzero.shape[0] != latent.shape[2]: - # Keep behavior safe when mask length does not match temporal length. - nonzero = torch.zeros((latent.shape[2],), device=latent.device, dtype=torch.bool) + raise ValueError( + f"Helios history mask length mismatch: mask_t={nonzero.shape[0]} latent_t={latent.shape[2]}" + ) converted = model.model.process_latent_in(latent) out = latent.clone() @@ -133,7 +175,7 @@ def _prepare_stage0_latent(batch, channels, frames, height, width, stage_count, def _downsample_latent_for_stage0(latent, stage_count): - """Downsample latent to stage 0 resolution (like Diffusers does)""" + """Downsample latent to stage 0 resolution.""" stage_latent = latent for _ in range(max(0, int(stage_count) - 1)): stage_latent = _downsample_latent_5d_bilinear_x2(stage_latent) @@ -154,7 +196,7 @@ def _sample_block_noise_like(latent, gamma, patch_size=(1, 2, 2), generator=None block_number = b * c * t * h_blocks * w_blocks if generator is not None: - # Exact Diffusers sampling path (MultivariateNormal.sample), while consuming + # Exact sampling path (MultivariateNormal.sample), while consuming # from an explicit generator by temporarily swapping default RNG state. with torch.random.fork_rng(devices=[latent.device] if latent.device.type == "cuda" else []): if latent.device.type == "cuda": @@ -231,7 +273,7 @@ def _helios_stage_tables(stage_count, stage_range, gamma, num_train_timesteps=10 tmax = min(float(sigmas[int(start_ratio * num_train_timesteps)].item() * num_train_timesteps), 999.0) tmin = float(sigmas[min(int(end_ratio * num_train_timesteps), num_train_timesteps - 1)].item() * num_train_timesteps) timesteps_per_stage[i] = torch.linspace(tmax, tmin, num_train_timesteps + 1)[:-1] - # Fixed: Use same sigma range [0.999, 0] for all stages like Diffusers + # Fixed: use the same sigma range [0.999, 0] for all stages. sigmas_per_stage[i] = torch.linspace(0.999, 0.0, num_train_timesteps + 1)[:-1] @@ -302,21 +344,18 @@ def _build_cfg_zero_star_pre_cfg(stage_idx, zero_steps, use_zero_init): state["i"] += 1 return conds_out - denoised_text = conds_out[0] # apply_model 返回的 denoised + denoised_text = conds_out[0] denoised_uncond = conds_out[1] cfg = float(args.get("cond_scale", 1.0)) - x = args["input"] # 当前的 noisy latent - sigma = args["sigma"] # 当前的 sigma + x = args["input"] + sigma = args["sigma"] - # 关键修复:将 denoised 转换为 flow - # denoised = x - flow * sigma => flow = (x - denoised) / sigma sigma_reshaped = sigma.reshape(sigma.shape[0], *([1] * (denoised_text.ndim - 1))) sigma_safe = torch.clamp(sigma_reshaped, min=1e-8) flow_text = (x - denoised_text) / sigma_safe flow_uncond = (x - denoised_uncond) / sigma_safe - # 在 flow 空间做 CFG Zero Star positive_flat = flow_text.reshape(flow_text.shape[0], -1) negative_flat = flow_uncond.reshape(flow_uncond.shape[0], -1) alpha = _optimized_scale(positive_flat, negative_flat) @@ -327,11 +366,9 @@ def _build_cfg_zero_star_pre_cfg(stage_idx, zero_steps, use_zero_init): else: flow_final = flow_uncond * alpha + cfg * (flow_text - flow_uncond * alpha) - # 将 flow 转回 denoised denoised_final = x - flow_final * sigma_safe state["i"] += 1 - # Return identical cond/uncond so downstream cfg_function keeps `final` unchanged. return [denoised_final, denoised_final] return pre_cfg_fn @@ -519,6 +556,8 @@ class HeliosImageToVideo(io.ComfyNode): io.Float.Input("image_noise_sigma_min", default=0.111, min=0.0, max=1.0, step=0.0001, round=False, advanced=True), io.Float.Input("image_noise_sigma_max", default=0.135, min=0.0, max=1.0, step=0.0001, round=False, advanced=True), io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, advanced=True), + io.Boolean.Input("include_history_in_output", default=False, advanced=True), + io.Boolean.Input("debug_latent_stats", default=False, advanced=True), ], outputs=[ io.Conditioning.Output(display_name="positive"), @@ -545,7 +584,11 @@ class HeliosImageToVideo(io.ComfyNode): image_noise_sigma_min=0.111, image_noise_sigma_max=0.135, noise_seed=0, + include_history_in_output=False, + debug_latent_stats=False, ) -> io.NodeOutput: + video_noise_sigma_min = 0.111 + video_noise_sigma_max = 0.135 spacial_scale = vae.spacial_compression_encode() latent_channels = vae.latent_channels latent_t = ((length - 1) // 4) + 1 @@ -560,10 +603,11 @@ class HeliosImageToVideo(io.ComfyNode): history_valid_mask = torch.zeros((batch_size, hist_len), device=latent.device, dtype=torch.bool) image_latent_prefix = None i2v_noise_gen = None + noise_gen_state = None if start_image is not None: image = comfy.utils.common_upscale(start_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - img_latent = vae.encode(image[:, :, :, :3]) + img_latent = vae.encode(image[:, :, :, :3]).to(device=latent.device, dtype=torch.float32) img_latent = comfy.utils.repeat_to_batch_size(img_latent, batch_size) image_latent_prefix = img_latent[:, :, :1] @@ -571,33 +615,38 @@ class HeliosImageToVideo(io.ComfyNode): i2v_noise_gen = torch.Generator(device=img_latent.device) i2v_noise_gen.manual_seed(int(noise_seed)) sigma = ( - torch.rand((img_latent.shape[0], 1, 1, 1, 1), device=img_latent.device, generator=i2v_noise_gen, dtype=img_latent.dtype) + torch.rand((1,), device=img_latent.device, generator=i2v_noise_gen, dtype=img_latent.dtype).view(1, 1, 1, 1, 1) * (float(image_noise_sigma_max) - float(image_noise_sigma_min)) + float(image_noise_sigma_min) ) - image_latent_prefix = sigma * torch.randn_like(image_latent_prefix, generator=i2v_noise_gen) + (1.0 - sigma) * image_latent_prefix + image_latent_prefix = _apply_helios_latent_space_noise(image_latent_prefix, sigma, generator=i2v_noise_gen) min_frames = max(1, (int(num_latent_frames_per_chunk) - 1) * 4 + 1) fake_video = image.repeat(min_frames, 1, 1, 1) - fake_latents_full = vae.encode(fake_video) + fake_latents_full = vae.encode(fake_video).to(device=latent.device, dtype=torch.float32) fake_latent = comfy.utils.repeat_to_batch_size(fake_latents_full[:, :, -1:], batch_size) - # Diffusers parity for I2V: # when adding noise to image latents, fake_image_latents used for history are also noised. if add_noise_to_image_latents: if i2v_noise_gen is None: i2v_noise_gen = torch.Generator(device=fake_latent.device) i2v_noise_gen.manual_seed(int(noise_seed)) # Keep backward compatibility with existing I2V node inputs: - # this node exposes only image sigma controls, while fake history - # latents follow the video-noise path in Diffusers. + # this node exposes only image sigma controls; fake history latents + # follow the video-noise defaults. fake_sigma = ( - torch.rand((fake_latent.shape[0], 1, 1, 1, 1), device=fake_latent.device, generator=i2v_noise_gen, dtype=fake_latent.dtype) - * (float(image_noise_sigma_max) - float(image_noise_sigma_min)) - + float(image_noise_sigma_min) + torch.rand((1,), device=fake_latent.device, generator=i2v_noise_gen, dtype=fake_latent.dtype).view(1, 1, 1, 1, 1) + * (float(video_noise_sigma_max) - float(video_noise_sigma_min)) + + float(video_noise_sigma_min) ) - fake_latent = fake_sigma * torch.randn_like(fake_latent, generator=i2v_noise_gen) + (1.0 - fake_sigma) * fake_latent + fake_latent = _apply_helios_latent_space_noise(fake_latent, fake_sigma, generator=i2v_noise_gen) history_latent[:, :, -1:] = fake_latent history_valid_mask[:, -1] = True + if i2v_noise_gen is not None: + noise_gen_state = i2v_noise_gen.get_state().clone() + if debug_latent_stats: + print(f"[HeliosDebug][I2V] image_latent_prefix: {_tensor_stats_str(image_latent_prefix)}") + print(f"[HeliosDebug][I2V] fake_latent: {_tensor_stats_str(fake_latent)}") + print(f"[HeliosDebug][I2V] history_latent: {_tensor_stats_str(history_latent)}") positive, negative = _set_helios_history_values(positive, negative, history_latent, sizes, keep_first_frame, prefix_latent=image_latent_prefix) return io.NodeOutput( @@ -608,6 +657,10 @@ class HeliosImageToVideo(io.ComfyNode): "helios_history_latent": history_latent, "helios_image_latent_prefix": image_latent_prefix, "helios_history_valid_mask": history_valid_mask, + "helios_num_frames": int(length), + "helios_noise_gen_state": noise_gen_state, + "helios_include_history_in_output": _strict_bool(include_history_in_output, default=False), + "helios_debug_latent_stats": bool(debug_latent_stats), }, ) @@ -686,6 +739,7 @@ class HeliosTextToVideo(io.ComfyNode): "helios_history_latent": history_latent, "helios_image_latent_prefix": None, "helios_history_valid_mask": history_valid_mask, + "helios_num_frames": int(length), }, ) @@ -707,10 +761,13 @@ class HeliosVideoToVideo(io.ComfyNode): io.Image.Input("video", optional=True), io.String.Input("history_sizes", default="16,2,1", advanced=True), io.Boolean.Input("keep_first_frame", default=True, advanced=True), + io.Int.Input("num_latent_frames_per_chunk", default=9, min=1, max=256, advanced=True), io.Boolean.Input("add_noise_to_video_latents", default=True, advanced=True), io.Float.Input("video_noise_sigma_min", default=0.111, min=0.0, max=1.0, step=0.0001, round=False, advanced=True), io.Float.Input("video_noise_sigma_max", default=0.135, min=0.0, max=1.0, step=0.0001, round=False, advanced=True), io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, advanced=True), + io.Boolean.Input("include_history_in_output", default=True, advanced=True), + io.Boolean.Input("debug_latent_stats", default=False, advanced=True), ], outputs=[ io.Conditioning.Output(display_name="positive"), @@ -732,10 +789,13 @@ class HeliosVideoToVideo(io.ComfyNode): video=None, history_sizes="16,2,1", keep_first_frame=True, + num_latent_frames_per_chunk=9, add_noise_to_video_latents=True, video_noise_sigma_min=0.111, video_noise_sigma_max=0.135, noise_seed=0, + include_history_in_output=True, + debug_latent_stats=False, ) -> io.NodeOutput: spacial_scale = vae.spacial_compression_encode() latent_channels = vae.latent_channels @@ -750,29 +810,81 @@ class HeliosVideoToVideo(io.ComfyNode): history_latent = torch.zeros([batch_size, latent_channels, hist_len, latent.shape[-2], latent.shape[-1]], device=latent.device, dtype=latent.dtype) history_valid_mask = torch.zeros((batch_size, hist_len), device=latent.device, dtype=torch.bool) image_latent_prefix = None + noise_gen_state = None + history_latent_output = history_latent if video is not None: video = comfy.utils.common_upscale(video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - vid_latent = vae.encode(video[:, :, :, :3]) + num_frames = int(video.shape[0]) + min_frames = max(1, (int(num_latent_frames_per_chunk) - 1) * 4 + 1) + num_chunks = num_frames // min_frames + if num_chunks == 0: + raise ValueError( + f"Video must have at least {min_frames} frames (got {num_frames} frames). " + f"Required: (num_latent_frames_per_chunk - 1) * 4 + 1 = ({int(num_latent_frames_per_chunk)} - 1) * 4 + 1 = {min_frames}" + ) + + first_frame = video[:1] + first_frame_latent = vae.encode(first_frame[:, :, :, :3]).to(device=latent.device, dtype=torch.float32) + + total_valid_frames = num_chunks * min_frames + start_frame = num_frames - total_valid_frames + latents_chunks = [] + for i in range(num_chunks): + chunk_start = start_frame + i * min_frames + chunk_end = chunk_start + min_frames + video_chunk = video[chunk_start:chunk_end] + chunk_latents = vae.encode(video_chunk[:, :, :, :3]).to(device=latent.device, dtype=torch.float32) + latents_chunks.append(chunk_latents) + vid_latent = torch.cat(latents_chunks, dim=2) + vid_latent_clean = vid_latent.clone() + if add_noise_to_video_latents: g = torch.Generator(device=vid_latent.device) g.manual_seed(int(noise_seed)) - frame_sigmas = ( - torch.rand((1, 1, vid_latent.shape[2], 1, 1), device=vid_latent.device, generator=g, dtype=vid_latent.dtype) + + image_sigma = ( + torch.rand((1,), device=first_frame_latent.device, generator=g, dtype=first_frame_latent.dtype).view(1, 1, 1, 1, 1) * (float(video_noise_sigma_max) - float(video_noise_sigma_min)) + float(video_noise_sigma_min) ) - vid_latent = frame_sigmas * torch.randn_like(vid_latent, generator=g) + (1.0 - frame_sigmas) * vid_latent - vid_latent = vid_latent[:, :, :hist_len] + first_frame_latent = _apply_helios_latent_space_noise(first_frame_latent, image_sigma, generator=g) + + noisy_chunks = [] + num_latent_chunks = max(1, vid_latent.shape[2] // int(num_latent_frames_per_chunk)) + for i in range(num_latent_chunks): + chunk_start = i * int(num_latent_frames_per_chunk) + chunk_end = chunk_start + int(num_latent_frames_per_chunk) + latent_chunk = vid_latent[:, :, chunk_start:chunk_end, :, :] + if latent_chunk.shape[2] == 0: + continue + chunk_frames = latent_chunk.shape[2] + frame_sigmas = ( + torch.rand((chunk_frames,), device=vid_latent.device, generator=g, dtype=vid_latent.dtype) + * (float(video_noise_sigma_max) - float(video_noise_sigma_min)) + + float(video_noise_sigma_min) + ).view(1, 1, chunk_frames, 1, 1) + noisy_chunk = _apply_helios_latent_space_noise(latent_chunk, frame_sigmas, generator=g) + noisy_chunks.append(noisy_chunk) + if len(noisy_chunks) > 0: + vid_latent = torch.cat(noisy_chunks, dim=2) + noise_gen_state = g.get_state().clone() + if debug_latent_stats: + print(f"[HeliosDebug][V2V] first_frame_latent: {_tensor_stats_str(first_frame_latent)}") + print(f"[HeliosDebug][V2V] video_latent: {_tensor_stats_str(vid_latent)}") + vid_latent = comfy.utils.repeat_to_batch_size(vid_latent, batch_size) - if vid_latent.shape[2] < hist_len: - keep_frames = hist_len - vid_latent.shape[2] + image_latent_prefix = comfy.utils.repeat_to_batch_size(first_frame_latent, batch_size) + video_frames = vid_latent.shape[2] + if video_frames < hist_len: + keep_frames = hist_len - video_frames history_latent = torch.cat([history_latent[:, :, :keep_frames], vid_latent], dim=2) + history_latent_output = torch.cat([history_latent_output[:, :, :keep_frames], comfy.utils.repeat_to_batch_size(vid_latent_clean, batch_size)], dim=2) history_valid_mask[:, keep_frames:] = True else: - history_latent = vid_latent[:, :, -hist_len:] - history_valid_mask[:] = True - image_latent_prefix = history_latent[:, :, :1] + history_latent = vid_latent + history_latent_output = comfy.utils.repeat_to_batch_size(vid_latent_clean, batch_size) + history_valid_mask = torch.ones((batch_size, video_frames), device=latent.device, dtype=torch.bool) positive, negative = _set_helios_history_values(positive, negative, history_latent, sizes, keep_first_frame, prefix_latent=image_latent_prefix) return io.NodeOutput( @@ -781,8 +893,14 @@ class HeliosVideoToVideo(io.ComfyNode): { "samples": latent, "helios_history_latent": history_latent, + "helios_history_latent_output": history_latent_output, "helios_image_latent_prefix": image_latent_prefix, "helios_history_valid_mask": history_valid_mask, + "helios_num_frames": int(length), + "helios_noise_gen_state": noise_gen_state, + # Keep initial history segment and generated chunks together in sampler output. + "helios_include_history_in_output": _strict_bool(include_history_in_output, default=True), + "helios_debug_latent_stats": bool(debug_latent_stats), }, ) @@ -894,7 +1012,6 @@ class HeliosPyramidSampler(io.ComfyNode): stage_steps = [max(1, int(s)) for s in stage_steps] stage_count = len(stage_steps) history_sizes_list = sorted([max(0, int(v)) for v in _parse_int_list(history_sizes, [16, 2, 1])], reverse=True) - # Diffusers parity: if not keeping first frame, fold prefix slot into short history size. if not keep_first_frame and len(history_sizes_list) > 0: history_sizes_list[-1] += 1 @@ -912,21 +1029,32 @@ class HeliosPyramidSampler(io.ComfyNode): b, c, t, h, w = latent_samples.shape chunk_t = max(1, int(num_latent_frames_per_chunk)) - chunk_count = max(1, (t + chunk_t - 1) // chunk_t) + num_frames = int(latent.get("helios_num_frames", max(1, (int(t) - 1) * 4 + 1))) + window_num_frames = (chunk_t - 1) * 4 + 1 + chunk_count = max(1, (num_frames + window_num_frames - 1) // window_num_frames) euler_sampler = comfy.samplers.KSAMPLER(_helios_euler_sample) target_device = comfy.model_management.get_torch_device() noise_gen = torch.Generator(device=target_device) noise_gen.manual_seed(int(noise_seed)) + noise_gen_state = latent.get("helios_noise_gen_state", None) + if noise_gen_state is not None: + try: + noise_gen.set_state(noise_gen_state) + except Exception: + pass + debug_latent_stats = bool(latent.get("helios_debug_latent_stats", False)) image_latent_prefix = latent.get("helios_image_latent_prefix", None) history_valid_mask = latent.get("helios_history_valid_mask", None) if history_valid_mask is None: raise ValueError("Helios sampler requires `helios_history_valid_mask` in latent input.") + history_full = None history_from_latent_applied = False if image_latent_prefix is not None: image_latent_prefix = model.model.process_latent_in(image_latent_prefix) if "helios_history_latent" in latent: history_in = _process_latent_in_preserve_zero_frames(model, latent["helios_history_latent"], valid_mask=history_valid_mask) + history_full = history_in positive, negative = _set_helios_history_values( positive, negative, @@ -959,8 +1087,6 @@ class HeliosPyramidSampler(io.ComfyNode): x0_output = {} generated_chunks = [] if latents_history_short is not None and latents_history_mid is not None and latents_history_long is not None: - # Diffusers parity: `history_latents` storage does NOT include the keep_first_frame prefix slot. - # `latents_history_short` in conditioning may include [prefix + short_base], so strip prefix here. short_base_size = history_sizes_list[-1] if len(history_sizes_list) > 0 else latents_history_short.shape[2] if keep_first_frame and latents_history_short.shape[2] > short_base_size: short_for_history = latents_history_short[:, :, -short_base_size:] @@ -974,7 +1100,7 @@ class HeliosPyramidSampler(io.ComfyNode): hist_len = max(1, sum(history_sizes_list)) rolling_history = torch.zeros((b, c, hist_len, h, w), device=latent_samples.device, dtype=latent_samples.dtype) - # Align with Diffusers behavior: when initial video latents are provided, seed history buffer + # When initial video latents are provided, seed history buffer # with those latents before the first denoising chunk. if not add_noise: hist_len = max(1, sum(history_sizes_list)) @@ -988,9 +1114,29 @@ class HeliosPyramidSampler(io.ComfyNode): rolling_history = video_latents[:, :, -hist_len:] # Keep history/prefix on the same device/dtype as denoising latents. - rolling_history = rolling_history.to(device=target_device, dtype=latent_samples.dtype) + rolling_history = rolling_history.to(device=target_device, dtype=torch.float32) if image_latent_prefix is not None: - image_latent_prefix = image_latent_prefix.to(device=target_device, dtype=latent_samples.dtype) + image_latent_prefix = image_latent_prefix.to(device=target_device, dtype=torch.float32) + + history_output = history_full if history_full is not None else rolling_history + if "helios_history_latent_output" in latent: + history_output = _process_latent_in_preserve_zero_frames( + model, + latent["helios_history_latent_output"], + valid_mask=history_valid_mask, + ) + history_output = history_output.to(device=target_device, dtype=torch.float32) + if history_valid_mask is not None: + if not torch.is_tensor(history_valid_mask): + history_valid_mask = torch.tensor(history_valid_mask, device=target_device) + history_valid_mask = history_valid_mask.to(device=target_device) + if history_valid_mask.ndim == 2: + initial_generated_latent_frames = int(history_valid_mask.any(dim=0).sum().item()) + else: + initial_generated_latent_frames = int(history_valid_mask.reshape(-1).sum().item()) + else: + initial_generated_latent_frames = 0 + total_generated_latent_frames = initial_generated_latent_frames for chunk_idx in range(chunk_count): # Extract chunk from input latents @@ -1000,8 +1146,6 @@ class HeliosPyramidSampler(io.ComfyNode): # Prepare initial latent for this chunk if add_noise: - # Diffusers parity: each chunk denoises a fixed latent window size. - # Keep chunk temporal length constant and crop only after all chunks. noise_shape = ( latent_samples.shape[0], latent_samples.shape[1], @@ -1009,9 +1153,9 @@ class HeliosPyramidSampler(io.ComfyNode): latent_samples.shape[3], latent_samples.shape[4], ) - stage_latent = torch.randn(noise_shape, device=target_device, dtype=latent_samples.dtype, generator=noise_gen) + stage_latent = torch.randn(noise_shape, device=target_device, dtype=torch.float32, generator=noise_gen) else: - # Use actual input latents; pad final short chunk to fixed size like Diffusers windowing. + # Use actual input latents; pad final short chunk to fixed size. stage_latent = latent_chunk.clone() if stage_latent.shape[2] < chunk_t: if stage_latent.shape[2] == 0: @@ -1024,22 +1168,20 @@ class HeliosPyramidSampler(io.ComfyNode): latent_samples.shape[4], ), device=latent_samples.device, - dtype=latent_samples.dtype, + dtype=torch.float32, ) else: pad = stage_latent[:, :, -1:].repeat(1, 1, chunk_t - stage_latent.shape[2], 1, 1) stage_latent = torch.cat([stage_latent, pad], dim=2) + stage_latent = stage_latent.to(dtype=torch.float32) # Downsample to stage 0 resolution for _ in range(max(0, int(stage_count) - 1)): stage_latent = _downsample_latent_5d_bilinear_x2(stage_latent) - # Keep stage latents on model device for parity with Diffusers scheduler/noise path. + # Keep stage latents on model device for scheduler/noise path consistency. stage_latent = stage_latent.to(target_device) - # Diffusers parity: - # keep_first_frame=True and no image_latent_prefix on the first chunk - # should use an all-zero prefix frame, not history[:, :, :1]. chunk_prefix = image_latent_prefix if keep_first_frame and image_latent_prefix is None and chunk_idx == 0: chunk_prefix = torch.zeros( @@ -1065,6 +1207,10 @@ class HeliosPyramidSampler(io.ComfyNode): latents_history_short = _extract_condition_value(positive_chunk, "latents_history_short") latents_history_mid = _extract_condition_value(positive_chunk, "latents_history_mid") latents_history_long = _extract_condition_value(positive_chunk, "latents_history_long") + if debug_latent_stats: + print(f"[HeliosDebug][Sampler][chunk={chunk_idx}] latents_history_short: {_tensor_stats_str(latents_history_short)}") + print(f"[HeliosDebug][Sampler][chunk={chunk_idx}] latents_history_mid: {_tensor_stats_str(latents_history_mid)}") + print(f"[HeliosDebug][Sampler][chunk={chunk_idx}] latents_history_long: {_tensor_stats_str(latents_history_long)}") for stage_idx in range(stage_count): stage_latent = stage_latent.to(comfy.model_management.get_torch_device()) @@ -1099,8 +1245,7 @@ class HeliosPyramidSampler(io.ComfyNode): else: pass - # Keep parity with Diffusers pipeline order: - # stage timesteps are computed before upsampling/renoise for stage > 0. + # Stage timesteps are computed before upsampling/renoise for stage > 0. if stage_idx > 0: stage_latent = _upsample_latent_5d(stage_latent, scale=2) @@ -1188,8 +1333,7 @@ class HeliosPyramidSampler(io.ComfyNode): seed=noise_seed + chunk_idx * 100 + stage_idx, ) # sample_custom returns latent_format.process_out(samples); convert back to model-space - # so subsequent pyramid stages and history conditioning stay in the same latent space - # as Diffusers' internal denoising latents. + # so subsequent pyramid stages and history conditioning stay in the same latent space. stage_latent = model.model.process_latent_in(stage_latent) if stage_latent.shape[-2] != h or stage_latent.shape[-1] != w: @@ -1205,12 +1349,27 @@ class HeliosPyramidSampler(io.ComfyNode): rolling_history = torch.cat([rolling_history, stage_latent.to(rolling_history.device, rolling_history.dtype)], dim=2) keep_hist = max(1, sum(history_sizes_list)) rolling_history = rolling_history[:, :, -keep_hist:] + total_generated_latent_frames += stage_latent.shape[2] + history_output = torch.cat([history_output, stage_latent.to(history_output.device, history_output.dtype)], dim=2) - stage_latent = torch.cat(generated_chunks, dim=2)[:, :, :t] + include_history_in_output = _strict_bool(latent.get("helios_include_history_in_output", False), default=False) + if include_history_in_output and history_output is not None: + keep_t = max(0, int(total_generated_latent_frames)) + stage_latent = history_output[:, :, -keep_t:] if keep_t > 0 else history_output[:, :, :0] + elif len(generated_chunks) > 0: + stage_latent = torch.cat(generated_chunks, dim=2) + else: + stage_latent = torch.zeros((b, c, 0, h, w), device=target_device, dtype=torch.float32) out = latent.copy() out.pop("downscale_ratio_spacial", None) out["samples"] = model.model.process_latent_out(stage_latent) + out["helios_chunk_decode"] = True + out["helios_chunk_latent_frames"] = int(chunk_t) + out["helios_chunk_count"] = int(len(generated_chunks)) + out["helios_window_num_frames"] = int(window_num_frames) + out["helios_num_frames"] = int(num_frames) + out["helios_prefix_latent_frames"] = int(initial_generated_latent_frames if include_history_in_output else 0) if "x0" in x0_output: x0_out = model.model.process_latent_out(x0_output["x0"].cpu()) @@ -1222,6 +1381,60 @@ class HeliosPyramidSampler(io.ComfyNode): return io.NodeOutput(out, out_denoised) +class HeliosVAEDecode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HeliosVAEDecode", + category="latent", + inputs=[ + io.Latent.Input("samples"), + io.Vae.Input("vae"), + ], + outputs=[io.Image.Output(display_name="image")], + ) + + @classmethod + def execute(cls, samples, vae) -> io.NodeOutput: + latent = samples["samples"] + if latent.is_nested: + latent = latent.unbind()[0] + + helios_chunk_decode = bool(samples.get("helios_chunk_decode", False)) + helios_chunk_latent_frames = int(samples.get("helios_chunk_latent_frames", 0) or 0) + helios_prefix_latent_frames = int(samples.get("helios_prefix_latent_frames", 0) or 0) + + if ( + helios_chunk_decode + and latent.ndim == 5 + and helios_chunk_latent_frames > 0 + and latent.shape[2] > 0 + ): + decoded_chunks = [] + prefix_t = max(0, min(helios_prefix_latent_frames, latent.shape[2])) + + if prefix_t > 0: + decoded_chunks.append(vae.decode(latent[:, :, :prefix_t])) + + body = latent[:, :, prefix_t:] + for start in range(0, body.shape[2], helios_chunk_latent_frames): + chunk = body[:, :, start:start + helios_chunk_latent_frames] + if chunk.shape[2] == 0: + continue + decoded_chunks.append(vae.decode(chunk)) + + if len(decoded_chunks) > 0: + images = torch.cat(decoded_chunks, dim=1) + else: + images = vae.decode(latent) + else: + images = vae.decode(latent) + + if len(images.shape) == 5: + images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) + return io.NodeOutput(images) + + class HeliosExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -1231,6 +1444,7 @@ class HeliosExtension(ComfyExtension): HeliosVideoToVideo, HeliosHistoryConditioning, HeliosPyramidSampler, + HeliosVAEDecode, ] From c93de1268520f4044179252140c0e1fdcaa2f600 Mon Sep 17 00:00:00 2001 From: qqingzheng <2533221180@qq.com> Date: Sun, 8 Mar 2026 23:11:13 +0800 Subject: [PATCH 04/10] Remove debug latent stats functionality from HeliosImageToVideo and HeliosVideoToVideo classes --- comfy_extras/nodes_helios.py | 35 ----------------------------------- 1 file changed, 35 deletions(-) diff --git a/comfy_extras/nodes_helios.py b/comfy_extras/nodes_helios.py index 3d5f80e76..89a63c3fd 100644 --- a/comfy_extras/nodes_helios.py +++ b/comfy_extras/nodes_helios.py @@ -58,21 +58,6 @@ def _apply_helios_latent_space_noise(latent, sigma, generator=None): return _HELIOS_LATENT_FORMAT.process_out(noised_in).to(device=latent.device, dtype=latent.dtype) -def _tensor_stats_str(x): - if x is None: - return "None" - if not torch.is_tensor(x): - return f"non-tensor type={type(x)}" - if x.numel() == 0: - return f"shape={tuple(x.shape)} empty" - xf = x.detach().to(torch.float32) - return ( - f"shape={tuple(x.shape)} " - f"mean={xf.mean().item():.6f} std={xf.std(unbiased=False).item():.6f} " - f"min={xf.min().item():.6f} max={xf.max().item():.6f}" - ) - - def _parse_float_list(values, default): if values is None: return default @@ -557,7 +542,6 @@ class HeliosImageToVideo(io.ComfyNode): io.Float.Input("image_noise_sigma_max", default=0.135, min=0.0, max=1.0, step=0.0001, round=False, advanced=True), io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, advanced=True), io.Boolean.Input("include_history_in_output", default=False, advanced=True), - io.Boolean.Input("debug_latent_stats", default=False, advanced=True), ], outputs=[ io.Conditioning.Output(display_name="positive"), @@ -585,7 +569,6 @@ class HeliosImageToVideo(io.ComfyNode): image_noise_sigma_max=0.135, noise_seed=0, include_history_in_output=False, - debug_latent_stats=False, ) -> io.NodeOutput: video_noise_sigma_min = 0.111 video_noise_sigma_max = 0.135 @@ -643,10 +626,6 @@ class HeliosImageToVideo(io.ComfyNode): history_valid_mask[:, -1] = True if i2v_noise_gen is not None: noise_gen_state = i2v_noise_gen.get_state().clone() - if debug_latent_stats: - print(f"[HeliosDebug][I2V] image_latent_prefix: {_tensor_stats_str(image_latent_prefix)}") - print(f"[HeliosDebug][I2V] fake_latent: {_tensor_stats_str(fake_latent)}") - print(f"[HeliosDebug][I2V] history_latent: {_tensor_stats_str(history_latent)}") positive, negative = _set_helios_history_values(positive, negative, history_latent, sizes, keep_first_frame, prefix_latent=image_latent_prefix) return io.NodeOutput( @@ -660,7 +639,6 @@ class HeliosImageToVideo(io.ComfyNode): "helios_num_frames": int(length), "helios_noise_gen_state": noise_gen_state, "helios_include_history_in_output": _strict_bool(include_history_in_output, default=False), - "helios_debug_latent_stats": bool(debug_latent_stats), }, ) @@ -767,7 +745,6 @@ class HeliosVideoToVideo(io.ComfyNode): io.Float.Input("video_noise_sigma_max", default=0.135, min=0.0, max=1.0, step=0.0001, round=False, advanced=True), io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, advanced=True), io.Boolean.Input("include_history_in_output", default=True, advanced=True), - io.Boolean.Input("debug_latent_stats", default=False, advanced=True), ], outputs=[ io.Conditioning.Output(display_name="positive"), @@ -795,7 +772,6 @@ class HeliosVideoToVideo(io.ComfyNode): video_noise_sigma_max=0.135, noise_seed=0, include_history_in_output=True, - debug_latent_stats=False, ) -> io.NodeOutput: spacial_scale = vae.spacial_compression_encode() latent_channels = vae.latent_channels @@ -869,10 +845,6 @@ class HeliosVideoToVideo(io.ComfyNode): if len(noisy_chunks) > 0: vid_latent = torch.cat(noisy_chunks, dim=2) noise_gen_state = g.get_state().clone() - if debug_latent_stats: - print(f"[HeliosDebug][V2V] first_frame_latent: {_tensor_stats_str(first_frame_latent)}") - print(f"[HeliosDebug][V2V] video_latent: {_tensor_stats_str(vid_latent)}") - vid_latent = comfy.utils.repeat_to_batch_size(vid_latent, batch_size) image_latent_prefix = comfy.utils.repeat_to_batch_size(first_frame_latent, batch_size) video_frames = vid_latent.shape[2] @@ -900,7 +872,6 @@ class HeliosVideoToVideo(io.ComfyNode): "helios_noise_gen_state": noise_gen_state, # Keep initial history segment and generated chunks together in sampler output. "helios_include_history_in_output": _strict_bool(include_history_in_output, default=True), - "helios_debug_latent_stats": bool(debug_latent_stats), }, ) @@ -1042,7 +1013,6 @@ class HeliosPyramidSampler(io.ComfyNode): noise_gen.set_state(noise_gen_state) except Exception: pass - debug_latent_stats = bool(latent.get("helios_debug_latent_stats", False)) image_latent_prefix = latent.get("helios_image_latent_prefix", None) history_valid_mask = latent.get("helios_history_valid_mask", None) @@ -1207,11 +1177,6 @@ class HeliosPyramidSampler(io.ComfyNode): latents_history_short = _extract_condition_value(positive_chunk, "latents_history_short") latents_history_mid = _extract_condition_value(positive_chunk, "latents_history_mid") latents_history_long = _extract_condition_value(positive_chunk, "latents_history_long") - if debug_latent_stats: - print(f"[HeliosDebug][Sampler][chunk={chunk_idx}] latents_history_short: {_tensor_stats_str(latents_history_short)}") - print(f"[HeliosDebug][Sampler][chunk={chunk_idx}] latents_history_mid: {_tensor_stats_str(latents_history_mid)}") - print(f"[HeliosDebug][Sampler][chunk={chunk_idx}] latents_history_long: {_tensor_stats_str(latents_history_long)}") - for stage_idx in range(stage_count): stage_latent = stage_latent.to(comfy.model_management.get_torch_device()) sigmas = _helios_stage_sigmas( From b476d5e62b0c00e40e7e87b9a490d99d66661e4f Mon Sep 17 00:00:00 2001 From: qqingzheng <2533221180@qq.com> Date: Tue, 10 Mar 2026 15:36:37 +0800 Subject: [PATCH 05/10] Fix lint: whitespace and unused vars --- comfy/ldm/helios/model.py | 24 +++++++----------------- comfy_extras/nodes_helios.py | 1 - 2 files changed, 7 insertions(+), 18 deletions(-) diff --git a/comfy/ldm/helios/model.py b/comfy/ldm/helios/model.py index c1ea5f595..2faeea897 100644 --- a/comfy/ldm/helios/model.py +++ b/comfy/ldm/helios/model.py @@ -711,19 +711,9 @@ class HeliosModel(torch.nn.Module): ) f_long = self._rope_downsample_3d(f_long, (long_t, hs, ws), (4, 4, 4)) hidden_states = torch.cat([x_long, hidden_states], dim=1) - freqs = torch.cat([f_long, freqs], dim=1) + freqs = torch.cat([f_long, freqs], dim=1) history_context_length = hidden_states.shape[1] - original_context_length - mismatch = hidden_states.shape[1] != freqs.shape[1] - summary_key = ( - int(post_t), - int(post_h), - int(post_w), - int(original_context_length), - int(hidden_states.shape[1]), - int(freqs.shape[1]), - int(history_context_length), - ) if timestep.ndim == 0: timestep = timestep.unsqueeze(0) @@ -770,28 +760,28 @@ class HeliosModel(torch.nn.Module): def unpatchify(self, x, grid_sizes): """ Unpatchify the output from proj_out back to video format. - + Args: x: [batch, num_patches, out_dim * prod(patch_size)] grid_sizes: (num_frames, height, width) in patch space - + Returns: [batch, out_dim, num_frames, height, width] in pixel space """ b = x.shape[0] post_t, post_h, post_w = grid_sizes p_t, p_h, p_w = self.patch_size - + # Reshape: [B, T*H*W, out_dim*p_t*p_h*p_w] -> [B, T, H, W, p_t, p_h, p_w, out_dim] # Use -1 to let PyTorch infer the channel dimension (out_dim) hidden_states = x.reshape(b, post_t, post_h, post_w, p_t, p_h, p_w, -1) - + # Permute: [B, T, H, W, p_t, p_h, p_w, C] -> [B, C, T, p_t, H, p_h, W, p_w] hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) - + # Flatten patches: [B, C, T, p_t, H, p_h, W, p_w] -> [B, C, T*p_t, H*p_h, W*p_w] output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) - + return output def _rope_downsample_3d(self, freqs, grid_sizes, kernel_size): b, _, one, d, i2, j2 = freqs.shape diff --git a/comfy_extras/nodes_helios.py b/comfy_extras/nodes_helios.py index 89a63c3fd..2fcf0a64f 100644 --- a/comfy_extras/nodes_helios.py +++ b/comfy_extras/nodes_helios.py @@ -412,7 +412,6 @@ def _helios_dmd_sample( for i in range(len(sigmas) - 1): sigma = sigmas[i] - sigma_next = sigmas[i + 1] timestep = all_timesteps[i] if i < len(all_timesteps) else i denoised = model(x, sigma * s_in, **extra_args) From c25df83b8ad43e76601ae4698cb96c33dff7f109 Mon Sep 17 00:00:00 2001 From: qqingzheng <2533221180@qq.com> Date: Tue, 10 Mar 2026 19:11:59 +0800 Subject: [PATCH 06/10] Fix Helios norm2 fallback and history RoPE guards; simplify sampler knobs --- comfy/ldm/helios/model.py | 73 +++++++++++++++++++++--------------- comfy_extras/nodes_helios.py | 63 +++++-------------------------- 2 files changed, 52 insertions(+), 84 deletions(-) diff --git a/comfy/ldm/helios/model.py b/comfy/ldm/helios/model.py index 2faeea897..911d45831 100644 --- a/comfy/ldm/helios/model.py +++ b/comfy/ldm/helios/model.py @@ -228,13 +228,14 @@ class HeliosAttentionBlock(nn.Module): operation_settings=operation_settings, ) + self.cross_attn_norm = bool(cross_attn_norm) self.norm2 = (operation_settings.get("operations").LayerNorm( dim, eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"), - ) if cross_attn_norm else nn.Identity()) + ) if self.cross_attn_norm else nn.Identity()) self.attn2 = HeliosSelfAttention( dim, num_heads, @@ -309,14 +310,17 @@ class HeliosAttentionBlock(nn.Module): if self.guidance_cross_attn and original_context_length is not None: history_seq_len = x.shape[1] - original_context_length history_x, x_main = torch.split(x, [history_seq_len, original_context_length], dim=1) - # norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior - norm_x_main = torch.nn.functional.layer_norm( - x_main.float(), - self.norm2.normalized_shape, - self.norm2.weight.to(x_main.device).float() if self.norm2.weight is not None else None, - self.norm2.bias.to(x_main.device).float() if self.norm2.bias is not None else None, - self.norm2.eps, - ).type_as(x_main) + if self.cross_attn_norm: + # norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior + norm_x_main = torch.nn.functional.layer_norm( + x_main.float(), + self.norm2.normalized_shape, + self.norm2.weight.to(x_main.device).float() if self.norm2.weight is not None else None, + self.norm2.bias.to(x_main.device).float() if self.norm2.bias is not None else None, + self.norm2.eps, + ).type_as(x_main) + else: + norm_x_main = x_main x_main = x_main + self.attn2( norm_x_main, context=context, @@ -324,14 +328,17 @@ class HeliosAttentionBlock(nn.Module): ) x = torch.cat([history_x, x_main], dim=1) else: - # norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior - norm_x = torch.nn.functional.layer_norm( - x.float(), - self.norm2.normalized_shape, - self.norm2.weight.to(x.device).float() if self.norm2.weight is not None else None, - self.norm2.bias.to(x.device).float() if self.norm2.bias is not None else None, - self.norm2.eps, - ).type_as(x) + if self.cross_attn_norm: + # norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior + norm_x = torch.nn.functional.layer_norm( + x.float(), + self.norm2.normalized_shape, + self.norm2.weight.to(x.device).float() if self.norm2.weight is not None else None, + self.norm2.bias.to(x.device).float() if self.norm2.bias is not None else None, + self.norm2.eps, + ).type_as(x) + else: + norm_x = x x = x + self.attn2(norm_x, context=context, transformer_options=transformer_options) # ffn @@ -673,45 +680,51 @@ class HeliosModel(torch.nn.Module): if latents_history_mid is not None and indices_latents_history_mid is not None: x_mid = self.patch_mid(pad_for_3d_conv(latents_history_mid, (2, 4, 4))) - _, _, tm, _, _ = x_mid.shape + _, _, tm, hm, wm = x_mid.shape x_mid = x_mid.flatten(2).transpose(1, 2) mid_t = indices_latents_history_mid.shape[1] + # patch_mid downsamples by 2 in (t, h, w); build RoPE on the pre-downsample grid. + mid_h = hm * 2 + mid_w = wm * 2 f_mid = self.rope_encode( t=mid_t * self.patch_size[0], - h=hs * self.patch_size[1], - w=ws * self.patch_size[2], + h=mid_h * self.patch_size[1], + w=mid_w * self.patch_size[2], steps_t=mid_t, - steps_h=hs, - steps_w=ws, + steps_h=mid_h, + steps_w=mid_w, device=x_mid.device, dtype=x_mid.dtype, transformer_options=transformer_options, frame_indices=indices_latents_history_mid, ) - f_mid = self._rope_downsample_3d(f_mid, (mid_t, hs, ws), (2, 2, 2)) + f_mid = self._rope_downsample_3d(f_mid, (mid_t, mid_h, mid_w), (2, 2, 2)) hidden_states = torch.cat([x_mid, hidden_states], dim=1) freqs = torch.cat([f_mid, freqs], dim=1) if latents_history_long is not None and indices_latents_history_long is not None: x_long = self.patch_long(pad_for_3d_conv(latents_history_long, (4, 8, 8))) - _, _, tl, _, _ = x_long.shape + _, _, tl, hl, wl = x_long.shape x_long = x_long.flatten(2).transpose(1, 2) long_t = indices_latents_history_long.shape[1] + # patch_long downsamples by 4 in (t, h, w); build RoPE on the pre-downsample grid. + long_h = hl * 4 + long_w = wl * 4 f_long = self.rope_encode( t=long_t * self.patch_size[0], - h=hs * self.patch_size[1], - w=ws * self.patch_size[2], + h=long_h * self.patch_size[1], + w=long_w * self.patch_size[2], steps_t=long_t, - steps_h=hs, - steps_w=ws, + steps_h=long_h, + steps_w=long_w, device=x_long.device, dtype=x_long.dtype, transformer_options=transformer_options, frame_indices=indices_latents_history_long, ) - f_long = self._rope_downsample_3d(f_long, (long_t, hs, ws), (4, 4, 4)) + f_long = self._rope_downsample_3d(f_long, (long_t, long_h, long_w), (4, 4, 4)) hidden_states = torch.cat([x_long, hidden_states], dim=1) - freqs = torch.cat([f_long, freqs], dim=1) + freqs = torch.cat([f_long, freqs], dim=1) history_context_length = hidden_states.shape[1] - original_context_length diff --git a/comfy_extras/nodes_helios.py b/comfy_extras/nodes_helios.py index 2fcf0a64f..7f62b5003 100644 --- a/comfy_extras/nodes_helios.py +++ b/comfy_extras/nodes_helios.py @@ -914,7 +914,6 @@ class HeliosPyramidSampler(io.ComfyNode): category="sampling/video_models", inputs=[ io.Model.Input("model"), - io.Boolean.Input("add_noise", default=True, advanced=True), io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, control_after_generate=True), io.Float.Input("cfg", default=5.0, min=0.0, max=100.0, step=0.1, round=0.01), io.Conditioning.Input("positive"), @@ -931,7 +930,6 @@ class HeliosPyramidSampler(io.ComfyNode): io.Boolean.Input("cfg_zero_star", default=True, advanced=True), io.Boolean.Input("use_zero_init", default=True, advanced=True), io.Int.Input("zero_steps", default=1, min=0, max=10000, advanced=True), - io.Boolean.Input("skip_first_chunk", default=False, advanced=True), ], outputs=[ io.Latent.Output(display_name="output"), @@ -943,7 +941,6 @@ class HeliosPyramidSampler(io.ComfyNode): def execute( cls, model, - add_noise, noise_seed, cfg, positive, @@ -960,7 +957,6 @@ class HeliosPyramidSampler(io.ComfyNode): cfg_zero_star, use_zero_init, zero_steps, - skip_first_chunk, ) -> io.NodeOutput: # Keep these scheduler knobs internal (not exposed in node UI). shift = 1.0 @@ -975,8 +971,6 @@ class HeliosPyramidSampler(io.ComfyNode): latent = latent_image.copy() latent_samples = comfy.sample.fix_empty_latent_channels(model, latent["samples"], latent.get("downscale_ratio_spacial", None)) - if not add_noise: - latent_samples = _process_latent_in_preserve_zero_frames(model, latent_samples) stage_steps = _parse_int_list(pyramid_steps, [10, 10, 10]) stage_steps = [max(1, int(s)) for s in stage_steps] @@ -1069,19 +1063,6 @@ class HeliosPyramidSampler(io.ComfyNode): hist_len = max(1, sum(history_sizes_list)) rolling_history = torch.zeros((b, c, hist_len, h, w), device=latent_samples.device, dtype=latent_samples.dtype) - # When initial video latents are provided, seed history buffer - # with those latents before the first denoising chunk. - if not add_noise: - hist_len = max(1, sum(history_sizes_list)) - rolling_history = rolling_history.to(device=latent_samples.device, dtype=latent_samples.dtype) - video_latents = latent_samples - video_frames = video_latents.shape[2] - if video_frames < hist_len: - keep_frames = hist_len - video_frames - rolling_history = torch.cat([rolling_history[:, :, :keep_frames], video_latents], dim=2) - else: - rolling_history = video_latents[:, :, -hist_len:] - # Keep history/prefix on the same device/dtype as denoising latents. rolling_history = rolling_history.to(device=target_device, dtype=torch.float32) if image_latent_prefix is not None: @@ -1108,41 +1089,15 @@ class HeliosPyramidSampler(io.ComfyNode): total_generated_latent_frames = initial_generated_latent_frames for chunk_idx in range(chunk_count): - # Extract chunk from input latents - chunk_start = chunk_idx * chunk_t - chunk_end = min(chunk_start + chunk_t, t) - latent_chunk = latent_samples[:, :, chunk_start:chunk_end, :, :] - # Prepare initial latent for this chunk - if add_noise: - noise_shape = ( - latent_samples.shape[0], - latent_samples.shape[1], - chunk_t, - latent_samples.shape[3], - latent_samples.shape[4], - ) - stage_latent = torch.randn(noise_shape, device=target_device, dtype=torch.float32, generator=noise_gen) - else: - # Use actual input latents; pad final short chunk to fixed size. - stage_latent = latent_chunk.clone() - if stage_latent.shape[2] < chunk_t: - if stage_latent.shape[2] == 0: - stage_latent = torch.zeros( - ( - latent_samples.shape[0], - latent_samples.shape[1], - chunk_t, - latent_samples.shape[3], - latent_samples.shape[4], - ), - device=latent_samples.device, - dtype=torch.float32, - ) - else: - pad = stage_latent[:, :, -1:].repeat(1, 1, chunk_t - stage_latent.shape[2], 1, 1) - stage_latent = torch.cat([stage_latent, pad], dim=2) - stage_latent = stage_latent.to(dtype=torch.float32) + noise_shape = ( + latent_samples.shape[0], + latent_samples.shape[1], + chunk_t, + latent_samples.shape[3], + latent_samples.shape[4], + ) + stage_latent = torch.randn(noise_shape, device=target_device, dtype=torch.float32, generator=noise_gen) # Downsample to stage 0 resolution for _ in range(max(0, int(stage_count) - 1)): @@ -1308,7 +1263,7 @@ class HeliosPyramidSampler(io.ComfyNode): stage_latent = stage_latent[:, :, :, :h, :w] generated_chunks.append(stage_latent) - if keep_first_frame and ((chunk_idx == 0 and image_latent_prefix is None) or (skip_first_chunk and chunk_idx == 1)): + if keep_first_frame and (chunk_idx == 0 and image_latent_prefix is None): image_latent_prefix = stage_latent[:, :, :1] rolling_history = torch.cat([rolling_history, stage_latent.to(rolling_history.device, rolling_history.dtype)], dim=2) keep_hist = max(1, sum(history_sizes_list)) From 86c0755ee2bf8abd5dfe16500429592b23aa5861 Mon Sep 17 00:00:00 2001 From: qqingzheng <2533221180@qq.com> Date: Tue, 10 Mar 2026 19:54:05 +0800 Subject: [PATCH 07/10] Remove Helios include-history-in-output plumbing --- comfy_extras/nodes_helios.py | 46 +++--------------------------------- 1 file changed, 3 insertions(+), 43 deletions(-) diff --git a/comfy_extras/nodes_helios.py b/comfy_extras/nodes_helios.py index 7f62b5003..15feb7491 100644 --- a/comfy_extras/nodes_helios.py +++ b/comfy_extras/nodes_helios.py @@ -540,7 +540,6 @@ class HeliosImageToVideo(io.ComfyNode): io.Float.Input("image_noise_sigma_min", default=0.111, min=0.0, max=1.0, step=0.0001, round=False, advanced=True), io.Float.Input("image_noise_sigma_max", default=0.135, min=0.0, max=1.0, step=0.0001, round=False, advanced=True), io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, advanced=True), - io.Boolean.Input("include_history_in_output", default=False, advanced=True), ], outputs=[ io.Conditioning.Output(display_name="positive"), @@ -567,7 +566,6 @@ class HeliosImageToVideo(io.ComfyNode): image_noise_sigma_min=0.111, image_noise_sigma_max=0.135, noise_seed=0, - include_history_in_output=False, ) -> io.NodeOutput: video_noise_sigma_min = 0.111 video_noise_sigma_max = 0.135 @@ -637,7 +635,6 @@ class HeliosImageToVideo(io.ComfyNode): "helios_history_valid_mask": history_valid_mask, "helios_num_frames": int(length), "helios_noise_gen_state": noise_gen_state, - "helios_include_history_in_output": _strict_bool(include_history_in_output, default=False), }, ) @@ -743,7 +740,6 @@ class HeliosVideoToVideo(io.ComfyNode): io.Float.Input("video_noise_sigma_min", default=0.111, min=0.0, max=1.0, step=0.0001, round=False, advanced=True), io.Float.Input("video_noise_sigma_max", default=0.135, min=0.0, max=1.0, step=0.0001, round=False, advanced=True), io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, advanced=True), - io.Boolean.Input("include_history_in_output", default=True, advanced=True), ], outputs=[ io.Conditioning.Output(display_name="positive"), @@ -770,7 +766,6 @@ class HeliosVideoToVideo(io.ComfyNode): video_noise_sigma_min=0.111, video_noise_sigma_max=0.135, noise_seed=0, - include_history_in_output=True, ) -> io.NodeOutput: spacial_scale = vae.spacial_compression_encode() latent_channels = vae.latent_channels @@ -869,8 +864,6 @@ class HeliosVideoToVideo(io.ComfyNode): "helios_history_valid_mask": history_valid_mask, "helios_num_frames": int(length), "helios_noise_gen_state": noise_gen_state, - # Keep initial history segment and generated chunks together in sampler output. - "helios_include_history_in_output": _strict_bool(include_history_in_output, default=True), }, ) @@ -1011,13 +1004,11 @@ class HeliosPyramidSampler(io.ComfyNode): history_valid_mask = latent.get("helios_history_valid_mask", None) if history_valid_mask is None: raise ValueError("Helios sampler requires `helios_history_valid_mask` in latent input.") - history_full = None history_from_latent_applied = False if image_latent_prefix is not None: image_latent_prefix = model.model.process_latent_in(image_latent_prefix) if "helios_history_latent" in latent: history_in = _process_latent_in_preserve_zero_frames(model, latent["helios_history_latent"], valid_mask=history_valid_mask) - history_full = history_in positive, negative = _set_helios_history_values( positive, negative, @@ -1068,25 +1059,7 @@ class HeliosPyramidSampler(io.ComfyNode): if image_latent_prefix is not None: image_latent_prefix = image_latent_prefix.to(device=target_device, dtype=torch.float32) - history_output = history_full if history_full is not None else rolling_history - if "helios_history_latent_output" in latent: - history_output = _process_latent_in_preserve_zero_frames( - model, - latent["helios_history_latent_output"], - valid_mask=history_valid_mask, - ) - history_output = history_output.to(device=target_device, dtype=torch.float32) - if history_valid_mask is not None: - if not torch.is_tensor(history_valid_mask): - history_valid_mask = torch.tensor(history_valid_mask, device=target_device) - history_valid_mask = history_valid_mask.to(device=target_device) - if history_valid_mask.ndim == 2: - initial_generated_latent_frames = int(history_valid_mask.any(dim=0).sum().item()) - else: - initial_generated_latent_frames = int(history_valid_mask.reshape(-1).sum().item()) - else: - initial_generated_latent_frames = 0 - total_generated_latent_frames = initial_generated_latent_frames + # Always return only newly generated chunks; input history is used only for conditioning. for chunk_idx in range(chunk_count): # Prepare initial latent for this chunk @@ -1268,14 +1241,8 @@ class HeliosPyramidSampler(io.ComfyNode): rolling_history = torch.cat([rolling_history, stage_latent.to(rolling_history.device, rolling_history.dtype)], dim=2) keep_hist = max(1, sum(history_sizes_list)) rolling_history = rolling_history[:, :, -keep_hist:] - total_generated_latent_frames += stage_latent.shape[2] - history_output = torch.cat([history_output, stage_latent.to(history_output.device, history_output.dtype)], dim=2) - include_history_in_output = _strict_bool(latent.get("helios_include_history_in_output", False), default=False) - if include_history_in_output and history_output is not None: - keep_t = max(0, int(total_generated_latent_frames)) - stage_latent = history_output[:, :, -keep_t:] if keep_t > 0 else history_output[:, :, :0] - elif len(generated_chunks) > 0: + if len(generated_chunks) > 0: stage_latent = torch.cat(generated_chunks, dim=2) else: stage_latent = torch.zeros((b, c, 0, h, w), device=target_device, dtype=torch.float32) @@ -1288,7 +1255,6 @@ class HeliosPyramidSampler(io.ComfyNode): out["helios_chunk_count"] = int(len(generated_chunks)) out["helios_window_num_frames"] = int(window_num_frames) out["helios_num_frames"] = int(num_frames) - out["helios_prefix_latent_frames"] = int(initial_generated_latent_frames if include_history_in_output else 0) if "x0" in x0_output: x0_out = model.model.process_latent_out(x0_output["x0"].cpu()) @@ -1321,7 +1287,6 @@ class HeliosVAEDecode(io.ComfyNode): helios_chunk_decode = bool(samples.get("helios_chunk_decode", False)) helios_chunk_latent_frames = int(samples.get("helios_chunk_latent_frames", 0) or 0) - helios_prefix_latent_frames = int(samples.get("helios_prefix_latent_frames", 0) or 0) if ( helios_chunk_decode @@ -1330,12 +1295,7 @@ class HeliosVAEDecode(io.ComfyNode): and latent.shape[2] > 0 ): decoded_chunks = [] - prefix_t = max(0, min(helios_prefix_latent_frames, latent.shape[2])) - - if prefix_t > 0: - decoded_chunks.append(vae.decode(latent[:, :, :prefix_t])) - - body = latent[:, :, prefix_t:] + body = latent for start in range(0, body.shape[2], helios_chunk_latent_frames): chunk = body[:, :, start:start + helios_chunk_latent_frames] if chunk.shape[2] == 0: From c874ed00c8ce539cec43c34f3622266e9a297343 Mon Sep 17 00:00:00 2001 From: qqingzheng <2533221180@qq.com> Date: Tue, 10 Mar 2026 20:22:41 +0800 Subject: [PATCH 08/10] Remove HeliosHistoryConditioning and sampler denoised output --- comfy_extras/nodes_helios.py | 42 +----------------------------------- 1 file changed, 1 insertion(+), 41 deletions(-) diff --git a/comfy_extras/nodes_helios.py b/comfy_extras/nodes_helios.py index 15feb7491..ca4373c66 100644 --- a/comfy_extras/nodes_helios.py +++ b/comfy_extras/nodes_helios.py @@ -868,37 +868,6 @@ class HeliosVideoToVideo(io.ComfyNode): ) -class HeliosHistoryConditioning(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="HeliosHistoryConditioning", - category="conditioning/video_models", - inputs=[ - io.Conditioning.Input("positive"), - io.Conditioning.Input("negative"), - io.Latent.Input("history_latent"), - io.String.Input("history_sizes", default="16,2,1"), - io.Boolean.Input("keep_first_frame", default=True), - ], - outputs=[ - io.Conditioning.Output(display_name="positive"), - io.Conditioning.Output(display_name="negative"), - ], - ) - - @classmethod - def execute(cls, positive, negative, history_latent, history_sizes, keep_first_frame) -> io.NodeOutput: - latent = history_latent["samples"] - if latent is None or len(latent.shape) != 5: - return io.NodeOutput(positive, negative) - sizes = _parse_int_list(history_sizes, [16, 2, 1]) - sizes = sorted([max(0, int(v)) for v in sizes], reverse=True) - prefix = history_latent.get("helios_image_latent_prefix", None) - positive, negative = _set_helios_history_values(positive, negative, latent, sizes, keep_first_frame, prefix_latent=prefix) - return io.NodeOutput(positive, negative) - - class HeliosPyramidSampler(io.ComfyNode): @classmethod def define_schema(cls): @@ -926,7 +895,6 @@ class HeliosPyramidSampler(io.ComfyNode): ], outputs=[ io.Latent.Output(display_name="output"), - io.Latent.Output(display_name="denoised_output"), ], ) @@ -1256,14 +1224,7 @@ class HeliosPyramidSampler(io.ComfyNode): out["helios_window_num_frames"] = int(window_num_frames) out["helios_num_frames"] = int(num_frames) - if "x0" in x0_output: - x0_out = model.model.process_latent_out(x0_output["x0"].cpu()) - out_denoised = latent.copy() - out_denoised["samples"] = x0_out - else: - out_denoised = out - - return io.NodeOutput(out, out_denoised) + return io.NodeOutput(out) class HeliosVAEDecode(io.ComfyNode): @@ -1321,7 +1282,6 @@ class HeliosExtension(ComfyExtension): HeliosTextToVideo, HeliosImageToVideo, HeliosVideoToVideo, - HeliosHistoryConditioning, HeliosPyramidSampler, HeliosVAEDecode, ] From a5c328871d38f2332bdf173c0ec646b1348b296d Mon Sep 17 00:00:00 2001 From: qqingzheng <2533221180@qq.com> Date: Thu, 12 Mar 2026 15:45:06 +0800 Subject: [PATCH 09/10] Refactor Helios to reuse WAN text encoder, latent format, and VAE --- comfy/latent_formats.py | 8 ------- comfy/sd.py | 1 - comfy/supported_models.py | 14 ++++++++---- comfy/text_encoders/helios.py | 41 ----------------------------------- comfy_extras/nodes_helios.py | 2 +- 5 files changed, 11 insertions(+), 55 deletions(-) delete mode 100644 comfy/text_encoders/helios.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 91db60ab5..6a57bca1c 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -783,11 +783,3 @@ class ZImagePixelSpace(ChromaRadiance): No VAE encoding/decoding — the model operates directly on RGB pixels. """ pass - -class Helios(Wan21): - """Helios video model latent format - - Helios uses the same latent format as Wan21 (same VAE architecture). - Inherits latents_mean, latents_std, and processing methods from Wan21. - """ - pass diff --git a/comfy/sd.py b/comfy/sd.py index cb442439d..3f8eabb46 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -48,7 +48,6 @@ import comfy.text_encoders.hunyuan_video import comfy.text_encoders.cosmos import comfy.text_encoders.lumina2 import comfy.text_encoders.wan -import comfy.text_encoders.helios import comfy.text_encoders.hidream import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 diff --git a/comfy/supported_models.py b/comfy/supported_models.py index b0fb3ce3d..5f58e0a9f 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -17,7 +17,6 @@ import comfy.text_encoders.hunyuan_video import comfy.text_encoders.cosmos import comfy.text_encoders.lumina2 import comfy.text_encoders.wan -import comfy.text_encoders.helios import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image @@ -1143,7 +1142,7 @@ class Helios(supported_models_base.BASE): } unet_extra_config = {} - latent_format = latent_formats.Helios + latent_format = latent_formats.Wan21 memory_usage_factor = 1.8 supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] @@ -1159,8 +1158,15 @@ class Helios(supported_models_base.BASE): def clip_target(self, state_dict={}): pref = self.text_encoder_key_prefix[0] - t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref)) - return supported_models_base.ClipTarget(comfy.text_encoders.helios.HeliosT5Tokenizer, comfy.text_encoders.helios.te(**t5_detect)) + t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect( + state_dict, + "{}umt5xxl.transformer.".format(pref), + ) + # Directly reuse WAN text encoder stack; no Helios-specific TE. + return supported_models_base.ClipTarget( + comfy.text_encoders.wan.WanT5Tokenizer, + comfy.text_encoders.wan.te(**t5_detect), + ) class WAN21_T2V(supported_models_base.BASE): unet_config = { diff --git a/comfy/text_encoders/helios.py b/comfy/text_encoders/helios.py deleted file mode 100644 index dc4b38b13..000000000 --- a/comfy/text_encoders/helios.py +++ /dev/null @@ -1,41 +0,0 @@ -from comfy import sd1_clip -from .spiece_tokenizer import SPieceTokenizer -import comfy.text_encoders.t5 -import os - - -class UMT5XXlModel(sd1_clip.SDClipModel): - def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): - textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_xxl.json") - super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options) - - -class UMT5XXlTokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): - tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key="umt5xxl", tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0, tokenizer_data=tokenizer_data) - - def state_dict(self): - return {"spiece_model": self.tokenizer.serialize_model()} - - -class HeliosT5Tokenizer(sd1_clip.SD1Tokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): - super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="umt5xxl", tokenizer=UMT5XXlTokenizer) - - -class HeliosT5Model(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): - super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs) - - -def te(dtype_t5=None, t5_quantization_metadata=None): - class HeliosTEModel(HeliosT5Model): - def __init__(self, device="cpu", dtype=None, model_options={}): - if t5_quantization_metadata is not None: - model_options = model_options.copy() - model_options["quantization_metadata"] = t5_quantization_metadata - if dtype_t5 is not None: - dtype = dtype_t5 - super().__init__(device=device, dtype=dtype, model_options=model_options) - return HeliosTEModel diff --git a/comfy_extras/nodes_helios.py b/comfy_extras/nodes_helios.py index ca4373c66..2a356be68 100644 --- a/comfy_extras/nodes_helios.py +++ b/comfy_extras/nodes_helios.py @@ -42,7 +42,7 @@ def _parse_int_list(values, default): return out if len(out) > 0 else default -_HELIOS_LATENT_FORMAT = comfy.latent_formats.Helios() +_HELIOS_LATENT_FORMAT = comfy.latent_formats.Wan21() def _apply_helios_latent_space_noise(latent, sigma, generator=None): From 9b6e0b067711e592441412da6771366e1f4756fd Mon Sep 17 00:00:00 2001 From: qqingzheng <2533221180@qq.com> Date: Thu, 12 Mar 2026 22:36:11 +0800 Subject: [PATCH 10/10] Reuse wan text encoder --- comfy/sd.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 3f8eabb46..4f7530c50 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1338,8 +1338,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.HELIOS: - clip_target.clip = comfy.text_encoders.helios.te(**t5xxl_detect(clip_data)) - clip_target.tokenizer = comfy.text_encoders.helios.HeliosT5Tokenizer + # Helios reuses the WAN UMT5-XXL text encoder stack. + clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.HIDREAM: clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),