mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-15 01:07:03 +08:00
convert nodes_tcfg.py to V3 schema (#9942)
This commit is contained in:
parent
80718908a9
commit
a061b06321
@ -1,8 +1,9 @@
|
||||
# TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)
|
||||
|
||||
from typing_extensions import override
|
||||
import torch
|
||||
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
|
||||
@ -26,23 +27,24 @@ def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tenso
|
||||
return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype)
|
||||
|
||||
|
||||
class TCFG(ComfyNodeABC):
|
||||
class TCFG(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"model": (IO.MODEL, {}),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TCFG",
|
||||
display_name="Tangential Damping CFG",
|
||||
category="advanced/guidance",
|
||||
description="TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality.",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(display_name="patched_model"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.MODEL,)
|
||||
RETURN_NAMES = ("patched_model",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/guidance"
|
||||
DESCRIPTION = "TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality."
|
||||
|
||||
def patch(self, model):
|
||||
@classmethod
|
||||
def execute(cls, model):
|
||||
m = model.clone()
|
||||
|
||||
def tangential_damping_cfg(args):
|
||||
@ -59,13 +61,16 @@ class TCFG(ComfyNodeABC):
|
||||
return [cond_pred, uncond_pred_td] + conds_out[2:]
|
||||
|
||||
m.set_model_sampler_pre_cfg_function(tangential_damping_cfg)
|
||||
return (m,)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TCFG": TCFG,
|
||||
}
|
||||
class TcfgExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TCFG,
|
||||
]
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TCFG": "Tangential Damping CFG",
|
||||
}
|
||||
|
||||
async def comfy_entrypoint() -> TcfgExtension:
|
||||
return TcfgExtension()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user