This commit is contained in:
Haoming 2026-03-27 10:54:24 +08:00
parent cfcd8f230b
commit 6f70a2cb99

View File

@ -86,39 +86,49 @@ class Sage3PatchModel(io.ComfyNode):
) )
return io.NodeOutput(model) return io.NodeOutput(model)
def attention_override(func: Callable, *args, **kwargs): def sage_wrapper(model_function, kwargs: dict):
transformer_options: dict = kwargs.get("transformer_options", {}) # parse the current step on every model call instead of every attention call
total_blocks: int = transformer_options.get("total_blocks", -1) x, timestep, c = kwargs["input"], kwargs["timestep"], kwargs["c"]
block_index: int = transformer_options.get("block_index", -1) # [0, N)
if total_blocks == -1 or not ( transformer_options: dict = c.get("transformer_options", {})
skip_early_block <= block_index < total_blocks - skip_last_block
):
return func(*args, **kwargs)
sample_sigmas: torch.Tensor = transformer_options.get("sample_sigmas", None) sample_sigmas: torch.Tensor = transformer_options.get("sample_sigmas", None)
sigmas: torch.Tensor = transformer_options.get("sigmas", None) sigmas: torch.Tensor = transformer_options.get("sigmas", None)
if sample_sigmas is None or sigmas is None: if sample_sigmas is None or sigmas is None:
return func(*args, **kwargs) transformer_options["_sage3"] = False
return model_function(x, timestep, **c)
total_steps: int = sample_sigmas.size(0) - 1 mask: torch.Tensor = (sample_sigmas == sigmas).nonzero(as_tuple=True)[0]
step: int = -1 # [0, N)
for i in range(total_steps): total_steps: int = sample_sigmas.size(0)
if torch.allclose(sample_sigmas[i], sigmas): step: int = mask.item() if mask.numel() > 0 else -1 # [0, N)
step = i
break
if step == -1 or not ( transformer_options["_sage3"] = step > -1 and (
skip_early_step <= step < total_steps - skip_last_step skip_early_step <= step < total_steps - skip_last_step
): )
return model_function(x, timestep, **c)
def attention_override(func: Callable, *args, **kwargs):
transformer_options: dict = kwargs.get("transformer_options", {})
if not transformer_options.get("_sage3", False):
return func(*args, **kwargs) return func(*args, **kwargs)
return sage3(*args, **kwargs) total_blocks: int = transformer_options.get("total_blocks", -1)
block_index: int = transformer_options.get("block_index", -1) # [0, N)
if total_blocks > -1 and (
skip_early_block <= block_index < total_blocks - skip_last_block
):
return sage3(*args, **kwargs)
else:
return func(*args, **kwargs)
model = model.clone() model = model.clone()
model.set_model_unet_function_wrapper(sage_wrapper)
model.model_options["transformer_options"][ model.model_options["transformer_options"][
"optimized_attention_override" "optimized_attention_override"
] = attention_override ] = attention_override