From 265b4f0fa1eba7e00950df4e4acb8ce312584aed Mon Sep 17 00:00:00 2001 From: Haoming Date: Wed, 31 Dec 2025 15:30:10 +0800 Subject: [PATCH 1/5] init --- comfy_extras/nodes_sage3.py | 78 +++++++++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 79 insertions(+) create mode 100644 comfy_extras/nodes_sage3.py diff --git a/comfy_extras/nodes_sage3.py b/comfy_extras/nodes_sage3.py new file mode 100644 index 000000000..7fe5c114c --- /dev/null +++ b/comfy_extras/nodes_sage3.py @@ -0,0 +1,78 @@ +from typing import Callable + +import torch +from typing_extensions import override + +from comfy.ldm.modules.attention import get_attention_function +from comfy.model_patcher import ModelPatcher +from comfy_api.latest import ComfyExtension, io +from server import PromptServer + + +class Sage3PatchModel(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Sage3PatchModel", + display_name="Patch SageAttention 3", + description="Apply `attention3_sage` to the middle blocks and steps, while using optimized_attention for the first/last blocks and steps", + category="_for_testing", + inputs=[ + io.Model.Input("model"), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) + + @classmethod + def execute(cls, model: ModelPatcher) -> io.NodeOutput: + sage3: Callable | None = get_attention_function("sage3", default=None) + + if sage3 is None: + PromptServer.instance.send_progress_text( + "`sageattn3` is not installed / available...", + cls.hidden.unique_id, + ) + return io.NodeOutput(model) + + def attention_override(func: Callable, *args, **kwargs): + transformer_options: dict = kwargs.get("transformer_options", {}) + + block_index: int = transformer_options.get("block_index", 0) + total_blocks: int = transformer_options.get("total_blocks", 1) + + if block_index == 0 or block_index >= (total_blocks - 1): + return func(*args, **kwargs) + + sample_sigmas: torch.Tensor = transformer_options["sample_sigmas"] + sigmas: torch.Tensor = transformer_options["sigmas"] + + total_steps: int = sample_sigmas.size(0) + step: int = 0 + + for i in range(total_steps): + if torch.allclose(sample_sigmas[i], sigmas): + step = i + break + + if step == 0 or step >= (total_steps - 1): + return func(*args, **kwargs) + + return sage3(*args, **kwargs) + + model = model.clone() + model.model_options["transformer_options"][ + "optimized_attention_override" + ] = attention_override + + return io.NodeOutput(model) + + +class Sage3Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [Sage3PatchModel] + + +async def comfy_entrypoint(): + return Sage3Extension() diff --git a/nodes.py b/nodes.py index d9e4ebd91..f320c247b 100644 --- a/nodes.py +++ b/nodes.py @@ -2360,6 +2360,7 @@ async def init_builtin_extra_nodes(): "nodes_nop.py", "nodes_kandinsky5.py", "nodes_wanmove.py", + "nodes_sage3.py", ] import_failed = [] From 43e950985632cf7725f4c8f68b56b0ea1aecc918 Mon Sep 17 00:00:00 2001 From: Haoming Date: Tue, 6 Jan 2026 13:54:27 +0800 Subject: [PATCH 2/5] desc --- comfy_extras/nodes_sage3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_sage3.py b/comfy_extras/nodes_sage3.py index 7fe5c114c..50d4d810e 100644 --- a/comfy_extras/nodes_sage3.py +++ b/comfy_extras/nodes_sage3.py @@ -15,7 +15,7 @@ class Sage3PatchModel(io.ComfyNode): return io.Schema( node_id="Sage3PatchModel", display_name="Patch SageAttention 3", - description="Apply `attention3_sage` to the middle blocks and steps, while using optimized_attention for the first/last blocks and steps", + description="Patch the model to use `attention3_sage` during the middle blocks and steps, keeping the default attention function for the first/last blocks and steps", category="_for_testing", inputs=[ io.Model.Input("model"), From 68deb84b7a48406511430ef509a285b56c6a0c43 Mon Sep 17 00:00:00 2001 From: Haoming Date: Sun, 15 Feb 2026 12:37:18 +0800 Subject: [PATCH 3/5] text --- comfy_extras/nodes_sage3.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_sage3.py b/comfy_extras/nodes_sage3.py index 50d4d810e..8954631c7 100644 --- a/comfy_extras/nodes_sage3.py +++ b/comfy_extras/nodes_sage3.py @@ -21,6 +21,7 @@ class Sage3PatchModel(io.ComfyNode): io.Model.Input("model"), ], outputs=[io.Model.Output()], + hidden=[io.Hidden.unique_id], is_experimental=True, ) @@ -29,10 +30,11 @@ class Sage3PatchModel(io.ComfyNode): sage3: Callable | None = get_attention_function("sage3", default=None) if sage3 is None: - PromptServer.instance.send_progress_text( - "`sageattn3` is not installed / available...", - cls.hidden.unique_id, - ) + if cls.hidden.unique_id: + PromptServer.instance.send_progress_text( + "`sageattn3` is not installed / available...", + cls.hidden.unique_id, + ) return io.NodeOutput(model) def attention_override(func: Callable, *args, **kwargs): From 19f59e5f01fd59f06971ba7b36a364dfaf167112 Mon Sep 17 00:00:00 2001 From: Haoming Date: Mon, 23 Mar 2026 21:07:49 +0800 Subject: [PATCH 4/5] add optional skip controls --- comfy_extras/nodes_sage3.py | 78 +++++++++++++++++++++++++++++++------ 1 file changed, 67 insertions(+), 11 deletions(-) diff --git a/comfy_extras/nodes_sage3.py b/comfy_extras/nodes_sage3.py index 8954631c7..8310f1d9a 100644 --- a/comfy_extras/nodes_sage3.py +++ b/comfy_extras/nodes_sage3.py @@ -15,10 +15,52 @@ class Sage3PatchModel(io.ComfyNode): return io.Schema( node_id="Sage3PatchModel", display_name="Patch SageAttention 3", - description="Patch the model to use `attention3_sage` during the middle blocks and steps, keeping the default attention function for the first/last blocks and steps", + description="Patch the model to use `attention3_sage` during the selected blocks and steps", category="_for_testing", inputs=[ io.Model.Input("model"), + io.Int.Input( + "skip_early_block", + tooltip="Use the default attention function for the first few Blocks", + default=1, + min=0, + max=99, + step=1, + display_mode=io.NumberDisplay.number, + optional=True, + advanced=True, + ), + io.Int.Input( + "skip_last_block", + tooltip="Use the default attention function for the last few Blocks", + default=1, + min=0, + max=99, + step=1, + optional=True, + advanced=True, + ), + io.Int.Input( + "skip_early_step", + tooltip="Use the default attention function for the first few Steps", + default=1, + min=0, + max=99, + step=1, + display_mode=io.NumberDisplay.number, + optional=True, + advanced=True, + ), + io.Int.Input( + "skip_last_step", + tooltip="Use the default attention function for the last few Steps", + default=1, + min=0, + max=99, + step=1, + optional=True, + advanced=True, + ), ], outputs=[io.Model.Output()], hidden=[io.Hidden.unique_id], @@ -26,13 +68,20 @@ class Sage3PatchModel(io.ComfyNode): ) @classmethod - def execute(cls, model: ModelPatcher) -> io.NodeOutput: + def execute( + cls, + model: ModelPatcher, + skip_early_block: int = 1, + skip_last_block: int = 1, + skip_early_step: int = 1, + skip_last_step: int = 1, + ) -> io.NodeOutput: sage3: Callable | None = get_attention_function("sage3", default=None) if sage3 is None: if cls.hidden.unique_id: PromptServer.instance.send_progress_text( - "`sageattn3` is not installed / available...", + 'To use the "Patch SageAttention 3" node, the `sageattn3` package must be installed first', cls.hidden.unique_id, ) return io.NodeOutput(model) @@ -40,24 +89,31 @@ class Sage3PatchModel(io.ComfyNode): def attention_override(func: Callable, *args, **kwargs): transformer_options: dict = kwargs.get("transformer_options", {}) - block_index: int = transformer_options.get("block_index", 0) - total_blocks: int = transformer_options.get("total_blocks", 1) + total_blocks: int = transformer_options.get("total_blocks", -1) + block_index: int = transformer_options.get("block_index", -1) # [0, N) - if block_index == 0 or block_index >= (total_blocks - 1): + if total_blocks == -1 or not ( + skip_early_block <= block_index < total_blocks - skip_last_block + ): return func(*args, **kwargs) - sample_sigmas: torch.Tensor = transformer_options["sample_sigmas"] - sigmas: torch.Tensor = transformer_options["sigmas"] + sample_sigmas: torch.Tensor = transformer_options.get("sample_sigmas", None) + sigmas: torch.Tensor = transformer_options.get("sigmas", None) - total_steps: int = sample_sigmas.size(0) - step: int = 0 + if sample_sigmas is None or sigmas is None: + return func(*args, **kwargs) + + total_steps: int = sample_sigmas.size(0) - 1 + step: int = -1 # [0, N) for i in range(total_steps): if torch.allclose(sample_sigmas[i], sigmas): step = i break - if step == 0 or step >= (total_steps - 1): + if step == -1 or not ( + skip_early_step <= step < total_steps - skip_last_step + ): return func(*args, **kwargs) return sage3(*args, **kwargs) From 6f70a2cb99b4fe819ab1e9cedcd9e4234346b91a Mon Sep 17 00:00:00 2001 From: Haoming Date: Fri, 27 Mar 2026 10:54:24 +0800 Subject: [PATCH 5/5] optimize --- comfy_extras/nodes_sage3.py | 46 ++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/comfy_extras/nodes_sage3.py b/comfy_extras/nodes_sage3.py index 8310f1d9a..a9523e92b 100644 --- a/comfy_extras/nodes_sage3.py +++ b/comfy_extras/nodes_sage3.py @@ -86,39 +86,49 @@ class Sage3PatchModel(io.ComfyNode): ) return io.NodeOutput(model) - def attention_override(func: Callable, *args, **kwargs): - transformer_options: dict = kwargs.get("transformer_options", {}) + def sage_wrapper(model_function, kwargs: dict): + # parse the current step on every model call instead of every attention call - total_blocks: int = transformer_options.get("total_blocks", -1) - block_index: int = transformer_options.get("block_index", -1) # [0, N) + x, timestep, c = kwargs["input"], kwargs["timestep"], kwargs["c"] - if total_blocks == -1 or not ( - skip_early_block <= block_index < total_blocks - skip_last_block - ): - return func(*args, **kwargs) + transformer_options: dict = c.get("transformer_options", {}) sample_sigmas: torch.Tensor = transformer_options.get("sample_sigmas", None) sigmas: torch.Tensor = transformer_options.get("sigmas", None) if sample_sigmas is None or sigmas is None: - return func(*args, **kwargs) + transformer_options["_sage3"] = False + return model_function(x, timestep, **c) - total_steps: int = sample_sigmas.size(0) - 1 - step: int = -1 # [0, N) + mask: torch.Tensor = (sample_sigmas == sigmas).nonzero(as_tuple=True)[0] - for i in range(total_steps): - if torch.allclose(sample_sigmas[i], sigmas): - step = i - break + total_steps: int = sample_sigmas.size(0) + step: int = mask.item() if mask.numel() > 0 else -1 # [0, N) - if step == -1 or not ( + transformer_options["_sage3"] = step > -1 and ( skip_early_step <= step < total_steps - skip_last_step - ): + ) + + return model_function(x, timestep, **c) + + def attention_override(func: Callable, *args, **kwargs): + transformer_options: dict = kwargs.get("transformer_options", {}) + + if not transformer_options.get("_sage3", False): return func(*args, **kwargs) - return sage3(*args, **kwargs) + total_blocks: int = transformer_options.get("total_blocks", -1) + block_index: int = transformer_options.get("block_index", -1) # [0, N) + + if total_blocks > -1 and ( + skip_early_block <= block_index < total_blocks - skip_last_block + ): + return sage3(*args, **kwargs) + else: + return func(*args, **kwargs) model = model.clone() + model.set_model_unet_function_wrapper(sage_wrapper) model.model_options["transformer_options"][ "optimized_attention_override" ] = attention_override