diff --git a/comfy_extras/nodes_sage3.py b/comfy_extras/nodes_sage3.py index 8954631c7..8310f1d9a 100644 --- a/comfy_extras/nodes_sage3.py +++ b/comfy_extras/nodes_sage3.py @@ -15,10 +15,52 @@ class Sage3PatchModel(io.ComfyNode): return io.Schema( node_id="Sage3PatchModel", display_name="Patch SageAttention 3", - description="Patch the model to use `attention3_sage` during the middle blocks and steps, keeping the default attention function for the first/last blocks and steps", + description="Patch the model to use `attention3_sage` during the selected blocks and steps", category="_for_testing", inputs=[ io.Model.Input("model"), + io.Int.Input( + "skip_early_block", + tooltip="Use the default attention function for the first few Blocks", + default=1, + min=0, + max=99, + step=1, + display_mode=io.NumberDisplay.number, + optional=True, + advanced=True, + ), + io.Int.Input( + "skip_last_block", + tooltip="Use the default attention function for the last few Blocks", + default=1, + min=0, + max=99, + step=1, + optional=True, + advanced=True, + ), + io.Int.Input( + "skip_early_step", + tooltip="Use the default attention function for the first few Steps", + default=1, + min=0, + max=99, + step=1, + display_mode=io.NumberDisplay.number, + optional=True, + advanced=True, + ), + io.Int.Input( + "skip_last_step", + tooltip="Use the default attention function for the last few Steps", + default=1, + min=0, + max=99, + step=1, + optional=True, + advanced=True, + ), ], outputs=[io.Model.Output()], hidden=[io.Hidden.unique_id], @@ -26,13 +68,20 @@ class Sage3PatchModel(io.ComfyNode): ) @classmethod - def execute(cls, model: ModelPatcher) -> io.NodeOutput: + def execute( + cls, + model: ModelPatcher, + skip_early_block: int = 1, + skip_last_block: int = 1, + skip_early_step: int = 1, + skip_last_step: int = 1, + ) -> io.NodeOutput: sage3: Callable | None = get_attention_function("sage3", default=None) if sage3 is None: if cls.hidden.unique_id: PromptServer.instance.send_progress_text( - "`sageattn3` is not installed / available...", + 'To use the "Patch SageAttention 3" node, the `sageattn3` package must be installed first', cls.hidden.unique_id, ) return io.NodeOutput(model) @@ -40,24 +89,31 @@ class Sage3PatchModel(io.ComfyNode): def attention_override(func: Callable, *args, **kwargs): transformer_options: dict = kwargs.get("transformer_options", {}) - block_index: int = transformer_options.get("block_index", 0) - total_blocks: int = transformer_options.get("total_blocks", 1) + total_blocks: int = transformer_options.get("total_blocks", -1) + block_index: int = transformer_options.get("block_index", -1) # [0, N) - if block_index == 0 or block_index >= (total_blocks - 1): + if total_blocks == -1 or not ( + skip_early_block <= block_index < total_blocks - skip_last_block + ): return func(*args, **kwargs) - sample_sigmas: torch.Tensor = transformer_options["sample_sigmas"] - sigmas: torch.Tensor = transformer_options["sigmas"] + sample_sigmas: torch.Tensor = transformer_options.get("sample_sigmas", None) + sigmas: torch.Tensor = transformer_options.get("sigmas", None) - total_steps: int = sample_sigmas.size(0) - step: int = 0 + if sample_sigmas is None or sigmas is None: + return func(*args, **kwargs) + + total_steps: int = sample_sigmas.size(0) - 1 + step: int = -1 # [0, N) for i in range(total_steps): if torch.allclose(sample_sigmas[i], sigmas): step = i break - if step == 0 or step >= (total_steps - 1): + if step == -1 or not ( + skip_early_step <= step < total_steps - skip_last_step + ): return func(*args, **kwargs) return sage3(*args, **kwargs)