mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-14 12:32:31 +08:00
add optional skip controls
This commit is contained in:
parent
471cc76c77
commit
19f59e5f01
@ -15,10 +15,52 @@ class Sage3PatchModel(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="Sage3PatchModel",
|
node_id="Sage3PatchModel",
|
||||||
display_name="Patch SageAttention 3",
|
display_name="Patch SageAttention 3",
|
||||||
description="Patch the model to use `attention3_sage` during the middle blocks and steps, keeping the default attention function for the first/last blocks and steps",
|
description="Patch the model to use `attention3_sage` during the selected blocks and steps",
|
||||||
category="_for_testing",
|
category="_for_testing",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
|
io.Int.Input(
|
||||||
|
"skip_early_block",
|
||||||
|
tooltip="Use the default attention function for the first few Blocks",
|
||||||
|
default=1,
|
||||||
|
min=0,
|
||||||
|
max=99,
|
||||||
|
step=1,
|
||||||
|
display_mode=io.NumberDisplay.number,
|
||||||
|
optional=True,
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
"skip_last_block",
|
||||||
|
tooltip="Use the default attention function for the last few Blocks",
|
||||||
|
default=1,
|
||||||
|
min=0,
|
||||||
|
max=99,
|
||||||
|
step=1,
|
||||||
|
optional=True,
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
"skip_early_step",
|
||||||
|
tooltip="Use the default attention function for the first few Steps",
|
||||||
|
default=1,
|
||||||
|
min=0,
|
||||||
|
max=99,
|
||||||
|
step=1,
|
||||||
|
display_mode=io.NumberDisplay.number,
|
||||||
|
optional=True,
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
"skip_last_step",
|
||||||
|
tooltip="Use the default attention function for the last few Steps",
|
||||||
|
default=1,
|
||||||
|
min=0,
|
||||||
|
max=99,
|
||||||
|
step=1,
|
||||||
|
optional=True,
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[io.Model.Output()],
|
outputs=[io.Model.Output()],
|
||||||
hidden=[io.Hidden.unique_id],
|
hidden=[io.Hidden.unique_id],
|
||||||
@ -26,13 +68,20 @@ class Sage3PatchModel(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: ModelPatcher) -> io.NodeOutput:
|
def execute(
|
||||||
|
cls,
|
||||||
|
model: ModelPatcher,
|
||||||
|
skip_early_block: int = 1,
|
||||||
|
skip_last_block: int = 1,
|
||||||
|
skip_early_step: int = 1,
|
||||||
|
skip_last_step: int = 1,
|
||||||
|
) -> io.NodeOutput:
|
||||||
sage3: Callable | None = get_attention_function("sage3", default=None)
|
sage3: Callable | None = get_attention_function("sage3", default=None)
|
||||||
|
|
||||||
if sage3 is None:
|
if sage3 is None:
|
||||||
if cls.hidden.unique_id:
|
if cls.hidden.unique_id:
|
||||||
PromptServer.instance.send_progress_text(
|
PromptServer.instance.send_progress_text(
|
||||||
"`sageattn3` is not installed / available...",
|
'To use the "Patch SageAttention 3" node, the `sageattn3` package must be installed first',
|
||||||
cls.hidden.unique_id,
|
cls.hidden.unique_id,
|
||||||
)
|
)
|
||||||
return io.NodeOutput(model)
|
return io.NodeOutput(model)
|
||||||
@ -40,24 +89,31 @@ class Sage3PatchModel(io.ComfyNode):
|
|||||||
def attention_override(func: Callable, *args, **kwargs):
|
def attention_override(func: Callable, *args, **kwargs):
|
||||||
transformer_options: dict = kwargs.get("transformer_options", {})
|
transformer_options: dict = kwargs.get("transformer_options", {})
|
||||||
|
|
||||||
block_index: int = transformer_options.get("block_index", 0)
|
total_blocks: int = transformer_options.get("total_blocks", -1)
|
||||||
total_blocks: int = transformer_options.get("total_blocks", 1)
|
block_index: int = transformer_options.get("block_index", -1) # [0, N)
|
||||||
|
|
||||||
if block_index == 0 or block_index >= (total_blocks - 1):
|
if total_blocks == -1 or not (
|
||||||
|
skip_early_block <= block_index < total_blocks - skip_last_block
|
||||||
|
):
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
sample_sigmas: torch.Tensor = transformer_options["sample_sigmas"]
|
sample_sigmas: torch.Tensor = transformer_options.get("sample_sigmas", None)
|
||||||
sigmas: torch.Tensor = transformer_options["sigmas"]
|
sigmas: torch.Tensor = transformer_options.get("sigmas", None)
|
||||||
|
|
||||||
total_steps: int = sample_sigmas.size(0)
|
if sample_sigmas is None or sigmas is None:
|
||||||
step: int = 0
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
total_steps: int = sample_sigmas.size(0) - 1
|
||||||
|
step: int = -1 # [0, N)
|
||||||
|
|
||||||
for i in range(total_steps):
|
for i in range(total_steps):
|
||||||
if torch.allclose(sample_sigmas[i], sigmas):
|
if torch.allclose(sample_sigmas[i], sigmas):
|
||||||
step = i
|
step = i
|
||||||
break
|
break
|
||||||
|
|
||||||
if step == 0 or step >= (total_steps - 1):
|
if step == -1 or not (
|
||||||
|
skip_early_step <= step < total_steps - skip_last_step
|
||||||
|
):
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return sage3(*args, **kwargs)
|
return sage3(*args, **kwargs)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user