From c25df83b8ad43e76601ae4698cb96c33dff7f109 Mon Sep 17 00:00:00 2001 From: qqingzheng <2533221180@qq.com> Date: Tue, 10 Mar 2026 19:11:59 +0800 Subject: [PATCH] Fix Helios norm2 fallback and history RoPE guards; simplify sampler knobs --- comfy/ldm/helios/model.py | 73 +++++++++++++++++++++--------------- comfy_extras/nodes_helios.py | 63 +++++-------------------------- 2 files changed, 52 insertions(+), 84 deletions(-) diff --git a/comfy/ldm/helios/model.py b/comfy/ldm/helios/model.py index 2faeea897..911d45831 100644 --- a/comfy/ldm/helios/model.py +++ b/comfy/ldm/helios/model.py @@ -228,13 +228,14 @@ class HeliosAttentionBlock(nn.Module): operation_settings=operation_settings, ) + self.cross_attn_norm = bool(cross_attn_norm) self.norm2 = (operation_settings.get("operations").LayerNorm( dim, eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"), - ) if cross_attn_norm else nn.Identity()) + ) if self.cross_attn_norm else nn.Identity()) self.attn2 = HeliosSelfAttention( dim, num_heads, @@ -309,14 +310,17 @@ class HeliosAttentionBlock(nn.Module): if self.guidance_cross_attn and original_context_length is not None: history_seq_len = x.shape[1] - original_context_length history_x, x_main = torch.split(x, [history_seq_len, original_context_length], dim=1) - # norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior - norm_x_main = torch.nn.functional.layer_norm( - x_main.float(), - self.norm2.normalized_shape, - self.norm2.weight.to(x_main.device).float() if self.norm2.weight is not None else None, - self.norm2.bias.to(x_main.device).float() if self.norm2.bias is not None else None, - self.norm2.eps, - ).type_as(x_main) + if self.cross_attn_norm: + # norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior + norm_x_main = torch.nn.functional.layer_norm( + x_main.float(), + self.norm2.normalized_shape, + self.norm2.weight.to(x_main.device).float() if self.norm2.weight is not None else None, + self.norm2.bias.to(x_main.device).float() if self.norm2.bias is not None else None, + self.norm2.eps, + ).type_as(x_main) + else: + norm_x_main = x_main x_main = x_main + self.attn2( norm_x_main, context=context, @@ -324,14 +328,17 @@ class HeliosAttentionBlock(nn.Module): ) x = torch.cat([history_x, x_main], dim=1) else: - # norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior - norm_x = torch.nn.functional.layer_norm( - x.float(), - self.norm2.normalized_shape, - self.norm2.weight.to(x.device).float() if self.norm2.weight is not None else None, - self.norm2.bias.to(x.device).float() if self.norm2.bias is not None else None, - self.norm2.eps, - ).type_as(x) + if self.cross_attn_norm: + # norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior + norm_x = torch.nn.functional.layer_norm( + x.float(), + self.norm2.normalized_shape, + self.norm2.weight.to(x.device).float() if self.norm2.weight is not None else None, + self.norm2.bias.to(x.device).float() if self.norm2.bias is not None else None, + self.norm2.eps, + ).type_as(x) + else: + norm_x = x x = x + self.attn2(norm_x, context=context, transformer_options=transformer_options) # ffn @@ -673,45 +680,51 @@ class HeliosModel(torch.nn.Module): if latents_history_mid is not None and indices_latents_history_mid is not None: x_mid = self.patch_mid(pad_for_3d_conv(latents_history_mid, (2, 4, 4))) - _, _, tm, _, _ = x_mid.shape + _, _, tm, hm, wm = x_mid.shape x_mid = x_mid.flatten(2).transpose(1, 2) mid_t = indices_latents_history_mid.shape[1] + # patch_mid downsamples by 2 in (t, h, w); build RoPE on the pre-downsample grid. + mid_h = hm * 2 + mid_w = wm * 2 f_mid = self.rope_encode( t=mid_t * self.patch_size[0], - h=hs * self.patch_size[1], - w=ws * self.patch_size[2], + h=mid_h * self.patch_size[1], + w=mid_w * self.patch_size[2], steps_t=mid_t, - steps_h=hs, - steps_w=ws, + steps_h=mid_h, + steps_w=mid_w, device=x_mid.device, dtype=x_mid.dtype, transformer_options=transformer_options, frame_indices=indices_latents_history_mid, ) - f_mid = self._rope_downsample_3d(f_mid, (mid_t, hs, ws), (2, 2, 2)) + f_mid = self._rope_downsample_3d(f_mid, (mid_t, mid_h, mid_w), (2, 2, 2)) hidden_states = torch.cat([x_mid, hidden_states], dim=1) freqs = torch.cat([f_mid, freqs], dim=1) if latents_history_long is not None and indices_latents_history_long is not None: x_long = self.patch_long(pad_for_3d_conv(latents_history_long, (4, 8, 8))) - _, _, tl, _, _ = x_long.shape + _, _, tl, hl, wl = x_long.shape x_long = x_long.flatten(2).transpose(1, 2) long_t = indices_latents_history_long.shape[1] + # patch_long downsamples by 4 in (t, h, w); build RoPE on the pre-downsample grid. + long_h = hl * 4 + long_w = wl * 4 f_long = self.rope_encode( t=long_t * self.patch_size[0], - h=hs * self.patch_size[1], - w=ws * self.patch_size[2], + h=long_h * self.patch_size[1], + w=long_w * self.patch_size[2], steps_t=long_t, - steps_h=hs, - steps_w=ws, + steps_h=long_h, + steps_w=long_w, device=x_long.device, dtype=x_long.dtype, transformer_options=transformer_options, frame_indices=indices_latents_history_long, ) - f_long = self._rope_downsample_3d(f_long, (long_t, hs, ws), (4, 4, 4)) + f_long = self._rope_downsample_3d(f_long, (long_t, long_h, long_w), (4, 4, 4)) hidden_states = torch.cat([x_long, hidden_states], dim=1) - freqs = torch.cat([f_long, freqs], dim=1) + freqs = torch.cat([f_long, freqs], dim=1) history_context_length = hidden_states.shape[1] - original_context_length diff --git a/comfy_extras/nodes_helios.py b/comfy_extras/nodes_helios.py index 2fcf0a64f..7f62b5003 100644 --- a/comfy_extras/nodes_helios.py +++ b/comfy_extras/nodes_helios.py @@ -914,7 +914,6 @@ class HeliosPyramidSampler(io.ComfyNode): category="sampling/video_models", inputs=[ io.Model.Input("model"), - io.Boolean.Input("add_noise", default=True, advanced=True), io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, control_after_generate=True), io.Float.Input("cfg", default=5.0, min=0.0, max=100.0, step=0.1, round=0.01), io.Conditioning.Input("positive"), @@ -931,7 +930,6 @@ class HeliosPyramidSampler(io.ComfyNode): io.Boolean.Input("cfg_zero_star", default=True, advanced=True), io.Boolean.Input("use_zero_init", default=True, advanced=True), io.Int.Input("zero_steps", default=1, min=0, max=10000, advanced=True), - io.Boolean.Input("skip_first_chunk", default=False, advanced=True), ], outputs=[ io.Latent.Output(display_name="output"), @@ -943,7 +941,6 @@ class HeliosPyramidSampler(io.ComfyNode): def execute( cls, model, - add_noise, noise_seed, cfg, positive, @@ -960,7 +957,6 @@ class HeliosPyramidSampler(io.ComfyNode): cfg_zero_star, use_zero_init, zero_steps, - skip_first_chunk, ) -> io.NodeOutput: # Keep these scheduler knobs internal (not exposed in node UI). shift = 1.0 @@ -975,8 +971,6 @@ class HeliosPyramidSampler(io.ComfyNode): latent = latent_image.copy() latent_samples = comfy.sample.fix_empty_latent_channels(model, latent["samples"], latent.get("downscale_ratio_spacial", None)) - if not add_noise: - latent_samples = _process_latent_in_preserve_zero_frames(model, latent_samples) stage_steps = _parse_int_list(pyramid_steps, [10, 10, 10]) stage_steps = [max(1, int(s)) for s in stage_steps] @@ -1069,19 +1063,6 @@ class HeliosPyramidSampler(io.ComfyNode): hist_len = max(1, sum(history_sizes_list)) rolling_history = torch.zeros((b, c, hist_len, h, w), device=latent_samples.device, dtype=latent_samples.dtype) - # When initial video latents are provided, seed history buffer - # with those latents before the first denoising chunk. - if not add_noise: - hist_len = max(1, sum(history_sizes_list)) - rolling_history = rolling_history.to(device=latent_samples.device, dtype=latent_samples.dtype) - video_latents = latent_samples - video_frames = video_latents.shape[2] - if video_frames < hist_len: - keep_frames = hist_len - video_frames - rolling_history = torch.cat([rolling_history[:, :, :keep_frames], video_latents], dim=2) - else: - rolling_history = video_latents[:, :, -hist_len:] - # Keep history/prefix on the same device/dtype as denoising latents. rolling_history = rolling_history.to(device=target_device, dtype=torch.float32) if image_latent_prefix is not None: @@ -1108,41 +1089,15 @@ class HeliosPyramidSampler(io.ComfyNode): total_generated_latent_frames = initial_generated_latent_frames for chunk_idx in range(chunk_count): - # Extract chunk from input latents - chunk_start = chunk_idx * chunk_t - chunk_end = min(chunk_start + chunk_t, t) - latent_chunk = latent_samples[:, :, chunk_start:chunk_end, :, :] - # Prepare initial latent for this chunk - if add_noise: - noise_shape = ( - latent_samples.shape[0], - latent_samples.shape[1], - chunk_t, - latent_samples.shape[3], - latent_samples.shape[4], - ) - stage_latent = torch.randn(noise_shape, device=target_device, dtype=torch.float32, generator=noise_gen) - else: - # Use actual input latents; pad final short chunk to fixed size. - stage_latent = latent_chunk.clone() - if stage_latent.shape[2] < chunk_t: - if stage_latent.shape[2] == 0: - stage_latent = torch.zeros( - ( - latent_samples.shape[0], - latent_samples.shape[1], - chunk_t, - latent_samples.shape[3], - latent_samples.shape[4], - ), - device=latent_samples.device, - dtype=torch.float32, - ) - else: - pad = stage_latent[:, :, -1:].repeat(1, 1, chunk_t - stage_latent.shape[2], 1, 1) - stage_latent = torch.cat([stage_latent, pad], dim=2) - stage_latent = stage_latent.to(dtype=torch.float32) + noise_shape = ( + latent_samples.shape[0], + latent_samples.shape[1], + chunk_t, + latent_samples.shape[3], + latent_samples.shape[4], + ) + stage_latent = torch.randn(noise_shape, device=target_device, dtype=torch.float32, generator=noise_gen) # Downsample to stage 0 resolution for _ in range(max(0, int(stage_count) - 1)): @@ -1308,7 +1263,7 @@ class HeliosPyramidSampler(io.ComfyNode): stage_latent = stage_latent[:, :, :, :h, :w] generated_chunks.append(stage_latent) - if keep_first_frame and ((chunk_idx == 0 and image_latent_prefix is None) or (skip_first_chunk and chunk_idx == 1)): + if keep_first_frame and (chunk_idx == 0 and image_latent_prefix is None): image_latent_prefix = stage_latent[:, :, :1] rolling_history = torch.cat([rolling_history, stage_latent.to(rolling_history.device, rolling_history.dtype)], dim=2) keep_hist = max(1, sum(history_sizes_list))