mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-21 06:27:24 +08:00
ltxv2: Implement block prefetching
This commit is contained in:
parent
74261f12f2
commit
9242551e9b
@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import (
|
|||||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
class CompressedTimestep:
|
class CompressedTimestep:
|
||||||
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||||
@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel):
|
|||||||
"""Process transformer blocks for LTXAV."""
|
"""Process transformer blocks for LTXAV."""
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
prefetch_queue = comfy.ops.make_prefetch_queue(list(self.transformer_blocks))
|
||||||
|
|
||||||
# Process transformer blocks
|
# Process transformer blocks
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
comfy.ops.prefetch_queue_pop(prefetch_queue, vx.device, block)
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
|
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel):
|
|||||||
a_prompt_timestep=a_prompt_timestep,
|
a_prompt_timestep=a_prompt_timestep,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
comfy.ops.prefetch_queue_pop(prefetch_queue, vx.device, None)
|
||||||
|
|
||||||
return [vx, ax]
|
return [vx, ax]
|
||||||
|
|
||||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||||
|
|||||||
14
comfy/ops.py
14
comfy/ops.py
@ -176,6 +176,7 @@ def cast_prefetch_all(module, device):
|
|||||||
comfy_modules.append(s)
|
comfy_modules.append(s)
|
||||||
|
|
||||||
offload_stream = cast_modules_with_vbar(comfy_modules, None, device, None, True)
|
offload_stream = cast_modules_with_vbar(comfy_modules, None, device, None, True)
|
||||||
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
return (offload_stream, (module, comfy_modules))
|
return (offload_stream, (module, comfy_modules))
|
||||||
|
|
||||||
|
|
||||||
@ -193,20 +194,13 @@ def prefetch_queue_pop(queue, device, module):
|
|||||||
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
||||||
uncast_prefetch_all(prefetch_state)
|
uncast_prefetch_all(prefetch_state)
|
||||||
|
|
||||||
active = queue[0]
|
prefetch = queue[0]
|
||||||
if active is not None:
|
|
||||||
offload_stream, prefetch_state = active
|
|
||||||
assert prefetch_state[0] is module
|
|
||||||
if offload_stream is not None:
|
|
||||||
comfy.model_management.sync_stream(device, offload_stream)
|
|
||||||
|
|
||||||
prefetch = queue[1]
|
|
||||||
if prefetch is not None:
|
if prefetch is not None:
|
||||||
queue[1] = cast_prefetch_all(prefetch, device)
|
queue[0] = cast_prefetch_all(prefetch, device)
|
||||||
|
|
||||||
|
|
||||||
def make_prefetch_queue(queue):
|
def make_prefetch_queue(queue):
|
||||||
queue = [None, None] + queue + [None, None]
|
queue = [None] + queue + [None]
|
||||||
comfy.model_management.PREFETCH_QUEUES.append(queue)
|
comfy.model_management.PREFETCH_QUEUES.append(queue)
|
||||||
return queue
|
return queue
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user