diff --git a/comfy/ldm/helios/model.py b/comfy/ldm/helios/model.py index 6fd37b875..c1ea5f595 100644 --- a/comfy/ldm/helios/model.py +++ b/comfy/ldm/helios/model.py @@ -652,8 +652,8 @@ class HeliosModel(torch.nn.Module): ) original_context_length = hidden_states.shape[1] - if (latents_history_short is not None and indices_latents_history_short is not None and hasattr(self, "patch_short")): - x_short = self.patch_short(latents_history_short).to(hidden_states.dtype) + if latents_history_short is not None and indices_latents_history_short is not None: + x_short = self.patch_short(latents_history_short) _, _, ts, hs, ws = x_short.shape x_short = x_short.flatten(2).transpose(1, 2) f_short = self.rope_encode( @@ -671,57 +671,45 @@ class HeliosModel(torch.nn.Module): hidden_states = torch.cat([x_short, hidden_states], dim=1) freqs = torch.cat([f_short, freqs], dim=1) - if (latents_history_mid is not None and indices_latents_history_mid is not None and hasattr(self, "patch_mid")): - x_mid = self.patch_mid(pad_for_3d_conv(latents_history_mid, (2, 4, 4))).to(hidden_states.dtype) - _, _, tm, hm, wm = x_mid.shape + if latents_history_mid is not None and indices_latents_history_mid is not None: + x_mid = self.patch_mid(pad_for_3d_conv(latents_history_mid, (2, 4, 4))) + _, _, tm, _, _ = x_mid.shape x_mid = x_mid.flatten(2).transpose(1, 2) mid_t = indices_latents_history_mid.shape[1] - if ("hs" in locals()) and ("ws" in locals()): - mid_h, mid_w = hs, ws - else: - mid_h, mid_w = hm * 2, wm * 2 f_mid = self.rope_encode( t=mid_t * self.patch_size[0], - h=mid_h * self.patch_size[1], - w=mid_w * self.patch_size[2], + h=hs * self.patch_size[1], + w=ws * self.patch_size[2], steps_t=mid_t, - steps_h=mid_h, - steps_w=mid_w, + steps_h=hs, + steps_w=ws, device=x_mid.device, dtype=x_mid.dtype, transformer_options=transformer_options, frame_indices=indices_latents_history_mid, ) - f_mid = self._rope_downsample_3d(f_mid, (mid_t, mid_h, mid_w), (2, 2, 2)) - if f_mid.shape[1] != x_mid.shape[1]: - f_mid = f_mid[:, :x_mid.shape[1]] + f_mid = self._rope_downsample_3d(f_mid, (mid_t, hs, ws), (2, 2, 2)) hidden_states = torch.cat([x_mid, hidden_states], dim=1) freqs = torch.cat([f_mid, freqs], dim=1) - if (latents_history_long is not None and indices_latents_history_long is not None and hasattr(self, "patch_long")): - x_long = self.patch_long(pad_for_3d_conv(latents_history_long, (4, 8, 8))).to(hidden_states.dtype) - _, _, tl, hl, wl = x_long.shape + if latents_history_long is not None and indices_latents_history_long is not None: + x_long = self.patch_long(pad_for_3d_conv(latents_history_long, (4, 8, 8))) + _, _, tl, _, _ = x_long.shape x_long = x_long.flatten(2).transpose(1, 2) long_t = indices_latents_history_long.shape[1] - if ("hs" in locals()) and ("ws" in locals()): - long_h, long_w = hs, ws - else: - long_h, long_w = hl * 4, wl * 4 f_long = self.rope_encode( t=long_t * self.patch_size[0], - h=long_h * self.patch_size[1], - w=long_w * self.patch_size[2], + h=hs * self.patch_size[1], + w=ws * self.patch_size[2], steps_t=long_t, - steps_h=long_h, - steps_w=long_w, + steps_h=hs, + steps_w=ws, device=x_long.device, dtype=x_long.dtype, transformer_options=transformer_options, frame_indices=indices_latents_history_long, ) - f_long = self._rope_downsample_3d(f_long, (long_t, long_h, long_w), (4, 4, 4)) - if f_long.shape[1] != x_long.shape[1]: - f_long = f_long[:, :x_long.shape[1]] + f_long = self._rope_downsample_3d(f_long, (long_t, hs, ws), (4, 4, 4)) hidden_states = torch.cat([x_long, hidden_states], dim=1) freqs = torch.cat([f_long, freqs], dim=1) diff --git a/comfy_extras/nodes_helios.py b/comfy_extras/nodes_helios.py index 13894082a..3d5f80e76 100644 --- a/comfy_extras/nodes_helios.py +++ b/comfy_extras/nodes_helios.py @@ -7,6 +7,7 @@ import comfy.model_patcher import comfy.sample import comfy.samplers import comfy.utils +import comfy.latent_formats import latent_preview import node_helpers @@ -41,6 +42,37 @@ def _parse_int_list(values, default): return out if len(out) > 0 else default +_HELIOS_LATENT_FORMAT = comfy.latent_formats.Helios() + + +def _apply_helios_latent_space_noise(latent, sigma, generator=None): + """Apply noise in Helios model latent space, then map back to VAE latent space.""" + latent_in = _HELIOS_LATENT_FORMAT.process_in(latent) + noise = torch.randn( + latent_in.shape, + device=latent_in.device, + dtype=latent_in.dtype, + generator=generator, + ) + noised_in = sigma * noise + (1.0 - sigma) * latent_in + return _HELIOS_LATENT_FORMAT.process_out(noised_in).to(device=latent.device, dtype=latent.dtype) + + +def _tensor_stats_str(x): + if x is None: + return "None" + if not torch.is_tensor(x): + return f"non-tensor type={type(x)}" + if x.numel() == 0: + return f"shape={tuple(x.shape)} empty" + xf = x.detach().to(torch.float32) + return ( + f"shape={tuple(x.shape)} " + f"mean={xf.mean().item():.6f} std={xf.std(unbiased=False).item():.6f} " + f"min={xf.min().item():.6f} max={xf.max().item():.6f}" + ) + + def _parse_float_list(values, default): if values is None: return default @@ -65,6 +97,15 @@ def _parse_float_list(values, default): return out if len(out) > 0 else default +def _strict_bool(value, default=False): + if isinstance(value, bool): + return value + if isinstance(value, int): + return value != 0 + # Reject non-bool numerics from stale workflows (e.g. 0.135). + return bool(default) + + def _extract_condition_value(conditioning, key): for c in conditioning: if len(c) < 2: @@ -94,8 +135,9 @@ def _process_latent_in_preserve_zero_frames(model, latent, valid_mask=None): return latent if nonzero.shape[0] != latent.shape[2]: - # Keep behavior safe when mask length does not match temporal length. - nonzero = torch.zeros((latent.shape[2],), device=latent.device, dtype=torch.bool) + raise ValueError( + f"Helios history mask length mismatch: mask_t={nonzero.shape[0]} latent_t={latent.shape[2]}" + ) converted = model.model.process_latent_in(latent) out = latent.clone() @@ -133,7 +175,7 @@ def _prepare_stage0_latent(batch, channels, frames, height, width, stage_count, def _downsample_latent_for_stage0(latent, stage_count): - """Downsample latent to stage 0 resolution (like Diffusers does)""" + """Downsample latent to stage 0 resolution.""" stage_latent = latent for _ in range(max(0, int(stage_count) - 1)): stage_latent = _downsample_latent_5d_bilinear_x2(stage_latent) @@ -154,7 +196,7 @@ def _sample_block_noise_like(latent, gamma, patch_size=(1, 2, 2), generator=None block_number = b * c * t * h_blocks * w_blocks if generator is not None: - # Exact Diffusers sampling path (MultivariateNormal.sample), while consuming + # Exact sampling path (MultivariateNormal.sample), while consuming # from an explicit generator by temporarily swapping default RNG state. with torch.random.fork_rng(devices=[latent.device] if latent.device.type == "cuda" else []): if latent.device.type == "cuda": @@ -231,7 +273,7 @@ def _helios_stage_tables(stage_count, stage_range, gamma, num_train_timesteps=10 tmax = min(float(sigmas[int(start_ratio * num_train_timesteps)].item() * num_train_timesteps), 999.0) tmin = float(sigmas[min(int(end_ratio * num_train_timesteps), num_train_timesteps - 1)].item() * num_train_timesteps) timesteps_per_stage[i] = torch.linspace(tmax, tmin, num_train_timesteps + 1)[:-1] - # Fixed: Use same sigma range [0.999, 0] for all stages like Diffusers + # Fixed: use the same sigma range [0.999, 0] for all stages. sigmas_per_stage[i] = torch.linspace(0.999, 0.0, num_train_timesteps + 1)[:-1] @@ -302,21 +344,18 @@ def _build_cfg_zero_star_pre_cfg(stage_idx, zero_steps, use_zero_init): state["i"] += 1 return conds_out - denoised_text = conds_out[0] # apply_model 返回的 denoised + denoised_text = conds_out[0] denoised_uncond = conds_out[1] cfg = float(args.get("cond_scale", 1.0)) - x = args["input"] # 当前的 noisy latent - sigma = args["sigma"] # 当前的 sigma + x = args["input"] + sigma = args["sigma"] - # 关键修复:将 denoised 转换为 flow - # denoised = x - flow * sigma => flow = (x - denoised) / sigma sigma_reshaped = sigma.reshape(sigma.shape[0], *([1] * (denoised_text.ndim - 1))) sigma_safe = torch.clamp(sigma_reshaped, min=1e-8) flow_text = (x - denoised_text) / sigma_safe flow_uncond = (x - denoised_uncond) / sigma_safe - # 在 flow 空间做 CFG Zero Star positive_flat = flow_text.reshape(flow_text.shape[0], -1) negative_flat = flow_uncond.reshape(flow_uncond.shape[0], -1) alpha = _optimized_scale(positive_flat, negative_flat) @@ -327,11 +366,9 @@ def _build_cfg_zero_star_pre_cfg(stage_idx, zero_steps, use_zero_init): else: flow_final = flow_uncond * alpha + cfg * (flow_text - flow_uncond * alpha) - # 将 flow 转回 denoised denoised_final = x - flow_final * sigma_safe state["i"] += 1 - # Return identical cond/uncond so downstream cfg_function keeps `final` unchanged. return [denoised_final, denoised_final] return pre_cfg_fn @@ -519,6 +556,8 @@ class HeliosImageToVideo(io.ComfyNode): io.Float.Input("image_noise_sigma_min", default=0.111, min=0.0, max=1.0, step=0.0001, round=False, advanced=True), io.Float.Input("image_noise_sigma_max", default=0.135, min=0.0, max=1.0, step=0.0001, round=False, advanced=True), io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, advanced=True), + io.Boolean.Input("include_history_in_output", default=False, advanced=True), + io.Boolean.Input("debug_latent_stats", default=False, advanced=True), ], outputs=[ io.Conditioning.Output(display_name="positive"), @@ -545,7 +584,11 @@ class HeliosImageToVideo(io.ComfyNode): image_noise_sigma_min=0.111, image_noise_sigma_max=0.135, noise_seed=0, + include_history_in_output=False, + debug_latent_stats=False, ) -> io.NodeOutput: + video_noise_sigma_min = 0.111 + video_noise_sigma_max = 0.135 spacial_scale = vae.spacial_compression_encode() latent_channels = vae.latent_channels latent_t = ((length - 1) // 4) + 1 @@ -560,10 +603,11 @@ class HeliosImageToVideo(io.ComfyNode): history_valid_mask = torch.zeros((batch_size, hist_len), device=latent.device, dtype=torch.bool) image_latent_prefix = None i2v_noise_gen = None + noise_gen_state = None if start_image is not None: image = comfy.utils.common_upscale(start_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - img_latent = vae.encode(image[:, :, :, :3]) + img_latent = vae.encode(image[:, :, :, :3]).to(device=latent.device, dtype=torch.float32) img_latent = comfy.utils.repeat_to_batch_size(img_latent, batch_size) image_latent_prefix = img_latent[:, :, :1] @@ -571,33 +615,38 @@ class HeliosImageToVideo(io.ComfyNode): i2v_noise_gen = torch.Generator(device=img_latent.device) i2v_noise_gen.manual_seed(int(noise_seed)) sigma = ( - torch.rand((img_latent.shape[0], 1, 1, 1, 1), device=img_latent.device, generator=i2v_noise_gen, dtype=img_latent.dtype) + torch.rand((1,), device=img_latent.device, generator=i2v_noise_gen, dtype=img_latent.dtype).view(1, 1, 1, 1, 1) * (float(image_noise_sigma_max) - float(image_noise_sigma_min)) + float(image_noise_sigma_min) ) - image_latent_prefix = sigma * torch.randn_like(image_latent_prefix, generator=i2v_noise_gen) + (1.0 - sigma) * image_latent_prefix + image_latent_prefix = _apply_helios_latent_space_noise(image_latent_prefix, sigma, generator=i2v_noise_gen) min_frames = max(1, (int(num_latent_frames_per_chunk) - 1) * 4 + 1) fake_video = image.repeat(min_frames, 1, 1, 1) - fake_latents_full = vae.encode(fake_video) + fake_latents_full = vae.encode(fake_video).to(device=latent.device, dtype=torch.float32) fake_latent = comfy.utils.repeat_to_batch_size(fake_latents_full[:, :, -1:], batch_size) - # Diffusers parity for I2V: # when adding noise to image latents, fake_image_latents used for history are also noised. if add_noise_to_image_latents: if i2v_noise_gen is None: i2v_noise_gen = torch.Generator(device=fake_latent.device) i2v_noise_gen.manual_seed(int(noise_seed)) # Keep backward compatibility with existing I2V node inputs: - # this node exposes only image sigma controls, while fake history - # latents follow the video-noise path in Diffusers. + # this node exposes only image sigma controls; fake history latents + # follow the video-noise defaults. fake_sigma = ( - torch.rand((fake_latent.shape[0], 1, 1, 1, 1), device=fake_latent.device, generator=i2v_noise_gen, dtype=fake_latent.dtype) - * (float(image_noise_sigma_max) - float(image_noise_sigma_min)) - + float(image_noise_sigma_min) + torch.rand((1,), device=fake_latent.device, generator=i2v_noise_gen, dtype=fake_latent.dtype).view(1, 1, 1, 1, 1) + * (float(video_noise_sigma_max) - float(video_noise_sigma_min)) + + float(video_noise_sigma_min) ) - fake_latent = fake_sigma * torch.randn_like(fake_latent, generator=i2v_noise_gen) + (1.0 - fake_sigma) * fake_latent + fake_latent = _apply_helios_latent_space_noise(fake_latent, fake_sigma, generator=i2v_noise_gen) history_latent[:, :, -1:] = fake_latent history_valid_mask[:, -1] = True + if i2v_noise_gen is not None: + noise_gen_state = i2v_noise_gen.get_state().clone() + if debug_latent_stats: + print(f"[HeliosDebug][I2V] image_latent_prefix: {_tensor_stats_str(image_latent_prefix)}") + print(f"[HeliosDebug][I2V] fake_latent: {_tensor_stats_str(fake_latent)}") + print(f"[HeliosDebug][I2V] history_latent: {_tensor_stats_str(history_latent)}") positive, negative = _set_helios_history_values(positive, negative, history_latent, sizes, keep_first_frame, prefix_latent=image_latent_prefix) return io.NodeOutput( @@ -608,6 +657,10 @@ class HeliosImageToVideo(io.ComfyNode): "helios_history_latent": history_latent, "helios_image_latent_prefix": image_latent_prefix, "helios_history_valid_mask": history_valid_mask, + "helios_num_frames": int(length), + "helios_noise_gen_state": noise_gen_state, + "helios_include_history_in_output": _strict_bool(include_history_in_output, default=False), + "helios_debug_latent_stats": bool(debug_latent_stats), }, ) @@ -686,6 +739,7 @@ class HeliosTextToVideo(io.ComfyNode): "helios_history_latent": history_latent, "helios_image_latent_prefix": None, "helios_history_valid_mask": history_valid_mask, + "helios_num_frames": int(length), }, ) @@ -707,10 +761,13 @@ class HeliosVideoToVideo(io.ComfyNode): io.Image.Input("video", optional=True), io.String.Input("history_sizes", default="16,2,1", advanced=True), io.Boolean.Input("keep_first_frame", default=True, advanced=True), + io.Int.Input("num_latent_frames_per_chunk", default=9, min=1, max=256, advanced=True), io.Boolean.Input("add_noise_to_video_latents", default=True, advanced=True), io.Float.Input("video_noise_sigma_min", default=0.111, min=0.0, max=1.0, step=0.0001, round=False, advanced=True), io.Float.Input("video_noise_sigma_max", default=0.135, min=0.0, max=1.0, step=0.0001, round=False, advanced=True), io.Int.Input("noise_seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, advanced=True), + io.Boolean.Input("include_history_in_output", default=True, advanced=True), + io.Boolean.Input("debug_latent_stats", default=False, advanced=True), ], outputs=[ io.Conditioning.Output(display_name="positive"), @@ -732,10 +789,13 @@ class HeliosVideoToVideo(io.ComfyNode): video=None, history_sizes="16,2,1", keep_first_frame=True, + num_latent_frames_per_chunk=9, add_noise_to_video_latents=True, video_noise_sigma_min=0.111, video_noise_sigma_max=0.135, noise_seed=0, + include_history_in_output=True, + debug_latent_stats=False, ) -> io.NodeOutput: spacial_scale = vae.spacial_compression_encode() latent_channels = vae.latent_channels @@ -750,29 +810,81 @@ class HeliosVideoToVideo(io.ComfyNode): history_latent = torch.zeros([batch_size, latent_channels, hist_len, latent.shape[-2], latent.shape[-1]], device=latent.device, dtype=latent.dtype) history_valid_mask = torch.zeros((batch_size, hist_len), device=latent.device, dtype=torch.bool) image_latent_prefix = None + noise_gen_state = None + history_latent_output = history_latent if video is not None: video = comfy.utils.common_upscale(video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - vid_latent = vae.encode(video[:, :, :, :3]) + num_frames = int(video.shape[0]) + min_frames = max(1, (int(num_latent_frames_per_chunk) - 1) * 4 + 1) + num_chunks = num_frames // min_frames + if num_chunks == 0: + raise ValueError( + f"Video must have at least {min_frames} frames (got {num_frames} frames). " + f"Required: (num_latent_frames_per_chunk - 1) * 4 + 1 = ({int(num_latent_frames_per_chunk)} - 1) * 4 + 1 = {min_frames}" + ) + + first_frame = video[:1] + first_frame_latent = vae.encode(first_frame[:, :, :, :3]).to(device=latent.device, dtype=torch.float32) + + total_valid_frames = num_chunks * min_frames + start_frame = num_frames - total_valid_frames + latents_chunks = [] + for i in range(num_chunks): + chunk_start = start_frame + i * min_frames + chunk_end = chunk_start + min_frames + video_chunk = video[chunk_start:chunk_end] + chunk_latents = vae.encode(video_chunk[:, :, :, :3]).to(device=latent.device, dtype=torch.float32) + latents_chunks.append(chunk_latents) + vid_latent = torch.cat(latents_chunks, dim=2) + vid_latent_clean = vid_latent.clone() + if add_noise_to_video_latents: g = torch.Generator(device=vid_latent.device) g.manual_seed(int(noise_seed)) - frame_sigmas = ( - torch.rand((1, 1, vid_latent.shape[2], 1, 1), device=vid_latent.device, generator=g, dtype=vid_latent.dtype) + + image_sigma = ( + torch.rand((1,), device=first_frame_latent.device, generator=g, dtype=first_frame_latent.dtype).view(1, 1, 1, 1, 1) * (float(video_noise_sigma_max) - float(video_noise_sigma_min)) + float(video_noise_sigma_min) ) - vid_latent = frame_sigmas * torch.randn_like(vid_latent, generator=g) + (1.0 - frame_sigmas) * vid_latent - vid_latent = vid_latent[:, :, :hist_len] + first_frame_latent = _apply_helios_latent_space_noise(first_frame_latent, image_sigma, generator=g) + + noisy_chunks = [] + num_latent_chunks = max(1, vid_latent.shape[2] // int(num_latent_frames_per_chunk)) + for i in range(num_latent_chunks): + chunk_start = i * int(num_latent_frames_per_chunk) + chunk_end = chunk_start + int(num_latent_frames_per_chunk) + latent_chunk = vid_latent[:, :, chunk_start:chunk_end, :, :] + if latent_chunk.shape[2] == 0: + continue + chunk_frames = latent_chunk.shape[2] + frame_sigmas = ( + torch.rand((chunk_frames,), device=vid_latent.device, generator=g, dtype=vid_latent.dtype) + * (float(video_noise_sigma_max) - float(video_noise_sigma_min)) + + float(video_noise_sigma_min) + ).view(1, 1, chunk_frames, 1, 1) + noisy_chunk = _apply_helios_latent_space_noise(latent_chunk, frame_sigmas, generator=g) + noisy_chunks.append(noisy_chunk) + if len(noisy_chunks) > 0: + vid_latent = torch.cat(noisy_chunks, dim=2) + noise_gen_state = g.get_state().clone() + if debug_latent_stats: + print(f"[HeliosDebug][V2V] first_frame_latent: {_tensor_stats_str(first_frame_latent)}") + print(f"[HeliosDebug][V2V] video_latent: {_tensor_stats_str(vid_latent)}") + vid_latent = comfy.utils.repeat_to_batch_size(vid_latent, batch_size) - if vid_latent.shape[2] < hist_len: - keep_frames = hist_len - vid_latent.shape[2] + image_latent_prefix = comfy.utils.repeat_to_batch_size(first_frame_latent, batch_size) + video_frames = vid_latent.shape[2] + if video_frames < hist_len: + keep_frames = hist_len - video_frames history_latent = torch.cat([history_latent[:, :, :keep_frames], vid_latent], dim=2) + history_latent_output = torch.cat([history_latent_output[:, :, :keep_frames], comfy.utils.repeat_to_batch_size(vid_latent_clean, batch_size)], dim=2) history_valid_mask[:, keep_frames:] = True else: - history_latent = vid_latent[:, :, -hist_len:] - history_valid_mask[:] = True - image_latent_prefix = history_latent[:, :, :1] + history_latent = vid_latent + history_latent_output = comfy.utils.repeat_to_batch_size(vid_latent_clean, batch_size) + history_valid_mask = torch.ones((batch_size, video_frames), device=latent.device, dtype=torch.bool) positive, negative = _set_helios_history_values(positive, negative, history_latent, sizes, keep_first_frame, prefix_latent=image_latent_prefix) return io.NodeOutput( @@ -781,8 +893,14 @@ class HeliosVideoToVideo(io.ComfyNode): { "samples": latent, "helios_history_latent": history_latent, + "helios_history_latent_output": history_latent_output, "helios_image_latent_prefix": image_latent_prefix, "helios_history_valid_mask": history_valid_mask, + "helios_num_frames": int(length), + "helios_noise_gen_state": noise_gen_state, + # Keep initial history segment and generated chunks together in sampler output. + "helios_include_history_in_output": _strict_bool(include_history_in_output, default=True), + "helios_debug_latent_stats": bool(debug_latent_stats), }, ) @@ -894,7 +1012,6 @@ class HeliosPyramidSampler(io.ComfyNode): stage_steps = [max(1, int(s)) for s in stage_steps] stage_count = len(stage_steps) history_sizes_list = sorted([max(0, int(v)) for v in _parse_int_list(history_sizes, [16, 2, 1])], reverse=True) - # Diffusers parity: if not keeping first frame, fold prefix slot into short history size. if not keep_first_frame and len(history_sizes_list) > 0: history_sizes_list[-1] += 1 @@ -912,21 +1029,32 @@ class HeliosPyramidSampler(io.ComfyNode): b, c, t, h, w = latent_samples.shape chunk_t = max(1, int(num_latent_frames_per_chunk)) - chunk_count = max(1, (t + chunk_t - 1) // chunk_t) + num_frames = int(latent.get("helios_num_frames", max(1, (int(t) - 1) * 4 + 1))) + window_num_frames = (chunk_t - 1) * 4 + 1 + chunk_count = max(1, (num_frames + window_num_frames - 1) // window_num_frames) euler_sampler = comfy.samplers.KSAMPLER(_helios_euler_sample) target_device = comfy.model_management.get_torch_device() noise_gen = torch.Generator(device=target_device) noise_gen.manual_seed(int(noise_seed)) + noise_gen_state = latent.get("helios_noise_gen_state", None) + if noise_gen_state is not None: + try: + noise_gen.set_state(noise_gen_state) + except Exception: + pass + debug_latent_stats = bool(latent.get("helios_debug_latent_stats", False)) image_latent_prefix = latent.get("helios_image_latent_prefix", None) history_valid_mask = latent.get("helios_history_valid_mask", None) if history_valid_mask is None: raise ValueError("Helios sampler requires `helios_history_valid_mask` in latent input.") + history_full = None history_from_latent_applied = False if image_latent_prefix is not None: image_latent_prefix = model.model.process_latent_in(image_latent_prefix) if "helios_history_latent" in latent: history_in = _process_latent_in_preserve_zero_frames(model, latent["helios_history_latent"], valid_mask=history_valid_mask) + history_full = history_in positive, negative = _set_helios_history_values( positive, negative, @@ -959,8 +1087,6 @@ class HeliosPyramidSampler(io.ComfyNode): x0_output = {} generated_chunks = [] if latents_history_short is not None and latents_history_mid is not None and latents_history_long is not None: - # Diffusers parity: `history_latents` storage does NOT include the keep_first_frame prefix slot. - # `latents_history_short` in conditioning may include [prefix + short_base], so strip prefix here. short_base_size = history_sizes_list[-1] if len(history_sizes_list) > 0 else latents_history_short.shape[2] if keep_first_frame and latents_history_short.shape[2] > short_base_size: short_for_history = latents_history_short[:, :, -short_base_size:] @@ -974,7 +1100,7 @@ class HeliosPyramidSampler(io.ComfyNode): hist_len = max(1, sum(history_sizes_list)) rolling_history = torch.zeros((b, c, hist_len, h, w), device=latent_samples.device, dtype=latent_samples.dtype) - # Align with Diffusers behavior: when initial video latents are provided, seed history buffer + # When initial video latents are provided, seed history buffer # with those latents before the first denoising chunk. if not add_noise: hist_len = max(1, sum(history_sizes_list)) @@ -988,9 +1114,29 @@ class HeliosPyramidSampler(io.ComfyNode): rolling_history = video_latents[:, :, -hist_len:] # Keep history/prefix on the same device/dtype as denoising latents. - rolling_history = rolling_history.to(device=target_device, dtype=latent_samples.dtype) + rolling_history = rolling_history.to(device=target_device, dtype=torch.float32) if image_latent_prefix is not None: - image_latent_prefix = image_latent_prefix.to(device=target_device, dtype=latent_samples.dtype) + image_latent_prefix = image_latent_prefix.to(device=target_device, dtype=torch.float32) + + history_output = history_full if history_full is not None else rolling_history + if "helios_history_latent_output" in latent: + history_output = _process_latent_in_preserve_zero_frames( + model, + latent["helios_history_latent_output"], + valid_mask=history_valid_mask, + ) + history_output = history_output.to(device=target_device, dtype=torch.float32) + if history_valid_mask is not None: + if not torch.is_tensor(history_valid_mask): + history_valid_mask = torch.tensor(history_valid_mask, device=target_device) + history_valid_mask = history_valid_mask.to(device=target_device) + if history_valid_mask.ndim == 2: + initial_generated_latent_frames = int(history_valid_mask.any(dim=0).sum().item()) + else: + initial_generated_latent_frames = int(history_valid_mask.reshape(-1).sum().item()) + else: + initial_generated_latent_frames = 0 + total_generated_latent_frames = initial_generated_latent_frames for chunk_idx in range(chunk_count): # Extract chunk from input latents @@ -1000,8 +1146,6 @@ class HeliosPyramidSampler(io.ComfyNode): # Prepare initial latent for this chunk if add_noise: - # Diffusers parity: each chunk denoises a fixed latent window size. - # Keep chunk temporal length constant and crop only after all chunks. noise_shape = ( latent_samples.shape[0], latent_samples.shape[1], @@ -1009,9 +1153,9 @@ class HeliosPyramidSampler(io.ComfyNode): latent_samples.shape[3], latent_samples.shape[4], ) - stage_latent = torch.randn(noise_shape, device=target_device, dtype=latent_samples.dtype, generator=noise_gen) + stage_latent = torch.randn(noise_shape, device=target_device, dtype=torch.float32, generator=noise_gen) else: - # Use actual input latents; pad final short chunk to fixed size like Diffusers windowing. + # Use actual input latents; pad final short chunk to fixed size. stage_latent = latent_chunk.clone() if stage_latent.shape[2] < chunk_t: if stage_latent.shape[2] == 0: @@ -1024,22 +1168,20 @@ class HeliosPyramidSampler(io.ComfyNode): latent_samples.shape[4], ), device=latent_samples.device, - dtype=latent_samples.dtype, + dtype=torch.float32, ) else: pad = stage_latent[:, :, -1:].repeat(1, 1, chunk_t - stage_latent.shape[2], 1, 1) stage_latent = torch.cat([stage_latent, pad], dim=2) + stage_latent = stage_latent.to(dtype=torch.float32) # Downsample to stage 0 resolution for _ in range(max(0, int(stage_count) - 1)): stage_latent = _downsample_latent_5d_bilinear_x2(stage_latent) - # Keep stage latents on model device for parity with Diffusers scheduler/noise path. + # Keep stage latents on model device for scheduler/noise path consistency. stage_latent = stage_latent.to(target_device) - # Diffusers parity: - # keep_first_frame=True and no image_latent_prefix on the first chunk - # should use an all-zero prefix frame, not history[:, :, :1]. chunk_prefix = image_latent_prefix if keep_first_frame and image_latent_prefix is None and chunk_idx == 0: chunk_prefix = torch.zeros( @@ -1065,6 +1207,10 @@ class HeliosPyramidSampler(io.ComfyNode): latents_history_short = _extract_condition_value(positive_chunk, "latents_history_short") latents_history_mid = _extract_condition_value(positive_chunk, "latents_history_mid") latents_history_long = _extract_condition_value(positive_chunk, "latents_history_long") + if debug_latent_stats: + print(f"[HeliosDebug][Sampler][chunk={chunk_idx}] latents_history_short: {_tensor_stats_str(latents_history_short)}") + print(f"[HeliosDebug][Sampler][chunk={chunk_idx}] latents_history_mid: {_tensor_stats_str(latents_history_mid)}") + print(f"[HeliosDebug][Sampler][chunk={chunk_idx}] latents_history_long: {_tensor_stats_str(latents_history_long)}") for stage_idx in range(stage_count): stage_latent = stage_latent.to(comfy.model_management.get_torch_device()) @@ -1099,8 +1245,7 @@ class HeliosPyramidSampler(io.ComfyNode): else: pass - # Keep parity with Diffusers pipeline order: - # stage timesteps are computed before upsampling/renoise for stage > 0. + # Stage timesteps are computed before upsampling/renoise for stage > 0. if stage_idx > 0: stage_latent = _upsample_latent_5d(stage_latent, scale=2) @@ -1188,8 +1333,7 @@ class HeliosPyramidSampler(io.ComfyNode): seed=noise_seed + chunk_idx * 100 + stage_idx, ) # sample_custom returns latent_format.process_out(samples); convert back to model-space - # so subsequent pyramid stages and history conditioning stay in the same latent space - # as Diffusers' internal denoising latents. + # so subsequent pyramid stages and history conditioning stay in the same latent space. stage_latent = model.model.process_latent_in(stage_latent) if stage_latent.shape[-2] != h or stage_latent.shape[-1] != w: @@ -1205,12 +1349,27 @@ class HeliosPyramidSampler(io.ComfyNode): rolling_history = torch.cat([rolling_history, stage_latent.to(rolling_history.device, rolling_history.dtype)], dim=2) keep_hist = max(1, sum(history_sizes_list)) rolling_history = rolling_history[:, :, -keep_hist:] + total_generated_latent_frames += stage_latent.shape[2] + history_output = torch.cat([history_output, stage_latent.to(history_output.device, history_output.dtype)], dim=2) - stage_latent = torch.cat(generated_chunks, dim=2)[:, :, :t] + include_history_in_output = _strict_bool(latent.get("helios_include_history_in_output", False), default=False) + if include_history_in_output and history_output is not None: + keep_t = max(0, int(total_generated_latent_frames)) + stage_latent = history_output[:, :, -keep_t:] if keep_t > 0 else history_output[:, :, :0] + elif len(generated_chunks) > 0: + stage_latent = torch.cat(generated_chunks, dim=2) + else: + stage_latent = torch.zeros((b, c, 0, h, w), device=target_device, dtype=torch.float32) out = latent.copy() out.pop("downscale_ratio_spacial", None) out["samples"] = model.model.process_latent_out(stage_latent) + out["helios_chunk_decode"] = True + out["helios_chunk_latent_frames"] = int(chunk_t) + out["helios_chunk_count"] = int(len(generated_chunks)) + out["helios_window_num_frames"] = int(window_num_frames) + out["helios_num_frames"] = int(num_frames) + out["helios_prefix_latent_frames"] = int(initial_generated_latent_frames if include_history_in_output else 0) if "x0" in x0_output: x0_out = model.model.process_latent_out(x0_output["x0"].cpu()) @@ -1222,6 +1381,60 @@ class HeliosPyramidSampler(io.ComfyNode): return io.NodeOutput(out, out_denoised) +class HeliosVAEDecode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HeliosVAEDecode", + category="latent", + inputs=[ + io.Latent.Input("samples"), + io.Vae.Input("vae"), + ], + outputs=[io.Image.Output(display_name="image")], + ) + + @classmethod + def execute(cls, samples, vae) -> io.NodeOutput: + latent = samples["samples"] + if latent.is_nested: + latent = latent.unbind()[0] + + helios_chunk_decode = bool(samples.get("helios_chunk_decode", False)) + helios_chunk_latent_frames = int(samples.get("helios_chunk_latent_frames", 0) or 0) + helios_prefix_latent_frames = int(samples.get("helios_prefix_latent_frames", 0) or 0) + + if ( + helios_chunk_decode + and latent.ndim == 5 + and helios_chunk_latent_frames > 0 + and latent.shape[2] > 0 + ): + decoded_chunks = [] + prefix_t = max(0, min(helios_prefix_latent_frames, latent.shape[2])) + + if prefix_t > 0: + decoded_chunks.append(vae.decode(latent[:, :, :prefix_t])) + + body = latent[:, :, prefix_t:] + for start in range(0, body.shape[2], helios_chunk_latent_frames): + chunk = body[:, :, start:start + helios_chunk_latent_frames] + if chunk.shape[2] == 0: + continue + decoded_chunks.append(vae.decode(chunk)) + + if len(decoded_chunks) > 0: + images = torch.cat(decoded_chunks, dim=1) + else: + images = vae.decode(latent) + else: + images = vae.decode(latent) + + if len(images.shape) == 5: + images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) + return io.NodeOutput(images) + + class HeliosExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -1231,6 +1444,7 @@ class HeliosExtension(ComfyExtension): HeliosVideoToVideo, HeliosHistoryConditioning, HeliosPyramidSampler, + HeliosVAEDecode, ]