mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 01:37:04 +08:00
convert nodes_tomesd.py to V3 schema (#10180)
This commit is contained in:
parent
8c26d7bbe6
commit
5c8e986e27
@ -1,7 +1,9 @@
|
|||||||
#Taken from: https://github.com/dbolya/tomesd
|
#Taken from: https://github.com/dbolya/tomesd
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Tuple, Callable
|
from typing import Tuple, Callable, Optional
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
import math
|
import math
|
||||||
|
|
||||||
def do_nothing(x: torch.Tensor, mode:str=None):
|
def do_nothing(x: torch.Tensor, mode:str=None):
|
||||||
@ -144,33 +146,45 @@ def get_functions(x, ratio, original_shape):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TomePatchModel:
|
class TomePatchModel(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model": ("MODEL",),
|
return io.Schema(
|
||||||
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
|
node_id="TomePatchModel",
|
||||||
}}
|
category="model_patches/unet",
|
||||||
RETURN_TYPES = ("MODEL",)
|
inputs=[
|
||||||
FUNCTION = "patch"
|
io.Model.Input("model"),
|
||||||
|
io.Float.Input("ratio", default=0.3, min=0.0, max=1.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "model_patches/unet"
|
@classmethod
|
||||||
|
def execute(cls, model, ratio) -> io.NodeOutput:
|
||||||
def patch(self, model, ratio):
|
u: Optional[Callable] = None
|
||||||
self.u = None
|
|
||||||
def tomesd_m(q, k, v, extra_options):
|
def tomesd_m(q, k, v, extra_options):
|
||||||
|
nonlocal u
|
||||||
#NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
|
#NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
|
||||||
#however from my basic testing it seems that using q instead gives better results
|
#however from my basic testing it seems that using q instead gives better results
|
||||||
m, self.u = get_functions(q, ratio, extra_options["original_shape"])
|
m, u = get_functions(q, ratio, extra_options["original_shape"])
|
||||||
return m(q), k, v
|
return m(q), k, v
|
||||||
def tomesd_u(n, extra_options):
|
def tomesd_u(n, extra_options):
|
||||||
return self.u(n)
|
nonlocal u
|
||||||
|
return u(n)
|
||||||
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
m.set_model_attn1_patch(tomesd_m)
|
m.set_model_attn1_patch(tomesd_m)
|
||||||
m.set_model_attn1_output_patch(tomesd_u)
|
m.set_model_attn1_output_patch(tomesd_u)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class TomePatchModelExtension(ComfyExtension):
|
||||||
"TomePatchModel": TomePatchModel,
|
@override
|
||||||
}
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
TomePatchModel,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> TomePatchModelExtension:
|
||||||
|
return TomePatchModelExtension()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user