mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
Merge upstream/master, keep local README.md
This commit is contained in:
commit
d4a90e2efb
File diff suppressed because one or more lines are too long
@ -55,6 +55,7 @@ total_vram = 0
|
||||
|
||||
# Training Related State
|
||||
in_training = False
|
||||
training_fp8_bwd = False
|
||||
|
||||
|
||||
def get_supported_float8_types():
|
||||
|
||||
65
comfy/ops.py
65
comfy/ops.py
@ -777,8 +777,16 @@ from .quant_ops import (
|
||||
|
||||
|
||||
class QuantLinearFunc(torch.autograd.Function):
|
||||
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
|
||||
Handles any input rank by flattening to 2D for matmul and restoring shape after.
|
||||
"""Custom autograd function for quantized linear: quantized forward, optionally FP8 backward.
|
||||
|
||||
When training_fp8_bwd is enabled:
|
||||
- Forward: quantize input per layout (FP8/NVFP4), use quantized matmul
|
||||
- Backward: all matmuls use FP8 tensor cores via torch.mm dispatch
|
||||
- Cached input is FP8 (half the memory of bf16)
|
||||
|
||||
When training_fp8_bwd is disabled:
|
||||
- Forward: quantize input per layout, use quantized matmul
|
||||
- Backward: dequantize weight to compute_dtype, use standard matmul
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@ -786,7 +794,7 @@ class QuantLinearFunc(torch.autograd.Function):
|
||||
input_shape = input_float.shape
|
||||
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
||||
|
||||
# Quantize input (same as inference path)
|
||||
# Quantize input for forward (same layout as weight)
|
||||
if layout_type is not None:
|
||||
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
||||
else:
|
||||
@ -797,43 +805,68 @@ class QuantLinearFunc(torch.autograd.Function):
|
||||
|
||||
output = torch.nn.functional.linear(q_input, w, b)
|
||||
|
||||
# Restore original input shape
|
||||
# Unflatten output to match original input shape
|
||||
if len(input_shape) > 2:
|
||||
output = output.unflatten(0, input_shape[:-1])
|
||||
|
||||
ctx.save_for_backward(input_float, weight)
|
||||
# Save for backward
|
||||
ctx.input_shape = input_shape
|
||||
ctx.has_bias = bias is not None
|
||||
ctx.compute_dtype = compute_dtype
|
||||
ctx.weight_requires_grad = weight.requires_grad
|
||||
ctx.fp8_bwd = comfy.model_management.training_fp8_bwd
|
||||
|
||||
if ctx.fp8_bwd:
|
||||
# Cache FP8 quantized input — half the memory of bf16
|
||||
if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'):
|
||||
ctx.q_input = q_input # already FP8, reuse
|
||||
else:
|
||||
# NVFP4 or other layout — quantize input to FP8 for backward
|
||||
ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout")
|
||||
ctx.save_for_backward(weight)
|
||||
else:
|
||||
ctx.q_input = None
|
||||
ctx.save_for_backward(input_float, weight)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
input_float, weight = ctx.saved_tensors
|
||||
compute_dtype = ctx.compute_dtype
|
||||
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
||||
|
||||
# Dequantize weight to compute dtype for backward matmul
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight_f = weight.dequantize().to(compute_dtype)
|
||||
# Value casting — only difference between fp8 and non-fp8 paths
|
||||
if ctx.fp8_bwd:
|
||||
weight, = ctx.saved_tensors
|
||||
# Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm
|
||||
grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout")
|
||||
if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"):
|
||||
weight_mm = weight
|
||||
elif isinstance(weight, QuantizedTensor):
|
||||
weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||
else:
|
||||
weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||
input_mm = ctx.q_input
|
||||
else:
|
||||
weight_f = weight.to(compute_dtype)
|
||||
input_float, weight = ctx.saved_tensors
|
||||
# Standard tensors → torch.mm does regular matmul
|
||||
grad_mm = grad_2d
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight_mm = weight.dequantize().to(compute_dtype)
|
||||
else:
|
||||
weight_mm = weight.to(compute_dtype)
|
||||
input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None
|
||||
|
||||
# grad_input = grad_output @ weight
|
||||
grad_input = torch.mm(grad_2d, weight_f)
|
||||
# Computation — same for both paths, dispatch handles the rest
|
||||
grad_input = torch.mm(grad_mm, weight_mm)
|
||||
if len(ctx.input_shape) > 2:
|
||||
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
||||
|
||||
# grad_weight (only if weight requires grad, typically frozen for quantized training)
|
||||
grad_weight = None
|
||||
if ctx.weight_requires_grad:
|
||||
input_f = input_float.flatten(0, -2).to(compute_dtype)
|
||||
grad_weight = torch.mm(grad_2d.t(), input_f)
|
||||
grad_weight = torch.mm(grad_mm.t(), input_mm)
|
||||
|
||||
# grad_bias
|
||||
grad_bias = None
|
||||
if ctx.has_bias:
|
||||
grad_bias = grad_2d.sum(dim=0)
|
||||
|
||||
@ -5,6 +5,10 @@ from comfy_api.latest._input import (
|
||||
MaskInput,
|
||||
LatentInput,
|
||||
VideoInput,
|
||||
CurvePoint,
|
||||
CurveInput,
|
||||
MonotoneCubicCurve,
|
||||
LinearCurve,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -13,4 +17,8 @@ __all__ = [
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
"VideoInput",
|
||||
"CurvePoint",
|
||||
"CurveInput",
|
||||
"MonotoneCubicCurve",
|
||||
"LinearCurve",
|
||||
]
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
|
||||
from .video_types import VideoInput
|
||||
|
||||
__all__ = [
|
||||
@ -7,4 +8,8 @@ __all__ = [
|
||||
"VideoInput",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
"CurvePoint",
|
||||
"CurveInput",
|
||||
"MonotoneCubicCurve",
|
||||
"LinearCurve",
|
||||
]
|
||||
|
||||
219
comfy_api/latest/_input/curve_types.py
Normal file
219
comfy_api/latest/_input/curve_types.py
Normal file
@ -0,0 +1,219 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CurvePoint = tuple[float, float]
|
||||
|
||||
|
||||
class CurveInput(ABC):
|
||||
"""Abstract base class for curve inputs.
|
||||
|
||||
Subclasses represent different curve representations (control-point
|
||||
interpolation, analytical functions, LUT-based, etc.) while exposing a
|
||||
uniform evaluation interface to downstream nodes.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def points(self) -> list[CurvePoint]:
|
||||
"""The control points that define this curve."""
|
||||
|
||||
@abstractmethod
|
||||
def interp(self, x: float) -> float:
|
||||
"""Evaluate the curve at a single *x* value in [0, 1]."""
|
||||
|
||||
def interp_array(self, xs: np.ndarray) -> np.ndarray:
|
||||
"""Vectorised evaluation over a numpy array of x values.
|
||||
|
||||
Subclasses should override this for better performance. The default
|
||||
falls back to scalar ``interp`` calls.
|
||||
"""
|
||||
return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs))
|
||||
|
||||
def to_lut(self, size: int = 256) -> np.ndarray:
|
||||
"""Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1]."""
|
||||
return self.interp_array(np.linspace(0.0, 1.0, size))
|
||||
|
||||
@staticmethod
|
||||
def from_raw(data) -> CurveInput:
|
||||
"""Convert raw curve data (dict or point list) to a CurveInput instance.
|
||||
|
||||
Accepts:
|
||||
- A ``CurveInput`` instance (returned as-is).
|
||||
- A dict with ``"points"`` and optional ``"interpolation"`` keys.
|
||||
- A bare list/sequence of ``(x, y)`` pairs (defaults to monotone cubic).
|
||||
"""
|
||||
if isinstance(data, CurveInput):
|
||||
return data
|
||||
if isinstance(data, dict):
|
||||
raw_points = data["points"]
|
||||
interpolation = data.get("interpolation", "monotone_cubic")
|
||||
else:
|
||||
raw_points = data
|
||||
interpolation = "monotone_cubic"
|
||||
points = [(float(x), float(y)) for x, y in raw_points]
|
||||
if interpolation == "linear":
|
||||
return LinearCurve(points)
|
||||
if interpolation != "monotone_cubic":
|
||||
logger.warning("Unknown curve interpolation %r, falling back to monotone_cubic", interpolation)
|
||||
return MonotoneCubicCurve(points)
|
||||
|
||||
|
||||
class MonotoneCubicCurve(CurveInput):
|
||||
"""Monotone cubic Hermite interpolation over control points.
|
||||
|
||||
Mirrors the frontend ``createMonotoneInterpolator`` in
|
||||
``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that
|
||||
backend evaluation matches the editor preview exactly.
|
||||
|
||||
All heavy work (sorting, slope computation) happens once at construction.
|
||||
``interp_array`` is fully vectorised with numpy.
|
||||
"""
|
||||
|
||||
def __init__(self, control_points: list[CurvePoint]):
|
||||
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||
self._slopes = self._compute_slopes()
|
||||
|
||||
@property
|
||||
def points(self) -> list[CurvePoint]:
|
||||
return list(self._points)
|
||||
|
||||
def _compute_slopes(self) -> np.ndarray:
|
||||
xs, ys = self._xs, self._ys
|
||||
n = len(xs)
|
||||
if n < 2:
|
||||
return np.zeros(n, dtype=np.float64)
|
||||
|
||||
dx = np.diff(xs)
|
||||
dy = np.diff(ys)
|
||||
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||
deltas = np.where(dx == 0, 0.0, dy / dx_safe)
|
||||
|
||||
slopes = np.empty(n, dtype=np.float64)
|
||||
slopes[0] = deltas[0]
|
||||
slopes[-1] = deltas[-1]
|
||||
for i in range(1, n - 1):
|
||||
if deltas[i - 1] * deltas[i] <= 0:
|
||||
slopes[i] = 0.0
|
||||
else:
|
||||
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
|
||||
|
||||
for i in range(n - 1):
|
||||
if deltas[i] == 0:
|
||||
slopes[i] = 0.0
|
||||
slopes[i + 1] = 0.0
|
||||
else:
|
||||
alpha = slopes[i] / deltas[i]
|
||||
beta = slopes[i + 1] / deltas[i]
|
||||
s = alpha * alpha + beta * beta
|
||||
if s > 9:
|
||||
t = 3 / math.sqrt(s)
|
||||
slopes[i] = t * alpha * deltas[i]
|
||||
slopes[i + 1] = t * beta * deltas[i]
|
||||
return slopes
|
||||
|
||||
def interp(self, x: float) -> float:
|
||||
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return 0.0
|
||||
if n == 1:
|
||||
return float(ys[0])
|
||||
if x <= xs[0]:
|
||||
return float(ys[0])
|
||||
if x >= xs[-1]:
|
||||
return float(ys[-1])
|
||||
|
||||
hi = int(np.searchsorted(xs, x, side='right'))
|
||||
hi = min(hi, n - 1)
|
||||
lo = hi - 1
|
||||
|
||||
dx = xs[hi] - xs[lo]
|
||||
if dx == 0:
|
||||
return float(ys[lo])
|
||||
|
||||
t = (x - xs[lo]) / dx
|
||||
t2 = t * t
|
||||
t3 = t2 * t
|
||||
h00 = 2 * t3 - 3 * t2 + 1
|
||||
h10 = t3 - 2 * t2 + t
|
||||
h01 = -2 * t3 + 3 * t2
|
||||
h11 = t3 - t2
|
||||
return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi])
|
||||
|
||||
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||
"""Fully vectorised evaluation using numpy."""
|
||||
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return np.zeros_like(xs_in, dtype=np.float64)
|
||||
if n == 1:
|
||||
return np.full_like(xs_in, ys[0], dtype=np.float64)
|
||||
|
||||
hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1)
|
||||
lo = hi - 1
|
||||
|
||||
dx = xs[hi] - xs[lo]
|
||||
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||
t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe)
|
||||
t2 = t * t
|
||||
t3 = t2 * t
|
||||
|
||||
h00 = 2 * t3 - 3 * t2 + 1
|
||||
h10 = t3 - 2 * t2 + t
|
||||
h01 = -2 * t3 + 3 * t2
|
||||
h11 = t3 - t2
|
||||
|
||||
result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]
|
||||
result = np.where(xs_in <= xs[0], ys[0], result)
|
||||
result = np.where(xs_in >= xs[-1], ys[-1], result)
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MonotoneCubicCurve(points={self._points})"
|
||||
|
||||
|
||||
class LinearCurve(CurveInput):
|
||||
"""Piecewise linear interpolation over control points.
|
||||
|
||||
Mirrors the frontend ``createLinearInterpolator`` in
|
||||
``ComfyUI_frontend/src/components/curve/curveUtils.ts``.
|
||||
"""
|
||||
|
||||
def __init__(self, control_points: list[CurvePoint]):
|
||||
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||
|
||||
@property
|
||||
def points(self) -> list[CurvePoint]:
|
||||
return list(self._points)
|
||||
|
||||
def interp(self, x: float) -> float:
|
||||
xs, ys = self._xs, self._ys
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return 0.0
|
||||
if n == 1:
|
||||
return float(ys[0])
|
||||
return float(np.interp(x, xs, ys))
|
||||
|
||||
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||
if len(self._xs) == 0:
|
||||
return np.zeros_like(xs_in, dtype=np.float64)
|
||||
if len(self._xs) == 1:
|
||||
return np.full_like(xs_in, self._ys[0], dtype=np.float64)
|
||||
return np.interp(xs_in, self._xs, self._ys)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"LinearCurve(points={self._points})"
|
||||
@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
from comfy.samplers import CFGGuider, Sampler
|
||||
from comfy.sd import CLIP, VAE
|
||||
from comfy.sd import StyleModel as StyleModel_
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.input import VideoInput, CurveInput as CurveInput_
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
@ -1242,8 +1242,9 @@ class BoundingBox(ComfyTypeIO):
|
||||
|
||||
@comfytype(io_type="CURVE")
|
||||
class Curve(ComfyTypeIO):
|
||||
CurvePoint = tuple[float, float]
|
||||
Type = list[CurvePoint]
|
||||
from comfy_api.input import CurvePoint
|
||||
if TYPE_CHECKING:
|
||||
Type = CurveInput_
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
@ -1252,6 +1253,18 @@ class Curve(ComfyTypeIO):
|
||||
if default is None:
|
||||
self.default = [(0.0, 0.0), (1.0, 1.0)]
|
||||
|
||||
def as_dict(self):
|
||||
d = super().as_dict()
|
||||
if self.default is not None:
|
||||
d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"}
|
||||
return d
|
||||
|
||||
|
||||
@comfytype(io_type="HISTOGRAM")
|
||||
class Histogram(ComfyTypeIO):
|
||||
"""A histogram represented as a list of bin counts."""
|
||||
Type = list[int]
|
||||
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
@ -2240,5 +2253,6 @@ __all__ = [
|
||||
"PriceBadge",
|
||||
"BoundingBox",
|
||||
"Curve",
|
||||
"Histogram",
|
||||
"NodeReplace",
|
||||
]
|
||||
|
||||
@ -29,13 +29,21 @@ class ImageEditRequest(BaseModel):
|
||||
class VideoGenerationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
image: InputUrlObject | None = Field(...)
|
||||
image: InputUrlObject | None = Field(None)
|
||||
reference_images: list[InputUrlObject] | None = Field(None)
|
||||
duration: int = Field(...)
|
||||
aspect_ratio: str | None = Field(...)
|
||||
resolution: str = Field(...)
|
||||
seed: int = Field(...)
|
||||
|
||||
|
||||
class VideoExtensionRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
video: InputUrlObject = Field(...)
|
||||
duration: int = Field(default=6)
|
||||
model: str | None = Field(default=None)
|
||||
|
||||
|
||||
class VideoEditRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
|
||||
@ -8,6 +8,7 @@ from comfy_api_nodes.apis.grok import (
|
||||
ImageGenerationResponse,
|
||||
InputUrlObject,
|
||||
VideoEditRequest,
|
||||
VideoExtensionRequest,
|
||||
VideoGenerationRequest,
|
||||
VideoGenerationResponse,
|
||||
VideoStatusResponse,
|
||||
@ -21,6 +22,7 @@ from comfy_api_nodes.util import (
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
validate_video_duration,
|
||||
@ -33,6 +35,13 @@ def _extract_grok_price(response) -> float | None:
|
||||
return None
|
||||
|
||||
|
||||
def _extract_grok_video_price(response) -> float | None:
|
||||
price = _extract_grok_price(response)
|
||||
if price is not None:
|
||||
return price * 1.43
|
||||
return None
|
||||
|
||||
|
||||
class GrokImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -354,6 +363,8 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
seed: int,
|
||||
image: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if model == "grok-imagine-video-beta":
|
||||
model = "grok-imagine-video"
|
||||
image_url = None
|
||||
if image is not None:
|
||||
if get_number_of_images(image) != 1:
|
||||
@ -462,6 +473,244 @@ class GrokVideoEditNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
|
||||
class GrokVideoReferenceNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoReferenceNode",
|
||||
display_name="Grok Reference-to-Video",
|
||||
category="api node/video/Grok",
|
||||
description="Generate video guided by reference images as style and content references.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text description of the desired video.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"grok-imagine-video",
|
||||
[
|
||||
IO.Autogrow.Input(
|
||||
"reference_images",
|
||||
template=IO.Autogrow.TemplatePrefix(
|
||||
IO.Image.Input("image"),
|
||||
prefix="reference_",
|
||||
min=1,
|
||||
max=7,
|
||||
),
|
||||
tooltip="Up to 7 reference images to guide the video generation.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["480p", "720p"],
|
||||
tooltip="The resolution of the output video.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"],
|
||||
tooltip="The aspect ratio of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=6,
|
||||
min=2,
|
||||
max=10,
|
||||
step=1,
|
||||
tooltip="The duration of the output video in seconds.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="The model to use for video generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model.duration", "model.resolution"],
|
||||
input_groups=["model.reference_images"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$refs := inputGroups["model.reference_images"];
|
||||
$rate := $res = "720p" ? 0.07 : 0.05;
|
||||
$price := ($rate * $dur + 0.002 * $refs) * 1.43;
|
||||
{"type":"usd","usd": $price}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
ref_image_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
list(model["reference_images"].values()),
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading base images",
|
||||
max_images=7,
|
||||
)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
|
||||
data=VideoGenerationRequest(
|
||||
model=model["model"],
|
||||
reference_images=[InputUrlObject(url=i) for i in ref_image_urls],
|
||||
prompt=prompt,
|
||||
resolution=model["resolution"],
|
||||
duration=model["duration"],
|
||||
aspect_ratio=model["aspect_ratio"],
|
||||
seed=seed,
|
||||
),
|
||||
response_model=VideoGenerationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||
response_model=VideoStatusResponse,
|
||||
price_extractor=_extract_grok_video_price,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
|
||||
class GrokVideoExtendNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoExtendNode",
|
||||
display_name="Grok Video Extend",
|
||||
category="api node/video/Grok",
|
||||
description="Extend an existing video with a seamless continuation based on a text prompt.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text description of what should happen next in the video.",
|
||||
),
|
||||
IO.Video.Input("video", tooltip="Source video to extend. MP4 format, 2-15 seconds."),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"grok-imagine-video",
|
||||
[
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=8,
|
||||
min=2,
|
||||
max=10,
|
||||
step=1,
|
||||
tooltip="Length of the extension in seconds.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="The model to use for video extension.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model.duration"]),
|
||||
expr="""
|
||||
(
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
{
|
||||
"type": "range_usd",
|
||||
"min_usd": (0.02 + 0.05 * $dur) * 1.43,
|
||||
"max_usd": (0.15 + 0.05 * $dur) * 1.43
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
video: Input.Video,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
validate_video_duration(video, min_duration=2, max_duration=15)
|
||||
video_size = get_fs_object_size(video.get_stream_source())
|
||||
if video_size > 50 * 1024 * 1024:
|
||||
raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.")
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/videos/extensions", method="POST"),
|
||||
data=VideoExtensionRequest(
|
||||
prompt=prompt,
|
||||
video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)),
|
||||
duration=model["duration"],
|
||||
),
|
||||
response_model=VideoGenerationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||
response_model=VideoStatusResponse,
|
||||
price_extractor=_extract_grok_video_price,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
|
||||
class GrokExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -469,7 +718,9 @@ class GrokExtension(ComfyExtension):
|
||||
GrokImageNode,
|
||||
GrokImageEditNode,
|
||||
GrokVideoNode,
|
||||
GrokVideoReferenceNode,
|
||||
GrokVideoEditNode,
|
||||
GrokVideoExtendNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
42
comfy_extras/nodes_curve.py
Normal file
42
comfy_extras/nodes_curve.py
Normal file
@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_api.input import CurveInput
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class CurveEditor(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CurveEditor",
|
||||
display_name="Curve Editor",
|
||||
category="utils",
|
||||
inputs=[
|
||||
io.Curve.Input("curve"),
|
||||
io.Histogram.Input("histogram", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Curve.Output("curve"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, curve, histogram=None) -> io.NodeOutput:
|
||||
result = CurveInput.from_raw(curve)
|
||||
|
||||
ui = {}
|
||||
if histogram is not None:
|
||||
ui["histogram"] = histogram if isinstance(histogram, list) else list(histogram)
|
||||
|
||||
return io.NodeOutput(result, ui=ui) if ui else io.NodeOutput(result)
|
||||
|
||||
|
||||
class CurveExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self):
|
||||
return [CurveEditor]
|
||||
|
||||
|
||||
async def comfy_entrypoint():
|
||||
return CurveExtension()
|
||||
79
comfy_extras/nodes_number_convert.py
Normal file
79
comfy_extras/nodes_number_convert.py
Normal file
@ -0,0 +1,79 @@
|
||||
"""Number Convert node for unified numeric type conversion.
|
||||
|
||||
Provides a single node that converts INT, FLOAT, STRING, and BOOL
|
||||
inputs into FLOAT and INT outputs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class NumberConvertNode(io.ComfyNode):
|
||||
"""Converts various types to numeric FLOAT and INT outputs."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="ComfyNumberConvert",
|
||||
display_name="Number Convert",
|
||||
category="math",
|
||||
search_aliases=[
|
||||
"int to float", "float to int", "number convert",
|
||||
"int2float", "float2int", "cast", "parse number",
|
||||
"string to number", "bool to int",
|
||||
],
|
||||
inputs=[
|
||||
io.MultiType.Input(
|
||||
"value",
|
||||
[io.Int, io.Float, io.String, io.Boolean],
|
||||
display_name="value",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Float.Output(display_name="FLOAT"),
|
||||
io.Int.Output(display_name="INT"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, value) -> io.NodeOutput:
|
||||
if isinstance(value, bool):
|
||||
float_val = 1.0 if value else 0.0
|
||||
elif isinstance(value, (int, float)):
|
||||
float_val = float(value)
|
||||
elif isinstance(value, str):
|
||||
text = value.strip()
|
||||
if not text:
|
||||
raise ValueError("Cannot convert empty string to number.")
|
||||
try:
|
||||
float_val = float(text)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Cannot convert string to number: {value!r}"
|
||||
) from None
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unsupported input type: {type(value).__name__}"
|
||||
)
|
||||
|
||||
if not math.isfinite(float_val):
|
||||
raise ValueError(
|
||||
f"Cannot convert non-finite value to number: {float_val}"
|
||||
)
|
||||
|
||||
return io.NodeOutput(float_val, int(float_val))
|
||||
|
||||
|
||||
class NumberConvertExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [NumberConvertNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> NumberConvertExtension:
|
||||
return NumberConvertExtension()
|
||||
@ -1030,6 +1030,11 @@ class TrainLoraNode(io.ComfyNode):
|
||||
default="bf16",
|
||||
tooltip="The dtype to use for lora.",
|
||||
),
|
||||
io.Boolean.Input(
|
||||
"quantized_backward",
|
||||
default=False,
|
||||
tooltip="When using training_dtype 'none' and training on quantized model, doing backward with quantized matmul when enabled.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"algorithm",
|
||||
options=list(adapter_maps.keys()),
|
||||
@ -1097,6 +1102,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed,
|
||||
training_dtype,
|
||||
lora_dtype,
|
||||
quantized_backward,
|
||||
algorithm,
|
||||
gradient_checkpointing,
|
||||
checkpoint_depth,
|
||||
@ -1117,6 +1123,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed = seed[0]
|
||||
training_dtype = training_dtype[0]
|
||||
lora_dtype = lora_dtype[0]
|
||||
quantized_backward = quantized_backward[0]
|
||||
algorithm = algorithm[0]
|
||||
gradient_checkpointing = gradient_checkpointing[0]
|
||||
offloading = offloading[0]
|
||||
@ -1125,6 +1132,8 @@ class TrainLoraNode(io.ComfyNode):
|
||||
bucket_mode = bucket_mode[0]
|
||||
bypass_mode = bypass_mode[0]
|
||||
|
||||
comfy.model_management.training_fp8_bwd = quantized_backward
|
||||
|
||||
# Process latents based on mode
|
||||
if bucket_mode:
|
||||
latents = _process_latents_bucket_mode(latents)
|
||||
|
||||
2
nodes.py
2
nodes.py
@ -2454,7 +2454,9 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_nag.py",
|
||||
"nodes_sdpose.py",
|
||||
"nodes_math.py",
|
||||
"nodes_number_convert.py",
|
||||
"nodes_painter.py",
|
||||
"nodes_curve.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.41.21
|
||||
comfyui-workflow-templates==0.9.26
|
||||
comfyui-frontend-package==1.42.8
|
||||
comfyui-workflow-templates==0.9.36
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
torchsde
|
||||
|
||||
123
tests-unit/comfy_extras_test/nodes_number_convert_test.py
Normal file
123
tests-unit/comfy_extras_test/nodes_number_convert_test.py
Normal file
@ -0,0 +1,123 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_nodes = MagicMock()
|
||||
mock_nodes.MAX_RESOLUTION = 16384
|
||||
mock_server = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}):
|
||||
from comfy_extras.nodes_number_convert import NumberConvertNode
|
||||
|
||||
|
||||
class TestNumberConvertExecute:
|
||||
@staticmethod
|
||||
def _exec(value) -> object:
|
||||
return NumberConvertNode.execute(value)
|
||||
|
||||
# --- INT input ---
|
||||
|
||||
def test_int_input(self):
|
||||
result = self._exec(42)
|
||||
assert result[0] == 42.0
|
||||
assert result[1] == 42
|
||||
|
||||
def test_int_zero(self):
|
||||
result = self._exec(0)
|
||||
assert result[0] == 0.0
|
||||
assert result[1] == 0
|
||||
|
||||
def test_int_negative(self):
|
||||
result = self._exec(-7)
|
||||
assert result[0] == -7.0
|
||||
assert result[1] == -7
|
||||
|
||||
# --- FLOAT input ---
|
||||
|
||||
def test_float_input(self):
|
||||
result = self._exec(3.14)
|
||||
assert result[0] == 3.14
|
||||
assert result[1] == 3
|
||||
|
||||
def test_float_truncation_toward_zero(self):
|
||||
result = self._exec(-2.9)
|
||||
assert result[0] == -2.9
|
||||
assert result[1] == -2 # int() truncates toward zero, not floor
|
||||
|
||||
def test_float_output_type(self):
|
||||
result = self._exec(5)
|
||||
assert isinstance(result[0], float)
|
||||
|
||||
def test_int_output_type(self):
|
||||
result = self._exec(5.7)
|
||||
assert isinstance(result[1], int)
|
||||
|
||||
# --- BOOL input ---
|
||||
|
||||
def test_bool_true(self):
|
||||
result = self._exec(True)
|
||||
assert result[0] == 1.0
|
||||
assert result[1] == 1
|
||||
|
||||
def test_bool_false(self):
|
||||
result = self._exec(False)
|
||||
assert result[0] == 0.0
|
||||
assert result[1] == 0
|
||||
|
||||
# --- STRING input ---
|
||||
|
||||
def test_string_integer(self):
|
||||
result = self._exec("42")
|
||||
assert result[0] == 42.0
|
||||
assert result[1] == 42
|
||||
|
||||
def test_string_float(self):
|
||||
result = self._exec("3.14")
|
||||
assert result[0] == 3.14
|
||||
assert result[1] == 3
|
||||
|
||||
def test_string_negative(self):
|
||||
result = self._exec("-5.5")
|
||||
assert result[0] == -5.5
|
||||
assert result[1] == -5
|
||||
|
||||
def test_string_with_whitespace(self):
|
||||
result = self._exec(" 7.0 ")
|
||||
assert result[0] == 7.0
|
||||
assert result[1] == 7
|
||||
|
||||
def test_string_scientific_notation(self):
|
||||
result = self._exec("1e3")
|
||||
assert result[0] == 1000.0
|
||||
assert result[1] == 1000
|
||||
|
||||
# --- STRING error paths ---
|
||||
|
||||
def test_empty_string_raises(self):
|
||||
with pytest.raises(ValueError, match="Cannot convert empty string"):
|
||||
self._exec("")
|
||||
|
||||
def test_whitespace_only_string_raises(self):
|
||||
with pytest.raises(ValueError, match="Cannot convert empty string"):
|
||||
self._exec(" ")
|
||||
|
||||
def test_non_numeric_string_raises(self):
|
||||
with pytest.raises(ValueError, match="Cannot convert string to number"):
|
||||
self._exec("abc")
|
||||
|
||||
def test_string_inf_raises(self):
|
||||
with pytest.raises(ValueError, match="non-finite"):
|
||||
self._exec("inf")
|
||||
|
||||
def test_string_nan_raises(self):
|
||||
with pytest.raises(ValueError, match="non-finite"):
|
||||
self._exec("nan")
|
||||
|
||||
def test_string_negative_inf_raises(self):
|
||||
with pytest.raises(ValueError, match="non-finite"):
|
||||
self._exec("-inf")
|
||||
|
||||
# --- Unsupported type ---
|
||||
|
||||
def test_unsupported_type_raises(self):
|
||||
with pytest.raises(TypeError, match="Unsupported input type"):
|
||||
self._exec([1, 2, 3])
|
||||
Loading…
Reference in New Issue
Block a user