This commit is contained in:
Haoming 2025-12-11 10:40:11 +08:00
parent a1c7584122
commit f332332495

View File

@ -628,7 +628,7 @@ class NextDiT(nn.Module):
transformer_options["total_blocks"] = len(self.layers)
transformer_options["block_type"] = "double"
for i, layer in enumerate(self.layers):
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device)
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu")
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]: