mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-30 16:20:17 +08:00
Compare commits
5 Commits
a2cb9850e0
...
7ab00cc93e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ab00cc93e | ||
|
|
034fac7054 | ||
|
|
027c862453 | ||
|
|
43e9509856 | ||
|
|
265b4f0fa1 |
78
comfy_extras/nodes_sage3.py
Normal file
78
comfy_extras/nodes_sage3.py
Normal file
@ -0,0 +1,78 @@
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy.ldm.modules.attention import get_attention_function
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
class Sage3PatchModel(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Sage3PatchModel",
|
||||
display_name="Patch SageAttention 3",
|
||||
description="Patch the model to use `attention3_sage` during the middle blocks and steps, keeping the default attention function for the first/last blocks and steps",
|
||||
category="_for_testing",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: ModelPatcher) -> io.NodeOutput:
|
||||
sage3: Callable | None = get_attention_function("sage3", default=None)
|
||||
|
||||
if sage3 is None:
|
||||
PromptServer.instance.send_progress_text(
|
||||
"`sageattn3` is not installed / available...",
|
||||
cls.hidden.unique_id,
|
||||
)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
def attention_override(func: Callable, *args, **kwargs):
|
||||
transformer_options: dict = kwargs.get("transformer_options", {})
|
||||
|
||||
block_index: int = transformer_options.get("block_index", 0)
|
||||
total_blocks: int = transformer_options.get("total_blocks", 1)
|
||||
|
||||
if block_index == 0 or block_index >= (total_blocks - 1):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
sample_sigmas: torch.Tensor = transformer_options["sample_sigmas"]
|
||||
sigmas: torch.Tensor = transformer_options["sigmas"]
|
||||
|
||||
total_steps: int = sample_sigmas.size(0)
|
||||
step: int = 0
|
||||
|
||||
for i in range(total_steps):
|
||||
if torch.allclose(sample_sigmas[i], sigmas):
|
||||
step = i
|
||||
break
|
||||
|
||||
if step == 0 or step >= (total_steps - 1):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return sage3(*args, **kwargs)
|
||||
|
||||
model = model.clone()
|
||||
model.model_options["transformer_options"][
|
||||
"optimized_attention_override"
|
||||
] = attention_override
|
||||
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
class Sage3Extension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [Sage3PatchModel]
|
||||
|
||||
|
||||
async def comfy_entrypoint():
|
||||
return Sage3Extension()
|
||||
31
nodes.py
31
nodes.py
@ -5,6 +5,7 @@ import torch
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import glob
|
||||
import hashlib
|
||||
import inspect
|
||||
import traceback
|
||||
@ -2371,6 +2372,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_nop.py",
|
||||
"nodes_kandinsky5.py",
|
||||
"nodes_wanmove.py",
|
||||
"nodes_sage3.py",
|
||||
"nodes_image_compare.py",
|
||||
]
|
||||
|
||||
@ -2384,35 +2386,12 @@ async def init_builtin_extra_nodes():
|
||||
|
||||
async def init_builtin_api_nodes():
|
||||
api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes")
|
||||
api_nodes_files = [
|
||||
"nodes_ideogram.py",
|
||||
"nodes_openai.py",
|
||||
"nodes_minimax.py",
|
||||
"nodes_veo2.py",
|
||||
"nodes_kling.py",
|
||||
"nodes_bfl.py",
|
||||
"nodes_bytedance.py",
|
||||
"nodes_ltxv.py",
|
||||
"nodes_luma.py",
|
||||
"nodes_recraft.py",
|
||||
"nodes_pixverse.py",
|
||||
"nodes_stability.py",
|
||||
"nodes_runway.py",
|
||||
"nodes_sora.py",
|
||||
"nodes_topaz.py",
|
||||
"nodes_tripo.py",
|
||||
"nodes_meshy.py",
|
||||
"nodes_moonvalley.py",
|
||||
"nodes_rodin.py",
|
||||
"nodes_gemini.py",
|
||||
"nodes_vidu.py",
|
||||
"nodes_wan.py",
|
||||
]
|
||||
api_nodes_files = sorted(glob.glob(os.path.join(api_nodes_dir, "nodes_*.py")))
|
||||
|
||||
import_failed = []
|
||||
for node_file in api_nodes_files:
|
||||
if not await load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"):
|
||||
import_failed.append(node_file)
|
||||
if not await load_custom_node(node_file, module_parent="comfy_api_nodes"):
|
||||
import_failed.append(os.path.basename(node_file))
|
||||
|
||||
return import_failed
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user