diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 1e957c09b..2fbfd4139 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -1049,6 +1049,48 @@ class ManualSigmas(io.ComfyNode): sigmas = torch.FloatTensor(sigmas) return io.NodeOutput(sigmas) +class CurveToSigmas(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CurveToSigmas", + display_name="Curve to Sigmas", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Curve.Input("curve", default=[[0.0, 1.0], [1.0, 0.0]]), + io.Model.Input("model", optional=True), + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, optional=True), + ], + outputs=[io.Sigmas.Output()] + ) + + @classmethod + def execute(cls, curve, steps, sigma_max=14.614642, model=None) -> io.NodeOutput: + points = sorted(curve, key=lambda p: p[0]) + model_sampling = model.get_model_object("model_sampling") if model is not None else None + + sigmas = [] + for i in range(steps + 1): + t = i / steps + y = points[0][1] if t < points[0][0] else points[-1][1] + for j in range(len(points) - 1): + if points[j][0] <= t <= points[j + 1][0]: + x0, y0 = points[j] + x1, y1 = points[j + 1] + y = y0 if x1 == x0 else y0 + (y1 - y0) * (t - x0) / (x1 - x0) + break + if model_sampling is not None: + sigmas.append(float(model_sampling.percent_to_sigma(1.0 - y))) + else: + sigmas.append(y * sigma_max) + + sigmas[-1] = 0.0 + return io.NodeOutput(torch.FloatTensor(sigmas)) + + get_sigmas = execute + + class CustomSamplersExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -1088,6 +1130,7 @@ class CustomSamplersExtension(ComfyExtension): AddNoise, SamplerCustomAdvanced, ManualSigmas, + CurveToSigmas, ] diff --git a/nodes.py b/nodes.py index 5be9b16f9..aba96a49d 100644 --- a/nodes.py +++ b/nodes.py @@ -2034,6 +2034,24 @@ class ImagePadForOutpaint: return (new_image, mask.unsqueeze(0)) +class CurveEditor: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "curve": ("CURVE", {"default": [[0, 0], [1, 1]]}), + } + } + + RETURN_TYPES = ("CURVE",) + RETURN_NAMES = ("curve",) + FUNCTION = "execute" + CATEGORY = "utils" + + def execute(self, curve): + return (curve,) + + NODE_CLASS_MAPPINGS = { "KSampler": KSampler, "CheckpointLoaderSimple": CheckpointLoaderSimple, @@ -2102,6 +2120,7 @@ NODE_CLASS_MAPPINGS = { "ConditioningZeroOut": ConditioningZeroOut, "ConditioningSetTimestepRange": ConditioningSetTimestepRange, "LoraLoaderModelOnly": LoraLoaderModelOnly, + "CurveEditor": CurveEditor, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -2170,6 +2189,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { # _for_testing "VAEDecodeTiled": "VAE Decode (Tiled)", "VAEEncodeTiled": "VAE Encode (Tiled)", + "CurveEditor": "Curve Editor", } EXTENSION_WEB_DIRS = {}