mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-03 04:47:29 +08:00
270 lines
13 KiB
Python
270 lines
13 KiB
Python
# TripoSplat nodes: image -> 3D gaussian splat
|
|
|
|
import logging
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from typing_extensions import override
|
|
|
|
import comfy.model_management
|
|
import comfy.nested_tensor
|
|
import comfy.patcher_extension
|
|
import comfy.utils
|
|
from comfy_api.latest import ComfyExtension, IO, Types
|
|
|
|
|
|
_Q_TOKEN_LENGTH = 8192
|
|
_LATENT_CHANNELS = 16
|
|
_CAM_CHANNELS = 5
|
|
_DINOV3_MEAN = [0.485, 0.456, 0.406]
|
|
_DINOV3_STD = [0.229, 0.224, 0.225]
|
|
_NUM_GAUSSIANS_MIN = 32768
|
|
_NUM_GAUSSIANS_MAX = 1048576
|
|
|
|
|
|
def _preprocess(image: torch.Tensor, mask: torch.Tensor, erode_radius: int, size: int) -> torch.Tensor:
|
|
# Match original preprocessing:
|
|
# resize min side to `size` -> erode alpha -> alpha bbox -> 1.2x square crop -> resize -> composite on black.
|
|
rgb = image[..., :3].clamp(0, 1).movedim(-1, 0) # (3, H, W)
|
|
alpha = mask.clamp(0, 1)[None] # (1, H, W)
|
|
rgba = torch.cat([rgb, alpha], 0)[None] # (1, 4, H, W)
|
|
|
|
h, w = rgba.shape[-2:]
|
|
s = size / min(w, h)
|
|
rgba = comfy.utils.common_upscale(rgba, max(1, round(w * s)), max(1, round(h * s)), "lanczos", "disabled").clamp(0, 1)
|
|
|
|
a = rgba[:, 3:4]
|
|
if erode_radius > 0:
|
|
# min filter over a (2r+1) window == morphological erosion of the alpha matte.
|
|
a = -F.max_pool2d(-a, 2 * erode_radius + 1, stride=1, padding=erode_radius)
|
|
rgba = torch.cat([rgba[:, :3], a], 1)
|
|
|
|
ys, xs = torch.nonzero(a[0, 0] > 0, as_tuple=True)
|
|
if xs.numel() == 0:
|
|
raise ValueError("TripoSplatPreprocessImage: mask is empty (no foreground pixels).")
|
|
x0, x1 = int(xs.min()), int(xs.max())
|
|
y0, y1 = int(ys.min()), int(ys.max())
|
|
cx, cy = (x0 + x1) / 2, (y0 + y1) / 2
|
|
half = max(x1 - x0, y1 - y0) / 2 * 1.2
|
|
left, upper, right, lower = int(cx - half), int(cy - half), int(cx + half), int(cy + half)
|
|
|
|
H, W = rgba.shape[-2:]
|
|
crop = rgba.new_zeros((1, 4, lower - upper, right - left)) # out-of-bounds stays 0, matching PIL.crop
|
|
sx0, sy0, sx1, sy1 = max(left, 0), max(upper, 0), min(right, W), min(lower, H)
|
|
if sx1 > sx0 and sy1 > sy0:
|
|
crop[:, :, sy0 - upper:sy1 - upper, sx0 - left:sx1 - left] = rgba[:, :, sy0:sy1, sx0:sx1]
|
|
|
|
crop = comfy.utils.common_upscale(crop, size, size, "lanczos", "disabled").clamp(0, 1)
|
|
out = (crop[:, :3] * crop[:, 3:4])[0].movedim(0, -1) # composite over black == rgb * alpha
|
|
return out.unsqueeze(0) # (1, 1024, 1024, 3)
|
|
|
|
|
|
class TripoSplatPreprocessImage(IO.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return IO.Schema(
|
|
node_id="TripoSplatPreprocessImage",
|
|
display_name="TripoSplat Preprocess Image",
|
|
category="3d/conditioning",
|
|
description="Crop center each image to a square canvas on a black background and add padding.",
|
|
inputs=[
|
|
IO.Image.Input("image"),
|
|
IO.Mask.Input("mask"),
|
|
IO.Int.Input("erode_radius", default=1, min=0, max=16,
|
|
tooltip="Erode the alpha matte by this pixel radius before cropping (avoids border bleed)."),
|
|
IO.Int.Input("size", default=1024, min=256, max=4096, step=16,
|
|
tooltip="Square image size. The model is trained at 1024; other sizes run but are off-distribution."),
|
|
],
|
|
outputs=[IO.Image.Output(display_name="image")],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, image, mask, erode_radius, size) -> IO.NodeOutput:
|
|
size = max(16, (int(size) // 16) * 16) # DINOv3 patch / Flux2 VAE stride is 16
|
|
if mask.shape[0] != image.shape[0]:
|
|
mask = comfy.utils.repeat_to_batch_size(mask, image.shape[0])
|
|
if tuple(mask.shape[1:]) != tuple(image.shape[1:3]):
|
|
mask = F.interpolate(mask[:, None].float(), size=tuple(image.shape[1:3]), mode="bilinear", align_corners=False)[:, 0]
|
|
prepared = torch.cat([_preprocess(image[i], mask[i], erode_radius, size) for i in range(image.shape[0])], dim=0)
|
|
return IO.NodeOutput(prepared)
|
|
|
|
|
|
class TripoSplatConditioning(IO.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return IO.Schema(
|
|
node_id="TripoSplatConditioning",
|
|
display_name="TripoSplat Conditioning",
|
|
category="3d/conditioning",
|
|
description="Encode the image with DINOv3 and the Flux2 VAE into TripoSplat positive/negative "
|
|
"conditioning, and create the fixed size noise target (latent + camera) for the KSampler",
|
|
inputs=[
|
|
IO.ClipVision.Input("clip_vision", tooltip="DINOv3 ViT-H/16+ image encoder"),
|
|
IO.Vae.Input("vae", tooltip="Flux2 VAE"),
|
|
IO.Image.Input("image"),
|
|
],
|
|
outputs=[
|
|
IO.Conditioning.Output(display_name="positive"),
|
|
IO.Conditioning.Output(display_name="negative"),
|
|
IO.Latent.Output(display_name="latent", tooltip="The fixed size noise target (latent +camera)."),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, clip_vision, vae, image) -> IO.NodeOutput:
|
|
# feature1: DINOv3 token sequence (cls + registers + patches), ImageNet-normalized, with a final non-affine layer norm on top
|
|
comfy.model_management.load_model_gpu(clip_vision.patcher)
|
|
device = clip_vision.load_device
|
|
model_dtype = next(clip_vision.model.parameters()).dtype
|
|
img = image.movedim(-1, 1).to(device) # (B,3,H,W) in [0,1]
|
|
mean = torch.tensor(_DINOV3_MEAN, device=device).view(1, 3, 1, 1)
|
|
std = torch.tensor(_DINOV3_STD, device=device).view(1, 3, 1, 1)
|
|
img = (img - mean) / std
|
|
seq = clip_vision.model(pixel_values=img.to(model_dtype))[0]
|
|
feature1 = F.layer_norm(seq.float(), seq.shape[-1:]).to(comfy.model_management.intermediate_device())
|
|
|
|
# Second conditioning: the Flux2 VAE latent of the image, carried as a standard reference_latents entry
|
|
ref = vae.encode(image).to(comfy.model_management.intermediate_device()) # (B, 128, H, W)
|
|
b = ref.shape[0]
|
|
|
|
positive = [[feature1, {"reference_latents": [ref]}]]
|
|
negative = [[torch.zeros_like(feature1), {"reference_latents": [torch.zeros_like(ref)]}]]
|
|
|
|
# Fixed noise target: the latent is a constant-shape (8192, 16) shape-code + a (1, 5) camera token
|
|
dev = comfy.model_management.intermediate_device()
|
|
latent_seq = torch.zeros([b, _Q_TOKEN_LENGTH, _LATENT_CHANNELS], device=dev)
|
|
camera = torch.zeros([b, 1, _CAM_CHANNELS], device=dev)
|
|
samples = comfy.nested_tensor.NestedTensor((latent_seq, camera))
|
|
return IO.NodeOutput(positive, negative, {"samples": samples})
|
|
|
|
|
|
class VAEDecodeTripoSplat(IO.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return IO.Schema(
|
|
node_id="VAEDecodeTripoSplat",
|
|
display_name="TripoSplat Decode",
|
|
category="3d/latent",
|
|
description="Decode the sampled TripoSplat latent into a 3D gaussian splat. "
|
|
"Modify the number of gaussians to vary the density.",
|
|
inputs=[
|
|
IO.Latent.Input("samples"),
|
|
IO.Vae.Input("vae", tooltip="TripoSplat VAE decoder"),
|
|
IO.Int.Input("num_gaussians", default=262144, min=_NUM_GAUSSIANS_MIN, max=_NUM_GAUSSIANS_MAX, step=32,
|
|
tooltip="Number of gaussians to produce (rounded to a multiple of 32). "
|
|
"262144 matches the octree's point density; higher oversamples the same points "
|
|
"(denser, but no new detail) and costs proportionally more VRAM/time."),
|
|
IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff,
|
|
tooltip="Seeds the octree point sampler (global RNG) for deterministic decodes."),
|
|
],
|
|
outputs=[IO.Splat.Output(display_name="splat")],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, samples, vae, num_gaussians, seed) -> IO.NodeOutput:
|
|
s = samples["samples"]
|
|
latent = s.unbind()[0] if getattr(s, "is_nested", False) else s # take the latent stream, drop camera
|
|
|
|
decoder = vae.first_stage_model
|
|
gpp = decoder.gaussians_per_point
|
|
n = max(_NUM_GAUSSIANS_MIN, min(_NUM_GAUSSIANS_MAX, int(num_gaussians)))
|
|
if n % gpp != 0:
|
|
n = round(n / gpp) * gpp
|
|
|
|
dtype_size = comfy.model_management.dtype_size(vae.vae_dtype)
|
|
hidden = decoder.gs.model_channels
|
|
cond_tokens = latent.shape[1]
|
|
memory_required = (cond_tokens * 4 + (n // gpp) * 10) * hidden * dtype_size
|
|
comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required)
|
|
latent = latent.to(device=vae.device, dtype=vae.vae_dtype)
|
|
generator = torch.Generator(device="cpu").manual_seed(seed)
|
|
parts = [g.render_tensors() for g in decoder.decode(latent, num_gaussians=n, generator=generator)]
|
|
positions, scales, rotations, opacities, sh = (torch.stack(t) for t in zip(*parts))
|
|
return IO.NodeOutput(Types.SPLAT(positions, scales, rotations, opacities, sh))
|
|
|
|
|
|
class TripoSplatSamplingPreview(IO.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return IO.Schema(
|
|
node_id="TripoSplatSamplingPreview",
|
|
display_name="TripoSplat Sampling Preview",
|
|
category="3d/latent",
|
|
description="Patch the TripoSplat model for the standard Ksampler node to show a live decoded "
|
|
"gaussian splat preview at each step.",
|
|
inputs=[
|
|
IO.Model.Input("model"),
|
|
IO.Vae.Input("vae", tooltip="TripoSplat VAE decoder"),
|
|
IO.Int.Input("octree_level", default=5, min=2, max=8, advanced=True,
|
|
tooltip="Octree depth for the preview decode (lower = cheaper/coarser)."),
|
|
IO.Int.Input("num_gaussians", default=16384, min=1024, max=262144, step=32,
|
|
tooltip="Number of gaussians to produce for the preview (rounded to a multiple of 32)."),
|
|
IO.Float.Input("yaw", default=90.0, min=-360.0, max=360.0, step=1.0, tooltip="Preview camera yaw in degrees.", advanced=True,),
|
|
IO.Float.Input("pitch", default=15.0, min=-89.0, max=89.0, step=1.0, tooltip="Preview camera pitch in degrees.", advanced=True,),
|
|
IO.Int.Input("point_size", default=3, min=1, max=16,
|
|
tooltip="Maximum splat radius in pixels. Each gaussian is sized from its scale and capped here; "
|
|
"lower = finer/pointier, higher = chunkier."),
|
|
],
|
|
outputs=[IO.Model.Output()],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, model, vae, octree_level, num_gaussians, yaw, pitch, point_size) -> IO.NodeOutput:
|
|
from comfy.ldm.triposplat.preview import decode_x0_to_image
|
|
cfg = {"gaussians": num_gaussians, "level": octree_level, "yaw": yaw, "pitch": pitch,
|
|
"point_size": point_size}
|
|
|
|
fsm = vae.first_stage_model
|
|
cond_tokens = model.model.diffusion_model.q_token_length
|
|
memory_required = (cond_tokens * 4 + (num_gaussians // fsm.gaussians_per_point) * 10) * fsm.gs.model_channels * comfy.model_management.dtype_size(vae.vae_dtype)
|
|
|
|
# Live preview via WrappersMP.OUTER_SAMPLE + ProgressBar
|
|
# The wrapper augments the sampler's own callback to decode x0 -> gaussian splat -> preview image each step
|
|
def outer_sample_wrapper(executor, *args, **kwargs):
|
|
args = list(args)
|
|
cb_idx = 5 # outer_sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
|
orig_cb = args[cb_idx] if len(args) > cb_idx else kwargs.get("callback")
|
|
state = {"ok": True, "pbar": None, "loaded": False}
|
|
|
|
def callback(step, x0, x, total_steps):
|
|
if orig_cb is not None:
|
|
orig_cb(step, x0, x, total_steps)
|
|
if not state["ok"]:
|
|
return
|
|
try:
|
|
if not state["loaded"]:
|
|
comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required)
|
|
state["loaded"] = True
|
|
img = decode_x0_to_image(vae, x0, cfg)
|
|
if state["pbar"] is None:
|
|
state["pbar"] = comfy.utils.ProgressBar(total_steps)
|
|
state["pbar"].update_absolute(step + 1, total_steps, ("JPEG", img, 512))
|
|
except Exception as e:
|
|
logging.warning("TripoSplatSamplingPreview: preview failed, disabling ({})".format(e))
|
|
state["ok"] = False
|
|
|
|
if len(args) > cb_idx:
|
|
args[cb_idx] = callback
|
|
else:
|
|
kwargs["callback"] = callback
|
|
return executor(*args, **kwargs)
|
|
|
|
m = model.clone()
|
|
m.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "triposplat_sampling_preview", outer_sample_wrapper)
|
|
return IO.NodeOutput(m)
|
|
|
|
|
|
class TripoSplatExtension(ComfyExtension):
|
|
@override
|
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
|
return [
|
|
TripoSplatPreprocessImage,
|
|
TripoSplatConditioning,
|
|
VAEDecodeTripoSplat,
|
|
TripoSplatSamplingPreview,
|
|
]
|
|
|
|
|
|
async def comfy_entrypoint() -> TripoSplatExtension:
|
|
return TripoSplatExtension()
|