mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 00:30:55 +08:00
wan: Implement block level prefetching
This commit is contained in:
parent
0814c1f63f
commit
2e843f309d
@ -538,6 +538,8 @@ class WanModel(torch.nn.Module):
|
||||
List[Tensor]:
|
||||
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
prefetch_queue = comfy.ops.make_prefetch_queue(list(self.blocks))
|
||||
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, None)
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
@ -569,6 +571,7 @@ class WanModel(torch.nn.Module):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -578,6 +581,7 @@ class WanModel(torch.nn.Module):
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user