mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-19 06:52:31 +08:00
Merge upstream/master, keep local README.md
This commit is contained in:
commit
d4a90e2efb
File diff suppressed because one or more lines are too long
@ -55,6 +55,7 @@ total_vram = 0
|
|||||||
|
|
||||||
# Training Related State
|
# Training Related State
|
||||||
in_training = False
|
in_training = False
|
||||||
|
training_fp8_bwd = False
|
||||||
|
|
||||||
|
|
||||||
def get_supported_float8_types():
|
def get_supported_float8_types():
|
||||||
|
|||||||
65
comfy/ops.py
65
comfy/ops.py
@ -777,8 +777,16 @@ from .quant_ops import (
|
|||||||
|
|
||||||
|
|
||||||
class QuantLinearFunc(torch.autograd.Function):
|
class QuantLinearFunc(torch.autograd.Function):
|
||||||
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
|
"""Custom autograd function for quantized linear: quantized forward, optionally FP8 backward.
|
||||||
Handles any input rank by flattening to 2D for matmul and restoring shape after.
|
|
||||||
|
When training_fp8_bwd is enabled:
|
||||||
|
- Forward: quantize input per layout (FP8/NVFP4), use quantized matmul
|
||||||
|
- Backward: all matmuls use FP8 tensor cores via torch.mm dispatch
|
||||||
|
- Cached input is FP8 (half the memory of bf16)
|
||||||
|
|
||||||
|
When training_fp8_bwd is disabled:
|
||||||
|
- Forward: quantize input per layout, use quantized matmul
|
||||||
|
- Backward: dequantize weight to compute_dtype, use standard matmul
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -786,7 +794,7 @@ class QuantLinearFunc(torch.autograd.Function):
|
|||||||
input_shape = input_float.shape
|
input_shape = input_float.shape
|
||||||
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
||||||
|
|
||||||
# Quantize input (same as inference path)
|
# Quantize input for forward (same layout as weight)
|
||||||
if layout_type is not None:
|
if layout_type is not None:
|
||||||
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
||||||
else:
|
else:
|
||||||
@ -797,43 +805,68 @@ class QuantLinearFunc(torch.autograd.Function):
|
|||||||
|
|
||||||
output = torch.nn.functional.linear(q_input, w, b)
|
output = torch.nn.functional.linear(q_input, w, b)
|
||||||
|
|
||||||
# Restore original input shape
|
# Unflatten output to match original input shape
|
||||||
if len(input_shape) > 2:
|
if len(input_shape) > 2:
|
||||||
output = output.unflatten(0, input_shape[:-1])
|
output = output.unflatten(0, input_shape[:-1])
|
||||||
|
|
||||||
ctx.save_for_backward(input_float, weight)
|
# Save for backward
|
||||||
ctx.input_shape = input_shape
|
ctx.input_shape = input_shape
|
||||||
ctx.has_bias = bias is not None
|
ctx.has_bias = bias is not None
|
||||||
ctx.compute_dtype = compute_dtype
|
ctx.compute_dtype = compute_dtype
|
||||||
ctx.weight_requires_grad = weight.requires_grad
|
ctx.weight_requires_grad = weight.requires_grad
|
||||||
|
ctx.fp8_bwd = comfy.model_management.training_fp8_bwd
|
||||||
|
|
||||||
|
if ctx.fp8_bwd:
|
||||||
|
# Cache FP8 quantized input — half the memory of bf16
|
||||||
|
if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'):
|
||||||
|
ctx.q_input = q_input # already FP8, reuse
|
||||||
|
else:
|
||||||
|
# NVFP4 or other layout — quantize input to FP8 for backward
|
||||||
|
ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout")
|
||||||
|
ctx.save_for_backward(weight)
|
||||||
|
else:
|
||||||
|
ctx.q_input = None
|
||||||
|
ctx.save_for_backward(input_float, weight)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.autograd.function.once_differentiable
|
@torch.autograd.function.once_differentiable
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input_float, weight = ctx.saved_tensors
|
|
||||||
compute_dtype = ctx.compute_dtype
|
compute_dtype = ctx.compute_dtype
|
||||||
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
||||||
|
|
||||||
# Dequantize weight to compute dtype for backward matmul
|
# Value casting — only difference between fp8 and non-fp8 paths
|
||||||
if isinstance(weight, QuantizedTensor):
|
if ctx.fp8_bwd:
|
||||||
weight_f = weight.dequantize().to(compute_dtype)
|
weight, = ctx.saved_tensors
|
||||||
|
# Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm
|
||||||
|
grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout")
|
||||||
|
if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"):
|
||||||
|
weight_mm = weight
|
||||||
|
elif isinstance(weight, QuantizedTensor):
|
||||||
|
weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||||
|
else:
|
||||||
|
weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||||
|
input_mm = ctx.q_input
|
||||||
else:
|
else:
|
||||||
weight_f = weight.to(compute_dtype)
|
input_float, weight = ctx.saved_tensors
|
||||||
|
# Standard tensors → torch.mm does regular matmul
|
||||||
|
grad_mm = grad_2d
|
||||||
|
if isinstance(weight, QuantizedTensor):
|
||||||
|
weight_mm = weight.dequantize().to(compute_dtype)
|
||||||
|
else:
|
||||||
|
weight_mm = weight.to(compute_dtype)
|
||||||
|
input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None
|
||||||
|
|
||||||
# grad_input = grad_output @ weight
|
# Computation — same for both paths, dispatch handles the rest
|
||||||
grad_input = torch.mm(grad_2d, weight_f)
|
grad_input = torch.mm(grad_mm, weight_mm)
|
||||||
if len(ctx.input_shape) > 2:
|
if len(ctx.input_shape) > 2:
|
||||||
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
||||||
|
|
||||||
# grad_weight (only if weight requires grad, typically frozen for quantized training)
|
|
||||||
grad_weight = None
|
grad_weight = None
|
||||||
if ctx.weight_requires_grad:
|
if ctx.weight_requires_grad:
|
||||||
input_f = input_float.flatten(0, -2).to(compute_dtype)
|
grad_weight = torch.mm(grad_mm.t(), input_mm)
|
||||||
grad_weight = torch.mm(grad_2d.t(), input_f)
|
|
||||||
|
|
||||||
# grad_bias
|
|
||||||
grad_bias = None
|
grad_bias = None
|
||||||
if ctx.has_bias:
|
if ctx.has_bias:
|
||||||
grad_bias = grad_2d.sum(dim=0)
|
grad_bias = grad_2d.sum(dim=0)
|
||||||
|
|||||||
@ -5,6 +5,10 @@ from comfy_api.latest._input import (
|
|||||||
MaskInput,
|
MaskInput,
|
||||||
LatentInput,
|
LatentInput,
|
||||||
VideoInput,
|
VideoInput,
|
||||||
|
CurvePoint,
|
||||||
|
CurveInput,
|
||||||
|
MonotoneCubicCurve,
|
||||||
|
LinearCurve,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -13,4 +17,8 @@ __all__ = [
|
|||||||
"MaskInput",
|
"MaskInput",
|
||||||
"LatentInput",
|
"LatentInput",
|
||||||
"VideoInput",
|
"VideoInput",
|
||||||
|
"CurvePoint",
|
||||||
|
"CurveInput",
|
||||||
|
"MonotoneCubicCurve",
|
||||||
|
"LinearCurve",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||||
|
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
|
||||||
from .video_types import VideoInput
|
from .video_types import VideoInput
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -7,4 +8,8 @@ __all__ = [
|
|||||||
"VideoInput",
|
"VideoInput",
|
||||||
"MaskInput",
|
"MaskInput",
|
||||||
"LatentInput",
|
"LatentInput",
|
||||||
|
"CurvePoint",
|
||||||
|
"CurveInput",
|
||||||
|
"MonotoneCubicCurve",
|
||||||
|
"LinearCurve",
|
||||||
]
|
]
|
||||||
|
|||||||
219
comfy_api/latest/_input/curve_types.py
Normal file
219
comfy_api/latest/_input/curve_types.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
CurvePoint = tuple[float, float]
|
||||||
|
|
||||||
|
|
||||||
|
class CurveInput(ABC):
|
||||||
|
"""Abstract base class for curve inputs.
|
||||||
|
|
||||||
|
Subclasses represent different curve representations (control-point
|
||||||
|
interpolation, analytical functions, LUT-based, etc.) while exposing a
|
||||||
|
uniform evaluation interface to downstream nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def points(self) -> list[CurvePoint]:
|
||||||
|
"""The control points that define this curve."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def interp(self, x: float) -> float:
|
||||||
|
"""Evaluate the curve at a single *x* value in [0, 1]."""
|
||||||
|
|
||||||
|
def interp_array(self, xs: np.ndarray) -> np.ndarray:
|
||||||
|
"""Vectorised evaluation over a numpy array of x values.
|
||||||
|
|
||||||
|
Subclasses should override this for better performance. The default
|
||||||
|
falls back to scalar ``interp`` calls.
|
||||||
|
"""
|
||||||
|
return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs))
|
||||||
|
|
||||||
|
def to_lut(self, size: int = 256) -> np.ndarray:
|
||||||
|
"""Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1]."""
|
||||||
|
return self.interp_array(np.linspace(0.0, 1.0, size))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_raw(data) -> CurveInput:
|
||||||
|
"""Convert raw curve data (dict or point list) to a CurveInput instance.
|
||||||
|
|
||||||
|
Accepts:
|
||||||
|
- A ``CurveInput`` instance (returned as-is).
|
||||||
|
- A dict with ``"points"`` and optional ``"interpolation"`` keys.
|
||||||
|
- A bare list/sequence of ``(x, y)`` pairs (defaults to monotone cubic).
|
||||||
|
"""
|
||||||
|
if isinstance(data, CurveInput):
|
||||||
|
return data
|
||||||
|
if isinstance(data, dict):
|
||||||
|
raw_points = data["points"]
|
||||||
|
interpolation = data.get("interpolation", "monotone_cubic")
|
||||||
|
else:
|
||||||
|
raw_points = data
|
||||||
|
interpolation = "monotone_cubic"
|
||||||
|
points = [(float(x), float(y)) for x, y in raw_points]
|
||||||
|
if interpolation == "linear":
|
||||||
|
return LinearCurve(points)
|
||||||
|
if interpolation != "monotone_cubic":
|
||||||
|
logger.warning("Unknown curve interpolation %r, falling back to monotone_cubic", interpolation)
|
||||||
|
return MonotoneCubicCurve(points)
|
||||||
|
|
||||||
|
|
||||||
|
class MonotoneCubicCurve(CurveInput):
|
||||||
|
"""Monotone cubic Hermite interpolation over control points.
|
||||||
|
|
||||||
|
Mirrors the frontend ``createMonotoneInterpolator`` in
|
||||||
|
``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that
|
||||||
|
backend evaluation matches the editor preview exactly.
|
||||||
|
|
||||||
|
All heavy work (sorting, slope computation) happens once at construction.
|
||||||
|
``interp_array`` is fully vectorised with numpy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, control_points: list[CurvePoint]):
|
||||||
|
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||||
|
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||||
|
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||||
|
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||||
|
self._slopes = self._compute_slopes()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def points(self) -> list[CurvePoint]:
|
||||||
|
return list(self._points)
|
||||||
|
|
||||||
|
def _compute_slopes(self) -> np.ndarray:
|
||||||
|
xs, ys = self._xs, self._ys
|
||||||
|
n = len(xs)
|
||||||
|
if n < 2:
|
||||||
|
return np.zeros(n, dtype=np.float64)
|
||||||
|
|
||||||
|
dx = np.diff(xs)
|
||||||
|
dy = np.diff(ys)
|
||||||
|
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||||
|
deltas = np.where(dx == 0, 0.0, dy / dx_safe)
|
||||||
|
|
||||||
|
slopes = np.empty(n, dtype=np.float64)
|
||||||
|
slopes[0] = deltas[0]
|
||||||
|
slopes[-1] = deltas[-1]
|
||||||
|
for i in range(1, n - 1):
|
||||||
|
if deltas[i - 1] * deltas[i] <= 0:
|
||||||
|
slopes[i] = 0.0
|
||||||
|
else:
|
||||||
|
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
|
||||||
|
|
||||||
|
for i in range(n - 1):
|
||||||
|
if deltas[i] == 0:
|
||||||
|
slopes[i] = 0.0
|
||||||
|
slopes[i + 1] = 0.0
|
||||||
|
else:
|
||||||
|
alpha = slopes[i] / deltas[i]
|
||||||
|
beta = slopes[i + 1] / deltas[i]
|
||||||
|
s = alpha * alpha + beta * beta
|
||||||
|
if s > 9:
|
||||||
|
t = 3 / math.sqrt(s)
|
||||||
|
slopes[i] = t * alpha * deltas[i]
|
||||||
|
slopes[i + 1] = t * beta * deltas[i]
|
||||||
|
return slopes
|
||||||
|
|
||||||
|
def interp(self, x: float) -> float:
|
||||||
|
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||||
|
n = len(xs)
|
||||||
|
if n == 0:
|
||||||
|
return 0.0
|
||||||
|
if n == 1:
|
||||||
|
return float(ys[0])
|
||||||
|
if x <= xs[0]:
|
||||||
|
return float(ys[0])
|
||||||
|
if x >= xs[-1]:
|
||||||
|
return float(ys[-1])
|
||||||
|
|
||||||
|
hi = int(np.searchsorted(xs, x, side='right'))
|
||||||
|
hi = min(hi, n - 1)
|
||||||
|
lo = hi - 1
|
||||||
|
|
||||||
|
dx = xs[hi] - xs[lo]
|
||||||
|
if dx == 0:
|
||||||
|
return float(ys[lo])
|
||||||
|
|
||||||
|
t = (x - xs[lo]) / dx
|
||||||
|
t2 = t * t
|
||||||
|
t3 = t2 * t
|
||||||
|
h00 = 2 * t3 - 3 * t2 + 1
|
||||||
|
h10 = t3 - 2 * t2 + t
|
||||||
|
h01 = -2 * t3 + 3 * t2
|
||||||
|
h11 = t3 - t2
|
||||||
|
return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi])
|
||||||
|
|
||||||
|
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||||
|
"""Fully vectorised evaluation using numpy."""
|
||||||
|
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||||
|
n = len(xs)
|
||||||
|
if n == 0:
|
||||||
|
return np.zeros_like(xs_in, dtype=np.float64)
|
||||||
|
if n == 1:
|
||||||
|
return np.full_like(xs_in, ys[0], dtype=np.float64)
|
||||||
|
|
||||||
|
hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1)
|
||||||
|
lo = hi - 1
|
||||||
|
|
||||||
|
dx = xs[hi] - xs[lo]
|
||||||
|
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||||
|
t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe)
|
||||||
|
t2 = t * t
|
||||||
|
t3 = t2 * t
|
||||||
|
|
||||||
|
h00 = 2 * t3 - 3 * t2 + 1
|
||||||
|
h10 = t3 - 2 * t2 + t
|
||||||
|
h01 = -2 * t3 + 3 * t2
|
||||||
|
h11 = t3 - t2
|
||||||
|
|
||||||
|
result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]
|
||||||
|
result = np.where(xs_in <= xs[0], ys[0], result)
|
||||||
|
result = np.where(xs_in >= xs[-1], ys[-1], result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"MonotoneCubicCurve(points={self._points})"
|
||||||
|
|
||||||
|
|
||||||
|
class LinearCurve(CurveInput):
|
||||||
|
"""Piecewise linear interpolation over control points.
|
||||||
|
|
||||||
|
Mirrors the frontend ``createLinearInterpolator`` in
|
||||||
|
``ComfyUI_frontend/src/components/curve/curveUtils.ts``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, control_points: list[CurvePoint]):
|
||||||
|
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||||
|
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||||
|
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||||
|
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def points(self) -> list[CurvePoint]:
|
||||||
|
return list(self._points)
|
||||||
|
|
||||||
|
def interp(self, x: float) -> float:
|
||||||
|
xs, ys = self._xs, self._ys
|
||||||
|
n = len(xs)
|
||||||
|
if n == 0:
|
||||||
|
return 0.0
|
||||||
|
if n == 1:
|
||||||
|
return float(ys[0])
|
||||||
|
return float(np.interp(x, xs, ys))
|
||||||
|
|
||||||
|
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||||
|
if len(self._xs) == 0:
|
||||||
|
return np.zeros_like(xs_in, dtype=np.float64)
|
||||||
|
if len(self._xs) == 1:
|
||||||
|
return np.full_like(xs_in, self._ys[0], dtype=np.float64)
|
||||||
|
return np.interp(xs_in, self._xs, self._ys)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"LinearCurve(points={self._points})"
|
||||||
@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|||||||
from comfy.samplers import CFGGuider, Sampler
|
from comfy.samplers import CFGGuider, Sampler
|
||||||
from comfy.sd import CLIP, VAE
|
from comfy.sd import CLIP, VAE
|
||||||
from comfy.sd import StyleModel as StyleModel_
|
from comfy.sd import StyleModel as StyleModel_
|
||||||
from comfy_api.input import VideoInput
|
from comfy_api.input import VideoInput, CurveInput as CurveInput_
|
||||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||||
prune_dict, shallow_clone_class)
|
prune_dict, shallow_clone_class)
|
||||||
from comfy_execution.graph_utils import ExecutionBlocker
|
from comfy_execution.graph_utils import ExecutionBlocker
|
||||||
@ -1242,8 +1242,9 @@ class BoundingBox(ComfyTypeIO):
|
|||||||
|
|
||||||
@comfytype(io_type="CURVE")
|
@comfytype(io_type="CURVE")
|
||||||
class Curve(ComfyTypeIO):
|
class Curve(ComfyTypeIO):
|
||||||
CurvePoint = tuple[float, float]
|
from comfy_api.input import CurvePoint
|
||||||
Type = list[CurvePoint]
|
if TYPE_CHECKING:
|
||||||
|
Type = CurveInput_
|
||||||
|
|
||||||
class Input(WidgetInput):
|
class Input(WidgetInput):
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||||
@ -1252,6 +1253,18 @@ class Curve(ComfyTypeIO):
|
|||||||
if default is None:
|
if default is None:
|
||||||
self.default = [(0.0, 0.0), (1.0, 1.0)]
|
self.default = [(0.0, 0.0), (1.0, 1.0)]
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
d = super().as_dict()
|
||||||
|
if self.default is not None:
|
||||||
|
d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"}
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
@comfytype(io_type="HISTOGRAM")
|
||||||
|
class Histogram(ComfyTypeIO):
|
||||||
|
"""A histogram represented as a list of bin counts."""
|
||||||
|
Type = list[int]
|
||||||
|
|
||||||
|
|
||||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||||
@ -2240,5 +2253,6 @@ __all__ = [
|
|||||||
"PriceBadge",
|
"PriceBadge",
|
||||||
"BoundingBox",
|
"BoundingBox",
|
||||||
"Curve",
|
"Curve",
|
||||||
|
"Histogram",
|
||||||
"NodeReplace",
|
"NodeReplace",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -29,13 +29,21 @@ class ImageEditRequest(BaseModel):
|
|||||||
class VideoGenerationRequest(BaseModel):
|
class VideoGenerationRequest(BaseModel):
|
||||||
model: str = Field(...)
|
model: str = Field(...)
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
image: InputUrlObject | None = Field(...)
|
image: InputUrlObject | None = Field(None)
|
||||||
|
reference_images: list[InputUrlObject] | None = Field(None)
|
||||||
duration: int = Field(...)
|
duration: int = Field(...)
|
||||||
aspect_ratio: str | None = Field(...)
|
aspect_ratio: str | None = Field(...)
|
||||||
resolution: str = Field(...)
|
resolution: str = Field(...)
|
||||||
seed: int = Field(...)
|
seed: int = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoExtensionRequest(BaseModel):
|
||||||
|
prompt: str = Field(...)
|
||||||
|
video: InputUrlObject = Field(...)
|
||||||
|
duration: int = Field(default=6)
|
||||||
|
model: str | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
class VideoEditRequest(BaseModel):
|
class VideoEditRequest(BaseModel):
|
||||||
model: str = Field(...)
|
model: str = Field(...)
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from comfy_api_nodes.apis.grok import (
|
|||||||
ImageGenerationResponse,
|
ImageGenerationResponse,
|
||||||
InputUrlObject,
|
InputUrlObject,
|
||||||
VideoEditRequest,
|
VideoEditRequest,
|
||||||
|
VideoExtensionRequest,
|
||||||
VideoGenerationRequest,
|
VideoGenerationRequest,
|
||||||
VideoGenerationResponse,
|
VideoGenerationResponse,
|
||||||
VideoStatusResponse,
|
VideoStatusResponse,
|
||||||
@ -21,6 +22,7 @@ from comfy_api_nodes.util import (
|
|||||||
poll_op,
|
poll_op,
|
||||||
sync_op,
|
sync_op,
|
||||||
tensor_to_base64_string,
|
tensor_to_base64_string,
|
||||||
|
upload_images_to_comfyapi,
|
||||||
upload_video_to_comfyapi,
|
upload_video_to_comfyapi,
|
||||||
validate_string,
|
validate_string,
|
||||||
validate_video_duration,
|
validate_video_duration,
|
||||||
@ -33,6 +35,13 @@ def _extract_grok_price(response) -> float | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_grok_video_price(response) -> float | None:
|
||||||
|
price = _extract_grok_price(response)
|
||||||
|
if price is not None:
|
||||||
|
return price * 1.43
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class GrokImageNode(IO.ComfyNode):
|
class GrokImageNode(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -354,6 +363,8 @@ class GrokVideoNode(IO.ComfyNode):
|
|||||||
seed: int,
|
seed: int,
|
||||||
image: Input.Image | None = None,
|
image: Input.Image | None = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
|
if model == "grok-imagine-video-beta":
|
||||||
|
model = "grok-imagine-video"
|
||||||
image_url = None
|
image_url = None
|
||||||
if image is not None:
|
if image is not None:
|
||||||
if get_number_of_images(image) != 1:
|
if get_number_of_images(image) != 1:
|
||||||
@ -462,6 +473,244 @@ class GrokVideoEditNode(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||||
|
|
||||||
|
|
||||||
|
class GrokVideoReferenceNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="GrokVideoReferenceNode",
|
||||||
|
display_name="Grok Reference-to-Video",
|
||||||
|
category="api node/video/Grok",
|
||||||
|
description="Generate video guided by reference images as style and content references.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="Text description of the desired video.",
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"grok-imagine-video",
|
||||||
|
[
|
||||||
|
IO.Autogrow.Input(
|
||||||
|
"reference_images",
|
||||||
|
template=IO.Autogrow.TemplatePrefix(
|
||||||
|
IO.Image.Input("image"),
|
||||||
|
prefix="reference_",
|
||||||
|
min=1,
|
||||||
|
max=7,
|
||||||
|
),
|
||||||
|
tooltip="Up to 7 reference images to guide the video generation.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=["480p", "720p"],
|
||||||
|
tooltip="The resolution of the output video.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"aspect_ratio",
|
||||||
|
options=["16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"],
|
||||||
|
tooltip="The aspect ratio of the output video.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=6,
|
||||||
|
min=2,
|
||||||
|
max=10,
|
||||||
|
step=1,
|
||||||
|
tooltip="The duration of the output video in seconds.",
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tooltip="The model to use for video generation.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed to determine if node should re-run; "
|
||||||
|
"actual results are nondeterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(
|
||||||
|
widgets=["model.duration", "model.resolution"],
|
||||||
|
input_groups=["model.reference_images"],
|
||||||
|
),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$res := $lookup(widgets, "model.resolution");
|
||||||
|
$dur := $lookup(widgets, "model.duration");
|
||||||
|
$refs := inputGroups["model.reference_images"];
|
||||||
|
$rate := $res = "720p" ? 0.07 : 0.05;
|
||||||
|
$price := ($rate * $dur + 0.002 * $refs) * 1.43;
|
||||||
|
{"type":"usd","usd": $price}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
model: dict,
|
||||||
|
seed: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
|
ref_image_urls = await upload_images_to_comfyapi(
|
||||||
|
cls,
|
||||||
|
list(model["reference_images"].values()),
|
||||||
|
mime_type="image/png",
|
||||||
|
wait_label="Uploading base images",
|
||||||
|
max_images=7,
|
||||||
|
)
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
|
||||||
|
data=VideoGenerationRequest(
|
||||||
|
model=model["model"],
|
||||||
|
reference_images=[InputUrlObject(url=i) for i in ref_image_urls],
|
||||||
|
prompt=prompt,
|
||||||
|
resolution=model["resolution"],
|
||||||
|
duration=model["duration"],
|
||||||
|
aspect_ratio=model["aspect_ratio"],
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
response_model=VideoGenerationResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||||
|
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||||
|
response_model=VideoStatusResponse,
|
||||||
|
price_extractor=_extract_grok_video_price,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||||
|
|
||||||
|
|
||||||
|
class GrokVideoExtendNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="GrokVideoExtendNode",
|
||||||
|
display_name="Grok Video Extend",
|
||||||
|
category="api node/video/Grok",
|
||||||
|
description="Extend an existing video with a seamless continuation based on a text prompt.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="Text description of what should happen next in the video.",
|
||||||
|
),
|
||||||
|
IO.Video.Input("video", tooltip="Source video to extend. MP4 format, 2-15 seconds."),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"grok-imagine-video",
|
||||||
|
[
|
||||||
|
IO.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=8,
|
||||||
|
min=2,
|
||||||
|
max=10,
|
||||||
|
step=1,
|
||||||
|
tooltip="Length of the extension in seconds.",
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tooltip="The model to use for video extension.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed to determine if node should re-run; "
|
||||||
|
"actual results are nondeterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(widgets=["model.duration"]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$dur := $lookup(widgets, "model.duration");
|
||||||
|
{
|
||||||
|
"type": "range_usd",
|
||||||
|
"min_usd": (0.02 + 0.05 * $dur) * 1.43,
|
||||||
|
"max_usd": (0.15 + 0.05 * $dur) * 1.43
|
||||||
|
}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
video: Input.Video,
|
||||||
|
model: dict,
|
||||||
|
seed: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
|
validate_video_duration(video, min_duration=2, max_duration=15)
|
||||||
|
video_size = get_fs_object_size(video.get_stream_source())
|
||||||
|
if video_size > 50 * 1024 * 1024:
|
||||||
|
raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.")
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/xai/v1/videos/extensions", method="POST"),
|
||||||
|
data=VideoExtensionRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)),
|
||||||
|
duration=model["duration"],
|
||||||
|
),
|
||||||
|
response_model=VideoGenerationResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||||
|
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||||
|
response_model=VideoStatusResponse,
|
||||||
|
price_extractor=_extract_grok_video_price,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||||
|
|
||||||
|
|
||||||
class GrokExtension(ComfyExtension):
|
class GrokExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -469,7 +718,9 @@ class GrokExtension(ComfyExtension):
|
|||||||
GrokImageNode,
|
GrokImageNode,
|
||||||
GrokImageEditNode,
|
GrokImageEditNode,
|
||||||
GrokVideoNode,
|
GrokVideoNode,
|
||||||
|
GrokVideoReferenceNode,
|
||||||
GrokVideoEditNode,
|
GrokVideoEditNode,
|
||||||
|
GrokVideoExtendNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
42
comfy_extras/nodes_curve.py
Normal file
42
comfy_extras/nodes_curve.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from comfy_api.input import CurveInput
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
|
class CurveEditor(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="CurveEditor",
|
||||||
|
display_name="Curve Editor",
|
||||||
|
category="utils",
|
||||||
|
inputs=[
|
||||||
|
io.Curve.Input("curve"),
|
||||||
|
io.Histogram.Input("histogram", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Curve.Output("curve"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, curve, histogram=None) -> io.NodeOutput:
|
||||||
|
result = CurveInput.from_raw(curve)
|
||||||
|
|
||||||
|
ui = {}
|
||||||
|
if histogram is not None:
|
||||||
|
ui["histogram"] = histogram if isinstance(histogram, list) else list(histogram)
|
||||||
|
|
||||||
|
return io.NodeOutput(result, ui=ui) if ui else io.NodeOutput(result)
|
||||||
|
|
||||||
|
|
||||||
|
class CurveExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self):
|
||||||
|
return [CurveEditor]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint():
|
||||||
|
return CurveExtension()
|
||||||
79
comfy_extras/nodes_number_convert.py
Normal file
79
comfy_extras/nodes_number_convert.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
"""Number Convert node for unified numeric type conversion.
|
||||||
|
|
||||||
|
Provides a single node that converts INT, FLOAT, STRING, and BOOL
|
||||||
|
inputs into FLOAT and INT outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
|
class NumberConvertNode(io.ComfyNode):
|
||||||
|
"""Converts various types to numeric FLOAT and INT outputs."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ComfyNumberConvert",
|
||||||
|
display_name="Number Convert",
|
||||||
|
category="math",
|
||||||
|
search_aliases=[
|
||||||
|
"int to float", "float to int", "number convert",
|
||||||
|
"int2float", "float2int", "cast", "parse number",
|
||||||
|
"string to number", "bool to int",
|
||||||
|
],
|
||||||
|
inputs=[
|
||||||
|
io.MultiType.Input(
|
||||||
|
"value",
|
||||||
|
[io.Int, io.Float, io.String, io.Boolean],
|
||||||
|
display_name="value",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Float.Output(display_name="FLOAT"),
|
||||||
|
io.Int.Output(display_name="INT"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, value) -> io.NodeOutput:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
float_val = 1.0 if value else 0.0
|
||||||
|
elif isinstance(value, (int, float)):
|
||||||
|
float_val = float(value)
|
||||||
|
elif isinstance(value, str):
|
||||||
|
text = value.strip()
|
||||||
|
if not text:
|
||||||
|
raise ValueError("Cannot convert empty string to number.")
|
||||||
|
try:
|
||||||
|
float_val = float(text)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot convert string to number: {value!r}"
|
||||||
|
) from None
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"Unsupported input type: {type(value).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not math.isfinite(float_val):
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot convert non-finite value to number: {float_val}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return io.NodeOutput(float_val, int(float_val))
|
||||||
|
|
||||||
|
|
||||||
|
class NumberConvertExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [NumberConvertNode]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> NumberConvertExtension:
|
||||||
|
return NumberConvertExtension()
|
||||||
@ -1030,6 +1030,11 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
default="bf16",
|
default="bf16",
|
||||||
tooltip="The dtype to use for lora.",
|
tooltip="The dtype to use for lora.",
|
||||||
),
|
),
|
||||||
|
io.Boolean.Input(
|
||||||
|
"quantized_backward",
|
||||||
|
default=False,
|
||||||
|
tooltip="When using training_dtype 'none' and training on quantized model, doing backward with quantized matmul when enabled.",
|
||||||
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"algorithm",
|
"algorithm",
|
||||||
options=list(adapter_maps.keys()),
|
options=list(adapter_maps.keys()),
|
||||||
@ -1097,6 +1102,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
seed,
|
seed,
|
||||||
training_dtype,
|
training_dtype,
|
||||||
lora_dtype,
|
lora_dtype,
|
||||||
|
quantized_backward,
|
||||||
algorithm,
|
algorithm,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
checkpoint_depth,
|
checkpoint_depth,
|
||||||
@ -1117,6 +1123,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
seed = seed[0]
|
seed = seed[0]
|
||||||
training_dtype = training_dtype[0]
|
training_dtype = training_dtype[0]
|
||||||
lora_dtype = lora_dtype[0]
|
lora_dtype = lora_dtype[0]
|
||||||
|
quantized_backward = quantized_backward[0]
|
||||||
algorithm = algorithm[0]
|
algorithm = algorithm[0]
|
||||||
gradient_checkpointing = gradient_checkpointing[0]
|
gradient_checkpointing = gradient_checkpointing[0]
|
||||||
offloading = offloading[0]
|
offloading = offloading[0]
|
||||||
@ -1125,6 +1132,8 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
bucket_mode = bucket_mode[0]
|
bucket_mode = bucket_mode[0]
|
||||||
bypass_mode = bypass_mode[0]
|
bypass_mode = bypass_mode[0]
|
||||||
|
|
||||||
|
comfy.model_management.training_fp8_bwd = quantized_backward
|
||||||
|
|
||||||
# Process latents based on mode
|
# Process latents based on mode
|
||||||
if bucket_mode:
|
if bucket_mode:
|
||||||
latents = _process_latents_bucket_mode(latents)
|
latents = _process_latents_bucket_mode(latents)
|
||||||
|
|||||||
2
nodes.py
2
nodes.py
@ -2454,7 +2454,9 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_nag.py",
|
"nodes_nag.py",
|
||||||
"nodes_sdpose.py",
|
"nodes_sdpose.py",
|
||||||
"nodes_math.py",
|
"nodes_math.py",
|
||||||
|
"nodes_number_convert.py",
|
||||||
"nodes_painter.py",
|
"nodes_painter.py",
|
||||||
|
"nodes_curve.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.41.21
|
comfyui-frontend-package==1.42.8
|
||||||
comfyui-workflow-templates==0.9.26
|
comfyui-workflow-templates==0.9.36
|
||||||
comfyui-embedded-docs==0.4.3
|
comfyui-embedded-docs==0.4.3
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
|
|||||||
123
tests-unit/comfy_extras_test/nodes_number_convert_test.py
Normal file
123
tests-unit/comfy_extras_test/nodes_number_convert_test.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
mock_nodes = MagicMock()
|
||||||
|
mock_nodes.MAX_RESOLUTION = 16384
|
||||||
|
mock_server = MagicMock()
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}):
|
||||||
|
from comfy_extras.nodes_number_convert import NumberConvertNode
|
||||||
|
|
||||||
|
|
||||||
|
class TestNumberConvertExecute:
|
||||||
|
@staticmethod
|
||||||
|
def _exec(value) -> object:
|
||||||
|
return NumberConvertNode.execute(value)
|
||||||
|
|
||||||
|
# --- INT input ---
|
||||||
|
|
||||||
|
def test_int_input(self):
|
||||||
|
result = self._exec(42)
|
||||||
|
assert result[0] == 42.0
|
||||||
|
assert result[1] == 42
|
||||||
|
|
||||||
|
def test_int_zero(self):
|
||||||
|
result = self._exec(0)
|
||||||
|
assert result[0] == 0.0
|
||||||
|
assert result[1] == 0
|
||||||
|
|
||||||
|
def test_int_negative(self):
|
||||||
|
result = self._exec(-7)
|
||||||
|
assert result[0] == -7.0
|
||||||
|
assert result[1] == -7
|
||||||
|
|
||||||
|
# --- FLOAT input ---
|
||||||
|
|
||||||
|
def test_float_input(self):
|
||||||
|
result = self._exec(3.14)
|
||||||
|
assert result[0] == 3.14
|
||||||
|
assert result[1] == 3
|
||||||
|
|
||||||
|
def test_float_truncation_toward_zero(self):
|
||||||
|
result = self._exec(-2.9)
|
||||||
|
assert result[0] == -2.9
|
||||||
|
assert result[1] == -2 # int() truncates toward zero, not floor
|
||||||
|
|
||||||
|
def test_float_output_type(self):
|
||||||
|
result = self._exec(5)
|
||||||
|
assert isinstance(result[0], float)
|
||||||
|
|
||||||
|
def test_int_output_type(self):
|
||||||
|
result = self._exec(5.7)
|
||||||
|
assert isinstance(result[1], int)
|
||||||
|
|
||||||
|
# --- BOOL input ---
|
||||||
|
|
||||||
|
def test_bool_true(self):
|
||||||
|
result = self._exec(True)
|
||||||
|
assert result[0] == 1.0
|
||||||
|
assert result[1] == 1
|
||||||
|
|
||||||
|
def test_bool_false(self):
|
||||||
|
result = self._exec(False)
|
||||||
|
assert result[0] == 0.0
|
||||||
|
assert result[1] == 0
|
||||||
|
|
||||||
|
# --- STRING input ---
|
||||||
|
|
||||||
|
def test_string_integer(self):
|
||||||
|
result = self._exec("42")
|
||||||
|
assert result[0] == 42.0
|
||||||
|
assert result[1] == 42
|
||||||
|
|
||||||
|
def test_string_float(self):
|
||||||
|
result = self._exec("3.14")
|
||||||
|
assert result[0] == 3.14
|
||||||
|
assert result[1] == 3
|
||||||
|
|
||||||
|
def test_string_negative(self):
|
||||||
|
result = self._exec("-5.5")
|
||||||
|
assert result[0] == -5.5
|
||||||
|
assert result[1] == -5
|
||||||
|
|
||||||
|
def test_string_with_whitespace(self):
|
||||||
|
result = self._exec(" 7.0 ")
|
||||||
|
assert result[0] == 7.0
|
||||||
|
assert result[1] == 7
|
||||||
|
|
||||||
|
def test_string_scientific_notation(self):
|
||||||
|
result = self._exec("1e3")
|
||||||
|
assert result[0] == 1000.0
|
||||||
|
assert result[1] == 1000
|
||||||
|
|
||||||
|
# --- STRING error paths ---
|
||||||
|
|
||||||
|
def test_empty_string_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Cannot convert empty string"):
|
||||||
|
self._exec("")
|
||||||
|
|
||||||
|
def test_whitespace_only_string_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Cannot convert empty string"):
|
||||||
|
self._exec(" ")
|
||||||
|
|
||||||
|
def test_non_numeric_string_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Cannot convert string to number"):
|
||||||
|
self._exec("abc")
|
||||||
|
|
||||||
|
def test_string_inf_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="non-finite"):
|
||||||
|
self._exec("inf")
|
||||||
|
|
||||||
|
def test_string_nan_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="non-finite"):
|
||||||
|
self._exec("nan")
|
||||||
|
|
||||||
|
def test_string_negative_inf_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="non-finite"):
|
||||||
|
self._exec("-inf")
|
||||||
|
|
||||||
|
# --- Unsupported type ---
|
||||||
|
|
||||||
|
def test_unsupported_type_raises(self):
|
||||||
|
with pytest.raises(TypeError, match="Unsupported input type"):
|
||||||
|
self._exec([1, 2, 3])
|
||||||
Loading…
Reference in New Issue
Block a user