From 6f70a2cb99b4fe819ab1e9cedcd9e4234346b91a Mon Sep 17 00:00:00 2001 From: Haoming Date: Fri, 27 Mar 2026 10:54:24 +0800 Subject: [PATCH] optimize --- comfy_extras/nodes_sage3.py | 46 ++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/comfy_extras/nodes_sage3.py b/comfy_extras/nodes_sage3.py index 8310f1d9a..a9523e92b 100644 --- a/comfy_extras/nodes_sage3.py +++ b/comfy_extras/nodes_sage3.py @@ -86,39 +86,49 @@ class Sage3PatchModel(io.ComfyNode): ) return io.NodeOutput(model) - def attention_override(func: Callable, *args, **kwargs): - transformer_options: dict = kwargs.get("transformer_options", {}) + def sage_wrapper(model_function, kwargs: dict): + # parse the current step on every model call instead of every attention call - total_blocks: int = transformer_options.get("total_blocks", -1) - block_index: int = transformer_options.get("block_index", -1) # [0, N) + x, timestep, c = kwargs["input"], kwargs["timestep"], kwargs["c"] - if total_blocks == -1 or not ( - skip_early_block <= block_index < total_blocks - skip_last_block - ): - return func(*args, **kwargs) + transformer_options: dict = c.get("transformer_options", {}) sample_sigmas: torch.Tensor = transformer_options.get("sample_sigmas", None) sigmas: torch.Tensor = transformer_options.get("sigmas", None) if sample_sigmas is None or sigmas is None: - return func(*args, **kwargs) + transformer_options["_sage3"] = False + return model_function(x, timestep, **c) - total_steps: int = sample_sigmas.size(0) - 1 - step: int = -1 # [0, N) + mask: torch.Tensor = (sample_sigmas == sigmas).nonzero(as_tuple=True)[0] - for i in range(total_steps): - if torch.allclose(sample_sigmas[i], sigmas): - step = i - break + total_steps: int = sample_sigmas.size(0) + step: int = mask.item() if mask.numel() > 0 else -1 # [0, N) - if step == -1 or not ( + transformer_options["_sage3"] = step > -1 and ( skip_early_step <= step < total_steps - skip_last_step - ): + ) + + return model_function(x, timestep, **c) + + def attention_override(func: Callable, *args, **kwargs): + transformer_options: dict = kwargs.get("transformer_options", {}) + + if not transformer_options.get("_sage3", False): return func(*args, **kwargs) - return sage3(*args, **kwargs) + total_blocks: int = transformer_options.get("total_blocks", -1) + block_index: int = transformer_options.get("block_index", -1) # [0, N) + + if total_blocks > -1 and ( + skip_early_block <= block_index < total_blocks - skip_last_block + ): + return sage3(*args, **kwargs) + else: + return func(*args, **kwargs) model = model.clone() + model.set_model_unet_function_wrapper(sage_wrapper) model.model_options["transformer_options"][ "optimized_attention_override" ] = attention_override