From 4a8143f06397af2ab23a9efc9112765997971ec6 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 31 May 2026 02:04:43 +0300 Subject: [PATCH] Rename gaussian -> splat, improve some tooltips --- comfy_api/latest/__init__.py | 4 +- comfy_api/latest/_io.py | 10 +- comfy_api/latest/_util/__init__.py | 4 +- comfy_api/latest/_util/geometry_types.py | 2 +- ...es_gaussian.py => nodes_gaussian_splat.py} | 255 +++++++++--------- nodes.py | 2 +- 6 files changed, 137 insertions(+), 140 deletions(-) rename comfy_extras/{nodes_gaussian.py => nodes_gaussian_splat.py} (88%) diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index a4c02b8db..294ad425e 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -5,7 +5,7 @@ from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.async_to_sync import create_sync_class from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from ._input_impl import VideoFromFile, VideoFromComponents -from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, GAUSSIAN, File3D +from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, SPLAT, File3D from . import _io_public as io from . import _ui_public as ui from comfy_execution.utils import get_executing_context @@ -143,7 +143,7 @@ class Types: VideoComponents = VideoComponents MESH = MESH VOXEL = VOXEL - GAUSSIAN = GAUSSIAN + SPLAT = SPLAT File3D = File3D diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 9a6b98692..a3aa508ce 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) from comfy_execution.graph_utils import ExecutionBlocker -from ._util import MESH, VOXEL, GAUSSIAN, SVG as _SVG, File3D +from ._util import MESH, VOXEL, SPLAT, SVG as _SVG, File3D class FolderType(str, Enum): @@ -684,9 +684,9 @@ class Voxel(ComfyTypeIO): class Mesh(ComfyTypeIO): Type = MESH -@comfytype(io_type="GAUSSIAN") -class Gaussian(ComfyTypeIO): - Type = GAUSSIAN +@comfytype(io_type="SPLAT") +class Splat(ComfyTypeIO): + Type = SPLAT @comfytype(io_type="FILE_3D") @@ -2324,7 +2324,7 @@ __all__ = [ "LossMap", "Voxel", "Mesh", - "Gaussian", + "Splat", "File3DAny", "File3DGLB", "File3DGLTF", diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py index acf615c39..b27f5a97e 100644 --- a/comfy_api/latest/_util/__init__.py +++ b/comfy_api/latest/_util/__init__.py @@ -1,5 +1,5 @@ from .video_types import VideoContainer, VideoCodec, VideoComponents -from .geometry_types import VOXEL, MESH, GAUSSIAN, File3D +from .geometry_types import VOXEL, MESH, SPLAT, File3D from .image_types import SVG __all__ = [ @@ -9,7 +9,7 @@ __all__ = [ "VideoComponents", "VOXEL", "MESH", - "GAUSSIAN", + "SPLAT", "File3D", "SVG", ] diff --git a/comfy_api/latest/_util/geometry_types.py b/comfy_api/latest/_util/geometry_types.py index a49c15536..84a18d69a 100644 --- a/comfy_api/latest/_util/geometry_types.py +++ b/comfy_api/latest/_util/geometry_types.py @@ -11,7 +11,7 @@ class VOXEL: self.data = data -class GAUSSIAN: +class SPLAT: """A batch of 3D Gaussian splats in render-ready (activated, world-space) form. Tensors are (B, N, ...) and zero-padded to a common N across the batch; `counts` (B,) holds the diff --git a/comfy_extras/nodes_gaussian.py b/comfy_extras/nodes_gaussian_splat.py similarity index 88% rename from comfy_extras/nodes_gaussian.py rename to comfy_extras/nodes_gaussian_splat.py index 3cf969abe..bc45db05e 100644 --- a/comfy_extras/nodes_gaussian.py +++ b/comfy_extras/nodes_gaussian_splat.py @@ -1,4 +1,4 @@ -# Generic utility nodes for the GAUSSIAN type (3D gaussian splats) +# Generic utility nodes for the SPLAT type (3D gaussian splats) import gzip import logging @@ -17,6 +17,7 @@ import comfy.model_management import comfy.utils from comfy_api.latest import ComfyExtension, IO, Types from comfy_extras.nodes_save_3d import pack_variable_mesh_batch +from server import PromptServer _C0 = 0.28209479177387814 # SH band-0 constant: DC coefficient -> base RGB @@ -29,7 +30,7 @@ def _linear_to_srgb(c): return torch.where(c <= 0.0031308, c * 12.92, 1.055 * c.clamp_min(0) ** (1 / 2.4) - 0.055) -def _real_len(g: Types.GAUSSIAN, i: int) -> int: +def _real_len(g: Types.SPLAT, i: int) -> int: # Real splat count of batch item i (honors variable-length `counts`). return int(g.counts[i].item()) if g.counts is not None else g.positions.shape[1] @@ -52,7 +53,7 @@ def _gaussian_ply_bytes(positions, scales, rotations, opacities, sh) -> bytes: xyz = positions.cpu().numpy().astype(np.float32) n = xyz.shape[0] if n == 0: - raise ValueError("GaussianToFile3D: gaussian is empty") + raise ValueError("SplatToFile3D: gaussian is empty") normals = np.zeros_like(xyz) f = sh.cpu().numpy().astype(np.float32) # (N, K, 3) f_dc = f[:, 0, :] # (N, 3) @@ -90,7 +91,7 @@ def _gaussian_ksplat_bytes(positions, scales, rotations, opacities, sh) -> bytes xyz = positions.cpu().numpy().astype(np.float32) n = xyz.shape[0] if n == 0: - raise ValueError("GaussianToFile3D: gaussian is empty") + raise ValueError("SplatToFile3D: gaussian is empty") scale = scales.cpu().numpy().astype(np.float32) rot = rotations.cpu().numpy().astype(np.float32) # wxyz, mirrors the .ply rot order rot = rot / np.linalg.norm(rot, axis=1, keepdims=True).clip(1e-12) @@ -145,7 +146,7 @@ def _gaussian_spz_bytes(positions, scales, rotations, opacities, sh) -> bytes: xyz = positions.cpu().numpy().astype(np.float32) n = xyz.shape[0] if n == 0: - raise ValueError("GaussianToFile3D: gaussian is empty") + raise ValueError("SplatToFile3D: gaussian is empty") # Positions: fixed point, masked to 24 bits, little-endian 3-byte words. fixed = 1 << _SPZ_FRACTIONAL_BITS @@ -202,7 +203,7 @@ def _norm_quat(q): def _parse_ply_gaussian(data: bytes): end = data.find(b'end_header') if end < 0: - raise ValueError("File3DToGaussian: not a PLY (missing end_header)") + raise ValueError("File3DToSplat: not a PLY (missing end_header)") header = data[:end].decode('ascii', 'replace') body = end + len(b'end_header') body += 2 if data[body:body + 2] == b'\r\n' else 1 @@ -212,14 +213,14 @@ def _parse_ply_gaussian(data: bytes): if not p: continue if p[0] == 'format' and p[1] != 'binary_little_endian': - raise ValueError(f"File3DToGaussian: unsupported PLY format '{p[1]}' (need binary_little_endian)") + raise ValueError(f"File3DToSplat: unsupported PLY format '{p[1]}' (need binary_little_endian)") if p[0] == 'element': in_vertex = p[1] == 'vertex' if in_vertex: count = int(p[2]) elif p[0] == 'property' and in_vertex: if p[1] == 'list': - raise ValueError("File3DToGaussian: PLY vertex has list properties (unsupported)") + raise ValueError("File3DToSplat: PLY vertex has list properties (unsupported)") props.append((p[2], '<' + _PLY_DTYPES[p[1]])) arr = np.frombuffer(data, np.dtype(props), count=count, offset=body) names = arr.dtype.names @@ -257,7 +258,7 @@ def _parse_ply_gaussian(data: bytes): def _parse_splat_gaussian(data: bytes): # antimatter15 .splat: 32-byte records (f32 xyz, f32 scale, u8 rgba, u8 quat as (b-128)/128 wxyz). if len(data) % 32 != 0: - raise ValueError("File3DToGaussian: .splat size is not a multiple of 32 bytes") + raise ValueError("File3DToSplat: .splat size is not a multiple of 32 bytes") rec = np.frombuffer(data, np.dtype([('xyz', ' str: return "ksplat" if len(data) % 32 == 0: return "splat" - raise ValueError("File3DToGaussian: could not determine splat format from contents") + raise ValueError("File3DToSplat: could not determine splat format from contents") -def _gaussian_item(g: Types.GAUSSIAN, i: int, device): +def _gaussian_item(g: Types.SPLAT, i: int, device): # Slice batch item i to its real length, as float32 torch tensors on `device` (SH DC -> base RGB). end = _real_len(g, i) to = lambda a: a.to(device=device, dtype=torch.float32) @@ -461,49 +462,48 @@ def _mat_to_quat(m): return q / q.norm(dim=-1, keepdim=True).clamp_min(1e-12) -class GaussianToFile3D(IO.ComfyNode): +class SplatToFile3D(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( - node_id="GaussianToFile3D", - display_name="Create 3D File (from Gaussian)", + node_id="SplatToFile3D", + display_name="Create 3D File (from Splat)", search_aliases=["gaussian to ply", "splat to file", "export gaussian"], category="3d/gaussian", - description="Serialize a gaussian splat to an in-memory File3D, for Save 3D Model / Preview 3D. " - "ply keeps full SH (standard 3DGS); ksplat and spz are compact viewer formats (base " + description="Serialize a gaussian splat to a File3D object for Save / Preview 3D nodes. " + "ply keeps full spherical harmonics (standard 3DGS); ksplat and spz are compact viewer formats (base " "color only). Single splat only - feed one batch item at a time.", inputs=[ - IO.Gaussian.Input("gaussian"), + IO.Splat.Input("splat"), IO.Combo.Input("format", options=["ply", "ksplat", "spz"], tooltip="ply: standard 3DGS with full spherical harmonics. ksplat: mkkellogg " "SplatBuffer (level 0, uncompressed). spz: Niantic gzip-compressed " - "(~10x smaller). ksplat/spz keep base color only - view-dependent SH " - "is dropped."), + "(~10x smaller). ksplat/spz keep base color only - view-dependent spherical harmonics is dropped."), ], outputs=[IO.File3DAny.Output(display_name="model_3d")], ) @classmethod - def execute(cls, gaussian, format="ply") -> IO.NodeOutput: - if gaussian.positions.shape[0] > 1: - logging.warning("GaussianToFile3D: got a batch of %d; converting only the first splat (File3D is a " - "single file).", gaussian.positions.shape[0]) - end = _real_len(gaussian, 0) + def execute(cls, splat, format="ply") -> IO.NodeOutput: + if splat.positions.shape[0] > 1: + logging.warning("SplatToFile3D: got a batch of %d; converting only the first splat (File3D is a " + "single file).", splat.positions.shape[0]) + end = _real_len(splat, 0) writer = {"ksplat": _gaussian_ksplat_bytes, "spz": _gaussian_spz_bytes}.get(format, _gaussian_ply_bytes) - data = writer(gaussian.positions[0, :end], gaussian.scales[0, :end], - gaussian.rotations[0, :end], gaussian.opacities[0, :end], gaussian.sh[0, :end]) + data = writer(splat.positions[0, :end], splat.scales[0, :end], + splat.rotations[0, :end], splat.opacities[0, :end], splat.sh[0, :end]) return IO.NodeOutput(Types.File3D(BytesIO(data), file_format=format)) -class File3DToGaussian(IO.ComfyNode): +class File3DToSplat(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( - node_id="File3DToGaussian", - display_name="Get Gaussian Splat", + node_id="File3DToSplat", + display_name="Get Splat", search_aliases=["load splat", "ply to gaussian", "import gaussian", "file to splat"], category="3d/gaussian", - description="Parse a splat File3D (.ply / .splat / .ksplat / .spz) into a GAUSSIAN. Inverse of " + description="Parse a splat File3D (.ply / .splat / .ksplat / .spz) into a gaussian. Inverse of " "Create 3D File (from Gaussian). ply carries full spherical harmonics; the others are base " "color only. Format is auto-detected from the file contents.", inputs=[ @@ -513,7 +513,7 @@ class File3DToGaussian(IO.ComfyNode): tooltip="A gaussian-splat 3D file", ), ], - outputs=[IO.Gaussian.Output(display_name="gaussian")], + outputs=[IO.Splat.Output(display_name="splat")], ) @classmethod @@ -524,14 +524,14 @@ class File3DToGaussian(IO.ComfyNode): xyz, scale, rot, opacity, sh = parser(data) t = lambda a: torch.from_numpy(np.ascontiguousarray(a)).float() - gaussian = Types.GAUSSIAN( + splat = Types.SPLAT( t(xyz)[None], # (1, N, 3) t(scale)[None], # (1, N, 3) linear t(rot)[None], # (1, N, 4) wxyz t(opacity).reshape(1, -1, 1), # (1, N, 1) t(sh)[None], # (1, N, K, 3) ) - return IO.NodeOutput(gaussian) + return IO.NodeOutput(splat) def _view_matrix_t(yaw_deg, pitch_deg, device): @@ -572,7 +572,7 @@ def _gauss_blur(x, sigma, dev): def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, fov, splat_scale, bg, sharpen=1.0, headlight_shading=0.0, render_style="color", camera_info=None, - yaw=35.0, pitch=30.0, zoom=1.0): + yaw=35.0, pitch=30.0, zoom=1.0, distance=0.0): # Perspective-correct anisotropic gaussian-splat rasterizer. Each splat is weighted by its 3D Gaussian's # peak along each pixel's ray (AAA / Hahlbohm), composited front-to-back across depth slabs. `render_style` # selects the image: color / clay / depth / normal. Returns (image HxWx3, coverage mask HxW) on CPU. @@ -609,7 +609,8 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, fov, splat_sc fov = fov if fov > 0 else 35.0 # fov=0 -> default 35 center = xyz.mean(0) extent = (xyz - center).norm(dim=-1).quantile(0.99).clamp_min(1e-4) # ignore outlier floaters - dist = extent / (math.tan(math.radians(fov) / 2) * 0.9) / max(zoom, 1e-3) + base = distance if distance > 0 else extent / (math.tan(math.radians(fov) / 2) * 0.9) # absolute dist, else auto-frame + dist = base / max(zoom, 1e-3) W = _view_matrix_t(yaw, pitch, dev) cam = (xyz - center) @ W.T + torch.tensor([0.0, 0.0, dist], device=dev) yflip = 1.0 @@ -775,19 +776,20 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, fov, splat_sc return img.clamp(0, 1).cpu(), covg.clamp(0, 1).cpu() -class RenderGaussian(IO.ComfyNode): +class RenderSplat(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( - node_id="RenderGaussian", - display_name="Render Gaussian Splat", + node_id="RenderSplat", + display_name="Render Splat", search_aliases=["splat to image", "render splat", "gaussian turntable"], category="3d/gaussian", description="Render a gaussian splat to an image with an anisotropic EWA rasterizer (oriented " - "elliptical splats, antialiased, depth-sorted front-to-back). frames>1 sweeps yaw a full " - "360 turn, producing an image batch (turntable) you can pipe into a video node.", + "elliptical splats, antialiased, depth-sorted front-to-back). Set frames greater than 1 " + "to sweep the camera yaw through a full 360° rotation, producing a batch of images " + "(a turntable) that you can feed into a video node.", inputs=[ - IO.Gaussian.Input("gaussian"), + IO.Splat.Input("splat"), IO.Int.Input("width", default=1024, min=64, max=2048, step=8), IO.Int.Input("height", default=1024, min=64, max=2048, step=8), IO.Int.Input("frames", default=1, min=-240, max=240, @@ -796,11 +798,13 @@ class RenderGaussian(IO.ComfyNode): IO.Float.Input("yaw", default=35.0, min=-360.0, max=360.0, step=1.0), IO.Float.Input("pitch", default=30.0, min=-89.0, max=89.0, step=1.0), IO.Float.Input("zoom", default=1.0, min=0.1, max=5.0, step=0.05, - tooltip="Camera dolly: >1 zooms in, <1 out. Without camera_info, 1.0 frames the whole " - "splat (~10% margin); with camera_info, 1.0 is exactly the supplied camera."), + tooltip="Camera dolly: >1 zooms in, <1 out. With camera_info or distance, 1.0 is exactly " + "that camera; otherwise 1.0 frames the whole splat (~10% margin)."), + IO.Float.Input("distance", default=0.0, min=0.0, max=1000.0, step=0.01, + tooltip="Absolute camera distance for the yaw/pitch orbit (from Get Camera Info). " + "0 = auto-frame the whole splat. Ignored when camera_info is connected."), IO.Float.Input("fov", default=0.0, min=0.0, max=120.0, step=1.0, - tooltip="Vertical field of view in degrees. 0 = auto: 35, or taken from camera_info " - "when connected. Any value >0 overrides (including over camera_info)."), + tooltip="Vertical field of view in degrees. 0 = camera_info if provided, otherwise defaults to 35. Any value above 0 overrides the camera_info FoV."), IO.Float.Input("splat_scale", default=1.0, min=0.1, max=5.0, step=0.05, advanced=True, tooltip="Multiplier on each splat's projected footprint (lower = crisper points, " "higher = softer/fuller surface)."), @@ -832,8 +836,8 @@ class RenderGaussian(IO.ComfyNode): ) @classmethod - def execute(cls, gaussian, width, height, yaw, pitch, frames, zoom, fov, splat_scale, sharpen, - headlight_shading, opacity_threshold, background, render_style, + def execute(cls, splat, width, height, yaw, pitch, frames, zoom, fov, splat_scale, sharpen, + headlight_shading, opacity_threshold, background, render_style, distance=0.0, camera_info=None, bg_image=None) -> IO.NodeOutput: bg = _hex_to_rgb(background) bg_imgs = None @@ -844,17 +848,17 @@ class RenderGaussian(IO.ComfyNode): orbit_dir = -1.0 if frames < 0 else 1.0 # sign = orbit direction if camera_info is not None: if n_frames > 1: - logging.warning("RenderGaussian: camera_info is a fixed camera; ignoring frames=%d (no orbit sweep).", frames) + logging.warning("RenderSplat: camera_info is a fixed camera; ignoring frames=%d (no orbit sweep).", frames) n_frames = 1 if str(camera_info.get("cameraType", "")).lower().startswith("ortho"): - logging.warning("RenderGaussian: orthographic camera_info is rendered with a perspective camera.") + logging.warning("RenderSplat: orthographic camera_info is rendered with a perspective camera.") imgs, masks = [], [] device = comfy.model_management.get_torch_device() # render device; splat stays in torch here -> no roundtrip - total = gaussian.positions.shape[0] * n_frames + total = splat.positions.shape[0] * n_frames pbar = comfy.utils.ProgressBar(total) if total > 1 else None k = 0 - for i in range(gaussian.positions.shape[0]): - xyz, rgb, opacity, scale, rot = _gaussian_item(gaussian, i, device) + for i in range(splat.positions.shape[0]): + xyz, rgb, opacity, scale, rot = _gaussian_item(splat, i, device) if opacity_threshold > 0: keep = opacity >= opacity_threshold xyz, rgb, opacity, scale, rot = xyz[keep], rgb[keep], opacity[keep], scale[keep], rot[keep] @@ -863,7 +867,8 @@ class RenderGaussian(IO.ComfyNode): bg_k = bg_imgs[k % bg_imgs.shape[0]] if bg_imgs is not None else bg # per-frame plate, or solid colour img, mask = _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, fov, splat_scale, bg_k, sharpen=sharpen, headlight_shading=headlight_shading, - render_style=render_style, camera_info=camera_info, yaw=y, pitch=pitch, zoom=zoom) + render_style=render_style, camera_info=camera_info, yaw=y, pitch=pitch, + zoom=zoom, distance=distance) imgs.append(img) masks.append(mask) k += 1 @@ -872,19 +877,18 @@ class RenderGaussian(IO.ComfyNode): return IO.NodeOutput(torch.stack(imgs), torch.stack(masks)) -class TransformGaussian(IO.ComfyNode): +class TransformSplat(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( - node_id="TransformGaussian", - display_name="Transform Gaussian Splat", + node_id="TransformSplat", + display_name="Transform Splat", search_aliases=["move splat", "rotate splat", "scale splat", "gaussian transform"], category="3d/gaussian", - description="Translate, rotate (Euler XYZ degrees) and scale (per-axis) a gaussian splat. Positions, " - "per-splat rotations and scales transform consistently; non-uniform scale re-derives each " - "splat's covariance (eigendecomposition) so the ellipsoids deform correctly.", + description="Translate, rotate, and scale a gaussian splat." + "Non-uniform scale also reshapes every individual splat, slower process.", inputs=[ - IO.Gaussian.Input("gaussian"), + IO.Splat.Input("splat"), IO.Float.Input("translate_x", default=0.0, min=-100.0, max=100.0, step=0.01), IO.Float.Input("translate_y", default=0.0, min=-100.0, max=100.0, step=0.01), IO.Float.Input("translate_z", default=0.0, min=-100.0, max=100.0, step=0.01), @@ -895,13 +899,13 @@ class TransformGaussian(IO.ComfyNode): IO.Float.Input("scale_y", default=1.0, min=0.01, max=100.0, step=0.01), IO.Float.Input("scale_z", default=1.0, min=0.01, max=100.0, step=0.01), ], - outputs=[IO.Gaussian.Output(display_name="gaussian")], + outputs=[IO.Splat.Output(display_name="splat")], ) @classmethod - def execute(cls, gaussian, translate_x, translate_y, translate_z, + def execute(cls, splat, translate_x, translate_y, translate_z, rotate_x, rotate_y, rotate_z, scale_x, scale_y, scale_z) -> IO.NodeOutput: - pos = gaussian.positions + pos = splat.positions dev, dt = pos.device, pos.dtype q_rot = _euler_to_quat(rotate_x, rotate_y, rotate_z).to(device=dev, dtype=dt) R = _quat_to_mat(q_rot[None])[0] # (3, 3) node rotation @@ -911,51 +915,45 @@ class TransformGaussian(IO.ComfyNode): positions = pos @ A.T + t # rotate, scale per-axis, then translate if scale_x == scale_y == scale_z: # uniform: rotation/scale factor out cleanly - scales = gaussian.scales * scale_x - rotations = _quat_mul(q_rot.expand_as(gaussian.rotations), gaussian.rotations) + scales = splat.scales * scale_x + rotations = _quat_mul(q_rot.expand_as(splat.rotations), splat.rotations) rotations = rotations / rotations.norm(dim=-1, keepdim=True).clamp_min(1e-12) else: # non-uniform: transform Sigma = A R s^2 R^T A^T, re-extract - rg = _quat_to_mat(gaussian.rotations.reshape(-1, 4)) # (M,3,3) per-splat rotation - s2 = gaussian.scales.reshape(-1, 3).square() - cov = (rg * s2[:, None, :]) @ rg.transpose(-1, -2) # Sigma - cov = A @ cov @ A.T # A Sigma A^T (A broadcast over splats) + rg = _quat_to_mat(splat.rotations.reshape(-1, 4)) # (M,3,3) per-splat rotation + s2 = splat.scales.reshape(-1, 3).square() + cov = (rg * s2[:, None, :]) @ rg.transpose(-1, -2) # Sigma + cov = A @ cov @ A.T # A Sigma A^T (A broadcast over splats) lam, V = torch.linalg.eigh(cov) # symmetric -> eigenvalues (asc), orthonormal axes V = V * torch.where(torch.linalg.det(V) < 0, -1.0, 1.0)[..., None, None] # keep a proper rotation - scales = lam.clamp_min(0).sqrt().reshape(gaussian.scales.shape) - rotations = _mat_to_quat(V).reshape(gaussian.rotations.shape) - out = Types.GAUSSIAN(positions, scales, rotations, gaussian.opacities, gaussian.sh, - counts=getattr(gaussian, "counts", None)) + scales = lam.clamp_min(0).sqrt().reshape(splat.scales.shape) + rotations = _mat_to_quat(V).reshape(splat.rotations.shape) + out = Types.SPLAT(positions, scales, rotations, splat.opacities, splat.sh, + counts=getattr(splat, "counts", None)) return IO.NodeOutput(out) -class GaussianInfo(IO.ComfyNode): +class GetSplatCount(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( - node_id="GaussianInfo", - display_name="Gaussian Splat Info", - search_aliases=["splat stats", "gaussian count", "splat info"], + node_id="GetSplatCount", + display_name="Get Splat Count", + search_aliases=["splat count", "gaussian count", "number of splats", "splat info"], category="3d/gaussian", - description="Report per-splat stats: count, bounding box, and opacity/scale ranges.", - inputs=[IO.Gaussian.Input("gaussian")], - outputs=[IO.String.Output(display_name="info")], + description="Returns the number of splats (summed across the batch) and shows it on the node.", + inputs=[IO.Splat.Input("splat")], + outputs=[IO.Splat.Output(display_name="splat"), + IO.Int.Output(display_name="count"), + ], + hidden=[IO.Hidden.unique_id], ) @classmethod - def execute(cls, gaussian) -> IO.NodeOutput: - lines = [] - for i in range(gaussian.positions.shape[0]): - xyz, _, opacity, scale, _ = _gaussian_item(gaussian, i, torch.device("cpu")) - lo, hi = xyz.amin(0), xyz.amax(0) - fmt = lambda v: "[" + ", ".join(f"{x:.3f}" for x in v) + "]" - lines.append( - f"gaussian[{i}]: count={xyz.shape[0]}\n" - f" aabb min={fmt(lo)} max={fmt(hi)} size={fmt(hi - lo)}\n" - f" opacity mean={opacity.mean():.3f} min={opacity.min():.3f} max={opacity.max():.3f}\n" - f" scale mean={scale.mean():.4f} min={scale.min():.4f} max={scale.max():.4f}" - ) - info = "\n".join(lines) - return IO.NodeOutput(info, ui={"text": [info]}) + def execute(cls, splat) -> IO.NodeOutput: + count = sum(_real_len(splat, i) for i in range(splat.positions.shape[0])) + if cls.hidden.unique_id: # show the count inline on the node + PromptServer.instance.send_progress_text(f"{count:,} splats", cls.hidden.unique_id) + return IO.NodeOutput(splat, count) def _pad_stack(items, n): @@ -967,15 +965,15 @@ def _pad_stack(items, n): return out -def _merge_gaussians(gaussians: list) -> Types.GAUSSIAN: - # Concatenate GAUSSIAN batches along the splat dimension (per item), padding SH to the highest degree. +def _merge_gaussians(gaussians: list) -> Types.SPLAT: + # Concatenate SPLAT batches along the splat dimension (per item), padding SH to the highest degree. gs = [g for g in gaussians if g is not None] if not gs: - raise ValueError("MergeGaussian: no gaussians to merge") + raise ValueError("MergeSplat: no gaussians to merge") b = gs[0].positions.shape[0] for g in gs: if g.positions.shape[0] != b: - raise ValueError(f"MergeGaussian: batch size mismatch ({b} vs {g.positions.shape[0]}).") + raise ValueError(f"MergeSplat: batch size mismatch ({b} vs {g.positions.shape[0]}).") max_k = max(g.sh.shape[2] for g in gs) pos_b, scl_b, rot_b, op_b, sh_b, lengths = [], [], [], [], [], [] @@ -1002,32 +1000,31 @@ def _merge_gaussians(gaussians: list) -> Types.GAUSSIAN: counts = None if len(set(lengths)) > 1: counts = torch.tensor(lengths, device=gs[0].positions.device, dtype=torch.int64) - return Types.GAUSSIAN(_pad_stack(pos_b, n), _pad_stack(scl_b, n), _pad_stack(rot_b, n), + return Types.SPLAT(_pad_stack(pos_b, n), _pad_stack(scl_b, n), _pad_stack(rot_b, n), _pad_stack(op_b, n), _pad_stack(sh_b, n), counts=counts) -class MergeGaussian(IO.ComfyNode): +class MergeSplat(IO.ComfyNode): @classmethod def define_schema(cls): - # Autogrow: a gaussian0/gaussian1/... input list that grows a fresh slot as you connect splats. - gaussians = IO.Autogrow.TemplatePrefix(IO.Gaussian.Input("gaussian"), prefix="gaussian", min=2, max=32) + # Autogrow: a splat0/splat1/... input list that grows a fresh slot as you connect splats. + splats = IO.Autogrow.TemplatePrefix(IO.Splat.Input("splat"), prefix="splat", min=2, max=32) return IO.Schema( - node_id="MergeGaussian", - display_name="Merge Gaussian Splats", + node_id="MergeSplat", + display_name="Merge Splats", search_aliases=["union splat", "densify gaussian", "combine splat", "merge gaussian"], category="3d/gaussian", - description="Concatenate any number of gaussian splats into one (per batch item). Because the " - "TripoSplat decoder samples points stochastically, unioning several decodes of the same " - "latent at different seeds densifies the surface - feed them here, then mesh the result.", - inputs=[IO.Autogrow.Input("gaussians", template=gaussians)], - outputs=[IO.Gaussian.Output(display_name="gaussian")], + description="Concatenate any number of gaussian splats into one. Unioning several decodes of the same " + "latent at different seeds densifies the surface, this can improve surface quality when meshing.", + inputs=[IO.Autogrow.Input("splats", template=splats)], + outputs=[IO.Splat.Output(display_name="splat")], ) @classmethod - def execute(cls, gaussians: IO.Autogrow.Type) -> IO.NodeOutput: - gs = [v for v in gaussians.values() if v is not None] + def execute(cls, splats: IO.Autogrow.Type) -> IO.NodeOutput: + gs = [v for v in splats.values() if v is not None] if not gs: - raise ValueError("MergeGaussian: connect at least one gaussian splat.") + raise ValueError("MergeSplat: connect at least one splat.") return IO.NodeOutput(_merge_gaussians(gs)) @@ -1232,7 +1229,7 @@ def _taubin_smooth(verts, faces, iters, lam=0.5, mu=-0.53): return np.ascontiguousarray(v.astype(np.float32)) -def _gaussian_to_mesh(g: Types.GAUSSIAN, i, res, kernel, taubin, level_bias, min_component, min_opacity, color_sharpen, device, progress=None): +def _gaussian_to_mesh(g: Types.SPLAT, i, res, kernel, taubin, level_bias, min_component, min_opacity, color_sharpen, device, progress=None): # Mesh one splat: density + colour grids -> Surface Nets -> floater removal -> Taubin smoothing -> # volume-sampled colours. Returns (verts, faces int64, colors in [0,1]), or None if no surface. rep = progress if progress is not None else (lambda *_: None) @@ -1297,12 +1294,12 @@ def _gaussian_to_mesh(g: Types.GAUSSIAN, i, res, kernel, taubin, level_bias, min return (torch.from_numpy(verts), torch.from_numpy(faces), torch.from_numpy(col)) -class GaussianToMesh(IO.ComfyNode): +class SplatToMesh(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( - node_id="GaussianToMesh", - display_name="Gaussian Splat to Mesh", + node_id="SplatToMesh", + display_name="Extract Mesh from Splat", search_aliases=["splat to mesh", "gaussian surface nets", "splat surface", "mesh splat"], category="3d/gaussian", description="Extract a coloured triangle MESH from a gaussian splat. Each splat is rasterized into a " @@ -1310,7 +1307,7 @@ class GaussianToMesh(IO.ComfyNode): "tiny floaters are dropped, and vertices are coloured from their nearest gaussians. Denser " "splats give more detail - union several decodes with Merge Gaussian Splats first.", inputs=[ - IO.Gaussian.Input("gaussian"), + IO.Splat.Input("splat"), IO.Int.Input("resolution", default=512, min=64, max=1024, step=16, tooltip="Density-grid resolution along the longest axis. Higher = finer surface, " "more VRAM/time (grows with resolution^3)."), @@ -1318,7 +1315,7 @@ class GaussianToMesh(IO.ComfyNode): tooltip="Max splat half-width in voxels. Each gaussian is rasterized over a window " "sized to its own 3-sigma, capped here - small surfels stay cheap, large ones " "aren't truncated. Raise if sparse splats leave gaps."), - IO.Int.Input("smooth", default=0, min=0, max=60, + IO.Int.Input("smooth", default=0, min=0, max=60, advanced = True, tooltip="Taubin mesh-smoothing iterations. Smooths the surface without shrinking it " "(volume-preserving), unlike blurring the density. 0 = raw surface."), IO.Float.Input("level", default=0.6, min=0.3, max=2.0, step=0.05, @@ -1338,18 +1335,18 @@ class GaussianToMesh(IO.ComfyNode): ) @classmethod - def execute(cls, gaussian, resolution, kernel, smooth, level, min_component, min_opacity, color_sharpen) -> IO.NodeOutput: + def execute(cls, splat, resolution, kernel, smooth, level, min_component, min_opacity, color_sharpen) -> IO.NodeOutput: device = comfy.model_management.get_torch_device() - b = gaussian.positions.shape[0] + b = splat.positions.shape[0] prec = 1000 # each splat owns a 0..prec block of the bar; its callback advances within that block pbar = comfy.utils.ProgressBar(b * prec) verts_l, faces_l, colors_l = [], [], [] for i in range(b): cb = lambda f, base=i * prec: pbar.update_absolute(base + int(min(max(f, 0.0), 1.0) * prec)) - res = _gaussian_to_mesh(gaussian, i, resolution, kernel, smooth, level, min_component, min_opacity, color_sharpen, device, cb) + res = _gaussian_to_mesh(splat, i, resolution, kernel, smooth, level, min_component, min_opacity, color_sharpen, device, cb) if res is None: - logging.warning("GaussianToMesh: splat %d produced no surface; emitting an empty mesh.", i) + logging.warning("SplatToMesh: splat %d produced no surface; emitting an empty mesh.", i) v, f, c = torch.zeros((0, 3)), torch.zeros((0, 3), dtype=torch.int64), torch.zeros((0, 3)) else: v, f, c = res @@ -1364,8 +1361,8 @@ class GaussianToMesh(IO.ComfyNode): class GaussianExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: - return [GaussianToFile3D, File3DToGaussian, RenderGaussian, TransformGaussian, GaussianInfo, - MergeGaussian, GaussianToMesh] + return [SplatToFile3D, File3DToSplat, RenderSplat, TransformSplat, GetSplatCount, + MergeSplat, SplatToMesh] async def comfy_entrypoint() -> GaussianExtension: diff --git a/nodes.py b/nodes.py index 7464d4465..5678bc22d 100644 --- a/nodes.py +++ b/nodes.py @@ -2455,7 +2455,7 @@ async def init_builtin_extra_nodes(): "nodes_save_3d.py", "nodes_moge.py", "nodes_mediapipe.py", - "nodes_gaussian.py", + "nodes_gaussian_splat.py", ] import_failed = []