From 2e843f309d550551ff3a8ecfb0fb9b8e627b7549 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 4 Nov 2025 08:34:28 +1000 Subject: [PATCH] wan: Implement block level prefetching --- comfy/ldm/wan/model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index a9d5e10d9..fa2e90c9f 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -538,6 +538,8 @@ class WanModel(torch.nn.Module): List[Tensor]: List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] """ + prefetch_queue = comfy.ops.make_prefetch_queue(list(self.blocks)) + comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, None) # embeddings x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] @@ -569,6 +571,7 @@ class WanModel(torch.nn.Module): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.blocks): + comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -578,6 +581,7 @@ class WanModel(torch.nn.Module): x = out["img"] else: x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) + comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block) # head x = self.head(x, e)