block info

This commit is contained in:
Haoming 2025-12-10 15:14:01 +08:00
parent f668c2e3c9
commit a1c7584122

View File

@ -625,7 +625,10 @@ class NextDiT(nn.Module):
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(img.device)
transformer_options["total_blocks"] = len(self.layers)
transformer_options["block_type"] = "double"
for i, layer in enumerate(self.layers):
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device)
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]: