mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-15 12:29:33 +08:00
Rename gaussian -> splat, improve some tooltips
This commit is contained in:
parent
dd4c7d7661
commit
4a8143f063
@ -5,7 +5,7 @@ from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from ._input_impl import VideoFromFile, VideoFromComponents
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, GAUSSIAN, File3D
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, SPLAT, File3D
|
||||
from . import _io_public as io
|
||||
from . import _ui_public as ui
|
||||
from comfy_execution.utils import get_executing_context
|
||||
@ -143,7 +143,7 @@ class Types:
|
||||
VideoComponents = VideoComponents
|
||||
MESH = MESH
|
||||
VOXEL = VOXEL
|
||||
GAUSSIAN = GAUSSIAN
|
||||
SPLAT = SPLAT
|
||||
File3D = File3D
|
||||
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from ._util import MESH, VOXEL, GAUSSIAN, SVG as _SVG, File3D
|
||||
from ._util import MESH, VOXEL, SPLAT, SVG as _SVG, File3D
|
||||
|
||||
|
||||
class FolderType(str, Enum):
|
||||
@ -684,9 +684,9 @@ class Voxel(ComfyTypeIO):
|
||||
class Mesh(ComfyTypeIO):
|
||||
Type = MESH
|
||||
|
||||
@comfytype(io_type="GAUSSIAN")
|
||||
class Gaussian(ComfyTypeIO):
|
||||
Type = GAUSSIAN
|
||||
@comfytype(io_type="SPLAT")
|
||||
class Splat(ComfyTypeIO):
|
||||
Type = SPLAT
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D")
|
||||
@ -2324,7 +2324,7 @@ __all__ = [
|
||||
"LossMap",
|
||||
"Voxel",
|
||||
"Mesh",
|
||||
"Gaussian",
|
||||
"Splat",
|
||||
"File3DAny",
|
||||
"File3DGLB",
|
||||
"File3DGLTF",
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
from .geometry_types import VOXEL, MESH, GAUSSIAN, File3D
|
||||
from .geometry_types import VOXEL, MESH, SPLAT, File3D
|
||||
from .image_types import SVG
|
||||
|
||||
__all__ = [
|
||||
@ -9,7 +9,7 @@ __all__ = [
|
||||
"VideoComponents",
|
||||
"VOXEL",
|
||||
"MESH",
|
||||
"GAUSSIAN",
|
||||
"SPLAT",
|
||||
"File3D",
|
||||
"SVG",
|
||||
]
|
||||
|
||||
@ -11,7 +11,7 @@ class VOXEL:
|
||||
self.data = data
|
||||
|
||||
|
||||
class GAUSSIAN:
|
||||
class SPLAT:
|
||||
"""A batch of 3D Gaussian splats in render-ready (activated, world-space) form.
|
||||
|
||||
Tensors are (B, N, ...) and zero-padded to a common N across the batch; `counts` (B,) holds the
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Generic utility nodes for the GAUSSIAN type (3D gaussian splats)
|
||||
# Generic utility nodes for the SPLAT type (3D gaussian splats)
|
||||
|
||||
import gzip
|
||||
import logging
|
||||
@ -17,6 +17,7 @@ import comfy.model_management
|
||||
import comfy.utils
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
from comfy_extras.nodes_save_3d import pack_variable_mesh_batch
|
||||
from server import PromptServer
|
||||
|
||||
_C0 = 0.28209479177387814 # SH band-0 constant: DC coefficient -> base RGB
|
||||
|
||||
@ -29,7 +30,7 @@ def _linear_to_srgb(c):
|
||||
return torch.where(c <= 0.0031308, c * 12.92, 1.055 * c.clamp_min(0) ** (1 / 2.4) - 0.055)
|
||||
|
||||
|
||||
def _real_len(g: Types.GAUSSIAN, i: int) -> int:
|
||||
def _real_len(g: Types.SPLAT, i: int) -> int:
|
||||
# Real splat count of batch item i (honors variable-length `counts`).
|
||||
return int(g.counts[i].item()) if g.counts is not None else g.positions.shape[1]
|
||||
|
||||
@ -52,7 +53,7 @@ def _gaussian_ply_bytes(positions, scales, rotations, opacities, sh) -> bytes:
|
||||
xyz = positions.cpu().numpy().astype(np.float32)
|
||||
n = xyz.shape[0]
|
||||
if n == 0:
|
||||
raise ValueError("GaussianToFile3D: gaussian is empty")
|
||||
raise ValueError("SplatToFile3D: gaussian is empty")
|
||||
normals = np.zeros_like(xyz)
|
||||
f = sh.cpu().numpy().astype(np.float32) # (N, K, 3)
|
||||
f_dc = f[:, 0, :] # (N, 3)
|
||||
@ -90,7 +91,7 @@ def _gaussian_ksplat_bytes(positions, scales, rotations, opacities, sh) -> bytes
|
||||
xyz = positions.cpu().numpy().astype(np.float32)
|
||||
n = xyz.shape[0]
|
||||
if n == 0:
|
||||
raise ValueError("GaussianToFile3D: gaussian is empty")
|
||||
raise ValueError("SplatToFile3D: gaussian is empty")
|
||||
scale = scales.cpu().numpy().astype(np.float32)
|
||||
rot = rotations.cpu().numpy().astype(np.float32) # wxyz, mirrors the .ply rot order
|
||||
rot = rot / np.linalg.norm(rot, axis=1, keepdims=True).clip(1e-12)
|
||||
@ -145,7 +146,7 @@ def _gaussian_spz_bytes(positions, scales, rotations, opacities, sh) -> bytes:
|
||||
xyz = positions.cpu().numpy().astype(np.float32)
|
||||
n = xyz.shape[0]
|
||||
if n == 0:
|
||||
raise ValueError("GaussianToFile3D: gaussian is empty")
|
||||
raise ValueError("SplatToFile3D: gaussian is empty")
|
||||
|
||||
# Positions: fixed point, masked to 24 bits, little-endian 3-byte words.
|
||||
fixed = 1 << _SPZ_FRACTIONAL_BITS
|
||||
@ -202,7 +203,7 @@ def _norm_quat(q):
|
||||
def _parse_ply_gaussian(data: bytes):
|
||||
end = data.find(b'end_header')
|
||||
if end < 0:
|
||||
raise ValueError("File3DToGaussian: not a PLY (missing end_header)")
|
||||
raise ValueError("File3DToSplat: not a PLY (missing end_header)")
|
||||
header = data[:end].decode('ascii', 'replace')
|
||||
body = end + len(b'end_header')
|
||||
body += 2 if data[body:body + 2] == b'\r\n' else 1
|
||||
@ -212,14 +213,14 @@ def _parse_ply_gaussian(data: bytes):
|
||||
if not p:
|
||||
continue
|
||||
if p[0] == 'format' and p[1] != 'binary_little_endian':
|
||||
raise ValueError(f"File3DToGaussian: unsupported PLY format '{p[1]}' (need binary_little_endian)")
|
||||
raise ValueError(f"File3DToSplat: unsupported PLY format '{p[1]}' (need binary_little_endian)")
|
||||
if p[0] == 'element':
|
||||
in_vertex = p[1] == 'vertex'
|
||||
if in_vertex:
|
||||
count = int(p[2])
|
||||
elif p[0] == 'property' and in_vertex:
|
||||
if p[1] == 'list':
|
||||
raise ValueError("File3DToGaussian: PLY vertex has list properties (unsupported)")
|
||||
raise ValueError("File3DToSplat: PLY vertex has list properties (unsupported)")
|
||||
props.append((p[2], '<' + _PLY_DTYPES[p[1]]))
|
||||
arr = np.frombuffer(data, np.dtype(props), count=count, offset=body)
|
||||
names = arr.dtype.names
|
||||
@ -257,7 +258,7 @@ def _parse_ply_gaussian(data: bytes):
|
||||
def _parse_splat_gaussian(data: bytes):
|
||||
# antimatter15 .splat: 32-byte records (f32 xyz, f32 scale, u8 rgba, u8 quat as (b-128)/128 wxyz).
|
||||
if len(data) % 32 != 0:
|
||||
raise ValueError("File3DToGaussian: .splat size is not a multiple of 32 bytes")
|
||||
raise ValueError("File3DToSplat: .splat size is not a multiple of 32 bytes")
|
||||
rec = np.frombuffer(data, np.dtype([('xyz', '<f4', 3), ('scale', '<f4', 3),
|
||||
('rgba', 'u1', 4), ('quat', 'u1', 4)]))
|
||||
rgba = rec['rgba'].astype(np.float32) / 255.0
|
||||
@ -270,11 +271,11 @@ def _parse_ksplat_gaussian(data: bytes):
|
||||
# mkkellogg SplatBuffer: 4096-byte header, N section headers, then per-section splat data. Supports
|
||||
# levels 0 (float) / 1 (half + bucketed positions) / 2 (half, uint8 SH). SH is skipped (base color kept).
|
||||
if data[0] != 0:
|
||||
raise ValueError(f"File3DToGaussian: unsupported .ksplat version {data[0]}.{data[1]}")
|
||||
raise ValueError(f"File3DToSplat: unsupported .ksplat version {data[0]}.{data[1]}")
|
||||
max_sections = struct.unpack_from('<I', data, 4)[0]
|
||||
level = struct.unpack_from('<H', data, 20)[0]
|
||||
if level not in _KSPLAT_COMPRESSION:
|
||||
raise ValueError(f"File3DToGaussian: invalid .ksplat compression level {level}")
|
||||
raise ValueError(f"File3DToSplat: invalid .ksplat compression level {level}")
|
||||
bc, bs, br, bcol, bshc, default_range = _KSPLAT_COMPRESSION[level]
|
||||
|
||||
parts = []
|
||||
@ -321,7 +322,7 @@ def _parse_ksplat_gaussian(data: bytes):
|
||||
base += bytes_per_splat * sec_max + buckets_store
|
||||
|
||||
if not parts:
|
||||
raise ValueError("File3DToGaussian: .ksplat has no splats")
|
||||
raise ValueError("File3DToSplat: .ksplat has no splats")
|
||||
return tuple(np.concatenate([p[i] for p in parts]) for i in range(5))
|
||||
|
||||
|
||||
@ -329,7 +330,7 @@ def _parse_spz_gaussian(data: bytes):
|
||||
# Niantic .spz (gzip-wrapped), versions 1-3. Base color only (SH skipped). See spark's SpzReader.
|
||||
raw = gzip.decompress(data)
|
||||
if struct.unpack_from('<I', raw, 0)[0] != _SPZ_MAGIC:
|
||||
raise ValueError("File3DToGaussian: invalid .spz (bad magic)")
|
||||
raise ValueError("File3DToSplat: invalid .spz (bad magic)")
|
||||
version = struct.unpack_from('<I', raw, 4)[0]
|
||||
n = struct.unpack_from('<I', raw, 8)[0]
|
||||
frac_bits = raw[13]
|
||||
@ -345,7 +346,7 @@ def _parse_spz_gaussian(data: bytes):
|
||||
xyz = (v / (1 << frac_bits)).astype(np.float32)
|
||||
off += n * 9
|
||||
else:
|
||||
raise ValueError(f"File3DToGaussian: unsupported .spz version {version}")
|
||||
raise ValueError(f"File3DToSplat: unsupported .spz version {version}")
|
||||
|
||||
alpha = np.frombuffer(raw, np.uint8, count=n, offset=off).astype(np.float32) / 255.0
|
||||
off += n
|
||||
@ -394,10 +395,10 @@ def _detect_splat_format(data: bytes) -> str:
|
||||
return "ksplat"
|
||||
if len(data) % 32 == 0:
|
||||
return "splat"
|
||||
raise ValueError("File3DToGaussian: could not determine splat format from contents")
|
||||
raise ValueError("File3DToSplat: could not determine splat format from contents")
|
||||
|
||||
|
||||
def _gaussian_item(g: Types.GAUSSIAN, i: int, device):
|
||||
def _gaussian_item(g: Types.SPLAT, i: int, device):
|
||||
# Slice batch item i to its real length, as float32 torch tensors on `device` (SH DC -> base RGB).
|
||||
end = _real_len(g, i)
|
||||
to = lambda a: a.to(device=device, dtype=torch.float32)
|
||||
@ -461,49 +462,48 @@ def _mat_to_quat(m):
|
||||
return q / q.norm(dim=-1, keepdim=True).clamp_min(1e-12)
|
||||
|
||||
|
||||
class GaussianToFile3D(IO.ComfyNode):
|
||||
class SplatToFile3D(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GaussianToFile3D",
|
||||
display_name="Create 3D File (from Gaussian)",
|
||||
node_id="SplatToFile3D",
|
||||
display_name="Create 3D File (from Splat)",
|
||||
search_aliases=["gaussian to ply", "splat to file", "export gaussian"],
|
||||
category="3d/gaussian",
|
||||
description="Serialize a gaussian splat to an in-memory File3D, for Save 3D Model / Preview 3D. "
|
||||
"ply keeps full SH (standard 3DGS); ksplat and spz are compact viewer formats (base "
|
||||
description="Serialize a gaussian splat to a File3D object for Save / Preview 3D nodes. "
|
||||
"ply keeps full spherical harmonics (standard 3DGS); ksplat and spz are compact viewer formats (base "
|
||||
"color only). Single splat only - feed one batch item at a time.",
|
||||
inputs=[
|
||||
IO.Gaussian.Input("gaussian"),
|
||||
IO.Splat.Input("splat"),
|
||||
IO.Combo.Input("format", options=["ply", "ksplat", "spz"],
|
||||
tooltip="ply: standard 3DGS with full spherical harmonics. ksplat: mkkellogg "
|
||||
"SplatBuffer (level 0, uncompressed). spz: Niantic gzip-compressed "
|
||||
"(~10x smaller). ksplat/spz keep base color only - view-dependent SH "
|
||||
"is dropped."),
|
||||
"(~10x smaller). ksplat/spz keep base color only - view-dependent spherical harmonics is dropped."),
|
||||
],
|
||||
outputs=[IO.File3DAny.Output(display_name="model_3d")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, gaussian, format="ply") -> IO.NodeOutput:
|
||||
if gaussian.positions.shape[0] > 1:
|
||||
logging.warning("GaussianToFile3D: got a batch of %d; converting only the first splat (File3D is a "
|
||||
"single file).", gaussian.positions.shape[0])
|
||||
end = _real_len(gaussian, 0)
|
||||
def execute(cls, splat, format="ply") -> IO.NodeOutput:
|
||||
if splat.positions.shape[0] > 1:
|
||||
logging.warning("SplatToFile3D: got a batch of %d; converting only the first splat (File3D is a "
|
||||
"single file).", splat.positions.shape[0])
|
||||
end = _real_len(splat, 0)
|
||||
writer = {"ksplat": _gaussian_ksplat_bytes, "spz": _gaussian_spz_bytes}.get(format, _gaussian_ply_bytes)
|
||||
data = writer(gaussian.positions[0, :end], gaussian.scales[0, :end],
|
||||
gaussian.rotations[0, :end], gaussian.opacities[0, :end], gaussian.sh[0, :end])
|
||||
data = writer(splat.positions[0, :end], splat.scales[0, :end],
|
||||
splat.rotations[0, :end], splat.opacities[0, :end], splat.sh[0, :end])
|
||||
return IO.NodeOutput(Types.File3D(BytesIO(data), file_format=format))
|
||||
|
||||
|
||||
class File3DToGaussian(IO.ComfyNode):
|
||||
class File3DToSplat(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="File3DToGaussian",
|
||||
display_name="Get Gaussian Splat",
|
||||
node_id="File3DToSplat",
|
||||
display_name="Get Splat",
|
||||
search_aliases=["load splat", "ply to gaussian", "import gaussian", "file to splat"],
|
||||
category="3d/gaussian",
|
||||
description="Parse a splat File3D (.ply / .splat / .ksplat / .spz) into a GAUSSIAN. Inverse of "
|
||||
description="Parse a splat File3D (.ply / .splat / .ksplat / .spz) into a gaussian. Inverse of "
|
||||
"Create 3D File (from Gaussian). ply carries full spherical harmonics; the others are base "
|
||||
"color only. Format is auto-detected from the file contents.",
|
||||
inputs=[
|
||||
@ -513,7 +513,7 @@ class File3DToGaussian(IO.ComfyNode):
|
||||
tooltip="A gaussian-splat 3D file",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Gaussian.Output(display_name="gaussian")],
|
||||
outputs=[IO.Splat.Output(display_name="splat")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -524,14 +524,14 @@ class File3DToGaussian(IO.ComfyNode):
|
||||
xyz, scale, rot, opacity, sh = parser(data)
|
||||
|
||||
t = lambda a: torch.from_numpy(np.ascontiguousarray(a)).float()
|
||||
gaussian = Types.GAUSSIAN(
|
||||
splat = Types.SPLAT(
|
||||
t(xyz)[None], # (1, N, 3)
|
||||
t(scale)[None], # (1, N, 3) linear
|
||||
t(rot)[None], # (1, N, 4) wxyz
|
||||
t(opacity).reshape(1, -1, 1), # (1, N, 1)
|
||||
t(sh)[None], # (1, N, K, 3)
|
||||
)
|
||||
return IO.NodeOutput(gaussian)
|
||||
return IO.NodeOutput(splat)
|
||||
|
||||
|
||||
def _view_matrix_t(yaw_deg, pitch_deg, device):
|
||||
@ -572,7 +572,7 @@ def _gauss_blur(x, sigma, dev):
|
||||
|
||||
def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, fov, splat_scale, bg, sharpen=1.0,
|
||||
headlight_shading=0.0, render_style="color", camera_info=None,
|
||||
yaw=35.0, pitch=30.0, zoom=1.0):
|
||||
yaw=35.0, pitch=30.0, zoom=1.0, distance=0.0):
|
||||
# Perspective-correct anisotropic gaussian-splat rasterizer. Each splat is weighted by its 3D Gaussian's
|
||||
# peak along each pixel's ray (AAA / Hahlbohm), composited front-to-back across depth slabs. `render_style`
|
||||
# selects the image: color / clay / depth / normal. Returns (image HxWx3, coverage mask HxW) on CPU.
|
||||
@ -609,7 +609,8 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, fov, splat_sc
|
||||
fov = fov if fov > 0 else 35.0 # fov=0 -> default 35
|
||||
center = xyz.mean(0)
|
||||
extent = (xyz - center).norm(dim=-1).quantile(0.99).clamp_min(1e-4) # ignore outlier floaters
|
||||
dist = extent / (math.tan(math.radians(fov) / 2) * 0.9) / max(zoom, 1e-3)
|
||||
base = distance if distance > 0 else extent / (math.tan(math.radians(fov) / 2) * 0.9) # absolute dist, else auto-frame
|
||||
dist = base / max(zoom, 1e-3)
|
||||
W = _view_matrix_t(yaw, pitch, dev)
|
||||
cam = (xyz - center) @ W.T + torch.tensor([0.0, 0.0, dist], device=dev)
|
||||
yflip = 1.0
|
||||
@ -775,19 +776,20 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, fov, splat_sc
|
||||
return img.clamp(0, 1).cpu(), covg.clamp(0, 1).cpu()
|
||||
|
||||
|
||||
class RenderGaussian(IO.ComfyNode):
|
||||
class RenderSplat(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RenderGaussian",
|
||||
display_name="Render Gaussian Splat",
|
||||
node_id="RenderSplat",
|
||||
display_name="Render Splat",
|
||||
search_aliases=["splat to image", "render splat", "gaussian turntable"],
|
||||
category="3d/gaussian",
|
||||
description="Render a gaussian splat to an image with an anisotropic EWA rasterizer (oriented "
|
||||
"elliptical splats, antialiased, depth-sorted front-to-back). frames>1 sweeps yaw a full "
|
||||
"360 turn, producing an image batch (turntable) you can pipe into a video node.",
|
||||
"elliptical splats, antialiased, depth-sorted front-to-back). Set frames greater than 1 "
|
||||
"to sweep the camera yaw through a full 360° rotation, producing a batch of images "
|
||||
"(a turntable) that you can feed into a video node.",
|
||||
inputs=[
|
||||
IO.Gaussian.Input("gaussian"),
|
||||
IO.Splat.Input("splat"),
|
||||
IO.Int.Input("width", default=1024, min=64, max=2048, step=8),
|
||||
IO.Int.Input("height", default=1024, min=64, max=2048, step=8),
|
||||
IO.Int.Input("frames", default=1, min=-240, max=240,
|
||||
@ -796,11 +798,13 @@ class RenderGaussian(IO.ComfyNode):
|
||||
IO.Float.Input("yaw", default=35.0, min=-360.0, max=360.0, step=1.0),
|
||||
IO.Float.Input("pitch", default=30.0, min=-89.0, max=89.0, step=1.0),
|
||||
IO.Float.Input("zoom", default=1.0, min=0.1, max=5.0, step=0.05,
|
||||
tooltip="Camera dolly: >1 zooms in, <1 out. Without camera_info, 1.0 frames the whole "
|
||||
"splat (~10% margin); with camera_info, 1.0 is exactly the supplied camera."),
|
||||
tooltip="Camera dolly: >1 zooms in, <1 out. With camera_info or distance, 1.0 is exactly "
|
||||
"that camera; otherwise 1.0 frames the whole splat (~10% margin)."),
|
||||
IO.Float.Input("distance", default=0.0, min=0.0, max=1000.0, step=0.01,
|
||||
tooltip="Absolute camera distance for the yaw/pitch orbit (from Get Camera Info). "
|
||||
"0 = auto-frame the whole splat. Ignored when camera_info is connected."),
|
||||
IO.Float.Input("fov", default=0.0, min=0.0, max=120.0, step=1.0,
|
||||
tooltip="Vertical field of view in degrees. 0 = auto: 35, or taken from camera_info "
|
||||
"when connected. Any value >0 overrides (including over camera_info)."),
|
||||
tooltip="Vertical field of view in degrees. 0 = camera_info if provided, otherwise defaults to 35. Any value above 0 overrides the camera_info FoV."),
|
||||
IO.Float.Input("splat_scale", default=1.0, min=0.1, max=5.0, step=0.05, advanced=True,
|
||||
tooltip="Multiplier on each splat's projected footprint (lower = crisper points, "
|
||||
"higher = softer/fuller surface)."),
|
||||
@ -832,8 +836,8 @@ class RenderGaussian(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, gaussian, width, height, yaw, pitch, frames, zoom, fov, splat_scale, sharpen,
|
||||
headlight_shading, opacity_threshold, background, render_style,
|
||||
def execute(cls, splat, width, height, yaw, pitch, frames, zoom, fov, splat_scale, sharpen,
|
||||
headlight_shading, opacity_threshold, background, render_style, distance=0.0,
|
||||
camera_info=None, bg_image=None) -> IO.NodeOutput:
|
||||
bg = _hex_to_rgb(background)
|
||||
bg_imgs = None
|
||||
@ -844,17 +848,17 @@ class RenderGaussian(IO.ComfyNode):
|
||||
orbit_dir = -1.0 if frames < 0 else 1.0 # sign = orbit direction
|
||||
if camera_info is not None:
|
||||
if n_frames > 1:
|
||||
logging.warning("RenderGaussian: camera_info is a fixed camera; ignoring frames=%d (no orbit sweep).", frames)
|
||||
logging.warning("RenderSplat: camera_info is a fixed camera; ignoring frames=%d (no orbit sweep).", frames)
|
||||
n_frames = 1
|
||||
if str(camera_info.get("cameraType", "")).lower().startswith("ortho"):
|
||||
logging.warning("RenderGaussian: orthographic camera_info is rendered with a perspective camera.")
|
||||
logging.warning("RenderSplat: orthographic camera_info is rendered with a perspective camera.")
|
||||
imgs, masks = [], []
|
||||
device = comfy.model_management.get_torch_device() # render device; splat stays in torch here -> no roundtrip
|
||||
total = gaussian.positions.shape[0] * n_frames
|
||||
total = splat.positions.shape[0] * n_frames
|
||||
pbar = comfy.utils.ProgressBar(total) if total > 1 else None
|
||||
k = 0
|
||||
for i in range(gaussian.positions.shape[0]):
|
||||
xyz, rgb, opacity, scale, rot = _gaussian_item(gaussian, i, device)
|
||||
for i in range(splat.positions.shape[0]):
|
||||
xyz, rgb, opacity, scale, rot = _gaussian_item(splat, i, device)
|
||||
if opacity_threshold > 0:
|
||||
keep = opacity >= opacity_threshold
|
||||
xyz, rgb, opacity, scale, rot = xyz[keep], rgb[keep], opacity[keep], scale[keep], rot[keep]
|
||||
@ -863,7 +867,8 @@ class RenderGaussian(IO.ComfyNode):
|
||||
bg_k = bg_imgs[k % bg_imgs.shape[0]] if bg_imgs is not None else bg # per-frame plate, or solid colour
|
||||
img, mask = _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, fov, splat_scale, bg_k,
|
||||
sharpen=sharpen, headlight_shading=headlight_shading,
|
||||
render_style=render_style, camera_info=camera_info, yaw=y, pitch=pitch, zoom=zoom)
|
||||
render_style=render_style, camera_info=camera_info, yaw=y, pitch=pitch,
|
||||
zoom=zoom, distance=distance)
|
||||
imgs.append(img)
|
||||
masks.append(mask)
|
||||
k += 1
|
||||
@ -872,19 +877,18 @@ class RenderGaussian(IO.ComfyNode):
|
||||
return IO.NodeOutput(torch.stack(imgs), torch.stack(masks))
|
||||
|
||||
|
||||
class TransformGaussian(IO.ComfyNode):
|
||||
class TransformSplat(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TransformGaussian",
|
||||
display_name="Transform Gaussian Splat",
|
||||
node_id="TransformSplat",
|
||||
display_name="Transform Splat",
|
||||
search_aliases=["move splat", "rotate splat", "scale splat", "gaussian transform"],
|
||||
category="3d/gaussian",
|
||||
description="Translate, rotate (Euler XYZ degrees) and scale (per-axis) a gaussian splat. Positions, "
|
||||
"per-splat rotations and scales transform consistently; non-uniform scale re-derives each "
|
||||
"splat's covariance (eigendecomposition) so the ellipsoids deform correctly.",
|
||||
description="Translate, rotate, and scale a gaussian splat."
|
||||
"Non-uniform scale also reshapes every individual splat, slower process.",
|
||||
inputs=[
|
||||
IO.Gaussian.Input("gaussian"),
|
||||
IO.Splat.Input("splat"),
|
||||
IO.Float.Input("translate_x", default=0.0, min=-100.0, max=100.0, step=0.01),
|
||||
IO.Float.Input("translate_y", default=0.0, min=-100.0, max=100.0, step=0.01),
|
||||
IO.Float.Input("translate_z", default=0.0, min=-100.0, max=100.0, step=0.01),
|
||||
@ -895,13 +899,13 @@ class TransformGaussian(IO.ComfyNode):
|
||||
IO.Float.Input("scale_y", default=1.0, min=0.01, max=100.0, step=0.01),
|
||||
IO.Float.Input("scale_z", default=1.0, min=0.01, max=100.0, step=0.01),
|
||||
],
|
||||
outputs=[IO.Gaussian.Output(display_name="gaussian")],
|
||||
outputs=[IO.Splat.Output(display_name="splat")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, gaussian, translate_x, translate_y, translate_z,
|
||||
def execute(cls, splat, translate_x, translate_y, translate_z,
|
||||
rotate_x, rotate_y, rotate_z, scale_x, scale_y, scale_z) -> IO.NodeOutput:
|
||||
pos = gaussian.positions
|
||||
pos = splat.positions
|
||||
dev, dt = pos.device, pos.dtype
|
||||
q_rot = _euler_to_quat(rotate_x, rotate_y, rotate_z).to(device=dev, dtype=dt)
|
||||
R = _quat_to_mat(q_rot[None])[0] # (3, 3) node rotation
|
||||
@ -911,51 +915,45 @@ class TransformGaussian(IO.ComfyNode):
|
||||
|
||||
positions = pos @ A.T + t # rotate, scale per-axis, then translate
|
||||
if scale_x == scale_y == scale_z: # uniform: rotation/scale factor out cleanly
|
||||
scales = gaussian.scales * scale_x
|
||||
rotations = _quat_mul(q_rot.expand_as(gaussian.rotations), gaussian.rotations)
|
||||
scales = splat.scales * scale_x
|
||||
rotations = _quat_mul(q_rot.expand_as(splat.rotations), splat.rotations)
|
||||
rotations = rotations / rotations.norm(dim=-1, keepdim=True).clamp_min(1e-12)
|
||||
else: # non-uniform: transform Sigma = A R s^2 R^T A^T, re-extract
|
||||
rg = _quat_to_mat(gaussian.rotations.reshape(-1, 4)) # (M,3,3) per-splat rotation
|
||||
s2 = gaussian.scales.reshape(-1, 3).square()
|
||||
cov = (rg * s2[:, None, :]) @ rg.transpose(-1, -2) # Sigma
|
||||
cov = A @ cov @ A.T # A Sigma A^T (A broadcast over splats)
|
||||
rg = _quat_to_mat(splat.rotations.reshape(-1, 4)) # (M,3,3) per-splat rotation
|
||||
s2 = splat.scales.reshape(-1, 3).square()
|
||||
cov = (rg * s2[:, None, :]) @ rg.transpose(-1, -2) # Sigma
|
||||
cov = A @ cov @ A.T # A Sigma A^T (A broadcast over splats)
|
||||
lam, V = torch.linalg.eigh(cov) # symmetric -> eigenvalues (asc), orthonormal axes
|
||||
V = V * torch.where(torch.linalg.det(V) < 0, -1.0, 1.0)[..., None, None] # keep a proper rotation
|
||||
scales = lam.clamp_min(0).sqrt().reshape(gaussian.scales.shape)
|
||||
rotations = _mat_to_quat(V).reshape(gaussian.rotations.shape)
|
||||
out = Types.GAUSSIAN(positions, scales, rotations, gaussian.opacities, gaussian.sh,
|
||||
counts=getattr(gaussian, "counts", None))
|
||||
scales = lam.clamp_min(0).sqrt().reshape(splat.scales.shape)
|
||||
rotations = _mat_to_quat(V).reshape(splat.rotations.shape)
|
||||
out = Types.SPLAT(positions, scales, rotations, splat.opacities, splat.sh,
|
||||
counts=getattr(splat, "counts", None))
|
||||
return IO.NodeOutput(out)
|
||||
|
||||
|
||||
class GaussianInfo(IO.ComfyNode):
|
||||
class GetSplatCount(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GaussianInfo",
|
||||
display_name="Gaussian Splat Info",
|
||||
search_aliases=["splat stats", "gaussian count", "splat info"],
|
||||
node_id="GetSplatCount",
|
||||
display_name="Get Splat Count",
|
||||
search_aliases=["splat count", "gaussian count", "number of splats", "splat info"],
|
||||
category="3d/gaussian",
|
||||
description="Report per-splat stats: count, bounding box, and opacity/scale ranges.",
|
||||
inputs=[IO.Gaussian.Input("gaussian")],
|
||||
outputs=[IO.String.Output(display_name="info")],
|
||||
description="Returns the number of splats (summed across the batch) and shows it on the node.",
|
||||
inputs=[IO.Splat.Input("splat")],
|
||||
outputs=[IO.Splat.Output(display_name="splat"),
|
||||
IO.Int.Output(display_name="count"),
|
||||
],
|
||||
hidden=[IO.Hidden.unique_id],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, gaussian) -> IO.NodeOutput:
|
||||
lines = []
|
||||
for i in range(gaussian.positions.shape[0]):
|
||||
xyz, _, opacity, scale, _ = _gaussian_item(gaussian, i, torch.device("cpu"))
|
||||
lo, hi = xyz.amin(0), xyz.amax(0)
|
||||
fmt = lambda v: "[" + ", ".join(f"{x:.3f}" for x in v) + "]"
|
||||
lines.append(
|
||||
f"gaussian[{i}]: count={xyz.shape[0]}\n"
|
||||
f" aabb min={fmt(lo)} max={fmt(hi)} size={fmt(hi - lo)}\n"
|
||||
f" opacity mean={opacity.mean():.3f} min={opacity.min():.3f} max={opacity.max():.3f}\n"
|
||||
f" scale mean={scale.mean():.4f} min={scale.min():.4f} max={scale.max():.4f}"
|
||||
)
|
||||
info = "\n".join(lines)
|
||||
return IO.NodeOutput(info, ui={"text": [info]})
|
||||
def execute(cls, splat) -> IO.NodeOutput:
|
||||
count = sum(_real_len(splat, i) for i in range(splat.positions.shape[0]))
|
||||
if cls.hidden.unique_id: # show the count inline on the node
|
||||
PromptServer.instance.send_progress_text(f"{count:,} splats", cls.hidden.unique_id)
|
||||
return IO.NodeOutput(splat, count)
|
||||
|
||||
|
||||
def _pad_stack(items, n):
|
||||
@ -967,15 +965,15 @@ def _pad_stack(items, n):
|
||||
return out
|
||||
|
||||
|
||||
def _merge_gaussians(gaussians: list) -> Types.GAUSSIAN:
|
||||
# Concatenate GAUSSIAN batches along the splat dimension (per item), padding SH to the highest degree.
|
||||
def _merge_gaussians(gaussians: list) -> Types.SPLAT:
|
||||
# Concatenate SPLAT batches along the splat dimension (per item), padding SH to the highest degree.
|
||||
gs = [g for g in gaussians if g is not None]
|
||||
if not gs:
|
||||
raise ValueError("MergeGaussian: no gaussians to merge")
|
||||
raise ValueError("MergeSplat: no gaussians to merge")
|
||||
b = gs[0].positions.shape[0]
|
||||
for g in gs:
|
||||
if g.positions.shape[0] != b:
|
||||
raise ValueError(f"MergeGaussian: batch size mismatch ({b} vs {g.positions.shape[0]}).")
|
||||
raise ValueError(f"MergeSplat: batch size mismatch ({b} vs {g.positions.shape[0]}).")
|
||||
max_k = max(g.sh.shape[2] for g in gs)
|
||||
|
||||
pos_b, scl_b, rot_b, op_b, sh_b, lengths = [], [], [], [], [], []
|
||||
@ -1002,32 +1000,31 @@ def _merge_gaussians(gaussians: list) -> Types.GAUSSIAN:
|
||||
counts = None
|
||||
if len(set(lengths)) > 1:
|
||||
counts = torch.tensor(lengths, device=gs[0].positions.device, dtype=torch.int64)
|
||||
return Types.GAUSSIAN(_pad_stack(pos_b, n), _pad_stack(scl_b, n), _pad_stack(rot_b, n),
|
||||
return Types.SPLAT(_pad_stack(pos_b, n), _pad_stack(scl_b, n), _pad_stack(rot_b, n),
|
||||
_pad_stack(op_b, n), _pad_stack(sh_b, n), counts=counts)
|
||||
|
||||
|
||||
class MergeGaussian(IO.ComfyNode):
|
||||
class MergeSplat(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
# Autogrow: a gaussian0/gaussian1/... input list that grows a fresh slot as you connect splats.
|
||||
gaussians = IO.Autogrow.TemplatePrefix(IO.Gaussian.Input("gaussian"), prefix="gaussian", min=2, max=32)
|
||||
# Autogrow: a splat0/splat1/... input list that grows a fresh slot as you connect splats.
|
||||
splats = IO.Autogrow.TemplatePrefix(IO.Splat.Input("splat"), prefix="splat", min=2, max=32)
|
||||
return IO.Schema(
|
||||
node_id="MergeGaussian",
|
||||
display_name="Merge Gaussian Splats",
|
||||
node_id="MergeSplat",
|
||||
display_name="Merge Splats",
|
||||
search_aliases=["union splat", "densify gaussian", "combine splat", "merge gaussian"],
|
||||
category="3d/gaussian",
|
||||
description="Concatenate any number of gaussian splats into one (per batch item). Because the "
|
||||
"TripoSplat decoder samples points stochastically, unioning several decodes of the same "
|
||||
"latent at different seeds densifies the surface - feed them here, then mesh the result.",
|
||||
inputs=[IO.Autogrow.Input("gaussians", template=gaussians)],
|
||||
outputs=[IO.Gaussian.Output(display_name="gaussian")],
|
||||
description="Concatenate any number of gaussian splats into one. Unioning several decodes of the same "
|
||||
"latent at different seeds densifies the surface, this can improve surface quality when meshing.",
|
||||
inputs=[IO.Autogrow.Input("splats", template=splats)],
|
||||
outputs=[IO.Splat.Output(display_name="splat")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, gaussians: IO.Autogrow.Type) -> IO.NodeOutput:
|
||||
gs = [v for v in gaussians.values() if v is not None]
|
||||
def execute(cls, splats: IO.Autogrow.Type) -> IO.NodeOutput:
|
||||
gs = [v for v in splats.values() if v is not None]
|
||||
if not gs:
|
||||
raise ValueError("MergeGaussian: connect at least one gaussian splat.")
|
||||
raise ValueError("MergeSplat: connect at least one splat.")
|
||||
return IO.NodeOutput(_merge_gaussians(gs))
|
||||
|
||||
|
||||
@ -1232,7 +1229,7 @@ def _taubin_smooth(verts, faces, iters, lam=0.5, mu=-0.53):
|
||||
return np.ascontiguousarray(v.astype(np.float32))
|
||||
|
||||
|
||||
def _gaussian_to_mesh(g: Types.GAUSSIAN, i, res, kernel, taubin, level_bias, min_component, min_opacity, color_sharpen, device, progress=None):
|
||||
def _gaussian_to_mesh(g: Types.SPLAT, i, res, kernel, taubin, level_bias, min_component, min_opacity, color_sharpen, device, progress=None):
|
||||
# Mesh one splat: density + colour grids -> Surface Nets -> floater removal -> Taubin smoothing ->
|
||||
# volume-sampled colours. Returns (verts, faces int64, colors in [0,1]), or None if no surface.
|
||||
rep = progress if progress is not None else (lambda *_: None)
|
||||
@ -1297,12 +1294,12 @@ def _gaussian_to_mesh(g: Types.GAUSSIAN, i, res, kernel, taubin, level_bias, min
|
||||
return (torch.from_numpy(verts), torch.from_numpy(faces), torch.from_numpy(col))
|
||||
|
||||
|
||||
class GaussianToMesh(IO.ComfyNode):
|
||||
class SplatToMesh(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GaussianToMesh",
|
||||
display_name="Gaussian Splat to Mesh",
|
||||
node_id="SplatToMesh",
|
||||
display_name="Extract Mesh from Splat",
|
||||
search_aliases=["splat to mesh", "gaussian surface nets", "splat surface", "mesh splat"],
|
||||
category="3d/gaussian",
|
||||
description="Extract a coloured triangle MESH from a gaussian splat. Each splat is rasterized into a "
|
||||
@ -1310,7 +1307,7 @@ class GaussianToMesh(IO.ComfyNode):
|
||||
"tiny floaters are dropped, and vertices are coloured from their nearest gaussians. Denser "
|
||||
"splats give more detail - union several decodes with Merge Gaussian Splats first.",
|
||||
inputs=[
|
||||
IO.Gaussian.Input("gaussian"),
|
||||
IO.Splat.Input("splat"),
|
||||
IO.Int.Input("resolution", default=512, min=64, max=1024, step=16,
|
||||
tooltip="Density-grid resolution along the longest axis. Higher = finer surface, "
|
||||
"more VRAM/time (grows with resolution^3)."),
|
||||
@ -1318,7 +1315,7 @@ class GaussianToMesh(IO.ComfyNode):
|
||||
tooltip="Max splat half-width in voxels. Each gaussian is rasterized over a window "
|
||||
"sized to its own 3-sigma, capped here - small surfels stay cheap, large ones "
|
||||
"aren't truncated. Raise if sparse splats leave gaps."),
|
||||
IO.Int.Input("smooth", default=0, min=0, max=60,
|
||||
IO.Int.Input("smooth", default=0, min=0, max=60, advanced = True,
|
||||
tooltip="Taubin mesh-smoothing iterations. Smooths the surface without shrinking it "
|
||||
"(volume-preserving), unlike blurring the density. 0 = raw surface."),
|
||||
IO.Float.Input("level", default=0.6, min=0.3, max=2.0, step=0.05,
|
||||
@ -1338,18 +1335,18 @@ class GaussianToMesh(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, gaussian, resolution, kernel, smooth, level, min_component, min_opacity, color_sharpen) -> IO.NodeOutput:
|
||||
def execute(cls, splat, resolution, kernel, smooth, level, min_component, min_opacity, color_sharpen) -> IO.NodeOutput:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
b = gaussian.positions.shape[0]
|
||||
b = splat.positions.shape[0]
|
||||
prec = 1000 # each splat owns a 0..prec block of the bar; its callback advances within that block
|
||||
pbar = comfy.utils.ProgressBar(b * prec)
|
||||
|
||||
verts_l, faces_l, colors_l = [], [], []
|
||||
for i in range(b):
|
||||
cb = lambda f, base=i * prec: pbar.update_absolute(base + int(min(max(f, 0.0), 1.0) * prec))
|
||||
res = _gaussian_to_mesh(gaussian, i, resolution, kernel, smooth, level, min_component, min_opacity, color_sharpen, device, cb)
|
||||
res = _gaussian_to_mesh(splat, i, resolution, kernel, smooth, level, min_component, min_opacity, color_sharpen, device, cb)
|
||||
if res is None:
|
||||
logging.warning("GaussianToMesh: splat %d produced no surface; emitting an empty mesh.", i)
|
||||
logging.warning("SplatToMesh: splat %d produced no surface; emitting an empty mesh.", i)
|
||||
v, f, c = torch.zeros((0, 3)), torch.zeros((0, 3), dtype=torch.int64), torch.zeros((0, 3))
|
||||
else:
|
||||
v, f, c = res
|
||||
@ -1364,8 +1361,8 @@ class GaussianToMesh(IO.ComfyNode):
|
||||
class GaussianExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [GaussianToFile3D, File3DToGaussian, RenderGaussian, TransformGaussian, GaussianInfo,
|
||||
MergeGaussian, GaussianToMesh]
|
||||
return [SplatToFile3D, File3DToSplat, RenderSplat, TransformSplat, GetSplatCount,
|
||||
MergeSplat, SplatToMesh]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> GaussianExtension:
|
||||
Loading…
Reference in New Issue
Block a user