mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-19 05:27:24 +08:00
ltxv2: Implement block prefetching
This commit is contained in:
parent
74261f12f2
commit
9242551e9b
@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import (
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.ops
|
||||
|
||||
class CompressedTimestep:
|
||||
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||
@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel):
|
||||
"""Process transformer blocks for LTXAV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
prefetch_queue = comfy.ops.make_prefetch_queue(list(self.transformer_blocks))
|
||||
|
||||
# Process transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
comfy.ops.prefetch_queue_pop(prefetch_queue, vx.device, block)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
|
||||
def block_wrap(args):
|
||||
@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel):
|
||||
a_prompt_timestep=a_prompt_timestep,
|
||||
)
|
||||
|
||||
comfy.ops.prefetch_queue_pop(prefetch_queue, vx.device, None)
|
||||
|
||||
return [vx, ax]
|
||||
|
||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||
|
||||
14
comfy/ops.py
14
comfy/ops.py
@ -176,6 +176,7 @@ def cast_prefetch_all(module, device):
|
||||
comfy_modules.append(s)
|
||||
|
||||
offload_stream = cast_modules_with_vbar(comfy_modules, None, device, None, True)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
return (offload_stream, (module, comfy_modules))
|
||||
|
||||
|
||||
@ -193,20 +194,13 @@ def prefetch_queue_pop(queue, device, module):
|
||||
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
||||
uncast_prefetch_all(prefetch_state)
|
||||
|
||||
active = queue[0]
|
||||
if active is not None:
|
||||
offload_stream, prefetch_state = active
|
||||
assert prefetch_state[0] is module
|
||||
if offload_stream is not None:
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
prefetch = queue[1]
|
||||
prefetch = queue[0]
|
||||
if prefetch is not None:
|
||||
queue[1] = cast_prefetch_all(prefetch, device)
|
||||
queue[0] = cast_prefetch_all(prefetch, device)
|
||||
|
||||
|
||||
def make_prefetch_queue(queue):
|
||||
queue = [None, None] + queue + [None, None]
|
||||
queue = [None] + queue + [None]
|
||||
comfy.model_management.PREFETCH_QUEUES.append(queue)
|
||||
return queue
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user