From 8ec0e3e6e3314773c60bc0b88c1a33d9e860dcbf Mon Sep 17 00:00:00 2001 From: Haoming Date: Tue, 9 Dec 2025 11:46:20 +0800 Subject: [PATCH] tensor --- comfy/ldm/wan/model.py | 10 +++++----- comfy/ldm/wan/model_animate.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 4216ce831..ae0625df7 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -571,7 +571,7 @@ class WanModel(torch.nn.Module): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -769,7 +769,7 @@ class VaceWanModel(WanModel): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -871,7 +871,7 @@ class CameraWanModel(WanModel): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -1338,7 +1338,7 @@ class WanModel_S2V(WanModel): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -1589,7 +1589,7 @@ class HumoWanModel(WanModel): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} diff --git a/comfy/ldm/wan/model_animate.py b/comfy/ldm/wan/model_animate.py index 84d7adec4..fee8c8df9 100644 --- a/comfy/ldm/wan/model_animate.py +++ b/comfy/ldm/wan/model_animate.py @@ -526,7 +526,7 @@ class AnimateWanModel(WanModel): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {}