diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 6f2ba41ef..55d25d897 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import ( from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector import comfy.ldm.common_dit +import comfy.ops class CompressedTimestep: """Store video timestep embeddings in compressed form using per-frame indexing.""" @@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel): """Process transformer blocks for LTXAV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + prefetch_queue = comfy.ops.make_prefetch_queue(list(self.transformer_blocks)) # Process transformer blocks for i, block in enumerate(self.transformer_blocks): + comfy.ops.prefetch_queue_pop(prefetch_queue, vx.device, block) if ("double_block", i) in blocks_replace: def block_wrap(args): @@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel): a_prompt_timestep=a_prompt_timestep, ) + comfy.ops.prefetch_queue_pop(prefetch_queue, vx.device, None) + return [vx, ax] def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): diff --git a/comfy/ops.py b/comfy/ops.py index 2e0b45be4..34b70a1bf 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -176,6 +176,7 @@ def cast_prefetch_all(module, device): comfy_modules.append(s) offload_stream = cast_modules_with_vbar(comfy_modules, None, device, None, True) + comfy.model_management.sync_stream(device, offload_stream) return (offload_stream, (module, comfy_modules)) @@ -193,20 +194,13 @@ def prefetch_queue_pop(queue, device, module): offload_stream.wait_stream(comfy.model_management.current_stream(device)) uncast_prefetch_all(prefetch_state) - active = queue[0] - if active is not None: - offload_stream, prefetch_state = active - assert prefetch_state[0] is module - if offload_stream is not None: - comfy.model_management.sync_stream(device, offload_stream) - - prefetch = queue[1] + prefetch = queue[0] if prefetch is not None: - queue[1] = cast_prefetch_all(prefetch, device) + queue[0] = cast_prefetch_all(prefetch, device) def make_prefetch_queue(queue): - queue = [None, None] + queue + [None, None] + queue = [None] + queue + [None] comfy.model_management.PREFETCH_QUEUES.append(queue) return queue