mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-22 17:43:33 +08:00
Fix Helios norm2 fallback and history RoPE guards; simplify sampler knobs
This commit is contained in:
parent
ddd4030c06
commit
c25df83b8a
@ -228,13 +228,14 @@ class HeliosAttentionBlock(nn.Module):
|
|||||||
operation_settings=operation_settings,
|
operation_settings=operation_settings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.cross_attn_norm = bool(cross_attn_norm)
|
||||||
self.norm2 = (operation_settings.get("operations").LayerNorm(
|
self.norm2 = (operation_settings.get("operations").LayerNorm(
|
||||||
dim,
|
dim,
|
||||||
eps,
|
eps,
|
||||||
elementwise_affine=True,
|
elementwise_affine=True,
|
||||||
device=operation_settings.get("device"),
|
device=operation_settings.get("device"),
|
||||||
dtype=operation_settings.get("dtype"),
|
dtype=operation_settings.get("dtype"),
|
||||||
) if cross_attn_norm else nn.Identity())
|
) if self.cross_attn_norm else nn.Identity())
|
||||||
self.attn2 = HeliosSelfAttention(
|
self.attn2 = HeliosSelfAttention(
|
||||||
dim,
|
dim,
|
||||||
num_heads,
|
num_heads,
|
||||||
@ -309,14 +310,17 @@ class HeliosAttentionBlock(nn.Module):
|
|||||||
if self.guidance_cross_attn and original_context_length is not None:
|
if self.guidance_cross_attn and original_context_length is not None:
|
||||||
history_seq_len = x.shape[1] - original_context_length
|
history_seq_len = x.shape[1] - original_context_length
|
||||||
history_x, x_main = torch.split(x, [history_seq_len, original_context_length], dim=1)
|
history_x, x_main = torch.split(x, [history_seq_len, original_context_length], dim=1)
|
||||||
# norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior
|
if self.cross_attn_norm:
|
||||||
norm_x_main = torch.nn.functional.layer_norm(
|
# norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior
|
||||||
x_main.float(),
|
norm_x_main = torch.nn.functional.layer_norm(
|
||||||
self.norm2.normalized_shape,
|
x_main.float(),
|
||||||
self.norm2.weight.to(x_main.device).float() if self.norm2.weight is not None else None,
|
self.norm2.normalized_shape,
|
||||||
self.norm2.bias.to(x_main.device).float() if self.norm2.bias is not None else None,
|
self.norm2.weight.to(x_main.device).float() if self.norm2.weight is not None else None,
|
||||||
self.norm2.eps,
|
self.norm2.bias.to(x_main.device).float() if self.norm2.bias is not None else None,
|
||||||
).type_as(x_main)
|
self.norm2.eps,
|
||||||
|
).type_as(x_main)
|
||||||
|
else:
|
||||||
|
norm_x_main = x_main
|
||||||
x_main = x_main + self.attn2(
|
x_main = x_main + self.attn2(
|
||||||
norm_x_main,
|
norm_x_main,
|
||||||
context=context,
|
context=context,
|
||||||
@ -324,14 +328,17 @@ class HeliosAttentionBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
x = torch.cat([history_x, x_main], dim=1)
|
x = torch.cat([history_x, x_main], dim=1)
|
||||||
else:
|
else:
|
||||||
# norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior
|
if self.cross_attn_norm:
|
||||||
norm_x = torch.nn.functional.layer_norm(
|
# norm2 has elementwise_affine=True, manually do FP32LayerNorm behavior
|
||||||
x.float(),
|
norm_x = torch.nn.functional.layer_norm(
|
||||||
self.norm2.normalized_shape,
|
x.float(),
|
||||||
self.norm2.weight.to(x.device).float() if self.norm2.weight is not None else None,
|
self.norm2.normalized_shape,
|
||||||
self.norm2.bias.to(x.device).float() if self.norm2.bias is not None else None,
|
self.norm2.weight.to(x.device).float() if self.norm2.weight is not None else None,
|
||||||
self.norm2.eps,
|
self.norm2.bias.to(x.device).float() if self.norm2.bias is not None else None,
|
||||||
).type_as(x)
|
self.norm2.eps,
|
||||||
|
).type_as(x)
|
||||||
|
else:
|
||||||
|
norm_x = x
|
||||||
x = x + self.attn2(norm_x, context=context, transformer_options=transformer_options)
|
x = x + self.attn2(norm_x, context=context, transformer_options=transformer_options)
|
||||||
|
|
||||||
# ffn
|
# ffn
|
||||||
@ -673,45 +680,51 @@ class HeliosModel(torch.nn.Module):
|
|||||||
|
|
||||||
if latents_history_mid is not None and indices_latents_history_mid is not None:
|
if latents_history_mid is not None and indices_latents_history_mid is not None:
|
||||||
x_mid = self.patch_mid(pad_for_3d_conv(latents_history_mid, (2, 4, 4)))
|
x_mid = self.patch_mid(pad_for_3d_conv(latents_history_mid, (2, 4, 4)))
|
||||||
_, _, tm, _, _ = x_mid.shape
|
_, _, tm, hm, wm = x_mid.shape
|
||||||
x_mid = x_mid.flatten(2).transpose(1, 2)
|
x_mid = x_mid.flatten(2).transpose(1, 2)
|
||||||
mid_t = indices_latents_history_mid.shape[1]
|
mid_t = indices_latents_history_mid.shape[1]
|
||||||
|
# patch_mid downsamples by 2 in (t, h, w); build RoPE on the pre-downsample grid.
|
||||||
|
mid_h = hm * 2
|
||||||
|
mid_w = wm * 2
|
||||||
f_mid = self.rope_encode(
|
f_mid = self.rope_encode(
|
||||||
t=mid_t * self.patch_size[0],
|
t=mid_t * self.patch_size[0],
|
||||||
h=hs * self.patch_size[1],
|
h=mid_h * self.patch_size[1],
|
||||||
w=ws * self.patch_size[2],
|
w=mid_w * self.patch_size[2],
|
||||||
steps_t=mid_t,
|
steps_t=mid_t,
|
||||||
steps_h=hs,
|
steps_h=mid_h,
|
||||||
steps_w=ws,
|
steps_w=mid_w,
|
||||||
device=x_mid.device,
|
device=x_mid.device,
|
||||||
dtype=x_mid.dtype,
|
dtype=x_mid.dtype,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
frame_indices=indices_latents_history_mid,
|
frame_indices=indices_latents_history_mid,
|
||||||
)
|
)
|
||||||
f_mid = self._rope_downsample_3d(f_mid, (mid_t, hs, ws), (2, 2, 2))
|
f_mid = self._rope_downsample_3d(f_mid, (mid_t, mid_h, mid_w), (2, 2, 2))
|
||||||
hidden_states = torch.cat([x_mid, hidden_states], dim=1)
|
hidden_states = torch.cat([x_mid, hidden_states], dim=1)
|
||||||
freqs = torch.cat([f_mid, freqs], dim=1)
|
freqs = torch.cat([f_mid, freqs], dim=1)
|
||||||
|
|
||||||
if latents_history_long is not None and indices_latents_history_long is not None:
|
if latents_history_long is not None and indices_latents_history_long is not None:
|
||||||
x_long = self.patch_long(pad_for_3d_conv(latents_history_long, (4, 8, 8)))
|
x_long = self.patch_long(pad_for_3d_conv(latents_history_long, (4, 8, 8)))
|
||||||
_, _, tl, _, _ = x_long.shape
|
_, _, tl, hl, wl = x_long.shape
|
||||||
x_long = x_long.flatten(2).transpose(1, 2)
|
x_long = x_long.flatten(2).transpose(1, 2)
|
||||||
long_t = indices_latents_history_long.shape[1]
|
long_t = indices_latents_history_long.shape[1]
|
||||||
|
# patch_long downsamples by 4 in (t, h, w); build RoPE on the pre-downsample grid.
|
||||||
|
long_h = hl * 4
|
||||||
|
long_w = wl * 4
|
||||||
f_long = self.rope_encode(
|
f_long = self.rope_encode(
|
||||||
t=long_t * self.patch_size[0],
|
t=long_t * self.patch_size[0],
|
||||||
h=hs * self.patch_size[1],
|
h=long_h * self.patch_size[1],
|
||||||
w=ws * self.patch_size[2],
|
w=long_w * self.patch_size[2],
|
||||||
steps_t=long_t,
|
steps_t=long_t,
|
||||||
steps_h=hs,
|
steps_h=long_h,
|
||||||
steps_w=ws,
|
steps_w=long_w,
|
||||||
device=x_long.device,
|
device=x_long.device,
|
||||||
dtype=x_long.dtype,
|
dtype=x_long.dtype,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
frame_indices=indices_latents_history_long,
|
frame_indices=indices_latents_history_long,
|
||||||
)
|
)
|
||||||
f_long = self._rope_downsample_3d(f_long, (long_t, hs, ws), (4, 4, 4))
|
f_long = self._rope_downsample_3d(f_long, (long_t, long_h, long_w), (4, 4, 4))
|
||||||
hidden_states = torch.cat([x_long, hidden_states], dim=1)
|
hidden_states = torch.cat([x_long, hidden_states], dim=1)
|
||||||
freqs = torch.cat([f_long, freqs], dim=1)
|
freqs = torch.cat([f_long, freqs], dim=1)
|
||||||
|
|
||||||
history_context_length = hidden_states.shape[1] - original_context_length
|
history_context_length = hidden_states.shape[1] - original_context_length
|
||||||
|
|
||||||
|
|||||||
@ -914,7 +914,6 @@ class HeliosPyramidSampler(io.ComfyNode):
|
|||||||
category="sampling/video_models",
|
category="sampling/video_models",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
io.Boolean.Input("add_noise", default=True, advanced=True),
|
|
||||||
io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, control_after_generate=True),
|
io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, control_after_generate=True),
|
||||||
io.Float.Input("cfg", default=5.0, min=0.0, max=100.0, step=0.1, round=0.01),
|
io.Float.Input("cfg", default=5.0, min=0.0, max=100.0, step=0.1, round=0.01),
|
||||||
io.Conditioning.Input("positive"),
|
io.Conditioning.Input("positive"),
|
||||||
@ -931,7 +930,6 @@ class HeliosPyramidSampler(io.ComfyNode):
|
|||||||
io.Boolean.Input("cfg_zero_star", default=True, advanced=True),
|
io.Boolean.Input("cfg_zero_star", default=True, advanced=True),
|
||||||
io.Boolean.Input("use_zero_init", default=True, advanced=True),
|
io.Boolean.Input("use_zero_init", default=True, advanced=True),
|
||||||
io.Int.Input("zero_steps", default=1, min=0, max=10000, advanced=True),
|
io.Int.Input("zero_steps", default=1, min=0, max=10000, advanced=True),
|
||||||
io.Boolean.Input("skip_first_chunk", default=False, advanced=True),
|
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Latent.Output(display_name="output"),
|
io.Latent.Output(display_name="output"),
|
||||||
@ -943,7 +941,6 @@ class HeliosPyramidSampler(io.ComfyNode):
|
|||||||
def execute(
|
def execute(
|
||||||
cls,
|
cls,
|
||||||
model,
|
model,
|
||||||
add_noise,
|
|
||||||
noise_seed,
|
noise_seed,
|
||||||
cfg,
|
cfg,
|
||||||
positive,
|
positive,
|
||||||
@ -960,7 +957,6 @@ class HeliosPyramidSampler(io.ComfyNode):
|
|||||||
cfg_zero_star,
|
cfg_zero_star,
|
||||||
use_zero_init,
|
use_zero_init,
|
||||||
zero_steps,
|
zero_steps,
|
||||||
skip_first_chunk,
|
|
||||||
) -> io.NodeOutput:
|
) -> io.NodeOutput:
|
||||||
# Keep these scheduler knobs internal (not exposed in node UI).
|
# Keep these scheduler knobs internal (not exposed in node UI).
|
||||||
shift = 1.0
|
shift = 1.0
|
||||||
@ -975,8 +971,6 @@ class HeliosPyramidSampler(io.ComfyNode):
|
|||||||
|
|
||||||
latent = latent_image.copy()
|
latent = latent_image.copy()
|
||||||
latent_samples = comfy.sample.fix_empty_latent_channels(model, latent["samples"], latent.get("downscale_ratio_spacial", None))
|
latent_samples = comfy.sample.fix_empty_latent_channels(model, latent["samples"], latent.get("downscale_ratio_spacial", None))
|
||||||
if not add_noise:
|
|
||||||
latent_samples = _process_latent_in_preserve_zero_frames(model, latent_samples)
|
|
||||||
|
|
||||||
stage_steps = _parse_int_list(pyramid_steps, [10, 10, 10])
|
stage_steps = _parse_int_list(pyramid_steps, [10, 10, 10])
|
||||||
stage_steps = [max(1, int(s)) for s in stage_steps]
|
stage_steps = [max(1, int(s)) for s in stage_steps]
|
||||||
@ -1069,19 +1063,6 @@ class HeliosPyramidSampler(io.ComfyNode):
|
|||||||
hist_len = max(1, sum(history_sizes_list))
|
hist_len = max(1, sum(history_sizes_list))
|
||||||
rolling_history = torch.zeros((b, c, hist_len, h, w), device=latent_samples.device, dtype=latent_samples.dtype)
|
rolling_history = torch.zeros((b, c, hist_len, h, w), device=latent_samples.device, dtype=latent_samples.dtype)
|
||||||
|
|
||||||
# When initial video latents are provided, seed history buffer
|
|
||||||
# with those latents before the first denoising chunk.
|
|
||||||
if not add_noise:
|
|
||||||
hist_len = max(1, sum(history_sizes_list))
|
|
||||||
rolling_history = rolling_history.to(device=latent_samples.device, dtype=latent_samples.dtype)
|
|
||||||
video_latents = latent_samples
|
|
||||||
video_frames = video_latents.shape[2]
|
|
||||||
if video_frames < hist_len:
|
|
||||||
keep_frames = hist_len - video_frames
|
|
||||||
rolling_history = torch.cat([rolling_history[:, :, :keep_frames], video_latents], dim=2)
|
|
||||||
else:
|
|
||||||
rolling_history = video_latents[:, :, -hist_len:]
|
|
||||||
|
|
||||||
# Keep history/prefix on the same device/dtype as denoising latents.
|
# Keep history/prefix on the same device/dtype as denoising latents.
|
||||||
rolling_history = rolling_history.to(device=target_device, dtype=torch.float32)
|
rolling_history = rolling_history.to(device=target_device, dtype=torch.float32)
|
||||||
if image_latent_prefix is not None:
|
if image_latent_prefix is not None:
|
||||||
@ -1108,41 +1089,15 @@ class HeliosPyramidSampler(io.ComfyNode):
|
|||||||
total_generated_latent_frames = initial_generated_latent_frames
|
total_generated_latent_frames = initial_generated_latent_frames
|
||||||
|
|
||||||
for chunk_idx in range(chunk_count):
|
for chunk_idx in range(chunk_count):
|
||||||
# Extract chunk from input latents
|
|
||||||
chunk_start = chunk_idx * chunk_t
|
|
||||||
chunk_end = min(chunk_start + chunk_t, t)
|
|
||||||
latent_chunk = latent_samples[:, :, chunk_start:chunk_end, :, :]
|
|
||||||
|
|
||||||
# Prepare initial latent for this chunk
|
# Prepare initial latent for this chunk
|
||||||
if add_noise:
|
noise_shape = (
|
||||||
noise_shape = (
|
latent_samples.shape[0],
|
||||||
latent_samples.shape[0],
|
latent_samples.shape[1],
|
||||||
latent_samples.shape[1],
|
chunk_t,
|
||||||
chunk_t,
|
latent_samples.shape[3],
|
||||||
latent_samples.shape[3],
|
latent_samples.shape[4],
|
||||||
latent_samples.shape[4],
|
)
|
||||||
)
|
stage_latent = torch.randn(noise_shape, device=target_device, dtype=torch.float32, generator=noise_gen)
|
||||||
stage_latent = torch.randn(noise_shape, device=target_device, dtype=torch.float32, generator=noise_gen)
|
|
||||||
else:
|
|
||||||
# Use actual input latents; pad final short chunk to fixed size.
|
|
||||||
stage_latent = latent_chunk.clone()
|
|
||||||
if stage_latent.shape[2] < chunk_t:
|
|
||||||
if stage_latent.shape[2] == 0:
|
|
||||||
stage_latent = torch.zeros(
|
|
||||||
(
|
|
||||||
latent_samples.shape[0],
|
|
||||||
latent_samples.shape[1],
|
|
||||||
chunk_t,
|
|
||||||
latent_samples.shape[3],
|
|
||||||
latent_samples.shape[4],
|
|
||||||
),
|
|
||||||
device=latent_samples.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
pad = stage_latent[:, :, -1:].repeat(1, 1, chunk_t - stage_latent.shape[2], 1, 1)
|
|
||||||
stage_latent = torch.cat([stage_latent, pad], dim=2)
|
|
||||||
stage_latent = stage_latent.to(dtype=torch.float32)
|
|
||||||
|
|
||||||
# Downsample to stage 0 resolution
|
# Downsample to stage 0 resolution
|
||||||
for _ in range(max(0, int(stage_count) - 1)):
|
for _ in range(max(0, int(stage_count) - 1)):
|
||||||
@ -1308,7 +1263,7 @@ class HeliosPyramidSampler(io.ComfyNode):
|
|||||||
stage_latent = stage_latent[:, :, :, :h, :w]
|
stage_latent = stage_latent[:, :, :, :h, :w]
|
||||||
|
|
||||||
generated_chunks.append(stage_latent)
|
generated_chunks.append(stage_latent)
|
||||||
if keep_first_frame and ((chunk_idx == 0 and image_latent_prefix is None) or (skip_first_chunk and chunk_idx == 1)):
|
if keep_first_frame and (chunk_idx == 0 and image_latent_prefix is None):
|
||||||
image_latent_prefix = stage_latent[:, :, :1]
|
image_latent_prefix = stage_latent[:, :, :1]
|
||||||
rolling_history = torch.cat([rolling_history, stage_latent.to(rolling_history.device, rolling_history.dtype)], dim=2)
|
rolling_history = torch.cat([rolling_history, stage_latent.to(rolling_history.device, rolling_history.dtype)], dim=2)
|
||||||
keep_hist = max(1, sum(history_sizes_list))
|
keep_hist = max(1, sum(history_sizes_list))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user