This commit is contained in:
Haoming 2025-12-11 10:41:18 +08:00
parent 8ec0e3e6e3
commit bb1513dc3c
2 changed files with 6 additions and 6 deletions

View File

@ -571,7 +571,7 @@ class WanModel(torch.nn.Module):
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device)
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu")
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -769,7 +769,7 @@ class VaceWanModel(WanModel):
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device)
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu")
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -871,7 +871,7 @@ class CameraWanModel(WanModel):
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device)
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu")
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -1338,7 +1338,7 @@ class WanModel_S2V(WanModel):
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device)
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu")
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -1589,7 +1589,7 @@ class HumoWanModel(WanModel):
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device)
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu")
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}

View File

@ -526,7 +526,7 @@ class AnimateWanModel(WanModel):
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device)
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu")
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}