From ef2c560e13d216df9fee48691437327b17456e29 Mon Sep 17 00:00:00 2001 From: Haoming Date: Sun, 23 Nov 2025 15:04:25 +0800 Subject: [PATCH 1/4] block info --- comfy/ldm/wan/model.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index a9d5e10d9..4216ce831 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -568,7 +568,10 @@ class WanModel(torch.nn.Module): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -763,7 +766,10 @@ class VaceWanModel(WanModel): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -862,7 +868,10 @@ class CameraWanModel(WanModel): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -1326,16 +1335,19 @@ class WanModel_S2V(WanModel): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"]) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context) + x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options) if audio_emb is not None: x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len) # head @@ -1574,7 +1586,10 @@ class HumoWanModel(WanModel): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} From 9e19885a6f85e050caf8d7e87af6587e7130ebe0 Mon Sep 17 00:00:00 2001 From: Haoming Date: Sun, 23 Nov 2025 15:10:18 +0800 Subject: [PATCH 2/4] animate --- comfy/ldm/wan/model_animate.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/ldm/wan/model_animate.py b/comfy/ldm/wan/model_animate.py index 7c87835d4..84d7adec4 100644 --- a/comfy/ldm/wan/model_animate.py +++ b/comfy/ldm/wan/model_animate.py @@ -523,7 +523,10 @@ class AnimateWanModel(WanModel): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} From 8ec0e3e6e3314773c60bc0b88c1a33d9e860dcbf Mon Sep 17 00:00:00 2001 From: Haoming Date: Tue, 9 Dec 2025 11:46:20 +0800 Subject: [PATCH 3/4] tensor --- comfy/ldm/wan/model.py | 10 +++++----- comfy/ldm/wan/model_animate.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 4216ce831..ae0625df7 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -571,7 +571,7 @@ class WanModel(torch.nn.Module): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -769,7 +769,7 @@ class VaceWanModel(WanModel): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -871,7 +871,7 @@ class CameraWanModel(WanModel): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -1338,7 +1338,7 @@ class WanModel_S2V(WanModel): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -1589,7 +1589,7 @@ class HumoWanModel(WanModel): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} diff --git a/comfy/ldm/wan/model_animate.py b/comfy/ldm/wan/model_animate.py index 84d7adec4..fee8c8df9 100644 --- a/comfy/ldm/wan/model_animate.py +++ b/comfy/ldm/wan/model_animate.py @@ -526,7 +526,7 @@ class AnimateWanModel(WanModel): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} From bb1513dc3c6076327edb9e0dee4d834b66a911d7 Mon Sep 17 00:00:00 2001 From: Haoming Date: Thu, 11 Dec 2025 10:41:18 +0800 Subject: [PATCH 4/4] device --- comfy/ldm/wan/model.py | 10 +++++----- comfy/ldm/wan/model_animate.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index ae0625df7..72f99462f 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -571,7 +571,7 @@ class WanModel(torch.nn.Module): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -769,7 +769,7 @@ class VaceWanModel(WanModel): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -871,7 +871,7 @@ class CameraWanModel(WanModel): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -1338,7 +1338,7 @@ class WanModel_S2V(WanModel): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -1589,7 +1589,7 @@ class HumoWanModel(WanModel): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} diff --git a/comfy/ldm/wan/model_animate.py b/comfy/ldm/wan/model_animate.py index fee8c8df9..9dcd141a3 100644 --- a/comfy/ldm/wan/model_animate.py +++ b/comfy/ldm/wan/model_animate.py @@ -526,7 +526,7 @@ class AnimateWanModel(WanModel): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("double_block", i) in blocks_replace: def block_wrap(args): out = {}