convert nodes_eps.py to V3 schema (#10172)

This commit is contained in:
Alexander Piskun 2025-10-03 21:45:02 +03:00 committed by GitHub
parent 3e68bc342c
commit d7aa414141
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,4 +1,9 @@
class EpsilonScaling: from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class EpsilonScaling(io.ComfyNode):
""" """
Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models' Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models'
(https://arxiv.org/abs/2308.15321v6). (https://arxiv.org/abs/2308.15321v6).
@ -8,26 +13,28 @@ class EpsilonScaling:
recommended by the paper for its practicality and effectiveness. recommended by the paper for its practicality and effectiveness.
""" """
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="Epsilon Scaling",
"model": ("MODEL",), category="model_patches/unet",
"scaling_factor": ("FLOAT", { inputs=[
"default": 1.005, io.Model.Input("model"),
"min": 0.5, io.Float.Input(
"max": 1.5, "scaling_factor",
"step": 0.001, default=1.005,
"display": "number" min=0.5,
}), max=1.5,
} step=0.001,
} display_mode=io.NumberDisplay.number,
),
],
outputs=[
io.Model.Output(),
],
)
RETURN_TYPES = ("MODEL",) @classmethod
FUNCTION = "patch" def execute(cls, model, scaling_factor) -> io.NodeOutput:
CATEGORY = "model_patches/unet"
def patch(self, model, scaling_factor):
# Prevent division by zero, though the UI's min value should prevent this. # Prevent division by zero, though the UI's min value should prevent this.
if scaling_factor == 0: if scaling_factor == 0:
scaling_factor = 1e-9 scaling_factor = 1e-9
@ -53,8 +60,15 @@ class EpsilonScaling:
model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function) model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function)
return (model_clone,) return io.NodeOutput(model_clone)
NODE_CLASS_MAPPINGS = {
"Epsilon Scaling": EpsilonScaling class EpsilonScalingExtension(ComfyExtension):
} @override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EpsilonScaling,
]
async def comfy_entrypoint() -> EpsilonScalingExtension:
return EpsilonScalingExtension()