ComfyUI/comfy_extras/nodes_sage3.py
2026-03-23 21:07:49 +08:00

137 lines
4.5 KiB
Python

from typing import Callable
import torch
from typing_extensions import override
from comfy.ldm.modules.attention import get_attention_function
from comfy.model_patcher import ModelPatcher
from comfy_api.latest import ComfyExtension, io
from server import PromptServer
class Sage3PatchModel(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Sage3PatchModel",
display_name="Patch SageAttention 3",
description="Patch the model to use `attention3_sage` during the selected blocks and steps",
category="_for_testing",
inputs=[
io.Model.Input("model"),
io.Int.Input(
"skip_early_block",
tooltip="Use the default attention function for the first few Blocks",
default=1,
min=0,
max=99,
step=1,
display_mode=io.NumberDisplay.number,
optional=True,
advanced=True,
),
io.Int.Input(
"skip_last_block",
tooltip="Use the default attention function for the last few Blocks",
default=1,
min=0,
max=99,
step=1,
optional=True,
advanced=True,
),
io.Int.Input(
"skip_early_step",
tooltip="Use the default attention function for the first few Steps",
default=1,
min=0,
max=99,
step=1,
display_mode=io.NumberDisplay.number,
optional=True,
advanced=True,
),
io.Int.Input(
"skip_last_step",
tooltip="Use the default attention function for the last few Steps",
default=1,
min=0,
max=99,
step=1,
optional=True,
advanced=True,
),
],
outputs=[io.Model.Output()],
hidden=[io.Hidden.unique_id],
is_experimental=True,
)
@classmethod
def execute(
cls,
model: ModelPatcher,
skip_early_block: int = 1,
skip_last_block: int = 1,
skip_early_step: int = 1,
skip_last_step: int = 1,
) -> io.NodeOutput:
sage3: Callable | None = get_attention_function("sage3", default=None)
if sage3 is None:
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(
'To use the "Patch SageAttention 3" node, the `sageattn3` package must be installed first',
cls.hidden.unique_id,
)
return io.NodeOutput(model)
def attention_override(func: Callable, *args, **kwargs):
transformer_options: dict = kwargs.get("transformer_options", {})
total_blocks: int = transformer_options.get("total_blocks", -1)
block_index: int = transformer_options.get("block_index", -1) # [0, N)
if total_blocks == -1 or not (
skip_early_block <= block_index < total_blocks - skip_last_block
):
return func(*args, **kwargs)
sample_sigmas: torch.Tensor = transformer_options.get("sample_sigmas", None)
sigmas: torch.Tensor = transformer_options.get("sigmas", None)
if sample_sigmas is None or sigmas is None:
return func(*args, **kwargs)
total_steps: int = sample_sigmas.size(0) - 1
step: int = -1 # [0, N)
for i in range(total_steps):
if torch.allclose(sample_sigmas[i], sigmas):
step = i
break
if step == -1 or not (
skip_early_step <= step < total_steps - skip_last_step
):
return func(*args, **kwargs)
return sage3(*args, **kwargs)
model = model.clone()
model.model_options["transformer_options"][
"optimized_attention_override"
] = attention_override
return io.NodeOutput(model)
class Sage3Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [Sage3PatchModel]
async def comfy_entrypoint():
return Sage3Extension()