ltxv2: Implement block prefetching

This commit is contained in:
Rattus 2026-04-27 22:49:03 +10:00
parent 74261f12f2
commit 9242551e9b
2 changed files with 9 additions and 10 deletions

View File

@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import (
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import comfy.ldm.common_dit
import comfy.ops
class CompressedTimestep:
"""Store video timestep embeddings in compressed form using per-frame indexing."""
@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel):
"""Process transformer blocks for LTXAV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
prefetch_queue = comfy.ops.make_prefetch_queue(list(self.transformer_blocks))
# Process transformer blocks
for i, block in enumerate(self.transformer_blocks):
comfy.ops.prefetch_queue_pop(prefetch_queue, vx.device, block)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel):
a_prompt_timestep=a_prompt_timestep,
)
comfy.ops.prefetch_queue_pop(prefetch_queue, vx.device, None)
return [vx, ax]
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):

View File

@ -176,6 +176,7 @@ def cast_prefetch_all(module, device):
comfy_modules.append(s)
offload_stream = cast_modules_with_vbar(comfy_modules, None, device, None, True)
comfy.model_management.sync_stream(device, offload_stream)
return (offload_stream, (module, comfy_modules))
@ -193,20 +194,13 @@ def prefetch_queue_pop(queue, device, module):
offload_stream.wait_stream(comfy.model_management.current_stream(device))
uncast_prefetch_all(prefetch_state)
active = queue[0]
if active is not None:
offload_stream, prefetch_state = active
assert prefetch_state[0] is module
if offload_stream is not None:
comfy.model_management.sync_stream(device, offload_stream)
prefetch = queue[1]
prefetch = queue[0]
if prefetch is not None:
queue[1] = cast_prefetch_all(prefetch, device)
queue[0] = cast_prefetch_all(prefetch, device)
def make_prefetch_queue(queue):
queue = [None, None] + queue + [None, None]
queue = [None] + queue + [None]
comfy.model_management.PREFETCH_QUEUES.append(queue)
return queue