qwen: Implement transformer block prefetching

This commit is contained in:
Rattus 2025-11-02 02:36:40 +10:00
parent e279e1f26e
commit 0814c1f63f

View File

@ -387,6 +387,9 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1]
prefetch_queue = comfy.ops.make_prefetch_queue(list(self.transformer_blocks))
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, None)
if ref_latents is not None:
h = 0
w = 0
@ -436,6 +439,7 @@ class QwenImageTransformer2DModel(nn.Module):
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks):
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -467,6 +471,8 @@ class QwenImageTransformer2DModel(nn.Module):
if add is not None:
hidden_states[:, :add.shape[1]] += add
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)