mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 17:42:58 +08:00
convert nodes_tcfg.py to V3 schema (#9942)
This commit is contained in:
parent
80718908a9
commit
a061b06321
@ -1,8 +1,9 @@
|
|||||||
# TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)
|
# TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
|
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
|
||||||
@ -26,23 +27,24 @@ def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tenso
|
|||||||
return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype)
|
return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype)
|
||||||
|
|
||||||
|
|
||||||
class TCFG(ComfyNodeABC):
|
class TCFG(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
def define_schema(cls):
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="TCFG",
|
||||||
"model": (IO.MODEL, {}),
|
display_name="Tangential Damping CFG",
|
||||||
}
|
category="advanced/guidance",
|
||||||
}
|
description="TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality.",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(display_name="patched_model"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = (IO.MODEL,)
|
@classmethod
|
||||||
RETURN_NAMES = ("patched_model",)
|
def execute(cls, model):
|
||||||
FUNCTION = "patch"
|
|
||||||
|
|
||||||
CATEGORY = "advanced/guidance"
|
|
||||||
DESCRIPTION = "TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality."
|
|
||||||
|
|
||||||
def patch(self, model):
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
|
||||||
def tangential_damping_cfg(args):
|
def tangential_damping_cfg(args):
|
||||||
@ -59,13 +61,16 @@ class TCFG(ComfyNodeABC):
|
|||||||
return [cond_pred, uncond_pred_td] + conds_out[2:]
|
return [cond_pred, uncond_pred_td] + conds_out[2:]
|
||||||
|
|
||||||
m.set_model_sampler_pre_cfg_function(tangential_damping_cfg)
|
m.set_model_sampler_pre_cfg_function(tangential_damping_cfg)
|
||||||
return (m,)
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class TcfgExtension(ComfyExtension):
|
||||||
"TCFG": TCFG,
|
@override
|
||||||
}
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
TCFG,
|
||||||
|
]
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
"TCFG": "Tangential Damping CFG",
|
async def comfy_entrypoint() -> TcfgExtension:
|
||||||
}
|
return TcfgExtension()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user