mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 15:32:32 +08:00
LTX2 context windows - Skip VRAM estimate clamp for packed latents
This commit is contained in:
parent
b348c7fa61
commit
d59d6fb7a0
@ -656,19 +656,22 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
|
||||
|
||||
|
||||
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
|
||||
# limit noise_shape length to context_length for more accurate vram use estimation
|
||||
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, conds, *args, **kwargs):
|
||||
# Scale noise_shape to a single context window so VRAM estimation budgets per-window.
|
||||
model_options = kwargs.get("model_options", None)
|
||||
if model_options is None:
|
||||
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
|
||||
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||
if handler is not None:
|
||||
noise_shape = list(noise_shape)
|
||||
# Guard: only clamp when dim is within bounds and the value is meaningful
|
||||
# (packed multimodal tensors have noise_shape=[B,1,flat] where flat is not frame count)
|
||||
if handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length:
|
||||
is_packed = len(noise_shape) == 3 and noise_shape[1] == 1
|
||||
if is_packed:
|
||||
# TODO: latent_shapes cond isn't attached yet at this point, so we can't compute a
|
||||
# per-window flat latent here. Skipping the clamp over-estimates but prevents immediate OOM.
|
||||
pass
|
||||
elif handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length:
|
||||
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
||||
return executor(model, noise_shape, *args, **kwargs)
|
||||
return executor(model, noise_shape, conds, *args, **kwargs)
|
||||
|
||||
|
||||
def create_prepare_sampling_wrapper(model: ModelPatcher):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user