diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 123c191a9..0ff8539f7 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -628,7 +628,7 @@ class NextDiT(nn.Module): transformer_options["total_blocks"] = len(self.layers) transformer_options["block_type"] = "double" for i, layer in enumerate(self.layers): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) if "double_block" in patches: for p in patches["double_block"]: