diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 2e8ef0687..3114b6a35 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -183,7 +183,7 @@ class Chroma(nn.Module): transformer_options["total_blocks"] = len(self.double_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if i not in self.skip_mmdit: double_mod = ( self.get_modulations(mod_vectors, "double_img", idx=i), @@ -229,7 +229,7 @@ class Chroma(nn.Module): transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if i not in self.skip_dit: single_mod = self.get_modulations(mod_vectors, "single", idx=i) if ("single_block", i) in blocks_replace: diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index f40c2a7a9..e4b1bf356 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -186,7 +186,7 @@ class Flux(nn.Module): transformer_options["total_blocks"] = len(self.double_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -233,7 +233,7 @@ class Flux(nn.Module): transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 55ab550f8..2509e5134 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -393,7 +393,7 @@ class HunyuanVideo(nn.Module): transformer_options["total_blocks"] = len(self.double_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -418,7 +418,7 @@ class HunyuanVideo(nn.Module): transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index 1509de2f8..464c1be3b 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -375,7 +375,7 @@ class Kandinsky5(nn.Module): transformer_options["total_blocks"] = len(self.visual_transformer_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.visual_transformer_blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("double_block", i) in blocks_replace: def block_wrap(args): return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options")) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index a8800ded0..dfa278ca5 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -921,7 +921,7 @@ class SpatialTransformer(nn.Module): if self.use_linear: x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") x = block(x, context=context[i], transformer_options=transformer_options) if self.use_linear: x = self.proj_out(x) @@ -1067,7 +1067,7 @@ class SpatialVideoTransformer(SpatialTransformer): for it_, (block, mix_block) in enumerate( zip(self.transformer_blocks, self.time_stack) ): - transformer_options["block_index"] = it_ + transformer_options["block_index"] = torch.tensor(it_, dtype=torch.uint8, device="cpu") x = block( x, context=spatial_context, diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 8c75670cd..40b5e8e6a 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -442,7 +442,7 @@ class QwenImageTransformer2DModel(nn.Module): transformer_options["total_blocks"] = len(self.transformer_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.transformer_blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("double_block", i) in blocks_replace: def block_wrap(args): out = {}