mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-21 20:00:17 +08:00
qwen: Implement transformer block prefetching
This commit is contained in:
parent
e279e1f26e
commit
0814c1f63f
@ -387,6 +387,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
hidden_states, img_ids, orig_shape = self.process_img(x)
|
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||||
num_embeds = hidden_states.shape[1]
|
num_embeds = hidden_states.shape[1]
|
||||||
|
|
||||||
|
prefetch_queue = comfy.ops.make_prefetch_queue(list(self.transformer_blocks))
|
||||||
|
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, None)
|
||||||
|
|
||||||
if ref_latents is not None:
|
if ref_latents is not None:
|
||||||
h = 0
|
h = 0
|
||||||
w = 0
|
w = 0
|
||||||
@ -436,6 +439,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block)
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
@ -467,6 +471,8 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
if add is not None:
|
if add is not None:
|
||||||
hidden_states[:, :add.shape[1]] += add
|
hidden_states[:, :add.shape[1]] += add
|
||||||
|
|
||||||
|
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block)
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states, temb)
|
hidden_states = self.norm_out(hidden_states, temb)
|
||||||
hidden_states = self.proj_out(hidden_states)
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user