Rename gaussian -> splat, improve some tooltips

This commit is contained in:
kijai 2026-05-31 02:04:43 +03:00
parent dd4c7d7661
commit 4a8143f063
6 changed files with 137 additions and 140 deletions

View File

@ -5,7 +5,7 @@ from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from ._input_impl import VideoFromFile, VideoFromComponents
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, GAUSSIAN, File3D
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, SPLAT, File3D
from . import _io_public as io
from . import _ui_public as ui
from comfy_execution.utils import get_executing_context
@ -143,7 +143,7 @@ class Types:
VideoComponents = VideoComponents
MESH = MESH
VOXEL = VOXEL
GAUSSIAN = GAUSSIAN
SPLAT = SPLAT
File3D = File3D

View File

@ -28,7 +28,7 @@ if TYPE_CHECKING:
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL, GAUSSIAN, SVG as _SVG, File3D
from ._util import MESH, VOXEL, SPLAT, SVG as _SVG, File3D
class FolderType(str, Enum):
@ -684,9 +684,9 @@ class Voxel(ComfyTypeIO):
class Mesh(ComfyTypeIO):
Type = MESH
@comfytype(io_type="GAUSSIAN")
class Gaussian(ComfyTypeIO):
Type = GAUSSIAN
@comfytype(io_type="SPLAT")
class Splat(ComfyTypeIO):
Type = SPLAT
@comfytype(io_type="FILE_3D")
@ -2324,7 +2324,7 @@ __all__ = [
"LossMap",
"Voxel",
"Mesh",
"Gaussian",
"Splat",
"File3DAny",
"File3DGLB",
"File3DGLTF",

View File

@ -1,5 +1,5 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents
from .geometry_types import VOXEL, MESH, GAUSSIAN, File3D
from .geometry_types import VOXEL, MESH, SPLAT, File3D
from .image_types import SVG
__all__ = [
@ -9,7 +9,7 @@ __all__ = [
"VideoComponents",
"VOXEL",
"MESH",
"GAUSSIAN",
"SPLAT",
"File3D",
"SVG",
]

View File

@ -11,7 +11,7 @@ class VOXEL:
self.data = data
class GAUSSIAN:
class SPLAT:
"""A batch of 3D Gaussian splats in render-ready (activated, world-space) form.
Tensors are (B, N, ...) and zero-padded to a common N across the batch; `counts` (B,) holds the

View File

@ -1,4 +1,4 @@
# Generic utility nodes for the GAUSSIAN type (3D gaussian splats)
# Generic utility nodes for the SPLAT type (3D gaussian splats)
import gzip
import logging
@ -17,6 +17,7 @@ import comfy.model_management
import comfy.utils
from comfy_api.latest import ComfyExtension, IO, Types
from comfy_extras.nodes_save_3d import pack_variable_mesh_batch
from server import PromptServer
_C0 = 0.28209479177387814 # SH band-0 constant: DC coefficient -> base RGB
@ -29,7 +30,7 @@ def _linear_to_srgb(c):
return torch.where(c <= 0.0031308, c * 12.92, 1.055 * c.clamp_min(0) ** (1 / 2.4) - 0.055)
def _real_len(g: Types.GAUSSIAN, i: int) -> int:
def _real_len(g: Types.SPLAT, i: int) -> int:
# Real splat count of batch item i (honors variable-length `counts`).
return int(g.counts[i].item()) if g.counts is not None else g.positions.shape[1]
@ -52,7 +53,7 @@ def _gaussian_ply_bytes(positions, scales, rotations, opacities, sh) -> bytes:
xyz = positions.cpu().numpy().astype(np.float32)
n = xyz.shape[0]
if n == 0:
raise ValueError("GaussianToFile3D: gaussian is empty")
raise ValueError("SplatToFile3D: gaussian is empty")
normals = np.zeros_like(xyz)
f = sh.cpu().numpy().astype(np.float32) # (N, K, 3)
f_dc = f[:, 0, :] # (N, 3)
@ -90,7 +91,7 @@ def _gaussian_ksplat_bytes(positions, scales, rotations, opacities, sh) -> bytes
xyz = positions.cpu().numpy().astype(np.float32)
n = xyz.shape[0]
if n == 0:
raise ValueError("GaussianToFile3D: gaussian is empty")
raise ValueError("SplatToFile3D: gaussian is empty")
scale = scales.cpu().numpy().astype(np.float32)
rot = rotations.cpu().numpy().astype(np.float32) # wxyz, mirrors the .ply rot order
rot = rot / np.linalg.norm(rot, axis=1, keepdims=True).clip(1e-12)
@ -145,7 +146,7 @@ def _gaussian_spz_bytes(positions, scales, rotations, opacities, sh) -> bytes:
xyz = positions.cpu().numpy().astype(np.float32)
n = xyz.shape[0]
if n == 0:
raise ValueError("GaussianToFile3D: gaussian is empty")
raise ValueError("SplatToFile3D: gaussian is empty")
# Positions: fixed point, masked to 24 bits, little-endian 3-byte words.
fixed = 1 << _SPZ_FRACTIONAL_BITS
@ -202,7 +203,7 @@ def _norm_quat(q):
def _parse_ply_gaussian(data: bytes):
end = data.find(b'end_header')
if end < 0:
raise ValueError("File3DToGaussian: not a PLY (missing end_header)")
raise ValueError("File3DToSplat: not a PLY (missing end_header)")
header = data[:end].decode('ascii', 'replace')
body = end + len(b'end_header')
body += 2 if data[body:body + 2] == b'\r\n' else 1
@ -212,14 +213,14 @@ def _parse_ply_gaussian(data: bytes):
if not p:
continue
if p[0] == 'format' and p[1] != 'binary_little_endian':
raise ValueError(f"File3DToGaussian: unsupported PLY format '{p[1]}' (need binary_little_endian)")
raise ValueError(f"File3DToSplat: unsupported PLY format '{p[1]}' (need binary_little_endian)")
if p[0] == 'element':
in_vertex = p[1] == 'vertex'
if in_vertex:
count = int(p[2])
elif p[0] == 'property' and in_vertex:
if p[1] == 'list':
raise ValueError("File3DToGaussian: PLY vertex has list properties (unsupported)")
raise ValueError("File3DToSplat: PLY vertex has list properties (unsupported)")
props.append((p[2], '<' + _PLY_DTYPES[p[1]]))
arr = np.frombuffer(data, np.dtype(props), count=count, offset=body)
names = arr.dtype.names
@ -257,7 +258,7 @@ def _parse_ply_gaussian(data: bytes):
def _parse_splat_gaussian(data: bytes):
# antimatter15 .splat: 32-byte records (f32 xyz, f32 scale, u8 rgba, u8 quat as (b-128)/128 wxyz).
if len(data) % 32 != 0:
raise ValueError("File3DToGaussian: .splat size is not a multiple of 32 bytes")
raise ValueError("File3DToSplat: .splat size is not a multiple of 32 bytes")
rec = np.frombuffer(data, np.dtype([('xyz', '<f4', 3), ('scale', '<f4', 3),
('rgba', 'u1', 4), ('quat', 'u1', 4)]))
rgba = rec['rgba'].astype(np.float32) / 255.0
@ -270,11 +271,11 @@ def _parse_ksplat_gaussian(data: bytes):
# mkkellogg SplatBuffer: 4096-byte header, N section headers, then per-section splat data. Supports
# levels 0 (float) / 1 (half + bucketed positions) / 2 (half, uint8 SH). SH is skipped (base color kept).
if data[0] != 0:
raise ValueError(f"File3DToGaussian: unsupported .ksplat version {data[0]}.{data[1]}")
raise ValueError(f"File3DToSplat: unsupported .ksplat version {data[0]}.{data[1]}")
max_sections = struct.unpack_from('<I', data, 4)[0]
level = struct.unpack_from('<H', data, 20)[0]
if level not in _KSPLAT_COMPRESSION:
raise ValueError(f"File3DToGaussian: invalid .ksplat compression level {level}")
raise ValueError(f"File3DToSplat: invalid .ksplat compression level {level}")
bc, bs, br, bcol, bshc, default_range = _KSPLAT_COMPRESSION[level]
parts = []
@ -321,7 +322,7 @@ def _parse_ksplat_gaussian(data: bytes):
base += bytes_per_splat * sec_max + buckets_store
if not parts:
raise ValueError("File3DToGaussian: .ksplat has no splats")
raise ValueError("File3DToSplat: .ksplat has no splats")
return tuple(np.concatenate([p[i] for p in parts]) for i in range(5))
@ -329,7 +330,7 @@ def _parse_spz_gaussian(data: bytes):
# Niantic .spz (gzip-wrapped), versions 1-3. Base color only (SH skipped). See spark's SpzReader.
raw = gzip.decompress(data)
if struct.unpack_from('<I', raw, 0)[0] != _SPZ_MAGIC:
raise ValueError("File3DToGaussian: invalid .spz (bad magic)")
raise ValueError("File3DToSplat: invalid .spz (bad magic)")
version = struct.unpack_from('<I', raw, 4)[0]
n = struct.unpack_from('<I', raw, 8)[0]
frac_bits = raw[13]
@ -345,7 +346,7 @@ def _parse_spz_gaussian(data: bytes):
xyz = (v / (1 << frac_bits)).astype(np.float32)
off += n * 9
else:
raise ValueError(f"File3DToGaussian: unsupported .spz version {version}")
raise ValueError(f"File3DToSplat: unsupported .spz version {version}")
alpha = np.frombuffer(raw, np.uint8, count=n, offset=off).astype(np.float32) / 255.0
off += n
@ -394,10 +395,10 @@ def _detect_splat_format(data: bytes) -> str:
return "ksplat"
if len(data) % 32 == 0:
return "splat"
raise ValueError("File3DToGaussian: could not determine splat format from contents")
raise ValueError("File3DToSplat: could not determine splat format from contents")
def _gaussian_item(g: Types.GAUSSIAN, i: int, device):
def _gaussian_item(g: Types.SPLAT, i: int, device):
# Slice batch item i to its real length, as float32 torch tensors on `device` (SH DC -> base RGB).
end = _real_len(g, i)
to = lambda a: a.to(device=device, dtype=torch.float32)
@ -461,49 +462,48 @@ def _mat_to_quat(m):
return q / q.norm(dim=-1, keepdim=True).clamp_min(1e-12)
class GaussianToFile3D(IO.ComfyNode):
class SplatToFile3D(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GaussianToFile3D",
display_name="Create 3D File (from Gaussian)",
node_id="SplatToFile3D",
display_name="Create 3D File (from Splat)",
search_aliases=["gaussian to ply", "splat to file", "export gaussian"],
category="3d/gaussian",
description="Serialize a gaussian splat to an in-memory File3D, for Save 3D Model / Preview 3D. "
"ply keeps full SH (standard 3DGS); ksplat and spz are compact viewer formats (base "
description="Serialize a gaussian splat to a File3D object for Save / Preview 3D nodes. "
"ply keeps full spherical harmonics (standard 3DGS); ksplat and spz are compact viewer formats (base "
"color only). Single splat only - feed one batch item at a time.",
inputs=[
IO.Gaussian.Input("gaussian"),
IO.Splat.Input("splat"),
IO.Combo.Input("format", options=["ply", "ksplat", "spz"],
tooltip="ply: standard 3DGS with full spherical harmonics. ksplat: mkkellogg "
"SplatBuffer (level 0, uncompressed). spz: Niantic gzip-compressed "
"(~10x smaller). ksplat/spz keep base color only - view-dependent SH "
"is dropped."),
"(~10x smaller). ksplat/spz keep base color only - view-dependent spherical harmonics is dropped."),
],
outputs=[IO.File3DAny.Output(display_name="model_3d")],
)
@classmethod
def execute(cls, gaussian, format="ply") -> IO.NodeOutput:
if gaussian.positions.shape[0] > 1:
logging.warning("GaussianToFile3D: got a batch of %d; converting only the first splat (File3D is a "
"single file).", gaussian.positions.shape[0])
end = _real_len(gaussian, 0)
def execute(cls, splat, format="ply") -> IO.NodeOutput:
if splat.positions.shape[0] > 1:
logging.warning("SplatToFile3D: got a batch of %d; converting only the first splat (File3D is a "
"single file).", splat.positions.shape[0])
end = _real_len(splat, 0)
writer = {"ksplat": _gaussian_ksplat_bytes, "spz": _gaussian_spz_bytes}.get(format, _gaussian_ply_bytes)
data = writer(gaussian.positions[0, :end], gaussian.scales[0, :end],
gaussian.rotations[0, :end], gaussian.opacities[0, :end], gaussian.sh[0, :end])
data = writer(splat.positions[0, :end], splat.scales[0, :end],
splat.rotations[0, :end], splat.opacities[0, :end], splat.sh[0, :end])
return IO.NodeOutput(Types.File3D(BytesIO(data), file_format=format))
class File3DToGaussian(IO.ComfyNode):
class File3DToSplat(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="File3DToGaussian",
display_name="Get Gaussian Splat",
node_id="File3DToSplat",
display_name="Get Splat",
search_aliases=["load splat", "ply to gaussian", "import gaussian", "file to splat"],
category="3d/gaussian",
description="Parse a splat File3D (.ply / .splat / .ksplat / .spz) into a GAUSSIAN. Inverse of "
description="Parse a splat File3D (.ply / .splat / .ksplat / .spz) into a gaussian. Inverse of "
"Create 3D File (from Gaussian). ply carries full spherical harmonics; the others are base "
"color only. Format is auto-detected from the file contents.",
inputs=[
@ -513,7 +513,7 @@ class File3DToGaussian(IO.ComfyNode):
tooltip="A gaussian-splat 3D file",
),
],
outputs=[IO.Gaussian.Output(display_name="gaussian")],
outputs=[IO.Splat.Output(display_name="splat")],
)
@classmethod
@ -524,14 +524,14 @@ class File3DToGaussian(IO.ComfyNode):
xyz, scale, rot, opacity, sh = parser(data)
t = lambda a: torch.from_numpy(np.ascontiguousarray(a)).float()
gaussian = Types.GAUSSIAN(
splat = Types.SPLAT(
t(xyz)[None], # (1, N, 3)
t(scale)[None], # (1, N, 3) linear
t(rot)[None], # (1, N, 4) wxyz
t(opacity).reshape(1, -1, 1), # (1, N, 1)
t(sh)[None], # (1, N, K, 3)
)
return IO.NodeOutput(gaussian)
return IO.NodeOutput(splat)
def _view_matrix_t(yaw_deg, pitch_deg, device):
@ -572,7 +572,7 @@ def _gauss_blur(x, sigma, dev):
def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, fov, splat_scale, bg, sharpen=1.0,
headlight_shading=0.0, render_style="color", camera_info=None,
yaw=35.0, pitch=30.0, zoom=1.0):
yaw=35.0, pitch=30.0, zoom=1.0, distance=0.0):
# Perspective-correct anisotropic gaussian-splat rasterizer. Each splat is weighted by its 3D Gaussian's
# peak along each pixel's ray (AAA / Hahlbohm), composited front-to-back across depth slabs. `render_style`
# selects the image: color / clay / depth / normal. Returns (image HxWx3, coverage mask HxW) on CPU.
@ -609,7 +609,8 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, fov, splat_sc
fov = fov if fov > 0 else 35.0 # fov=0 -> default 35
center = xyz.mean(0)
extent = (xyz - center).norm(dim=-1).quantile(0.99).clamp_min(1e-4) # ignore outlier floaters
dist = extent / (math.tan(math.radians(fov) / 2) * 0.9) / max(zoom, 1e-3)
base = distance if distance > 0 else extent / (math.tan(math.radians(fov) / 2) * 0.9) # absolute dist, else auto-frame
dist = base / max(zoom, 1e-3)
W = _view_matrix_t(yaw, pitch, dev)
cam = (xyz - center) @ W.T + torch.tensor([0.0, 0.0, dist], device=dev)
yflip = 1.0
@ -775,19 +776,20 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, fov, splat_sc
return img.clamp(0, 1).cpu(), covg.clamp(0, 1).cpu()
class RenderGaussian(IO.ComfyNode):
class RenderSplat(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="RenderGaussian",
display_name="Render Gaussian Splat",
node_id="RenderSplat",
display_name="Render Splat",
search_aliases=["splat to image", "render splat", "gaussian turntable"],
category="3d/gaussian",
description="Render a gaussian splat to an image with an anisotropic EWA rasterizer (oriented "
"elliptical splats, antialiased, depth-sorted front-to-back). frames>1 sweeps yaw a full "
"360 turn, producing an image batch (turntable) you can pipe into a video node.",
"elliptical splats, antialiased, depth-sorted front-to-back). Set frames greater than 1 "
"to sweep the camera yaw through a full 360° rotation, producing a batch of images "
"(a turntable) that you can feed into a video node.",
inputs=[
IO.Gaussian.Input("gaussian"),
IO.Splat.Input("splat"),
IO.Int.Input("width", default=1024, min=64, max=2048, step=8),
IO.Int.Input("height", default=1024, min=64, max=2048, step=8),
IO.Int.Input("frames", default=1, min=-240, max=240,
@ -796,11 +798,13 @@ class RenderGaussian(IO.ComfyNode):
IO.Float.Input("yaw", default=35.0, min=-360.0, max=360.0, step=1.0),
IO.Float.Input("pitch", default=30.0, min=-89.0, max=89.0, step=1.0),
IO.Float.Input("zoom", default=1.0, min=0.1, max=5.0, step=0.05,
tooltip="Camera dolly: >1 zooms in, <1 out. Without camera_info, 1.0 frames the whole "
"splat (~10% margin); with camera_info, 1.0 is exactly the supplied camera."),
tooltip="Camera dolly: >1 zooms in, <1 out. With camera_info or distance, 1.0 is exactly "
"that camera; otherwise 1.0 frames the whole splat (~10% margin)."),
IO.Float.Input("distance", default=0.0, min=0.0, max=1000.0, step=0.01,
tooltip="Absolute camera distance for the yaw/pitch orbit (from Get Camera Info). "
"0 = auto-frame the whole splat. Ignored when camera_info is connected."),
IO.Float.Input("fov", default=0.0, min=0.0, max=120.0, step=1.0,
tooltip="Vertical field of view in degrees. 0 = auto: 35, or taken from camera_info "
"when connected. Any value >0 overrides (including over camera_info)."),
tooltip="Vertical field of view in degrees. 0 = camera_info if provided, otherwise defaults to 35. Any value above 0 overrides the camera_info FoV."),
IO.Float.Input("splat_scale", default=1.0, min=0.1, max=5.0, step=0.05, advanced=True,
tooltip="Multiplier on each splat's projected footprint (lower = crisper points, "
"higher = softer/fuller surface)."),
@ -832,8 +836,8 @@ class RenderGaussian(IO.ComfyNode):
)
@classmethod
def execute(cls, gaussian, width, height, yaw, pitch, frames, zoom, fov, splat_scale, sharpen,
headlight_shading, opacity_threshold, background, render_style,
def execute(cls, splat, width, height, yaw, pitch, frames, zoom, fov, splat_scale, sharpen,
headlight_shading, opacity_threshold, background, render_style, distance=0.0,
camera_info=None, bg_image=None) -> IO.NodeOutput:
bg = _hex_to_rgb(background)
bg_imgs = None
@ -844,17 +848,17 @@ class RenderGaussian(IO.ComfyNode):
orbit_dir = -1.0 if frames < 0 else 1.0 # sign = orbit direction
if camera_info is not None:
if n_frames > 1:
logging.warning("RenderGaussian: camera_info is a fixed camera; ignoring frames=%d (no orbit sweep).", frames)
logging.warning("RenderSplat: camera_info is a fixed camera; ignoring frames=%d (no orbit sweep).", frames)
n_frames = 1
if str(camera_info.get("cameraType", "")).lower().startswith("ortho"):
logging.warning("RenderGaussian: orthographic camera_info is rendered with a perspective camera.")
logging.warning("RenderSplat: orthographic camera_info is rendered with a perspective camera.")
imgs, masks = [], []
device = comfy.model_management.get_torch_device() # render device; splat stays in torch here -> no roundtrip
total = gaussian.positions.shape[0] * n_frames
total = splat.positions.shape[0] * n_frames
pbar = comfy.utils.ProgressBar(total) if total > 1 else None
k = 0
for i in range(gaussian.positions.shape[0]):
xyz, rgb, opacity, scale, rot = _gaussian_item(gaussian, i, device)
for i in range(splat.positions.shape[0]):
xyz, rgb, opacity, scale, rot = _gaussian_item(splat, i, device)
if opacity_threshold > 0:
keep = opacity >= opacity_threshold
xyz, rgb, opacity, scale, rot = xyz[keep], rgb[keep], opacity[keep], scale[keep], rot[keep]
@ -863,7 +867,8 @@ class RenderGaussian(IO.ComfyNode):
bg_k = bg_imgs[k % bg_imgs.shape[0]] if bg_imgs is not None else bg # per-frame plate, or solid colour
img, mask = _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, fov, splat_scale, bg_k,
sharpen=sharpen, headlight_shading=headlight_shading,
render_style=render_style, camera_info=camera_info, yaw=y, pitch=pitch, zoom=zoom)
render_style=render_style, camera_info=camera_info, yaw=y, pitch=pitch,
zoom=zoom, distance=distance)
imgs.append(img)
masks.append(mask)
k += 1
@ -872,19 +877,18 @@ class RenderGaussian(IO.ComfyNode):
return IO.NodeOutput(torch.stack(imgs), torch.stack(masks))
class TransformGaussian(IO.ComfyNode):
class TransformSplat(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TransformGaussian",
display_name="Transform Gaussian Splat",
node_id="TransformSplat",
display_name="Transform Splat",
search_aliases=["move splat", "rotate splat", "scale splat", "gaussian transform"],
category="3d/gaussian",
description="Translate, rotate (Euler XYZ degrees) and scale (per-axis) a gaussian splat. Positions, "
"per-splat rotations and scales transform consistently; non-uniform scale re-derives each "
"splat's covariance (eigendecomposition) so the ellipsoids deform correctly.",
description="Translate, rotate, and scale a gaussian splat."
"Non-uniform scale also reshapes every individual splat, slower process.",
inputs=[
IO.Gaussian.Input("gaussian"),
IO.Splat.Input("splat"),
IO.Float.Input("translate_x", default=0.0, min=-100.0, max=100.0, step=0.01),
IO.Float.Input("translate_y", default=0.0, min=-100.0, max=100.0, step=0.01),
IO.Float.Input("translate_z", default=0.0, min=-100.0, max=100.0, step=0.01),
@ -895,13 +899,13 @@ class TransformGaussian(IO.ComfyNode):
IO.Float.Input("scale_y", default=1.0, min=0.01, max=100.0, step=0.01),
IO.Float.Input("scale_z", default=1.0, min=0.01, max=100.0, step=0.01),
],
outputs=[IO.Gaussian.Output(display_name="gaussian")],
outputs=[IO.Splat.Output(display_name="splat")],
)
@classmethod
def execute(cls, gaussian, translate_x, translate_y, translate_z,
def execute(cls, splat, translate_x, translate_y, translate_z,
rotate_x, rotate_y, rotate_z, scale_x, scale_y, scale_z) -> IO.NodeOutput:
pos = gaussian.positions
pos = splat.positions
dev, dt = pos.device, pos.dtype
q_rot = _euler_to_quat(rotate_x, rotate_y, rotate_z).to(device=dev, dtype=dt)
R = _quat_to_mat(q_rot[None])[0] # (3, 3) node rotation
@ -911,51 +915,45 @@ class TransformGaussian(IO.ComfyNode):
positions = pos @ A.T + t # rotate, scale per-axis, then translate
if scale_x == scale_y == scale_z: # uniform: rotation/scale factor out cleanly
scales = gaussian.scales * scale_x
rotations = _quat_mul(q_rot.expand_as(gaussian.rotations), gaussian.rotations)
scales = splat.scales * scale_x
rotations = _quat_mul(q_rot.expand_as(splat.rotations), splat.rotations)
rotations = rotations / rotations.norm(dim=-1, keepdim=True).clamp_min(1e-12)
else: # non-uniform: transform Sigma = A R s^2 R^T A^T, re-extract
rg = _quat_to_mat(gaussian.rotations.reshape(-1, 4)) # (M,3,3) per-splat rotation
s2 = gaussian.scales.reshape(-1, 3).square()
cov = (rg * s2[:, None, :]) @ rg.transpose(-1, -2) # Sigma
cov = A @ cov @ A.T # A Sigma A^T (A broadcast over splats)
rg = _quat_to_mat(splat.rotations.reshape(-1, 4)) # (M,3,3) per-splat rotation
s2 = splat.scales.reshape(-1, 3).square()
cov = (rg * s2[:, None, :]) @ rg.transpose(-1, -2) # Sigma
cov = A @ cov @ A.T # A Sigma A^T (A broadcast over splats)
lam, V = torch.linalg.eigh(cov) # symmetric -> eigenvalues (asc), orthonormal axes
V = V * torch.where(torch.linalg.det(V) < 0, -1.0, 1.0)[..., None, None] # keep a proper rotation
scales = lam.clamp_min(0).sqrt().reshape(gaussian.scales.shape)
rotations = _mat_to_quat(V).reshape(gaussian.rotations.shape)
out = Types.GAUSSIAN(positions, scales, rotations, gaussian.opacities, gaussian.sh,
counts=getattr(gaussian, "counts", None))
scales = lam.clamp_min(0).sqrt().reshape(splat.scales.shape)
rotations = _mat_to_quat(V).reshape(splat.rotations.shape)
out = Types.SPLAT(positions, scales, rotations, splat.opacities, splat.sh,
counts=getattr(splat, "counts", None))
return IO.NodeOutput(out)
class GaussianInfo(IO.ComfyNode):
class GetSplatCount(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GaussianInfo",
display_name="Gaussian Splat Info",
search_aliases=["splat stats", "gaussian count", "splat info"],
node_id="GetSplatCount",
display_name="Get Splat Count",
search_aliases=["splat count", "gaussian count", "number of splats", "splat info"],
category="3d/gaussian",
description="Report per-splat stats: count, bounding box, and opacity/scale ranges.",
inputs=[IO.Gaussian.Input("gaussian")],
outputs=[IO.String.Output(display_name="info")],
description="Returns the number of splats (summed across the batch) and shows it on the node.",
inputs=[IO.Splat.Input("splat")],
outputs=[IO.Splat.Output(display_name="splat"),
IO.Int.Output(display_name="count"),
],
hidden=[IO.Hidden.unique_id],
)
@classmethod
def execute(cls, gaussian) -> IO.NodeOutput:
lines = []
for i in range(gaussian.positions.shape[0]):
xyz, _, opacity, scale, _ = _gaussian_item(gaussian, i, torch.device("cpu"))
lo, hi = xyz.amin(0), xyz.amax(0)
fmt = lambda v: "[" + ", ".join(f"{x:.3f}" for x in v) + "]"
lines.append(
f"gaussian[{i}]: count={xyz.shape[0]}\n"
f" aabb min={fmt(lo)} max={fmt(hi)} size={fmt(hi - lo)}\n"
f" opacity mean={opacity.mean():.3f} min={opacity.min():.3f} max={opacity.max():.3f}\n"
f" scale mean={scale.mean():.4f} min={scale.min():.4f} max={scale.max():.4f}"
)
info = "\n".join(lines)
return IO.NodeOutput(info, ui={"text": [info]})
def execute(cls, splat) -> IO.NodeOutput:
count = sum(_real_len(splat, i) for i in range(splat.positions.shape[0]))
if cls.hidden.unique_id: # show the count inline on the node
PromptServer.instance.send_progress_text(f"{count:,} splats", cls.hidden.unique_id)
return IO.NodeOutput(splat, count)
def _pad_stack(items, n):
@ -967,15 +965,15 @@ def _pad_stack(items, n):
return out
def _merge_gaussians(gaussians: list) -> Types.GAUSSIAN:
# Concatenate GAUSSIAN batches along the splat dimension (per item), padding SH to the highest degree.
def _merge_gaussians(gaussians: list) -> Types.SPLAT:
# Concatenate SPLAT batches along the splat dimension (per item), padding SH to the highest degree.
gs = [g for g in gaussians if g is not None]
if not gs:
raise ValueError("MergeGaussian: no gaussians to merge")
raise ValueError("MergeSplat: no gaussians to merge")
b = gs[0].positions.shape[0]
for g in gs:
if g.positions.shape[0] != b:
raise ValueError(f"MergeGaussian: batch size mismatch ({b} vs {g.positions.shape[0]}).")
raise ValueError(f"MergeSplat: batch size mismatch ({b} vs {g.positions.shape[0]}).")
max_k = max(g.sh.shape[2] for g in gs)
pos_b, scl_b, rot_b, op_b, sh_b, lengths = [], [], [], [], [], []
@ -1002,32 +1000,31 @@ def _merge_gaussians(gaussians: list) -> Types.GAUSSIAN:
counts = None
if len(set(lengths)) > 1:
counts = torch.tensor(lengths, device=gs[0].positions.device, dtype=torch.int64)
return Types.GAUSSIAN(_pad_stack(pos_b, n), _pad_stack(scl_b, n), _pad_stack(rot_b, n),
return Types.SPLAT(_pad_stack(pos_b, n), _pad_stack(scl_b, n), _pad_stack(rot_b, n),
_pad_stack(op_b, n), _pad_stack(sh_b, n), counts=counts)
class MergeGaussian(IO.ComfyNode):
class MergeSplat(IO.ComfyNode):
@classmethod
def define_schema(cls):
# Autogrow: a gaussian0/gaussian1/... input list that grows a fresh slot as you connect splats.
gaussians = IO.Autogrow.TemplatePrefix(IO.Gaussian.Input("gaussian"), prefix="gaussian", min=2, max=32)
# Autogrow: a splat0/splat1/... input list that grows a fresh slot as you connect splats.
splats = IO.Autogrow.TemplatePrefix(IO.Splat.Input("splat"), prefix="splat", min=2, max=32)
return IO.Schema(
node_id="MergeGaussian",
display_name="Merge Gaussian Splats",
node_id="MergeSplat",
display_name="Merge Splats",
search_aliases=["union splat", "densify gaussian", "combine splat", "merge gaussian"],
category="3d/gaussian",
description="Concatenate any number of gaussian splats into one (per batch item). Because the "
"TripoSplat decoder samples points stochastically, unioning several decodes of the same "
"latent at different seeds densifies the surface - feed them here, then mesh the result.",
inputs=[IO.Autogrow.Input("gaussians", template=gaussians)],
outputs=[IO.Gaussian.Output(display_name="gaussian")],
description="Concatenate any number of gaussian splats into one. Unioning several decodes of the same "
"latent at different seeds densifies the surface, this can improve surface quality when meshing.",
inputs=[IO.Autogrow.Input("splats", template=splats)],
outputs=[IO.Splat.Output(display_name="splat")],
)
@classmethod
def execute(cls, gaussians: IO.Autogrow.Type) -> IO.NodeOutput:
gs = [v for v in gaussians.values() if v is not None]
def execute(cls, splats: IO.Autogrow.Type) -> IO.NodeOutput:
gs = [v for v in splats.values() if v is not None]
if not gs:
raise ValueError("MergeGaussian: connect at least one gaussian splat.")
raise ValueError("MergeSplat: connect at least one splat.")
return IO.NodeOutput(_merge_gaussians(gs))
@ -1232,7 +1229,7 @@ def _taubin_smooth(verts, faces, iters, lam=0.5, mu=-0.53):
return np.ascontiguousarray(v.astype(np.float32))
def _gaussian_to_mesh(g: Types.GAUSSIAN, i, res, kernel, taubin, level_bias, min_component, min_opacity, color_sharpen, device, progress=None):
def _gaussian_to_mesh(g: Types.SPLAT, i, res, kernel, taubin, level_bias, min_component, min_opacity, color_sharpen, device, progress=None):
# Mesh one splat: density + colour grids -> Surface Nets -> floater removal -> Taubin smoothing ->
# volume-sampled colours. Returns (verts, faces int64, colors in [0,1]), or None if no surface.
rep = progress if progress is not None else (lambda *_: None)
@ -1297,12 +1294,12 @@ def _gaussian_to_mesh(g: Types.GAUSSIAN, i, res, kernel, taubin, level_bias, min
return (torch.from_numpy(verts), torch.from_numpy(faces), torch.from_numpy(col))
class GaussianToMesh(IO.ComfyNode):
class SplatToMesh(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GaussianToMesh",
display_name="Gaussian Splat to Mesh",
node_id="SplatToMesh",
display_name="Extract Mesh from Splat",
search_aliases=["splat to mesh", "gaussian surface nets", "splat surface", "mesh splat"],
category="3d/gaussian",
description="Extract a coloured triangle MESH from a gaussian splat. Each splat is rasterized into a "
@ -1310,7 +1307,7 @@ class GaussianToMesh(IO.ComfyNode):
"tiny floaters are dropped, and vertices are coloured from their nearest gaussians. Denser "
"splats give more detail - union several decodes with Merge Gaussian Splats first.",
inputs=[
IO.Gaussian.Input("gaussian"),
IO.Splat.Input("splat"),
IO.Int.Input("resolution", default=512, min=64, max=1024, step=16,
tooltip="Density-grid resolution along the longest axis. Higher = finer surface, "
"more VRAM/time (grows with resolution^3)."),
@ -1318,7 +1315,7 @@ class GaussianToMesh(IO.ComfyNode):
tooltip="Max splat half-width in voxels. Each gaussian is rasterized over a window "
"sized to its own 3-sigma, capped here - small surfels stay cheap, large ones "
"aren't truncated. Raise if sparse splats leave gaps."),
IO.Int.Input("smooth", default=0, min=0, max=60,
IO.Int.Input("smooth", default=0, min=0, max=60, advanced = True,
tooltip="Taubin mesh-smoothing iterations. Smooths the surface without shrinking it "
"(volume-preserving), unlike blurring the density. 0 = raw surface."),
IO.Float.Input("level", default=0.6, min=0.3, max=2.0, step=0.05,
@ -1338,18 +1335,18 @@ class GaussianToMesh(IO.ComfyNode):
)
@classmethod
def execute(cls, gaussian, resolution, kernel, smooth, level, min_component, min_opacity, color_sharpen) -> IO.NodeOutput:
def execute(cls, splat, resolution, kernel, smooth, level, min_component, min_opacity, color_sharpen) -> IO.NodeOutput:
device = comfy.model_management.get_torch_device()
b = gaussian.positions.shape[0]
b = splat.positions.shape[0]
prec = 1000 # each splat owns a 0..prec block of the bar; its callback advances within that block
pbar = comfy.utils.ProgressBar(b * prec)
verts_l, faces_l, colors_l = [], [], []
for i in range(b):
cb = lambda f, base=i * prec: pbar.update_absolute(base + int(min(max(f, 0.0), 1.0) * prec))
res = _gaussian_to_mesh(gaussian, i, resolution, kernel, smooth, level, min_component, min_opacity, color_sharpen, device, cb)
res = _gaussian_to_mesh(splat, i, resolution, kernel, smooth, level, min_component, min_opacity, color_sharpen, device, cb)
if res is None:
logging.warning("GaussianToMesh: splat %d produced no surface; emitting an empty mesh.", i)
logging.warning("SplatToMesh: splat %d produced no surface; emitting an empty mesh.", i)
v, f, c = torch.zeros((0, 3)), torch.zeros((0, 3), dtype=torch.int64), torch.zeros((0, 3))
else:
v, f, c = res
@ -1364,8 +1361,8 @@ class GaussianToMesh(IO.ComfyNode):
class GaussianExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [GaussianToFile3D, File3DToGaussian, RenderGaussian, TransformGaussian, GaussianInfo,
MergeGaussian, GaussianToMesh]
return [SplatToFile3D, File3DToSplat, RenderSplat, TransformSplat, GetSplatCount,
MergeSplat, SplatToMesh]
async def comfy_entrypoint() -> GaussianExtension:

View File

@ -2455,7 +2455,7 @@ async def init_builtin_extra_nodes():
"nodes_save_3d.py",
"nodes_moge.py",
"nodes_mediapipe.py",
"nodes_gaussian.py",
"nodes_gaussian_splat.py",
]
import_failed = []