mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-28 20:43:32 +08:00
CURVE node (#12757)
* CURVE node * remove curve to sigmas node * feat: add CurveInput ABC with MonotoneCubicCurve implementation (#12986) CurveInput is an abstract base class so future curve representations (bezier, LUT-based, analytical functions) can be added without breaking downstream nodes that type-check against CurveInput. MonotoneCubicCurve is the concrete implementation that: - Mirrors frontend createMonotoneInterpolator (curveUtils.ts) exactly - Pre-computes slopes as numpy arrays at construction time - Provides vectorised interp_array() using numpy for batch evaluation - interp() for single-value evaluation - to_lut() for generating lookup tables CurveEditor node wraps raw widget points in MonotoneCubicCurve. * linear curve * refactor: move CurveEditor to comfy_extras/nodes_curve.py with V3 schema * feat: add HISTOGRAM type and histogram support to CurveEditor * code improve --------- Co-authored-by: Christian Byrne <cbyrne@comfy.org>
This commit is contained in:
parent
c2862b24af
commit
8e73678dae
@ -5,6 +5,10 @@ from comfy_api.latest._input import (
|
|||||||
MaskInput,
|
MaskInput,
|
||||||
LatentInput,
|
LatentInput,
|
||||||
VideoInput,
|
VideoInput,
|
||||||
|
CurvePoint,
|
||||||
|
CurveInput,
|
||||||
|
MonotoneCubicCurve,
|
||||||
|
LinearCurve,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -13,4 +17,8 @@ __all__ = [
|
|||||||
"MaskInput",
|
"MaskInput",
|
||||||
"LatentInput",
|
"LatentInput",
|
||||||
"VideoInput",
|
"VideoInput",
|
||||||
|
"CurvePoint",
|
||||||
|
"CurveInput",
|
||||||
|
"MonotoneCubicCurve",
|
||||||
|
"LinearCurve",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||||
|
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
|
||||||
from .video_types import VideoInput
|
from .video_types import VideoInput
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -7,4 +8,8 @@ __all__ = [
|
|||||||
"VideoInput",
|
"VideoInput",
|
||||||
"MaskInput",
|
"MaskInput",
|
||||||
"LatentInput",
|
"LatentInput",
|
||||||
|
"CurvePoint",
|
||||||
|
"CurveInput",
|
||||||
|
"MonotoneCubicCurve",
|
||||||
|
"LinearCurve",
|
||||||
]
|
]
|
||||||
|
|||||||
219
comfy_api/latest/_input/curve_types.py
Normal file
219
comfy_api/latest/_input/curve_types.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
CurvePoint = tuple[float, float]
|
||||||
|
|
||||||
|
|
||||||
|
class CurveInput(ABC):
|
||||||
|
"""Abstract base class for curve inputs.
|
||||||
|
|
||||||
|
Subclasses represent different curve representations (control-point
|
||||||
|
interpolation, analytical functions, LUT-based, etc.) while exposing a
|
||||||
|
uniform evaluation interface to downstream nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def points(self) -> list[CurvePoint]:
|
||||||
|
"""The control points that define this curve."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def interp(self, x: float) -> float:
|
||||||
|
"""Evaluate the curve at a single *x* value in [0, 1]."""
|
||||||
|
|
||||||
|
def interp_array(self, xs: np.ndarray) -> np.ndarray:
|
||||||
|
"""Vectorised evaluation over a numpy array of x values.
|
||||||
|
|
||||||
|
Subclasses should override this for better performance. The default
|
||||||
|
falls back to scalar ``interp`` calls.
|
||||||
|
"""
|
||||||
|
return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs))
|
||||||
|
|
||||||
|
def to_lut(self, size: int = 256) -> np.ndarray:
|
||||||
|
"""Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1]."""
|
||||||
|
return self.interp_array(np.linspace(0.0, 1.0, size))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_raw(data) -> CurveInput:
|
||||||
|
"""Convert raw curve data (dict or point list) to a CurveInput instance.
|
||||||
|
|
||||||
|
Accepts:
|
||||||
|
- A ``CurveInput`` instance (returned as-is).
|
||||||
|
- A dict with ``"points"`` and optional ``"interpolation"`` keys.
|
||||||
|
- A bare list/sequence of ``(x, y)`` pairs (defaults to monotone cubic).
|
||||||
|
"""
|
||||||
|
if isinstance(data, CurveInput):
|
||||||
|
return data
|
||||||
|
if isinstance(data, dict):
|
||||||
|
raw_points = data["points"]
|
||||||
|
interpolation = data.get("interpolation", "monotone_cubic")
|
||||||
|
else:
|
||||||
|
raw_points = data
|
||||||
|
interpolation = "monotone_cubic"
|
||||||
|
points = [(float(x), float(y)) for x, y in raw_points]
|
||||||
|
if interpolation == "linear":
|
||||||
|
return LinearCurve(points)
|
||||||
|
if interpolation != "monotone_cubic":
|
||||||
|
logger.warning("Unknown curve interpolation %r, falling back to monotone_cubic", interpolation)
|
||||||
|
return MonotoneCubicCurve(points)
|
||||||
|
|
||||||
|
|
||||||
|
class MonotoneCubicCurve(CurveInput):
|
||||||
|
"""Monotone cubic Hermite interpolation over control points.
|
||||||
|
|
||||||
|
Mirrors the frontend ``createMonotoneInterpolator`` in
|
||||||
|
``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that
|
||||||
|
backend evaluation matches the editor preview exactly.
|
||||||
|
|
||||||
|
All heavy work (sorting, slope computation) happens once at construction.
|
||||||
|
``interp_array`` is fully vectorised with numpy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, control_points: list[CurvePoint]):
|
||||||
|
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||||
|
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||||
|
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||||
|
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||||
|
self._slopes = self._compute_slopes()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def points(self) -> list[CurvePoint]:
|
||||||
|
return list(self._points)
|
||||||
|
|
||||||
|
def _compute_slopes(self) -> np.ndarray:
|
||||||
|
xs, ys = self._xs, self._ys
|
||||||
|
n = len(xs)
|
||||||
|
if n < 2:
|
||||||
|
return np.zeros(n, dtype=np.float64)
|
||||||
|
|
||||||
|
dx = np.diff(xs)
|
||||||
|
dy = np.diff(ys)
|
||||||
|
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||||
|
deltas = np.where(dx == 0, 0.0, dy / dx_safe)
|
||||||
|
|
||||||
|
slopes = np.empty(n, dtype=np.float64)
|
||||||
|
slopes[0] = deltas[0]
|
||||||
|
slopes[-1] = deltas[-1]
|
||||||
|
for i in range(1, n - 1):
|
||||||
|
if deltas[i - 1] * deltas[i] <= 0:
|
||||||
|
slopes[i] = 0.0
|
||||||
|
else:
|
||||||
|
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
|
||||||
|
|
||||||
|
for i in range(n - 1):
|
||||||
|
if deltas[i] == 0:
|
||||||
|
slopes[i] = 0.0
|
||||||
|
slopes[i + 1] = 0.0
|
||||||
|
else:
|
||||||
|
alpha = slopes[i] / deltas[i]
|
||||||
|
beta = slopes[i + 1] / deltas[i]
|
||||||
|
s = alpha * alpha + beta * beta
|
||||||
|
if s > 9:
|
||||||
|
t = 3 / math.sqrt(s)
|
||||||
|
slopes[i] = t * alpha * deltas[i]
|
||||||
|
slopes[i + 1] = t * beta * deltas[i]
|
||||||
|
return slopes
|
||||||
|
|
||||||
|
def interp(self, x: float) -> float:
|
||||||
|
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||||
|
n = len(xs)
|
||||||
|
if n == 0:
|
||||||
|
return 0.0
|
||||||
|
if n == 1:
|
||||||
|
return float(ys[0])
|
||||||
|
if x <= xs[0]:
|
||||||
|
return float(ys[0])
|
||||||
|
if x >= xs[-1]:
|
||||||
|
return float(ys[-1])
|
||||||
|
|
||||||
|
hi = int(np.searchsorted(xs, x, side='right'))
|
||||||
|
hi = min(hi, n - 1)
|
||||||
|
lo = hi - 1
|
||||||
|
|
||||||
|
dx = xs[hi] - xs[lo]
|
||||||
|
if dx == 0:
|
||||||
|
return float(ys[lo])
|
||||||
|
|
||||||
|
t = (x - xs[lo]) / dx
|
||||||
|
t2 = t * t
|
||||||
|
t3 = t2 * t
|
||||||
|
h00 = 2 * t3 - 3 * t2 + 1
|
||||||
|
h10 = t3 - 2 * t2 + t
|
||||||
|
h01 = -2 * t3 + 3 * t2
|
||||||
|
h11 = t3 - t2
|
||||||
|
return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi])
|
||||||
|
|
||||||
|
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||||
|
"""Fully vectorised evaluation using numpy."""
|
||||||
|
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||||
|
n = len(xs)
|
||||||
|
if n == 0:
|
||||||
|
return np.zeros_like(xs_in, dtype=np.float64)
|
||||||
|
if n == 1:
|
||||||
|
return np.full_like(xs_in, ys[0], dtype=np.float64)
|
||||||
|
|
||||||
|
hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1)
|
||||||
|
lo = hi - 1
|
||||||
|
|
||||||
|
dx = xs[hi] - xs[lo]
|
||||||
|
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||||
|
t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe)
|
||||||
|
t2 = t * t
|
||||||
|
t3 = t2 * t
|
||||||
|
|
||||||
|
h00 = 2 * t3 - 3 * t2 + 1
|
||||||
|
h10 = t3 - 2 * t2 + t
|
||||||
|
h01 = -2 * t3 + 3 * t2
|
||||||
|
h11 = t3 - t2
|
||||||
|
|
||||||
|
result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]
|
||||||
|
result = np.where(xs_in <= xs[0], ys[0], result)
|
||||||
|
result = np.where(xs_in >= xs[-1], ys[-1], result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"MonotoneCubicCurve(points={self._points})"
|
||||||
|
|
||||||
|
|
||||||
|
class LinearCurve(CurveInput):
|
||||||
|
"""Piecewise linear interpolation over control points.
|
||||||
|
|
||||||
|
Mirrors the frontend ``createLinearInterpolator`` in
|
||||||
|
``ComfyUI_frontend/src/components/curve/curveUtils.ts``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, control_points: list[CurvePoint]):
|
||||||
|
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||||
|
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||||
|
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||||
|
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def points(self) -> list[CurvePoint]:
|
||||||
|
return list(self._points)
|
||||||
|
|
||||||
|
def interp(self, x: float) -> float:
|
||||||
|
xs, ys = self._xs, self._ys
|
||||||
|
n = len(xs)
|
||||||
|
if n == 0:
|
||||||
|
return 0.0
|
||||||
|
if n == 1:
|
||||||
|
return float(ys[0])
|
||||||
|
return float(np.interp(x, xs, ys))
|
||||||
|
|
||||||
|
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||||
|
if len(self._xs) == 0:
|
||||||
|
return np.zeros_like(xs_in, dtype=np.float64)
|
||||||
|
if len(self._xs) == 1:
|
||||||
|
return np.full_like(xs_in, self._ys[0], dtype=np.float64)
|
||||||
|
return np.interp(xs_in, self._xs, self._ys)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"LinearCurve(points={self._points})"
|
||||||
@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|||||||
from comfy.samplers import CFGGuider, Sampler
|
from comfy.samplers import CFGGuider, Sampler
|
||||||
from comfy.sd import CLIP, VAE
|
from comfy.sd import CLIP, VAE
|
||||||
from comfy.sd import StyleModel as StyleModel_
|
from comfy.sd import StyleModel as StyleModel_
|
||||||
from comfy_api.input import VideoInput
|
from comfy_api.input import VideoInput, CurveInput as CurveInput_
|
||||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||||
prune_dict, shallow_clone_class)
|
prune_dict, shallow_clone_class)
|
||||||
from comfy_execution.graph_utils import ExecutionBlocker
|
from comfy_execution.graph_utils import ExecutionBlocker
|
||||||
@ -1242,8 +1242,9 @@ class BoundingBox(ComfyTypeIO):
|
|||||||
|
|
||||||
@comfytype(io_type="CURVE")
|
@comfytype(io_type="CURVE")
|
||||||
class Curve(ComfyTypeIO):
|
class Curve(ComfyTypeIO):
|
||||||
CurvePoint = tuple[float, float]
|
from comfy_api.input import CurvePoint
|
||||||
Type = list[CurvePoint]
|
if TYPE_CHECKING:
|
||||||
|
Type = CurveInput_
|
||||||
|
|
||||||
class Input(WidgetInput):
|
class Input(WidgetInput):
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||||
@ -1252,6 +1253,18 @@ class Curve(ComfyTypeIO):
|
|||||||
if default is None:
|
if default is None:
|
||||||
self.default = [(0.0, 0.0), (1.0, 1.0)]
|
self.default = [(0.0, 0.0), (1.0, 1.0)]
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
d = super().as_dict()
|
||||||
|
if self.default is not None:
|
||||||
|
d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"}
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
@comfytype(io_type="HISTOGRAM")
|
||||||
|
class Histogram(ComfyTypeIO):
|
||||||
|
"""A histogram represented as a list of bin counts."""
|
||||||
|
Type = list[int]
|
||||||
|
|
||||||
|
|
||||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||||
@ -2240,5 +2253,6 @@ __all__ = [
|
|||||||
"PriceBadge",
|
"PriceBadge",
|
||||||
"BoundingBox",
|
"BoundingBox",
|
||||||
"Curve",
|
"Curve",
|
||||||
|
"Histogram",
|
||||||
"NodeReplace",
|
"NodeReplace",
|
||||||
]
|
]
|
||||||
|
|||||||
42
comfy_extras/nodes_curve.py
Normal file
42
comfy_extras/nodes_curve.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from comfy_api.input import CurveInput
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
|
class CurveEditor(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="CurveEditor",
|
||||||
|
display_name="Curve Editor",
|
||||||
|
category="utils",
|
||||||
|
inputs=[
|
||||||
|
io.Curve.Input("curve"),
|
||||||
|
io.Histogram.Input("histogram", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Curve.Output("curve"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, curve, histogram=None) -> io.NodeOutput:
|
||||||
|
result = CurveInput.from_raw(curve)
|
||||||
|
|
||||||
|
ui = {}
|
||||||
|
if histogram is not None:
|
||||||
|
ui["histogram"] = histogram if isinstance(histogram, list) else list(histogram)
|
||||||
|
|
||||||
|
return io.NodeOutput(result, ui=ui) if ui else io.NodeOutput(result)
|
||||||
|
|
||||||
|
|
||||||
|
class CurveExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self):
|
||||||
|
return [CurveEditor]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint():
|
||||||
|
return CurveExtension()
|
||||||
Loading…
Reference in New Issue
Block a user