This commit is contained in:
Haoming 2026-03-27 10:54:24 +08:00
parent cfcd8f230b
commit 6f70a2cb99

View File

@ -86,39 +86,49 @@ class Sage3PatchModel(io.ComfyNode):
)
return io.NodeOutput(model)
def attention_override(func: Callable, *args, **kwargs):
transformer_options: dict = kwargs.get("transformer_options", {})
def sage_wrapper(model_function, kwargs: dict):
# parse the current step on every model call instead of every attention call
total_blocks: int = transformer_options.get("total_blocks", -1)
block_index: int = transformer_options.get("block_index", -1) # [0, N)
x, timestep, c = kwargs["input"], kwargs["timestep"], kwargs["c"]
if total_blocks == -1 or not (
skip_early_block <= block_index < total_blocks - skip_last_block
):
return func(*args, **kwargs)
transformer_options: dict = c.get("transformer_options", {})
sample_sigmas: torch.Tensor = transformer_options.get("sample_sigmas", None)
sigmas: torch.Tensor = transformer_options.get("sigmas", None)
if sample_sigmas is None or sigmas is None:
return func(*args, **kwargs)
transformer_options["_sage3"] = False
return model_function(x, timestep, **c)
total_steps: int = sample_sigmas.size(0) - 1
step: int = -1 # [0, N)
mask: torch.Tensor = (sample_sigmas == sigmas).nonzero(as_tuple=True)[0]
for i in range(total_steps):
if torch.allclose(sample_sigmas[i], sigmas):
step = i
break
total_steps: int = sample_sigmas.size(0)
step: int = mask.item() if mask.numel() > 0 else -1 # [0, N)
if step == -1 or not (
transformer_options["_sage3"] = step > -1 and (
skip_early_step <= step < total_steps - skip_last_step
):
)
return model_function(x, timestep, **c)
def attention_override(func: Callable, *args, **kwargs):
transformer_options: dict = kwargs.get("transformer_options", {})
if not transformer_options.get("_sage3", False):
return func(*args, **kwargs)
return sage3(*args, **kwargs)
total_blocks: int = transformer_options.get("total_blocks", -1)
block_index: int = transformer_options.get("block_index", -1) # [0, N)
if total_blocks > -1 and (
skip_early_block <= block_index < total_blocks - skip_last_block
):
return sage3(*args, **kwargs)
else:
return func(*args, **kwargs)
model = model.clone()
model.set_model_unet_function_wrapper(sage_wrapper)
model.model_options["transformer_options"][
"optimized_attention_override"
] = attention_override