diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index e0a585b10..294ad425e 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -5,7 +5,7 @@ from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.async_to_sync import create_sync_class from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from ._input_impl import VideoFromFile, VideoFromComponents -from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D +from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, SPLAT, File3D from . import _io_public as io from . import _ui_public as ui from comfy_execution.utils import get_executing_context @@ -143,6 +143,7 @@ class Types: VideoComponents = VideoComponents MESH = MESH VOXEL = VOXEL + SPLAT = SPLAT File3D = File3D diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index e03bafcde..a3aa508ce 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) from comfy_execution.graph_utils import ExecutionBlocker -from ._util import MESH, VOXEL, SVG as _SVG, File3D +from ._util import MESH, VOXEL, SPLAT, SVG as _SVG, File3D class FolderType(str, Enum): @@ -684,6 +684,10 @@ class Voxel(ComfyTypeIO): class Mesh(ComfyTypeIO): Type = MESH +@comfytype(io_type="SPLAT") +class Splat(ComfyTypeIO): + Type = SPLAT + @comfytype(io_type="FILE_3D") class File3DAny(ComfyTypeIO): @@ -2320,6 +2324,7 @@ __all__ = [ "LossMap", "Voxel", "Mesh", + "Splat", "File3DAny", "File3DGLB", "File3DGLTF", diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py index 115baf392..b27f5a97e 100644 --- a/comfy_api/latest/_util/__init__.py +++ b/comfy_api/latest/_util/__init__.py @@ -1,5 +1,5 @@ from .video_types import VideoContainer, VideoCodec, VideoComponents -from .geometry_types import VOXEL, MESH, File3D +from .geometry_types import VOXEL, MESH, SPLAT, File3D from .image_types import SVG __all__ = [ @@ -9,6 +9,7 @@ __all__ = [ "VideoComponents", "VOXEL", "MESH", + "SPLAT", "File3D", "SVG", ] diff --git a/comfy_api/latest/_util/geometry_types.py b/comfy_api/latest/_util/geometry_types.py index cdde60b10..84a18d69a 100644 --- a/comfy_api/latest/_util/geometry_types.py +++ b/comfy_api/latest/_util/geometry_types.py @@ -11,13 +11,32 @@ class VOXEL: self.data = data +class SPLAT: + """A batch of 3D Gaussian splats in render-ready (activated, world-space) form. + + Tensors are (B, N, ...) and zero-padded to a common N across the batch; `counts` (B,) holds the + real per-item lengths (None when rows are uniform and no slicing is needed). SH coefficients are + stored as (B, N, K, 3) with K = (sh_degree + 1)**2; the DC (diffuse) term is sh[..., 0, :]. + """ + + def __init__(self, positions: torch.Tensor, scales: torch.Tensor, rotations: torch.Tensor, + opacities: torch.Tensor, sh: torch.Tensor, counts: torch.Tensor | None = None): + self.positions = positions # (B, N, 3) world-space centers + self.scales = scales # (B, N, 3) linear (positive) per-axis std + self.rotations = rotations # (B, N, 4) quaternion wxyz (normalized) + self.opacities = opacities # (B, N, 1) in [0, 1] + self.sh = sh # (B, N, K, 3) spherical-harmonic color coefficients + self.counts = counts # (B,) real lengths, or None + + class MESH: def __init__(self, vertices: torch.Tensor, faces: torch.Tensor, uvs: torch.Tensor | None = None, vertex_colors: torch.Tensor | None = None, texture: torch.Tensor | None = None, vertex_counts: torch.Tensor | None = None, - face_counts: torch.Tensor | None = None): + face_counts: torch.Tensor | None = None, + unlit: bool = False): assert (vertex_counts is None) == (face_counts is None), \ "vertex_counts and face_counts must be provided together (both or neither)" @@ -30,6 +49,8 @@ class MESH: # these hold the real per-item lengths (B,). None means rows are uniform and no slicing is needed. self.vertex_counts = vertex_counts self.face_counts = face_counts + # Render flat / emissive (no scene lighting) when saved, e.g. for gaussian-splat-derived meshes. + self.unlit = unlit class File3D: diff --git a/comfy_extras/nodes_gaussian_splat.py b/comfy_extras/nodes_gaussian_splat.py new file mode 100644 index 000000000..7fb878b8b --- /dev/null +++ b/comfy_extras/nodes_gaussian_splat.py @@ -0,0 +1,1663 @@ +# Generic utility nodes for the SPLAT type (3D gaussian splats) + +import gzip +import logging +import math +import struct +from io import BytesIO + +import numpy as np +import torch +from typing_extensions import override +from scipy.ndimage import map_coordinates, minimum as _ndi_minimum, maximum as _ndi_maximum +from scipy.sparse import coo_matrix +from scipy.sparse.csgraph import connected_components + +import comfy.model_management +import comfy.utils +from comfy_api.latest import ComfyExtension, IO, Types +from comfy_extras.nodes_save_3d import pack_variable_mesh_batch +from server import PromptServer + +_C0 = 0.28209479177387814 # SH band-0 constant: DC coefficient -> base RGB + + +def _srgb_to_linear(c): + return torch.where(c <= 0.04045, c / 12.92, ((c.clamp_min(0) + 0.055) / 1.055) ** 2.4) + + +def _linear_to_srgb(c): + return torch.where(c <= 0.0031308, c * 12.92, 1.055 * c.clamp_min(0) ** (1 / 2.4) - 0.055) + + +def _real_len(g: Types.SPLAT, i: int) -> int: + # Real splat count of batch item i (honors variable-length `counts`). + return int(g.counts[i].item()) if g.counts is not None else g.positions.shape[1] + + +def _hex_to_rgb(h: str) -> tuple[float, float, float]: + # "#RRGGBB" -> (r,g,b) in [0,1]; falls back to black. + h = h.lstrip("#") + if len(h) != 6: + return (0.0, 0.0, 0.0) + return tuple(int(h[i:i + 2], 16) / 255.0 for i in (0, 2, 4)) + + +def _quantile(x, q): + # torch.quantile errors above 2**24 elements; stride-subsample large inputs for the estimate. + lim = 1 << 24 + if x.numel() > lim: + x = x[:: x.numel() // lim + 1] + return torch.quantile(x, q) + + +def _gaussian_ply_bytes(positions, scales, rotations, opacities, sh) -> bytes: + """Serialize render-ready gaussian tensors as a binary 3DGS .ply. + + positions (N,3) world; scales (N,3) linear; rotations (N,4) quat wxyz; opacities (N,1) in [0,1]; + sh (N,K,3) SH coefficients. Activated values are inverted to the standard 3D gaussian splat storage convention + (log scale, logit opacity). + """ + xyz = positions.cpu().numpy().astype(np.float32) + n = xyz.shape[0] + if n == 0: + raise ValueError("SplatToFile3D: gaussian is empty") + normals = np.zeros_like(xyz) + f = sh.cpu().numpy().astype(np.float32) # (N, K, 3) + f_dc = f[:, 0, :] # (N, 3) + f_rest = f[:, 1:, :].transpose(0, 2, 1).reshape(n, -1) # (N, 3*(K-1)) channel-major + op = opacities.cpu().numpy().astype(np.float32).reshape(n, 1).clip(1e-6, 1 - 1e-6) + op = np.log(op / (1.0 - op)) # inverse sigmoid (logit) + scale = np.log(scales.cpu().numpy().astype(np.float32).clip(min=1e-8)) + rot = rotations.cpu().numpy().astype(np.float32) # (N, 4) + + attrs = (['x', 'y', 'z', 'nx', 'ny', 'nz'] + + [f'f_dc_{i}' for i in range(3)] + + [f'f_rest_{i}' for i in range(f_rest.shape[1])] + + ['opacity'] + [f'scale_{i}' for i in range(3)] + [f'rot_{i}' for i in range(4)]) + elements = np.empty(n, dtype=[(a, 'f4') for a in attrs]) + elements[:] = list(map(tuple, np.concatenate([xyz, normals, f_dc, f_rest, op, scale, rot], axis=1))) + + header = "ply\nformat binary_little_endian 1.0\n" + f"element vertex {n}\n" + header += "".join(f"property float {a}\n" for a in attrs) + "end_header\n" + return header.encode('ascii') + elements.tobytes() + + +# .ksplat (mkkellogg SplatBuffer) level 0, SH degree 0: 4096-byte header, one 1024-byte section header, +# then N 44-byte records. Bucketing/quantization only exist at levels >= 1. See SplatBuffer.js. +_KSPLAT_HEADER_BYTES = 4096 +_KSPLAT_SECTION_HEADER_BYTES = 1024 +_KSPLAT_BYTES_PER_SPLAT = 44 # center 12 + scale 12 + rotation 16 + color(RGBA u8) 4 +_KSPLAT_VERSION = (0, 1) # SplatBuffer CurrentMajor/MinorVersion + + +def _gaussian_ksplat_bytes(positions, scales, rotations, opacities, sh) -> bytes: + """Serialize gaussian tensors as a level-0, SH degree-0 .ksplat (linear scale, opacity in color alpha). + + positions (N,3) world; scales (N,3) linear; rotations (N,4) wxyz; opacities (N,1) in [0,1]; sh (N,K,3). + """ + xyz = positions.cpu().numpy().astype(np.float32) + n = xyz.shape[0] + if n == 0: + raise ValueError("SplatToFile3D: gaussian is empty") + scale = scales.cpu().numpy().astype(np.float32) + rot = rotations.cpu().numpy().astype(np.float32) # wxyz, mirrors the .ply rot order + rot = rot / np.linalg.norm(rot, axis=1, keepdims=True).clip(1e-12) + rgb = np.clip(sh[:, 0, :].cpu().numpy().astype(np.float32) * _C0 + 0.5, 0, 1) + op = opacities.cpu().numpy().astype(np.float32).reshape(n, 1).clip(0, 1) + rgba = np.round(np.concatenate([rgb, op], axis=1) * 255.0).astype(np.uint8) # (N, 4) RGBA + + # 44-byte record: float center(3) + scale(3) + rot(4), then uint8 rgba(4). + floats = np.concatenate([xyz, scale, rot], axis=1).astype(' bytes: + """Serialize gaussian tensors as a gzip-compressed .spz (Niantic v2, SH degree 0, base color only). + + positions (N,3) world; scales (N,3) linear; rotations (N,4) wxyz; opacities (N,1) in [0,1]; sh (N,K,3). + """ + xyz = positions.cpu().numpy().astype(np.float32) + n = xyz.shape[0] + if n == 0: + raise ValueError("SplatToFile3D: gaussian is empty") + + # Positions: fixed point, masked to 24 bits, little-endian 3-byte words. + fixed = 1 << _SPZ_FRACTIONAL_BITS + qi = np.clip(np.round(xyz * fixed), -(1 << 23), (1 << 23) - 1).astype(np.int32) + qu = (qi & 0xFFFFFF).astype(np.uint32) + pos = np.stack([qu & 0xFF, (qu >> 8) & 0xFF, (qu >> 16) & 0xFF], axis=-1).reshape(n, 9).astype(np.uint8) + + alpha = np.round(opacities.cpu().numpy().astype(np.float32).reshape(n) * 255.0).clip(0, 255).astype(np.uint8) + + rgb = sh[:, 0, :].cpu().numpy().astype(np.float32) * _C0 + 0.5 + col = np.round(((rgb - 0.5) / _SPZ_COLOR_SCALE + 0.5) * 255.0).clip(0, 255).astype(np.uint8) # (N,3) + + sln = np.log(scales.cpu().numpy().astype(np.float32).clip(min=1e-9)) + scb = np.round((sln + 10.0) * 16.0).clip(0, 255).astype(np.uint8) # (N,3) inverts exp(b/16-10) + + rot = rotations.cpu().numpy().astype(np.float32) # wxyz + rot = rot / np.linalg.norm(rot, axis=1, keepdims=True).clip(1e-12) + rot[rot[:, 0] < 0] *= -1.0 # canonical w >= 0 (w dropped on decode) + rotb = np.round((rot[:, 1:4] + 1.0) * 127.5).clip(0, 255).astype(np.uint8) # (N,3) x,y,z + + header = bytearray(16) + struct.pack_into(' (positions, scales linear, rotations wxyz, opacities [0,1], sh (N,K,3)) ---- +# Inverse of the writers above and of spark's loaders. ksplat/splat/spz carry base color only (SH degree 0 +# -> K=1); .ply round-trips full SH. None of the formats flip axes, so import is the identity of export. +_PLY_DTYPES = {'char': 'i1', 'uchar': 'u1', 'short': 'i2', 'ushort': 'u2', 'int': 'i4', 'uint': 'u4', + 'float': 'f4', 'double': 'f8', 'int8': 'i1', 'uint8': 'u1', 'int16': 'i2', 'uint16': 'u2', + 'int32': 'i4', 'uint32': 'u4', 'float32': 'f4', 'float64': 'f8'} +_KSPLAT_COMPRESSION = { # level -> (bytesPerCenter, scale, rotation, color, shComponent, defaultScaleRange) + 0: (12, 12, 16, 4, 4, 1), 1: (6, 6, 8, 4, 2, 32767), 2: (6, 6, 8, 4, 1, 32767)} +_KSPLAT_SH_COMPONENTS = {0: 0, 1: 9, 2: 24, 3: 45} + + +def _rgb_to_sh_dc(rgb): + return ((np.asarray(rgb, np.float32) - 0.5) / _C0)[:, None, :] # (N,3) base color -> (N,1,3) SH DC + + +def _norm_quat(q): + return q / np.linalg.norm(q, axis=1, keepdims=True).clip(1e-12) + + +def _parse_ply_gaussian(data: bytes): + end = data.find(b'end_header') + if end < 0: + raise ValueError("File3DToSplat: not a PLY (missing end_header)") + header = data[:end].decode('ascii', 'replace') + body = end + len(b'end_header') + body += 2 if data[body:body + 2] == b'\r\n' else 1 + count, props, in_vertex = 0, [], False + for line in header.splitlines(): + p = line.split() + if not p: + continue + if p[0] == 'format' and p[1] != 'binary_little_endian': + raise ValueError(f"File3DToSplat: unsupported PLY format '{p[1]}' (need binary_little_endian)") + if p[0] == 'element': + in_vertex = p[1] == 'vertex' + if in_vertex: + count = int(p[2]) + elif p[0] == 'property' and in_vertex: + if p[1] == 'list': + raise ValueError("File3DToSplat: PLY vertex has list properties (unsupported)") + props.append((p[2], '<' + _PLY_DTYPES[p[1]])) + arr = np.frombuffer(data, np.dtype(props), count=count, offset=body) + names = arr.dtype.names + c = lambda k: arr[k].astype(np.float32) + n = count + + xyz = np.stack([c('x'), c('y'), c('z')], 1) + if 'scale_0' in names: + scale = np.exp(np.stack([c('scale_0'), c('scale_1'), c('scale_2')], 1)) # 3DGS stores log scale + else: + scale = np.full((n, 3), 0.01, np.float32) + if 'rot_0' in names: + rot = _norm_quat(np.stack([c('rot_0'), c('rot_1'), c('rot_2'), c('rot_3')], 1)) # wxyz + else: + rot = np.tile(np.array([1, 0, 0, 0], np.float32), (n, 1)) + opacity = 1.0 / (1.0 + np.exp(-c('opacity'))) if 'opacity' in names else np.ones(n, np.float32) + + if 'f_dc_0' in names: + dc = np.stack([c('f_dc_0'), c('f_dc_1'), c('f_dc_2')], 1) # (N,3) + rest = sorted((k for k in names if k.startswith('f_rest_')), key=lambda s: int(s.split('_')[-1])) + if rest: + r = np.stack([c(k) for k in rest], 1) # (N, 3*(K-1)) channel-major + kk = r.shape[1] // 3 + 1 + r = r.reshape(n, 3, kk - 1).transpose(0, 2, 1) # -> (N, K-1, 3) + sh = np.concatenate([dc[:, None, :], r], 1) + else: + sh = dc[:, None, :] + elif 'red' in names: + sh = _rgb_to_sh_dc(np.stack([c('red'), c('green'), c('blue')], 1) / 255.0) + else: + sh = np.zeros((n, 1, 3), np.float32) + return xyz, scale, rot, opacity, sh + + +def _parse_splat_gaussian(data: bytes): + # antimatter15 .splat: 32-byte records (f32 xyz, f32 scale, u8 rgba, u8 quat as (b-128)/128 wxyz). + if len(data) % 32 != 0: + raise ValueError("File3DToSplat: .splat size is not a multiple of 32 bytes") + rec = np.frombuffer(data, np.dtype([('xyz', ' 0: + ct, ft = (' full_splats: + lengths = np.frombuffer(data, '> 30) & 3 + q = np.zeros((n, 4), np.float32) # x,y,z,w + remaining, sumsq = combined.copy(), np.zeros(n, np.float64) + for comp in (3, 2, 1, 0): + active = comp != largest + value = (remaining & 0x1FF).astype(np.float64) + sign = (remaining >> 9) & 1 + remaining = np.where(active, remaining >> 10, remaining) + val = (1.0 / math.sqrt(2)) * (value / 0x1FF) + val = np.where(sign == 1, -val, val) + q[active, comp] = val[active] + sumsq += np.where(active, val * val, 0.0) + q[np.arange(n), largest] = np.sqrt(np.clip(1.0 - sumsq, 0, None)) + rot = _norm_quat(np.stack([q[:, 3], q[:, 0], q[:, 1], q[:, 2]], 1)) # xyzw -> wxyz + else: + qb = np.frombuffer(raw, np.uint8, count=n * 3, offset=off).reshape(n, 3).astype(np.float32) + xq = qb / 127.5 - 1.0 + w = np.sqrt(np.clip(1.0 - (xq ** 2).sum(1), 0, None)) + rot = _norm_quat(np.concatenate([w[:, None], xq], 1)) # wxyz + return xyz, scale, rot, alpha, _rgb_to_sh_dc(rgb) + + +_GAUSSIAN_PARSERS = {"ply": _parse_ply_gaussian, "splat": _parse_splat_gaussian, + "ksplat": _parse_ksplat_gaussian, "spz": _parse_spz_gaussian} + + +def _detect_splat_format(data: bytes) -> str: + if data[:3] == b'ply': + return "ply" + if data[:2] == b'\x1f\x8b': # gzip -> spz + return "spz" + if len(data) >= 2 and data[0] == 0 and data[1] >= 1: # ksplat version 0.x header + return "ksplat" + if len(data) % 32 == 0: + return "splat" + raise ValueError("File3DToSplat: could not determine splat format from contents") + + +def _gaussian_item(g: Types.SPLAT, i: int, device): + # Slice batch item i to its real length, as float32 torch tensors on `device` (SH DC -> base RGB). + end = _real_len(g, i) + to = lambda a: a.to(device=device, dtype=torch.float32) + xyz = to(g.positions[i, :end]) + rgb = (to(g.sh[i, :end, 0, :]) * _C0 + 0.5).clamp(0, 1) + opacity = to(g.opacities[i, :end]).reshape(-1) + scale = to(g.scales[i, :end]) + rot = to(g.rotations[i, :end]) + return xyz, rgb, opacity, scale, rot + + +def _quat_to_mat(q): + # q: (N, 4) wxyz, normalized -> (N, 3, 3) + q = q / q.norm(dim=-1, keepdim=True).clamp_min(1e-12) + w, x, y, z = q.unbind(-1) + return torch.stack([ + 1 - 2 * (y * y + z * z), 2 * (x * y - w * z), 2 * (x * z + w * y), + 2 * (x * y + w * z), 1 - 2 * (x * x + z * z), 2 * (y * z - w * x), + 2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x * x + y * y), + ], dim=-1).reshape(-1, 3, 3) + + +def _quat_mul(a, b): + # Hamilton product a (x) b, wxyz. + aw, ax, ay, az = a.unbind(-1) + bw, bx, by, bz = b.unbind(-1) + return torch.stack([ + aw * bw - ax * bx - ay * by - az * bz, + aw * bx + ax * bw + ay * bz - az * by, + aw * by - ax * bz + ay * bw + az * bx, + aw * bz + ax * by - ay * bx + az * bw, + ], dim=-1) + + +def _euler_to_quat(rx, ry, rz): + # Degrees, applied as Rz @ Ry @ Rx (rotate about X, then Y, then Z in world). Returns wxyz. + c, s = np.cos(np.radians([rx, ry, rz]) / 2.0), np.sin(np.radians([rx, ry, rz]) / 2.0) + qx = torch.tensor([c[0], s[0], 0.0, 0.0], dtype=torch.float32) + qy = torch.tensor([c[1], 0.0, s[1], 0.0], dtype=torch.float32) + qz = torch.tensor([c[2], 0.0, 0.0, s[2]], dtype=torch.float32) + return _quat_mul(_quat_mul(qz, qy), qx) + + +def _mat_to_quat(m): + # Rotation matrix (..., 3, 3) -> quaternion (..., 4) wxyz. Batched; builds the four candidate quaternions + # and keeps the one with the largest component (numerically stable across all rotations). + m00, m11, m22 = m[..., 0, 0], m[..., 1, 1], m[..., 2, 2] + m21, m12 = m[..., 2, 1], m[..., 1, 2] + m02, m20 = m[..., 0, 2], m[..., 2, 0] + m10, m01 = m[..., 1, 0], m[..., 0, 1] + q2 = torch.stack([1 + m00 + m11 + m22, 1 + m00 - m11 - m22, + 1 - m00 + m11 - m22, 1 - m00 - m11 + m22], -1) # 4 * (w^2, x^2, y^2, z^2) + cand = torch.stack([ + torch.stack([q2[..., 0], m21 - m12, m02 - m20, m10 - m01], -1), + torch.stack([m21 - m12, q2[..., 1], m10 + m01, m02 + m20], -1), + torch.stack([m02 - m20, m10 + m01, q2[..., 2], m12 + m21], -1), + torch.stack([m10 - m01, m02 + m20, m12 + m21, q2[..., 3]], -1), + ], -2) # (...,4,4) candidates, rows = wxyz + sel = q2.argmax(-1) + q = torch.gather(cand, -2, sel[..., None, None].expand(sel.shape + (1, 4)))[..., 0, :] + return q / q.norm(dim=-1, keepdim=True).clamp_min(1e-12) + + +class SplatToFile3D(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SplatToFile3D", + display_name="Create 3D File (from Splat)", + search_aliases=["gaussian to ply", "splat to file", "export gaussian"], + category="3d/splat", + description="Serialize a gaussian splat to a File3D object for Save / Preview 3D nodes. " + "Supports one item per batch only.", + inputs=[ + IO.Splat.Input("splat"), + IO.Combo.Input("format", options=["ply", "ksplat", "spz"], # TODO: add "splat" when we have a writer for it + tooltip="ply: standard 3D Gaussian Splat with full spherical harmonics. " + "ksplat: mkkellogg SplatBuffer (level 0, uncompressed), base color only " + "spz: Niantic gzip-compressed (~10x smaller), base color only " + ), + ], + outputs=[IO.File3DAny.Output(display_name="model_3d")], + ) + + @classmethod + def execute(cls, splat, format="ply") -> IO.NodeOutput: + if splat.positions.shape[0] > 1: + logging.warning("SplatToFile3D supports one item per batch only. Got %d; using first.", splat.positions.shape[0]) + end = _real_len(splat, 0) + writer = {"ksplat": _gaussian_ksplat_bytes, "spz": _gaussian_spz_bytes}.get(format, _gaussian_ply_bytes) + data = writer(splat.positions[0, :end], splat.scales[0, :end], + splat.rotations[0, :end], splat.opacities[0, :end], splat.sh[0, :end]) + return IO.NodeOutput(Types.File3D(BytesIO(data), file_format=format)) + + +class File3DToSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="File3DToSplat", + display_name="Get Splat", + search_aliases=["load splat", "ply to splat", "import splat", "file to splat"], + category="3d/splat", + description="Parse a splat File3D into a gaussian splat. Inverse of Create 3D File (from Splat). " + "Supported format: PLY, SPLAT, KSPLAT, SPZ. PLY carries full spherical harmonics, " + "the other formats are base color only. Format is auto-detected from the file contents.", + inputs=[ + IO.MultiType.Input( + IO.File3DAny.Input("model_3d"), + types=[IO.File3DPLY, IO.File3DSPLAT, IO.File3DKSPLAT, IO.File3DSPZ], + tooltip="A gaussian splat 3D file", + ), + ], + outputs=[IO.Splat.Output(display_name="splat")], + ) + + @classmethod + def execute(cls, model_3d: Types.File3D) -> IO.NodeOutput: + data = model_3d.get_bytes() + fmt = (model_3d.format or "").lower() + parser = _GAUSSIAN_PARSERS.get(fmt) or _GAUSSIAN_PARSERS[_detect_splat_format(data)] + xyz, scale, rot, opacity, sh = parser(data) + + t = lambda a: torch.from_numpy(np.ascontiguousarray(a)).float() + splat = Types.SPLAT( + t(xyz)[None], # (1, N, 3) + t(scale)[None], # (1, N, 3) linear + t(rot)[None], # (1, N, 4) wxyz + t(opacity).reshape(1, -1, 1), # (1, N, 1) + t(sh)[None], # (1, N, K, 3) + ) + return IO.NodeOutput(splat) + + +def _view_matrix_t(yaw_deg, pitch_deg, device): + y, p = math.radians(yaw_deg), math.radians(pitch_deg) + cy, sy, cp, sp = math.cos(y), math.sin(y), math.cos(p), math.sin(p) + Ry = torch.tensor([[cy, 0, sy], [0, 1, 0], [-sy, 0, cy]], device=device) + Rx = torch.tensor([[1, 0, 0], [0, cp, -sp], [0, sp, cp]], device=device) + return Rx @ Ry + + +def _camera_basis(camera_info, dev): + # Look-at basis in the splat frame, named by their projection rows: right = image +x, up = image +y + # (down, since yflip=1), fwd = view/depth axis (eye -> scene). Load3D is three.js (right-handed, Y-up, + # camera looks down -Z); the splat is 3DGS (Y-down, Z-forward). World -> splat is a 180 deg rotation + # about X: (x, y, z) -> (x, -y, -z) (det +1, no mirror, no axis swap). + pos, tgt = camera_info.get("position", {}), camera_info.get("target", {}) + m = lambda d: torch.tensor([float(d.get("x", 0.0)), -float(d.get("y", 0.0)), -float(d.get("z", 0.0))], device=dev) + eye, target = m(pos), m(tgt) + mv = lambda v: torch.stack([v[0], -v[1], -v[2]]) # same world->splat map, for direction vectors + n = lambda v: v / v.norm().clamp_min(1e-8) + q = camera_info.get("quaternion") + if q: # exact camera world rotation (incl. roll) + qwxyz = torch.tensor([float(q.get("w", 1.0)), float(q.get("x", 0.0)), + float(q.get("y", 0.0)), float(q.get("z", 0.0))], device=dev) + R = _quat_to_mat(qwxyz[None])[0] # columns = camera world axes; looks down local -Z + right = n(mv(R[:, 0])) # camera +X -> image right + up = n(mv(-R[:, 1])) # camera +Y is image up; image-down row is its negative + fwd = n(mv(-R[:, 2])) # camera looks down local -Z -> view direction + return eye, target, right, up, fwd + fwd = n(target - eye) # no quaternion: orbit-consistent, roll-free + yaw = math.degrees(math.atan2(-float(fwd[0]), float(fwd[2]))) + pitch = math.degrees(math.asin(max(-1.0, min(1.0, float(fwd[1]))))) + W = _view_matrix_t(yaw, pitch, dev) + return eye, target, W[0], W[1], W[2] + + +def _lookat_quat_wxyz(position, target, dev): + # three.js lookAt in world frame: camera local +Z = (eye - target), up = world +Y. Returns wxyz. + z = position - target + z = z / z.norm().clamp_min(1e-8) + up0 = torch.tensor([0.0, 1.0, 0.0], device=dev) + if z.dot(up0).abs() > 0.999: # looking straight up/down + up0 = torch.tensor([0.0, 0.0, 1.0], device=dev) + x = torch.linalg.cross(up0, z) + x = x / x.norm().clamp_min(1e-8) + y = torch.linalg.cross(z, x) + R = torch.stack([x, y, z], dim=1) # columns = camera world axes + return _mat_to_quat(R[None])[0] + + +def _lookat_camera_info(position, target, fov, dev, zoom=1.0, camera_type="perspective", roll=0.0): + # Build a camera_info from a world-space (right-handed, Y-up) eye + look-at target; up = world +Y. + pos = torch.as_tensor(position, dtype=torch.float32, device=dev) + tgt = torch.as_tensor(target, dtype=torch.float32, device=dev) + q = _lookat_quat_wxyz(pos, tgt, dev) + if roll: # roll about the view axis (camera local Z) + a = math.radians(roll) + qz = torch.tensor([math.cos(a / 2), 0.0, 0.0, math.sin(a / 2)], device=dev) + q = _quat_mul(q[None], qz[None])[0] + xyz = lambda v: {"x": float(v[0]), "y": float(v[1]), "z": float(v[2])} + return {"position": xyz(pos), "target": xyz(tgt), + "quaternion": {"x": float(q[1]), "y": float(q[2]), "z": float(q[3]), "w": float(q[0])}, + "fov": float(fov), "cameraType": str(camera_type), "zoom": float(zoom)} + + +def _quat_camera_info(position, quat_xyzw, fov, dev, zoom=1.0, camera_type="perspective"): + # camera_info from an explicit world position + camera-rotation quaternion (three.js: looks down local -Z). + pos = torch.as_tensor(position, dtype=torch.float32, device=dev) + qx, qy, qz, qw = (float(c) for c in quat_xyzw) + qwxyz = torch.tensor([qw, qx, qy, qz], dtype=torch.float32, device=dev) + qwxyz = qwxyz / qwxyz.norm().clamp_min(1e-8) + R = _quat_to_mat(qwxyz[None])[0] + tgt = pos - R[:, 2] # look one unit down local -Z + xyz = lambda v: {"x": float(v[0]), "y": float(v[1]), "z": float(v[2])} + return {"position": xyz(pos), "target": xyz(tgt), + "quaternion": {"x": float(qwxyz[1]), "y": float(qwxyz[2]), "z": float(qwxyz[3]), "w": float(qwxyz[0])}, + "fov": float(fov), "cameraType": str(camera_type), "zoom": float(zoom)} + + +def _orbit_camera_info(yaw, pitch, distance, fov, pivot_splat, dev): + # Orbit helper for RenderSplat's default camera: yaw/pitch about `pivot_splat` (splat frame) at `distance`. + # World<->splat is the (x,-y,-z) map, so _camera_basis recovers exactly _view_matrix_t(yaw, pitch). + y, p = math.radians(yaw), math.radians(pitch) + cy, sy, cp, sp = math.cos(y), math.sin(y), math.cos(p), math.sin(p) + fwd_splat = torch.tensor([-cp * sy, sp, cp * cy], device=dev) # == _view_matrix_t(yaw, pitch)[2] + m = lambda v: torch.stack([v[0], -v[1], -v[2]]) # splat<->world (its own inverse) + return _lookat_camera_info(m(pivot_splat - distance * fwd_splat), m(pivot_splat), fov, dev) + + +def _orbit_camera_info_yaw(camera_info, angle_deg, dev): + # Turntable: rigidly rotate a camera_info about world +Y around its target by angle_deg. Returns a new dict. + a = math.radians(angle_deg) + ca, sa = math.cos(a), math.sin(a) + v = lambda d: torch.tensor([float(d.get("x", 0.0)), float(d.get("y", 0.0)), float(d.get("z", 0.0))], device=dev) + pos, tgt = v(camera_info.get("position", {})), v(camera_info.get("target", {})) + Ry = torch.tensor([[ca, 0.0, sa], [0.0, 1.0, 0.0], [-sa, 0.0, ca]], device=dev) + new_pos = tgt + Ry @ (pos - tgt) + q = camera_info.get("quaternion") or {} + qcur = torch.tensor([float(q.get("w", 1.0)), float(q.get("x", 0.0)), + float(q.get("y", 0.0)), float(q.get("z", 0.0))], device=dev) + qy = torch.tensor([math.cos(a / 2), 0.0, math.sin(a / 2), 0.0], device=dev) # world +Y rotation + qn = _quat_mul(qy[None], qcur[None])[0] + xyz = lambda t: {"x": float(t[0]), "y": float(t[1]), "z": float(t[2])} + return {**camera_info, "position": xyz(new_pos), + "quaternion": {"x": float(qn[1]), "y": float(qn[2]), "z": float(qn[3]), "w": float(qn[0])}} + + +def _gauss_blur(x, sigma, dev): + # Separable Gaussian blur of (1, C, H, W). Used to denoise the screen-space normal map. + r = max(1, int(round(3 * sigma))) + k = torch.exp(-0.5 * (torch.arange(-r, r + 1, device=dev, dtype=torch.float32) / sigma) ** 2) + k = k / k.sum() + c = x.shape[1] + x = torch.nn.functional.conv2d(x, k.view(1, 1, 1, -1).expand(c, 1, 1, -1), padding=(0, r), groups=c) + x = torch.nn.functional.conv2d(x, k.view(1, 1, -1, 1).expand(c, 1, -1, 1), padding=(r, 0), groups=c) + return x + + +def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, bg, camera_info, + sharpen=1.0, headlight_shading=0.0, render_style="color"): + # Perspective-correct anisotropic gaussian splat rasterizer. Each splat is weighted by its 3D Gaussian's + # peak along each pixel's ray (AAA / Hahlbohm), composited front-to-back across depth slabs. `render_style` + # selects the image: color / clay / depth / normal. Returns (image HxWx3, coverage mask HxW) on CPU. + dev = comfy.model_management.get_torch_device() + t = lambda a: torch.as_tensor(a, dtype=torch.float32, device=dev) + idev, idtype = comfy.model_management.intermediate_device(), comfy.model_management.intermediate_dtype() + xyz, rgb, opacity = t(xyz), t(rgb).clamp(0, 1), t(opacity).reshape(-1) + scale, rot = t(scale) * float(splat_scale), t(rot) + do_linear = render_style == "color" # colour blends in linear light, re-encoded at the end + if do_linear: + rgb = _srgb_to_linear(rgb) + flat = width * height + bg_t = t(bg) + bg_comp = _srgb_to_linear(bg_t) if do_linear else bg_t # background blended in the same space as the splats + need_depth = render_style == "depth" + need_normal = render_style in ("normal", "clay") or headlight_shading > 0 + + def background_only(): # no splats to rasterize -> just the background + empty mask + img = bg_t.expand(height, width, 3) if render_style == "color" else torch.zeros(height, width, 3, device=dev) + return img.to(idev, idtype), torch.zeros(height, width, device=idev, dtype=idtype) + + if xyz.shape[0] == 0: # empty input (e.g. all culled by opacity_threshold) + return background_only() + + eye, target, right, up, fwd = _camera_basis(camera_info, dev) # all camera state comes from camera_info + W = torch.stack([right, up, fwd], 0) # rows = camera axes (world -> camera) + cam = (xyz - eye) @ W.T + fov = float(camera_info.get("fov", 0) or 0) or 35.0 + zoom = float(camera_info.get("zoom", 1.0) or 1.0) # three.js digital zoom: scales the focal length + is_ortho = str(camera_info.get("cameraType", "")).lower().startswith("ortho") + xc, yc, zc = cam.unbind(-1) + + keep = zc > 1e-2 + xc, yc, zc, rgb, opacity, scale, rot = (a[keep] for a in (xc, yc, zc, rgb, opacity, scale, rot)) + if xc.shape[0] == 0: # nothing in front of the camera -> background only + return background_only() + if render_style == "clay": + rgb = torch.full_like(rgb, 0.75) # neutral albedo -> shading shows pure geometry + + f = (min(width, height) / 2) / math.tan(math.radians(fov) / 2) * zoom # fov over the smaller axis, x camera zoom + cx0, cy0 = width / 2, height / 2 + + # Camera-space 3D covariance per splat: Sigma = (W Rq) diag(scale^2) (W Rq)^T, plus a tiny relative + # regularizer for a stable inverse (a pixel-size Mip low-pass would over-thicken flat surfels and blur). + Mw = W[None] @ _quat_to_mat(rot) # (N,3,3) world -> camera + cam_cov = (Mw * scale.square()[:, None, :]) @ Mw.transpose(1, 2) + cam_cov = cam_cov + (cam_cov.diagonal(dim1=-2, dim2=-1).mean(-1) * 1e-3)[:, None, None] * torch.eye(3, device=dev) + + # Perspective-correct weighting: peak of the 3D Gaussian along each pixel ray. Precompute Si, Si@mu, mu^T Si mu. + mu = torch.stack([xc, yc, zc], -1) + si = torch.linalg.inv(cam_cov) + simu = (si @ mu[:, :, None])[:, :, 0] # (N,3) + musimu = (mu * simu).sum(-1) # (N,) + s00, s01, s02 = si[:, 0, 0], si[:, 0, 1], si[:, 0, 2] + s11, s12, s22 = si[:, 1, 1], si[:, 1, 2], si[:, 2, 2] + simu0, simu1, simu2 = simu.unbind(-1) + if need_normal: # surfel normal = thinnest axis, oriented toward camera + nrm = Mw[torch.arange(Mw.shape[0], device=dev), :, scale.argmin(-1)] # (N,3) camera-space normal + nrm = nrm * torch.where(nrm[:, 2:3] > 0, -1.0, 1.0) # flip so nz <= 0 (faces camera) + + # Screen centre (exact) + footprint radius from the affine 2D projection (used only to size the kernel). + # The image is +y-down, so the projection's y row is unflipped - it matches the splat frame's +Y. + jm = torch.zeros(xc.shape[0], 2, 3, device=dev) + if is_ortho: # parallel projection: screen = s * (xc, yc) + s = f / float((target - eye).norm().clamp_min(1e-6)) # pixels per world unit at the target plane + cx, cy = cx0 + s * xc, cy0 + s * yc + jm[:, 0, 0] = s + jm[:, 1, 1] = s + else: # perspective: screen = f * (xc, yc) / zc + invz = 1.0 / zc + cx, cy = cx0 + f * xc * invz, cy0 + f * yc * invz + jm[:, 0, 0], jm[:, 0, 2] = f * invz, -f * xc * invz.square() + jm[:, 1, 1], jm[:, 1, 2] = f * invz, -f * yc * invz.square() + cov2 = jm @ cam_cov @ jm.transpose(1, 2) + a, b, c = cov2[:, 0, 0], cov2[:, 0, 1], cov2[:, 1, 1] + max_eig = (a + c) * 0.5 + (((a - c) * 0.5).square() + b * b).clamp_min(0).sqrt() + radius = 3.0 * max_eig.clamp_min(1e-8).sqrt() + K = int(min(max(24, min(width, height) // 16), max(2, math.ceil(_quantile(radius, 0.995).item())))) + + # Per-splat kernel size: bucket splats by radius into a coarse ladder of window sizes (global K stays the cap) so + # small splats (the bulk of it) use a small window. + levels = [L for L in (16, 64, 256) if L < K] + [K] + levels_t = torch.tensor(levels, device=dev, dtype=torch.float32) + grids = [] + for L in levels: + rng = torch.arange(-L, L + 1, device=dev, dtype=torch.float32) + gy, gx = torch.meshgrid(rng, rng, indexing="ij") + grids.append((gx.reshape(-1), gy.reshape(-1))) + blevel = torch.bucketize(radius * (4.0 / 3.0), levels_t).clamp_(max=len(levels) - 1) # window >= ~4 sigma + + n = zc.shape[0] + ns = int(min(256, max(1, n // 1000))) # depth slabs: 1 per ~1000 splats, capped + nl = len(levels) + order = torch.argsort(zc) # front (small zc) -> back -> defines the slabs + bounds = torch.linspace(0, n, ns + 1, device=dev).round().long() + rank = torch.empty(n, dtype=torch.long, device=dev) + rank[order] = torch.arange(n, device=dev) # depth rank of each splat + slab_id = (torch.searchsorted(bounds, rank, right=True) - 1).clamp_(0, ns - 1) + key = slab_id * nl + blevel # group by slab, then kernel level (order-free within) + order = torch.argsort(key) + key = key[order] + + cxr, cyr = cx[order].round(), cy[order].round() + s00, s01, s02 = s00[order], s01[order], s02[order] + s11, s12, s22 = s11[order], s12[order], s22[order] + s01b, s02b, s12b = s01 * 2, s02 * 2, s12 * 2 # doubled cross terms for the fused quadratic forms + simu0, simu1, simu2, musimu = simu0[order], simu1[order], simu2[order], musimu[order] + opacity, rgb = opacity[order], rgb[order] + zc_o = zc[order] if need_depth else None + nrm_o = nrm[order] if need_normal else None + mux_o, muy_o, muz_o = (xc[order], yc[order], zc[order]) if is_ortho else (None, None, None) + + # Pack the per-splat scalars into one tensor so each chunk slices once + common = [cxr, cyr, s00, s11, s22, s01b, s02b, s12b, opacity] + pstack = torch.stack(common + ([s02, s12, mux_o, muy_o, muz_o] if is_ortho else [simu0, simu1, simu2, musimu])) + + # Precompute the (slab, level) run table on-GPU and pull it to the CPU once + starts = torch.cat([torch.zeros(1, dtype=torch.long, device=dev), (key[1:] != key[:-1]).nonzero().flatten() + 1]) + ks = key[starts] + run_lo = starts.tolist() + [n] + run_lev = (ks % nl).tolist() + run_slab = torch.div(ks, nl, rounding_mode="floor").tolist() + slab_runs = [[] for _ in range(ns)] + for r in range(len(run_lev)): + slab_runs[run_slab[r]].append((run_lo[r], run_lo[r + 1], run_lev[r])) + + def splat(lo, hi, ox, oy): # -> pixel idx (m,M), alpha (m,M); weight = 3D Gaussian peak along each pixel's ray + cols = pstack[:, lo:hi, None].unbind(0) + cxr_, cyr_, a00, a11, a22, b01, b02, b12, opa = cols[:9] # a* = Si components; b* = 2 * cross terms + px = cxr_ + ox[None, :] + py = cyr_ + oy[None, :] + valid = (px >= 0) & (px < width) & (py >= 0) & (py < height) + if is_ortho: # parallel ray (0,0,1) from screen point (X, Y, 0); rz constant per splat + c02, c12, mx, my, mz = cols[9:] + rx = (px - cx0) / s - mx + ry = (py - cy0) / s - my + rz = -mz + a22rz = a22 * rz + inx = torch.addcmul(b02 * rz, a00, rx).addcmul_(b01, ry) # a00 rx + b01 ry + b02 rz + rSr = torch.addcmul(a22rz * rz, rx, inx).addcmul_(ry, torch.addcmul(b12 * rz, a11, ry)) + dsr = torch.addcmul(a22rz, c02, rx).addcmul_(c12, ry) + q = torch.addcdiv(rSr, dsr * dsr, a22.clamp_min(1e-12), value=-1).clamp_min_(0) + else: # perspective ray (dx,dy,1) through the camera origin + su0, su1, su2, mus = cols[9:] + dx, dy = (px - cx0) / f, (py - cy0) / f + dsid = torch.addcmul(a22, dx, torch.addcmul(b02, a00, dx)) # a22 + dx*(a00 dx + b02) + dsid = dsid.addcmul_(dy, torch.addcmul(b12, a11, dy)) # + dy*(a11 dy + b12) + dsid = dsid.addcmul_(b01 * dx, dy) # + (2 s01) dx dy + dsimu = torch.addcmul(su2, dx, su0).addcmul_(dy, su1) + q = torch.addcdiv(mus, dsimu * dsimu, dsid.clamp_min(1e-12), value=-1).clamp_min_(0) + alpha = (opa * torch.exp(-0.5 * q) * valid).clamp_(0, 0.999) + idx = py.long().clamp(0, height - 1) * width + px.long().clamp(0, width - 1) + return idx, alpha + + # Front-to-back compositing over the depth slabs set up above. Within a slab the accumulation is a pure + # sum (order-independent), so splats are grouped by kernel level and each level uses its own tight window. + sharp = sharpen != 1.0 # winner-take-more colour blend: dominant splat shows more + cacc = torch.zeros((flat, 3), device=dev) + trans = torch.ones((flat,), device=dev) + a_buf = torch.zeros((flat,), device=dev) # sum alpha -> colour/depth/normal weight (alpha-weighted mean) + tau_buf = torch.zeros((flat,), device=dev) # sum -ln(1-alpha) -> slab opacity = 1-prod(1-alpha) + crgb = torch.zeros((flat, 3), device=dev) # sum alpha^p * rgb -> slab colour + wbuf = torch.zeros((flat,), device=dev) if sharp else None # sum alpha^p -> colour normalizer (sharp only) + dacc = torch.zeros((flat,), device=dev) if need_depth else None # front-weighted depth + nacc = torch.zeros((flat, 3), device=dev) if need_normal else None # front-weighted camera-space normal + zslab = torch.zeros((flat,), device=dev) if need_depth else None + nslab = torch.zeros((flat, 3), device=dev) if need_normal else None + stale = 0 # consecutive fully-occluded slabs -> early-out + for si in range(ns): + runs = slab_runs[si] + if not runs: + continue + a_buf.zero_() + tau_buf.zero_() + crgb.zero_() + if sharp: + wbuf.zero_() + if need_depth: + zslab.zero_() + if need_normal: + nslab.zero_() + for r_lo, r_hi, li in runs: # contiguous same-kernel-level runs in this slab + ox, oy = grids[li] + ch = max(2048, 10_000_000 // ox.shape[0]) # splats/chunk, bounded by this level's kernel size + for lo in range(r_lo, r_hi, ch): + hi = min(lo + ch, r_hi) + idx, alpha = splat(lo, hi, ox, oy) + idx, af = idx.reshape(-1), alpha.reshape(-1) + a_buf.index_add_(0, idx, af) + tau_buf.index_add_(0, idx, (-torch.log1p(-alpha)).reshape(-1)) # -ln(1-alpha), correct opacity merge + apw = alpha.pow(sharpen) if sharp else alpha # bias colour toward the highest-alpha splat + crgb.index_add_(0, idx, (apw[:, :, None] * rgb[lo:hi, None, :]).reshape(-1, 3)) + if sharp: + wbuf.index_add_(0, idx, apw.reshape(-1)) + if need_depth: + zslab.index_add_(0, idx, (alpha * zc_o[lo:hi, None]).reshape(-1)) + if need_normal: + nslab.index_add_(0, idx, (alpha[:, :, None] * nrm_o[lo:hi, None, :]).reshape(-1, 3)) + slab_a = 1 - torch.exp(-tau_buf) # 1 - prod(1-alpha): true opacity of the slab's splats + front = trans * slab_a + denom = wbuf if sharp else a_buf + cacc.addcmul_(front[:, None], crgb / denom.clamp_min(1e-8)[:, None]) # cacc += front * (crgb/denom) + if need_depth or need_normal: + ainv = a_buf.clamp_min(1e-8) # alpha-weighted-mean normalizer (depth/normal only) + if need_depth: + dacc.addcmul_(front, zslab / ainv) + if need_normal: + nacc.addcmul_(front[:, None], nslab / ainv[:, None]) + trans.mul_(1 - slab_a) + if si % 8 == 7: # checkpoint every 8 slabs (a per-slab GPU sync would cost more) + if float(front.max()) < 1e-3: # this checkpoint slab is fully occluded by what is in front + stale += 1 + if stale >= 2: # two occluded checkpoints running -> the rest are too -> stop + break + else: + stale = 0 + + cov = 1 - trans + covg = cov.reshape(height, width) + covm = covg > 0.5 if render_style in ("depth", "normal") else None # silhouette mask (depth/normal styles only) + depth_map = (dacc / cov.clamp_min(1e-6)).reshape(height, width) if need_depth else None + nrm_map = None + if need_normal: + # Per-splat surfel normals are jittery, so do a masked blur + nb = nacc.reshape(height, width, 3).permute(2, 0, 1)[None] + cb = cov.reshape(1, 1, height, width) + nb, cb = _gauss_blur(nb, 1.2, dev), _gauss_blur(cb, 1.2, dev) + normal = (nb / cb.clamp_min(1e-6))[0].permute(1, 2, 0) + nrm_map = normal / normal.norm(dim=-1, keepdim=True).clamp_min(1e-6) + + if render_style == "depth": # near = bright, far = dark, 0 off-object + d = torch.zeros(height, width, device=dev) + if bool(covm.any()): + lo, hi = depth_map[covm].min(), depth_map[covm].max() + d = torch.where(covm, ((hi - depth_map) / (hi - lo).clamp_min(1e-6)).clamp(0, 1), d) + img = d[:, :, None].expand(height, width, 3) + elif render_style == "normal": # OpenGL normal map: +X right, +Y up, +Z to viewer + enc = (nrm_map * t([1.0, -1.0, -1.0]) * 0.5 + 0.5).clamp(0, 1) + img = enc * covm[:, :, None] + else: # color / clay + img = cacc.reshape(height, width, 3) + if render_style == "clay": # studio key light + ambient -> sculpted matte look + kl = t([-0.4, -0.7, -0.6]) # key from screen upper-left, angled toward the viewer + kl = kl / kl.norm() + hl = (0.5 * (nrm_map * kl).sum(-1) + 0.5).clamp(0, 1) # half-Lambert: soft terminator, no harsh dark side + img = img * (0.35 + 0.65 * hl * hl)[:, :, None] # ambient floor + diffuse key + elif headlight_shading > 0: # camera headlight: darken faces turned from view + k = float(headlight_shading) + ndotl = (-nrm_map[:, :, 2]).clamp(0, 1) + img = img * (1 - 0.6 * k + 0.6 * k * ndotl)[:, :, None] + img = img.addcmul_(trans.reshape(height, width, 1), bg_comp) + if do_linear: # back to display space after linear compositing + img = _linear_to_srgb(img) + return img.clamp(0, 1).to(idev, idtype), covg.clamp(0, 1).to(idev, idtype) + + +class RenderSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RenderSplat", + display_name="Render Splat", + search_aliases=["splat to image", "render splat", "gaussian turntable"], + category="3d/splat", + description="Render a gaussian splat as an image with an anisotropic EWA rasterizer (oriented " + "elliptical splats, antialiased, depth-sorted front-to-back). The camera comes from a " + "camera_info input (Load / Preview 3D, or a Create Camera Info node); leave it empty to " + "auto-frame the splat. Set frames greater than 1 for a turntable batch of images to feed a Video node.", + inputs=[ + IO.Splat.Input("splat"), + IO.Int.Input("width", default=1024, min=64, max=2048, step=8), + IO.Int.Input("height", default=1024, min=64, max=2048, step=8), + IO.Int.Input("frames", default=1, min=-240, max=240, + tooltip="-1, 0, 1 = single still image; >1 = turntable, the camera orbits over a full " + "360 turn (works with any camera_info). Negative value orbits the other way."), + IO.Float.Input("splat_scale", default=1.0, min=0.1, max=5.0, step=0.05, advanced=True, + tooltip="Multiplier on each splat's projected footprint (lower = crisper points, " + "higher = softer/fuller surface)."), + IO.Float.Input("sharpen", default=2.0, min=1.0, max=8.0, step=0.5, + tooltip="Sharpen overlapping splats: 1.0 = physically-correct blend; higher biases " + "each pixel toward its dominant (nearest) splat for crisper texture, without " + "shrinking splats or opening gaps. Non-physical above 1."), + IO.Float.Input("headlight_shading", default=0.0, min=0.0, max=3.0, step=0.05, advanced=True, + tooltip="Diffuse shading from a light at the camera (headlight), using the splat surfel " + "normals: darkens surfaces that turn away from view to reveal form/curvature. " + "0 = flat albedo, 1 = strongest shading."), + IO.Float.Input("opacity_threshold", default=0.0, min=0.0, max=1.0, step=0.01, advanced=True, + tooltip="Cull gaussians with opacity below this (removes faint floaters)."), + IO.Combo.Input("render_style", options=["color", "clay", "depth", "normal"], + tooltip="What the image output shows: color, clay (neutral-albedo shaded), " + "depth (near=bright), normal (OpenGL normal map)."), + IO.Color.Input("background", default="#000000"), + IO.Image.Input("bg_image", optional=True, + tooltip="Optional background plate composited behind the splat (overrides the solid " + "background colour). Resized to the render size; a batch is used per frame, " + "a single image for all. color/clay only."), + IO.Load3DCamera.Input("camera_info", optional=True, + tooltip="Camera to render from - a Load3D / Preview3D camera or a Create Camera " + "Info node. If empty, the splat is auto-framed from a default 3/4 view."), + ], + outputs=[IO.Image.Output(display_name="image"), IO.Mask.Output(display_name="mask")], + ) + + @classmethod + def execute(cls, splat, width, height, frames, splat_scale, sharpen, headlight_shading, + opacity_threshold, background, render_style, camera_info=None, bg_image=None) -> IO.NodeOutput: + bg = _hex_to_rgb(background) + bg_imgs = None + if bg_image is not None: # resize the plate(s) to the render size: (B,H,W,3) + bi = comfy.utils.common_upscale(bg_image.movedim(-1, 1), width, height, "bicubic", "disabled") + bg_imgs = bi.movedim(1, -1).clamp(0, 1) + n_frames = abs(int(frames)) or 1 # magnitude = frame count (0 -> single still) + orbit_dir = -1.0 if frames < 0 else 1.0 # sign = orbit direction + imgs, masks = [], [] + device = comfy.model_management.get_torch_device() + total = splat.positions.shape[0] * n_frames + pbar = comfy.utils.ProgressBar(total) if total > 1 else None + k = 0 + for i in range(splat.positions.shape[0]): + xyz, rgb, opacity, scale, rot = _gaussian_item(splat, i, device) + if opacity_threshold > 0: + keep = opacity >= opacity_threshold + xyz, rgb, opacity, scale, rot = xyz[keep], rgb[keep], opacity[keep], scale[keep], rot[keep] + base_cam = camera_info + if base_cam is None: # no camera -> default 3/4 view, auto-framed on the splat + center = xyz.mean(0) if xyz.shape[0] else torch.zeros(3, device=device) + extent = (_quantile((xyz - center).norm(dim=-1), 0.99).clamp_min(1e-4) if xyz.shape[0] + else torch.tensor(1.0, device=device)) + dist = float(extent / (math.tan(math.radians(35.0) / 2) * 0.9)) + base_cam = _orbit_camera_info(35.0, 30.0, dist, 35.0, center, device) + for fr in range(n_frames): + cam_fr = (base_cam if n_frames == 1 + else _orbit_camera_info_yaw(base_cam, orbit_dir * 360.0 * fr / n_frames, device)) + bg_k = bg_imgs[k % bg_imgs.shape[0]] if bg_imgs is not None else bg # per-frame plate, or solid colour + img, mask = _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, bg_k, cam_fr, + sharpen=sharpen, headlight_shading=headlight_shading, + render_style=render_style) + imgs.append(img) + masks.append(mask) + k += 1 + if pbar is not None: + pbar.update(1) + return IO.NodeOutput(torch.stack(imgs), torch.stack(masks)) + + +class CreateCameraInfo(IO.ComfyNode): # TODO: move to better file + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="CreateCameraInfo", + display_name="Create Camera Info", + search_aliases=["camera position", "make camera info", "orbit camera", "look at camera"], + category="3d", + description="Build a camera_info" + "Mode 'orbit' aims with yaw/pitch/distance around the target; " + "'look_at' places the camera at world position. Coordinates are the viewer's world space (right-handed,Y-up).", + inputs=[ + IO.DynamicCombo.Input("mode", options=[ + IO.DynamicCombo.Option("orbit", [ + IO.Float.Input("yaw", default=35.0, min=-360.0, max=360.0, step=1.0), + IO.Float.Input("pitch", default=30.0, min=-89.0, max=89.0, step=1.0), + IO.Float.Input("distance", default=4.0, min=0.01, max=1000.0, step=0.01, + tooltip="Camera distance from the target."), + ]), + IO.DynamicCombo.Option("look_at", [ + IO.Float.Input("position_x", default=4.0, min=-1000.0, max=1000.0, step=0.01, + tooltip="Camera position in world space (right-handed, Y-up)."), + IO.Float.Input("position_y", default=4.0, min=-1000.0, max=1000.0, step=0.01), + IO.Float.Input("position_z", default=4.0, min=-1000.0, max=1000.0, step=0.01), + ]), + IO.DynamicCombo.Option("quaternion", [ + IO.Float.Input("position_x", default=4.0, min=-1000.0, max=1000.0, step=0.01, + tooltip="Camera position in world space (right-handed, Y-up)."), + IO.Float.Input("position_y", default=4.0, min=-1000.0, max=1000.0, step=0.01), + IO.Float.Input("position_z", default=4.0, min=-1000.0, max=1000.0, step=0.01), + IO.Float.Input("quat_x", default=0.0, min=-1.0, max=1.0, step=0.001), + IO.Float.Input("quat_y", default=0.0, min=-1.0, max=1.0, step=0.001), + IO.Float.Input("quat_z", default=0.0, min=-1.0, max=1.0, step=0.001), + IO.Float.Input("quat_w", default=1.0, min=-1.0, max=1.0, step=0.001, + tooltip="Camera world-rotation quaternion (three.js: looks down local -Z). Normalized for you."), + ]), + ], tooltip="How to define the camera: orbit angles, an explicit position, or a position + quaternion."), + IO.Float.Input("target_x", default=0.0, min=-1000.0, max=1000.0, step=0.01, advanced=True, + tooltip="Look-at point (orbit pivot / aim). In orbit mode, move it to pan/translate the " + "whole camera. Ignored in quaternion mode. Defaults to the origin."), + IO.Float.Input("target_y", default=0.0, min=-1000.0, max=1000.0, step=0.01, advanced=True), + IO.Float.Input("target_z", default=0.0, min=-1000.0, max=1000.0, step=0.01, advanced=True), + IO.Float.Input("roll", default=0.0, min=-180.0, max=180.0, step=1.0, + tooltip="Camera roll about the view axis, degrees."), + IO.Float.Input("fov", default=35.0, min=1.0, max=120.0, step=1.0, + tooltip="Vertical field of view in degrees."), + IO.Float.Input("zoom", default=1.0, min=0.01, max=100.0, step=0.01, + tooltip="Digital zoom (focal-length multiplier). >1 zooms in without moving the camera."), + IO.Combo.Input("camera_type", options=["perspective", "orthographic"], + tooltip="Projection used by Render Splat: perspective (foreshortening) or orthographic (parallel)."), + ], + outputs=[IO.Load3DCamera.Output(display_name="camera_info")], + ) + + @classmethod + def execute(cls, mode, target_x, target_y, target_z, roll, fov, zoom=1.0, camera_type="perspective") -> IO.NodeOutput: + dev = comfy.model_management.get_torch_device() + kind = mode["mode"] + if kind == "quaternion": # explicit world position + camera rotation + position = [mode["position_x"], mode["position_y"], mode["position_z"]] + quat = [mode["quat_x"], mode["quat_y"], mode["quat_z"], mode["quat_w"]] + return IO.NodeOutput(_quat_camera_info(position, quat, fov, dev, zoom=zoom, camera_type=camera_type)) + target = [target_x, target_y, target_z] # orbit pivot / aim; move it to pan the whole camera + if kind == "orbit": # yaw/pitch/distance about the target (world Y-up) + y, p = math.radians(mode["yaw"]), math.radians(mode["pitch"]) + cy, sy, cp, sp = math.cos(y), math.sin(y), math.cos(p), math.sin(p) + d = mode["distance"] + position = [target_x + d * cp * sy, target_y + d * sp, target_z + d * cp * cy] + else: # look_at: explicit world-space camera position + position = [mode["position_x"], mode["position_y"], mode["position_z"]] + return IO.NodeOutput(_lookat_camera_info(position, target, fov, dev, zoom=zoom, camera_type=camera_type, roll=roll)) + + +class TransformSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TransformSplat", + display_name="Transform Splat", + search_aliases=["move splat", "rotate splat", "scale splat", "gaussian transform"], + category="3d/splat", + description="Translate, rotate, and scale a gaussian splat. " + "Non-uniform scale also reshapes every individual splat, slower process.", + inputs=[ + IO.Splat.Input("splat"), + IO.Float.Input("translate_x", default=0.0, min=-100.0, max=100.0, step=0.01), + IO.Float.Input("translate_y", default=0.0, min=-100.0, max=100.0, step=0.01), + IO.Float.Input("translate_z", default=0.0, min=-100.0, max=100.0, step=0.01), + IO.Float.Input("rotate_x", default=0.0, min=-360.0, max=360.0, step=1.0), + IO.Float.Input("rotate_y", default=0.0, min=-360.0, max=360.0, step=1.0), + IO.Float.Input("rotate_z", default=0.0, min=-360.0, max=360.0, step=1.0), + IO.Float.Input("scale_x", default=1.0, min=0.01, max=100.0, step=0.01), + IO.Float.Input("scale_y", default=1.0, min=0.01, max=100.0, step=0.01), + IO.Float.Input("scale_z", default=1.0, min=0.01, max=100.0, step=0.01), + ], + outputs=[IO.Splat.Output(display_name="splat")], + ) + + @classmethod + def execute(cls, splat, translate_x, translate_y, translate_z, + rotate_x, rotate_y, rotate_z, scale_x, scale_y, scale_z) -> IO.NodeOutput: + pos = splat.positions + dev, dt = pos.device, pos.dtype + q_rot = _euler_to_quat(rotate_x, rotate_y, rotate_z).to(device=dev, dtype=dt) + R = _quat_to_mat(q_rot[None])[0] # (3, 3) node rotation + D = torch.tensor([scale_x, scale_y, scale_z], dtype=dt, device=dev) + A = D[:, None] * R # diag(D) @ R: per-axis scale after rotation + t = torch.tensor([translate_x, translate_y, translate_z], dtype=dt, device=dev) + + positions = pos @ A.T + t # rotate, scale per-axis, then translate + if scale_x == scale_y == scale_z: # uniform: rotation/scale factor out cleanly + scales = splat.scales * scale_x + rotations = _quat_mul(q_rot.expand_as(splat.rotations), splat.rotations) + rotations = rotations / rotations.norm(dim=-1, keepdim=True).clamp_min(1e-12) + else: # non-uniform: transform Sigma = A R s^2 R^T A^T, re-extract + rg = _quat_to_mat(splat.rotations.reshape(-1, 4)) # (M,3,3) per-splat rotation + s2 = splat.scales.reshape(-1, 3).square() + cov = (rg * s2[:, None, :]) @ rg.transpose(-1, -2) # Sigma + cov = A @ cov @ A.T # A Sigma A^T (A broadcast over splats) + lam, V = torch.linalg.eigh(cov) # symmetric -> eigenvalues (asc), orthonormal axes + V = V * torch.where(torch.linalg.det(V) < 0, -1.0, 1.0)[..., None, None] # keep a proper rotation + scales = lam.clamp_min(0).sqrt().reshape(splat.scales.shape) + rotations = _mat_to_quat(V).reshape(splat.rotations.shape) + out = Types.SPLAT(positions, scales, rotations, splat.opacities, splat.sh, + counts=getattr(splat, "counts", None)) + return IO.NodeOutput(out) + + +class GetSplatCount(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GetSplatCount", + display_name="Get Splat Count", + search_aliases=["splat count", "gaussian count", "number of splats", "splat info"], + category="3d/splat", + description="Returns the number of splats summed across the batch.", + inputs=[IO.Splat.Input("splat")], + outputs=[IO.Splat.Output(display_name="splat"), + IO.Int.Output(display_name="count"), + ], + hidden=[IO.Hidden.unique_id], + ) + + @classmethod + def execute(cls, splat) -> IO.NodeOutput: + count = sum(_real_len(splat, i) for i in range(splat.positions.shape[0])) + if cls.hidden.unique_id: # show the count inline on the node + PromptServer.instance.send_progress_text(f"{count:,} splats", cls.hidden.unique_id) + return IO.NodeOutput(splat, count) + + +def _pad_stack(items, n): + # Stack a list of (Lᵢ, *tail) tensors into (B, n, *tail), zero-padding each row up to n. + tail = items[0].shape[1:] + out = items[0].new_zeros((len(items), n, *tail)) + for i, t in enumerate(items): + out[i, :t.shape[0]] = t + return out + + +def _merge_gaussians(gaussians: list) -> Types.SPLAT: + # Concatenate SPLAT batches along the splat dimension (per item), padding SH to the highest degree. + gs = [g for g in gaussians if g is not None] + if not gs: + raise ValueError("MergeSplat: no gaussians to merge") + b = gs[0].positions.shape[0] + for g in gs: + if g.positions.shape[0] != b: + raise ValueError(f"MergeSplat: batch size mismatch ({b} vs {g.positions.shape[0]}).") + max_k = max(g.sh.shape[2] for g in gs) + + pos_b, scl_b, rot_b, op_b, sh_b, lengths = [], [], [], [], [], [] + for i in range(b): + pos_i, scl_i, rot_i, op_i, sh_i = [], [], [], [], [] + for g in gs: + end = _real_len(g, i) + pos_i.append(g.positions[i, :end]) + scl_i.append(g.scales[i, :end]) + rot_i.append(g.rotations[i, :end]) + op_i.append(g.opacities[i, :end]) + sh = g.sh[i, :end] # (end, K, 3) + if sh.shape[1] < max_k: # zero-pad lower-degree SH + sh = torch.cat([sh, sh.new_zeros(sh.shape[0], max_k - sh.shape[1], sh.shape[2])], dim=1) + sh_i.append(sh) + pos_b.append(torch.cat(pos_i)) + scl_b.append(torch.cat(scl_i)) + rot_b.append(torch.cat(rot_i)) + op_b.append(torch.cat(op_i)) + sh_b.append(torch.cat(sh_i)) + lengths.append(pos_b[-1].shape[0]) + + n = max(lengths) + counts = None + if len(set(lengths)) > 1: + counts = torch.tensor(lengths, device=gs[0].positions.device, dtype=torch.int64) + return Types.SPLAT(_pad_stack(pos_b, n), _pad_stack(scl_b, n), _pad_stack(rot_b, n), + _pad_stack(op_b, n), _pad_stack(sh_b, n), counts=counts) + + +class MergeSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + # Autogrow: a splat0/splat1/... input list that grows a fresh slot as you connect splats. + splats = IO.Autogrow.TemplatePrefix(IO.Splat.Input("splat"), prefix="splat", min=2, max=32) + return IO.Schema( + node_id="MergeSplat", + display_name="Merge Splats", + search_aliases=["union splat", "densify gaussian", "combine splat", "merge gaussian"], + category="3d/splat", + description="Concatenate any number of gaussian splats into one. Unioning several decodes of the same " + "latent at different seeds densifies the surface, this can improve surface quality when meshing.", + inputs=[IO.Autogrow.Input("splats", template=splats)], + outputs=[IO.Splat.Output(display_name="splat")], + ) + + @classmethod + def execute(cls, splats: IO.Autogrow.Type) -> IO.NodeOutput: + gs = [v for v in splats.values() if v is not None] + if not gs: + raise ValueError("MergeSplat: connect at least one splat.") + return IO.NodeOutput(_merge_gaussians(gs)) + + +def _inverse_covariance(scale, quat): + # Per-splat Sigma^-1 = R diag(1/s^2) R^T. scale (N,3) linear std, quat (N,4) wxyz -> (N,3,3). + q = quat / quat.norm(dim=1, keepdim=True).clamp_min(1e-12) + w, x, y, z = q.unbind(-1) + R = torch.stack([ + 1 - 2 * (y * y + z * z), 2 * (x * y - w * z), 2 * (x * z + w * y), + 2 * (x * y + w * z), 1 - 2 * (x * x + z * z), 2 * (y * z - w * x), + 2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x * x + y * y), + ], dim=1).reshape(-1, 3, 3) + inv_s2 = 1.0 / scale.clamp_min(1e-8) ** 2 # (N, 3) + return torch.einsum("nij,nj,nkj->nik", R, inv_s2, R) + + +def _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, color_sharpen=1.0, chunk=4096, progress=None, + col_dtype=torch.float16): + # Splat each gaussian as its oriented-covariance disk (3-sigma, opacity-weighted) into a density grid, + # plus a colour volume. Each gaussian uses a voxel window sized to its OWN 3-sigma (capped at `kernel`). + # Colour is weighted by w^color_sharpen: >1 biases each voxel toward its dominant gaussian (crisper + # texture). Returns (density, colour numerator, colour normaliser, origin, voxel). + pad = 4.0 * scale.median() + lo = xyz.amin(0) - pad + hi = xyz.amax(0) + pad + voxel = ((hi - lo).max() / res).clamp_min(1e-8) + dx, dy, dz = (torch.ceil((hi - lo) / voxel).long() + 1).tolist() + + sinv = _inverse_covariance(scale, quat) + kreq = torch.ceil(3.0 * scale.amax(-1) / voxel).long().clamp(1, int(kernel)) # per-gaussian half-width + sharp = color_sharpen != 1.0 + vol = torch.zeros(dx * dy * dz, device=device) # Sum(w) density (surface) + colvol = torch.zeros(dx * dy * dz, 3, device=device, dtype=col_dtype) # Sum(w^p * rgb) colour numerator + wcol = torch.zeros(dx * dy * dz, device=device, dtype=col_dtype) if sharp else None # Sum(w^p) normaliser (p>1) + n, done = xyz.shape[0], 0 + for k in range(1, int(kernel) + 1): + sel = (kreq == k).nonzero(as_tuple=True)[0] + if sel.numel() == 0: + continue + rng = torch.arange(-k, k + 1, device=device, dtype=torch.float32) + off = torch.stack(torch.meshgrid(rng, rng, rng, indexing="ij"), -1).reshape(-1, 3) # (M, 3) + for st in range(0, sel.numel(), chunk): + gi = sel[st:st + chunk] + cc = xyz[gi] + idx = ((cc - lo) / voxel).round()[:, None, :] + off[None] # (b, M, 3) voxel coords + d = (lo + idx * voxel) - cc[:, None, :] # world offset to voxel center + quad = torch.einsum("bmi,bij,bmj->bm", d, sinv[gi], d) + wgt = opacity[gi, None] * torch.exp(-0.5 * quad) + wgt = torch.where(quad < 9.0, wgt, torch.zeros_like(wgt)) # clip beyond 3 sigma + ii = idx.long() + ix = ii[..., 0].clamp(0, dx - 1) + iy = ii[..., 1].clamp(0, dy - 1) + iz = ii[..., 2].clamp(0, dz - 1) + flat = (ix * (dy * dz) + iy * dz + iz).reshape(-1) + vol.index_add_(0, flat, wgt.reshape(-1)) + wp = wgt.pow(color_sharpen) if sharp else wgt # winner-take-more colour weight + colvol.index_add_(0, flat, (wp[..., None] * rgb[gi, None, :]).reshape(-1, 3).to(col_dtype)) + if sharp: + wcol.index_add_(0, flat, wp.reshape(-1).to(col_dtype)) + done += gi.numel() + if progress is not None: + progress(min(1.0, done / max(1, n))) + colnorm = (wcol if sharp else vol).reshape(dx, dy, dz) # p==1 -> Sum(w) == density + return vol.reshape(dx, dy, dz), colvol.reshape(dx, dy, dz, 3), colnorm, lo.cpu().numpy(), float(voxel) + + +def _connected_components_gpu(faces, nv): + # FastSV connected components: grandparent hooking + shortcutting, ~O(log nv) iterations. + # Returns per-vertex component labels (min node id, not densified). + a = torch.cat([faces[:, 0], faces[:, 1]]) # 2F edge endpoints: (v0,v1),(v1,v2) + b = torch.cat([faces[:, 1], faces[:, 2]]) + f = torch.arange(nv, device=faces.device) + while True: + gp = f[f] # grandparent + ga, gb = gp[a], gp[b] + new = f.clone() + new.scatter_reduce_(0, f[a], gb, "amin", include_self=True) # stochastic hooking onto roots + new.scatter_reduce_(0, f[b], ga, "amin", include_self=True) + new.scatter_reduce_(0, a, gb, "amin", include_self=True) # aggressive hooking, both directions + new.scatter_reduce_(0, b, ga, "amin", include_self=True) + new = new[new] # shortcut (path compression) + if torch.equal(new, f): + return f + f = new + + +def _clean_components_gpu(verts, faces, min_verts, device): + # GPU port of _clean_components: FastSV components + scatter reductions. Byte-identical to the numpy path + vt = torch.as_tensor(verts, device=device) + ft = torch.as_tensor(faces, device=device) + nv = vt.shape[0] + _, label = torch.unique(_connected_components_gpu(ft, nv), return_inverse=True) # dense 0..ncomp-1 + ncomp = int(label.max()) + 1 + flabel = label[ft[:, 0]] # component id per face + keep = torch.bincount(label, minlength=ncomp) >= min_verts # per-component vertex-count gate + if int(keep.sum()) > 1: + fcount = torch.bincount(flabel, minlength=ncomp) + largest = int(torch.where(keep, fcount, fcount.new_tensor(-1)).argmax()) + v0, v1, v2 = vt[ft[:, 0]], vt[ft[:, 1]], vt[ft[:, 2]] + cvol = torch.zeros(ncomp, device=device).scatter_add_(0, flabel, (v0 * torch.linalg.cross(v1, v2)).sum(-1)) + idx3 = label[:, None].expand(-1, 3) # per-component vertex bbox + cmin = torch.full((ncomp, 3), float("inf"), device=device).scatter_reduce_(0, idx3, vt, "amin", include_self=True) + cmax = torch.full((ncomp, 3), float("-inf"), device=device).scatter_reduce_(0, idx3, vt, "amax", include_self=True) + tol = 1e-4 * (cmax[largest] - cmin[largest]).max() + enclosed = (cmin >= cmin[largest] - tol).all(1) & (cmax <= cmax[largest] + tol).all(1) + inner = enclosed & (torch.sign(cvol) != torch.sign(cvol[largest])) & (torch.arange(ncomp, device=device) != largest) + keep &= ~inner + faces_k = ft[keep[flabel]] + if faces_k.shape[0] == 0: + return verts[:0], faces[:0] + used = torch.unique(faces_k) # sorted, matches np.unique + remap = torch.full((nv,), -1, dtype=torch.int64, device=device) + remap[used] = torch.arange(used.shape[0], device=device) + return vt[used].cpu().numpy(), remap[faces_k].cpu().numpy() + + +def _clean_components(verts, faces, min_verts, device=None): + # Drop floaters (components with < min_verts vertices) and inner shells - the surfel shell density + # extracts a double wall (outer + inner cavity surface). GPU path (FastSV CC + scatter reductions, ~13x + # faster) when an accelerator has headroom; else numpy/scipy. Both produce byte-identical output. + if device is not None and not comfy.model_management.is_device_cpu(device) and \ + comfy.model_management.get_free_memory(device) > 10 * faces.size * 8: # peak ~8.4x faces bytes + return _clean_components_gpu(verts, faces, min_verts, device) + nv = len(verts) + e = np.concatenate([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [0, 2]]], 0) + ncomp, label = connected_components(coo_matrix((np.ones(len(e)), (e[:, 0], e[:, 1])), shape=(nv, nv)), directed=False) + flabel = label[faces[:, 0]] # component id per face + keep = np.bincount(label, minlength=ncomp) >= min_verts # per-component vertex-count gate + if keep.sum() > 1: + fcount = np.bincount(flabel, minlength=ncomp) + largest = np.where(keep, fcount, -1).argmax() + v0, v1, v2 = verts[faces[:, 0]], verts[faces[:, 1]], verts[faces[:, 2]] + cvol = np.bincount(flabel, weights=np.einsum("ij,ij->i", v0, np.cross(v1, v2)), minlength=ncomp) # 6*signed vol + cidx = np.arange(ncomp) # per-component vertex bbox via ndimage (~6x faster than ufunc.at) + cmin = np.stack([_ndi_minimum(verts[:, a], label, cidx) for a in range(3)], 1) + cmax = np.stack([_ndi_maximum(verts[:, a], label, cidx) for a in range(3)], 1) + tol = 1e-4 * (cmax[largest] - cmin[largest]).max() + enclosed = (cmin >= cmin[largest] - tol).all(1) & (cmax <= cmax[largest] + tol).all(1) + inner = enclosed & (np.sign(cvol) != np.sign(cvol[largest])) & (np.arange(ncomp) != largest) + keep &= ~inner + faces = faces[keep[flabel]] + if len(faces) == 0: + return verts[:0], faces + used = np.unique(faces) + remap = np.full(nv, -1, np.int64) + remap[used] = np.arange(len(used)) + return verts[used], remap[faces] + + +def _surface_nets(vol, level, voxel, origin, device): + # Vectorized Surface Nets: one dual vertex per sign-changing cell at its edge-crossing mean, quads wound CCW-outward. + # Returns verts (V,3), faces (F,3). + vol = vol.to(device=device, dtype=torch.float32) + dx, dy, dz = vol.shape + origin_t = torch.as_tensor(origin, device=device, dtype=torch.float32) + empty = (np.zeros((0, 3), np.float32), np.zeros((0, 3), np.int64)) + if dx < 2 or dy < 2 or dz < 2: + return empty + + # Active = cells whose 8 corners aren't all in/all out. + inside = vol >= level # (dx,dy,dz) bool + cs8 = [inside[ox:ox + dx - 1, oy:oy + dy - 1, oz:oz + dz - 1] + for ox, oy, oz in ((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0), + (0, 0, 1), (1, 0, 1), (0, 1, 1), (1, 1, 1))] + any_in = cs8[0] | cs8[1] | cs8[2] | cs8[3] | cs8[4] | cs8[5] | cs8[6] | cs8[7] + all_in = cs8[0] & cs8[1] & cs8[2] & cs8[3] & cs8[4] & cs8[5] & cs8[6] & cs8[7] + active = any_in & ~all_in # (cx,cy,cz) straddling cells + nv = int(active.sum()) + if nv == 0: + return empty + + # Active cells only (a thin shell): each dual vertex = mean of its 12 edges' zero-crossings. + del any_in, all_in, cs8 # corner bool grids no longer needed + ac = active.nonzero(as_tuple=False) # (nv,3) cell min-corner indices + offs = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], + [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]], device=device) + offf = offs.to(torch.float32) + edges = torch.tensor([[0, 1], [0, 2], [0, 4], [1, 3], [1, 5], [2, 3], + [2, 6], [3, 7], [4, 5], [4, 6], [5, 7], [6, 7]], device=device) + e0, e1 = edges[:, 0], edges[:, 1] + oe0, oe1 = offf[e0], offf[e1] # (12,3) edge endpoints + + cstep = 1 << 18 # chunk to bound peak memory (CPU RAM too) + loc = [] + for st in range(0, nv, cstep): + ci = ac[st:st + cstep, None, :] + offs[None] # (m,8,3) + cval = vol[ci[..., 0], ci[..., 1], ci[..., 2]] # (m,8) corner values + csl = cval >= level + v0, v1 = cval[:, e0], cval[:, e1] # (m,12) + cross = (csl[:, e0] != csl[:, e1])[..., None].to(torch.float32) + denom = v1 - v0 + t = torch.where(denom.abs() > 1e-12, (level - v0) / denom, torch.full_like(denom, 0.5)).clamp(0, 1) + pts = torch.lerp(oe0, oe1, t[..., None]) # (m,12,3) local crossings (fused interp) + loc.append((pts * cross).sum(1) / cross.sum(1).clamp_min(1.0)) # (m,3) in [0,1] + local = torch.cat(loc, 0) if len(loc) > 1 else loc[0] # (nv,3) + verts = origin_t + (ac.to(torch.float32) + local) * voxel # world space + del loc, local, ac + + vid = torch.full((dx - 1, dy - 1, dz - 1), -1, dtype=torch.int32, device=device) + vid[active] = torch.arange(nv, dtype=torch.int32, device=device) + del active + + # Each straddling grid edge -> one quad from its 4 cells; `sol` (low-end sign) picks outward winding. + faces = [] + + def emit(cr, sol, a, b, d, c): + valid = cr & (a >= 0) & (b >= 0) & (c >= 0) & (d >= 0) + if not bool(valid.any()): + return + a, b, c, d, sol = a[valid], b[valid], c[valid], d[valid], sol[valid] + p2, p4 = torch.where(sol, b, c), torch.where(sol, c, b) # reverse quad winding where ~sol + faces.append(torch.stack([a, p2, d], 1)) + faces.append(torch.stack([a, d, p4], 1)) + + a = inside[0:dx - 1, 1:dy - 1, 1:dz - 1] + emit(a != inside[1:dx, 1:dy - 1, 1:dz - 1], a, + vid[:, 0:dy - 2, 0:dz - 2], vid[:, 1:dy - 1, 0:dz - 2], + vid[:, 1:dy - 1, 1:dz - 1], vid[:, 0:dy - 2, 1:dz - 1]) + a = inside[1:dx - 1, 0:dy - 1, 1:dz - 1] + emit(a != inside[1:dx - 1, 1:dy, 1:dz - 1], a, + vid[0:dx - 2, :, 0:dz - 2], vid[0:dx - 2, :, 1:dz - 1], + vid[1:dx - 1, :, 1:dz - 1], vid[1:dx - 1, :, 0:dz - 2]) + a = inside[1:dx - 1, 1:dy - 1, 0:dz - 1] + emit(a != inside[1:dx - 1, 1:dy - 1, 1:dz], a, + vid[0:dx - 2, 0:dy - 2, :], vid[1:dx - 1, 0:dy - 2, :], + vid[1:dx - 1, 1:dy - 1, :], vid[0:dx - 2, 1:dy - 1, :]) + + if not faces: + return empty + return verts.cpu().numpy().astype(np.float32), torch.cat(faces, 0).cpu().numpy().astype(np.int64) + + +def _otsu_level(values, bins=256): + # Otsu threshold: the density value that best splits inside/outside (max between-class variance). + hist, edges = np.histogram(values, bins=bins) + hist = hist.astype(np.float64) + centers = (edges[:-1] + edges[1:]) * 0.5 + w = np.cumsum(hist) # background-class weight at each split + mu = np.cumsum(hist * centers) + wf = w[-1] - w # foreground-class weight + mb = mu / np.where(w > 0, w, 1.0) + mf = (mu[-1] - mu) / np.where(wf > 0, wf, 1.0) + var_b = w * wf * (mb - mf) ** 2 # between-class variance + var_b[(w <= 0) | (wf <= 0)] = -1.0 + return float(centers[int(np.argmax(var_b))]) + + +def _taubin_smooth(verts, faces, iters, lam=0.5, mu=-0.53): + # Taubin lambda|mu smoothing: low-pass the mesh surface without the shrinkage of a Laplacian blur + # (the mu inflation pass cancels the lambda pass's volume loss). Uniform (umbrella) weights. + if iters <= 0 or len(verts) == 0 or len(faces) == 0: + return verts + nv = len(verts) + e = np.concatenate([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [0, 2]]], 0) + e = np.concatenate([e, e[:, ::-1]], 0) # symmetric adjacency + adj = coo_matrix((np.ones(len(e), np.float32), (e[:, 0], e[:, 1])), shape=(nv, nv)).tocsr() + adj.data[:] = 1.0 + deg = np.clip(np.asarray(adj.sum(1)).ravel(), 1.0, None).astype(np.float32)[:, None] + v = verts.astype(np.float32) # fp32 matvec: ~2x faster, sub-micron drift on unit-scale verts + for _ in range(int(iters)): + for fac in (lam, mu): + v = v + np.float32(fac) * ((adj @ v) / deg - v) # fac * (mean(neighbours) - v) + return np.ascontiguousarray(v) + + +def _sample_vertex_colours_gpu(colvol, colnorm, verts, origin, voxel, device): + # GPU trilinear sampling of the colour numerator (3ch) and normaliser (1ch) at vertex grid-coords + # reproduces scipy map_coordinates(order=1, mode='nearest'). Returns col (V,3) numpy. + dx, dy, dz = colnorm.shape + vt = torch.as_tensor(verts, device=device, dtype=torch.float32) + org = torch.as_tensor(origin, device=device, dtype=torch.float32) + gi = (vt - org) / voxel # (V,3) grid-index coords (x,y,z) + size = torch.tensor([dx, dy, dz], device=device, dtype=torch.float32) + g = 2.0 * gi / (size - 1).clamp_min(1.0) - 1.0 # -> [-1,1] (align_corners) + grid = torch.stack([g[:, 2], g[:, 1], g[:, 0]], -1)[None, None, None] # (1,1,1,V,3): grid_sample order (W=z,H=y,D=x) + + def samp(v): # (dx,dy,dz,C) cpu fp16 -> (C,V) fp32 on device + inp = v.to(device).permute(3, 0, 1, 2)[None].float() + o = torch.nn.functional.grid_sample(inp, grid, mode="bilinear", padding_mode="border", align_corners=True) + return o[0, :, 0, 0, :] + num = samp(colvol) # (3,V) + den = samp(colnorm[..., None]) # (1,V) + return (num / den.clamp_min(1e-8)).T.cpu().numpy() # (V,3) + + +def _gaussian_to_mesh(g: Types.SPLAT, i, res, kernel, taubin, level_bias, min_component, min_opacity, color_sharpen, device, progress=None): + # Mesh one splat: density + colour grids -> Surface Nets -> floater removal -> Taubin smoothing -> + # volume-sampled colours. Returns (verts, faces int64, colors in [0,1]), or None if no surface. + rep = progress if progress is not None else (lambda *_: None) + + end = _real_len(g, i) + xyz = g.positions[i, :end].to(device=device, dtype=torch.float32) + scale = g.scales[i, :end].to(device=device, dtype=torch.float32) + quat = g.rotations[i, :end].to(device=device, dtype=torch.float32) + opacity = g.opacities[i, :end].reshape(-1).to(device=device, dtype=torch.float32) + rgb = (g.sh[i, :end, 0, :].to(device=device, dtype=torch.float32) * _C0 + 0.5).clamp(0, 1) + + keep = opacity >= min_opacity + xyz, scale, quat, opacity, rgb = xyz[keep], scale[keep], quat[keep], opacity[keep], rgb[keep] + if xyz.shape[0] == 0: + return None + + vol, colvol, colnorm, origin, voxel = _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, + color_sharpen=color_sharpen, + progress=lambda f: rep(0.25 * f)) # density build: 0 -> 25% + # Colour: sample on the GPU (grid_sample) when there's headroom + colour_gpu = not comfy.model_management.is_device_cpu(device) and comfy.model_management.get_free_memory(device) > 6 * vol.numel() * 4 + if colour_gpu: + colvol_cpu, colnorm_cpu = colvol.cpu(), colnorm.half().cpu() # park colours (fp16) off-GPU during meshing + colvol_np = colnorm_np = None + else: + colvol_np = colvol.cpu().numpy().astype(np.float32) # Sum(w^p * rgb) colour numerator (fp16 grid -> fp32) + colnorm_np = colnorm.cpu().numpy().astype(np.float32) # Sum(w^p) colour normaliser + del colvol, colnorm # free the colour grids before iso-surfacing + rep(0.40) + + vmin, vmax = float(vol.min()), float(vol.max()) + occ = vol[vol > vmax * 1e-3] # occupied voxels (skip the empty-space peak) + if occ.numel() == 0: + return None + # Otsu picks the inside/outside split principledly; `level_bias` nudges it (1.0 = auto). Clamp strictly + # inside the data range so a bias can't push the iso off the histogram. + level = min(max(_otsu_level(occ.cpu().numpy()) * level_bias, vmin + 1e-6 * (vmax - vmin)), + vmax - 1e-6 * (vmax - vmin)) + + # Iso-surface on the accelerator when there's headroom: ~15x faster than CPU, identical output. Chunked + # Surface Nets peaks at ~3-3.5x the density grid, so fall back to CPU for large grids / tight VRAM. + sn_dev = device + if not comfy.model_management.is_device_cpu(device) and comfy.model_management.get_free_memory(device) < 6 * vol.numel() * 4: + sn_dev = torch.device("cpu") + vol = vol.cpu() + verts, faces = _surface_nets(vol, level, voxel, origin, sn_dev) + del vol + rep(0.55) + if min_component > 0 and len(faces) > 0: + verts, faces = _clean_components(verts, faces, min_component, device) + if len(verts) == 0 or len(faces) == 0: + return None + + # Taubin smooths the blocky iso without shrinking it (unlike blurring the density, which rounds features). + verts = _taubin_smooth(verts, faces, taubin) + rep(0.7) + + # Colour each vertex from the co-splatted colour volume: trilinearly sample the numerator Sum(w^p*rgb) + # and normaliser Sum(w^p) separately, then divide. Normalising AFTER interpolation keeps zero-density + # edge voxels from pulling colours toward black, and matches the gaussians that formed the surface. + if colour_gpu: + col = _sample_vertex_colours_gpu(colvol_cpu, colnorm_cpu, verts, origin, voxel, device) + else: + coords = ((verts - origin) / voxel).T # (3, V) grid-index coords, matching volume axes + num = np.stack([map_coordinates(colvol_np[..., c], coords, order=1, mode="nearest") for c in range(3)], -1) + den = map_coordinates(colnorm_np, coords, order=1, mode="nearest") + col = num / np.clip(den, 1e-8, None)[:, None] + rep(1.0) + + # The unlit material's COLOR_0 is linear and the viewer sRGB-encodes it on output; the splat colours + # are display (sRGB) values, so convert sRGB -> linear here to land at the same brightness as the splat. + col = np.clip(col, 0, 1) + col = np.where(col <= 0.04045, col / 12.92, ((col + 0.055) / 1.055) ** 2.4).astype(np.float32) + + # Splat +Y is glTF's -Y: rotate 180 deg about X (negate Y,Z) to land upright. Proper rotation, so + # winding is kept; done after colouring (which works in the splat frame). + verts = np.ascontiguousarray(verts * np.array([1.0, -1.0, -1.0], dtype=np.float32)) + return (torch.from_numpy(verts), torch.from_numpy(faces), torch.from_numpy(col)) + + +class SplatToMesh(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SplatToMesh", + display_name="Extract Mesh from Splat", + search_aliases=["splat to mesh", "gaussian surface nets", "splat surface", "mesh splat"], + category="3d/splat", + description="Extract a coloured mesh from a gaussian splat.", + inputs=[ + IO.Splat.Input("splat"), + IO.Int.Input("resolution", default=384, min=64, max=768, step=16, + tooltip="Density-grid resolution along the longest axis. Higher = finer surface, " + "more VRAM/time (grows with resolution^3)."), + IO.Int.Input("kernel", default=5, min=1, max=8, + tooltip="Max splat half-width in voxels. Each gaussian is rasterized over a window " + "sized to its own 3-sigma, capped here - small surfels stay cheap, large ones " + "aren't truncated. Raise if sparse splats leave gaps."), + IO.Int.Input("smooth", default=0, min=0, max=60, advanced = True, + tooltip="Taubin mesh-smoothing iterations. Smooths the surface without shrinking it " + "(volume-preserving), unlike blurring the density. 0 = raw surface."), + IO.Float.Input("level", default=0.4, min=0.0, max=2.0, step=0.01, + tooltip="Iso-surface level. Auto-picked by Otsu; this biases it (1.0 = auto, lower = " + "fatter/more-connected surface, higher = thinner/tighter)."), + IO.Int.Input("min_component", default=500, min=0, max=100000, step=50, advanced=True, + tooltip="Drop connected components smaller than this many vertices (0 = keep all). " + "Removes detached floater blobs and the inner shell of the double wall."), + IO.Float.Input("min_opacity", default=0.02, min=0.0, max=1.0, step=0.01, advanced=True, + tooltip="Ignore gaussians fainter than this before meshing."), + IO.Float.Input("color_sharpen", default=2.0, min=1.0, max=8.0, step=0.5, + tooltip="Crisp up the vertex texture: 1.0 = physically-correct blend; higher biases " + "each voxel's colour toward its dominant gaussian instead of averaging " + "neighbours (de-smears the texture). Colour only - geometry is unchanged."), + ], + outputs=[IO.Mesh.Output(display_name="mesh")], + ) + + @classmethod + def execute(cls, splat, resolution, kernel, smooth, level, min_component, min_opacity, color_sharpen) -> IO.NodeOutput: + device = comfy.model_management.get_torch_device() + b = splat.positions.shape[0] + prec = 1000 # each splat owns a 0..prec block of the bar; its callback advances within that block + pbar = comfy.utils.ProgressBar(b * prec) + + verts_l, faces_l, colors_l = [], [], [] + for i in range(b): + cb = lambda f, base=i * prec: pbar.update_absolute(base + int(min(max(f, 0.0), 1.0) * prec)) + res = _gaussian_to_mesh(splat, i, resolution, kernel, smooth, level, min_component, min_opacity, color_sharpen, device, cb) + if res is None: + logging.warning("SplatToMesh: splat %d produced no surface; emitting an empty mesh.", i) + v, f, c = torch.zeros((0, 3)), torch.zeros((0, 3), dtype=torch.int64), torch.zeros((0, 3)) + else: + v, f, c = res + verts_l.append(v) + faces_l.append(f) + colors_l.append(c) + pbar.update_absolute((i + 1) * prec) # snap to block end (covers empty / early-out splats) + # unlit: render flat (emissive-like) so SaveGLB matches the splat instead of lighting/washing it. + return IO.NodeOutput(pack_variable_mesh_batch(verts_l, faces_l, colors=colors_l, unlit=True)) + + +class GaussianExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [SplatToFile3D, File3DToSplat, RenderSplat, CreateCameraInfo, TransformSplat, + GetSplatCount, MergeSplat, SplatToMesh] + + +async def comfy_entrypoint() -> GaussianExtension: + return GaussianExtension() diff --git a/comfy_extras/nodes_save_3d.py b/comfy_extras/nodes_save_3d.py index c03524246..a91549e7f 100644 --- a/comfy_extras/nodes_save_3d.py +++ b/comfy_extras/nodes_save_3d.py @@ -16,7 +16,7 @@ from comfy.cli_args import args from comfy_api.latest import ComfyExtension, IO, Types -def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None): +def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None, unlit=False): # Pack lists of (Nᵢ, *) vertex/face/color/uv tensors into padded batched tensors, # stashing per-item lengths as runtime attrs so consumers can recover the real slice. # colors and uvs are 1:1 with vertices, so they're padded to max_vertices and read with vertex_counts. @@ -54,7 +54,7 @@ def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=Non return Types.MESH(packed_vertices, packed_faces, uvs=packed_uvs, vertex_colors=packed_colors, texture=texture, - vertex_counts=vertex_counts, face_counts=face_counts) + vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit) def get_mesh_batch_item(mesh, index): @@ -77,7 +77,7 @@ def get_mesh_batch_item(mesh, index): def save_glb(vertices, faces, filepath, metadata=None, - uvs=None, vertex_colors=None, texture_image=None): + uvs=None, vertex_colors=None, texture_image=None, unlit=False): """ Save PyTorch tensor vertices and faces as a GLB file without external dependencies. @@ -234,6 +234,17 @@ def save_glb(vertices, faces, filepath, metadata=None, textures = [] samplers = [] materials = [] + extensions_used = [] + if unlit and texture_png_bytes is None: + # Flat, light-independent shading (KHR_materials_unlit): COLOR_0 is shown as-is, matching how a + # gaussian splat renders (emissive). Without this the viewer lights the mesh and washes the colours. + materials.append({ + "pbrMetallicRoughness": {"baseColorFactor": [1.0, 1.0, 1.0, 1.0], "metallicFactor": 0.0, "roughnessFactor": 1.0}, + "extensions": {"KHR_materials_unlit": {}}, + "doubleSided": True, + }) + extensions_used.append("KHR_materials_unlit") + primitive["material"] = 0 if texture_png_bytes is not None and "TEXCOORD_0" in primitive_attributes: buffer_views.append({ "buffer": 0, @@ -271,6 +282,8 @@ def save_glb(vertices, faces, filepath, metadata=None, gltf["textures"] = textures if materials: gltf["materials"] = materials + if extensions_used: + gltf["extensionsUsed"] = extensions_used if metadata: gltf["asset"]["extras"] = metadata @@ -376,7 +389,8 @@ class SaveGLB(IO.ComfyNode): save_glb(vertices_i, faces_i, os.path.join(full_output_folder, f), metadata, uvs=uvs_i, vertex_colors=v_colors, - texture_image=tex_img) + texture_image=tex_img, + unlit=getattr(mesh, "unlit", False)) results.append({ "filename": f, "subfolder": subfolder, diff --git a/nodes.py b/nodes.py index 528bf316f..5678bc22d 100644 --- a/nodes.py +++ b/nodes.py @@ -2455,6 +2455,7 @@ async def init_builtin_extra_nodes(): "nodes_save_3d.py", "nodes_moge.py", "nodes_mediapipe.py", + "nodes_gaussian_splat.py", ] import_failed = []