Make tensor int again

This commit is contained in:
Jedrzej Kosinski 2025-12-15 17:30:55 -08:00
parent f332332495
commit afb6191099

View File

@ -628,7 +628,7 @@ class NextDiT(nn.Module):
transformer_options["total_blocks"] = len(self.layers)
transformer_options["block_type"] = "double"
for i, layer in enumerate(self.layers):
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu")
transformer_options["block_index"] = i
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]: