Enhance Helios model with latent space noise application and debugging options

This commit is contained in:
qqingzheng 2026-03-08 23:08:57 +08:00
parent d93133ee53
commit 26f44ab770
2 changed files with 287 additions and 85 deletions

View File

@ -652,8 +652,8 @@ class HeliosModel(torch.nn.Module):
)
original_context_length = hidden_states.shape[1]
if (latents_history_short is not None and indices_latents_history_short is not None and hasattr(self, "patch_short")):
x_short = self.patch_short(latents_history_short).to(hidden_states.dtype)
if latents_history_short is not None and indices_latents_history_short is not None:
x_short = self.patch_short(latents_history_short)
_, _, ts, hs, ws = x_short.shape
x_short = x_short.flatten(2).transpose(1, 2)
f_short = self.rope_encode(
@ -671,57 +671,45 @@ class HeliosModel(torch.nn.Module):
hidden_states = torch.cat([x_short, hidden_states], dim=1)
freqs = torch.cat([f_short, freqs], dim=1)
if (latents_history_mid is not None and indices_latents_history_mid is not None and hasattr(self, "patch_mid")):
x_mid = self.patch_mid(pad_for_3d_conv(latents_history_mid, (2, 4, 4))).to(hidden_states.dtype)
_, _, tm, hm, wm = x_mid.shape
if latents_history_mid is not None and indices_latents_history_mid is not None:
x_mid = self.patch_mid(pad_for_3d_conv(latents_history_mid, (2, 4, 4)))
_, _, tm, _, _ = x_mid.shape
x_mid = x_mid.flatten(2).transpose(1, 2)
mid_t = indices_latents_history_mid.shape[1]
if ("hs" in locals()) and ("ws" in locals()):
mid_h, mid_w = hs, ws
else:
mid_h, mid_w = hm * 2, wm * 2
f_mid = self.rope_encode(
t=mid_t * self.patch_size[0],
h=mid_h * self.patch_size[1],
w=mid_w * self.patch_size[2],
h=hs * self.patch_size[1],
w=ws * self.patch_size[2],
steps_t=mid_t,
steps_h=mid_h,
steps_w=mid_w,
steps_h=hs,
steps_w=ws,
device=x_mid.device,
dtype=x_mid.dtype,
transformer_options=transformer_options,
frame_indices=indices_latents_history_mid,
)
f_mid = self._rope_downsample_3d(f_mid, (mid_t, mid_h, mid_w), (2, 2, 2))
if f_mid.shape[1] != x_mid.shape[1]:
f_mid = f_mid[:, :x_mid.shape[1]]
f_mid = self._rope_downsample_3d(f_mid, (mid_t, hs, ws), (2, 2, 2))
hidden_states = torch.cat([x_mid, hidden_states], dim=1)
freqs = torch.cat([f_mid, freqs], dim=1)
if (latents_history_long is not None and indices_latents_history_long is not None and hasattr(self, "patch_long")):
x_long = self.patch_long(pad_for_3d_conv(latents_history_long, (4, 8, 8))).to(hidden_states.dtype)
_, _, tl, hl, wl = x_long.shape
if latents_history_long is not None and indices_latents_history_long is not None:
x_long = self.patch_long(pad_for_3d_conv(latents_history_long, (4, 8, 8)))
_, _, tl, _, _ = x_long.shape
x_long = x_long.flatten(2).transpose(1, 2)
long_t = indices_latents_history_long.shape[1]
if ("hs" in locals()) and ("ws" in locals()):
long_h, long_w = hs, ws
else:
long_h, long_w = hl * 4, wl * 4
f_long = self.rope_encode(
t=long_t * self.patch_size[0],
h=long_h * self.patch_size[1],
w=long_w * self.patch_size[2],
h=hs * self.patch_size[1],
w=ws * self.patch_size[2],
steps_t=long_t,
steps_h=long_h,
steps_w=long_w,
steps_h=hs,
steps_w=ws,
device=x_long.device,
dtype=x_long.dtype,
transformer_options=transformer_options,
frame_indices=indices_latents_history_long,
)
f_long = self._rope_downsample_3d(f_long, (long_t, long_h, long_w), (4, 4, 4))
if f_long.shape[1] != x_long.shape[1]:
f_long = f_long[:, :x_long.shape[1]]
f_long = self._rope_downsample_3d(f_long, (long_t, hs, ws), (4, 4, 4))
hidden_states = torch.cat([x_long, hidden_states], dim=1)
freqs = torch.cat([f_long, freqs], dim=1)

View File

@ -7,6 +7,7 @@ import comfy.model_patcher
import comfy.sample
import comfy.samplers
import comfy.utils
import comfy.latent_formats
import latent_preview
import node_helpers
@ -41,6 +42,37 @@ def _parse_int_list(values, default):
return out if len(out) > 0 else default
_HELIOS_LATENT_FORMAT = comfy.latent_formats.Helios()
def _apply_helios_latent_space_noise(latent, sigma, generator=None):
"""Apply noise in Helios model latent space, then map back to VAE latent space."""
latent_in = _HELIOS_LATENT_FORMAT.process_in(latent)
noise = torch.randn(
latent_in.shape,
device=latent_in.device,
dtype=latent_in.dtype,
generator=generator,
)
noised_in = sigma * noise + (1.0 - sigma) * latent_in
return _HELIOS_LATENT_FORMAT.process_out(noised_in).to(device=latent.device, dtype=latent.dtype)
def _tensor_stats_str(x):
if x is None:
return "None"
if not torch.is_tensor(x):
return f"non-tensor type={type(x)}"
if x.numel() == 0:
return f"shape={tuple(x.shape)} empty"
xf = x.detach().to(torch.float32)
return (
f"shape={tuple(x.shape)} "
f"mean={xf.mean().item():.6f} std={xf.std(unbiased=False).item():.6f} "
f"min={xf.min().item():.6f} max={xf.max().item():.6f}"
)
def _parse_float_list(values, default):
if values is None:
return default
@ -65,6 +97,15 @@ def _parse_float_list(values, default):
return out if len(out) > 0 else default
def _strict_bool(value, default=False):
if isinstance(value, bool):
return value
if isinstance(value, int):
return value != 0
# Reject non-bool numerics from stale workflows (e.g. 0.135).
return bool(default)
def _extract_condition_value(conditioning, key):
for c in conditioning:
if len(c) < 2:
@ -94,8 +135,9 @@ def _process_latent_in_preserve_zero_frames(model, latent, valid_mask=None):
return latent
if nonzero.shape[0] != latent.shape[2]:
# Keep behavior safe when mask length does not match temporal length.
nonzero = torch.zeros((latent.shape[2],), device=latent.device, dtype=torch.bool)
raise ValueError(
f"Helios history mask length mismatch: mask_t={nonzero.shape[0]} latent_t={latent.shape[2]}"
)
converted = model.model.process_latent_in(latent)
out = latent.clone()
@ -133,7 +175,7 @@ def _prepare_stage0_latent(batch, channels, frames, height, width, stage_count,
def _downsample_latent_for_stage0(latent, stage_count):
"""Downsample latent to stage 0 resolution (like Diffusers does)"""
"""Downsample latent to stage 0 resolution."""
stage_latent = latent
for _ in range(max(0, int(stage_count) - 1)):
stage_latent = _downsample_latent_5d_bilinear_x2(stage_latent)
@ -154,7 +196,7 @@ def _sample_block_noise_like(latent, gamma, patch_size=(1, 2, 2), generator=None
block_number = b * c * t * h_blocks * w_blocks
if generator is not None:
# Exact Diffusers sampling path (MultivariateNormal.sample), while consuming
# Exact sampling path (MultivariateNormal.sample), while consuming
# from an explicit generator by temporarily swapping default RNG state.
with torch.random.fork_rng(devices=[latent.device] if latent.device.type == "cuda" else []):
if latent.device.type == "cuda":
@ -231,7 +273,7 @@ def _helios_stage_tables(stage_count, stage_range, gamma, num_train_timesteps=10
tmax = min(float(sigmas[int(start_ratio * num_train_timesteps)].item() * num_train_timesteps), 999.0)
tmin = float(sigmas[min(int(end_ratio * num_train_timesteps), num_train_timesteps - 1)].item() * num_train_timesteps)
timesteps_per_stage[i] = torch.linspace(tmax, tmin, num_train_timesteps + 1)[:-1]
# Fixed: Use same sigma range [0.999, 0] for all stages like Diffusers
# Fixed: use the same sigma range [0.999, 0] for all stages.
sigmas_per_stage[i] = torch.linspace(0.999, 0.0, num_train_timesteps + 1)[:-1]
@ -302,21 +344,18 @@ def _build_cfg_zero_star_pre_cfg(stage_idx, zero_steps, use_zero_init):
state["i"] += 1
return conds_out
denoised_text = conds_out[0] # apply_model 返回的 denoised
denoised_text = conds_out[0]
denoised_uncond = conds_out[1]
cfg = float(args.get("cond_scale", 1.0))
x = args["input"] # 当前的 noisy latent
sigma = args["sigma"] # 当前的 sigma
x = args["input"]
sigma = args["sigma"]
# 关键修复:将 denoised 转换为 flow
# denoised = x - flow * sigma => flow = (x - denoised) / sigma
sigma_reshaped = sigma.reshape(sigma.shape[0], *([1] * (denoised_text.ndim - 1)))
sigma_safe = torch.clamp(sigma_reshaped, min=1e-8)
flow_text = (x - denoised_text) / sigma_safe
flow_uncond = (x - denoised_uncond) / sigma_safe
# 在 flow 空间做 CFG Zero Star
positive_flat = flow_text.reshape(flow_text.shape[0], -1)
negative_flat = flow_uncond.reshape(flow_uncond.shape[0], -1)
alpha = _optimized_scale(positive_flat, negative_flat)
@ -327,11 +366,9 @@ def _build_cfg_zero_star_pre_cfg(stage_idx, zero_steps, use_zero_init):
else:
flow_final = flow_uncond * alpha + cfg * (flow_text - flow_uncond * alpha)
# 将 flow 转回 denoised
denoised_final = x - flow_final * sigma_safe
state["i"] += 1
# Return identical cond/uncond so downstream cfg_function keeps `final` unchanged.
return [denoised_final, denoised_final]
return pre_cfg_fn
@ -519,6 +556,8 @@ class HeliosImageToVideo(io.ComfyNode):
io.Float.Input("image_noise_sigma_min", default=0.111, min=0.0, max=1.0, step=0.0001, round=False, advanced=True),
io.Float.Input("image_noise_sigma_max", default=0.135, min=0.0, max=1.0, step=0.0001, round=False, advanced=True),
io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, advanced=True),
io.Boolean.Input("include_history_in_output", default=False, advanced=True),
io.Boolean.Input("debug_latent_stats", default=False, advanced=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
@ -545,7 +584,11 @@ class HeliosImageToVideo(io.ComfyNode):
image_noise_sigma_min=0.111,
image_noise_sigma_max=0.135,
noise_seed=0,
include_history_in_output=False,
debug_latent_stats=False,
) -> io.NodeOutput:
video_noise_sigma_min = 0.111
video_noise_sigma_max = 0.135
spacial_scale = vae.spacial_compression_encode()
latent_channels = vae.latent_channels
latent_t = ((length - 1) // 4) + 1
@ -560,10 +603,11 @@ class HeliosImageToVideo(io.ComfyNode):
history_valid_mask = torch.zeros((batch_size, hist_len), device=latent.device, dtype=torch.bool)
image_latent_prefix = None
i2v_noise_gen = None
noise_gen_state = None
if start_image is not None:
image = comfy.utils.common_upscale(start_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
img_latent = vae.encode(image[:, :, :, :3])
img_latent = vae.encode(image[:, :, :, :3]).to(device=latent.device, dtype=torch.float32)
img_latent = comfy.utils.repeat_to_batch_size(img_latent, batch_size)
image_latent_prefix = img_latent[:, :, :1]
@ -571,33 +615,38 @@ class HeliosImageToVideo(io.ComfyNode):
i2v_noise_gen = torch.Generator(device=img_latent.device)
i2v_noise_gen.manual_seed(int(noise_seed))
sigma = (
torch.rand((img_latent.shape[0], 1, 1, 1, 1), device=img_latent.device, generator=i2v_noise_gen, dtype=img_latent.dtype)
torch.rand((1,), device=img_latent.device, generator=i2v_noise_gen, dtype=img_latent.dtype).view(1, 1, 1, 1, 1)
* (float(image_noise_sigma_max) - float(image_noise_sigma_min))
+ float(image_noise_sigma_min)
)
image_latent_prefix = sigma * torch.randn_like(image_latent_prefix, generator=i2v_noise_gen) + (1.0 - sigma) * image_latent_prefix
image_latent_prefix = _apply_helios_latent_space_noise(image_latent_prefix, sigma, generator=i2v_noise_gen)
min_frames = max(1, (int(num_latent_frames_per_chunk) - 1) * 4 + 1)
fake_video = image.repeat(min_frames, 1, 1, 1)
fake_latents_full = vae.encode(fake_video)
fake_latents_full = vae.encode(fake_video).to(device=latent.device, dtype=torch.float32)
fake_latent = comfy.utils.repeat_to_batch_size(fake_latents_full[:, :, -1:], batch_size)
# Diffusers parity for I2V:
# when adding noise to image latents, fake_image_latents used for history are also noised.
if add_noise_to_image_latents:
if i2v_noise_gen is None:
i2v_noise_gen = torch.Generator(device=fake_latent.device)
i2v_noise_gen.manual_seed(int(noise_seed))
# Keep backward compatibility with existing I2V node inputs:
# this node exposes only image sigma controls, while fake history
# latents follow the video-noise path in Diffusers.
# this node exposes only image sigma controls; fake history latents
# follow the video-noise defaults.
fake_sigma = (
torch.rand((fake_latent.shape[0], 1, 1, 1, 1), device=fake_latent.device, generator=i2v_noise_gen, dtype=fake_latent.dtype)
* (float(image_noise_sigma_max) - float(image_noise_sigma_min))
+ float(image_noise_sigma_min)
torch.rand((1,), device=fake_latent.device, generator=i2v_noise_gen, dtype=fake_latent.dtype).view(1, 1, 1, 1, 1)
* (float(video_noise_sigma_max) - float(video_noise_sigma_min))
+ float(video_noise_sigma_min)
)
fake_latent = fake_sigma * torch.randn_like(fake_latent, generator=i2v_noise_gen) + (1.0 - fake_sigma) * fake_latent
fake_latent = _apply_helios_latent_space_noise(fake_latent, fake_sigma, generator=i2v_noise_gen)
history_latent[:, :, -1:] = fake_latent
history_valid_mask[:, -1] = True
if i2v_noise_gen is not None:
noise_gen_state = i2v_noise_gen.get_state().clone()
if debug_latent_stats:
print(f"[HeliosDebug][I2V] image_latent_prefix: {_tensor_stats_str(image_latent_prefix)}")
print(f"[HeliosDebug][I2V] fake_latent: {_tensor_stats_str(fake_latent)}")
print(f"[HeliosDebug][I2V] history_latent: {_tensor_stats_str(history_latent)}")
positive, negative = _set_helios_history_values(positive, negative, history_latent, sizes, keep_first_frame, prefix_latent=image_latent_prefix)
return io.NodeOutput(
@ -608,6 +657,10 @@ class HeliosImageToVideo(io.ComfyNode):
"helios_history_latent": history_latent,
"helios_image_latent_prefix": image_latent_prefix,
"helios_history_valid_mask": history_valid_mask,
"helios_num_frames": int(length),
"helios_noise_gen_state": noise_gen_state,
"helios_include_history_in_output": _strict_bool(include_history_in_output, default=False),
"helios_debug_latent_stats": bool(debug_latent_stats),
},
)
@ -686,6 +739,7 @@ class HeliosTextToVideo(io.ComfyNode):
"helios_history_latent": history_latent,
"helios_image_latent_prefix": None,
"helios_history_valid_mask": history_valid_mask,
"helios_num_frames": int(length),
},
)
@ -707,10 +761,13 @@ class HeliosVideoToVideo(io.ComfyNode):
io.Image.Input("video", optional=True),
io.String.Input("history_sizes", default="16,2,1", advanced=True),
io.Boolean.Input("keep_first_frame", default=True, advanced=True),
io.Int.Input("num_latent_frames_per_chunk", default=9, min=1, max=256, advanced=True),
io.Boolean.Input("add_noise_to_video_latents", default=True, advanced=True),
io.Float.Input("video_noise_sigma_min", default=0.111, min=0.0, max=1.0, step=0.0001, round=False, advanced=True),
io.Float.Input("video_noise_sigma_max", default=0.135, min=0.0, max=1.0, step=0.0001, round=False, advanced=True),
io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, advanced=True),
io.Boolean.Input("include_history_in_output", default=True, advanced=True),
io.Boolean.Input("debug_latent_stats", default=False, advanced=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
@ -732,10 +789,13 @@ class HeliosVideoToVideo(io.ComfyNode):
video=None,
history_sizes="16,2,1",
keep_first_frame=True,
num_latent_frames_per_chunk=9,
add_noise_to_video_latents=True,
video_noise_sigma_min=0.111,
video_noise_sigma_max=0.135,
noise_seed=0,
include_history_in_output=True,
debug_latent_stats=False,
) -> io.NodeOutput:
spacial_scale = vae.spacial_compression_encode()
latent_channels = vae.latent_channels
@ -750,29 +810,81 @@ class HeliosVideoToVideo(io.ComfyNode):
history_latent = torch.zeros([batch_size, latent_channels, hist_len, latent.shape[-2], latent.shape[-1]], device=latent.device, dtype=latent.dtype)
history_valid_mask = torch.zeros((batch_size, hist_len), device=latent.device, dtype=torch.bool)
image_latent_prefix = None
noise_gen_state = None
history_latent_output = history_latent
if video is not None:
video = comfy.utils.common_upscale(video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
vid_latent = vae.encode(video[:, :, :, :3])
num_frames = int(video.shape[0])
min_frames = max(1, (int(num_latent_frames_per_chunk) - 1) * 4 + 1)
num_chunks = num_frames // min_frames
if num_chunks == 0:
raise ValueError(
f"Video must have at least {min_frames} frames (got {num_frames} frames). "
f"Required: (num_latent_frames_per_chunk - 1) * 4 + 1 = ({int(num_latent_frames_per_chunk)} - 1) * 4 + 1 = {min_frames}"
)
first_frame = video[:1]
first_frame_latent = vae.encode(first_frame[:, :, :, :3]).to(device=latent.device, dtype=torch.float32)
total_valid_frames = num_chunks * min_frames
start_frame = num_frames - total_valid_frames
latents_chunks = []
for i in range(num_chunks):
chunk_start = start_frame + i * min_frames
chunk_end = chunk_start + min_frames
video_chunk = video[chunk_start:chunk_end]
chunk_latents = vae.encode(video_chunk[:, :, :, :3]).to(device=latent.device, dtype=torch.float32)
latents_chunks.append(chunk_latents)
vid_latent = torch.cat(latents_chunks, dim=2)
vid_latent_clean = vid_latent.clone()
if add_noise_to_video_latents:
g = torch.Generator(device=vid_latent.device)
g.manual_seed(int(noise_seed))
frame_sigmas = (
torch.rand((1, 1, vid_latent.shape[2], 1, 1), device=vid_latent.device, generator=g, dtype=vid_latent.dtype)
image_sigma = (
torch.rand((1,), device=first_frame_latent.device, generator=g, dtype=first_frame_latent.dtype).view(1, 1, 1, 1, 1)
* (float(video_noise_sigma_max) - float(video_noise_sigma_min))
+ float(video_noise_sigma_min)
)
vid_latent = frame_sigmas * torch.randn_like(vid_latent, generator=g) + (1.0 - frame_sigmas) * vid_latent
vid_latent = vid_latent[:, :, :hist_len]
first_frame_latent = _apply_helios_latent_space_noise(first_frame_latent, image_sigma, generator=g)
noisy_chunks = []
num_latent_chunks = max(1, vid_latent.shape[2] // int(num_latent_frames_per_chunk))
for i in range(num_latent_chunks):
chunk_start = i * int(num_latent_frames_per_chunk)
chunk_end = chunk_start + int(num_latent_frames_per_chunk)
latent_chunk = vid_latent[:, :, chunk_start:chunk_end, :, :]
if latent_chunk.shape[2] == 0:
continue
chunk_frames = latent_chunk.shape[2]
frame_sigmas = (
torch.rand((chunk_frames,), device=vid_latent.device, generator=g, dtype=vid_latent.dtype)
* (float(video_noise_sigma_max) - float(video_noise_sigma_min))
+ float(video_noise_sigma_min)
).view(1, 1, chunk_frames, 1, 1)
noisy_chunk = _apply_helios_latent_space_noise(latent_chunk, frame_sigmas, generator=g)
noisy_chunks.append(noisy_chunk)
if len(noisy_chunks) > 0:
vid_latent = torch.cat(noisy_chunks, dim=2)
noise_gen_state = g.get_state().clone()
if debug_latent_stats:
print(f"[HeliosDebug][V2V] first_frame_latent: {_tensor_stats_str(first_frame_latent)}")
print(f"[HeliosDebug][V2V] video_latent: {_tensor_stats_str(vid_latent)}")
vid_latent = comfy.utils.repeat_to_batch_size(vid_latent, batch_size)
if vid_latent.shape[2] < hist_len:
keep_frames = hist_len - vid_latent.shape[2]
image_latent_prefix = comfy.utils.repeat_to_batch_size(first_frame_latent, batch_size)
video_frames = vid_latent.shape[2]
if video_frames < hist_len:
keep_frames = hist_len - video_frames
history_latent = torch.cat([history_latent[:, :, :keep_frames], vid_latent], dim=2)
history_latent_output = torch.cat([history_latent_output[:, :, :keep_frames], comfy.utils.repeat_to_batch_size(vid_latent_clean, batch_size)], dim=2)
history_valid_mask[:, keep_frames:] = True
else:
history_latent = vid_latent[:, :, -hist_len:]
history_valid_mask[:] = True
image_latent_prefix = history_latent[:, :, :1]
history_latent = vid_latent
history_latent_output = comfy.utils.repeat_to_batch_size(vid_latent_clean, batch_size)
history_valid_mask = torch.ones((batch_size, video_frames), device=latent.device, dtype=torch.bool)
positive, negative = _set_helios_history_values(positive, negative, history_latent, sizes, keep_first_frame, prefix_latent=image_latent_prefix)
return io.NodeOutput(
@ -781,8 +893,14 @@ class HeliosVideoToVideo(io.ComfyNode):
{
"samples": latent,
"helios_history_latent": history_latent,
"helios_history_latent_output": history_latent_output,
"helios_image_latent_prefix": image_latent_prefix,
"helios_history_valid_mask": history_valid_mask,
"helios_num_frames": int(length),
"helios_noise_gen_state": noise_gen_state,
# Keep initial history segment and generated chunks together in sampler output.
"helios_include_history_in_output": _strict_bool(include_history_in_output, default=True),
"helios_debug_latent_stats": bool(debug_latent_stats),
},
)
@ -894,7 +1012,6 @@ class HeliosPyramidSampler(io.ComfyNode):
stage_steps = [max(1, int(s)) for s in stage_steps]
stage_count = len(stage_steps)
history_sizes_list = sorted([max(0, int(v)) for v in _parse_int_list(history_sizes, [16, 2, 1])], reverse=True)
# Diffusers parity: if not keeping first frame, fold prefix slot into short history size.
if not keep_first_frame and len(history_sizes_list) > 0:
history_sizes_list[-1] += 1
@ -912,21 +1029,32 @@ class HeliosPyramidSampler(io.ComfyNode):
b, c, t, h, w = latent_samples.shape
chunk_t = max(1, int(num_latent_frames_per_chunk))
chunk_count = max(1, (t + chunk_t - 1) // chunk_t)
num_frames = int(latent.get("helios_num_frames", max(1, (int(t) - 1) * 4 + 1)))
window_num_frames = (chunk_t - 1) * 4 + 1
chunk_count = max(1, (num_frames + window_num_frames - 1) // window_num_frames)
euler_sampler = comfy.samplers.KSAMPLER(_helios_euler_sample)
target_device = comfy.model_management.get_torch_device()
noise_gen = torch.Generator(device=target_device)
noise_gen.manual_seed(int(noise_seed))
noise_gen_state = latent.get("helios_noise_gen_state", None)
if noise_gen_state is not None:
try:
noise_gen.set_state(noise_gen_state)
except Exception:
pass
debug_latent_stats = bool(latent.get("helios_debug_latent_stats", False))
image_latent_prefix = latent.get("helios_image_latent_prefix", None)
history_valid_mask = latent.get("helios_history_valid_mask", None)
if history_valid_mask is None:
raise ValueError("Helios sampler requires `helios_history_valid_mask` in latent input.")
history_full = None
history_from_latent_applied = False
if image_latent_prefix is not None:
image_latent_prefix = model.model.process_latent_in(image_latent_prefix)
if "helios_history_latent" in latent:
history_in = _process_latent_in_preserve_zero_frames(model, latent["helios_history_latent"], valid_mask=history_valid_mask)
history_full = history_in
positive, negative = _set_helios_history_values(
positive,
negative,
@ -959,8 +1087,6 @@ class HeliosPyramidSampler(io.ComfyNode):
x0_output = {}
generated_chunks = []
if latents_history_short is not None and latents_history_mid is not None and latents_history_long is not None:
# Diffusers parity: `history_latents` storage does NOT include the keep_first_frame prefix slot.
# `latents_history_short` in conditioning may include [prefix + short_base], so strip prefix here.
short_base_size = history_sizes_list[-1] if len(history_sizes_list) > 0 else latents_history_short.shape[2]
if keep_first_frame and latents_history_short.shape[2] > short_base_size:
short_for_history = latents_history_short[:, :, -short_base_size:]
@ -974,7 +1100,7 @@ class HeliosPyramidSampler(io.ComfyNode):
hist_len = max(1, sum(history_sizes_list))
rolling_history = torch.zeros((b, c, hist_len, h, w), device=latent_samples.device, dtype=latent_samples.dtype)
# Align with Diffusers behavior: when initial video latents are provided, seed history buffer
# When initial video latents are provided, seed history buffer
# with those latents before the first denoising chunk.
if not add_noise:
hist_len = max(1, sum(history_sizes_list))
@ -988,9 +1114,29 @@ class HeliosPyramidSampler(io.ComfyNode):
rolling_history = video_latents[:, :, -hist_len:]
# Keep history/prefix on the same device/dtype as denoising latents.
rolling_history = rolling_history.to(device=target_device, dtype=latent_samples.dtype)
rolling_history = rolling_history.to(device=target_device, dtype=torch.float32)
if image_latent_prefix is not None:
image_latent_prefix = image_latent_prefix.to(device=target_device, dtype=latent_samples.dtype)
image_latent_prefix = image_latent_prefix.to(device=target_device, dtype=torch.float32)
history_output = history_full if history_full is not None else rolling_history
if "helios_history_latent_output" in latent:
history_output = _process_latent_in_preserve_zero_frames(
model,
latent["helios_history_latent_output"],
valid_mask=history_valid_mask,
)
history_output = history_output.to(device=target_device, dtype=torch.float32)
if history_valid_mask is not None:
if not torch.is_tensor(history_valid_mask):
history_valid_mask = torch.tensor(history_valid_mask, device=target_device)
history_valid_mask = history_valid_mask.to(device=target_device)
if history_valid_mask.ndim == 2:
initial_generated_latent_frames = int(history_valid_mask.any(dim=0).sum().item())
else:
initial_generated_latent_frames = int(history_valid_mask.reshape(-1).sum().item())
else:
initial_generated_latent_frames = 0
total_generated_latent_frames = initial_generated_latent_frames
for chunk_idx in range(chunk_count):
# Extract chunk from input latents
@ -1000,8 +1146,6 @@ class HeliosPyramidSampler(io.ComfyNode):
# Prepare initial latent for this chunk
if add_noise:
# Diffusers parity: each chunk denoises a fixed latent window size.
# Keep chunk temporal length constant and crop only after all chunks.
noise_shape = (
latent_samples.shape[0],
latent_samples.shape[1],
@ -1009,9 +1153,9 @@ class HeliosPyramidSampler(io.ComfyNode):
latent_samples.shape[3],
latent_samples.shape[4],
)
stage_latent = torch.randn(noise_shape, device=target_device, dtype=latent_samples.dtype, generator=noise_gen)
stage_latent = torch.randn(noise_shape, device=target_device, dtype=torch.float32, generator=noise_gen)
else:
# Use actual input latents; pad final short chunk to fixed size like Diffusers windowing.
# Use actual input latents; pad final short chunk to fixed size.
stage_latent = latent_chunk.clone()
if stage_latent.shape[2] < chunk_t:
if stage_latent.shape[2] == 0:
@ -1024,22 +1168,20 @@ class HeliosPyramidSampler(io.ComfyNode):
latent_samples.shape[4],
),
device=latent_samples.device,
dtype=latent_samples.dtype,
dtype=torch.float32,
)
else:
pad = stage_latent[:, :, -1:].repeat(1, 1, chunk_t - stage_latent.shape[2], 1, 1)
stage_latent = torch.cat([stage_latent, pad], dim=2)
stage_latent = stage_latent.to(dtype=torch.float32)
# Downsample to stage 0 resolution
for _ in range(max(0, int(stage_count) - 1)):
stage_latent = _downsample_latent_5d_bilinear_x2(stage_latent)
# Keep stage latents on model device for parity with Diffusers scheduler/noise path.
# Keep stage latents on model device for scheduler/noise path consistency.
stage_latent = stage_latent.to(target_device)
# Diffusers parity:
# keep_first_frame=True and no image_latent_prefix on the first chunk
# should use an all-zero prefix frame, not history[:, :, :1].
chunk_prefix = image_latent_prefix
if keep_first_frame and image_latent_prefix is None and chunk_idx == 0:
chunk_prefix = torch.zeros(
@ -1065,6 +1207,10 @@ class HeliosPyramidSampler(io.ComfyNode):
latents_history_short = _extract_condition_value(positive_chunk, "latents_history_short")
latents_history_mid = _extract_condition_value(positive_chunk, "latents_history_mid")
latents_history_long = _extract_condition_value(positive_chunk, "latents_history_long")
if debug_latent_stats:
print(f"[HeliosDebug][Sampler][chunk={chunk_idx}] latents_history_short: {_tensor_stats_str(latents_history_short)}")
print(f"[HeliosDebug][Sampler][chunk={chunk_idx}] latents_history_mid: {_tensor_stats_str(latents_history_mid)}")
print(f"[HeliosDebug][Sampler][chunk={chunk_idx}] latents_history_long: {_tensor_stats_str(latents_history_long)}")
for stage_idx in range(stage_count):
stage_latent = stage_latent.to(comfy.model_management.get_torch_device())
@ -1099,8 +1245,7 @@ class HeliosPyramidSampler(io.ComfyNode):
else:
pass
# Keep parity with Diffusers pipeline order:
# stage timesteps are computed before upsampling/renoise for stage > 0.
# Stage timesteps are computed before upsampling/renoise for stage > 0.
if stage_idx > 0:
stage_latent = _upsample_latent_5d(stage_latent, scale=2)
@ -1188,8 +1333,7 @@ class HeliosPyramidSampler(io.ComfyNode):
seed=noise_seed + chunk_idx * 100 + stage_idx,
)
# sample_custom returns latent_format.process_out(samples); convert back to model-space
# so subsequent pyramid stages and history conditioning stay in the same latent space
# as Diffusers' internal denoising latents.
# so subsequent pyramid stages and history conditioning stay in the same latent space.
stage_latent = model.model.process_latent_in(stage_latent)
if stage_latent.shape[-2] != h or stage_latent.shape[-1] != w:
@ -1205,12 +1349,27 @@ class HeliosPyramidSampler(io.ComfyNode):
rolling_history = torch.cat([rolling_history, stage_latent.to(rolling_history.device, rolling_history.dtype)], dim=2)
keep_hist = max(1, sum(history_sizes_list))
rolling_history = rolling_history[:, :, -keep_hist:]
total_generated_latent_frames += stage_latent.shape[2]
history_output = torch.cat([history_output, stage_latent.to(history_output.device, history_output.dtype)], dim=2)
stage_latent = torch.cat(generated_chunks, dim=2)[:, :, :t]
include_history_in_output = _strict_bool(latent.get("helios_include_history_in_output", False), default=False)
if include_history_in_output and history_output is not None:
keep_t = max(0, int(total_generated_latent_frames))
stage_latent = history_output[:, :, -keep_t:] if keep_t > 0 else history_output[:, :, :0]
elif len(generated_chunks) > 0:
stage_latent = torch.cat(generated_chunks, dim=2)
else:
stage_latent = torch.zeros((b, c, 0, h, w), device=target_device, dtype=torch.float32)
out = latent.copy()
out.pop("downscale_ratio_spacial", None)
out["samples"] = model.model.process_latent_out(stage_latent)
out["helios_chunk_decode"] = True
out["helios_chunk_latent_frames"] = int(chunk_t)
out["helios_chunk_count"] = int(len(generated_chunks))
out["helios_window_num_frames"] = int(window_num_frames)
out["helios_num_frames"] = int(num_frames)
out["helios_prefix_latent_frames"] = int(initial_generated_latent_frames if include_history_in_output else 0)
if "x0" in x0_output:
x0_out = model.model.process_latent_out(x0_output["x0"].cpu())
@ -1222,6 +1381,60 @@ class HeliosPyramidSampler(io.ComfyNode):
return io.NodeOutput(out, out_denoised)
class HeliosVAEDecode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="HeliosVAEDecode",
category="latent",
inputs=[
io.Latent.Input("samples"),
io.Vae.Input("vae"),
],
outputs=[io.Image.Output(display_name="image")],
)
@classmethod
def execute(cls, samples, vae) -> io.NodeOutput:
latent = samples["samples"]
if latent.is_nested:
latent = latent.unbind()[0]
helios_chunk_decode = bool(samples.get("helios_chunk_decode", False))
helios_chunk_latent_frames = int(samples.get("helios_chunk_latent_frames", 0) or 0)
helios_prefix_latent_frames = int(samples.get("helios_prefix_latent_frames", 0) or 0)
if (
helios_chunk_decode
and latent.ndim == 5
and helios_chunk_latent_frames > 0
and latent.shape[2] > 0
):
decoded_chunks = []
prefix_t = max(0, min(helios_prefix_latent_frames, latent.shape[2]))
if prefix_t > 0:
decoded_chunks.append(vae.decode(latent[:, :, :prefix_t]))
body = latent[:, :, prefix_t:]
for start in range(0, body.shape[2], helios_chunk_latent_frames):
chunk = body[:, :, start:start + helios_chunk_latent_frames]
if chunk.shape[2] == 0:
continue
decoded_chunks.append(vae.decode(chunk))
if len(decoded_chunks) > 0:
images = torch.cat(decoded_chunks, dim=1)
else:
images = vae.decode(latent)
else:
images = vae.decode(latent)
if len(images.shape) == 5:
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
return io.NodeOutput(images)
class HeliosExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@ -1231,6 +1444,7 @@ class HeliosExtension(ComfyExtension):
HeliosVideoToVideo,
HeliosHistoryConditioning,
HeliosPyramidSampler,
HeliosVAEDecode,
]