mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-17 23:25:05 +08:00
Enhance Helios model with latent space noise application and debugging options
This commit is contained in:
parent
d93133ee53
commit
26f44ab770
@ -652,8 +652,8 @@ class HeliosModel(torch.nn.Module):
|
||||
)
|
||||
original_context_length = hidden_states.shape[1]
|
||||
|
||||
if (latents_history_short is not None and indices_latents_history_short is not None and hasattr(self, "patch_short")):
|
||||
x_short = self.patch_short(latents_history_short).to(hidden_states.dtype)
|
||||
if latents_history_short is not None and indices_latents_history_short is not None:
|
||||
x_short = self.patch_short(latents_history_short)
|
||||
_, _, ts, hs, ws = x_short.shape
|
||||
x_short = x_short.flatten(2).transpose(1, 2)
|
||||
f_short = self.rope_encode(
|
||||
@ -671,57 +671,45 @@ class HeliosModel(torch.nn.Module):
|
||||
hidden_states = torch.cat([x_short, hidden_states], dim=1)
|
||||
freqs = torch.cat([f_short, freqs], dim=1)
|
||||
|
||||
if (latents_history_mid is not None and indices_latents_history_mid is not None and hasattr(self, "patch_mid")):
|
||||
x_mid = self.patch_mid(pad_for_3d_conv(latents_history_mid, (2, 4, 4))).to(hidden_states.dtype)
|
||||
_, _, tm, hm, wm = x_mid.shape
|
||||
if latents_history_mid is not None and indices_latents_history_mid is not None:
|
||||
x_mid = self.patch_mid(pad_for_3d_conv(latents_history_mid, (2, 4, 4)))
|
||||
_, _, tm, _, _ = x_mid.shape
|
||||
x_mid = x_mid.flatten(2).transpose(1, 2)
|
||||
mid_t = indices_latents_history_mid.shape[1]
|
||||
if ("hs" in locals()) and ("ws" in locals()):
|
||||
mid_h, mid_w = hs, ws
|
||||
else:
|
||||
mid_h, mid_w = hm * 2, wm * 2
|
||||
f_mid = self.rope_encode(
|
||||
t=mid_t * self.patch_size[0],
|
||||
h=mid_h * self.patch_size[1],
|
||||
w=mid_w * self.patch_size[2],
|
||||
h=hs * self.patch_size[1],
|
||||
w=ws * self.patch_size[2],
|
||||
steps_t=mid_t,
|
||||
steps_h=mid_h,
|
||||
steps_w=mid_w,
|
||||
steps_h=hs,
|
||||
steps_w=ws,
|
||||
device=x_mid.device,
|
||||
dtype=x_mid.dtype,
|
||||
transformer_options=transformer_options,
|
||||
frame_indices=indices_latents_history_mid,
|
||||
)
|
||||
f_mid = self._rope_downsample_3d(f_mid, (mid_t, mid_h, mid_w), (2, 2, 2))
|
||||
if f_mid.shape[1] != x_mid.shape[1]:
|
||||
f_mid = f_mid[:, :x_mid.shape[1]]
|
||||
f_mid = self._rope_downsample_3d(f_mid, (mid_t, hs, ws), (2, 2, 2))
|
||||
hidden_states = torch.cat([x_mid, hidden_states], dim=1)
|
||||
freqs = torch.cat([f_mid, freqs], dim=1)
|
||||
|
||||
if (latents_history_long is not None and indices_latents_history_long is not None and hasattr(self, "patch_long")):
|
||||
x_long = self.patch_long(pad_for_3d_conv(latents_history_long, (4, 8, 8))).to(hidden_states.dtype)
|
||||
_, _, tl, hl, wl = x_long.shape
|
||||
if latents_history_long is not None and indices_latents_history_long is not None:
|
||||
x_long = self.patch_long(pad_for_3d_conv(latents_history_long, (4, 8, 8)))
|
||||
_, _, tl, _, _ = x_long.shape
|
||||
x_long = x_long.flatten(2).transpose(1, 2)
|
||||
long_t = indices_latents_history_long.shape[1]
|
||||
if ("hs" in locals()) and ("ws" in locals()):
|
||||
long_h, long_w = hs, ws
|
||||
else:
|
||||
long_h, long_w = hl * 4, wl * 4
|
||||
f_long = self.rope_encode(
|
||||
t=long_t * self.patch_size[0],
|
||||
h=long_h * self.patch_size[1],
|
||||
w=long_w * self.patch_size[2],
|
||||
h=hs * self.patch_size[1],
|
||||
w=ws * self.patch_size[2],
|
||||
steps_t=long_t,
|
||||
steps_h=long_h,
|
||||
steps_w=long_w,
|
||||
steps_h=hs,
|
||||
steps_w=ws,
|
||||
device=x_long.device,
|
||||
dtype=x_long.dtype,
|
||||
transformer_options=transformer_options,
|
||||
frame_indices=indices_latents_history_long,
|
||||
)
|
||||
f_long = self._rope_downsample_3d(f_long, (long_t, long_h, long_w), (4, 4, 4))
|
||||
if f_long.shape[1] != x_long.shape[1]:
|
||||
f_long = f_long[:, :x_long.shape[1]]
|
||||
f_long = self._rope_downsample_3d(f_long, (long_t, hs, ws), (4, 4, 4))
|
||||
hidden_states = torch.cat([x_long, hidden_states], dim=1)
|
||||
freqs = torch.cat([f_long, freqs], dim=1)
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ import comfy.model_patcher
|
||||
import comfy.sample
|
||||
import comfy.samplers
|
||||
import comfy.utils
|
||||
import comfy.latent_formats
|
||||
import latent_preview
|
||||
import node_helpers
|
||||
|
||||
@ -41,6 +42,37 @@ def _parse_int_list(values, default):
|
||||
return out if len(out) > 0 else default
|
||||
|
||||
|
||||
_HELIOS_LATENT_FORMAT = comfy.latent_formats.Helios()
|
||||
|
||||
|
||||
def _apply_helios_latent_space_noise(latent, sigma, generator=None):
|
||||
"""Apply noise in Helios model latent space, then map back to VAE latent space."""
|
||||
latent_in = _HELIOS_LATENT_FORMAT.process_in(latent)
|
||||
noise = torch.randn(
|
||||
latent_in.shape,
|
||||
device=latent_in.device,
|
||||
dtype=latent_in.dtype,
|
||||
generator=generator,
|
||||
)
|
||||
noised_in = sigma * noise + (1.0 - sigma) * latent_in
|
||||
return _HELIOS_LATENT_FORMAT.process_out(noised_in).to(device=latent.device, dtype=latent.dtype)
|
||||
|
||||
|
||||
def _tensor_stats_str(x):
|
||||
if x is None:
|
||||
return "None"
|
||||
if not torch.is_tensor(x):
|
||||
return f"non-tensor type={type(x)}"
|
||||
if x.numel() == 0:
|
||||
return f"shape={tuple(x.shape)} empty"
|
||||
xf = x.detach().to(torch.float32)
|
||||
return (
|
||||
f"shape={tuple(x.shape)} "
|
||||
f"mean={xf.mean().item():.6f} std={xf.std(unbiased=False).item():.6f} "
|
||||
f"min={xf.min().item():.6f} max={xf.max().item():.6f}"
|
||||
)
|
||||
|
||||
|
||||
def _parse_float_list(values, default):
|
||||
if values is None:
|
||||
return default
|
||||
@ -65,6 +97,15 @@ def _parse_float_list(values, default):
|
||||
return out if len(out) > 0 else default
|
||||
|
||||
|
||||
def _strict_bool(value, default=False):
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, int):
|
||||
return value != 0
|
||||
# Reject non-bool numerics from stale workflows (e.g. 0.135).
|
||||
return bool(default)
|
||||
|
||||
|
||||
def _extract_condition_value(conditioning, key):
|
||||
for c in conditioning:
|
||||
if len(c) < 2:
|
||||
@ -94,8 +135,9 @@ def _process_latent_in_preserve_zero_frames(model, latent, valid_mask=None):
|
||||
return latent
|
||||
|
||||
if nonzero.shape[0] != latent.shape[2]:
|
||||
# Keep behavior safe when mask length does not match temporal length.
|
||||
nonzero = torch.zeros((latent.shape[2],), device=latent.device, dtype=torch.bool)
|
||||
raise ValueError(
|
||||
f"Helios history mask length mismatch: mask_t={nonzero.shape[0]} latent_t={latent.shape[2]}"
|
||||
)
|
||||
|
||||
converted = model.model.process_latent_in(latent)
|
||||
out = latent.clone()
|
||||
@ -133,7 +175,7 @@ def _prepare_stage0_latent(batch, channels, frames, height, width, stage_count,
|
||||
|
||||
|
||||
def _downsample_latent_for_stage0(latent, stage_count):
|
||||
"""Downsample latent to stage 0 resolution (like Diffusers does)"""
|
||||
"""Downsample latent to stage 0 resolution."""
|
||||
stage_latent = latent
|
||||
for _ in range(max(0, int(stage_count) - 1)):
|
||||
stage_latent = _downsample_latent_5d_bilinear_x2(stage_latent)
|
||||
@ -154,7 +196,7 @@ def _sample_block_noise_like(latent, gamma, patch_size=(1, 2, 2), generator=None
|
||||
block_number = b * c * t * h_blocks * w_blocks
|
||||
|
||||
if generator is not None:
|
||||
# Exact Diffusers sampling path (MultivariateNormal.sample), while consuming
|
||||
# Exact sampling path (MultivariateNormal.sample), while consuming
|
||||
# from an explicit generator by temporarily swapping default RNG state.
|
||||
with torch.random.fork_rng(devices=[latent.device] if latent.device.type == "cuda" else []):
|
||||
if latent.device.type == "cuda":
|
||||
@ -231,7 +273,7 @@ def _helios_stage_tables(stage_count, stage_range, gamma, num_train_timesteps=10
|
||||
tmax = min(float(sigmas[int(start_ratio * num_train_timesteps)].item() * num_train_timesteps), 999.0)
|
||||
tmin = float(sigmas[min(int(end_ratio * num_train_timesteps), num_train_timesteps - 1)].item() * num_train_timesteps)
|
||||
timesteps_per_stage[i] = torch.linspace(tmax, tmin, num_train_timesteps + 1)[:-1]
|
||||
# Fixed: Use same sigma range [0.999, 0] for all stages like Diffusers
|
||||
# Fixed: use the same sigma range [0.999, 0] for all stages.
|
||||
sigmas_per_stage[i] = torch.linspace(0.999, 0.0, num_train_timesteps + 1)[:-1]
|
||||
|
||||
|
||||
@ -302,21 +344,18 @@ def _build_cfg_zero_star_pre_cfg(stage_idx, zero_steps, use_zero_init):
|
||||
state["i"] += 1
|
||||
return conds_out
|
||||
|
||||
denoised_text = conds_out[0] # apply_model 返回的 denoised
|
||||
denoised_text = conds_out[0]
|
||||
denoised_uncond = conds_out[1]
|
||||
cfg = float(args.get("cond_scale", 1.0))
|
||||
x = args["input"] # 当前的 noisy latent
|
||||
sigma = args["sigma"] # 当前的 sigma
|
||||
x = args["input"]
|
||||
sigma = args["sigma"]
|
||||
|
||||
# 关键修复:将 denoised 转换为 flow
|
||||
# denoised = x - flow * sigma => flow = (x - denoised) / sigma
|
||||
sigma_reshaped = sigma.reshape(sigma.shape[0], *([1] * (denoised_text.ndim - 1)))
|
||||
sigma_safe = torch.clamp(sigma_reshaped, min=1e-8)
|
||||
|
||||
flow_text = (x - denoised_text) / sigma_safe
|
||||
flow_uncond = (x - denoised_uncond) / sigma_safe
|
||||
|
||||
# 在 flow 空间做 CFG Zero Star
|
||||
positive_flat = flow_text.reshape(flow_text.shape[0], -1)
|
||||
negative_flat = flow_uncond.reshape(flow_uncond.shape[0], -1)
|
||||
alpha = _optimized_scale(positive_flat, negative_flat)
|
||||
@ -327,11 +366,9 @@ def _build_cfg_zero_star_pre_cfg(stage_idx, zero_steps, use_zero_init):
|
||||
else:
|
||||
flow_final = flow_uncond * alpha + cfg * (flow_text - flow_uncond * alpha)
|
||||
|
||||
# 将 flow 转回 denoised
|
||||
denoised_final = x - flow_final * sigma_safe
|
||||
|
||||
state["i"] += 1
|
||||
# Return identical cond/uncond so downstream cfg_function keeps `final` unchanged.
|
||||
return [denoised_final, denoised_final]
|
||||
|
||||
return pre_cfg_fn
|
||||
@ -519,6 +556,8 @@ class HeliosImageToVideo(io.ComfyNode):
|
||||
io.Float.Input("image_noise_sigma_min", default=0.111, min=0.0, max=1.0, step=0.0001, round=False, advanced=True),
|
||||
io.Float.Input("image_noise_sigma_max", default=0.135, min=0.0, max=1.0, step=0.0001, round=False, advanced=True),
|
||||
io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, advanced=True),
|
||||
io.Boolean.Input("include_history_in_output", default=False, advanced=True),
|
||||
io.Boolean.Input("debug_latent_stats", default=False, advanced=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
@ -545,7 +584,11 @@ class HeliosImageToVideo(io.ComfyNode):
|
||||
image_noise_sigma_min=0.111,
|
||||
image_noise_sigma_max=0.135,
|
||||
noise_seed=0,
|
||||
include_history_in_output=False,
|
||||
debug_latent_stats=False,
|
||||
) -> io.NodeOutput:
|
||||
video_noise_sigma_min = 0.111
|
||||
video_noise_sigma_max = 0.135
|
||||
spacial_scale = vae.spacial_compression_encode()
|
||||
latent_channels = vae.latent_channels
|
||||
latent_t = ((length - 1) // 4) + 1
|
||||
@ -560,10 +603,11 @@ class HeliosImageToVideo(io.ComfyNode):
|
||||
history_valid_mask = torch.zeros((batch_size, hist_len), device=latent.device, dtype=torch.bool)
|
||||
image_latent_prefix = None
|
||||
i2v_noise_gen = None
|
||||
noise_gen_state = None
|
||||
|
||||
if start_image is not None:
|
||||
image = comfy.utils.common_upscale(start_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
img_latent = vae.encode(image[:, :, :, :3])
|
||||
img_latent = vae.encode(image[:, :, :, :3]).to(device=latent.device, dtype=torch.float32)
|
||||
img_latent = comfy.utils.repeat_to_batch_size(img_latent, batch_size)
|
||||
image_latent_prefix = img_latent[:, :, :1]
|
||||
|
||||
@ -571,33 +615,38 @@ class HeliosImageToVideo(io.ComfyNode):
|
||||
i2v_noise_gen = torch.Generator(device=img_latent.device)
|
||||
i2v_noise_gen.manual_seed(int(noise_seed))
|
||||
sigma = (
|
||||
torch.rand((img_latent.shape[0], 1, 1, 1, 1), device=img_latent.device, generator=i2v_noise_gen, dtype=img_latent.dtype)
|
||||
torch.rand((1,), device=img_latent.device, generator=i2v_noise_gen, dtype=img_latent.dtype).view(1, 1, 1, 1, 1)
|
||||
* (float(image_noise_sigma_max) - float(image_noise_sigma_min))
|
||||
+ float(image_noise_sigma_min)
|
||||
)
|
||||
image_latent_prefix = sigma * torch.randn_like(image_latent_prefix, generator=i2v_noise_gen) + (1.0 - sigma) * image_latent_prefix
|
||||
image_latent_prefix = _apply_helios_latent_space_noise(image_latent_prefix, sigma, generator=i2v_noise_gen)
|
||||
|
||||
min_frames = max(1, (int(num_latent_frames_per_chunk) - 1) * 4 + 1)
|
||||
fake_video = image.repeat(min_frames, 1, 1, 1)
|
||||
fake_latents_full = vae.encode(fake_video)
|
||||
fake_latents_full = vae.encode(fake_video).to(device=latent.device, dtype=torch.float32)
|
||||
fake_latent = comfy.utils.repeat_to_batch_size(fake_latents_full[:, :, -1:], batch_size)
|
||||
# Diffusers parity for I2V:
|
||||
# when adding noise to image latents, fake_image_latents used for history are also noised.
|
||||
if add_noise_to_image_latents:
|
||||
if i2v_noise_gen is None:
|
||||
i2v_noise_gen = torch.Generator(device=fake_latent.device)
|
||||
i2v_noise_gen.manual_seed(int(noise_seed))
|
||||
# Keep backward compatibility with existing I2V node inputs:
|
||||
# this node exposes only image sigma controls, while fake history
|
||||
# latents follow the video-noise path in Diffusers.
|
||||
# this node exposes only image sigma controls; fake history latents
|
||||
# follow the video-noise defaults.
|
||||
fake_sigma = (
|
||||
torch.rand((fake_latent.shape[0], 1, 1, 1, 1), device=fake_latent.device, generator=i2v_noise_gen, dtype=fake_latent.dtype)
|
||||
* (float(image_noise_sigma_max) - float(image_noise_sigma_min))
|
||||
+ float(image_noise_sigma_min)
|
||||
torch.rand((1,), device=fake_latent.device, generator=i2v_noise_gen, dtype=fake_latent.dtype).view(1, 1, 1, 1, 1)
|
||||
* (float(video_noise_sigma_max) - float(video_noise_sigma_min))
|
||||
+ float(video_noise_sigma_min)
|
||||
)
|
||||
fake_latent = fake_sigma * torch.randn_like(fake_latent, generator=i2v_noise_gen) + (1.0 - fake_sigma) * fake_latent
|
||||
fake_latent = _apply_helios_latent_space_noise(fake_latent, fake_sigma, generator=i2v_noise_gen)
|
||||
history_latent[:, :, -1:] = fake_latent
|
||||
history_valid_mask[:, -1] = True
|
||||
if i2v_noise_gen is not None:
|
||||
noise_gen_state = i2v_noise_gen.get_state().clone()
|
||||
if debug_latent_stats:
|
||||
print(f"[HeliosDebug][I2V] image_latent_prefix: {_tensor_stats_str(image_latent_prefix)}")
|
||||
print(f"[HeliosDebug][I2V] fake_latent: {_tensor_stats_str(fake_latent)}")
|
||||
print(f"[HeliosDebug][I2V] history_latent: {_tensor_stats_str(history_latent)}")
|
||||
|
||||
positive, negative = _set_helios_history_values(positive, negative, history_latent, sizes, keep_first_frame, prefix_latent=image_latent_prefix)
|
||||
return io.NodeOutput(
|
||||
@ -608,6 +657,10 @@ class HeliosImageToVideo(io.ComfyNode):
|
||||
"helios_history_latent": history_latent,
|
||||
"helios_image_latent_prefix": image_latent_prefix,
|
||||
"helios_history_valid_mask": history_valid_mask,
|
||||
"helios_num_frames": int(length),
|
||||
"helios_noise_gen_state": noise_gen_state,
|
||||
"helios_include_history_in_output": _strict_bool(include_history_in_output, default=False),
|
||||
"helios_debug_latent_stats": bool(debug_latent_stats),
|
||||
},
|
||||
)
|
||||
|
||||
@ -686,6 +739,7 @@ class HeliosTextToVideo(io.ComfyNode):
|
||||
"helios_history_latent": history_latent,
|
||||
"helios_image_latent_prefix": None,
|
||||
"helios_history_valid_mask": history_valid_mask,
|
||||
"helios_num_frames": int(length),
|
||||
},
|
||||
)
|
||||
|
||||
@ -707,10 +761,13 @@ class HeliosVideoToVideo(io.ComfyNode):
|
||||
io.Image.Input("video", optional=True),
|
||||
io.String.Input("history_sizes", default="16,2,1", advanced=True),
|
||||
io.Boolean.Input("keep_first_frame", default=True, advanced=True),
|
||||
io.Int.Input("num_latent_frames_per_chunk", default=9, min=1, max=256, advanced=True),
|
||||
io.Boolean.Input("add_noise_to_video_latents", default=True, advanced=True),
|
||||
io.Float.Input("video_noise_sigma_min", default=0.111, min=0.0, max=1.0, step=0.0001, round=False, advanced=True),
|
||||
io.Float.Input("video_noise_sigma_max", default=0.135, min=0.0, max=1.0, step=0.0001, round=False, advanced=True),
|
||||
io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, advanced=True),
|
||||
io.Boolean.Input("include_history_in_output", default=True, advanced=True),
|
||||
io.Boolean.Input("debug_latent_stats", default=False, advanced=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
@ -732,10 +789,13 @@ class HeliosVideoToVideo(io.ComfyNode):
|
||||
video=None,
|
||||
history_sizes="16,2,1",
|
||||
keep_first_frame=True,
|
||||
num_latent_frames_per_chunk=9,
|
||||
add_noise_to_video_latents=True,
|
||||
video_noise_sigma_min=0.111,
|
||||
video_noise_sigma_max=0.135,
|
||||
noise_seed=0,
|
||||
include_history_in_output=True,
|
||||
debug_latent_stats=False,
|
||||
) -> io.NodeOutput:
|
||||
spacial_scale = vae.spacial_compression_encode()
|
||||
latent_channels = vae.latent_channels
|
||||
@ -750,29 +810,81 @@ class HeliosVideoToVideo(io.ComfyNode):
|
||||
history_latent = torch.zeros([batch_size, latent_channels, hist_len, latent.shape[-2], latent.shape[-1]], device=latent.device, dtype=latent.dtype)
|
||||
history_valid_mask = torch.zeros((batch_size, hist_len), device=latent.device, dtype=torch.bool)
|
||||
image_latent_prefix = None
|
||||
noise_gen_state = None
|
||||
history_latent_output = history_latent
|
||||
|
||||
if video is not None:
|
||||
video = comfy.utils.common_upscale(video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
vid_latent = vae.encode(video[:, :, :, :3])
|
||||
num_frames = int(video.shape[0])
|
||||
min_frames = max(1, (int(num_latent_frames_per_chunk) - 1) * 4 + 1)
|
||||
num_chunks = num_frames // min_frames
|
||||
if num_chunks == 0:
|
||||
raise ValueError(
|
||||
f"Video must have at least {min_frames} frames (got {num_frames} frames). "
|
||||
f"Required: (num_latent_frames_per_chunk - 1) * 4 + 1 = ({int(num_latent_frames_per_chunk)} - 1) * 4 + 1 = {min_frames}"
|
||||
)
|
||||
|
||||
first_frame = video[:1]
|
||||
first_frame_latent = vae.encode(first_frame[:, :, :, :3]).to(device=latent.device, dtype=torch.float32)
|
||||
|
||||
total_valid_frames = num_chunks * min_frames
|
||||
start_frame = num_frames - total_valid_frames
|
||||
latents_chunks = []
|
||||
for i in range(num_chunks):
|
||||
chunk_start = start_frame + i * min_frames
|
||||
chunk_end = chunk_start + min_frames
|
||||
video_chunk = video[chunk_start:chunk_end]
|
||||
chunk_latents = vae.encode(video_chunk[:, :, :, :3]).to(device=latent.device, dtype=torch.float32)
|
||||
latents_chunks.append(chunk_latents)
|
||||
vid_latent = torch.cat(latents_chunks, dim=2)
|
||||
vid_latent_clean = vid_latent.clone()
|
||||
|
||||
if add_noise_to_video_latents:
|
||||
g = torch.Generator(device=vid_latent.device)
|
||||
g.manual_seed(int(noise_seed))
|
||||
frame_sigmas = (
|
||||
torch.rand((1, 1, vid_latent.shape[2], 1, 1), device=vid_latent.device, generator=g, dtype=vid_latent.dtype)
|
||||
|
||||
image_sigma = (
|
||||
torch.rand((1,), device=first_frame_latent.device, generator=g, dtype=first_frame_latent.dtype).view(1, 1, 1, 1, 1)
|
||||
* (float(video_noise_sigma_max) - float(video_noise_sigma_min))
|
||||
+ float(video_noise_sigma_min)
|
||||
)
|
||||
vid_latent = frame_sigmas * torch.randn_like(vid_latent, generator=g) + (1.0 - frame_sigmas) * vid_latent
|
||||
vid_latent = vid_latent[:, :, :hist_len]
|
||||
first_frame_latent = _apply_helios_latent_space_noise(first_frame_latent, image_sigma, generator=g)
|
||||
|
||||
noisy_chunks = []
|
||||
num_latent_chunks = max(1, vid_latent.shape[2] // int(num_latent_frames_per_chunk))
|
||||
for i in range(num_latent_chunks):
|
||||
chunk_start = i * int(num_latent_frames_per_chunk)
|
||||
chunk_end = chunk_start + int(num_latent_frames_per_chunk)
|
||||
latent_chunk = vid_latent[:, :, chunk_start:chunk_end, :, :]
|
||||
if latent_chunk.shape[2] == 0:
|
||||
continue
|
||||
chunk_frames = latent_chunk.shape[2]
|
||||
frame_sigmas = (
|
||||
torch.rand((chunk_frames,), device=vid_latent.device, generator=g, dtype=vid_latent.dtype)
|
||||
* (float(video_noise_sigma_max) - float(video_noise_sigma_min))
|
||||
+ float(video_noise_sigma_min)
|
||||
).view(1, 1, chunk_frames, 1, 1)
|
||||
noisy_chunk = _apply_helios_latent_space_noise(latent_chunk, frame_sigmas, generator=g)
|
||||
noisy_chunks.append(noisy_chunk)
|
||||
if len(noisy_chunks) > 0:
|
||||
vid_latent = torch.cat(noisy_chunks, dim=2)
|
||||
noise_gen_state = g.get_state().clone()
|
||||
if debug_latent_stats:
|
||||
print(f"[HeliosDebug][V2V] first_frame_latent: {_tensor_stats_str(first_frame_latent)}")
|
||||
print(f"[HeliosDebug][V2V] video_latent: {_tensor_stats_str(vid_latent)}")
|
||||
|
||||
vid_latent = comfy.utils.repeat_to_batch_size(vid_latent, batch_size)
|
||||
if vid_latent.shape[2] < hist_len:
|
||||
keep_frames = hist_len - vid_latent.shape[2]
|
||||
image_latent_prefix = comfy.utils.repeat_to_batch_size(first_frame_latent, batch_size)
|
||||
video_frames = vid_latent.shape[2]
|
||||
if video_frames < hist_len:
|
||||
keep_frames = hist_len - video_frames
|
||||
history_latent = torch.cat([history_latent[:, :, :keep_frames], vid_latent], dim=2)
|
||||
history_latent_output = torch.cat([history_latent_output[:, :, :keep_frames], comfy.utils.repeat_to_batch_size(vid_latent_clean, batch_size)], dim=2)
|
||||
history_valid_mask[:, keep_frames:] = True
|
||||
else:
|
||||
history_latent = vid_latent[:, :, -hist_len:]
|
||||
history_valid_mask[:] = True
|
||||
image_latent_prefix = history_latent[:, :, :1]
|
||||
history_latent = vid_latent
|
||||
history_latent_output = comfy.utils.repeat_to_batch_size(vid_latent_clean, batch_size)
|
||||
history_valid_mask = torch.ones((batch_size, video_frames), device=latent.device, dtype=torch.bool)
|
||||
|
||||
positive, negative = _set_helios_history_values(positive, negative, history_latent, sizes, keep_first_frame, prefix_latent=image_latent_prefix)
|
||||
return io.NodeOutput(
|
||||
@ -781,8 +893,14 @@ class HeliosVideoToVideo(io.ComfyNode):
|
||||
{
|
||||
"samples": latent,
|
||||
"helios_history_latent": history_latent,
|
||||
"helios_history_latent_output": history_latent_output,
|
||||
"helios_image_latent_prefix": image_latent_prefix,
|
||||
"helios_history_valid_mask": history_valid_mask,
|
||||
"helios_num_frames": int(length),
|
||||
"helios_noise_gen_state": noise_gen_state,
|
||||
# Keep initial history segment and generated chunks together in sampler output.
|
||||
"helios_include_history_in_output": _strict_bool(include_history_in_output, default=True),
|
||||
"helios_debug_latent_stats": bool(debug_latent_stats),
|
||||
},
|
||||
)
|
||||
|
||||
@ -894,7 +1012,6 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
stage_steps = [max(1, int(s)) for s in stage_steps]
|
||||
stage_count = len(stage_steps)
|
||||
history_sizes_list = sorted([max(0, int(v)) for v in _parse_int_list(history_sizes, [16, 2, 1])], reverse=True)
|
||||
# Diffusers parity: if not keeping first frame, fold prefix slot into short history size.
|
||||
if not keep_first_frame and len(history_sizes_list) > 0:
|
||||
history_sizes_list[-1] += 1
|
||||
|
||||
@ -912,21 +1029,32 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
|
||||
b, c, t, h, w = latent_samples.shape
|
||||
chunk_t = max(1, int(num_latent_frames_per_chunk))
|
||||
chunk_count = max(1, (t + chunk_t - 1) // chunk_t)
|
||||
num_frames = int(latent.get("helios_num_frames", max(1, (int(t) - 1) * 4 + 1)))
|
||||
window_num_frames = (chunk_t - 1) * 4 + 1
|
||||
chunk_count = max(1, (num_frames + window_num_frames - 1) // window_num_frames)
|
||||
euler_sampler = comfy.samplers.KSAMPLER(_helios_euler_sample)
|
||||
target_device = comfy.model_management.get_torch_device()
|
||||
noise_gen = torch.Generator(device=target_device)
|
||||
noise_gen.manual_seed(int(noise_seed))
|
||||
noise_gen_state = latent.get("helios_noise_gen_state", None)
|
||||
if noise_gen_state is not None:
|
||||
try:
|
||||
noise_gen.set_state(noise_gen_state)
|
||||
except Exception:
|
||||
pass
|
||||
debug_latent_stats = bool(latent.get("helios_debug_latent_stats", False))
|
||||
|
||||
image_latent_prefix = latent.get("helios_image_latent_prefix", None)
|
||||
history_valid_mask = latent.get("helios_history_valid_mask", None)
|
||||
if history_valid_mask is None:
|
||||
raise ValueError("Helios sampler requires `helios_history_valid_mask` in latent input.")
|
||||
history_full = None
|
||||
history_from_latent_applied = False
|
||||
if image_latent_prefix is not None:
|
||||
image_latent_prefix = model.model.process_latent_in(image_latent_prefix)
|
||||
if "helios_history_latent" in latent:
|
||||
history_in = _process_latent_in_preserve_zero_frames(model, latent["helios_history_latent"], valid_mask=history_valid_mask)
|
||||
history_full = history_in
|
||||
positive, negative = _set_helios_history_values(
|
||||
positive,
|
||||
negative,
|
||||
@ -959,8 +1087,6 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
x0_output = {}
|
||||
generated_chunks = []
|
||||
if latents_history_short is not None and latents_history_mid is not None and latents_history_long is not None:
|
||||
# Diffusers parity: `history_latents` storage does NOT include the keep_first_frame prefix slot.
|
||||
# `latents_history_short` in conditioning may include [prefix + short_base], so strip prefix here.
|
||||
short_base_size = history_sizes_list[-1] if len(history_sizes_list) > 0 else latents_history_short.shape[2]
|
||||
if keep_first_frame and latents_history_short.shape[2] > short_base_size:
|
||||
short_for_history = latents_history_short[:, :, -short_base_size:]
|
||||
@ -974,7 +1100,7 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
hist_len = max(1, sum(history_sizes_list))
|
||||
rolling_history = torch.zeros((b, c, hist_len, h, w), device=latent_samples.device, dtype=latent_samples.dtype)
|
||||
|
||||
# Align with Diffusers behavior: when initial video latents are provided, seed history buffer
|
||||
# When initial video latents are provided, seed history buffer
|
||||
# with those latents before the first denoising chunk.
|
||||
if not add_noise:
|
||||
hist_len = max(1, sum(history_sizes_list))
|
||||
@ -988,9 +1114,29 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
rolling_history = video_latents[:, :, -hist_len:]
|
||||
|
||||
# Keep history/prefix on the same device/dtype as denoising latents.
|
||||
rolling_history = rolling_history.to(device=target_device, dtype=latent_samples.dtype)
|
||||
rolling_history = rolling_history.to(device=target_device, dtype=torch.float32)
|
||||
if image_latent_prefix is not None:
|
||||
image_latent_prefix = image_latent_prefix.to(device=target_device, dtype=latent_samples.dtype)
|
||||
image_latent_prefix = image_latent_prefix.to(device=target_device, dtype=torch.float32)
|
||||
|
||||
history_output = history_full if history_full is not None else rolling_history
|
||||
if "helios_history_latent_output" in latent:
|
||||
history_output = _process_latent_in_preserve_zero_frames(
|
||||
model,
|
||||
latent["helios_history_latent_output"],
|
||||
valid_mask=history_valid_mask,
|
||||
)
|
||||
history_output = history_output.to(device=target_device, dtype=torch.float32)
|
||||
if history_valid_mask is not None:
|
||||
if not torch.is_tensor(history_valid_mask):
|
||||
history_valid_mask = torch.tensor(history_valid_mask, device=target_device)
|
||||
history_valid_mask = history_valid_mask.to(device=target_device)
|
||||
if history_valid_mask.ndim == 2:
|
||||
initial_generated_latent_frames = int(history_valid_mask.any(dim=0).sum().item())
|
||||
else:
|
||||
initial_generated_latent_frames = int(history_valid_mask.reshape(-1).sum().item())
|
||||
else:
|
||||
initial_generated_latent_frames = 0
|
||||
total_generated_latent_frames = initial_generated_latent_frames
|
||||
|
||||
for chunk_idx in range(chunk_count):
|
||||
# Extract chunk from input latents
|
||||
@ -1000,8 +1146,6 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
|
||||
# Prepare initial latent for this chunk
|
||||
if add_noise:
|
||||
# Diffusers parity: each chunk denoises a fixed latent window size.
|
||||
# Keep chunk temporal length constant and crop only after all chunks.
|
||||
noise_shape = (
|
||||
latent_samples.shape[0],
|
||||
latent_samples.shape[1],
|
||||
@ -1009,9 +1153,9 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
latent_samples.shape[3],
|
||||
latent_samples.shape[4],
|
||||
)
|
||||
stage_latent = torch.randn(noise_shape, device=target_device, dtype=latent_samples.dtype, generator=noise_gen)
|
||||
stage_latent = torch.randn(noise_shape, device=target_device, dtype=torch.float32, generator=noise_gen)
|
||||
else:
|
||||
# Use actual input latents; pad final short chunk to fixed size like Diffusers windowing.
|
||||
# Use actual input latents; pad final short chunk to fixed size.
|
||||
stage_latent = latent_chunk.clone()
|
||||
if stage_latent.shape[2] < chunk_t:
|
||||
if stage_latent.shape[2] == 0:
|
||||
@ -1024,22 +1168,20 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
latent_samples.shape[4],
|
||||
),
|
||||
device=latent_samples.device,
|
||||
dtype=latent_samples.dtype,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
else:
|
||||
pad = stage_latent[:, :, -1:].repeat(1, 1, chunk_t - stage_latent.shape[2], 1, 1)
|
||||
stage_latent = torch.cat([stage_latent, pad], dim=2)
|
||||
stage_latent = stage_latent.to(dtype=torch.float32)
|
||||
|
||||
# Downsample to stage 0 resolution
|
||||
for _ in range(max(0, int(stage_count) - 1)):
|
||||
stage_latent = _downsample_latent_5d_bilinear_x2(stage_latent)
|
||||
|
||||
# Keep stage latents on model device for parity with Diffusers scheduler/noise path.
|
||||
# Keep stage latents on model device for scheduler/noise path consistency.
|
||||
stage_latent = stage_latent.to(target_device)
|
||||
|
||||
# Diffusers parity:
|
||||
# keep_first_frame=True and no image_latent_prefix on the first chunk
|
||||
# should use an all-zero prefix frame, not history[:, :, :1].
|
||||
chunk_prefix = image_latent_prefix
|
||||
if keep_first_frame and image_latent_prefix is None and chunk_idx == 0:
|
||||
chunk_prefix = torch.zeros(
|
||||
@ -1065,6 +1207,10 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
latents_history_short = _extract_condition_value(positive_chunk, "latents_history_short")
|
||||
latents_history_mid = _extract_condition_value(positive_chunk, "latents_history_mid")
|
||||
latents_history_long = _extract_condition_value(positive_chunk, "latents_history_long")
|
||||
if debug_latent_stats:
|
||||
print(f"[HeliosDebug][Sampler][chunk={chunk_idx}] latents_history_short: {_tensor_stats_str(latents_history_short)}")
|
||||
print(f"[HeliosDebug][Sampler][chunk={chunk_idx}] latents_history_mid: {_tensor_stats_str(latents_history_mid)}")
|
||||
print(f"[HeliosDebug][Sampler][chunk={chunk_idx}] latents_history_long: {_tensor_stats_str(latents_history_long)}")
|
||||
|
||||
for stage_idx in range(stage_count):
|
||||
stage_latent = stage_latent.to(comfy.model_management.get_torch_device())
|
||||
@ -1099,8 +1245,7 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
else:
|
||||
pass
|
||||
|
||||
# Keep parity with Diffusers pipeline order:
|
||||
# stage timesteps are computed before upsampling/renoise for stage > 0.
|
||||
# Stage timesteps are computed before upsampling/renoise for stage > 0.
|
||||
if stage_idx > 0:
|
||||
stage_latent = _upsample_latent_5d(stage_latent, scale=2)
|
||||
|
||||
@ -1188,8 +1333,7 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
seed=noise_seed + chunk_idx * 100 + stage_idx,
|
||||
)
|
||||
# sample_custom returns latent_format.process_out(samples); convert back to model-space
|
||||
# so subsequent pyramid stages and history conditioning stay in the same latent space
|
||||
# as Diffusers' internal denoising latents.
|
||||
# so subsequent pyramid stages and history conditioning stay in the same latent space.
|
||||
stage_latent = model.model.process_latent_in(stage_latent)
|
||||
|
||||
if stage_latent.shape[-2] != h or stage_latent.shape[-1] != w:
|
||||
@ -1205,12 +1349,27 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
rolling_history = torch.cat([rolling_history, stage_latent.to(rolling_history.device, rolling_history.dtype)], dim=2)
|
||||
keep_hist = max(1, sum(history_sizes_list))
|
||||
rolling_history = rolling_history[:, :, -keep_hist:]
|
||||
total_generated_latent_frames += stage_latent.shape[2]
|
||||
history_output = torch.cat([history_output, stage_latent.to(history_output.device, history_output.dtype)], dim=2)
|
||||
|
||||
stage_latent = torch.cat(generated_chunks, dim=2)[:, :, :t]
|
||||
include_history_in_output = _strict_bool(latent.get("helios_include_history_in_output", False), default=False)
|
||||
if include_history_in_output and history_output is not None:
|
||||
keep_t = max(0, int(total_generated_latent_frames))
|
||||
stage_latent = history_output[:, :, -keep_t:] if keep_t > 0 else history_output[:, :, :0]
|
||||
elif len(generated_chunks) > 0:
|
||||
stage_latent = torch.cat(generated_chunks, dim=2)
|
||||
else:
|
||||
stage_latent = torch.zeros((b, c, 0, h, w), device=target_device, dtype=torch.float32)
|
||||
|
||||
out = latent.copy()
|
||||
out.pop("downscale_ratio_spacial", None)
|
||||
out["samples"] = model.model.process_latent_out(stage_latent)
|
||||
out["helios_chunk_decode"] = True
|
||||
out["helios_chunk_latent_frames"] = int(chunk_t)
|
||||
out["helios_chunk_count"] = int(len(generated_chunks))
|
||||
out["helios_window_num_frames"] = int(window_num_frames)
|
||||
out["helios_num_frames"] = int(num_frames)
|
||||
out["helios_prefix_latent_frames"] = int(initial_generated_latent_frames if include_history_in_output else 0)
|
||||
|
||||
if "x0" in x0_output:
|
||||
x0_out = model.model.process_latent_out(x0_output["x0"].cpu())
|
||||
@ -1222,6 +1381,60 @@ class HeliosPyramidSampler(io.ComfyNode):
|
||||
return io.NodeOutput(out, out_denoised)
|
||||
|
||||
|
||||
class HeliosVAEDecode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="HeliosVAEDecode",
|
||||
category="latent",
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
io.Vae.Input("vae"),
|
||||
],
|
||||
outputs=[io.Image.Output(display_name="image")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, vae) -> io.NodeOutput:
|
||||
latent = samples["samples"]
|
||||
if latent.is_nested:
|
||||
latent = latent.unbind()[0]
|
||||
|
||||
helios_chunk_decode = bool(samples.get("helios_chunk_decode", False))
|
||||
helios_chunk_latent_frames = int(samples.get("helios_chunk_latent_frames", 0) or 0)
|
||||
helios_prefix_latent_frames = int(samples.get("helios_prefix_latent_frames", 0) or 0)
|
||||
|
||||
if (
|
||||
helios_chunk_decode
|
||||
and latent.ndim == 5
|
||||
and helios_chunk_latent_frames > 0
|
||||
and latent.shape[2] > 0
|
||||
):
|
||||
decoded_chunks = []
|
||||
prefix_t = max(0, min(helios_prefix_latent_frames, latent.shape[2]))
|
||||
|
||||
if prefix_t > 0:
|
||||
decoded_chunks.append(vae.decode(latent[:, :, :prefix_t]))
|
||||
|
||||
body = latent[:, :, prefix_t:]
|
||||
for start in range(0, body.shape[2], helios_chunk_latent_frames):
|
||||
chunk = body[:, :, start:start + helios_chunk_latent_frames]
|
||||
if chunk.shape[2] == 0:
|
||||
continue
|
||||
decoded_chunks.append(vae.decode(chunk))
|
||||
|
||||
if len(decoded_chunks) > 0:
|
||||
images = torch.cat(decoded_chunks, dim=1)
|
||||
else:
|
||||
images = vae.decode(latent)
|
||||
else:
|
||||
images = vae.decode(latent)
|
||||
|
||||
if len(images.shape) == 5:
|
||||
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||
return io.NodeOutput(images)
|
||||
|
||||
|
||||
class HeliosExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@ -1231,6 +1444,7 @@ class HeliosExtension(ComfyExtension):
|
||||
HeliosVideoToVideo,
|
||||
HeliosHistoryConditioning,
|
||||
HeliosPyramidSampler,
|
||||
HeliosVAEDecode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user