From 0814c1f63f5f98484ea4e39bf4bd810e3666c078 Mon Sep 17 00:00:00 2001 From: Rattus Date: Sun, 2 Nov 2025 02:36:40 +1000 Subject: [PATCH] qwen: Implement transformer block prefetching --- comfy/ldm/qwen_image/model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index e5d0d17c1..05edd985b 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -387,6 +387,9 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states, img_ids, orig_shape = self.process_img(x) num_embeds = hidden_states.shape[1] + prefetch_queue = comfy.ops.make_prefetch_queue(list(self.transformer_blocks)) + comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, None) + if ref_latents is not None: h = 0 w = 0 @@ -436,6 +439,7 @@ class QwenImageTransformer2DModel(nn.Module): blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.transformer_blocks): + comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -467,6 +471,8 @@ class QwenImageTransformer2DModel(nn.Module): if add is not None: hidden_states[:, :add.shape[1]] += add + comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block) + hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states)