wan: Implement block level prefetching

This commit is contained in:
Rattus 2025-11-04 08:34:28 +10:00
parent 0814c1f63f
commit 2e843f309d

View File

@ -538,6 +538,8 @@ class WanModel(torch.nn.Module):
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
prefetch_queue = comfy.ops.make_prefetch_queue(list(self.blocks))
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, None)
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
@ -569,6 +571,7 @@ class WanModel(torch.nn.Module):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -578,6 +581,7 @@ class WanModel(torch.nn.Module):
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block)
# head
x = self.head(x, e)