from typing import Callable import torch from typing_extensions import override from comfy.ldm.modules.attention import get_attention_function from comfy.model_patcher import ModelPatcher from comfy_api.latest import ComfyExtension, io from server import PromptServer class Sage3PatchModel(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="Sage3PatchModel", display_name="Patch SageAttention 3", description="Patch the model to use `attention3_sage` during the selected blocks and steps", category="_for_testing", inputs=[ io.Model.Input("model"), io.Int.Input( "skip_early_block", tooltip="Use the default attention function for the first few Blocks", default=1, min=0, max=99, step=1, display_mode=io.NumberDisplay.number, optional=True, advanced=True, ), io.Int.Input( "skip_last_block", tooltip="Use the default attention function for the last few Blocks", default=1, min=0, max=99, step=1, optional=True, advanced=True, ), io.Int.Input( "skip_early_step", tooltip="Use the default attention function for the first few Steps", default=1, min=0, max=99, step=1, display_mode=io.NumberDisplay.number, optional=True, advanced=True, ), io.Int.Input( "skip_last_step", tooltip="Use the default attention function for the last few Steps", default=1, min=0, max=99, step=1, optional=True, advanced=True, ), ], outputs=[io.Model.Output()], hidden=[io.Hidden.unique_id], is_experimental=True, ) @classmethod def execute( cls, model: ModelPatcher, skip_early_block: int = 1, skip_last_block: int = 1, skip_early_step: int = 1, skip_last_step: int = 1, ) -> io.NodeOutput: sage3: Callable | None = get_attention_function("sage3", default=None) if sage3 is None: if cls.hidden.unique_id: PromptServer.instance.send_progress_text( 'To use the "Patch SageAttention 3" node, the `sageattn3` package must be installed first', cls.hidden.unique_id, ) return io.NodeOutput(model) def attention_override(func: Callable, *args, **kwargs): transformer_options: dict = kwargs.get("transformer_options", {}) total_blocks: int = transformer_options.get("total_blocks", -1) block_index: int = transformer_options.get("block_index", -1) # [0, N) if total_blocks == -1 or not ( skip_early_block <= block_index < total_blocks - skip_last_block ): return func(*args, **kwargs) sample_sigmas: torch.Tensor = transformer_options.get("sample_sigmas", None) sigmas: torch.Tensor = transformer_options.get("sigmas", None) if sample_sigmas is None or sigmas is None: return func(*args, **kwargs) total_steps: int = sample_sigmas.size(0) - 1 step: int = -1 # [0, N) for i in range(total_steps): if torch.allclose(sample_sigmas[i], sigmas): step = i break if step == -1 or not ( skip_early_step <= step < total_steps - skip_last_step ): return func(*args, **kwargs) return sage3(*args, **kwargs) model = model.clone() model.model_options["transformer_options"][ "optimized_attention_override" ] = attention_override return io.NodeOutput(model) class Sage3Extension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [Sage3PatchModel] async def comfy_entrypoint(): return Sage3Extension()