convert nodes_tcfg.py to V3 schema (#9942)

This commit is contained in:
Alexander Piskun 2025-09-27 00:13:05 +03:00 committed by GitHub
parent 80718908a9
commit a061b06321
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,8 +1,9 @@
# TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)
from typing_extensions import override
import torch
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
from comfy_api.latest import ComfyExtension, io
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
@ -26,23 +27,24 @@ def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tenso
return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype)
class TCFG(ComfyNodeABC):
class TCFG(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"model": (IO.MODEL, {}),
}
}
def define_schema(cls):
return io.Schema(
node_id="TCFG",
display_name="Tangential Damping CFG",
category="advanced/guidance",
description="TCFG Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality.",
inputs=[
io.Model.Input("model"),
],
outputs=[
io.Model.Output(display_name="patched_model"),
],
)
RETURN_TYPES = (IO.MODEL,)
RETURN_NAMES = ("patched_model",)
FUNCTION = "patch"
CATEGORY = "advanced/guidance"
DESCRIPTION = "TCFG Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality."
def patch(self, model):
@classmethod
def execute(cls, model):
m = model.clone()
def tangential_damping_cfg(args):
@ -59,13 +61,16 @@ class TCFG(ComfyNodeABC):
return [cond_pred, uncond_pred_td] + conds_out[2:]
m.set_model_sampler_pre_cfg_function(tangential_damping_cfg)
return (m,)
return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"TCFG": TCFG,
}
class TcfgExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TCFG,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"TCFG": "Tangential Damping CFG",
}
async def comfy_entrypoint() -> TcfgExtension:
return TcfgExtension()