Fix Helios norm2 fallback and history RoPE guards; simplify sampler knobs

This commit is contained in:
qqingzheng 2026-03-10 19:11:59 +08:00
parent ddd4030c06
commit c25df83b8a
2 changed files with 52 additions and 84 deletions

View File

@ -228,13 +228,14 @@ class HeliosAttentionBlock(nn.Module):
operation_settings=operation_settings,
)
self.cross_attn_norm = bool(cross_attn_norm)
self.norm2 = (operation_settings.get("operations").LayerNorm(
dim,
eps,
elementwise_affine=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
) if cross_attn_norm else nn.Identity())
) if self.cross_attn_norm else nn.Identity())
self.attn2 = HeliosSelfAttention(
dim,
num_heads,
@ -309,14 +310,17 @@ class HeliosAttentionBlock(nn.Module):
if self.guidance_cross_attn and original_context_length is not None:
history_seq_len = x.shape[1] - original_context_length
history_x, x_main = torch.split(x, [history_seq_len, original_context_length], dim=1)
# norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior
norm_x_main = torch.nn.functional.layer_norm(
x_main.float(),
self.norm2.normalized_shape,
self.norm2.weight.to(x_main.device).float() if self.norm2.weight is not None else None,
self.norm2.bias.to(x_main.device).float() if self.norm2.bias is not None else None,
self.norm2.eps,
).type_as(x_main)
if self.cross_attn_norm:
# norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior
norm_x_main = torch.nn.functional.layer_norm(
x_main.float(),
self.norm2.normalized_shape,
self.norm2.weight.to(x_main.device).float() if self.norm2.weight is not None else None,
self.norm2.bias.to(x_main.device).float() if self.norm2.bias is not None else None,
self.norm2.eps,
).type_as(x_main)
else:
norm_x_main = x_main
x_main = x_main + self.attn2(
norm_x_main,
context=context,
@ -324,14 +328,17 @@ class HeliosAttentionBlock(nn.Module):
)
x = torch.cat([history_x, x_main], dim=1)
else:
# norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior
norm_x = torch.nn.functional.layer_norm(
x.float(),
self.norm2.normalized_shape,
self.norm2.weight.to(x.device).float() if self.norm2.weight is not None else None,
self.norm2.bias.to(x.device).float() if self.norm2.bias is not None else None,
self.norm2.eps,
).type_as(x)
if self.cross_attn_norm:
# norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior
norm_x = torch.nn.functional.layer_norm(
x.float(),
self.norm2.normalized_shape,
self.norm2.weight.to(x.device).float() if self.norm2.weight is not None else None,
self.norm2.bias.to(x.device).float() if self.norm2.bias is not None else None,
self.norm2.eps,
).type_as(x)
else:
norm_x = x
x = x + self.attn2(norm_x, context=context, transformer_options=transformer_options)
# ffn
@ -673,45 +680,51 @@ class HeliosModel(torch.nn.Module):
if latents_history_mid is not None and indices_latents_history_mid is not None:
x_mid = self.patch_mid(pad_for_3d_conv(latents_history_mid, (2, 4, 4)))
_, _, tm, _, _ = x_mid.shape
_, _, tm, hm, wm = x_mid.shape
x_mid = x_mid.flatten(2).transpose(1, 2)
mid_t = indices_latents_history_mid.shape[1]
# patch_mid downsamples by 2 in (t, h, w); build RoPE on the pre-downsample grid.
mid_h = hm * 2
mid_w = wm * 2
f_mid = self.rope_encode(
t=mid_t * self.patch_size[0],
h=hs * self.patch_size[1],
w=ws * self.patch_size[2],
h=mid_h * self.patch_size[1],
w=mid_w * self.patch_size[2],
steps_t=mid_t,
steps_h=hs,
steps_w=ws,
steps_h=mid_h,
steps_w=mid_w,
device=x_mid.device,
dtype=x_mid.dtype,
transformer_options=transformer_options,
frame_indices=indices_latents_history_mid,
)
f_mid = self._rope_downsample_3d(f_mid, (mid_t, hs, ws), (2, 2, 2))
f_mid = self._rope_downsample_3d(f_mid, (mid_t, mid_h, mid_w), (2, 2, 2))
hidden_states = torch.cat([x_mid, hidden_states], dim=1)
freqs = torch.cat([f_mid, freqs], dim=1)
if latents_history_long is not None and indices_latents_history_long is not None:
x_long = self.patch_long(pad_for_3d_conv(latents_history_long, (4, 8, 8)))
_, _, tl, _, _ = x_long.shape
_, _, tl, hl, wl = x_long.shape
x_long = x_long.flatten(2).transpose(1, 2)
long_t = indices_latents_history_long.shape[1]
# patch_long downsamples by 4 in (t, h, w); build RoPE on the pre-downsample grid.
long_h = hl * 4
long_w = wl * 4
f_long = self.rope_encode(
t=long_t * self.patch_size[0],
h=hs * self.patch_size[1],
w=ws * self.patch_size[2],
h=long_h * self.patch_size[1],
w=long_w * self.patch_size[2],
steps_t=long_t,
steps_h=hs,
steps_w=ws,
steps_h=long_h,
steps_w=long_w,
device=x_long.device,
dtype=x_long.dtype,
transformer_options=transformer_options,
frame_indices=indices_latents_history_long,
)
f_long = self._rope_downsample_3d(f_long, (long_t, hs, ws), (4, 4, 4))
f_long = self._rope_downsample_3d(f_long, (long_t, long_h, long_w), (4, 4, 4))
hidden_states = torch.cat([x_long, hidden_states], dim=1)
freqs = torch.cat([f_long, freqs], dim=1)
freqs = torch.cat([f_long, freqs], dim=1)
history_context_length = hidden_states.shape[1] - original_context_length

View File

@ -914,7 +914,6 @@ class HeliosPyramidSampler(io.ComfyNode):
category="sampling/video_models",
inputs=[
io.Model.Input("model"),
io.Boolean.Input("add_noise", default=True, advanced=True),
io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, control_after_generate=True),
io.Float.Input("cfg", default=5.0, min=0.0, max=100.0, step=0.1, round=0.01),
io.Conditioning.Input("positive"),
@ -931,7 +930,6 @@ class HeliosPyramidSampler(io.ComfyNode):
io.Boolean.Input("cfg_zero_star", default=True, advanced=True),
io.Boolean.Input("use_zero_init", default=True, advanced=True),
io.Int.Input("zero_steps", default=1, min=0, max=10000, advanced=True),
io.Boolean.Input("skip_first_chunk", default=False, advanced=True),
],
outputs=[
io.Latent.Output(display_name="output"),
@ -943,7 +941,6 @@ class HeliosPyramidSampler(io.ComfyNode):
def execute(
cls,
model,
add_noise,
noise_seed,
cfg,
positive,
@ -960,7 +957,6 @@ class HeliosPyramidSampler(io.ComfyNode):
cfg_zero_star,
use_zero_init,
zero_steps,
skip_first_chunk,
) -> io.NodeOutput:
# Keep these scheduler knobs internal (not exposed in node UI).
shift = 1.0
@ -975,8 +971,6 @@ class HeliosPyramidSampler(io.ComfyNode):
latent = latent_image.copy()
latent_samples = comfy.sample.fix_empty_latent_channels(model, latent["samples"], latent.get("downscale_ratio_spacial", None))
if not add_noise:
latent_samples = _process_latent_in_preserve_zero_frames(model, latent_samples)
stage_steps = _parse_int_list(pyramid_steps, [10, 10, 10])
stage_steps = [max(1, int(s)) for s in stage_steps]
@ -1069,19 +1063,6 @@ class HeliosPyramidSampler(io.ComfyNode):
hist_len = max(1, sum(history_sizes_list))
rolling_history = torch.zeros((b, c, hist_len, h, w), device=latent_samples.device, dtype=latent_samples.dtype)
# When initial video latents are provided, seed history buffer
# with those latents before the first denoising chunk.
if not add_noise:
hist_len = max(1, sum(history_sizes_list))
rolling_history = rolling_history.to(device=latent_samples.device, dtype=latent_samples.dtype)
video_latents = latent_samples
video_frames = video_latents.shape[2]
if video_frames < hist_len:
keep_frames = hist_len - video_frames
rolling_history = torch.cat([rolling_history[:, :, :keep_frames], video_latents], dim=2)
else:
rolling_history = video_latents[:, :, -hist_len:]
# Keep history/prefix on the same device/dtype as denoising latents.
rolling_history = rolling_history.to(device=target_device, dtype=torch.float32)
if image_latent_prefix is not None:
@ -1108,41 +1089,15 @@ class HeliosPyramidSampler(io.ComfyNode):
total_generated_latent_frames = initial_generated_latent_frames
for chunk_idx in range(chunk_count):
# Extract chunk from input latents
chunk_start = chunk_idx * chunk_t
chunk_end = min(chunk_start + chunk_t, t)
latent_chunk = latent_samples[:, :, chunk_start:chunk_end, :, :]
# Prepare initial latent for this chunk
if add_noise:
noise_shape = (
latent_samples.shape[0],
latent_samples.shape[1],
chunk_t,
latent_samples.shape[3],
latent_samples.shape[4],
)
stage_latent = torch.randn(noise_shape, device=target_device, dtype=torch.float32, generator=noise_gen)
else:
# Use actual input latents; pad final short chunk to fixed size.
stage_latent = latent_chunk.clone()
if stage_latent.shape[2] < chunk_t:
if stage_latent.shape[2] == 0:
stage_latent = torch.zeros(
(
latent_samples.shape[0],
latent_samples.shape[1],
chunk_t,
latent_samples.shape[3],
latent_samples.shape[4],
),
device=latent_samples.device,
dtype=torch.float32,
)
else:
pad = stage_latent[:, :, -1:].repeat(1, 1, chunk_t - stage_latent.shape[2], 1, 1)
stage_latent = torch.cat([stage_latent, pad], dim=2)
stage_latent = stage_latent.to(dtype=torch.float32)
noise_shape = (
latent_samples.shape[0],
latent_samples.shape[1],
chunk_t,
latent_samples.shape[3],
latent_samples.shape[4],
)
stage_latent = torch.randn(noise_shape, device=target_device, dtype=torch.float32, generator=noise_gen)
# Downsample to stage 0 resolution
for _ in range(max(0, int(stage_count) - 1)):
@ -1308,7 +1263,7 @@ class HeliosPyramidSampler(io.ComfyNode):
stage_latent = stage_latent[:, :, :, :h, :w]
generated_chunks.append(stage_latent)
if keep_first_frame and ((chunk_idx == 0 and image_latent_prefix is None) or (skip_first_chunk and chunk_idx == 1)):
if keep_first_frame and (chunk_idx == 0 and image_latent_prefix is None):
image_latent_prefix = stage_latent[:, :, :1]
rolling_history = torch.cat([rolling_history, stage_latent.to(rolling_history.device, rolling_history.dtype)], dim=2)
keep_hist = max(1, sum(history_sizes_list))