mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 00:30:55 +08:00
qwen: Implement transformer block prefetching
This commit is contained in:
parent
e279e1f26e
commit
0814c1f63f
@ -387,6 +387,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||
num_embeds = hidden_states.shape[1]
|
||||
|
||||
prefetch_queue = comfy.ops.make_prefetch_queue(list(self.transformer_blocks))
|
||||
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, None)
|
||||
|
||||
if ref_latents is not None:
|
||||
h = 0
|
||||
w = 0
|
||||
@ -436,6 +439,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -467,6 +471,8 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
if add is not None:
|
||||
hidden_states[:, :add.shape[1]] += add
|
||||
|
||||
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user