ComfyUI/comfy_extras/sam3d_body/rasterizer.py
2026-05-26 02:15:15 +03:00

468 lines
18 KiB
Python

"""Pure-PyTorch rasterizer for SAM 3D Body meshes.
Algorithm: forward triangle rasterizer with hard z-buffer. Per-face screen
bbox cull → faces sorted by bbox size and chunked under a fixed pixel
budget → inside-test via edge functions, barycentric interpolation, depth
test via `scatter_reduce_(amin)`.
"""
from typing import Sequence
import numpy as np
import torch
import comfy.model_management
from .utils import jet_colormap
_CANONICAL_PRESETS = {"rainbow", "rainbow_face_normal", "rainbow_face_semantic"}
_rainbow_cache: dict = {}
def rainbow_colors_from_canonical(
positions: np.ndarray,
tilt_x_deg: float = 0.0,
tilt_z_deg: float = 0.0,
) -> np.ndarray:
"""Compute per-vertex jet-colormap RGB from canonical (T-pose, Y-up) vertices.
Args:
positions: (N_v, 3) canonical vertex positions, Y-up (head at max Y).
tilt_x_deg: rotation of the jet axis around X (in degrees). Positive
biases the ramp toward +Z (front).
tilt_z_deg: rotation of the jet axis around Z (in degrees). Positive
biases the ramp toward +X (right, in body frame).
Returns:
(N_v, 3) float32 RGB in [0, 1].
"""
key = (id(positions), round(float(tilt_x_deg), 3), round(float(tilt_z_deg), 3))
cached = _rainbow_cache.get(key)
if cached is not None:
return cached
theta_x = np.deg2rad(tilt_x_deg)
theta_z = np.deg2rad(tilt_z_deg)
axis = np.array([
np.sin(theta_z),
np.cos(theta_z) * np.cos(theta_x),
np.cos(theta_z) * np.sin(theta_x),
], dtype=np.float32)
s = positions @ axis
s = (s - s.min()) / max(float(s.max() - s.min()), 1e-8)
s = np.clip(s * 0.98, 0.0, 1.0).astype(np.float32)
colors = jet_colormap(s)
_rainbow_cache[key] = colors
if len(_rainbow_cache) > 32:
_rainbow_cache.pop(next(iter(_rainbow_cache)))
return colors
def _vertex_normals(verts: torch.Tensor, faces: torch.Tensor) -> torch.Tensor:
"""Area-weighted per-vertex normals; matches `_compute_vertex_normals`."""
v0 = verts[faces[:, 0]]
v1 = verts[faces[:, 1]]
v2 = verts[faces[:, 2]]
fn = torch.cross(v1 - v0, v2 - v0, dim=-1)
vn = torch.zeros_like(verts)
vn.index_add_(0, faces[:, 0], fn)
vn.index_add_(0, faces[:, 1], fn)
vn.index_add_(0, faces[:, 2], fn)
return vn / vn.norm(dim=-1, keepdim=True).clamp(min=1e-8)
def _build_vcolor(canonical_colors, shader_preset, tilt_x, tilt_z):
"""Mirrors the canonical_colors -> per-vertex RGB pipeline in
`rasterizer.render_pose_data`. Returns a numpy float32 (V, 3) table."""
positions = np.asarray(canonical_colors.get("positions"), dtype=np.float32)
vcolor = rainbow_colors_from_canonical(positions, tilt_x_deg=tilt_x, tilt_z_deg=tilt_z).copy()
if shader_preset in ("rainbow_face_normal", "rainbow_face_semantic"):
face_mask = canonical_colors.get("face_mask")
if face_mask is not None and face_mask.any():
if shader_preset == "rainbow_face_normal":
norm = np.asarray(canonical_colors["norm"], dtype=np.float32)
vcolor[face_mask] = norm[face_mask]
else: # rainbow_face_semantic
sem = np.asarray(canonical_colors["face_region_rgb"], dtype=np.float32)
assigned = sem.sum(axis=1) > 0
vcolor[assigned] = sem[assigned]
return vcolor
def _rasterize_chunk(
fv_pix: torch.Tensor, # (Fc, 3, 2) — pixel coords (sub-pixel float)
fv_z: torch.Tensor, # (Fc, 3) — image-frame z (smaller=closer)
bb_min_x: torch.Tensor, bb_max_x: torch.Tensor, # (Fc,) clamped int bboxes
bb_min_y: torch.Tensor, bb_max_y: torch.Tensor,
max_sx: int, max_sy: int,
W: int,
):
"""Rasterize a chunk of faces at pixel centers. Returns flat tensors of
inside fragments: (pixel_idx, depth, face_local, bary).
"""
device = fv_pix.device
if max_sx == 0 or max_sy == 0:
return None
sx = bb_max_x - bb_min_x
sy = bb_max_y - bb_min_y
px_off = torch.arange(max_sx, device=device)
py_off = torch.arange(max_sy, device=device)
# Pixel-center sample positions, broadcast to (Fc, max_sy, max_sx).
P_x = (bb_min_x[:, None, None] + px_off[None, None, :]).float() + 0.5
P_y = (bb_min_y[:, None, None] + py_off[None, :, None]).float() + 0.5
in_bb = (px_off[None, None, :] < sx[:, None, None]) & \
(py_off[None, :, None] < sy[:, None, None])
Ax = fv_pix[:, 0, 0][:, None, None]
Ay = fv_pix[:, 0, 1][:, None, None]
Bx = fv_pix[:, 1, 0][:, None, None]
By = fv_pix[:, 1, 1][:, None, None]
Cx = fv_pix[:, 2, 0][:, None, None]
Cy = fv_pix[:, 2, 1][:, None, None]
area2 = (Bx - Ax) * (Cy - Ay) - (By - Ay) * (Cx - Ax) # (Fc, 1, 1)
e_a = (Bx - P_x) * (Cy - P_y) - (By - P_y) * (Cx - P_x)
e_b = (Cx - P_x) * (Ay - P_y) - (Cy - P_y) * (Ax - P_x)
e_c = (Ax - P_x) * (By - P_y) - (Ay - P_y) * (Bx - P_x)
# Same-sign-as-area2 inside test (no back-face culling — match either winding).
nondegen = area2.abs() > 1e-6 # threshold rejects near-degenerate triangles
inside = (e_a * area2 >= 0) & (e_b * area2 >= 0) & (e_c * area2 >= 0)
inside = inside & in_bb & nondegen
if not inside.any():
return None
inv_a2 = torch.where(nondegen, 1.0 / area2, torch.zeros_like(area2))
w_a = e_a * inv_a2
w_b = e_b * inv_a2
w_c = e_c * inv_a2
z_a = fv_z[:, 0, None, None]
z_b = fv_z[:, 1, None, None]
z_c = fv_z[:, 2, None, None]
z_grid = w_a * z_a + w_b * z_b + w_c * z_c
fi, yi, xi = inside.nonzero(as_tuple=True)
px_pixel = bb_min_x[fi] + xi
py_pixel = bb_min_y[fi] + yi
pixel_idx = (py_pixel * W + px_pixel).long()
z_flat = z_grid[fi, yi, xi]
bary_flat = torch.stack([w_a[fi, yi, xi], w_b[fi, yi, xi], w_c[fi, yi, xi]], dim=-1)
return pixel_idx, z_flat, fi.long(), bary_flat
def _rasterize_person(
verts_world: torch.Tensor, faces: torch.Tensor,
focal: float, W: int, H: int,
z_buf: torch.Tensor, color_buf: torch.Tensor, mask_buf: torch.Tensor,
shade_fn,
):
# Project image-frame verts to pixel coords. Skip verts at/behind camera.
z_min_ok = 0.05
valid_v = verts_world[:, 2] > z_min_ok
safe_z = verts_world[:, 2].clamp(min=z_min_ok)
px = 0.5 * W + focal * verts_world[:, 0] / safe_z
py = 0.5 * H + focal * verts_world[:, 1] / safe_z
Fv_pix = torch.stack([px, py], dim=-1)[faces] # (F, 3, 2)
Fv_z = verts_world[faces][..., 2] # (F, 3)
Fv_valid = valid_v[faces].all(dim=-1)
sx_face = Fv_pix[..., 0]
sy_face = Fv_pix[..., 1]
bb_min_x = sx_face.amin(dim=-1).floor().long().clamp(min=0, max=W)
bb_max_x = (sx_face.amax(dim=-1).ceil().long() + 1).clamp(min=0, max=W)
bb_min_y = sy_face.amin(dim=-1).floor().long().clamp(min=0, max=H)
bb_max_y = (sy_face.amax(dim=-1).ceil().long() + 1).clamp(min=0, max=H)
sx_all = bb_max_x - bb_min_x
sy_all = bb_max_y - bb_min_y
valid_face = Fv_valid & (sx_all > 0) & (sy_all > 0)
keep = torch.where(valid_face)[0]
if keep.numel() == 0:
return
# Sort kept faces by max bbox dimension so chunks stay similarly-sized.
bbsize = torch.maximum(sx_all, sy_all)[keep]
order = torch.argsort(bbsize)
keep = keep[order]
n = keep.numel()
sx_cpu = sx_all[keep].tolist()
sy_cpu = sy_all[keep].tolist()
bbsize_cpu = bbsize[order].tolist()
PIXEL_BUDGET = 4_000_000
MAX_CHUNK = 8192
i = 0
while i < n:
e = min(i + MAX_CHUNK, n)
# Shrink chunk so worst-case per-face bbox stays within pixel budget.
while e > i + 1:
bb = bbsize_cpu[e - 1]
if (e - i) * bb * bb <= PIXEL_BUDGET:
break
e = max(i + 1, e - max(1, (e - i) // 4))
chunk = keep[i:e]
max_sx = max(sx_cpu[i:e])
max_sy = max(sy_cpu[i:e])
i = e
result = _rasterize_chunk(
Fv_pix[chunk], Fv_z[chunk],
bb_min_x[chunk], bb_max_x[chunk],
bb_min_y[chunk], bb_max_y[chunk],
max_sx, max_sy, W,
)
if result is None:
continue
pixel_idx, z_chunk, face_local, bary = result
face_global = chunk[face_local]
# Atomic depth test against z_buf.
old_at = z_buf[pixel_idx].clone()
z_buf.scatter_reduce_(0, pixel_idx, z_chunk, reduce='amin', include_self=True)
new_at = z_buf[pixel_idx]
is_min = (z_chunk == new_at) & (new_at < old_at)
if not is_min.any():
continue
# Multiple fragments can land on the same pixel and share the new min;
# stable-sort by pixel and keep the first of each run so shade_fn runs
# once per winning pixel. O(M) where M = surviving fragments
surv_pixel = pixel_idx[is_min]
surv_face = face_global[is_min]
surv_bary = bary[is_min]
sort_perm = torch.argsort(surv_pixel, stable=True)
sp = surv_pixel[sort_perm]
first = torch.ones_like(sp, dtype=torch.bool)
first[1:] = sp[1:] != sp[:-1]
selected = sort_perm[first]
wp_idx = surv_pixel[selected]
wp_face = surv_face[selected]
wp_bary = surv_bary[selected]
color_buf[wp_idx] = shade_fn(wp_face, wp_bary)
mask_buf[wp_idx] = True
def _make_shade_fn(
shader_preset, composite,
view_normals_v, view_pos_v, vcolor_v, faces,
base_color, light_dir, pastel_mix,
):
device = view_normals_v.device
base_color_t = torch.as_tensor(base_color, dtype=torch.float32, device=device)
light_dir_t = torch.as_tensor(light_dir, dtype=torch.float32, device=device)
# Light-vector constants — normalized once per render call.
l_unit = -light_dir_t
l_unit = l_unit / l_unit.norm().clamp(min=1e-8)
if pastel_mix <= 0.0:
apply_pastel = lambda rgb: rgb
else:
pm = float(pastel_mix)
apply_pastel = lambda rgb: rgb * (1.0 - pm) + pm
def gather_n(face_idx, bary):
n = (view_normals_v[faces[face_idx]] * bary.unsqueeze(-1)).sum(dim=1)
return n / n.norm(dim=-1, keepdim=True).clamp(min=1e-8)
def gather_pos(face_idx, bary):
return (view_pos_v[faces[face_idx]] * bary.unsqueeze(-1)).sum(dim=1)
def gather_color(face_idx, bary):
return (vcolor_v[faces[face_idx]] * bary.unsqueeze(-1)).sum(dim=1)
if composite == "silhouette":
return lambda fi, ba: torch.ones((fi.shape[0], 3), device=device)
if shader_preset == "normals":
# View-space surface normal encoded as RGB (OpenGL Y+ convention).
# +X right → R; +Y up → G; +Z toward viewer → B. Each face shows mostly
# one channel, matching standard normal-map visualization.
def shade(face_idx, bary):
n = gather_n(face_idx, bary)
return apply_pastel(((n + 1.0) * 0.5).clamp(0.0, 1.0))
return shade
use_vcolor = (shader_preset in _CANONICAL_PRESETS) and (vcolor_v is not None)
if not use_vcolor:
# default.frag: ambient + diffuse + rim
def shade(face_idx, bary):
n = gather_n(face_idx, bary)
v = -gather_pos(face_idx, bary)
v = v / v.norm(dim=-1, keepdim=True).clamp(min=1e-8)
ndotl = (n * l_unit).sum(dim=-1).clamp(min=0)
ndotv = (n * v).sum(dim=-1).clamp(min=0)
rim = (1.0 - ndotv).pow(3.0)
lit = 0.25 * base_color_t \
+ 0.75 * base_color_t * ndotl.unsqueeze(-1) \
+ 0.35 * rim.unsqueeze(-1)
return apply_pastel(lit)
return shade
if shader_preset == "rainbow":
def shade(face_idx, bary):
base = gather_color(face_idx, bary)
n = gather_n(face_idx, bary)
ndotl = (n * l_unit).sum(dim=-1).clamp(min=0)
return apply_pastel(base * (0.65 + 0.35 * ndotl).unsqueeze(-1))
return shade
# rainbow_face_* → rainbow_lit.frag. All light-direction & half-vector
# constants depend only on the (constant) light_dir, so precompute them.
key_l = l_unit
fill_l = torch.stack([-key_l[0], key_l[1].abs(), -key_l[2]])
view_dir = torch.tensor([0.0, 0.0, 1.0], device=device)
h = key_l + view_dir
h = h / h.norm().clamp(min=1e-8)
def shade(face_idx, bary):
base = gather_color(face_idx, bary)
n = gather_n(face_idx, bary)
key_ndotl = (n * key_l).sum(dim=-1).clamp(min=0)
fill_ndotl = (n * fill_l).sum(dim=-1).clamp(min=0)
rim = (1.0 - n[..., 2].clamp(min=0)).pow(2.5) * 0.30
shade_val = (0.45 + 0.45 * key_ndotl + 0.15 * fill_ndotl + rim * 0.5).clamp(min=0.0, max=1.25)
ndoth = (n * h).sum(dim=-1).clamp(min=0)
spec = ndoth.pow(48) * 0.12
lit = base * shade_val.unsqueeze(-1) + spec.unsqueeze(-1)
return apply_pastel(lit)
return shade
def render_pose_data_torch(
pose_data: dict,
frame_idx: int,
W: int,
H: int,
background=None, # Optional[np.ndarray | torch.Tensor] (H, W, 3) fp32 [0, 1]
composite: str = "over",
opacity: float = 1.0,
shader_preset: str = "default",
base_color: Sequence[float] = (0.68, 0.71, 0.78),
light_dir: Sequence[float] = (0.4, -0.7, -0.6),
rainbow_tilt_x_deg: float = 0.0,
rainbow_tilt_z_deg: float = 0.0,
person_brightness_falloff: float = 0.6,
) -> torch.Tensor:
"""Render one frame of persons from `pose_data` at resolution WxH.
Returns an (H, W, 3) float32 torch.Tensor on the comfy compute device,
ready to be stacked into the node's IMAGE output without a CPU round-trip."""
device = comfy.model_management.get_torch_device()
persons = pose_data["frames"][frame_idx] if frame_idx < len(pose_data["frames"]) else []
if len(persons) == 0:
if composite == "over" and background is not None:
if isinstance(background, np.ndarray):
bg = torch.as_tensor(background, dtype=torch.float32, device=device)
else:
bg = background.to(device=device, dtype=torch.float32) if (
background.device != device or background.dtype != torch.float32
) else background
return bg.clamp(0.0, 1.0)
return torch.zeros((H, W, 3), device=device, dtype=torch.float32)
faces = torch.as_tensor(np.asarray(pose_data["faces"], dtype=np.int64), device=device)
canonical_colors = pose_data.get("canonical_colors")
using_canonical = shader_preset in _CANONICAL_PRESETS
if using_canonical and canonical_colors is None:
shader_preset = "default"
using_canonical = False
vcolor = None
if using_canonical:
vcolor_np = _build_vcolor(canonical_colors, shader_preset,
rainbow_tilt_x_deg, rainbow_tilt_z_deg)
vcolor = torch.as_tensor(vcolor_np, dtype=torch.float32, device=device)
falloff = max(0.0, min(1.0, float(person_brightness_falloff)))
person_pastel = [0.0 if k == 0 else (1.0 - falloff ** k) for k in range(len(persons))]
# Front-to-back draw order so nearer persons overdraw farther ones.
order = sorted(range(len(persons)),
key=lambda i: -float(np.asarray(persons[i]["pred_cam_t"]).reshape(-1)[2]))
HW = H * W
z_buf = torch.full((HW,), float('inf'), device=device, dtype=torch.float32)
color_buf = torch.zeros((HW, 3), device=device, dtype=torch.float32)
mask_buf = torch.zeros(HW, device=device, dtype=torch.bool)
for idx in order:
p = persons[idx]
verts_np = np.asarray(p["pred_vertices"], dtype=np.float32).reshape(-1, 3)
cam_t = np.asarray(p["pred_cam_t"], dtype=np.float32).reshape(3)
verts_world = torch.as_tensor(verts_np + cam_t[None, :],
device=device, dtype=torch.float32)
focal = float(np.asarray(p.get("focal_length", 5000.0)).reshape(-1)[0])
# Image-frame (+Y down, +Z forward) → view-space (+Y up, -Z forward)
# for shading, matching what the GL-style shader math expects.
view_pos_v = torch.stack(
[verts_world[:, 0], -verts_world[:, 1], -verts_world[:, 2]], dim=-1,
)
normals_world = _vertex_normals(verts_world, faces)
view_normals_v = torch.stack(
[normals_world[:, 0], -normals_world[:, 1], -normals_world[:, 2]], dim=-1,
)
vcolor_p = vcolor if (vcolor is not None and vcolor.shape[0] == verts_world.shape[0]) else None
# Only canonical-vcolor shaders need vcolor; geometric shaders
# ('normals', 'depth') and the lit default work without it.
if shader_preset in _CANONICAL_PRESETS and vcolor_p is None:
effective_preset = "default"
else:
effective_preset = shader_preset
shade_fn = _make_shade_fn(
effective_preset, composite,
view_normals_v, view_pos_v, vcolor_p, faces,
base_color, light_dir, person_pastel[idx],
)
_rasterize_person(
verts_world, faces, focal, W, H,
z_buf, color_buf, mask_buf, shade_fn,
)
# Stay on GPU through readback + composite.
if shader_preset == "depth":
# z_buf already holds linear image-frame z (smaller=closer; +inf where no mesh covers)
# Normalize within the rendered mesh's range: near=white, far=black, background=black
mask_2d = mask_buf.reshape(H, W)
z_2d = z_buf.reshape(H, W)
if mask_2d.any():
zin = z_2d[mask_2d]
zmin = zin.min()
zr = (zin.max() - zmin).clamp(min=1e-6)
norm = torch.where(mask_2d, 1.0 - (z_2d - zmin) / zr, torch.zeros_like(z_2d))
else:
norm = torch.zeros((H, W), device=device, dtype=torch.float32)
rendered = torch.stack([norm, norm, norm], dim=-1)
mask_f = mask_2d.float()
else:
rendered = color_buf.reshape(H, W, 3).clamp(0.0, 1.0)
mask_f = mask_buf.reshape(H, W).float()
if composite == "over" and background is not None:
if isinstance(background, np.ndarray):
bg = torch.as_tensor(background, dtype=torch.float32, device=device)
else:
bg = background.to(device=device, dtype=torch.float32)
a = mask_f.unsqueeze(-1)
if opacity != 1.0:
a = a * float(opacity)
rendered = torch.lerp(bg, rendered, a)
return rendered