convert nodes_tcfg.py to V3 schema (#9942)

This commit is contained in:
Alexander Piskun 2025-09-27 00:13:05 +03:00 committed by GitHub
parent 80718908a9
commit a061b06321
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,8 +1,9 @@
# TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137) # TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)
from typing_extensions import override
import torch import torch
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict from comfy_api.latest import ComfyExtension, io
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor: def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
@ -26,23 +27,24 @@ def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tenso
return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype) return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype)
class TCFG(ComfyNodeABC): class TCFG(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(cls) -> InputTypeDict: def define_schema(cls):
return { return io.Schema(
"required": { node_id="TCFG",
"model": (IO.MODEL, {}), display_name="Tangential Damping CFG",
} category="advanced/guidance",
} description="TCFG Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality.",
inputs=[
io.Model.Input("model"),
],
outputs=[
io.Model.Output(display_name="patched_model"),
],
)
RETURN_TYPES = (IO.MODEL,) @classmethod
RETURN_NAMES = ("patched_model",) def execute(cls, model):
FUNCTION = "patch"
CATEGORY = "advanced/guidance"
DESCRIPTION = "TCFG Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality."
def patch(self, model):
m = model.clone() m = model.clone()
def tangential_damping_cfg(args): def tangential_damping_cfg(args):
@ -59,13 +61,16 @@ class TCFG(ComfyNodeABC):
return [cond_pred, uncond_pred_td] + conds_out[2:] return [cond_pred, uncond_pred_td] + conds_out[2:]
m.set_model_sampler_pre_cfg_function(tangential_damping_cfg) m.set_model_sampler_pre_cfg_function(tangential_damping_cfg)
return (m,) return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = { class TcfgExtension(ComfyExtension):
"TCFG": TCFG, @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TCFG,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"TCFG": "Tangential Damping CFG", async def comfy_entrypoint() -> TcfgExtension:
} return TcfgExtension()