diff --git a/comfy_extras/nodes_gaussian_splat.py b/comfy_extras/nodes_gaussian_splat.py index 251ae65ec..a44ec5e86 100644 --- a/comfy_extras/nodes_gaussian_splat.py +++ b/comfy_extras/nodes_gaussian_splat.py @@ -568,6 +568,77 @@ def _camera_basis(camera_info, dev): return eye, target, W[0], W[1], W[2] +def _lookat_quat_wxyz(position, target, dev): + # three.js lookAt in world frame: camera local +Z = (eye - target), up = world +Y. Returns wxyz. + z = position - target + z = z / z.norm().clamp_min(1e-8) + up0 = torch.tensor([0.0, 1.0, 0.0], device=dev) + if z.dot(up0).abs() > 0.999: # looking straight up/down + up0 = torch.tensor([0.0, 0.0, 1.0], device=dev) + x = torch.linalg.cross(up0, z) + x = x / x.norm().clamp_min(1e-8) + y = torch.linalg.cross(z, x) + R = torch.stack([x, y, z], dim=1) # columns = camera world axes + return _mat_to_quat(R[None])[0] + + +def _lookat_camera_info(position, target, fov, dev, zoom=1.0, camera_type="perspective", roll=0.0): + # Build a camera_info from a world-space (right-handed, Y-up) eye + look-at target; up = world +Y. + pos = torch.as_tensor(position, dtype=torch.float32, device=dev) + tgt = torch.as_tensor(target, dtype=torch.float32, device=dev) + q = _lookat_quat_wxyz(pos, tgt, dev) + if roll: # roll about the view axis (camera local Z) + a = math.radians(roll) + qz = torch.tensor([math.cos(a / 2), 0.0, 0.0, math.sin(a / 2)], device=dev) + q = _quat_mul(q[None], qz[None])[0] + xyz = lambda v: {"x": float(v[0]), "y": float(v[1]), "z": float(v[2])} + return {"position": xyz(pos), "target": xyz(tgt), + "quaternion": {"x": float(q[1]), "y": float(q[2]), "z": float(q[3]), "w": float(q[0])}, + "fov": float(fov), "cameraType": str(camera_type), "zoom": float(zoom)} + + +def _quat_camera_info(position, quat_xyzw, fov, dev, zoom=1.0, camera_type="perspective"): + # camera_info from an explicit world position + camera-rotation quaternion (three.js: looks down local -Z). + pos = torch.as_tensor(position, dtype=torch.float32, device=dev) + qx, qy, qz, qw = (float(c) for c in quat_xyzw) + qwxyz = torch.tensor([qw, qx, qy, qz], dtype=torch.float32, device=dev) + qwxyz = qwxyz / qwxyz.norm().clamp_min(1e-8) + R = _quat_to_mat(qwxyz[None])[0] + tgt = pos - R[:, 2] # look one unit down local -Z + xyz = lambda v: {"x": float(v[0]), "y": float(v[1]), "z": float(v[2])} + return {"position": xyz(pos), "target": xyz(tgt), + "quaternion": {"x": float(qwxyz[1]), "y": float(qwxyz[2]), "z": float(qwxyz[3]), "w": float(qwxyz[0])}, + "fov": float(fov), "cameraType": str(camera_type), "zoom": float(zoom)} + + +def _orbit_camera_info(yaw, pitch, distance, fov, pivot_splat, dev): + # Orbit helper for RenderSplat's default camera: yaw/pitch about `pivot_splat` (splat frame) at `distance`. + # World<->splat is the (x,-y,-z) map, so _camera_basis recovers exactly _view_matrix_t(yaw, pitch). + y, p = math.radians(yaw), math.radians(pitch) + cy, sy, cp, sp = math.cos(y), math.sin(y), math.cos(p), math.sin(p) + fwd_splat = torch.tensor([-cp * sy, sp, cp * cy], device=dev) # == _view_matrix_t(yaw, pitch)[2] + m = lambda v: torch.stack([v[0], -v[1], -v[2]]) # splat<->world (its own inverse) + return _lookat_camera_info(m(pivot_splat - distance * fwd_splat), m(pivot_splat), fov, dev) + + +def _orbit_camera_info_yaw(camera_info, angle_deg, dev): + # Turntable: rigidly rotate a camera_info about world +Y around its target by angle_deg. Returns a new dict. + a = math.radians(angle_deg) + ca, sa = math.cos(a), math.sin(a) + v = lambda d: torch.tensor([float(d.get("x", 0.0)), float(d.get("y", 0.0)), float(d.get("z", 0.0))], device=dev) + pos, tgt = v(camera_info.get("position", {})), v(camera_info.get("target", {})) + Ry = torch.tensor([[ca, 0.0, sa], [0.0, 1.0, 0.0], [-sa, 0.0, ca]], device=dev) + new_pos = tgt + Ry @ (pos - tgt) + q = camera_info.get("quaternion") or {} + qcur = torch.tensor([float(q.get("w", 1.0)), float(q.get("x", 0.0)), + float(q.get("y", 0.0)), float(q.get("z", 0.0))], device=dev) + qy = torch.tensor([math.cos(a / 2), 0.0, math.sin(a / 2), 0.0], device=dev) # world +Y rotation + qn = _quat_mul(qy[None], qcur[None])[0] + xyz = lambda t: {"x": float(t[0]), "y": float(t[1]), "z": float(t[2])} + return {**camera_info, "position": xyz(new_pos), + "quaternion": {"x": float(qn[1]), "y": float(qn[2]), "z": float(qn[3]), "w": float(qn[0])}} + + def _gauss_blur(x, sigma, dev): # Separable Gaussian blur of (1, C, H, W). Used to denoise the screen-space normal map. r = max(1, int(round(3 * sigma))) @@ -579,9 +650,8 @@ def _gauss_blur(x, sigma, dev): return x -def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, fov, splat_scale, bg, sharpen=1.0, - headlight_shading=0.0, render_style="color", camera_info=None, - yaw=35.0, pitch=30.0, zoom=1.0, distance=0.0): +def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, bg, camera_info, + sharpen=1.0, headlight_shading=0.0, render_style="color"): # Perspective-correct anisotropic gaussian-splat rasterizer. Each splat is weighted by its 3D Gaussian's # peak along each pixel's ray (AAA / Hahlbohm), composited front-to-back across depth slabs. `render_style` # selects the image: color / clay / depth / normal. Returns (image HxWx3, coverage mask HxW) on CPU. @@ -605,24 +675,12 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, fov, splat_sc if xyz.shape[0] == 0: # empty input (e.g. all culled by opacity_threshold) return background_only() - if camera_info is not None: - eye, target, right, up, fwd = _camera_basis(camera_info, dev) - d = (target - eye).norm().clamp_min(1e-6) - eye = target - fwd * (d / max(zoom, 1e-3)) # zoom is relative to camera_info: 1.0 = as authored - W = torch.stack([right, up, fwd], 0) # rows = camera axes (world -> camera) - cam = (xyz - eye) @ W.T - cam_fov = float(camera_info.get("fov", 0) or 0) # fov=0 -> take it from the camera (or 35 if absent) - fov = fov if fov > 0 else (cam_fov if cam_fov > 0 else 35.0) - yflip = 1.0 # match the orbit path (splat +Y is visually down) - else: - fov = fov if fov > 0 else 35.0 # fov=0 -> default 35 - center = xyz.mean(0) - extent = (xyz - center).norm(dim=-1).quantile(0.99).clamp_min(1e-4) # ignore outlier floaters - base = distance if distance > 0 else extent / (math.tan(math.radians(fov) / 2) * 0.9) # absolute dist, else auto-frame - dist = base / max(zoom, 1e-3) - W = _view_matrix_t(yaw, pitch, dev) - cam = (xyz - center) @ W.T + torch.tensor([0.0, 0.0, dist], device=dev) - yflip = 1.0 + eye, _, right, up, fwd = _camera_basis(camera_info, dev) # all camera state comes from camera_info + W = torch.stack([right, up, fwd], 0) # rows = camera axes (world -> camera) + cam = (xyz - eye) @ W.T + fov = float(camera_info.get("fov", 0) or 0) or 35.0 + zoom = float(camera_info.get("zoom", 1.0) or 1.0) # three.js digital zoom: scales the focal length + yflip = 1.0 # splat +Y is image-down xc, yc, zc = cam.unbind(-1) keep = zc > 1e-2 @@ -632,7 +690,7 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, fov, splat_sc if render_style == "clay": rgb = torch.full_like(rgb, 0.75) # neutral albedo -> shading shows pure geometry - f = (min(width, height) / 2) / math.tan(math.radians(fov) / 2) # fov over the smaller axis -> object fits + f = (min(width, height) / 2) / math.tan(math.radians(fov) / 2) * zoom # fov over the smaller axis, x camera zoom invz = 1.0 / zc cx0, cy0 = width / 2, height / 2 @@ -794,26 +852,16 @@ class RenderSplat(IO.ComfyNode): search_aliases=["splat to image", "render splat", "gaussian turntable"], category="3d/gaussian", description="Render a gaussian splat to an image with an anisotropic EWA rasterizer (oriented " - "elliptical splats, antialiased, depth-sorted front-to-back). Set frames greater than 1 " - "to sweep the camera yaw through a full 360° rotation, producing a batch of images " - "(a turntable) that you can feed into a video node.", + "elliptical splats, antialiased, depth-sorted front-to-back). The camera comes from a " + "camera_info input (Load3D / Preview3D, or a Create Camera Info node); leave it empty to " + "auto-frame the splat. Set frames greater than 1 for a turntable batch to feed a video node.", inputs=[ IO.Splat.Input("splat"), IO.Int.Input("width", default=1024, min=64, max=2048, step=8), IO.Int.Input("height", default=1024, min=64, max=2048, step=8), IO.Int.Input("frames", default=1, min=-240, max=240, - tooltip="+/-1 = single still at the given yaw; magnitude >1 = orbit, yaw swept over a " - "full turn. Negative orbits the opposite direction. 0 = single still."), - IO.Float.Input("yaw", default=35.0, min=-360.0, max=360.0, step=1.0), - IO.Float.Input("pitch", default=30.0, min=-89.0, max=89.0, step=1.0), - IO.Float.Input("zoom", default=1.0, min=0.1, max=5.0, step=0.05, - tooltip="Camera dolly: >1 zooms in, <1 out. With camera_info or distance, 1.0 is exactly " - "that camera; otherwise 1.0 frames the whole splat (~10% margin)."), - IO.Float.Input("distance", default=0.0, min=0.0, max=1000.0, step=0.01, - tooltip="Absolute camera distance for the yaw/pitch orbit (from Get Camera Info). " - "0 = auto-frame the whole splat. Ignored when camera_info is connected."), - IO.Float.Input("fov", default=0.0, min=0.0, max=120.0, step=1.0, - tooltip="Vertical field of view in degrees. 0 = camera_info if provided, otherwise defaults to 35. Any value above 0 overrides the camera_info FoV."), + tooltip="+/-1 = single still; magnitude >1 = turntable, the camera orbited over a full " + "360 turn (works with any camera_info). Negative orbits the other way. 0 = single still."), IO.Float.Input("splat_scale", default=1.0, min=0.1, max=5.0, step=0.05, advanced=True, tooltip="Multiplier on each splat's projected footprint (lower = crisper points, " "higher = softer/fuller surface)."), @@ -837,17 +885,16 @@ class RenderSplat(IO.ComfyNode): "background colour). Resized to the render size; a batch is used per frame, " "a single image for all. color/clay only."), IO.Load3DCamera.Input("camera_info", optional=True, - tooltip="Render from this exact camera (e.g. from Load3D / Preview3D) " - "instead of the yaw/pitch orbit. Disables the turntable sweep."), + tooltip="Camera to render from - a Load3D / Preview3D camera or a Create Camera " + "Info node. If empty, the splat is auto-framed from a default 3/4 view."), ], outputs=[IO.Image.Output(display_name="image"), IO.Mask.Output(display_name="mask")], ) @classmethod - def execute(cls, splat, width, height, yaw, pitch, frames, zoom, fov, splat_scale, sharpen, - headlight_shading, opacity_threshold, background, render_style, distance=0.0, - camera_info=None, bg_image=None) -> IO.NodeOutput: + def execute(cls, splat, width, height, frames, splat_scale, sharpen, headlight_shading, + opacity_threshold, background, render_style, camera_info=None, bg_image=None) -> IO.NodeOutput: bg = _hex_to_rgb(background) bg_imgs = None if bg_image is not None: # resize the plate(s) to the render size: (B,H,W,3) @@ -855,12 +902,8 @@ class RenderSplat(IO.ComfyNode): bg_imgs = bi.movedim(1, -1).clamp(0, 1) n_frames = abs(int(frames)) or 1 # magnitude = frame count (0 -> single still) orbit_dir = -1.0 if frames < 0 else 1.0 # sign = orbit direction - if camera_info is not None: - if n_frames > 1: - logging.warning("RenderSplat: camera_info is a fixed camera; ignoring frames=%d (no orbit sweep).", frames) - n_frames = 1 - if str(camera_info.get("cameraType", "")).lower().startswith("ortho"): - logging.warning("RenderSplat: orthographic camera_info is rendered with a perspective camera.") + if camera_info is not None and str(camera_info.get("cameraType", "")).lower().startswith("ortho"): + logging.warning("RenderSplat: orthographic camera_info is rendered with a perspective camera.") imgs, masks = [], [] device = comfy.model_management.get_torch_device() # render device; splat stays in torch here -> no roundtrip total = splat.positions.shape[0] * n_frames @@ -871,13 +914,20 @@ class RenderSplat(IO.ComfyNode): if opacity_threshold > 0: keep = opacity >= opacity_threshold xyz, rgb, opacity, scale, rot = xyz[keep], rgb[keep], opacity[keep], scale[keep], rot[keep] + base_cam = camera_info + if base_cam is None: # no camera -> default 3/4 view, auto-framed on the splat + center = xyz.mean(0) if xyz.shape[0] else torch.zeros(3, device=device) + extent = ((xyz - center).norm(dim=-1).quantile(0.99).clamp_min(1e-4) if xyz.shape[0] + else torch.tensor(1.0, device=device)) + dist = float(extent / (math.tan(math.radians(35.0) / 2) * 0.9)) + base_cam = _orbit_camera_info(35.0, 30.0, dist, 35.0, center, device) for fr in range(n_frames): - y = yaw + (orbit_dir * 360.0 * fr / n_frames if n_frames > 1 else 0.0) + cam_fr = (base_cam if n_frames == 1 + else _orbit_camera_info_yaw(base_cam, orbit_dir * 360.0 * fr / n_frames, device)) bg_k = bg_imgs[k % bg_imgs.shape[0]] if bg_imgs is not None else bg # per-frame plate, or solid colour - img, mask = _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, fov, splat_scale, bg_k, + img, mask = _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, bg_k, cam_fr, sharpen=sharpen, headlight_shading=headlight_shading, - render_style=render_style, camera_info=camera_info, yaw=y, pitch=pitch, - zoom=zoom, distance=distance) + render_style=render_style) imgs.append(img) masks.append(mask) k += 1 @@ -886,6 +936,81 @@ class RenderSplat(IO.ComfyNode): return IO.NodeOutput(torch.stack(imgs), torch.stack(masks)) +class CreateCameraInfo(IO.ComfyNode): # TODO: move to better file + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="CreateCameraInfo", + display_name="Create Camera Info", + search_aliases=["camera position", "make camera info", "orbit camera", "look at camera"], + category="3d/camera", + description="Build a camera_info" + "Mode 'orbit' aims with yaw/pitch/distance around the target; " + "'look_at' places the camera at world position. Coordinates are the viewer's world space (right-handed,Y-up).", + inputs=[ + IO.DynamicCombo.Input("mode", options=[ + IO.DynamicCombo.Option("orbit", [ + IO.Float.Input("yaw", default=35.0, min=-360.0, max=360.0, step=1.0), + IO.Float.Input("pitch", default=30.0, min=-89.0, max=89.0, step=1.0), + IO.Float.Input("distance", default=4.0, min=0.01, max=1000.0, step=0.01, + tooltip="Camera distance from the target."), + ]), + IO.DynamicCombo.Option("look_at", [ + IO.Float.Input("position_x", default=4.0, min=-1000.0, max=1000.0, step=0.01, + tooltip="Camera position in world space (right-handed, Y-up)."), + IO.Float.Input("position_y", default=4.0, min=-1000.0, max=1000.0, step=0.01), + IO.Float.Input("position_z", default=4.0, min=-1000.0, max=1000.0, step=0.01), + ]), + IO.DynamicCombo.Option("quaternion", [ + IO.Float.Input("position_x", default=4.0, min=-1000.0, max=1000.0, step=0.01, + tooltip="Camera position in world space (right-handed, Y-up)."), + IO.Float.Input("position_y", default=4.0, min=-1000.0, max=1000.0, step=0.01), + IO.Float.Input("position_z", default=4.0, min=-1000.0, max=1000.0, step=0.01), + IO.Float.Input("quat_x", default=0.0, min=-1.0, max=1.0, step=0.001), + IO.Float.Input("quat_y", default=0.0, min=-1.0, max=1.0, step=0.001), + IO.Float.Input("quat_z", default=0.0, min=-1.0, max=1.0, step=0.001), + IO.Float.Input("quat_w", default=1.0, min=-1.0, max=1.0, step=0.001, + tooltip="Camera world-rotation quaternion (three.js: looks down local -Z). Normalized for you."), + ]), + ], tooltip="How to define the camera: orbit angles, an explicit position, or a position + quaternion."), + IO.Float.Input("target_x", default=0.0, min=-1000.0, max=1000.0, step=0.01, advanced=True, + tooltip="Look-at point (orbit pivot / aim). In orbit mode, move it to pan/translate the " + "whole camera. Ignored in quaternion mode. Defaults to the origin."), + IO.Float.Input("target_y", default=0.0, min=-1000.0, max=1000.0, step=0.01, advanced=True), + IO.Float.Input("target_z", default=0.0, min=-1000.0, max=1000.0, step=0.01, advanced=True), + IO.Float.Input("roll", default=0.0, min=-180.0, max=180.0, step=1.0, + tooltip="Camera roll about the view axis, degrees."), + IO.Float.Input("fov", default=35.0, min=1.0, max=120.0, step=1.0, + tooltip="Vertical field of view in degrees."), + IO.Float.Input("zoom", default=1.0, min=0.01, max=100.0, step=0.01, + tooltip="Digital zoom (focal-length multiplier). >1 zooms in without moving the camera."), + IO.Combo.Input("camera_type", options=["perspective", "orthographic"], + tooltip="orthographic is currently rendered as perspective by Render Splat."), + ], + outputs=[IO.Load3DCamera.Output(display_name="camera_info")], + ) + + @classmethod + def execute(cls, mode, target_x, target_y, target_z, roll, fov, + zoom=1.0, camera_type="perspective") -> IO.NodeOutput: + dev = comfy.model_management.get_torch_device() + kind = mode["mode"] + if kind == "quaternion": # explicit world position + camera rotation + position = [mode["position_x"], mode["position_y"], mode["position_z"]] + quat = [mode["quat_x"], mode["quat_y"], mode["quat_z"], mode["quat_w"]] + return IO.NodeOutput(_quat_camera_info(position, quat, fov, dev, zoom=zoom, camera_type=camera_type)) + target = [target_x, target_y, target_z] # orbit pivot / aim; move it to pan the whole camera + if kind == "orbit": # yaw/pitch/distance about the target (world Y-up) + y, p = math.radians(mode["yaw"]), math.radians(mode["pitch"]) + cy, sy, cp, sp = math.cos(y), math.sin(y), math.cos(p), math.sin(p) + d = mode["distance"] + position = [target_x + d * cp * sy, target_y + d * sp, target_z + d * cp * cy] + else: # look_at: explicit world-space camera position + position = [mode["position_x"], mode["position_y"], mode["position_z"]] + return IO.NodeOutput(_lookat_camera_info(position, target, fov, dev, + zoom=zoom, camera_type=camera_type, roll=roll)) + + class TransformSplat(IO.ComfyNode): @classmethod def define_schema(cls): @@ -931,7 +1056,7 @@ class TransformSplat(IO.ComfyNode): rg = _quat_to_mat(splat.rotations.reshape(-1, 4)) # (M,3,3) per-splat rotation s2 = splat.scales.reshape(-1, 3).square() cov = (rg * s2[:, None, :]) @ rg.transpose(-1, -2) # Sigma - cov = A @ cov @ A.T # A Sigma A^T (A broadcast over splats) + cov = A @ cov @ A.T # A Sigma A^T (A broadcast over splats) lam, V = torch.linalg.eigh(cov) # symmetric -> eigenvalues (asc), orthonormal axes V = V * torch.where(torch.linalg.det(V) < 0, -1.0, 1.0)[..., None, None] # keep a proper rotation scales = lam.clamp_min(0).sqrt().reshape(splat.scales.shape) @@ -1370,8 +1495,8 @@ class SplatToMesh(IO.ComfyNode): class GaussianExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: - return [SplatToFile3D, File3DToSplat, RenderSplat, TransformSplat, GetSplatCount, - MergeSplat, SplatToMesh] + return [SplatToFile3D, File3DToSplat, RenderSplat, CreateCameraInfo, TransformSplat, + GetSplatCount, MergeSplat, SplatToMesh] async def comfy_entrypoint() -> GaussianExtension: