diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index c47df49ca..123c191a9 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -625,7 +625,10 @@ class NextDiT(nn.Module): img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) freqs_cis = freqs_cis.to(img.device) + transformer_options["total_blocks"] = len(self.layers) + transformer_options["block_type"] = "double" for i, layer in enumerate(self.layers): + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device) img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) if "double_block" in patches: for p in patches["double_block"]: