From 8e73678dae6e5763bc860d6f98566243a494f9c2 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Tue, 24 Mar 2026 17:47:28 -0400 Subject: [PATCH] CURVE node (#12757) * CURVE node * remove curve to sigmas node * feat: add CurveInput ABC with MonotoneCubicCurve implementation (#12986) CurveInput is an abstract base class so future curve representations (bezier, LUT-based, analytical functions) can be added without breaking downstream nodes that type-check against CurveInput. MonotoneCubicCurve is the concrete implementation that: - Mirrors frontend createMonotoneInterpolator (curveUtils.ts) exactly - Pre-computes slopes as numpy arrays at construction time - Provides vectorised interp_array() using numpy for batch evaluation - interp() for single-value evaluation - to_lut() for generating lookup tables CurveEditor node wraps raw widget points in MonotoneCubicCurve. * linear curve * refactor: move CurveEditor to comfy_extras/nodes_curve.py with V3 schema * feat: add HISTOGRAM type and histogram support to CurveEditor * code improve --------- Co-authored-by: Christian Byrne --- comfy_api/input/__init__.py | 8 + comfy_api/latest/_input/__init__.py | 5 + comfy_api/latest/_input/curve_types.py | 219 +++++++++++++++++++++++++ comfy_api/latest/_io.py | 20 ++- comfy_extras/nodes_curve.py | 42 +++++ nodes.py | 1 + 6 files changed, 292 insertions(+), 3 deletions(-) create mode 100644 comfy_api/latest/_input/curve_types.py create mode 100644 comfy_extras/nodes_curve.py diff --git a/comfy_api/input/__init__.py b/comfy_api/input/__init__.py index 68ff78270..16d4acfd1 100644 --- a/comfy_api/input/__init__.py +++ b/comfy_api/input/__init__.py @@ -5,6 +5,10 @@ from comfy_api.latest._input import ( MaskInput, LatentInput, VideoInput, + CurvePoint, + CurveInput, + MonotoneCubicCurve, + LinearCurve, ) __all__ = [ @@ -13,4 +17,8 @@ __all__ = [ "MaskInput", "LatentInput", "VideoInput", + "CurvePoint", + "CurveInput", + "MonotoneCubicCurve", + "LinearCurve", ] diff --git a/comfy_api/latest/_input/__init__.py b/comfy_api/latest/_input/__init__.py index 14f0e72f4..05cd3d40a 100644 --- a/comfy_api/latest/_input/__init__.py +++ b/comfy_api/latest/_input/__init__.py @@ -1,4 +1,5 @@ from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput +from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve from .video_types import VideoInput __all__ = [ @@ -7,4 +8,8 @@ __all__ = [ "VideoInput", "MaskInput", "LatentInput", + "CurvePoint", + "CurveInput", + "MonotoneCubicCurve", + "LinearCurve", ] diff --git a/comfy_api/latest/_input/curve_types.py b/comfy_api/latest/_input/curve_types.py new file mode 100644 index 000000000..b6dd7adf9 --- /dev/null +++ b/comfy_api/latest/_input/curve_types.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import logging +import math +from abc import ABC, abstractmethod +import numpy as np + +logger = logging.getLogger(__name__) + + +CurvePoint = tuple[float, float] + + +class CurveInput(ABC): + """Abstract base class for curve inputs. + + Subclasses represent different curve representations (control-point + interpolation, analytical functions, LUT-based, etc.) while exposing a + uniform evaluation interface to downstream nodes. + """ + + @property + @abstractmethod + def points(self) -> list[CurvePoint]: + """The control points that define this curve.""" + + @abstractmethod + def interp(self, x: float) -> float: + """Evaluate the curve at a single *x* value in [0, 1].""" + + def interp_array(self, xs: np.ndarray) -> np.ndarray: + """Vectorised evaluation over a numpy array of x values. + + Subclasses should override this for better performance. The default + falls back to scalar ``interp`` calls. + """ + return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs)) + + def to_lut(self, size: int = 256) -> np.ndarray: + """Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1].""" + return self.interp_array(np.linspace(0.0, 1.0, size)) + + @staticmethod + def from_raw(data) -> CurveInput: + """Convert raw curve data (dict or point list) to a CurveInput instance. + + Accepts: + - A ``CurveInput`` instance (returned as-is). + - A dict with ``"points"`` and optional ``"interpolation"`` keys. + - A bare list/sequence of ``(x, y)`` pairs (defaults to monotone cubic). + """ + if isinstance(data, CurveInput): + return data + if isinstance(data, dict): + raw_points = data["points"] + interpolation = data.get("interpolation", "monotone_cubic") + else: + raw_points = data + interpolation = "monotone_cubic" + points = [(float(x), float(y)) for x, y in raw_points] + if interpolation == "linear": + return LinearCurve(points) + if interpolation != "monotone_cubic": + logger.warning("Unknown curve interpolation %r, falling back to monotone_cubic", interpolation) + return MonotoneCubicCurve(points) + + +class MonotoneCubicCurve(CurveInput): + """Monotone cubic Hermite interpolation over control points. + + Mirrors the frontend ``createMonotoneInterpolator`` in + ``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that + backend evaluation matches the editor preview exactly. + + All heavy work (sorting, slope computation) happens once at construction. + ``interp_array`` is fully vectorised with numpy. + """ + + def __init__(self, control_points: list[CurvePoint]): + sorted_pts = sorted(control_points, key=lambda p: p[0]) + self._points = [(float(x), float(y)) for x, y in sorted_pts] + self._xs = np.array([p[0] for p in self._points], dtype=np.float64) + self._ys = np.array([p[1] for p in self._points], dtype=np.float64) + self._slopes = self._compute_slopes() + + @property + def points(self) -> list[CurvePoint]: + return list(self._points) + + def _compute_slopes(self) -> np.ndarray: + xs, ys = self._xs, self._ys + n = len(xs) + if n < 2: + return np.zeros(n, dtype=np.float64) + + dx = np.diff(xs) + dy = np.diff(ys) + dx_safe = np.where(dx == 0, 1.0, dx) + deltas = np.where(dx == 0, 0.0, dy / dx_safe) + + slopes = np.empty(n, dtype=np.float64) + slopes[0] = deltas[0] + slopes[-1] = deltas[-1] + for i in range(1, n - 1): + if deltas[i - 1] * deltas[i] <= 0: + slopes[i] = 0.0 + else: + slopes[i] = (deltas[i - 1] + deltas[i]) / 2 + + for i in range(n - 1): + if deltas[i] == 0: + slopes[i] = 0.0 + slopes[i + 1] = 0.0 + else: + alpha = slopes[i] / deltas[i] + beta = slopes[i + 1] / deltas[i] + s = alpha * alpha + beta * beta + if s > 9: + t = 3 / math.sqrt(s) + slopes[i] = t * alpha * deltas[i] + slopes[i + 1] = t * beta * deltas[i] + return slopes + + def interp(self, x: float) -> float: + xs, ys, slopes = self._xs, self._ys, self._slopes + n = len(xs) + if n == 0: + return 0.0 + if n == 1: + return float(ys[0]) + if x <= xs[0]: + return float(ys[0]) + if x >= xs[-1]: + return float(ys[-1]) + + hi = int(np.searchsorted(xs, x, side='right')) + hi = min(hi, n - 1) + lo = hi - 1 + + dx = xs[hi] - xs[lo] + if dx == 0: + return float(ys[lo]) + + t = (x - xs[lo]) / dx + t2 = t * t + t3 = t2 * t + h00 = 2 * t3 - 3 * t2 + 1 + h10 = t3 - 2 * t2 + t + h01 = -2 * t3 + 3 * t2 + h11 = t3 - t2 + return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]) + + def interp_array(self, xs_in: np.ndarray) -> np.ndarray: + """Fully vectorised evaluation using numpy.""" + xs, ys, slopes = self._xs, self._ys, self._slopes + n = len(xs) + if n == 0: + return np.zeros_like(xs_in, dtype=np.float64) + if n == 1: + return np.full_like(xs_in, ys[0], dtype=np.float64) + + hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1) + lo = hi - 1 + + dx = xs[hi] - xs[lo] + dx_safe = np.where(dx == 0, 1.0, dx) + t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe) + t2 = t * t + t3 = t2 * t + + h00 = 2 * t3 - 3 * t2 + 1 + h10 = t3 - 2 * t2 + t + h01 = -2 * t3 + 3 * t2 + h11 = t3 - t2 + + result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi] + result = np.where(xs_in <= xs[0], ys[0], result) + result = np.where(xs_in >= xs[-1], ys[-1], result) + return result + + def __repr__(self) -> str: + return f"MonotoneCubicCurve(points={self._points})" + + +class LinearCurve(CurveInput): + """Piecewise linear interpolation over control points. + + Mirrors the frontend ``createLinearInterpolator`` in + ``ComfyUI_frontend/src/components/curve/curveUtils.ts``. + """ + + def __init__(self, control_points: list[CurvePoint]): + sorted_pts = sorted(control_points, key=lambda p: p[0]) + self._points = [(float(x), float(y)) for x, y in sorted_pts] + self._xs = np.array([p[0] for p in self._points], dtype=np.float64) + self._ys = np.array([p[1] for p in self._points], dtype=np.float64) + + @property + def points(self) -> list[CurvePoint]: + return list(self._points) + + def interp(self, x: float) -> float: + xs, ys = self._xs, self._ys + n = len(xs) + if n == 0: + return 0.0 + if n == 1: + return float(ys[0]) + return float(np.interp(x, xs, ys)) + + def interp_array(self, xs_in: np.ndarray) -> np.ndarray: + if len(self._xs) == 0: + return np.zeros_like(xs_in, dtype=np.float64) + if len(self._xs) == 1: + return np.full_like(xs_in, self._ys[0], dtype=np.float64) + return np.interp(xs_in, self._xs, self._ys) + + def __repr__(self) -> str: + return f"LinearCurve(points={self._points})" diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 7ca8f4e0c..1cbc8ed26 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from comfy.samplers import CFGGuider, Sampler from comfy.sd import CLIP, VAE from comfy.sd import StyleModel as StyleModel_ - from comfy_api.input import VideoInput + from comfy_api.input import VideoInput, CurveInput as CurveInput_ from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) from comfy_execution.graph_utils import ExecutionBlocker @@ -1242,8 +1242,9 @@ class BoundingBox(ComfyTypeIO): @comfytype(io_type="CURVE") class Curve(ComfyTypeIO): - CurvePoint = tuple[float, float] - Type = list[CurvePoint] + from comfy_api.input import CurvePoint + if TYPE_CHECKING: + Type = CurveInput_ class Input(WidgetInput): def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, @@ -1252,6 +1253,18 @@ class Curve(ComfyTypeIO): if default is None: self.default = [(0.0, 0.0), (1.0, 1.0)] + def as_dict(self): + d = super().as_dict() + if self.default is not None: + d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"} + return d + + +@comfytype(io_type="HISTOGRAM") +class Histogram(ComfyTypeIO): + """A histogram represented as a list of bin counts.""" + Type = list[int] + DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): @@ -2240,5 +2253,6 @@ __all__ = [ "PriceBadge", "BoundingBox", "Curve", + "Histogram", "NodeReplace", ] diff --git a/comfy_extras/nodes_curve.py b/comfy_extras/nodes_curve.py new file mode 100644 index 000000000..9016a84f9 --- /dev/null +++ b/comfy_extras/nodes_curve.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from comfy_api.latest import ComfyExtension, io +from comfy_api.input import CurveInput +from typing_extensions import override + + +class CurveEditor(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CurveEditor", + display_name="Curve Editor", + category="utils", + inputs=[ + io.Curve.Input("curve"), + io.Histogram.Input("histogram", optional=True), + ], + outputs=[ + io.Curve.Output("curve"), + ], + ) + + @classmethod + def execute(cls, curve, histogram=None) -> io.NodeOutput: + result = CurveInput.from_raw(curve) + + ui = {} + if histogram is not None: + ui["histogram"] = histogram if isinstance(histogram, list) else list(histogram) + + return io.NodeOutput(result, ui=ui) if ui else io.NodeOutput(result) + + +class CurveExtension(ComfyExtension): + @override + async def get_node_list(self): + return [CurveEditor] + + +async def comfy_entrypoint(): + return CurveExtension() diff --git a/nodes.py b/nodes.py index 2c4650a20..79874d051 100644 --- a/nodes.py +++ b/nodes.py @@ -2455,6 +2455,7 @@ async def init_builtin_extra_nodes(): "nodes_sdpose.py", "nodes_math.py", "nodes_painter.py", + "nodes_curve.py", ] import_failed = []