From d59d6fb7a0796a6adfa9efcd55f80b4477e4020a Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Sun, 12 Apr 2026 15:43:36 -0600 Subject: [PATCH] LTX2 context windows - Skip VRAM estimate clamp for packed latents --- comfy/context_windows.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index a9f456426..89963699c 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -656,19 +656,22 @@ class IndexListContextHandler(ContextHandlerABC): callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final) -def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs): - # limit noise_shape length to context_length for more accurate vram use estimation +def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, conds, *args, **kwargs): + # Scale noise_shape to a single context window so VRAM estimation budgets per-window. model_options = kwargs.get("model_options", None) if model_options is None: raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.") handler: IndexListContextHandler = model_options.get("context_handler", None) if handler is not None: noise_shape = list(noise_shape) - # Guard: only clamp when dim is within bounds and the value is meaningful - # (packed multimodal tensors have noise_shape=[B,1,flat] where flat is not frame count) - if handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length: + is_packed = len(noise_shape) == 3 and noise_shape[1] == 1 + if is_packed: + # TODO: latent_shapes cond isn't attached yet at this point, so we can't compute a + # per-window flat latent here. Skipping the clamp over-estimates but prevents immediate OOM. + pass + elif handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length: noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length) - return executor(model, noise_shape, *args, **kwargs) + return executor(model, noise_shape, conds, *args, **kwargs) def create_prepare_sampling_wrapper(model: ModelPatcher):