mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 05:23:37 +08:00
optimize
This commit is contained in:
parent
cfcd8f230b
commit
6f70a2cb99
@ -86,39 +86,49 @@ class Sage3PatchModel(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
return io.NodeOutput(model)
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
def attention_override(func: Callable, *args, **kwargs):
|
def sage_wrapper(model_function, kwargs: dict):
|
||||||
transformer_options: dict = kwargs.get("transformer_options", {})
|
# parse the current step on every model call instead of every attention call
|
||||||
|
|
||||||
total_blocks: int = transformer_options.get("total_blocks", -1)
|
x, timestep, c = kwargs["input"], kwargs["timestep"], kwargs["c"]
|
||||||
block_index: int = transformer_options.get("block_index", -1) # [0, N)
|
|
||||||
|
|
||||||
if total_blocks == -1 or not (
|
transformer_options: dict = c.get("transformer_options", {})
|
||||||
skip_early_block <= block_index < total_blocks - skip_last_block
|
|
||||||
):
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
sample_sigmas: torch.Tensor = transformer_options.get("sample_sigmas", None)
|
sample_sigmas: torch.Tensor = transformer_options.get("sample_sigmas", None)
|
||||||
sigmas: torch.Tensor = transformer_options.get("sigmas", None)
|
sigmas: torch.Tensor = transformer_options.get("sigmas", None)
|
||||||
|
|
||||||
if sample_sigmas is None or sigmas is None:
|
if sample_sigmas is None or sigmas is None:
|
||||||
return func(*args, **kwargs)
|
transformer_options["_sage3"] = False
|
||||||
|
return model_function(x, timestep, **c)
|
||||||
|
|
||||||
total_steps: int = sample_sigmas.size(0) - 1
|
mask: torch.Tensor = (sample_sigmas == sigmas).nonzero(as_tuple=True)[0]
|
||||||
step: int = -1 # [0, N)
|
|
||||||
|
|
||||||
for i in range(total_steps):
|
total_steps: int = sample_sigmas.size(0)
|
||||||
if torch.allclose(sample_sigmas[i], sigmas):
|
step: int = mask.item() if mask.numel() > 0 else -1 # [0, N)
|
||||||
step = i
|
|
||||||
break
|
|
||||||
|
|
||||||
if step == -1 or not (
|
transformer_options["_sage3"] = step > -1 and (
|
||||||
skip_early_step <= step < total_steps - skip_last_step
|
skip_early_step <= step < total_steps - skip_last_step
|
||||||
):
|
)
|
||||||
|
|
||||||
|
return model_function(x, timestep, **c)
|
||||||
|
|
||||||
|
def attention_override(func: Callable, *args, **kwargs):
|
||||||
|
transformer_options: dict = kwargs.get("transformer_options", {})
|
||||||
|
|
||||||
|
if not transformer_options.get("_sage3", False):
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return sage3(*args, **kwargs)
|
total_blocks: int = transformer_options.get("total_blocks", -1)
|
||||||
|
block_index: int = transformer_options.get("block_index", -1) # [0, N)
|
||||||
|
|
||||||
|
if total_blocks > -1 and (
|
||||||
|
skip_early_block <= block_index < total_blocks - skip_last_block
|
||||||
|
):
|
||||||
|
return sage3(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
|
model.set_model_unet_function_wrapper(sage_wrapper)
|
||||||
model.model_options["transformer_options"][
|
model.model_options["transformer_options"][
|
||||||
"optimized_attention_override"
|
"optimized_attention_override"
|
||||||
] = attention_override
|
] = attention_override
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user