This commit is contained in:
Haoming 2025-12-09 11:44:12 +08:00
parent 3b0368aa34
commit 846a019fc2
6 changed files with 10 additions and 10 deletions

View File

@ -183,7 +183,7 @@ class Chroma(nn.Module):
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device)
if i not in self.skip_mmdit:
double_mod = (
self.get_modulations(mod_vectors, "double_img", idx=i),
@ -229,7 +229,7 @@ class Chroma(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device)
if i not in self.skip_dit:
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
if ("single_block", i) in blocks_replace:

View File

@ -186,7 +186,7 @@ class Flux(nn.Module):
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -233,7 +233,7 @@ class Flux(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device)
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}

View File

@ -392,7 +392,7 @@ class HunyuanVideo(nn.Module):
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -417,7 +417,7 @@ class HunyuanVideo(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device)
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}

View File

@ -375,7 +375,7 @@ class Kandinsky5(nn.Module):
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.visual_transformer_blocks):
transformer_options["block_index"] = i
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=visual_embed.device)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))

View File

@ -921,7 +921,7 @@ class SpatialTransformer(nn.Module):
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device)
x = block(x, context=context[i], transformer_options=transformer_options)
if self.use_linear:
x = self.proj_out(x)
@ -1067,7 +1067,7 @@ class SpatialVideoTransformer(SpatialTransformer):
for it_, (block, mix_block) in enumerate(
zip(self.transformer_blocks, self.time_stack)
):
transformer_options["block_index"] = it_
transformer_options["block_index"] = torch.tensor(it_, dtype=torch.uint8, device=x.device)
x = block(
x,
context=spatial_context,

View File

@ -442,7 +442,7 @@ class QwenImageTransformer2DModel(nn.Module):
transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=hidden_states.device)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}