From 846a019fc2d487c14158beacd0f0b353c83e5d75 Mon Sep 17 00:00:00 2001 From: Haoming Date: Tue, 9 Dec 2025 11:44:12 +0800 Subject: [PATCH 1/2] tensor --- comfy/ldm/chroma/model.py | 4 ++-- comfy/ldm/flux/model.py | 4 ++-- comfy/ldm/hunyuan_video/model.py | 4 ++-- comfy/ldm/kandinsky5/model.py | 2 +- comfy/ldm/modules/attention.py | 4 ++-- comfy/ldm/qwen_image/model.py | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 2e8ef0687..05a82e63f 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -183,7 +183,7 @@ class Chroma(nn.Module): transformer_options["total_blocks"] = len(self.double_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device) if i not in self.skip_mmdit: double_mod = ( self.get_modulations(mod_vectors, "double_img", idx=i), @@ -229,7 +229,7 @@ class Chroma(nn.Module): transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device) if i not in self.skip_dit: single_mod = self.get_modulations(mod_vectors, "single", idx=i) if ("single_block", i) in blocks_replace: diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index f40c2a7a9..e098c4f8f 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -186,7 +186,7 @@ class Flux(nn.Module): transformer_options["total_blocks"] = len(self.double_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -233,7 +233,7 @@ class Flux(nn.Module): transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device) if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 2749c53f5..1f6c07d6b 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -392,7 +392,7 @@ class HunyuanVideo(nn.Module): transformer_options["total_blocks"] = len(self.double_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -417,7 +417,7 @@ class HunyuanVideo(nn.Module): transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device) if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index 1509de2f8..1feb0adc5 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -375,7 +375,7 @@ class Kandinsky5(nn.Module): transformer_options["total_blocks"] = len(self.visual_transformer_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.visual_transformer_blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=visual_embed.device) if ("double_block", i) in blocks_replace: def block_wrap(args): return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options")) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index a8800ded0..ea803ba5c 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -921,7 +921,7 @@ class SpatialTransformer(nn.Module): if self.use_linear: x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) x = block(x, context=context[i], transformer_options=transformer_options) if self.use_linear: x = self.proj_out(x) @@ -1067,7 +1067,7 @@ class SpatialVideoTransformer(SpatialTransformer): for it_, (block, mix_block) in enumerate( zip(self.transformer_blocks, self.time_stack) ): - transformer_options["block_index"] = it_ + transformer_options["block_index"] = torch.tensor(it_, dtype=torch.uint8, device=x.device) x = block( x, context=spatial_context, diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 8c75670cd..7d075481f 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -442,7 +442,7 @@ class QwenImageTransformer2DModel(nn.Module): transformer_options["total_blocks"] = len(self.transformer_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.transformer_blocks): - transformer_options["block_index"] = i + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=hidden_states.device) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} From 5884783f90727b3213e66a03e5468e71eb54b78e Mon Sep 17 00:00:00 2001 From: Haoming Date: Thu, 11 Dec 2025 10:44:00 +0800 Subject: [PATCH 2/2] device --- comfy/ldm/chroma/model.py | 4 ++-- comfy/ldm/flux/model.py | 4 ++-- comfy/ldm/hunyuan_video/model.py | 4 ++-- comfy/ldm/kandinsky5/model.py | 2 +- comfy/ldm/modules/attention.py | 4 ++-- comfy/ldm/qwen_image/model.py | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 05a82e63f..3114b6a35 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -183,7 +183,7 @@ class Chroma(nn.Module): transformer_options["total_blocks"] = len(self.double_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if i not in self.skip_mmdit: double_mod = ( self.get_modulations(mod_vectors, "double_img", idx=i), @@ -229,7 +229,7 @@ class Chroma(nn.Module): transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if i not in self.skip_dit: single_mod = self.get_modulations(mod_vectors, "single", idx=i) if ("single_block", i) in blocks_replace: diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index e098c4f8f..e4b1bf356 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -186,7 +186,7 @@ class Flux(nn.Module): transformer_options["total_blocks"] = len(self.double_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -233,7 +233,7 @@ class Flux(nn.Module): transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 1f6c07d6b..49f3320fb 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -392,7 +392,7 @@ class HunyuanVideo(nn.Module): transformer_options["total_blocks"] = len(self.double_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -417,7 +417,7 @@ class HunyuanVideo(nn.Module): transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index 1feb0adc5..464c1be3b 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -375,7 +375,7 @@ class Kandinsky5(nn.Module): transformer_options["total_blocks"] = len(self.visual_transformer_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.visual_transformer_blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=visual_embed.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("double_block", i) in blocks_replace: def block_wrap(args): return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options")) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index ea803ba5c..dfa278ca5 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -921,7 +921,7 @@ class SpatialTransformer(nn.Module): if self.use_linear: x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") x = block(x, context=context[i], transformer_options=transformer_options) if self.use_linear: x = self.proj_out(x) @@ -1067,7 +1067,7 @@ class SpatialVideoTransformer(SpatialTransformer): for it_, (block, mix_block) in enumerate( zip(self.transformer_blocks, self.time_stack) ): - transformer_options["block_index"] = torch.tensor(it_, dtype=torch.uint8, device=x.device) + transformer_options["block_index"] = torch.tensor(it_, dtype=torch.uint8, device="cpu") x = block( x, context=spatial_context, diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 7d075481f..40b5e8e6a 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -442,7 +442,7 @@ class QwenImageTransformer2DModel(nn.Module): transformer_options["total_blocks"] = len(self.transformer_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.transformer_blocks): - transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=hidden_states.device) + transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu") if ("double_block", i) in blocks_replace: def block_wrap(args): out = {}