convert nodes_fresca.py to V3 schema (#9951)

This commit is contained in:
Alexander Piskun 2025-09-27 00:12:04 +03:00 committed by GitHub
parent 76eb1d72c3
commit 7ea173c187
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,8 @@
# Code based on https://github.com/WikiChao/FreSca (MIT License)
import torch
import torch.fft as fft
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
@ -51,25 +53,31 @@ def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
return x_filtered
class FreSca:
class FreSca(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01,
"tooltip": "Scaling factor for low-frequency components"}),
"scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01,
"tooltip": "Scaling factor for high-frequency components"}),
"freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1,
"tooltip": "Number of frequency indices around center to consider as low-frequency"}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
DESCRIPTION = "Applies frequency-dependent scaling to the guidance"
def patch(self, model, scale_low, scale_high, freq_cutoff):
def define_schema(cls):
return io.Schema(
node_id="FreSca",
display_name="FreSca",
category="_for_testing",
description="Applies frequency-dependent scaling to the guidance",
inputs=[
io.Model.Input("model"),
io.Float.Input("scale_low", default=1.0, min=0, max=10, step=0.01,
tooltip="Scaling factor for low-frequency components"),
io.Float.Input("scale_high", default=1.25, min=0, max=10, step=0.01,
tooltip="Scaling factor for high-frequency components"),
io.Int.Input("freq_cutoff", default=20, min=1, max=10000, step=1,
tooltip="Number of frequency indices around center to consider as low-frequency"),
],
outputs=[
io.Model.Output(),
],
is_experimental=True,
)
@classmethod
def execute(cls, model, scale_low, scale_high, freq_cutoff):
def custom_cfg_function(args):
conds_out = args["conds_out"]
if len(conds_out) <= 1 or None in args["conds"][:2]:
@ -91,13 +99,16 @@ class FreSca:
m = model.clone()
m.set_model_sampler_pre_cfg_function(custom_cfg_function)
return (m,)
return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"FreSca": FreSca,
}
class FreScaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
FreSca,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"FreSca": "FreSca",
}
async def comfy_entrypoint() -> FreScaExtension:
return FreScaExtension()