Refactor render camera use

This commit is contained in:
kijai 2026-05-31 03:14:15 +03:00
parent fbd3ab6417
commit 52ecc98fa0

View File

@ -568,6 +568,77 @@ def _camera_basis(camera_info, dev):
return eye, target, W[0], W[1], W[2]
def _lookat_quat_wxyz(position, target, dev):
# three.js lookAt in world frame: camera local +Z = (eye - target), up = world +Y. Returns wxyz.
z = position - target
z = z / z.norm().clamp_min(1e-8)
up0 = torch.tensor([0.0, 1.0, 0.0], device=dev)
if z.dot(up0).abs() > 0.999: # looking straight up/down
up0 = torch.tensor([0.0, 0.0, 1.0], device=dev)
x = torch.linalg.cross(up0, z)
x = x / x.norm().clamp_min(1e-8)
y = torch.linalg.cross(z, x)
R = torch.stack([x, y, z], dim=1) # columns = camera world axes
return _mat_to_quat(R[None])[0]
def _lookat_camera_info(position, target, fov, dev, zoom=1.0, camera_type="perspective", roll=0.0):
# Build a camera_info from a world-space (right-handed, Y-up) eye + look-at target; up = world +Y.
pos = torch.as_tensor(position, dtype=torch.float32, device=dev)
tgt = torch.as_tensor(target, dtype=torch.float32, device=dev)
q = _lookat_quat_wxyz(pos, tgt, dev)
if roll: # roll about the view axis (camera local Z)
a = math.radians(roll)
qz = torch.tensor([math.cos(a / 2), 0.0, 0.0, math.sin(a / 2)], device=dev)
q = _quat_mul(q[None], qz[None])[0]
xyz = lambda v: {"x": float(v[0]), "y": float(v[1]), "z": float(v[2])}
return {"position": xyz(pos), "target": xyz(tgt),
"quaternion": {"x": float(q[1]), "y": float(q[2]), "z": float(q[3]), "w": float(q[0])},
"fov": float(fov), "cameraType": str(camera_type), "zoom": float(zoom)}
def _quat_camera_info(position, quat_xyzw, fov, dev, zoom=1.0, camera_type="perspective"):
# camera_info from an explicit world position + camera-rotation quaternion (three.js: looks down local -Z).
pos = torch.as_tensor(position, dtype=torch.float32, device=dev)
qx, qy, qz, qw = (float(c) for c in quat_xyzw)
qwxyz = torch.tensor([qw, qx, qy, qz], dtype=torch.float32, device=dev)
qwxyz = qwxyz / qwxyz.norm().clamp_min(1e-8)
R = _quat_to_mat(qwxyz[None])[0]
tgt = pos - R[:, 2] # look one unit down local -Z
xyz = lambda v: {"x": float(v[0]), "y": float(v[1]), "z": float(v[2])}
return {"position": xyz(pos), "target": xyz(tgt),
"quaternion": {"x": float(qwxyz[1]), "y": float(qwxyz[2]), "z": float(qwxyz[3]), "w": float(qwxyz[0])},
"fov": float(fov), "cameraType": str(camera_type), "zoom": float(zoom)}
def _orbit_camera_info(yaw, pitch, distance, fov, pivot_splat, dev):
# Orbit helper for RenderSplat's default camera: yaw/pitch about `pivot_splat` (splat frame) at `distance`.
# World<->splat is the (x,-y,-z) map, so _camera_basis recovers exactly _view_matrix_t(yaw, pitch).
y, p = math.radians(yaw), math.radians(pitch)
cy, sy, cp, sp = math.cos(y), math.sin(y), math.cos(p), math.sin(p)
fwd_splat = torch.tensor([-cp * sy, sp, cp * cy], device=dev) # == _view_matrix_t(yaw, pitch)[2]
m = lambda v: torch.stack([v[0], -v[1], -v[2]]) # splat<->world (its own inverse)
return _lookat_camera_info(m(pivot_splat - distance * fwd_splat), m(pivot_splat), fov, dev)
def _orbit_camera_info_yaw(camera_info, angle_deg, dev):
# Turntable: rigidly rotate a camera_info about world +Y around its target by angle_deg. Returns a new dict.
a = math.radians(angle_deg)
ca, sa = math.cos(a), math.sin(a)
v = lambda d: torch.tensor([float(d.get("x", 0.0)), float(d.get("y", 0.0)), float(d.get("z", 0.0))], device=dev)
pos, tgt = v(camera_info.get("position", {})), v(camera_info.get("target", {}))
Ry = torch.tensor([[ca, 0.0, sa], [0.0, 1.0, 0.0], [-sa, 0.0, ca]], device=dev)
new_pos = tgt + Ry @ (pos - tgt)
q = camera_info.get("quaternion") or {}
qcur = torch.tensor([float(q.get("w", 1.0)), float(q.get("x", 0.0)),
float(q.get("y", 0.0)), float(q.get("z", 0.0))], device=dev)
qy = torch.tensor([math.cos(a / 2), 0.0, math.sin(a / 2), 0.0], device=dev) # world +Y rotation
qn = _quat_mul(qy[None], qcur[None])[0]
xyz = lambda t: {"x": float(t[0]), "y": float(t[1]), "z": float(t[2])}
return {**camera_info, "position": xyz(new_pos),
"quaternion": {"x": float(qn[1]), "y": float(qn[2]), "z": float(qn[3]), "w": float(qn[0])}}
def _gauss_blur(x, sigma, dev):
# Separable Gaussian blur of (1, C, H, W). Used to denoise the screen-space normal map.
r = max(1, int(round(3 * sigma)))
@ -579,9 +650,8 @@ def _gauss_blur(x, sigma, dev):
return x
def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, fov, splat_scale, bg, sharpen=1.0,
headlight_shading=0.0, render_style="color", camera_info=None,
yaw=35.0, pitch=30.0, zoom=1.0, distance=0.0):
def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, bg, camera_info,
sharpen=1.0, headlight_shading=0.0, render_style="color"):
# Perspective-correct anisotropic gaussian-splat rasterizer. Each splat is weighted by its 3D Gaussian's
# peak along each pixel's ray (AAA / Hahlbohm), composited front-to-back across depth slabs. `render_style`
# selects the image: color / clay / depth / normal. Returns (image HxWx3, coverage mask HxW) on CPU.
@ -605,24 +675,12 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, fov, splat_sc
if xyz.shape[0] == 0: # empty input (e.g. all culled by opacity_threshold)
return background_only()
if camera_info is not None:
eye, target, right, up, fwd = _camera_basis(camera_info, dev)
d = (target - eye).norm().clamp_min(1e-6)
eye = target - fwd * (d / max(zoom, 1e-3)) # zoom is relative to camera_info: 1.0 = as authored
W = torch.stack([right, up, fwd], 0) # rows = camera axes (world -> camera)
cam = (xyz - eye) @ W.T
cam_fov = float(camera_info.get("fov", 0) or 0) # fov=0 -> take it from the camera (or 35 if absent)
fov = fov if fov > 0 else (cam_fov if cam_fov > 0 else 35.0)
yflip = 1.0 # match the orbit path (splat +Y is visually down)
else:
fov = fov if fov > 0 else 35.0 # fov=0 -> default 35
center = xyz.mean(0)
extent = (xyz - center).norm(dim=-1).quantile(0.99).clamp_min(1e-4) # ignore outlier floaters
base = distance if distance > 0 else extent / (math.tan(math.radians(fov) / 2) * 0.9) # absolute dist, else auto-frame
dist = base / max(zoom, 1e-3)
W = _view_matrix_t(yaw, pitch, dev)
cam = (xyz - center) @ W.T + torch.tensor([0.0, 0.0, dist], device=dev)
yflip = 1.0
eye, _, right, up, fwd = _camera_basis(camera_info, dev) # all camera state comes from camera_info
W = torch.stack([right, up, fwd], 0) # rows = camera axes (world -> camera)
cam = (xyz - eye) @ W.T
fov = float(camera_info.get("fov", 0) or 0) or 35.0
zoom = float(camera_info.get("zoom", 1.0) or 1.0) # three.js digital zoom: scales the focal length
yflip = 1.0 # splat +Y is image-down
xc, yc, zc = cam.unbind(-1)
keep = zc > 1e-2
@ -632,7 +690,7 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, fov, splat_sc
if render_style == "clay":
rgb = torch.full_like(rgb, 0.75) # neutral albedo -> shading shows pure geometry
f = (min(width, height) / 2) / math.tan(math.radians(fov) / 2) # fov over the smaller axis -> object fits
f = (min(width, height) / 2) / math.tan(math.radians(fov) / 2) * zoom # fov over the smaller axis, x camera zoom
invz = 1.0 / zc
cx0, cy0 = width / 2, height / 2
@ -794,26 +852,16 @@ class RenderSplat(IO.ComfyNode):
search_aliases=["splat to image", "render splat", "gaussian turntable"],
category="3d/gaussian",
description="Render a gaussian splat to an image with an anisotropic EWA rasterizer (oriented "
"elliptical splats, antialiased, depth-sorted front-to-back). Set frames greater than 1 "
"to sweep the camera yaw through a full 360° rotation, producing a batch of images "
"(a turntable) that you can feed into a video node.",
"elliptical splats, antialiased, depth-sorted front-to-back). The camera comes from a "
"camera_info input (Load3D / Preview3D, or a Create Camera Info node); leave it empty to "
"auto-frame the splat. Set frames greater than 1 for a turntable batch to feed a video node.",
inputs=[
IO.Splat.Input("splat"),
IO.Int.Input("width", default=1024, min=64, max=2048, step=8),
IO.Int.Input("height", default=1024, min=64, max=2048, step=8),
IO.Int.Input("frames", default=1, min=-240, max=240,
tooltip="+/-1 = single still at the given yaw; magnitude >1 = orbit, yaw swept over a "
"full turn. Negative orbits the opposite direction. 0 = single still."),
IO.Float.Input("yaw", default=35.0, min=-360.0, max=360.0, step=1.0),
IO.Float.Input("pitch", default=30.0, min=-89.0, max=89.0, step=1.0),
IO.Float.Input("zoom", default=1.0, min=0.1, max=5.0, step=0.05,
tooltip="Camera dolly: >1 zooms in, <1 out. With camera_info or distance, 1.0 is exactly "
"that camera; otherwise 1.0 frames the whole splat (~10% margin)."),
IO.Float.Input("distance", default=0.0, min=0.0, max=1000.0, step=0.01,
tooltip="Absolute camera distance for the yaw/pitch orbit (from Get Camera Info). "
"0 = auto-frame the whole splat. Ignored when camera_info is connected."),
IO.Float.Input("fov", default=0.0, min=0.0, max=120.0, step=1.0,
tooltip="Vertical field of view in degrees. 0 = camera_info if provided, otherwise defaults to 35. Any value above 0 overrides the camera_info FoV."),
tooltip="+/-1 = single still; magnitude >1 = turntable, the camera orbited over a full "
"360 turn (works with any camera_info). Negative orbits the other way. 0 = single still."),
IO.Float.Input("splat_scale", default=1.0, min=0.1, max=5.0, step=0.05, advanced=True,
tooltip="Multiplier on each splat's projected footprint (lower = crisper points, "
"higher = softer/fuller surface)."),
@ -837,17 +885,16 @@ class RenderSplat(IO.ComfyNode):
"background colour). Resized to the render size; a batch is used per frame, "
"a single image for all. color/clay only."),
IO.Load3DCamera.Input("camera_info", optional=True,
tooltip="Render from this exact camera (e.g. from Load3D / Preview3D) "
"instead of the yaw/pitch orbit. Disables the turntable sweep."),
tooltip="Camera to render from - a Load3D / Preview3D camera or a Create Camera "
"Info node. If empty, the splat is auto-framed from a default 3/4 view."),
],
outputs=[IO.Image.Output(display_name="image"),
IO.Mask.Output(display_name="mask")],
)
@classmethod
def execute(cls, splat, width, height, yaw, pitch, frames, zoom, fov, splat_scale, sharpen,
headlight_shading, opacity_threshold, background, render_style, distance=0.0,
camera_info=None, bg_image=None) -> IO.NodeOutput:
def execute(cls, splat, width, height, frames, splat_scale, sharpen, headlight_shading,
opacity_threshold, background, render_style, camera_info=None, bg_image=None) -> IO.NodeOutput:
bg = _hex_to_rgb(background)
bg_imgs = None
if bg_image is not None: # resize the plate(s) to the render size: (B,H,W,3)
@ -855,12 +902,8 @@ class RenderSplat(IO.ComfyNode):
bg_imgs = bi.movedim(1, -1).clamp(0, 1)
n_frames = abs(int(frames)) or 1 # magnitude = frame count (0 -> single still)
orbit_dir = -1.0 if frames < 0 else 1.0 # sign = orbit direction
if camera_info is not None:
if n_frames > 1:
logging.warning("RenderSplat: camera_info is a fixed camera; ignoring frames=%d (no orbit sweep).", frames)
n_frames = 1
if str(camera_info.get("cameraType", "")).lower().startswith("ortho"):
logging.warning("RenderSplat: orthographic camera_info is rendered with a perspective camera.")
if camera_info is not None and str(camera_info.get("cameraType", "")).lower().startswith("ortho"):
logging.warning("RenderSplat: orthographic camera_info is rendered with a perspective camera.")
imgs, masks = [], []
device = comfy.model_management.get_torch_device() # render device; splat stays in torch here -> no roundtrip
total = splat.positions.shape[0] * n_frames
@ -871,13 +914,20 @@ class RenderSplat(IO.ComfyNode):
if opacity_threshold > 0:
keep = opacity >= opacity_threshold
xyz, rgb, opacity, scale, rot = xyz[keep], rgb[keep], opacity[keep], scale[keep], rot[keep]
base_cam = camera_info
if base_cam is None: # no camera -> default 3/4 view, auto-framed on the splat
center = xyz.mean(0) if xyz.shape[0] else torch.zeros(3, device=device)
extent = ((xyz - center).norm(dim=-1).quantile(0.99).clamp_min(1e-4) if xyz.shape[0]
else torch.tensor(1.0, device=device))
dist = float(extent / (math.tan(math.radians(35.0) / 2) * 0.9))
base_cam = _orbit_camera_info(35.0, 30.0, dist, 35.0, center, device)
for fr in range(n_frames):
y = yaw + (orbit_dir * 360.0 * fr / n_frames if n_frames > 1 else 0.0)
cam_fr = (base_cam if n_frames == 1
else _orbit_camera_info_yaw(base_cam, orbit_dir * 360.0 * fr / n_frames, device))
bg_k = bg_imgs[k % bg_imgs.shape[0]] if bg_imgs is not None else bg # per-frame plate, or solid colour
img, mask = _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, fov, splat_scale, bg_k,
img, mask = _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, bg_k, cam_fr,
sharpen=sharpen, headlight_shading=headlight_shading,
render_style=render_style, camera_info=camera_info, yaw=y, pitch=pitch,
zoom=zoom, distance=distance)
render_style=render_style)
imgs.append(img)
masks.append(mask)
k += 1
@ -886,6 +936,81 @@ class RenderSplat(IO.ComfyNode):
return IO.NodeOutput(torch.stack(imgs), torch.stack(masks))
class CreateCameraInfo(IO.ComfyNode): # TODO: move to better file
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="CreateCameraInfo",
display_name="Create Camera Info",
search_aliases=["camera position", "make camera info", "orbit camera", "look at camera"],
category="3d/camera",
description="Build a camera_info"
"Mode 'orbit' aims with yaw/pitch/distance around the target; "
"'look_at' places the camera at world position. Coordinates are the viewer's world space (right-handed,Y-up).",
inputs=[
IO.DynamicCombo.Input("mode", options=[
IO.DynamicCombo.Option("orbit", [
IO.Float.Input("yaw", default=35.0, min=-360.0, max=360.0, step=1.0),
IO.Float.Input("pitch", default=30.0, min=-89.0, max=89.0, step=1.0),
IO.Float.Input("distance", default=4.0, min=0.01, max=1000.0, step=0.01,
tooltip="Camera distance from the target."),
]),
IO.DynamicCombo.Option("look_at", [
IO.Float.Input("position_x", default=4.0, min=-1000.0, max=1000.0, step=0.01,
tooltip="Camera position in world space (right-handed, Y-up)."),
IO.Float.Input("position_y", default=4.0, min=-1000.0, max=1000.0, step=0.01),
IO.Float.Input("position_z", default=4.0, min=-1000.0, max=1000.0, step=0.01),
]),
IO.DynamicCombo.Option("quaternion", [
IO.Float.Input("position_x", default=4.0, min=-1000.0, max=1000.0, step=0.01,
tooltip="Camera position in world space (right-handed, Y-up)."),
IO.Float.Input("position_y", default=4.0, min=-1000.0, max=1000.0, step=0.01),
IO.Float.Input("position_z", default=4.0, min=-1000.0, max=1000.0, step=0.01),
IO.Float.Input("quat_x", default=0.0, min=-1.0, max=1.0, step=0.001),
IO.Float.Input("quat_y", default=0.0, min=-1.0, max=1.0, step=0.001),
IO.Float.Input("quat_z", default=0.0, min=-1.0, max=1.0, step=0.001),
IO.Float.Input("quat_w", default=1.0, min=-1.0, max=1.0, step=0.001,
tooltip="Camera world-rotation quaternion (three.js: looks down local -Z). Normalized for you."),
]),
], tooltip="How to define the camera: orbit angles, an explicit position, or a position + quaternion."),
IO.Float.Input("target_x", default=0.0, min=-1000.0, max=1000.0, step=0.01, advanced=True,
tooltip="Look-at point (orbit pivot / aim). In orbit mode, move it to pan/translate the "
"whole camera. Ignored in quaternion mode. Defaults to the origin."),
IO.Float.Input("target_y", default=0.0, min=-1000.0, max=1000.0, step=0.01, advanced=True),
IO.Float.Input("target_z", default=0.0, min=-1000.0, max=1000.0, step=0.01, advanced=True),
IO.Float.Input("roll", default=0.0, min=-180.0, max=180.0, step=1.0,
tooltip="Camera roll about the view axis, degrees."),
IO.Float.Input("fov", default=35.0, min=1.0, max=120.0, step=1.0,
tooltip="Vertical field of view in degrees."),
IO.Float.Input("zoom", default=1.0, min=0.01, max=100.0, step=0.01,
tooltip="Digital zoom (focal-length multiplier). >1 zooms in without moving the camera."),
IO.Combo.Input("camera_type", options=["perspective", "orthographic"],
tooltip="orthographic is currently rendered as perspective by Render Splat."),
],
outputs=[IO.Load3DCamera.Output(display_name="camera_info")],
)
@classmethod
def execute(cls, mode, target_x, target_y, target_z, roll, fov,
zoom=1.0, camera_type="perspective") -> IO.NodeOutput:
dev = comfy.model_management.get_torch_device()
kind = mode["mode"]
if kind == "quaternion": # explicit world position + camera rotation
position = [mode["position_x"], mode["position_y"], mode["position_z"]]
quat = [mode["quat_x"], mode["quat_y"], mode["quat_z"], mode["quat_w"]]
return IO.NodeOutput(_quat_camera_info(position, quat, fov, dev, zoom=zoom, camera_type=camera_type))
target = [target_x, target_y, target_z] # orbit pivot / aim; move it to pan the whole camera
if kind == "orbit": # yaw/pitch/distance about the target (world Y-up)
y, p = math.radians(mode["yaw"]), math.radians(mode["pitch"])
cy, sy, cp, sp = math.cos(y), math.sin(y), math.cos(p), math.sin(p)
d = mode["distance"]
position = [target_x + d * cp * sy, target_y + d * sp, target_z + d * cp * cy]
else: # look_at: explicit world-space camera position
position = [mode["position_x"], mode["position_y"], mode["position_z"]]
return IO.NodeOutput(_lookat_camera_info(position, target, fov, dev,
zoom=zoom, camera_type=camera_type, roll=roll))
class TransformSplat(IO.ComfyNode):
@classmethod
def define_schema(cls):
@ -931,7 +1056,7 @@ class TransformSplat(IO.ComfyNode):
rg = _quat_to_mat(splat.rotations.reshape(-1, 4)) # (M,3,3) per-splat rotation
s2 = splat.scales.reshape(-1, 3).square()
cov = (rg * s2[:, None, :]) @ rg.transpose(-1, -2) # Sigma
cov = A @ cov @ A.T # A Sigma A^T (A broadcast over splats)
cov = A @ cov @ A.T # A Sigma A^T (A broadcast over splats)
lam, V = torch.linalg.eigh(cov) # symmetric -> eigenvalues (asc), orthonormal axes
V = V * torch.where(torch.linalg.det(V) < 0, -1.0, 1.0)[..., None, None] # keep a proper rotation
scales = lam.clamp_min(0).sqrt().reshape(splat.scales.shape)
@ -1370,8 +1495,8 @@ class SplatToMesh(IO.ComfyNode):
class GaussianExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [SplatToFile3D, File3DToSplat, RenderSplat, TransformSplat, GetSplatCount,
MergeSplat, SplatToMesh]
return [SplatToFile3D, File3DToSplat, RenderSplat, CreateCameraInfo, TransformSplat,
GetSplatCount, MergeSplat, SplatToMesh]
async def comfy_entrypoint() -> GaussianExtension: