diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 4ace5ec13..b528f6327 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -619,7 +619,28 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.") if not handler.freenoise: return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) - noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"]) + + # For packed multimodal tensors (e.g. LTXAV), noise is [B, 1, flat] and FreeNoise + # must only shuffle the video portion. Unpack, apply to video, repack. + latent_shapes = None + try: + latent_shapes = guider.conds['positive'][0]['model_conds']['latent_shapes'].cond + except (KeyError, IndexError, AttributeError): + pass + + if latent_shapes is not None and len(latent_shapes) > 1: + modalities = comfy.utils.unpack_latents(noise, latent_shapes) + video_total = latent_shapes[0][handler.dim] + modalities[0] = apply_freenoise(modalities[0], handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"]) + for i in range(1, len(modalities)): + mod_total = latent_shapes[i][handler.dim] + ratio = mod_total / video_total if video_total > 0 else 1 + mod_ctx_len = max(round(handler.context_length * ratio), 1) + mod_ctx_overlap = max(round(handler.context_overlap * ratio), 0) + modalities[i] = apply_freenoise(modalities[i], handler.dim, mod_ctx_len, mod_ctx_overlap, extra_args["seed"]) + noise, _ = comfy.utils.pack_latents(modalities) + else: + noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"]) return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) diff --git a/comfy/model_base.py b/comfy/model_base.py index ae2ce2eb0..893beb85a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1197,17 +1197,18 @@ class LTXAV(BaseModel): return result video_total = latent_shapes[0][dim] - audio_total = latent_shapes[1][dim] + video_window_len = len(primary_indices) - # Proportional mapping — video and audio cover same real-time duration - v_start, v_end = min(primary_indices), max(primary_indices) + 1 - a_start = round(v_start * audio_total / video_total) - a_end = round(v_end * audio_total / video_total) - audio_indices = list(range(a_start, min(a_end, audio_total))) - if not audio_indices: - audio_indices = [min(a_start, audio_total - 1)] + for i in range(1, len(latent_shapes)): + mod_total = latent_shapes[i][dim] + # Length proportional to video window frame count (not index span) + mod_window_len = max(round(video_window_len * mod_total / video_total), 1) + # Anchor to end of video range + v_end = max(primary_indices) + 1 + mod_end = min(round(v_end * mod_total / video_total), mod_total) + mod_start = max(mod_end - mod_window_len, 0) + result.append(list(range(mod_start, min(mod_start + mod_window_len, mod_total)))) - result.append(audio_indices) return result def get_guide_frame_count(self, x, conds):