From bb1513dc3c6076327edb9e0dee4d834b66a911d7 Mon Sep 17 00:00:00 2001 From: Haoming Date: Thu, 11 Dec 2025 10:41:18 +0800 Subject: [PATCH] device --- comfy/ldm/wan/model.py | 10 +++++----- comfy/ldm/wan/model_animate.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index ae0625df7..72f99462f 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -571,7 +571,7 @@ class WanModel(torch.nn.Module): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -769,7 +769,7 @@ class VaceWanModel(WanModel): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -871,7 +871,7 @@ class CameraWanModel(WanModel): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -1338,7 +1338,7 @@ class WanModel_S2V(WanModel): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -1589,7 +1589,7 @@ class HumoWanModel(WanModel): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} diff --git a/comfy/ldm/wan/model_animate.py b/comfy/ldm/wan/model_animate.py index fee8c8df9..9dcd141a3 100644 --- a/comfy/ldm/wan/model_animate.py +++ b/comfy/ldm/wan/model_animate.py @@ -526,7 +526,7 @@ class AnimateWanModel(WanModel): transformer_options["total_blocks"] = len(self.blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("double_block", i) in blocks_replace: def block_wrap(args): out = {}