LTX2 context windows - Skip VRAM estimate clamp for packed latents

This commit is contained in:
ozbayb 2026-04-12 15:43:36 -06:00
parent b348c7fa61
commit d59d6fb7a0

View File

@ -656,19 +656,22 @@ class IndexListContextHandler(ContextHandlerABC):
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final) callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs): def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, conds, *args, **kwargs):
# limit noise_shape length to context_length for more accurate vram use estimation # Scale noise_shape to a single context window so VRAM estimation budgets per-window.
model_options = kwargs.get("model_options", None) model_options = kwargs.get("model_options", None)
if model_options is None: if model_options is None:
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.") raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
handler: IndexListContextHandler = model_options.get("context_handler", None) handler: IndexListContextHandler = model_options.get("context_handler", None)
if handler is not None: if handler is not None:
noise_shape = list(noise_shape) noise_shape = list(noise_shape)
# Guard: only clamp when dim is within bounds and the value is meaningful is_packed = len(noise_shape) == 3 and noise_shape[1] == 1
# (packed multimodal tensors have noise_shape=[B,1,flat] where flat is not frame count) if is_packed:
if handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length: # TODO: latent_shapes cond isn't attached yet at this point, so we can't compute a
# per-window flat latent here. Skipping the clamp over-estimates but prevents immediate OOM.
pass
elif handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length:
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length) noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
return executor(model, noise_shape, *args, **kwargs) return executor(model, noise_shape, conds, *args, **kwargs)
def create_prepare_sampling_wrapper(model: ModelPatcher): def create_prepare_sampling_wrapper(model: ModelPatcher):