convert nodes_tomesd.py to V3 schema (#10180)

This commit is contained in:
Alexander Piskun 2025-10-03 21:50:38 +03:00 committed by GitHub
parent 8c26d7bbe6
commit 5c8e986e27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,9 @@
#Taken from: https://github.com/dbolya/tomesd #Taken from: https://github.com/dbolya/tomesd
import torch import torch
from typing import Tuple, Callable from typing import Tuple, Callable, Optional
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import math import math
def do_nothing(x: torch.Tensor, mode:str=None): def do_nothing(x: torch.Tensor, mode:str=None):
@ -144,33 +146,45 @@ def get_functions(x, ratio, original_shape):
class TomePatchModel: class TomePatchModel(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "model": ("MODEL",), return io.Schema(
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}), node_id="TomePatchModel",
}} category="model_patches/unet",
RETURN_TYPES = ("MODEL",) inputs=[
FUNCTION = "patch" io.Model.Input("model"),
io.Float.Input("ratio", default=0.3, min=0.0, max=1.0, step=0.01),
],
outputs=[io.Model.Output()],
)
CATEGORY = "model_patches/unet" @classmethod
def execute(cls, model, ratio) -> io.NodeOutput:
def patch(self, model, ratio): u: Optional[Callable] = None
self.u = None
def tomesd_m(q, k, v, extra_options): def tomesd_m(q, k, v, extra_options):
nonlocal u
#NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q #NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
#however from my basic testing it seems that using q instead gives better results #however from my basic testing it seems that using q instead gives better results
m, self.u = get_functions(q, ratio, extra_options["original_shape"]) m, u = get_functions(q, ratio, extra_options["original_shape"])
return m(q), k, v return m(q), k, v
def tomesd_u(n, extra_options): def tomesd_u(n, extra_options):
return self.u(n) nonlocal u
return u(n)
m = model.clone() m = model.clone()
m.set_model_attn1_patch(tomesd_m) m.set_model_attn1_patch(tomesd_m)
m.set_model_attn1_output_patch(tomesd_u) m.set_model_attn1_output_patch(tomesd_u)
return (m, ) return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = { class TomePatchModelExtension(ComfyExtension):
"TomePatchModel": TomePatchModel, @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TomePatchModel,
]
async def comfy_entrypoint() -> TomePatchModelExtension:
return TomePatchModelExtension()