mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-14 19:17:32 +08:00
initial MoGe support
This commit is contained in:
parent
aa9d2fc713
commit
919a74f819
@ -106,6 +106,7 @@ class Dino2Encoder(torch.nn.Module):
|
||||
class Dino2PatchEmbeddings(torch.nn.Module):
|
||||
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.projection = operations.Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=dim,
|
||||
@ -125,17 +126,37 @@ class Dino2Embeddings(torch.nn.Module):
|
||||
super().__init__()
|
||||
patch_size = 14
|
||||
image_size = 518
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
|
||||
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
|
||||
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
|
||||
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) # mask_token is a pre-training param, kept only so strict loading accepts the key.
|
||||
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
||||
|
||||
def interpolate_pos_encoding(self, x, h_pixels, w_pixels):
|
||||
pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, torch.float32)
|
||||
|
||||
class_pos = pos_embed[:, 0:1]
|
||||
patch_pos = pos_embed[:, 1:]
|
||||
N = patch_pos.shape[1]
|
||||
M = int(N ** 0.5)
|
||||
h0 = h_pixels // self.patch_size
|
||||
w0 = w_pixels // self.patch_size
|
||||
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M) # +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
|
||||
|
||||
patch_pos = patch_pos.reshape(1, M, M, -1).permute(0, 3, 1, 2)
|
||||
patch_pos = torch.nn.functional.interpolate(patch_pos, scale_factor=scale_factor, mode="bicubic", antialias=False)
|
||||
patch_pos = patch_pos.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
return torch.cat((class_pos, patch_pos), dim=1).to(x.dtype)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
x = self.patch_embeddings(pixel_values)
|
||||
# TODO: mask_token?
|
||||
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||
if x.shape[1] - 1 == self.position_embeddings.shape[1] - 1:
|
||||
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||
else:
|
||||
h, w = pixel_values.shape[-2:]
|
||||
x = x + self.interpolate_pos_encoding(x, h, w)
|
||||
return x
|
||||
|
||||
|
||||
@ -158,3 +179,21 @@ class Dinov2Model(torch.nn.Module):
|
||||
x = self.layernorm(x)
|
||||
pooled_output = x[:, 0, :]
|
||||
return x, i, pooled_output, None
|
||||
|
||||
def get_intermediate_layers(self, pixel_values, indices, apply_norm=True):
|
||||
x = self.embeddings(pixel_values)
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
n_layers = len(self.encoder.layer)
|
||||
resolved = [(i if i >= 0 else n_layers + i) for i in indices]
|
||||
target = set(resolved)
|
||||
max_idx = max(resolved)
|
||||
n_skip = 1 # skip cls token
|
||||
cache = {}
|
||||
for i, layer in enumerate(self.encoder.layer):
|
||||
x = layer(x, optimized_attention)
|
||||
if i in target:
|
||||
normed = self.layernorm(x) if apply_norm else x
|
||||
cache[i] = (normed[:, n_skip:], normed[:, 0])
|
||||
if i >= max_idx:
|
||||
break
|
||||
return [cache[i] for i in resolved]
|
||||
|
||||
0
comfy/ldm/moge/__init__.py
Normal file
0
comfy/ldm/moge/__init__.py
Normal file
200
comfy/ldm/moge/geometry.py
Normal file
200
comfy/ldm/moge/geometry.py
Normal file
@ -0,0 +1,200 @@
|
||||
"""Pure-torch + scipy geometry helpers for MoGe inference and mesh export."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from scipy.optimize import least_squares
|
||||
|
||||
def normalized_view_plane_uv(width: int, height: int, aspect_ratio: Optional[float] = None,
|
||||
dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> torch.Tensor:
|
||||
"""Normalized view-plane UV coordinates with corners at +/-(W, H)/diagonal."""
|
||||
if aspect_ratio is None:
|
||||
aspect_ratio = width / height
|
||||
span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
|
||||
span_y = 1.0 / (1 + aspect_ratio ** 2) ** 0.5
|
||||
u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
|
||||
v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
|
||||
u, v = torch.meshgrid(u, v, indexing="xy")
|
||||
return torch.stack([u, v], dim=-1)
|
||||
|
||||
|
||||
def intrinsics_from_focal_center(fx: torch.Tensor, fy: torch.Tensor, cx: torch.Tensor, cy: torch.Tensor) -> torch.Tensor:
|
||||
"""Assemble (..., 3, 3) intrinsics from broadcastable fx, fy, cx, cy."""
|
||||
fx, fy, cx, cy = [torch.as_tensor(v) for v in (fx, fy, cx, cy)]
|
||||
fx, fy, cx, cy = torch.broadcast_tensors(fx, fy, cx, cy)
|
||||
zero = torch.zeros_like(fx)
|
||||
one = torch.ones_like(fx)
|
||||
return torch.stack([
|
||||
torch.stack([fx, zero, cx], dim=-1),
|
||||
torch.stack([zero, fy, cy], dim=-1),
|
||||
torch.stack([zero, zero, one], dim=-1),
|
||||
], dim=-2)
|
||||
|
||||
|
||||
def depth_map_to_point_map(depth: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
|
||||
"""Back-project a (..., H, W) depth map through K^-1 to (..., H, W, 3) camera-space points.
|
||||
|
||||
Intrinsics use normalized image coords (x in [0, 1] left->right, y in [0, 1] top->bottom).
|
||||
"""
|
||||
H, W = depth.shape[-2:]
|
||||
device, dtype = depth.device, depth.dtype
|
||||
u = (torch.arange(W, dtype=dtype, device=device) + 0.5) / W
|
||||
v = (torch.arange(H, dtype=dtype, device=device) + 0.5) / H
|
||||
grid_v, grid_u = torch.meshgrid(v, u, indexing="ij")
|
||||
pix = torch.stack([grid_u, grid_v, torch.ones_like(grid_u)], dim=-1)
|
||||
K_inv = torch.linalg.inv(intrinsics)
|
||||
rays = torch.einsum("...ij,hwj->...hwi", K_inv, pix)
|
||||
return rays * depth.unsqueeze(-1)
|
||||
|
||||
|
||||
def _solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray) -> Tuple[float, float]:
|
||||
uv = uv.reshape(-1, 2)
|
||||
xy = xyz[..., :2].reshape(-1, 2)
|
||||
z = xyz[..., 2].reshape(-1)
|
||||
|
||||
def fn(uv_, xy_, z_, shift):
|
||||
xy_proj = xy_ / (z_ + shift)[:, None]
|
||||
f = (xy_proj * uv_).sum() / np.square(xy_proj).sum()
|
||||
return (f * xy_proj - uv_).ravel()
|
||||
|
||||
sol = least_squares(partial(fn, uv, xy, z), x0=0.0, ftol=1e-3, method="lm")
|
||||
optim_shift = float(np.asarray(sol["x"]).squeeze())
|
||||
xy_proj = xy / (z + optim_shift)[:, None]
|
||||
optim_focal = float((xy_proj * uv).sum() / np.square(xy_proj).sum())
|
||||
return optim_shift, optim_focal
|
||||
|
||||
|
||||
def _solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float) -> float:
|
||||
uv = uv.reshape(-1, 2)
|
||||
xy = xyz[..., :2].reshape(-1, 2)
|
||||
z = xyz[..., 2].reshape(-1)
|
||||
|
||||
def fn(uv_, xy_, z_, shift):
|
||||
xy_proj = xy_ / (z_ + shift)[:, None]
|
||||
return (focal * xy_proj - uv_).ravel()
|
||||
|
||||
sol = least_squares(partial(fn, uv, xy, z), x0=0.0, ftol=1e-3, method="lm")
|
||||
return float(np.asarray(sol["x"]).squeeze())
|
||||
|
||||
|
||||
def recover_focal_shift(points: torch.Tensor, mask: Optional[torch.Tensor] = None,
|
||||
focal: Optional[torch.Tensor] = None, downsample_size: Tuple[int, int] = (64, 64)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Recover the focal length and z-shift that turn ``points`` into a metric point map.
|
||||
|
||||
Optical center is at the image center; returned focal is relative to half the image diagonal.
|
||||
Returns ``(focal, shift)`` on the same device/dtype as ``points``.
|
||||
"""
|
||||
shape = points.shape
|
||||
H, W = shape[-3], shape[-2]
|
||||
points_b = points.reshape(-1, H, W, 3)
|
||||
mask_b = None if mask is None else mask.reshape(-1, H, W)
|
||||
focal_b = None if focal is None else focal.reshape(-1)
|
||||
|
||||
uv = normalized_view_plane_uv(W, H, dtype=points.dtype, device=points.device)
|
||||
|
||||
points_lr = F.interpolate(points_b.permute(0, 3, 1, 2), downsample_size, mode="nearest").permute(0, 2, 3, 1)
|
||||
uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode="nearest").squeeze(0).permute(1, 2, 0)
|
||||
mask_lr = None
|
||||
if mask_b is not None:
|
||||
mask_lr = F.interpolate(mask_b.to(torch.float32).unsqueeze(1), downsample_size, mode="nearest").squeeze(1) > 0
|
||||
|
||||
uv_np = uv_lr.detach().cpu().numpy()
|
||||
points_np = points_lr.detach().cpu().numpy()
|
||||
mask_np = None if mask_lr is None else mask_lr.detach().cpu().numpy()
|
||||
focal_np = None if focal_b is None else focal_b.detach().cpu().numpy()
|
||||
|
||||
out_focal: list = []
|
||||
out_shift: list = []
|
||||
for i in range(points_b.shape[0]):
|
||||
if mask_np is None:
|
||||
xyz_i = points_np[i].reshape(-1, 3)
|
||||
uv_i = uv_np.reshape(-1, 2)
|
||||
else:
|
||||
sel = mask_np[i]
|
||||
if sel.sum() < 2:
|
||||
out_focal.append(1.0)
|
||||
out_shift.append(0.0)
|
||||
continue
|
||||
xyz_i = points_np[i][sel]
|
||||
uv_i = uv_np[sel]
|
||||
if focal_np is None:
|
||||
shift_i, focal_i = _solve_optimal_focal_shift(uv_i, xyz_i)
|
||||
out_focal.append(focal_i)
|
||||
else:
|
||||
shift_i = _solve_optimal_shift(uv_i, xyz_i, float(focal_np[i]))
|
||||
out_shift.append(shift_i)
|
||||
|
||||
shift_t = torch.tensor(out_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3])
|
||||
if focal is None:
|
||||
focal_t = torch.tensor(out_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3])
|
||||
else:
|
||||
focal_t = focal.reshape(shape[:-3])
|
||||
return focal_t, shift_t
|
||||
|
||||
|
||||
def depth_map_edge(depth: torch.Tensor, atol: Optional[float] = None, rtol: Optional[float] = None, kernel_size: int = 3) -> torch.Tensor:
|
||||
"""Per-pixel boolean: True where the local depth window's max-min span exceeds atol or rtol*depth."""
|
||||
shape = depth.shape
|
||||
d = depth.reshape(-1, 1, *shape[-2:])
|
||||
pad = kernel_size // 2
|
||||
diff = F.max_pool2d(d, kernel_size, stride=1, padding=pad) + F.max_pool2d(-d, kernel_size, stride=1, padding=pad)
|
||||
edge = torch.zeros_like(d, dtype=torch.bool)
|
||||
if atol is not None:
|
||||
edge |= diff > atol
|
||||
if rtol is not None:
|
||||
edge |= (diff / d.clamp_min(1e-6)).nan_to_num_() > rtol
|
||||
return edge.reshape(*shape)
|
||||
|
||||
|
||||
def triangulate_grid_mesh(points: torch.Tensor, mask: Optional[torch.Tensor] = None, decimation: int = 1, discontinuity_threshold: float = 0.04,
|
||||
depth: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Triangulate a (H, W, 3) point map into ``(vertices, faces, uvs)`` on CPU.
|
||||
|
||||
Vertices: pixels with finite coords (passing optional ``mask``). Quads with four valid corners
|
||||
become two triangles. ``depth`` overrides the scalar used for the rtol edge check; pass radial
|
||||
depth for panoramas (the default ``points[..., 2]`` goes negative below the equator).
|
||||
"""
|
||||
points = points.detach().cpu()
|
||||
finite = torch.isfinite(points).all(dim=-1)
|
||||
if mask is None:
|
||||
mask = finite
|
||||
else:
|
||||
mask = mask.detach().cpu().to(torch.bool) & finite
|
||||
|
||||
if discontinuity_threshold > 0:
|
||||
d = depth.detach().cpu() if depth is not None else points[..., 2]
|
||||
# Replace inf with 0 so max-pool doesn't poison neighbourhoods (mask above already excludes those pixels).
|
||||
d_finite = torch.where(finite, d, torch.zeros_like(d))
|
||||
edge = depth_map_edge(d_finite, rtol=discontinuity_threshold)
|
||||
mask = mask & ~edge
|
||||
|
||||
if decimation > 1:
|
||||
points = points[::decimation, ::decimation].contiguous()
|
||||
mask = mask[::decimation, ::decimation].contiguous()
|
||||
H, W = points.shape[:2]
|
||||
|
||||
flat_mask = mask.reshape(-1)
|
||||
idx = torch.full((H * W,), -1, dtype=torch.long)
|
||||
n_valid = int(flat_mask.sum().item())
|
||||
idx[flat_mask] = torch.arange(n_valid, dtype=torch.long)
|
||||
idx = idx.reshape(H, W)
|
||||
|
||||
vertices = points.reshape(-1, 3)[flat_mask].contiguous()
|
||||
|
||||
yy, xx = torch.meshgrid(torch.arange(H), torch.arange(W), indexing="ij")
|
||||
u = xx.float() / max(W - 1, 1)
|
||||
v = yy.float() / max(H - 1, 1)
|
||||
uvs = torch.stack([u, v], dim=-1).reshape(-1, 2)[flat_mask].contiguous()
|
||||
|
||||
a, b, c, d = idx[:-1, :-1], idx[:-1, 1:], idx[1:, 1:], idx[1:, :-1]
|
||||
quad_ok = (a >= 0) & (b >= 0) & (c >= 0) & (d >= 0)
|
||||
a, b, c, d = a[quad_ok], b[quad_ok], c[quad_ok], d[quad_ok]
|
||||
faces = torch.cat([torch.stack([a, b, c], dim=-1), torch.stack([a, c, d], dim=-1)], dim=0).contiguous()
|
||||
return vertices, faces, uvs
|
||||
349
comfy/ldm/moge/model.py
Normal file
349
comfy/ldm/moge/model.py
Normal file
@ -0,0 +1,349 @@
|
||||
"""MoGe v1 / v2 inference modules and a state-dict-driven builder.
|
||||
|
||||
V1: DINOv2 backbone + multi-output head (points, mask).
|
||||
V2: DINOv2 encoder + neck + per-output heads (points, mask, normal, optional metric-scale MLP).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
import comfy.model_management
|
||||
import comfy.model_patcher
|
||||
|
||||
from comfy.image_encoders.dino2 import Dinov2Model
|
||||
|
||||
from .geometry import depth_map_to_point_map, intrinsics_from_focal_center, recover_focal_shift
|
||||
from .modules import ConvStack, DINOv2Encoder, HeadV1, MLP, _view_plane_uv_grid
|
||||
|
||||
|
||||
def _remap_points(points: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply the ``exp`` remap: z -> exp(z), xy stays linear and gets scaled by the new z."""
|
||||
xy, z = points.split([2, 1], dim=-1)
|
||||
z = torch.exp(z)
|
||||
return torch.cat([xy * z, z], dim=-1)
|
||||
|
||||
|
||||
def _detect_dinov2(sd: dict, prefix: str) -> Dict[str, Any]:
|
||||
# All shipped MoGe checkpoints use plain DINOv2
|
||||
hidden = sd[prefix + "embeddings.cls_token"].shape[-1]
|
||||
layer_prefix = prefix + "encoder.layer."
|
||||
depth = 1 + max(int(k[len(layer_prefix):].split(".")[0]) for k in sd if k.startswith(layer_prefix))
|
||||
return {
|
||||
"hidden_size": hidden,
|
||||
"num_attention_heads": hidden // 64,
|
||||
"num_hidden_layers": depth,
|
||||
"layer_norm_eps": 1e-6,
|
||||
"use_swiglu_ffn": False,
|
||||
}
|
||||
|
||||
|
||||
class MoGeModelV1(nn.Module):
|
||||
"""MoGe v1: DINOv2 backbone + HeadV1 (points, mask)."""
|
||||
|
||||
image_mean: torch.Tensor
|
||||
image_std: torch.Tensor
|
||||
|
||||
intermediate_layers = 4
|
||||
num_tokens_range: Tuple[Number, Number] = (1200, 2500)
|
||||
mask_threshold = 0.5
|
||||
|
||||
def __init__(self, backbone: Dict[str, Any], dim_upsample: List[int] = (256, 128, 128),
|
||||
num_res_blocks: int = 1, dim_times_res_block_hidden: int = 1,
|
||||
dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
super().__init__()
|
||||
self.backbone = Dinov2Model(backbone, dtype, device, operations)
|
||||
self.head = HeadV1(dim_in=backbone["hidden_size"], dim_upsample=list(dim_upsample),
|
||||
num_res_blocks=num_res_blocks, dim_times_res_block_hidden=dim_times_res_block_hidden,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
||||
self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
||||
|
||||
def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
|
||||
H, W = image.shape[-2:]
|
||||
resize = ((num_tokens * 14 ** 2) / (H * W)) ** 0.5
|
||||
rh, rw = int(H * resize), int(W * resize)
|
||||
x = F.interpolate(image, (rh, rw), mode="bicubic", align_corners=False, antialias=True)
|
||||
x = (x - self.image_mean) / self.image_std
|
||||
x14 = F.interpolate(x, (rh // 14 * 14, rw // 14 * 14), mode="bilinear", align_corners=False, antialias=True)
|
||||
|
||||
n_layers = len(self.backbone.encoder.layer)
|
||||
indices = list(range(n_layers - self.intermediate_layers, n_layers))
|
||||
feats = self.backbone.get_intermediate_layers(x14, indices, apply_norm=True)
|
||||
|
||||
points, mask = self.head(feats, x)
|
||||
points = F.interpolate(points.float(), (H, W), mode="bilinear", align_corners=False)
|
||||
points = _remap_points(points.permute(0, 2, 3, 1))
|
||||
|
||||
mask = F.interpolate(mask.float(), (H, W), mode="bilinear", align_corners=False).squeeze(1)
|
||||
|
||||
return {"points": points, "mask": mask}
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
"""Detect the v1 head config from sd, build a model, and load weights."""
|
||||
sd = _remap_state_dict(sd)
|
||||
n_up = 1 + max(int(k.split(".")[2]) for k in sd if k.startswith("head.upsample_blocks."))
|
||||
dim_upsample = [sd[f"head.upsample_blocks.{i}.0.0.weight"].shape[1] for i in range(n_up)]
|
||||
# Each upsample stage is Sequential[upsampler, *res_blocks]; count res blocks at level 0.
|
||||
num_res_blocks = max({int(k.split(".")[3]) for k in sd if k.startswith("head.upsample_blocks.0.")})
|
||||
hidden_out = sd["head.upsample_blocks.0.1.layers.2.weight"].shape[0]
|
||||
dim_times = max(hidden_out // dim_upsample[0], 1)
|
||||
model = cls(backbone=_detect_dinov2(sd, prefix="backbone."),
|
||||
dim_upsample=dim_upsample, num_res_blocks=num_res_blocks, dim_times_res_block_hidden=dim_times,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
model.load_state_dict(sd, strict=True)
|
||||
return model
|
||||
|
||||
|
||||
class MoGeModelV2(nn.Module):
|
||||
"""MoGe v2: DINOv2 encoder + neck + per-output heads (points/mask/normal/metric-scale)."""
|
||||
|
||||
intermediate_layers = 4
|
||||
num_tokens_range: Tuple[Number, Number] = (1200, 3600)
|
||||
|
||||
def __init__(self,
|
||||
encoder: Dict[str, Any],
|
||||
neck: Dict[str, Any],
|
||||
points_head: Dict[str, Any],
|
||||
mask_head: Dict[str, Any],
|
||||
scale_head: Dict[str, Any],
|
||||
normal_head: Optional[Dict[str, Any]] = None,
|
||||
dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
super().__init__()
|
||||
self.encoder = DINOv2Encoder(**encoder, dtype=dtype, device=device, operations=operations)
|
||||
self.neck = ConvStack(**neck, dtype=dtype, device=device, operations=operations)
|
||||
self.points_head = ConvStack(**points_head, dtype=dtype, device=device, operations=operations)
|
||||
self.mask_head = ConvStack(**mask_head, dtype=dtype, device=device, operations=operations)
|
||||
self.scale_head = MLP(**scale_head, dtype=dtype, device=device, operations=operations)
|
||||
if normal_head is not None:
|
||||
self.normal_head = ConvStack(**normal_head, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
|
||||
B, _, H, W = image.shape
|
||||
device, dtype = image.device, image.dtype
|
||||
aspect_ratio = W / H
|
||||
base_h = round((num_tokens / aspect_ratio) ** 0.5)
|
||||
base_w = round((num_tokens * aspect_ratio) ** 0.5)
|
||||
|
||||
feat_top, cls_token = self.encoder(image, base_h, base_w, return_class_token=True)
|
||||
|
||||
# 5-level pyramid: feat at level 0 concatenated with UV, other levels UV-only.
|
||||
levels = [_view_plane_uv_grid(B, base_h * (2 ** L), base_w * (2 ** L), aspect_ratio, dtype, device)
|
||||
for L in range(5)]
|
||||
levels[0] = torch.cat([feat_top, levels[0]], dim=1)
|
||||
|
||||
feats = self.neck(levels)
|
||||
|
||||
def _resize(v):
|
||||
return F.interpolate(v, (H, W), mode="bilinear", align_corners=False)
|
||||
|
||||
points = _remap_points(_resize(self.points_head(feats)[-1]).permute(0, 2, 3, 1))
|
||||
mask = _resize(self.mask_head(feats)[-1]).squeeze(1).sigmoid()
|
||||
metric_scale = self.scale_head(cls_token).squeeze(1).exp()
|
||||
|
||||
result = {"points": points, "mask": mask, "metric_scale": metric_scale}
|
||||
if hasattr(self, "normal_head"):
|
||||
normal = _resize(self.normal_head(feats)[-1])
|
||||
result["normal"] = F.normalize(normal.permute(0, 2, 3, 1), dim=-1)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
"""Detect the v2 encoder/neck/heads config from ``sd``, build a model, and load weights."""
|
||||
sd = _remap_state_dict(sd)
|
||||
backbone = _detect_dinov2(sd, prefix="encoder.backbone.")
|
||||
depth = backbone["num_hidden_layers"]
|
||||
n = cls.intermediate_layers
|
||||
encoder = {
|
||||
"backbone": backbone,
|
||||
"intermediate_layers": [(depth // n) * (i + 1) - 1 for i in range(n)],
|
||||
"dim_out": sd["encoder.output_projections.0.weight"].shape[0],
|
||||
}
|
||||
# scale_head is an MLP: Sequential of [Linear, ReLU, ..., Linear]; Linear weight is (out, in).
|
||||
scale_idxs = sorted({int(k.split(".")[1]) for k in sd if k.startswith("scale_head.")})
|
||||
scale_first = sd[f"scale_head.{scale_idxs[0]}.weight"]
|
||||
cfg: Dict[str, Any] = {
|
||||
"encoder": encoder,
|
||||
"neck": cls._detect_convstack(sd, "neck."),
|
||||
"points_head": cls._detect_convstack(sd, "points_head."),
|
||||
"mask_head": cls._detect_convstack(sd, "mask_head."),
|
||||
"scale_head": {"dims": [scale_first.shape[1]] + [sd[f"scale_head.{i}.weight"].shape[0] for i in scale_idxs]},
|
||||
}
|
||||
if any(k.startswith("normal_head.") for k in sd):
|
||||
cfg["normal_head"] = cls._detect_convstack(sd, "normal_head.")
|
||||
model = cls(**cfg, dtype=dtype, device=device, operations=operations)
|
||||
model.load_state_dict(sd, strict=True)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _detect_convstack(sd: dict, prefix: str) -> Dict[str, Any]:
|
||||
"""Reconstruct a ConvStack config from the keys under ``prefix``"""
|
||||
in_keys = [k for k in sd if k.startswith(f"{prefix}input_blocks.") and k.endswith(".weight")]
|
||||
n = 1 + max(int(k[len(f"{prefix}input_blocks."):].split(".")[0]) for k in in_keys)
|
||||
|
||||
in_shapes = [sd[f"{prefix}input_blocks.{i}.weight"].shape for i in range(n)]
|
||||
has_out = lambda i: f"{prefix}output_blocks.{i}.weight" in sd
|
||||
has_norm = f"{prefix}res_blocks.0.0.layers.0.weight" in sd
|
||||
|
||||
def num_res_at(i):
|
||||
rb_prefix = f"{prefix}res_blocks.{i}."
|
||||
return len({int(k[len(rb_prefix):].split(".")[0]) for k in sd if k.startswith(rb_prefix)})
|
||||
|
||||
return {
|
||||
"dim_in": [s[1] for s in in_shapes],
|
||||
"dim_res_blocks": [s[0] for s in in_shapes],
|
||||
"dim_out": [sd[f"{prefix}output_blocks.{i}.weight"].shape[0] if has_out(i) else None for i in range(n)],
|
||||
"num_res_blocks": [num_res_at(i) for i in range(n)],
|
||||
"resamplers": ["conv_transpose" if f"{prefix}resamplers.{i}.0.weight" in sd else "bilinear"
|
||||
for i in range(n - 1)],
|
||||
"res_block_in_norm": "layer_norm" if has_norm else "none",
|
||||
"res_block_hidden_norm": "group_norm" if has_norm else "none",
|
||||
}
|
||||
|
||||
|
||||
# Translate the Meta-style DINOv2 keys MoGe ships to the naming ComfyUI DINOv2 port expects,
|
||||
# and split each fused qkv tensor into Q/K/V.
|
||||
_DINOV2_TOPLEVEL_RENAMES = {
|
||||
"patch_embed.proj.weight": "embeddings.patch_embeddings.projection.weight",
|
||||
"patch_embed.proj.bias": "embeddings.patch_embeddings.projection.bias",
|
||||
"cls_token": "embeddings.cls_token",
|
||||
"pos_embed": "embeddings.position_embeddings",
|
||||
"register_tokens": "embeddings.register_tokens",
|
||||
"mask_token": "embeddings.mask_token",
|
||||
"norm.weight": "layernorm.weight",
|
||||
"norm.bias": "layernorm.bias",
|
||||
}
|
||||
_DINOV2_BLOCK_RENAMES = [
|
||||
("ls1.gamma", "layer_scale1.lambda1"),
|
||||
("ls2.gamma", "layer_scale2.lambda1"),
|
||||
("attn.proj.", "attention.output.dense."),
|
||||
("mlp.w12.", "mlp.weights_in."),
|
||||
("mlp.w3.", "mlp.weights_out."),
|
||||
]
|
||||
|
||||
|
||||
def _remap_state_dict(sd: dict) -> dict:
|
||||
"""Unwrap the upstream ``{"model": ..., "model_config": ...}`` envelope and remap DINOv2 keys"""
|
||||
if "model" in sd and "model_config" in sd:
|
||||
sd = sd["model"]
|
||||
prefix = "encoder.backbone." if any(k.startswith("encoder.backbone.") for k in sd) else "backbone."
|
||||
out: dict = {}
|
||||
for k, v in sd.items():
|
||||
if not k.startswith(prefix):
|
||||
out[k] = v
|
||||
continue
|
||||
rel = k[len(prefix):]
|
||||
if rel in _DINOV2_TOPLEVEL_RENAMES:
|
||||
out[prefix + _DINOV2_TOPLEVEL_RENAMES[rel]] = v
|
||||
continue
|
||||
if not rel.startswith("blocks."):
|
||||
out[k] = v
|
||||
continue
|
||||
_, idx, sub = rel.split(".", 2)
|
||||
if sub in ("attn.qkv.weight", "attn.qkv.bias"):
|
||||
tail = sub.rsplit(".", 1)[1]
|
||||
q, kw, vw = v.chunk(3, dim=0)
|
||||
base = f"{prefix}encoder.layer.{idx}.attention.attention"
|
||||
out[f"{base}.query.{tail}"] = q
|
||||
out[f"{base}.key.{tail}"] = kw
|
||||
out[f"{base}.value.{tail}"] = vw
|
||||
continue
|
||||
for old, new in _DINOV2_BLOCK_RENAMES:
|
||||
sub = sub.replace(old, new)
|
||||
out[f"{prefix}encoder.layer.{idx}.{sub}"] = v
|
||||
return out
|
||||
|
||||
|
||||
def build_from_state_dict(sd: dict, dtype=None, device=None, operations=comfy.ops.manual_cast) -> nn.Module:
|
||||
"""Dispatch to v1 or v2 based on the DINOv2 backbone prefix."""
|
||||
sd = _remap_state_dict(sd)
|
||||
cls = MoGeModelV2 if any(k.startswith("encoder.backbone.") for k in sd) else MoGeModelV1
|
||||
return cls.from_state_dict(sd, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
|
||||
class MoGeModel:
|
||||
"""Loaded MoGe model + ComfyUI memory management."""
|
||||
|
||||
def __init__(self, state_dict: dict):
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
|
||||
self.model = build_from_state_dict(state_dict, dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast).eval()
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.version = "v2" if hasattr(self.model, "encoder") else "v1"
|
||||
self.mask_threshold = float(getattr(self.model, "mask_threshold", 0.5))
|
||||
nt = getattr(self.model, "num_tokens_range", (1200, 2500 if self.version == "v1" else 3600))
|
||||
self.num_tokens_range = (int(nt[0]), int(nt[1]))
|
||||
|
||||
def infer(self, image: torch.Tensor, num_tokens: Optional[int] = None,
|
||||
resolution_level: int = 9, fov_x: Optional[Union[Number, torch.Tensor]] = None,
|
||||
force_projection: bool = True, apply_mask: bool = True,
|
||||
apply_metric_scale: bool = True
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Run a single MoGe forward + post-process pass. ``image`` is (B, 3, H, W) in [0, 1]."""
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
image = image.to(device=self.load_device, dtype=self.dtype)
|
||||
H, W = image.shape[-2:]
|
||||
aspect_ratio = W / H
|
||||
|
||||
if num_tokens is None:
|
||||
lo, hi = self.num_tokens_range
|
||||
num_tokens = int(lo + (resolution_level / 9) * (hi - lo))
|
||||
|
||||
out = self.model.forward(image, num_tokens=num_tokens)
|
||||
points = out["points"].float() # recover_focal_shift goes through scipy on CPU; needs fp32.
|
||||
mask_binary = out["mask"] > self.mask_threshold
|
||||
normal = out.get("normal")
|
||||
metric_scale = out.get("metric_scale")
|
||||
|
||||
diag = (1 + aspect_ratio ** 2) ** 0.5
|
||||
|
||||
def focal_from_fov_deg(deg):
|
||||
fov = torch.as_tensor(deg, device=points.device, dtype=points.dtype)
|
||||
return aspect_ratio / diag / torch.tan(torch.deg2rad(fov / 2))
|
||||
|
||||
if fov_x is None:
|
||||
focal, shift = recover_focal_shift(points, mask_binary)
|
||||
# Fall back to 60 deg FoV when the least-squares solver flips the focal sign.
|
||||
bad = ~torch.isfinite(focal) | (focal <= 0)
|
||||
if bool(bad.any()):
|
||||
focal = torch.where(bad, focal_from_fov_deg(60.0), focal)
|
||||
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
|
||||
else:
|
||||
focal = focal_from_fov_deg(fov_x).expand(points.shape[0])
|
||||
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
|
||||
|
||||
f_diag = focal / 2 * diag
|
||||
half = torch.tensor(0.5, device=points.device, dtype=points.dtype)
|
||||
intrinsics = intrinsics_from_focal_center(f_diag / aspect_ratio, f_diag, half, half)
|
||||
points[..., 2] = points[..., 2] + shift[..., None, None]
|
||||
# v2 only: filter mask by depth>0 to drop metric-scale negative-depth artifacts.
|
||||
if self.version == "v2":
|
||||
mask_binary = mask_binary & (points[..., 2] > 0)
|
||||
depth = points[..., 2].clone()
|
||||
|
||||
if force_projection:
|
||||
points = depth_map_to_point_map(depth, intrinsics=intrinsics)
|
||||
|
||||
if apply_metric_scale and metric_scale is not None:
|
||||
points = points * metric_scale[:, None, None, None]
|
||||
depth = depth * metric_scale[:, None, None]
|
||||
|
||||
if apply_mask:
|
||||
points = torch.where(mask_binary[..., None], points, torch.full_like(points, float("inf")))
|
||||
depth = torch.where(mask_binary, depth, torch.full_like(depth, float("inf")))
|
||||
if normal is not None:
|
||||
normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal))
|
||||
|
||||
result = {"points": points, "depth": depth, "intrinsics": intrinsics, "mask": mask_binary}
|
||||
if normal is not None:
|
||||
result["normal"] = normal
|
||||
return result
|
||||
204
comfy/ldm/moge/modules.py
Normal file
204
comfy/ldm/moge/modules.py
Normal file
@ -0,0 +1,204 @@
|
||||
"""Building blocks for MoGe: residual conv stack, resamplers, MLP, DINOv2 encoder, v1 head."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
from comfy.image_encoders.dino2 import Dinov2Model
|
||||
|
||||
from .geometry import normalized_view_plane_uv
|
||||
|
||||
|
||||
def _conv2d(operations, c_in: int, c_out: int, k: int = 3, *, dtype=None, device=None):
|
||||
return operations.Conv2d(c_in, c_out, kernel_size=k, padding=k // 2, padding_mode="replicate", dtype=dtype, device=device)
|
||||
|
||||
|
||||
def _view_plane_uv_grid(batch: int, height: int, width: int, aspect_ratio: float, dtype, device) -> torch.Tensor:
|
||||
"""Batched normalized view-plane UV grid as a (B, 2, H, W) tensor."""
|
||||
uv = normalized_view_plane_uv(width, height, aspect_ratio=aspect_ratio, dtype=dtype, device=device)
|
||||
return uv.permute(2, 0, 1).unsqueeze(0).expand(batch, -1, -1, -1)
|
||||
|
||||
|
||||
def _concat_view_plane_uv(x: torch.Tensor, aspect_ratio: float) -> torch.Tensor:
|
||||
"""Append a 2-channel normalized view-plane UV grid to x along the channel dim."""
|
||||
uv = _view_plane_uv_grid(x.shape[0], x.shape[-2], x.shape[-1], aspect_ratio, x.dtype, x.device)
|
||||
return torch.cat([x, uv], dim=1)
|
||||
|
||||
|
||||
class ResidualConvBlock(nn.Module):
|
||||
def __init__(self, channels: int, hidden_channels: Optional[int] = None, in_norm: str = "layer_norm", hidden_norm: str = "group_norm",
|
||||
dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
super().__init__()
|
||||
hidden_channels = hidden_channels if hidden_channels is not None else channels
|
||||
|
||||
in_norm_layer = operations.GroupNorm(1, channels) if in_norm == "layer_norm" else nn.Identity()
|
||||
hidden_norm_layer = (operations.GroupNorm(max(hidden_channels // 32, 1), hidden_channels)
|
||||
if hidden_norm == "group_norm" else nn.Identity())
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
in_norm_layer, nn.ReLU(), _conv2d(operations, channels, hidden_channels, dtype=dtype, device=device),
|
||||
hidden_norm_layer, nn.ReLU(), _conv2d(operations, hidden_channels, channels, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x) + x
|
||||
|
||||
|
||||
class Resampler(nn.Sequential):
|
||||
"""2x upsampler: ConvTranspose2d(2x2) or bilinear upsample, followed by a 3x3 conv."""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, type_: str, dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
if type_ == "conv_transpose":
|
||||
up = operations.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, dtype=dtype, device=device)
|
||||
conv_in = out_channels
|
||||
else: # "bilinear"
|
||||
up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
|
||||
conv_in = in_channels
|
||||
super().__init__(up, _conv2d(operations, conv_in, out_channels, dtype=dtype, device=device))
|
||||
|
||||
|
||||
class MLP(nn.Sequential):
|
||||
def __init__(self, dims: Sequence[int], dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
layers = []
|
||||
for d_in, d_out in zip(dims[:-2], dims[1:-1]):
|
||||
layers.append(operations.Linear(d_in, d_out, dtype=dtype, device=device))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
layers.append(operations.Linear(dims[-2], dims[-1], dtype=dtype, device=device))
|
||||
super().__init__(*layers)
|
||||
|
||||
|
||||
class ConvStack(nn.Module):
|
||||
def __init__(self, dim_in: List[Optional[int]], dim_res_blocks: List[int], dim_out: List[Optional[int]], resamplers: List[str],
|
||||
num_res_blocks: List[int], dim_times_res_block_hidden: int = 1, res_block_in_norm: str = "layer_norm", res_block_hidden_norm: str = "group_norm",
|
||||
dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
super().__init__()
|
||||
|
||||
self.input_blocks = nn.ModuleList([
|
||||
(_conv2d(operations, d_in, d_res, k=1, dtype=dtype, device=device)
|
||||
if d_in is not None else nn.Identity())
|
||||
for d_in, d_res in zip(dim_in, dim_res_blocks)
|
||||
])
|
||||
|
||||
self.resamplers = nn.ModuleList([
|
||||
Resampler(prev, succ, type_=r, dtype=dtype, device=device, operations=operations)
|
||||
for prev, succ, r in zip(dim_res_blocks[:-1], dim_res_blocks[1:], resamplers)
|
||||
])
|
||||
|
||||
self.res_blocks = nn.ModuleList([
|
||||
nn.Sequential(*[
|
||||
ResidualConvBlock(d_res, dim_times_res_block_hidden * d_res, in_norm=res_block_in_norm, hidden_norm=res_block_hidden_norm, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_res_blocks[i])
|
||||
])
|
||||
for i, d_res in enumerate(dim_res_blocks)
|
||||
])
|
||||
|
||||
self.output_blocks = nn.ModuleList([
|
||||
(_conv2d(operations, d_res, d_out, k=1, dtype=dtype, device=device)
|
||||
if d_out is not None else nn.Identity())
|
||||
for d_out, d_res in zip(dim_out, dim_res_blocks)
|
||||
])
|
||||
|
||||
def forward(self, in_features: List[Optional[torch.Tensor]]):
|
||||
out_features = []
|
||||
x = None
|
||||
for i in range(len(self.res_blocks)):
|
||||
feat = self.input_blocks[i](in_features[i]) if in_features[i] is not None else None
|
||||
if i == 0:
|
||||
x = feat
|
||||
elif feat is not None:
|
||||
x = x + feat
|
||||
x = self.res_blocks[i](x)
|
||||
out_features.append(self.output_blocks[i](x))
|
||||
if i < len(self.res_blocks) - 1:
|
||||
x = self.resamplers[i](x)
|
||||
return out_features
|
||||
|
||||
|
||||
class DINOv2Encoder(nn.Module):
|
||||
"""Comfy DINOv2 backbone with per-layer 1x1 projection heads."""
|
||||
|
||||
def __init__(self, backbone: dict, intermediate_layers: List[int], dim_out: int, dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
super().__init__()
|
||||
self.intermediate_layers = list(intermediate_layers)
|
||||
dim_features = backbone["hidden_size"]
|
||||
self.backbone = Dinov2Model(backbone, dtype, device, operations)
|
||||
self.output_projections = nn.ModuleList([
|
||||
_conv2d(operations, dim_features, dim_out, k=1, dtype=dtype, device=device)
|
||||
for _ in range(len(self.intermediate_layers))
|
||||
])
|
||||
self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
||||
self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
||||
|
||||
def forward(self, image: torch.Tensor, token_rows: int, token_cols: int,
|
||||
return_class_token: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
image_14 = F.interpolate(image, (token_rows * 14, token_cols * 14), mode="bilinear", align_corners=False, antialias=True)
|
||||
image_14 = (image_14 - self.image_mean) / self.image_std
|
||||
feats = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, apply_norm=True)
|
||||
x = torch.stack([
|
||||
proj(feat.permute(0, 2, 1).unflatten(2, (token_rows, token_cols)).contiguous())
|
||||
for proj, (feat, _cls) in zip(self.output_projections, feats)
|
||||
], dim=1).sum(dim=1)
|
||||
if return_class_token:
|
||||
return x, feats[-1][1]
|
||||
return x
|
||||
|
||||
|
||||
class HeadV1(nn.Module):
|
||||
"""v1 head: 4 backbone-feature projections -> shared upsample stack -> per-target output convs (points, mask)."""
|
||||
|
||||
NUM_FEATURES = 4
|
||||
DIM_PROJ = 512
|
||||
DIM_OUT = (3, 1) # 3 channels for points, 1 for mask
|
||||
LAST_CONV_CHANNELS = 32
|
||||
|
||||
def __init__(self, dim_in: int, dim_upsample: List[int] = (256, 128, 128), num_res_blocks: int = 1, dim_times_res_block_hidden: int = 1,
|
||||
dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
super().__init__()
|
||||
self.projects = nn.ModuleList([
|
||||
_conv2d(operations, dim_in, self.DIM_PROJ, k=1, dtype=dtype, device=device)
|
||||
for _ in range(self.NUM_FEATURES)
|
||||
])
|
||||
def upsampler(in_ch, out_ch):
|
||||
return nn.Sequential(
|
||||
operations.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2, dtype=dtype, device=device),
|
||||
_conv2d(operations, out_ch, out_ch, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
in_chs = [self.DIM_PROJ] + list(dim_upsample[:-1])
|
||||
self.upsample_blocks = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
upsampler(in_ch + 2, out_ch),
|
||||
*(ResidualConvBlock(out_ch, dim_times_res_block_hidden * out_ch, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_res_blocks))
|
||||
)
|
||||
for in_ch, out_ch in zip(in_chs, dim_upsample)
|
||||
])
|
||||
self.output_block = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
_conv2d(operations, dim_upsample[-1] + 2, self.LAST_CONV_CHANNELS, dtype=dtype, device=device),
|
||||
nn.ReLU(inplace=True),
|
||||
_conv2d(operations, self.LAST_CONV_CHANNELS, d_out, k=1, dtype=dtype, device=device),
|
||||
)
|
||||
for d_out in self.DIM_OUT
|
||||
])
|
||||
|
||||
def forward(self, hidden_states, image: torch.Tensor):
|
||||
img_h, img_w = image.shape[-2:]
|
||||
patch_h, patch_w = img_h // 14, img_w // 14
|
||||
aspect = img_w / img_h
|
||||
x = torch.stack([
|
||||
proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous())
|
||||
for proj, (feat, _cls) in zip(self.projects, hidden_states)
|
||||
], dim=1).sum(dim=1)
|
||||
|
||||
for block in self.upsample_blocks:
|
||||
x = block(_concat_view_plane_uv(x, aspect))
|
||||
|
||||
x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False)
|
||||
x = _concat_view_plane_uv(x, aspect)
|
||||
return [block(x) for block in self.output_block]
|
||||
315
comfy/ldm/moge/panorama.py
Normal file
315
comfy/ldm/moge/panorama.py
Normal file
@ -0,0 +1,315 @@
|
||||
"""Panorama (equirectangular) inference helpers for MoGe.
|
||||
|
||||
Splits an equirect into 12 perspective views via an icosahedron camera rig, runs
|
||||
the model per view, and stitches per-view distance maps back into a single
|
||||
equirect distance map via a multi-scale Poisson + gradient sparse solve.
|
||||
Image sampling uses ``F.grid_sample`` (GPU); the sparse solve uses ``lsmr`` (CPU).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def _icosahedron_directions() -> np.ndarray:
|
||||
"""12 icosahedron-vertex directions (non-normalised, matching upstream's vertex order)."""
|
||||
A = (1.0 + np.sqrt(5.0)) / 2.0
|
||||
return np.array([
|
||||
[0, 1, A], [0, -1, A], [0, 1, -A], [0, -1, -A],
|
||||
[1, A, 0], [-1, A, 0], [1, -A, 0], [-1, -A, 0],
|
||||
[A, 0, 1], [A, 0, -1], [-A, 0, 1], [-A, 0, -1],
|
||||
], dtype=np.float32)
|
||||
|
||||
|
||||
def _intrinsics_from_fov(fov_x_rad: float, fov_y_rad: float) -> np.ndarray:
|
||||
"""Normalised-image (unit-square) K matrix."""
|
||||
fx = 0.5 / np.tan(fov_x_rad / 2)
|
||||
fy = 0.5 / np.tan(fov_y_rad / 2)
|
||||
return np.array([[fx, 0, 0.5], [0, fy, 0.5], [0, 0, 1]], dtype=np.float32)
|
||||
|
||||
|
||||
def _extrinsics_look_at(eye: np.ndarray, target: np.ndarray, up: np.ndarray) -> np.ndarray:
|
||||
"""OpenCV-convention world->camera extrinsics for an array of look-at targets (N, 4, 4)."""
|
||||
eye = np.asarray(eye, dtype=np.float32)
|
||||
target = np.asarray(target, dtype=np.float32)
|
||||
up = np.asarray(up, dtype=np.float32)
|
||||
if target.ndim == 1:
|
||||
target = target[None]
|
||||
|
||||
fwd = target - eye
|
||||
fwd = fwd / np.linalg.norm(fwd, axis=-1, keepdims=True).clip(1e-12)
|
||||
right = np.cross(fwd, up)
|
||||
right_norm = np.linalg.norm(right, axis=-1, keepdims=True)
|
||||
# Fall back to an arbitrary perpendicular if forward is parallel to up.
|
||||
parallel = right_norm.squeeze(-1) < 1e-6
|
||||
if parallel.any():
|
||||
alt_up = np.array([1, 0, 0], dtype=np.float32)
|
||||
right = np.where(parallel[:, None], np.cross(fwd, alt_up), right)
|
||||
right_norm = np.linalg.norm(right, axis=-1, keepdims=True)
|
||||
right = right / right_norm.clip(1e-12)
|
||||
new_up = np.cross(fwd, right)
|
||||
|
||||
R = np.stack([right, new_up, fwd], axis=-2)
|
||||
t = -np.einsum("nij,j->ni", R, eye)
|
||||
E = np.zeros((R.shape[0], 4, 4), dtype=np.float32)
|
||||
E[:, :3, :3] = R
|
||||
E[:, :3, 3] = t
|
||||
E[:, 3, 3] = 1.0
|
||||
return E
|
||||
|
||||
|
||||
def get_panorama_cameras() -> Tuple[np.ndarray, List[np.ndarray]]:
|
||||
"""Returns (extrinsics (12, 4, 4), [intrinsics] * 12) for icosahedron views at 90 deg FoV."""
|
||||
targets = _icosahedron_directions()
|
||||
eye = np.zeros(3, dtype=np.float32)
|
||||
up = np.array([0, 0, 1], dtype=np.float32)
|
||||
extrinsics = _extrinsics_look_at(eye, targets, up)
|
||||
K = _intrinsics_from_fov(np.deg2rad(90.0), np.deg2rad(90.0))
|
||||
return extrinsics, [K] * len(targets)
|
||||
|
||||
|
||||
def spherical_uv_to_directions(uv: np.ndarray) -> np.ndarray:
|
||||
"""Equirect UV in [0, 1] -> 3D unit-direction (Z up)."""
|
||||
theta = (1 - uv[..., 0]) * (2 * np.pi)
|
||||
phi = uv[..., 1] * np.pi
|
||||
return np.stack([
|
||||
np.sin(phi) * np.cos(theta),
|
||||
np.sin(phi) * np.sin(theta),
|
||||
np.cos(phi),
|
||||
], axis=-1).astype(np.float32)
|
||||
|
||||
|
||||
def directions_to_spherical_uv(directions: np.ndarray) -> np.ndarray:
|
||||
"""3D direction -> equirect UV in [0, 1]."""
|
||||
n = np.linalg.norm(directions, axis=-1, keepdims=True).clip(1e-12)
|
||||
d = directions / n
|
||||
u = 1 - np.arctan2(d[..., 1], d[..., 0]) / (2 * np.pi) % 1.0
|
||||
v = np.arccos(d[..., 2].clip(-1, 1)) / np.pi
|
||||
return np.stack([u, v], axis=-1).astype(np.float32)
|
||||
|
||||
|
||||
def _uv_grid(H: int, W: int) -> np.ndarray:
|
||||
"""Pixel-center UV grid in [0, 1]; (H, W, 2)."""
|
||||
u = (np.arange(W, dtype=np.float32) + 0.5) / W
|
||||
v = (np.arange(H, dtype=np.float32) + 0.5) / H
|
||||
return np.stack(np.meshgrid(u, v, indexing="xy"), axis=-1)
|
||||
|
||||
|
||||
def _unproject_cv(uv: np.ndarray, depth: np.ndarray,
|
||||
extrinsics: np.ndarray, intrinsics: np.ndarray) -> np.ndarray:
|
||||
"""Back-project pixels into world coords (OpenCV convention)."""
|
||||
pix = np.concatenate([uv, np.ones_like(uv[..., :1])], axis=-1)
|
||||
K_inv = np.linalg.inv(intrinsics)
|
||||
cam = pix @ K_inv.T * depth[..., None]
|
||||
cam_h = np.concatenate([cam, np.ones_like(cam[..., :1])], axis=-1)
|
||||
E_inv = np.linalg.inv(extrinsics)
|
||||
return (cam_h @ E_inv.T)[..., :3]
|
||||
|
||||
|
||||
def _project_cv(points: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""World coords -> (uv, depth) in the camera (OpenCV convention)."""
|
||||
pts_h = np.concatenate([points, np.ones_like(points[..., :1])], axis=-1)
|
||||
cam = pts_h @ extrinsics.T
|
||||
cam_xyz = cam[..., :3]
|
||||
depth = cam_xyz[..., 2]
|
||||
proj = cam_xyz @ intrinsics.T
|
||||
uv = proj[..., :2] / proj[..., 2:3].clip(1e-12)
|
||||
return uv.astype(np.float32), depth.astype(np.float32)
|
||||
|
||||
|
||||
def _grid_sample_uv(img_bchw: torch.Tensor, uv: torch.Tensor, mode: str = "bilinear") -> torch.Tensor:
|
||||
"""Sample img_bchw at UV-in-[0,1] coords ``uv`` of shape (B, H, W, 2); replicate-border."""
|
||||
grid = uv * 2.0 - 1.0
|
||||
return F.grid_sample(img_bchw, grid, mode=mode, padding_mode="border", align_corners=False)
|
||||
|
||||
|
||||
def split_panorama_image(image: torch.Tensor, extrinsics: np.ndarray, intrinsics: List[np.ndarray], resolution: int) -> torch.Tensor:
|
||||
"""(3, Hp, Wp) equirect on any device -> (N, 3, R, R) perspective crops on the same device."""
|
||||
device = image.device
|
||||
N = len(extrinsics)
|
||||
uv = _uv_grid(resolution, resolution)
|
||||
sample_uvs = []
|
||||
for i in range(N):
|
||||
world = _unproject_cv(uv, np.ones(uv.shape[:-1], dtype=np.float32), extrinsics[i], intrinsics[i])
|
||||
sample_uvs.append(directions_to_spherical_uv(world))
|
||||
sample_uvs = np.stack(sample_uvs, axis=0)
|
||||
|
||||
img_bchw = image.unsqueeze(0).expand(N, -1, -1, -1).contiguous()
|
||||
sample_uvs_t = torch.from_numpy(sample_uvs).to(device=device, dtype=image.dtype)
|
||||
return _grid_sample_uv(img_bchw, sample_uvs_t, mode="bilinear")
|
||||
|
||||
|
||||
def _poisson_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False):
|
||||
"""Sparse Laplacian operator over the H x W grid."""
|
||||
from scipy.sparse import csr_array
|
||||
grid_index = np.arange(H * W).reshape(H, W)
|
||||
grid_index = np.pad(grid_index, ((0, 0), (1, 1)), mode="wrap" if wrap_x else "edge")
|
||||
grid_index = np.pad(grid_index, ((1, 1), (0, 0)), mode="wrap" if wrap_y else "edge")
|
||||
|
||||
data = np.array([[-4, 1, 1, 1, 1]], dtype=np.float32).repeat(H * W, axis=0).reshape(-1)
|
||||
indices = np.stack([
|
||||
grid_index[1:-1, 1:-1],
|
||||
grid_index[:-2, 1:-1], grid_index[2:, 1:-1],
|
||||
grid_index[1:-1, :-2], grid_index[1:-1, 2:],
|
||||
], axis=-1).reshape(-1)
|
||||
indptr = np.arange(0, H * W * 5 + 1, 5)
|
||||
return csr_array((data, indices, indptr), shape=(H * W, H * W))
|
||||
|
||||
|
||||
def _grad_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False):
|
||||
"""Sparse forward-difference operator over the H x W grid."""
|
||||
from scipy.sparse import csr_array
|
||||
grid_index = np.arange(W * H).reshape(H, W)
|
||||
if wrap_x:
|
||||
grid_index = np.pad(grid_index, ((0, 0), (0, 1)), mode="wrap")
|
||||
if wrap_y:
|
||||
grid_index = np.pad(grid_index, ((0, 1), (0, 0)), mode="wrap")
|
||||
|
||||
data = np.concatenate([
|
||||
np.concatenate([
|
||||
np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1),
|
||||
-np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1),
|
||||
], axis=1).reshape(-1),
|
||||
np.concatenate([
|
||||
np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1),
|
||||
-np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1),
|
||||
], axis=1).reshape(-1),
|
||||
])
|
||||
indices = np.concatenate([
|
||||
np.concatenate([grid_index[:, :-1].reshape(-1, 1), grid_index[:, 1:].reshape(-1, 1)], axis=1).reshape(-1),
|
||||
np.concatenate([grid_index[:-1, :].reshape(-1, 1), grid_index[1:, :].reshape(-1, 1)], axis=1).reshape(-1),
|
||||
])
|
||||
nx = grid_index.shape[0] * (grid_index.shape[1] - 1)
|
||||
ny = (grid_index.shape[0] - 1) * grid_index.shape[1]
|
||||
indptr = np.arange(0, nx * 2 + ny * 2 + 1, 2)
|
||||
return csr_array((data, indices, indptr), shape=(nx + ny, H * W))
|
||||
|
||||
|
||||
def _scipy_remap_bilinear(img: np.ndarray, sample_pixels: np.ndarray, mode: str = "bilinear") -> np.ndarray:
|
||||
"""Bilinear/nearest sampling at fractional pixel coords; out-of-range clamps to nearest border."""
|
||||
from scipy.ndimage import map_coordinates
|
||||
H, W = img.shape[:2]
|
||||
yy = np.clip(sample_pixels[..., 1], 0, H - 1)
|
||||
xx = np.clip(sample_pixels[..., 0], 0, W - 1)
|
||||
order = 1 if mode == "bilinear" else 0
|
||||
if img.ndim == 2:
|
||||
return map_coordinates(img, [yy, xx], order=order, mode="nearest").astype(img.dtype)
|
||||
out = np.stack([
|
||||
map_coordinates(img[..., c], [yy, xx], order=order, mode="nearest")
|
||||
for c in range(img.shape[-1])
|
||||
], axis=-1)
|
||||
return out.astype(img.dtype)
|
||||
|
||||
|
||||
def merge_panorama_depth(width: int, height: int,
|
||||
distance_maps: List[np.ndarray], pred_masks: List[np.ndarray],
|
||||
extrinsics: List[np.ndarray], intrinsics: List[np.ndarray],
|
||||
on_view: Optional[Callable[[], None]] = None,
|
||||
on_solve_start: Optional[Callable[[int, int], None]] = None,
|
||||
on_solve_end: Optional[Callable[[int, int], None]] = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Stitch per-view distance maps into a single equirect distance map.
|
||||
|
||||
Recursive multi-scale solve: solves at half resolution first and uses that as the lsmr init
|
||||
for the full-resolution solve. Optional callbacks fire per view processed and around each
|
||||
lsmr solve so callers can drive a progress bar.
|
||||
"""
|
||||
from scipy.ndimage import convolve
|
||||
from scipy.sparse import vstack
|
||||
from scipy.sparse.linalg import lsmr
|
||||
|
||||
if max(width, height) > 256:
|
||||
coarse_depth, _ = merge_panorama_depth(width // 2, height // 2,
|
||||
distance_maps, pred_masks, extrinsics, intrinsics,
|
||||
on_view=on_view,
|
||||
on_solve_start=on_solve_start,
|
||||
on_solve_end=on_solve_end)
|
||||
t = torch.from_numpy(coarse_depth).unsqueeze(0).unsqueeze(0)
|
||||
t = F.interpolate(t, size=(height, width), mode="bilinear", align_corners=False)
|
||||
depth_init = t.squeeze().numpy().astype(np.float32)
|
||||
else:
|
||||
depth_init = None
|
||||
|
||||
spherical_directions = spherical_uv_to_directions(_uv_grid(height, width))
|
||||
|
||||
pano_log_grad_maps, pano_grad_masks = [], []
|
||||
pano_log_lap_maps, pano_lap_masks = [], []
|
||||
pano_pred_masks: List[np.ndarray] = []
|
||||
|
||||
for i in range(len(distance_maps)):
|
||||
proj_uv, proj_depth = _project_cv(spherical_directions, extrinsics[i], intrinsics[i])
|
||||
proj_valid = (proj_depth > 0) & (proj_uv > 0).all(axis=-1) & (proj_uv < 1).all(axis=-1)
|
||||
|
||||
Hd, Wd = distance_maps[i].shape[:2]
|
||||
proj_pixels = np.clip(proj_uv, 0, 1) * np.array([Wd - 1, Hd - 1], dtype=np.float32)
|
||||
|
||||
log_dist = np.log(np.clip(distance_maps[i], 1e-6, None))
|
||||
sampled = _scipy_remap_bilinear(log_dist, proj_pixels, mode="bilinear")
|
||||
pano_log = np.where(proj_valid, sampled, 0.0).astype(np.float32)
|
||||
|
||||
sampled_mask = _scipy_remap_bilinear(pred_masks[i].astype(np.uint8), proj_pixels, mode="nearest")
|
||||
pano_pred = proj_valid & (sampled_mask > 0)
|
||||
|
||||
# Equirect wraps horizontally but not vertically: wrap pad along x, edge pad along y.
|
||||
padded = np.pad(pano_log, ((0, 0), (0, 1)), mode="wrap")
|
||||
gx, gy = padded[:, :-1] - padded[:, 1:], padded[:-1, :] - padded[1:, :]
|
||||
padded_m = np.pad(pano_pred, ((0, 0), (0, 1)), mode="wrap")
|
||||
mx, my = padded_m[:, :-1] & padded_m[:, 1:], padded_m[:-1, :] & padded_m[1:, :]
|
||||
pano_log_grad_maps.append((gx, gy))
|
||||
pano_grad_masks.append((mx, my))
|
||||
|
||||
padded = np.pad(pano_log, ((1, 1), (0, 0)), mode="edge")
|
||||
padded = np.pad(padded, ((0, 0), (1, 1)), mode="wrap")
|
||||
lap_kernel = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=np.float32)
|
||||
lap = convolve(padded, lap_kernel)[1:-1, 1:-1]
|
||||
padded_m = np.pad(pano_pred, ((1, 1), (0, 0)), mode="edge")
|
||||
padded_m = np.pad(padded_m, ((0, 0), (1, 1)), mode="wrap")
|
||||
m_kernel = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.uint8)
|
||||
lap_mask = convolve(padded_m.astype(np.uint8), m_kernel)[1:-1, 1:-1] == 5
|
||||
pano_log_lap_maps.append(lap)
|
||||
pano_lap_masks.append(lap_mask)
|
||||
pano_pred_masks.append(pano_pred)
|
||||
|
||||
if on_view is not None:
|
||||
on_view()
|
||||
|
||||
gx = np.stack([m[0] for m in pano_log_grad_maps], axis=0)
|
||||
gy = np.stack([m[1] for m in pano_log_grad_maps], axis=0)
|
||||
mx = np.stack([m[0] for m in pano_grad_masks], axis=0)
|
||||
my = np.stack([m[1] for m in pano_grad_masks], axis=0)
|
||||
gx_avg = (gx * mx).sum(axis=0) / mx.sum(axis=0).clip(1e-3)
|
||||
gy_avg = (gy * my).sum(axis=0) / my.sum(axis=0).clip(1e-3)
|
||||
|
||||
laps = np.stack(pano_log_lap_maps, axis=0)
|
||||
lap_masks = np.stack(pano_lap_masks, axis=0)
|
||||
lap_avg = (laps * lap_masks).sum(axis=0) / lap_masks.sum(axis=0).clip(1e-3)
|
||||
|
||||
grad_x_mask = mx.any(axis=0).reshape(-1)
|
||||
grad_y_mask = my.any(axis=0).reshape(-1)
|
||||
grad_mask = np.concatenate([grad_x_mask, grad_y_mask])
|
||||
lap_mask_flat = lap_masks.any(axis=0).reshape(-1)
|
||||
|
||||
A = vstack([
|
||||
_grad_equation(width, height, wrap_x=True, wrap_y=False)[grad_mask],
|
||||
_poisson_equation(width, height, wrap_x=True, wrap_y=False)[lap_mask_flat],
|
||||
])
|
||||
b = np.concatenate([
|
||||
gx_avg.reshape(-1)[grad_x_mask],
|
||||
gy_avg.reshape(-1)[grad_y_mask],
|
||||
lap_avg.reshape(-1)[lap_mask_flat],
|
||||
])
|
||||
x0 = np.log(np.clip(depth_init, 1e-6, None)).reshape(-1) if depth_init is not None else None
|
||||
|
||||
if on_solve_start is not None:
|
||||
on_solve_start(width, height)
|
||||
x, *_ = lsmr(A, b, atol=1e-5, btol=1e-5, x0=x0, show=False)
|
||||
if on_solve_end is not None:
|
||||
on_solve_end(width, height)
|
||||
|
||||
pano_depth = np.exp(x).reshape(height, width).astype(np.float32)
|
||||
pano_mask = np.any(pano_pred_masks, axis=0)
|
||||
return pano_depth, pano_mask
|
||||
94
comfy/ldm/moge/state_dict.py
Normal file
94
comfy/ldm/moge/state_dict.py
Normal file
@ -0,0 +1,94 @@
|
||||
"""Translate MoGe checkpoint keys to the layouts our nn.Modules use.
|
||||
|
||||
MoGe checkpoints embed DINOv2 with the original Meta naming
|
||||
(``backbone.blocks.{i}.attn.qkv.weight``, ``ls1.gamma``, ``mlp.w12``, ...).
|
||||
The shared ``comfy.image_encoders.dino2.Dinov2Model`` uses HF naming
|
||||
(``encoder.layer.{i}.attention.attention.{query,key,value}.weight``,
|
||||
``layer_scale1.lambda1``, ``mlp.weights_in``, ...). We rewrite keys at load
|
||||
time and split the fused ``qkv`` weight into separate Q/K/V tensors.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
|
||||
_DINOV2_TOPLEVEL_RENAMES = {
|
||||
"patch_embed.proj.weight": "embeddings.patch_embeddings.projection.weight",
|
||||
"patch_embed.proj.bias": "embeddings.patch_embeddings.projection.bias",
|
||||
"cls_token": "embeddings.cls_token",
|
||||
"pos_embed": "embeddings.position_embeddings",
|
||||
"register_tokens": "embeddings.register_tokens",
|
||||
"mask_token": "embeddings.mask_token",
|
||||
"norm.weight": "layernorm.weight",
|
||||
"norm.bias": "layernorm.bias",
|
||||
}
|
||||
|
||||
_BLOCK_SUFFIX_RENAMES = [
|
||||
("ls1.gamma", "layer_scale1.lambda1"),
|
||||
("ls2.gamma", "layer_scale2.lambda1"),
|
||||
("attn.proj.", "attention.output.dense."),
|
||||
("mlp.w12.", "mlp.weights_in."),
|
||||
("mlp.w3.", "mlp.weights_out."),
|
||||
]
|
||||
|
||||
_BLOCK_RE = re.compile(r"^blocks\.(\d+)\.(.+)$")
|
||||
|
||||
|
||||
def remap_dinov2_keys(sd: dict, src_prefix: str = "") -> dict:
|
||||
"""Rewrite Meta-style DINOv2 keys under ``src_prefix`` to comfy/HF naming.
|
||||
|
||||
Splits each fused ``attn.qkv.{weight,bias}`` into separate
|
||||
``attention.attention.{query,key,value}.{weight,bias}`` tensors using a
|
||||
chunk along the leading dim.
|
||||
|
||||
Keys that do not start with ``src_prefix`` are returned unchanged.
|
||||
"""
|
||||
out: dict = {}
|
||||
for k, v in sd.items():
|
||||
if not k.startswith(src_prefix):
|
||||
out[k] = v
|
||||
continue
|
||||
rel = k[len(src_prefix):]
|
||||
|
||||
# Top-level (cls token, pos embed, patch embed, mask token, register tokens, final norm).
|
||||
if rel in _DINOV2_TOPLEVEL_RENAMES:
|
||||
out[src_prefix + _DINOV2_TOPLEVEL_RENAMES[rel]] = v
|
||||
continue
|
||||
|
||||
m = _BLOCK_RE.match(rel)
|
||||
if not m:
|
||||
out[k] = v
|
||||
continue
|
||||
|
||||
i, sub = m.group(1), m.group(2)
|
||||
|
||||
# Split fused qkv into separate q / k / v tensors.
|
||||
if sub == "attn.qkv.weight" or sub == "attn.qkv.bias":
|
||||
q, kw, vw = v.chunk(3, dim=0)
|
||||
tail = sub.rsplit(".", 1)[1] # weight / bias
|
||||
base = "{}encoder.layer.{}.attention.attention".format(src_prefix, i)
|
||||
out["{}.query.{}".format(base, tail)] = q
|
||||
out["{}.key.{}".format(base, tail)] = kw
|
||||
out["{}.value.{}".format(base, tail)] = vw
|
||||
continue
|
||||
|
||||
for old, new in _BLOCK_SUFFIX_RENAMES:
|
||||
sub = sub.replace(old, new)
|
||||
out["{}encoder.layer.{}.{}".format(src_prefix, i, sub)] = v
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def remap_moge_state_dict(sd: dict) -> dict:
|
||||
"""Convert a full MoGe checkpoint state dict to the layout our modules expect.
|
||||
|
||||
- v1 backbone lives under ``backbone.`` -> rewrite that subtree.
|
||||
- v2 backbone lives under ``encoder.backbone.`` -> rewrite that subtree.
|
||||
|
||||
Everything else (heads, neck, projections, image_mean/std buffers) keeps
|
||||
its original key names and passes through unchanged.
|
||||
"""
|
||||
if any(k.startswith("encoder.backbone.") for k in sd):
|
||||
return remap_dinov2_keys(sd, src_prefix="encoder.backbone.")
|
||||
return remap_dinov2_keys(sd, src_prefix="backbone.")
|
||||
163
comfy/moge.py
Normal file
163
comfy/moge.py
Normal file
@ -0,0 +1,163 @@
|
||||
"""High-level loader and inference wrapper for MoGe v1 / v2 checkpoints.
|
||||
|
||||
Mirrors the structure of :mod:`comfy.clip_vision`: owns the ``nn.Module`` and
|
||||
a :class:`comfy.model_patcher.CoreModelPatcher`, exposes a
|
||||
:meth:`MoGeModel.infer` that runs preprocessing, forward, and post-processing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from numbers import Number
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.model_patcher
|
||||
import comfy.ops
|
||||
import comfy.utils
|
||||
|
||||
from .ldm.moge.geometry import (
|
||||
depth_map_to_point_map,
|
||||
intrinsics_from_focal_center,
|
||||
recover_focal_shift,
|
||||
)
|
||||
from .ldm.moge.model import detect_and_build
|
||||
from .ldm.moge.state_dict import remap_moge_state_dict
|
||||
|
||||
|
||||
class MoGeModel:
|
||||
"""Loaded MoGe model + ComfyUI memory management."""
|
||||
|
||||
def __init__(self, state_dict: dict):
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
|
||||
sd = remap_moge_state_dict(state_dict)
|
||||
self.model = detect_and_build(sd, dtype=self.dtype, device=offload_device,
|
||||
operations=comfy.ops.manual_cast)
|
||||
self.model.load_state_dict(sd, strict=True)
|
||||
self.model.eval()
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(
|
||||
self.model, load_device=self.load_device, offload_device=offload_device
|
||||
)
|
||||
self.version = "v2" if hasattr(self.model, "encoder") else "v1"
|
||||
self.mask_threshold = float(getattr(self.model, "mask_threshold", 0.5))
|
||||
nt = getattr(self.model, "num_tokens_range", (1200, 2500 if self.version == "v1" else 3600))
|
||||
self.num_tokens_range = (int(nt[0]), int(nt[1]))
|
||||
|
||||
@torch.inference_mode()
|
||||
def infer(self,
|
||||
image: torch.Tensor,
|
||||
num_tokens: Optional[int] = None,
|
||||
resolution_level: int = 9,
|
||||
fov_x: Optional[Union[Number, torch.Tensor]] = None,
|
||||
force_projection: bool = True,
|
||||
apply_mask: bool = True) -> Dict[str, torch.Tensor]:
|
||||
"""Run a single MoGe forward + post-process pass.
|
||||
|
||||
``image`` must already be ``(B, 3, H, W)`` in ``[0, 1]`` on any device.
|
||||
Returns a dict with at least ``points``, ``depth``, ``intrinsics``,
|
||||
``mask``; v2 checkpoints additionally produce ``normal``.
|
||||
"""
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
device = self.load_device
|
||||
image = image.to(device=device, dtype=self.dtype)
|
||||
|
||||
if image.dim() == 3:
|
||||
image = image.unsqueeze(0)
|
||||
H, W = image.shape[-2:]
|
||||
aspect_ratio = W / H
|
||||
|
||||
if num_tokens is None:
|
||||
lo, hi = self.num_tokens_range
|
||||
num_tokens = int(lo + (resolution_level / 9) * (hi - lo))
|
||||
|
||||
out = self.model.forward(image, num_tokens=num_tokens)
|
||||
points = out.get("points")
|
||||
normal = out.get("normal")
|
||||
mask = out.get("mask")
|
||||
metric_scale = out.get("metric_scale")
|
||||
|
||||
# Post-processing always runs in fp32 for numerical stability.
|
||||
if points is not None: points = points.float()
|
||||
if normal is not None: normal = normal.float()
|
||||
if mask is not None: mask = mask.float()
|
||||
if metric_scale is not None: metric_scale = metric_scale.float()
|
||||
|
||||
mask_binary = (mask > self.mask_threshold) if mask is not None else None
|
||||
|
||||
depth = None
|
||||
intrinsics = None
|
||||
if points is not None:
|
||||
if fov_x is None:
|
||||
focal, shift = recover_focal_shift(points, mask_binary)
|
||||
# The unconstrained least-squares solver inside recover_focal_shift
|
||||
# can converge to a degenerate solution where (z + shift) is
|
||||
# negative for most pixels, which flips the sign of the
|
||||
# estimated focal. Detect that and fall back to a sensible
|
||||
# 60-degree-FoV default rather than emitting garbage geometry.
|
||||
bad = ~torch.isfinite(focal) | (focal <= 0)
|
||||
if bool(bad.any()):
|
||||
default_fov = 60.0
|
||||
fov_t = torch.as_tensor(default_fov, device=points.device, dtype=points.dtype)
|
||||
fallback_focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 \
|
||||
/ torch.tan(torch.deg2rad(fov_t / 2))
|
||||
fallback_focal = fallback_focal.expand_as(focal).clone()
|
||||
focal = torch.where(bad, fallback_focal, focal)
|
||||
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
|
||||
else:
|
||||
fov_t = torch.as_tensor(fov_x, device=points.device, dtype=points.dtype)
|
||||
focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(fov_t / 2))
|
||||
if focal.ndim == 0:
|
||||
focal = focal[None].expand(points.shape[0])
|
||||
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
|
||||
fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio
|
||||
fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
|
||||
half = torch.tensor(0.5, device=points.device, dtype=points.dtype)
|
||||
intrinsics = intrinsics_from_focal_center(fx, fy, half, half)
|
||||
points[..., 2] = points[..., 2] + shift[..., None, None]
|
||||
# v2 upstream additionally filters mask by depth > 0 as a safeguard
|
||||
# against negative-depth artifacts from the metric-scale path; v1
|
||||
# does not, and applying it there can cut out the foreground when
|
||||
# shift recovery overshoots slightly.
|
||||
if mask_binary is not None and self.version == "v2":
|
||||
mask_binary = mask_binary & (points[..., 2] > 0)
|
||||
depth = points[..., 2].clone()
|
||||
|
||||
if force_projection and depth is not None and intrinsics is not None:
|
||||
points = depth_map_to_point_map(depth, intrinsics=intrinsics)
|
||||
|
||||
if metric_scale is not None:
|
||||
if points is not None:
|
||||
points = points * metric_scale[:, None, None, None]
|
||||
if depth is not None:
|
||||
depth = depth * metric_scale[:, None, None]
|
||||
|
||||
if apply_mask and mask_binary is not None:
|
||||
if points is not None:
|
||||
points = torch.where(mask_binary[..., None], points,
|
||||
torch.full_like(points, float("inf")))
|
||||
if depth is not None:
|
||||
depth = torch.where(mask_binary, depth,
|
||||
torch.full_like(depth, float("inf")))
|
||||
if normal is not None:
|
||||
normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal))
|
||||
|
||||
result = {
|
||||
"points": points,
|
||||
"depth": depth,
|
||||
"intrinsics": intrinsics,
|
||||
"mask": mask_binary,
|
||||
"normal": normal,
|
||||
}
|
||||
return {k: v for k, v in result.items() if v is not None}
|
||||
|
||||
|
||||
def load(ckpt_path: str) -> MoGeModel:
|
||||
"""Load a MoGe ``.pt`` / ``.safetensors`` checkpoint into a :class:`MoGeModel`."""
|
||||
sd = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
||||
if isinstance(sd, dict) and "model" in sd and "model_config" in sd:
|
||||
sd = sd["model"]
|
||||
return MoGeModel(sd)
|
||||
@ -12,9 +12,19 @@ class VOXEL:
|
||||
|
||||
|
||||
class MESH:
|
||||
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor):
|
||||
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor,
|
||||
uvs: torch.Tensor | None = None,
|
||||
vertex_colors: torch.Tensor | None = None,
|
||||
texture: torch.Tensor | None = None):
|
||||
# vertices: (B, N, 3), faces: (B, M, 3). Optional fields:
|
||||
# - uvs: (B, N, 2) per-vertex texture coordinates.
|
||||
# - vertex_colors: (B, N, 3 or 4) per-vertex colors in [0, 1].
|
||||
# - texture: (B, H, W, 3) baseColor texture image in [0, 1] (comfy IMAGE format).
|
||||
self.vertices = vertices
|
||||
self.faces = faces
|
||||
self.uvs = uvs
|
||||
self.vertex_colors = vertex_colors
|
||||
self.texture = texture
|
||||
|
||||
|
||||
class File3D:
|
||||
|
||||
@ -1,12 +1,8 @@
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
import struct
|
||||
import numpy as np
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_from_grid_torch
|
||||
import folder_paths
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args
|
||||
from comfy_extras.nodes_save_3d import pack_variable_mesh_batch
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
from comfy_api.latest._util import MESH, VOXEL # only for backward compatibility if someone import it from this file (will be removed later) # noqa
|
||||
@ -444,7 +440,9 @@ class VoxelToMeshBasic(IO.ComfyNode):
|
||||
vertices.append(v)
|
||||
faces.append(f)
|
||||
|
||||
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
|
||||
if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces):
|
||||
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
|
||||
return IO.NodeOutput(pack_variable_mesh_batch(vertices, faces))
|
||||
|
||||
decode = execute # TODO: remove
|
||||
|
||||
@ -481,206 +479,13 @@ class VoxelToMesh(IO.ComfyNode):
|
||||
vertices.append(v)
|
||||
faces.append(f)
|
||||
|
||||
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
|
||||
if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces):
|
||||
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
|
||||
return IO.NodeOutput(pack_variable_mesh_batch(vertices, faces))
|
||||
|
||||
decode = execute # TODO: remove
|
||||
|
||||
|
||||
def save_glb(vertices, faces, filepath, metadata=None):
|
||||
"""
|
||||
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
|
||||
|
||||
Parameters:
|
||||
vertices: torch.Tensor of shape (N, 3) - The vertex coordinates
|
||||
faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces)
|
||||
filepath: str - Output filepath (should end with .glb)
|
||||
"""
|
||||
|
||||
# Convert tensors to numpy arrays
|
||||
vertices_np = vertices.cpu().numpy().astype(np.float32)
|
||||
faces_np = faces.cpu().numpy().astype(np.uint32)
|
||||
|
||||
vertices_buffer = vertices_np.tobytes()
|
||||
indices_buffer = faces_np.tobytes()
|
||||
|
||||
def pad_to_4_bytes(buffer):
|
||||
padding_length = (4 - (len(buffer) % 4)) % 4
|
||||
return buffer + b'\x00' * padding_length
|
||||
|
||||
vertices_buffer_padded = pad_to_4_bytes(vertices_buffer)
|
||||
indices_buffer_padded = pad_to_4_bytes(indices_buffer)
|
||||
|
||||
buffer_data = vertices_buffer_padded + indices_buffer_padded
|
||||
|
||||
vertices_byte_length = len(vertices_buffer)
|
||||
vertices_byte_offset = 0
|
||||
indices_byte_length = len(indices_buffer)
|
||||
indices_byte_offset = len(vertices_buffer_padded)
|
||||
|
||||
gltf = {
|
||||
"asset": {"version": "2.0", "generator": "ComfyUI"},
|
||||
"buffers": [
|
||||
{
|
||||
"byteLength": len(buffer_data)
|
||||
}
|
||||
],
|
||||
"bufferViews": [
|
||||
{
|
||||
"buffer": 0,
|
||||
"byteOffset": vertices_byte_offset,
|
||||
"byteLength": vertices_byte_length,
|
||||
"target": 34962 # ARRAY_BUFFER
|
||||
},
|
||||
{
|
||||
"buffer": 0,
|
||||
"byteOffset": indices_byte_offset,
|
||||
"byteLength": indices_byte_length,
|
||||
"target": 34963 # ELEMENT_ARRAY_BUFFER
|
||||
}
|
||||
],
|
||||
"accessors": [
|
||||
{
|
||||
"bufferView": 0,
|
||||
"byteOffset": 0,
|
||||
"componentType": 5126, # FLOAT
|
||||
"count": len(vertices_np),
|
||||
"type": "VEC3",
|
||||
"max": vertices_np.max(axis=0).tolist(),
|
||||
"min": vertices_np.min(axis=0).tolist()
|
||||
},
|
||||
{
|
||||
"bufferView": 1,
|
||||
"byteOffset": 0,
|
||||
"componentType": 5125, # UNSIGNED_INT
|
||||
"count": faces_np.size,
|
||||
"type": "SCALAR"
|
||||
}
|
||||
],
|
||||
"meshes": [
|
||||
{
|
||||
"primitives": [
|
||||
{
|
||||
"attributes": {
|
||||
"POSITION": 0
|
||||
},
|
||||
"indices": 1,
|
||||
"mode": 4 # TRIANGLES
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"mesh": 0
|
||||
}
|
||||
],
|
||||
"scenes": [
|
||||
{
|
||||
"nodes": [0]
|
||||
}
|
||||
],
|
||||
"scene": 0
|
||||
}
|
||||
|
||||
if metadata is not None:
|
||||
gltf["asset"]["extras"] = metadata
|
||||
|
||||
# Convert the JSON to bytes
|
||||
gltf_json = json.dumps(gltf).encode('utf8')
|
||||
|
||||
def pad_json_to_4_bytes(buffer):
|
||||
padding_length = (4 - (len(buffer) % 4)) % 4
|
||||
return buffer + b' ' * padding_length
|
||||
|
||||
gltf_json_padded = pad_json_to_4_bytes(gltf_json)
|
||||
|
||||
# Create the GLB header
|
||||
# Magic glTF
|
||||
glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data))
|
||||
|
||||
# Create JSON chunk header (chunk type 0)
|
||||
json_chunk_header = struct.pack('<II', len(gltf_json_padded), 0x4E4F534A) # "JSON" in little endian
|
||||
|
||||
# Create BIN chunk header (chunk type 1)
|
||||
bin_chunk_header = struct.pack('<II', len(buffer_data), 0x004E4942) # "BIN\0" in little endian
|
||||
|
||||
# Write the GLB file
|
||||
with open(filepath, 'wb') as f:
|
||||
f.write(glb_header)
|
||||
f.write(json_chunk_header)
|
||||
f.write(gltf_json_padded)
|
||||
f.write(bin_chunk_header)
|
||||
f.write(buffer_data)
|
||||
|
||||
return filepath
|
||||
|
||||
|
||||
class SaveGLB(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveGLB",
|
||||
display_name="Save 3D Model",
|
||||
search_aliases=["export 3d model", "save mesh"],
|
||||
category="3d",
|
||||
essentials_category="Basics",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
IO.Mesh.Input("mesh"),
|
||||
types=[
|
||||
IO.File3DGLB,
|
||||
IO.File3DGLTF,
|
||||
IO.File3DOBJ,
|
||||
IO.File3DFBX,
|
||||
IO.File3DSTL,
|
||||
IO.File3DUSDZ,
|
||||
IO.File3DAny,
|
||||
],
|
||||
tooltip="Mesh or 3D file to save",
|
||||
),
|
||||
IO.String.Input("filename_prefix", default="3d/ComfyUI"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh: Types.MESH | Types.File3D, filename_prefix: str) -> IO.NodeOutput:
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||
results = []
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
if cls.hidden.prompt is not None:
|
||||
metadata["prompt"] = json.dumps(cls.hidden.prompt)
|
||||
if cls.hidden.extra_pnginfo is not None:
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
||||
|
||||
if isinstance(mesh, Types.File3D):
|
||||
# Handle File3D input - save BytesIO data to output folder
|
||||
ext = mesh.format or "glb"
|
||||
f = f"{filename}_{counter:05}_.{ext}"
|
||||
mesh.save_to(os.path.join(full_output_folder, f))
|
||||
results.append({
|
||||
"filename": f,
|
||||
"subfolder": subfolder,
|
||||
"type": "output"
|
||||
})
|
||||
else:
|
||||
# Handle Mesh input - save vertices and faces as GLB
|
||||
for i in range(mesh.vertices.shape[0]):
|
||||
f = f"{filename}_{counter:05}_.glb"
|
||||
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
|
||||
results.append({
|
||||
"filename": f,
|
||||
"subfolder": subfolder,
|
||||
"type": "output"
|
||||
})
|
||||
counter += 1
|
||||
return IO.NodeOutput(ui={"3d": results})
|
||||
|
||||
|
||||
class Hunyuan3dExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -691,7 +496,6 @@ class Hunyuan3dExtension(ComfyExtension):
|
||||
VAEDecodeHunyuan3D,
|
||||
VoxelToMeshBasic,
|
||||
VoxelToMesh,
|
||||
SaveGLB,
|
||||
]
|
||||
|
||||
|
||||
|
||||
445
comfy_extras/nodes_moge.py
Normal file
445
comfy_extras/nodes_moge.py
Normal file
@ -0,0 +1,445 @@
|
||||
"""ComfyUI nodes for the native MoGe (Monocular Geometry Estimation) integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
from comfy_api.latest import ComfyExtension, Types, io
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy.ldm.moge.model import MoGeModel
|
||||
from comfy.ldm.moge.geometry import triangulate_grid_mesh
|
||||
|
||||
MoGeModelType = io.Custom("MOGE_MODEL")
|
||||
MoGeGeometry = io.Custom("MOGE_GEOMETRY")
|
||||
|
||||
|
||||
@dataclass
|
||||
class _MoGeGeometryPayload:
|
||||
points: Optional[torch.Tensor] # (B, H, W, 3)
|
||||
depth: Optional[torch.Tensor] # (B, H, W)
|
||||
intrinsics: Optional[torch.Tensor] # (B, 3, 3)
|
||||
mask: Optional[torch.Tensor] # (B, H, W) bool
|
||||
normal: Optional[torch.Tensor] # (B, H, W, 3) or None for v1
|
||||
image: torch.Tensor # (B, H, W, 3) in [0, 1], CPU
|
||||
|
||||
|
||||
def _turbo(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Anton Mikhailov polynomial approximation of the turbo colormap."""
|
||||
x = x.clamp(0.0, 1.0)
|
||||
x2 = x * x
|
||||
x3 = x2 * x
|
||||
x4 = x2 * x2
|
||||
x5 = x4 * x
|
||||
r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5
|
||||
g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5
|
||||
b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5
|
||||
return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0)
|
||||
|
||||
|
||||
def _normals_from_points(points: torch.Tensor) -> torch.Tensor:
|
||||
"""Camera-space surface normals from a (B, H, W, 3) point map (v1 fallback)."""
|
||||
finite = torch.isfinite(points).all(dim=-1)
|
||||
pts = torch.where(finite.unsqueeze(-1), points, torch.zeros_like(points))
|
||||
dx = pts[..., :, 2:, :] - pts[..., :, :-2, :]
|
||||
dy = pts[..., 2:, :, :] - pts[..., :-2, :, :]
|
||||
dx = torch.nn.functional.pad(dx.permute(0, 3, 1, 2), (1, 1, 0, 0)).permute(0, 2, 3, 1)
|
||||
dy = torch.nn.functional.pad(dy.permute(0, 3, 1, 2), (0, 0, 1, 1)).permute(0, 2, 3, 1)
|
||||
n = torch.cross(dx, dy, dim=-1)
|
||||
n = torch.nn.functional.normalize(n, dim=-1)
|
||||
return torch.where(finite.unsqueeze(-1), n, torch.zeros_like(n))
|
||||
|
||||
|
||||
def _screen_normals_from_depth(depth: torch.Tensor) -> torch.Tensor:
|
||||
"""Screen-space surface normals (X right, Y down, Z into scene)."""
|
||||
finite = torch.isfinite(depth) & (depth > 0)
|
||||
d = torch.where(finite, depth, torch.zeros_like(depth))
|
||||
H, W = d.shape[-2:]
|
||||
d4d = d.unsqueeze(1)
|
||||
# Scale gradients to normalized image coords so a 45 deg tilt lands as a 45 deg normal regardless of resolution.
|
||||
dz_dx = (d4d[..., :, 2:] - d4d[..., :, :-2]) * (W / 2.0)
|
||||
dz_dy = (d4d[..., 2:, :] - d4d[..., :-2, :]) * (H / 2.0)
|
||||
dz_dx = torch.nn.functional.pad(dz_dx, (1, 1, 0, 0)).squeeze(1)
|
||||
dz_dy = torch.nn.functional.pad(dz_dy, (0, 0, 1, 1)).squeeze(1)
|
||||
n = torch.stack([-dz_dx, -dz_dy, torch.ones_like(d)], dim=-1)
|
||||
n = torch.nn.functional.normalize(n, dim=-1)
|
||||
return torch.where(finite.unsqueeze(-1), n, torch.zeros_like(n))
|
||||
|
||||
|
||||
def _normalize_disparity(depth: torch.Tensor) -> torch.Tensor:
|
||||
"""Per-batch normalize 1/depth to [0, 1] using 0.1/99.9 percentile clipping."""
|
||||
out = torch.zeros_like(depth)
|
||||
for i in range(depth.shape[0]):
|
||||
d = depth[i]
|
||||
valid = torch.isfinite(d) & (d > 0)
|
||||
if not valid.any():
|
||||
continue
|
||||
disp = torch.where(valid, 1.0 / d.clamp_min(1e-6), torch.zeros_like(d))
|
||||
disp_valid = disp[valid]
|
||||
lo = torch.quantile(disp_valid, 0.001)
|
||||
hi = torch.quantile(disp_valid, 0.999)
|
||||
scale = (hi - lo).clamp_min(1e-6)
|
||||
norm = ((disp - lo) / scale).clamp(0.0, 1.0)
|
||||
out[i] = torch.where(valid, norm, torch.zeros_like(norm))
|
||||
return out
|
||||
|
||||
|
||||
class LoadMoGeModel(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadMoGeModel",
|
||||
display_name="Load MoGe Model",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
io.Combo.Input("model_name", options=folder_paths.get_filename_list("moge")),
|
||||
],
|
||||
outputs=[MoGeModelType.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_name) -> io.NodeOutput:
|
||||
path = folder_paths.get_full_path_or_raise("moge", model_name)
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
return io.NodeOutput(MoGeModel(sd))
|
||||
|
||||
|
||||
class MoGePanoramaInference(io.ComfyNode):
|
||||
"""Equirectangular panorama inference: split into 12 perspective views, run
|
||||
MoGe at fov_x=90 on each, merge via multi-scale Poisson + gradient solve.
|
||||
v2's predicted normals and metric scale are ignored (per-view scales would not align across seams).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGePanoramaInference",
|
||||
display_name="MoGe Panorama Inference",
|
||||
category="image/geometry",
|
||||
inputs=[
|
||||
MoGeModelType.Input("moge_model"),
|
||||
io.Image.Input("image", tooltip="Equirectangular panorama (any aspect)."),
|
||||
io.Int.Input("resolution_level", default=9, min=0, max=9,
|
||||
tooltip="Per-view detail (0 = fast, 9 = slow)."),
|
||||
io.Int.Input("split_resolution", default=512, min=256, max=1024,
|
||||
tooltip="Resolution of each perspective split."),
|
||||
io.Int.Input("merge_resolution", default=1920, min=256, max=8192,
|
||||
tooltip="Long-side resolution of the merged equirect distance map."),
|
||||
io.Int.Input("batch_size", default=4, min=1, max=12,
|
||||
tooltip="Views per inference batch (12 splits total)."),
|
||||
],
|
||||
outputs=[MoGeGeometry.Output(display_name="geometry")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, moge_model, image, resolution_level,
|
||||
split_resolution, merge_resolution, batch_size) -> io.NodeOutput:
|
||||
from comfy.ldm.moge.panorama import (
|
||||
get_panorama_cameras, split_panorama_image, merge_panorama_depth,
|
||||
spherical_uv_to_directions, _uv_grid,
|
||||
)
|
||||
import comfy.model_management as cmm
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
if image.shape[0] != 1:
|
||||
raise ValueError(f"MoGePanoramaInference takes a single image (got batch of {image.shape[0]})")
|
||||
|
||||
H, W = int(image.shape[1]), int(image.shape[2])
|
||||
scale = min(merge_resolution / max(H, W), 1.0)
|
||||
merge_h, merge_w = max(int(H * scale), 32), max(int(W * scale), 32)
|
||||
|
||||
extrinsics, intrinsics = get_panorama_cameras()
|
||||
|
||||
cmm.load_model_gpu(moge_model.patcher)
|
||||
device = moge_model.load_device
|
||||
img_chw = image[0].movedim(-1, -3).to(device=device, dtype=moge_model.dtype)
|
||||
splits = split_panorama_image(img_chw, extrinsics, intrinsics, split_resolution)
|
||||
|
||||
n_views = splits.shape[0]
|
||||
|
||||
# Weight each lsmr solve by 4^level so the final-resolution solve doesn't leave the bar idle.
|
||||
merge_levels: list[tuple[int, int]] = []
|
||||
w_, h_ = merge_w, merge_h
|
||||
while True:
|
||||
merge_levels.append((w_, h_))
|
||||
if max(w_, h_) <= 256:
|
||||
break
|
||||
w_, h_ = w_ // 2, h_ // 2
|
||||
merge_levels.reverse()
|
||||
|
||||
solve_weight = {wh: 4 ** i for i, wh in enumerate(merge_levels)}
|
||||
n_merge_view_units = n_views * len(merge_levels)
|
||||
n_merge_solve_units = sum(solve_weight.values())
|
||||
|
||||
pbar = comfy.utils.ProgressBar(n_views + n_merge_view_units + n_merge_solve_units)
|
||||
done = 0
|
||||
|
||||
distance_maps: list = []
|
||||
masks: list = []
|
||||
with tqdm(total=n_views, desc="MoGe panorama inference") as tq:
|
||||
for i in range(0, n_views, batch_size):
|
||||
batch = splits[i:i + batch_size]
|
||||
# apply_metric_scale=False: per-view scales would not align across overlap seams.
|
||||
result = moge_model.infer(batch, resolution_level=resolution_level,
|
||||
fov_x=90.0, force_projection=True,
|
||||
apply_mask=False, apply_metric_scale=False)
|
||||
distance_maps.extend(list(result["points"].float().norm(dim=-1).cpu().numpy()))
|
||||
masks.extend(list(result["mask"].cpu().numpy()))
|
||||
n = batch.shape[0]
|
||||
done += n
|
||||
pbar.update_absolute(done)
|
||||
tq.update(n)
|
||||
|
||||
with tqdm(total=n_merge_view_units + n_merge_solve_units, desc="MoGe panorama merge: views") as tq:
|
||||
def _on_merge_view():
|
||||
nonlocal done
|
||||
done += 1
|
||||
pbar.update_absolute(done)
|
||||
tq.update(1)
|
||||
|
||||
def _on_solve_start(w, h):
|
||||
tq.set_description(f"MoGe panorama merge: solving {w}x{h}")
|
||||
|
||||
def _on_solve_end(w, h):
|
||||
nonlocal done
|
||||
weight = solve_weight[(w, h)]
|
||||
done += weight
|
||||
pbar.update_absolute(done)
|
||||
tq.update(weight)
|
||||
tq.set_description("MoGe panorama merge: views")
|
||||
|
||||
pano_depth, pano_mask = merge_panorama_depth(
|
||||
merge_w, merge_h, distance_maps, masks, list(extrinsics), intrinsics,
|
||||
on_view=_on_merge_view, on_solve_start=_on_solve_start, on_solve_end=_on_solve_end)
|
||||
|
||||
if (merge_h, merge_w) != (H, W):
|
||||
t = torch.from_numpy(pano_depth).unsqueeze(0).unsqueeze(0)
|
||||
pano_depth = torch.nn.functional.interpolate(t, size=(H, W), mode="bilinear",
|
||||
align_corners=False).squeeze().numpy().astype(np.float32)
|
||||
t = torch.from_numpy(pano_mask.astype(np.uint8)).unsqueeze(0).unsqueeze(0).float()
|
||||
pano_mask = (torch.nn.functional.interpolate(t, size=(H, W), mode="nearest").squeeze().numpy() > 0)
|
||||
|
||||
# Pixels uncovered by any view's predicted foreground are unconstrained in the lsmr solve
|
||||
# and stay at log_depth=0 (depth=1) -- without this push-out they form a sphere shell
|
||||
# woven through the foreground; here we lift them to a far skybox radius instead.
|
||||
if pano_mask.any() and not pano_mask.all():
|
||||
far = float(np.quantile(pano_depth[pano_mask], 0.95)) * 5.0
|
||||
pano_depth = np.where(pano_mask, pano_depth, far).astype(np.float32)
|
||||
|
||||
uv = _uv_grid(H, W)
|
||||
directions = spherical_uv_to_directions(uv)
|
||||
points_np = directions * pano_depth[..., None]
|
||||
|
||||
points = torch.from_numpy(points_np).unsqueeze(0).float()
|
||||
depth = torch.from_numpy(pano_depth).unsqueeze(0).float()
|
||||
mask = torch.from_numpy(pano_mask).unsqueeze(0)
|
||||
|
||||
# Points stay in MoGe spherical coords; MoGePointMapToMesh applies the spherical->glTF rotation
|
||||
# after triangulation -- rotating before would scramble the rtol depth-edge check.
|
||||
geometry = _MoGeGeometryPayload(
|
||||
points=points, depth=depth, intrinsics=None, mask=mask, normal=None,
|
||||
image=image.detach().cpu(),
|
||||
)
|
||||
return io.NodeOutput(geometry)
|
||||
|
||||
|
||||
class MoGeInference(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGeInference",
|
||||
display_name="MoGe Inference",
|
||||
category="image/geometry",
|
||||
inputs=[
|
||||
MoGeModelType.Input("moge_model"),
|
||||
io.Image.Input("image"),
|
||||
io.Int.Input("resolution_level", default=9, min=0, max=9,
|
||||
tooltip="0 = fastest, 9 = most detail."),
|
||||
io.Float.Input("fov_x_degrees", default=0.0, min=0.0, max=170.0, step=0.1,
|
||||
tooltip="Override horizontal FoV. 0 = auto."),
|
||||
io.Int.Input("batch_size", default=4, min=1, max=64,
|
||||
tooltip="Images per inference call. Lower if you OOM on a long video / image set."),
|
||||
io.Boolean.Input("force_projection", default=True),
|
||||
io.Boolean.Input("apply_mask", default=True,
|
||||
tooltip="Set masked-out points/depth to inf."),
|
||||
],
|
||||
outputs=[MoGeGeometry.Output(display_name="geometry")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, moge_model, image, resolution_level, fov_x_degrees,
|
||||
batch_size, force_projection, apply_mask) -> io.NodeOutput:
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
bchw = image.movedim(-1, -3).contiguous()
|
||||
B = bchw.shape[0]
|
||||
fov = None if fov_x_degrees <= 0 else float(fov_x_degrees)
|
||||
|
||||
pbar = comfy.utils.ProgressBar(B)
|
||||
chunks: list[dict] = []
|
||||
with tqdm(total=B, desc="MoGe inference") as tq:
|
||||
for i in range(0, B, batch_size):
|
||||
chunk = bchw[i:i + batch_size]
|
||||
chunks.append(moge_model.infer(chunk, resolution_level=resolution_level, fov_x=fov,
|
||||
force_projection=force_projection, apply_mask=apply_mask))
|
||||
pbar.update_absolute(min(i + batch_size, B))
|
||||
tq.update(chunk.shape[0])
|
||||
|
||||
def stack(field):
|
||||
vals = [c[field] for c in chunks if field in c]
|
||||
return torch.cat(vals, dim=0) if vals else None
|
||||
|
||||
geometry = _MoGeGeometryPayload(
|
||||
points=stack("points"),
|
||||
depth=stack("depth"),
|
||||
intrinsics=stack("intrinsics"),
|
||||
mask=stack("mask"),
|
||||
normal=stack("normal"),
|
||||
image=image.detach().cpu(),
|
||||
)
|
||||
return io.NodeOutput(geometry)
|
||||
|
||||
|
||||
_RENDER_MODES = ["depth", "depth_colored", "normal", "normal_screen", "mask"]
|
||||
|
||||
|
||||
class MoGeRender(io.ComfyNode):
|
||||
"""Render a visualization or mask from a MOGE_GEOMETRY packet."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGeRender",
|
||||
display_name="MoGe Render",
|
||||
category="image/geometry",
|
||||
inputs=[
|
||||
MoGeGeometry.Input("geometry"),
|
||||
io.Combo.Input("output", options=_RENDER_MODES, default="depth_colored"),
|
||||
],
|
||||
outputs=[io.Image.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, geometry, output) -> io.NodeOutput:
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
# Pick the input tensor for the chosen mode and validate availability.
|
||||
if output in ("depth", "depth_colored", "normal_screen"):
|
||||
if geometry.depth is None:
|
||||
raise ValueError("MoGeGeometry has no depth output.")
|
||||
src = geometry.depth
|
||||
elif output == "normal":
|
||||
if geometry.normal is not None:
|
||||
src = geometry.normal
|
||||
elif geometry.points is not None:
|
||||
src = geometry.points
|
||||
else:
|
||||
raise ValueError("MoGeGeometry has neither normals nor points to derive normals from.")
|
||||
elif output == "mask":
|
||||
if geometry.mask is None:
|
||||
raise ValueError("MoGeGeometry has no mask output.")
|
||||
src = geometry.mask
|
||||
else:
|
||||
raise ValueError(f"Unknown output mode: {output}")
|
||||
|
||||
import comfy.model_management as cmm
|
||||
|
||||
B = src.shape[0]
|
||||
pbar = comfy.utils.ProgressBar(B)
|
||||
out: list[torch.Tensor] = []
|
||||
with tqdm(total=B, desc=f"MoGe render: {output}") as tq:
|
||||
for i in range(B):
|
||||
slc = src[i:i + 1].float()
|
||||
if output in ("depth", "depth_colored"):
|
||||
d = _normalize_disparity(slc)
|
||||
out.append(_turbo(d) if output == "depth_colored"
|
||||
else d.unsqueeze(-1).expand(*d.shape, 3).contiguous())
|
||||
elif output == "normal":
|
||||
n = slc if geometry.normal is not None else _normals_from_points(slc)
|
||||
out.append((n * 0.5 + 0.5).clamp(0.0, 1.0))
|
||||
elif output == "normal_screen":
|
||||
n = _screen_normals_from_depth(slc)
|
||||
out.append((n * 0.5 + 0.5).clamp(0.0, 1.0))
|
||||
elif output == "mask":
|
||||
out.append(slc.unsqueeze(-1).expand(*slc.shape, 3).contiguous())
|
||||
pbar.update_absolute(i + 1)
|
||||
tq.update(1)
|
||||
result = torch.cat(out, dim=0).to(device=cmm.intermediate_device(), dtype=cmm.intermediate_dtype())
|
||||
return io.NodeOutput(result)
|
||||
|
||||
|
||||
class MoGePointMapToMesh(io.ComfyNode):
|
||||
"""Triangulate one image of a MoGe point map into a Types.MESH (UVs + texture)."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGePointMapToMesh",
|
||||
display_name="MoGe Point Map to Mesh",
|
||||
category="3d",
|
||||
inputs=[
|
||||
MoGeGeometry.Input("geometry"),
|
||||
io.Int.Input("batch_index", default=0, min=0, max=64,
|
||||
tooltip="Which image of a batched MoGe geometry to mesh. Per-image vertex counts "
|
||||
"differ, so batches can't be stacked into a single MESH."),
|
||||
io.Int.Input("decimation", default=1, min=1, max=8,
|
||||
tooltip="Vertex stride; 1 = full resolution."),
|
||||
io.Float.Input("discontinuity_threshold", default=0.04, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="Drop pixels whose 3x3 depth span exceeds this fraction. 0 = off."),
|
||||
io.Boolean.Input("texture", default=True,
|
||||
tooltip="Carry the source image through as the baseColor texture."),
|
||||
],
|
||||
outputs=[io.Mesh.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, geometry, batch_index, decimation, discontinuity_threshold, texture) -> io.NodeOutput:
|
||||
if geometry.points is None:
|
||||
raise ValueError("MoGeGeometry has no points output.")
|
||||
B = geometry.points.shape[0]
|
||||
if batch_index >= B:
|
||||
raise ValueError(f"batch_index {batch_index} out of range; geometry has batch size {B}.")
|
||||
|
||||
# Pass geometry.depth so the rtol edge check sees radial depth -- for panoramas
|
||||
# points[..., 2] = cos(phi)*r goes negative below the equator and the rtol clamp would drop the bottom half.
|
||||
edge_depth = geometry.depth[batch_index] if geometry.depth is not None else None
|
||||
verts, faces, uvs = triangulate_grid_mesh(
|
||||
geometry.points[batch_index], decimation=decimation,
|
||||
discontinuity_threshold=discontinuity_threshold, depth=edge_depth,
|
||||
)
|
||||
if verts.shape[0] == 0 or faces.shape[0] == 0:
|
||||
raise ValueError("MoGe produced an empty mesh; try discontinuity_threshold=0 or apply_mask=False.")
|
||||
|
||||
if geometry.intrinsics is None:
|
||||
# Panorama: rotate MoGe spherical (Z up) -> glTF (Y up, Z back). Pure rotation
|
||||
# preserves the natural inward winding (correct for inside-the-sphere viewing).
|
||||
verts = verts[:, [1, 2, 0]].contiguous()
|
||||
else:
|
||||
# Perspective MoGe (X right, Y down, Z forward) -> glTF; face flip keeps winding CCW after the Y/Z flip.
|
||||
verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype)
|
||||
faces = faces[:, [0, 2, 1]].contiguous()
|
||||
|
||||
tex = geometry.image[batch_index:batch_index + 1] if texture and geometry.image is not None else None
|
||||
mesh = Types.MESH(
|
||||
vertices=verts.unsqueeze(0),
|
||||
faces=faces.unsqueeze(0),
|
||||
uvs=uvs.unsqueeze(0),
|
||||
texture=tex,
|
||||
)
|
||||
return io.NodeOutput(mesh)
|
||||
|
||||
|
||||
class MoGeExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> MoGeExtension:
|
||||
return MoGeExtension()
|
||||
380
comfy_extras/nodes_save_3d.py
Normal file
380
comfy_extras/nodes_save_3d.py
Normal file
@ -0,0 +1,380 @@
|
||||
"""Save-side 3D nodes: mesh packing/slicing helpers + GLB writer + SaveGLB node.
|
||||
|
||||
Pairs with nodes_load_3d.py (load-side counterpart).
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import folder_paths
|
||||
from comfy.cli_args import args
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
|
||||
|
||||
def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None):
|
||||
# Pack lists of (Nᵢ, *) vertex/face/color/uv tensors into padded batched tensors,
|
||||
# stashing per-item lengths as runtime attrs so consumers can recover the real slice.
|
||||
# uvs are 1:1 with vertices, so they're padded to max_vertices and read with vertex_counts.
|
||||
batch_size = len(vertices)
|
||||
max_vertices = max(v.shape[0] for v in vertices)
|
||||
max_faces = max(f.shape[0] for f in faces)
|
||||
|
||||
packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1]))
|
||||
packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1]))
|
||||
vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64)
|
||||
face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64)
|
||||
|
||||
for i, (v, f) in enumerate(zip(vertices, faces)):
|
||||
packed_vertices[i, :v.shape[0]] = v
|
||||
packed_faces[i, :f.shape[0]] = f
|
||||
|
||||
packed_colors = None
|
||||
color_counts = None
|
||||
if colors is not None:
|
||||
max_colors = max(c.shape[0] for c in colors)
|
||||
packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1]))
|
||||
color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64)
|
||||
for i, c in enumerate(colors):
|
||||
packed_colors[i, :c.shape[0]] = c
|
||||
|
||||
packed_uvs = None
|
||||
if uvs is not None:
|
||||
packed_uvs = uvs[0].new_zeros((batch_size, max_vertices, uvs[0].shape[1]))
|
||||
for i, u in enumerate(uvs):
|
||||
packed_uvs[i, :u.shape[0]] = u
|
||||
|
||||
mesh = Types.MESH(packed_vertices, packed_faces, uvs=packed_uvs, vertex_colors=packed_colors)
|
||||
mesh.vertex_counts = vertex_counts
|
||||
mesh.face_counts = face_counts
|
||||
if color_counts is not None:
|
||||
mesh.color_counts = color_counts
|
||||
return mesh
|
||||
|
||||
|
||||
def get_mesh_batch_item(mesh, index):
|
||||
# Returns (vertices, faces, colors) for batch index, slicing to real lengths
|
||||
# if pack_variable_mesh_batch added per-item counts.
|
||||
if hasattr(mesh, "vertex_counts"):
|
||||
vertex_count = int(mesh.vertex_counts[index].item())
|
||||
face_count = int(mesh.face_counts[index].item())
|
||||
vertices = mesh.vertices[index, :vertex_count]
|
||||
faces = mesh.faces[index, :face_count]
|
||||
colors = None
|
||||
v_colors = getattr(mesh, "vertex_colors", None)
|
||||
if v_colors is not None:
|
||||
if hasattr(mesh, "color_counts"):
|
||||
color_count = int(mesh.color_counts[index].item())
|
||||
colors = v_colors[index, :color_count]
|
||||
else:
|
||||
colors = v_colors[index, :vertex_count]
|
||||
return vertices, faces, colors
|
||||
|
||||
colors = None
|
||||
v_colors = getattr(mesh, "vertex_colors", None)
|
||||
if v_colors is not None:
|
||||
colors = v_colors[index]
|
||||
return mesh.vertices[index], mesh.faces[index], colors
|
||||
|
||||
|
||||
def save_glb(vertices, faces, filepath, metadata=None,
|
||||
uvs=None, vertex_colors=None, texture_image=None):
|
||||
"""
|
||||
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
|
||||
|
||||
Parameters:
|
||||
vertices: torch.Tensor of shape (N, 3) - The vertex coordinates
|
||||
faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces)
|
||||
filepath: str - Output filepath (should end with .glb)
|
||||
metadata: dict - Optional asset.extras metadata
|
||||
uvs: torch.Tensor of shape (N, 2) - Optional per-vertex texture coordinates
|
||||
vertex_colors: torch.Tensor of shape (N, 3) or (N, 4) - Optional per-vertex colors in [0, 1]
|
||||
texture_image: PIL.Image - Optional baseColor texture, embedded as PNG
|
||||
"""
|
||||
|
||||
# Convert tensors to numpy arrays
|
||||
vertices_np = vertices.cpu().numpy().astype(np.float32)
|
||||
faces_np = faces.cpu().numpy().astype(np.uint32)
|
||||
uvs_np = uvs.cpu().numpy().astype(np.float32) if uvs is not None else None
|
||||
colors_np = vertex_colors.cpu().numpy().astype(np.float32) if vertex_colors is not None else None
|
||||
if colors_np is not None:
|
||||
colors_np = np.clip(colors_np, 0.0, 1.0)
|
||||
texture_png_bytes = None
|
||||
if texture_image is not None:
|
||||
import io as _io
|
||||
buf = _io.BytesIO()
|
||||
texture_image.save(buf, format="PNG")
|
||||
texture_png_bytes = buf.getvalue()
|
||||
|
||||
vertices_buffer = vertices_np.tobytes()
|
||||
indices_buffer = faces_np.tobytes()
|
||||
uvs_buffer = uvs_np.tobytes() if uvs_np is not None else b""
|
||||
colors_buffer = colors_np.tobytes() if colors_np is not None else b""
|
||||
texture_buffer = texture_png_bytes if texture_png_bytes is not None else b""
|
||||
|
||||
def pad_to_4_bytes(buffer):
|
||||
padding_length = (4 - (len(buffer) % 4)) % 4
|
||||
return buffer + b'\x00' * padding_length
|
||||
|
||||
vertices_buffer_padded = pad_to_4_bytes(vertices_buffer)
|
||||
indices_buffer_padded = pad_to_4_bytes(indices_buffer)
|
||||
uvs_buffer_padded = pad_to_4_bytes(uvs_buffer)
|
||||
colors_buffer_padded = pad_to_4_bytes(colors_buffer)
|
||||
texture_buffer_padded = pad_to_4_bytes(texture_buffer)
|
||||
|
||||
buffer_data = (vertices_buffer_padded + indices_buffer_padded
|
||||
+ uvs_buffer_padded + colors_buffer_padded + texture_buffer_padded)
|
||||
|
||||
vertices_byte_length = len(vertices_buffer)
|
||||
vertices_byte_offset = 0
|
||||
indices_byte_length = len(indices_buffer)
|
||||
indices_byte_offset = len(vertices_buffer_padded)
|
||||
uvs_byte_offset = indices_byte_offset + len(indices_buffer_padded)
|
||||
colors_byte_offset = uvs_byte_offset + len(uvs_buffer_padded)
|
||||
texture_byte_offset = colors_byte_offset + len(colors_buffer_padded)
|
||||
|
||||
buffer_views = [
|
||||
{
|
||||
"buffer": 0,
|
||||
"byteOffset": vertices_byte_offset,
|
||||
"byteLength": vertices_byte_length,
|
||||
"target": 34962 # ARRAY_BUFFER
|
||||
},
|
||||
{
|
||||
"buffer": 0,
|
||||
"byteOffset": indices_byte_offset,
|
||||
"byteLength": indices_byte_length,
|
||||
"target": 34963 # ELEMENT_ARRAY_BUFFER
|
||||
}
|
||||
]
|
||||
accessors = [
|
||||
{
|
||||
"bufferView": 0,
|
||||
"byteOffset": 0,
|
||||
"componentType": 5126, # FLOAT
|
||||
"count": len(vertices_np),
|
||||
"type": "VEC3",
|
||||
"max": vertices_np.max(axis=0).tolist(),
|
||||
"min": vertices_np.min(axis=0).tolist()
|
||||
},
|
||||
{
|
||||
"bufferView": 1,
|
||||
"byteOffset": 0,
|
||||
"componentType": 5125, # UNSIGNED_INT
|
||||
"count": faces_np.size,
|
||||
"type": "SCALAR"
|
||||
}
|
||||
]
|
||||
primitive_attributes = {"POSITION": 0}
|
||||
|
||||
if uvs_np is not None and len(uvs_np) > 0:
|
||||
buffer_views.append({
|
||||
"buffer": 0,
|
||||
"byteOffset": uvs_byte_offset,
|
||||
"byteLength": len(uvs_buffer),
|
||||
"target": 34962
|
||||
})
|
||||
accessor_idx = len(accessors)
|
||||
accessors.append({
|
||||
"bufferView": len(buffer_views) - 1,
|
||||
"byteOffset": 0,
|
||||
"componentType": 5126,
|
||||
"count": len(uvs_np),
|
||||
"type": "VEC2",
|
||||
})
|
||||
primitive_attributes["TEXCOORD_0"] = accessor_idx
|
||||
|
||||
if colors_np is not None and len(colors_np) > 0:
|
||||
buffer_views.append({
|
||||
"buffer": 0,
|
||||
"byteOffset": colors_byte_offset,
|
||||
"byteLength": len(colors_buffer),
|
||||
"target": 34962
|
||||
})
|
||||
accessor_idx = len(accessors)
|
||||
accessors.append({
|
||||
"bufferView": len(buffer_views) - 1,
|
||||
"byteOffset": 0,
|
||||
"componentType": 5126,
|
||||
"count": len(colors_np),
|
||||
"type": "VEC3" if colors_np.shape[1] == 3 else "VEC4",
|
||||
})
|
||||
primitive_attributes["COLOR_0"] = accessor_idx
|
||||
|
||||
primitive = {
|
||||
"attributes": primitive_attributes,
|
||||
"indices": 1,
|
||||
"mode": 4 # TRIANGLES
|
||||
}
|
||||
|
||||
images = []
|
||||
textures = []
|
||||
samplers = []
|
||||
materials = []
|
||||
if texture_png_bytes is not None and "TEXCOORD_0" in primitive_attributes:
|
||||
buffer_views.append({
|
||||
"buffer": 0,
|
||||
"byteOffset": texture_byte_offset,
|
||||
"byteLength": len(texture_buffer),
|
||||
})
|
||||
images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"})
|
||||
samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071})
|
||||
textures.append({"source": 0, "sampler": 0})
|
||||
materials.append({
|
||||
"pbrMetallicRoughness": {
|
||||
"baseColorTexture": {"index": 0, "texCoord": 0},
|
||||
"metallicFactor": 0.0,
|
||||
"roughnessFactor": 1.0,
|
||||
},
|
||||
"doubleSided": True,
|
||||
})
|
||||
primitive["material"] = 0
|
||||
|
||||
gltf = {
|
||||
"asset": {"version": "2.0", "generator": "ComfyUI"},
|
||||
"buffers": [{"byteLength": len(buffer_data)}],
|
||||
"bufferViews": buffer_views,
|
||||
"accessors": accessors,
|
||||
"meshes": [{"primitives": [primitive]}],
|
||||
"nodes": [{"mesh": 0}],
|
||||
"scenes": [{"nodes": [0]}],
|
||||
"scene": 0,
|
||||
}
|
||||
if images:
|
||||
gltf["images"] = images
|
||||
if samplers:
|
||||
gltf["samplers"] = samplers
|
||||
if textures:
|
||||
gltf["textures"] = textures
|
||||
if materials:
|
||||
gltf["materials"] = materials
|
||||
|
||||
if metadata is not None:
|
||||
gltf["asset"]["extras"] = metadata
|
||||
|
||||
# Convert the JSON to bytes
|
||||
gltf_json = json.dumps(gltf).encode('utf8')
|
||||
|
||||
def pad_json_to_4_bytes(buffer):
|
||||
padding_length = (4 - (len(buffer) % 4)) % 4
|
||||
return buffer + b' ' * padding_length
|
||||
|
||||
gltf_json_padded = pad_json_to_4_bytes(gltf_json)
|
||||
|
||||
# Create the GLB header
|
||||
# Magic glTF
|
||||
glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data))
|
||||
|
||||
# Create JSON chunk header (chunk type 0)
|
||||
json_chunk_header = struct.pack('<II', len(gltf_json_padded), 0x4E4F534A) # "JSON" in little endian
|
||||
|
||||
# Create BIN chunk header (chunk type 1)
|
||||
bin_chunk_header = struct.pack('<II', len(buffer_data), 0x004E4942) # "BIN\0" in little endian
|
||||
|
||||
# Write the GLB file
|
||||
with open(filepath, 'wb') as f:
|
||||
f.write(glb_header)
|
||||
f.write(json_chunk_header)
|
||||
f.write(gltf_json_padded)
|
||||
f.write(bin_chunk_header)
|
||||
f.write(buffer_data)
|
||||
|
||||
return filepath
|
||||
|
||||
|
||||
class SaveGLB(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveGLB",
|
||||
display_name="Save 3D Model",
|
||||
search_aliases=["export 3d model", "save mesh"],
|
||||
category="3d",
|
||||
essentials_category="Basics",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
IO.Mesh.Input("mesh"),
|
||||
types=[
|
||||
IO.File3DGLB,
|
||||
IO.File3DGLTF,
|
||||
IO.File3DOBJ,
|
||||
IO.File3DFBX,
|
||||
IO.File3DSTL,
|
||||
IO.File3DUSDZ,
|
||||
IO.File3DAny,
|
||||
],
|
||||
tooltip="Mesh or 3D file to save",
|
||||
),
|
||||
IO.String.Input("filename_prefix", default="3d/ComfyUI"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh: Types.MESH | Types.File3D, filename_prefix: str) -> IO.NodeOutput:
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||
results = []
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
if cls.hidden.prompt is not None:
|
||||
metadata["prompt"] = json.dumps(cls.hidden.prompt)
|
||||
if cls.hidden.extra_pnginfo is not None:
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
||||
|
||||
if isinstance(mesh, Types.File3D):
|
||||
# Handle File3D input - save BytesIO data to output folder
|
||||
ext = mesh.format or "glb"
|
||||
f = f"{filename}_{counter:05}_.{ext}"
|
||||
mesh.save_to(os.path.join(full_output_folder, f))
|
||||
results.append({
|
||||
"filename": f,
|
||||
"subfolder": subfolder,
|
||||
"type": "output"
|
||||
})
|
||||
else:
|
||||
# Handle Mesh input - save vertices and faces as GLB; carry optional UVs / colors / texture.
|
||||
uvs_b = getattr(mesh, "uvs", None)
|
||||
texture_b = getattr(mesh, "texture", None)
|
||||
for i in range(mesh.vertices.shape[0]):
|
||||
vertices_i, faces_i, v_colors = get_mesh_batch_item(mesh, i)
|
||||
if vertices_i.shape[0] == 0 or faces_i.shape[0] == 0:
|
||||
logging.warning(f"SaveGLB: skipping empty mesh at batch index {i}")
|
||||
continue
|
||||
uvs_i = None
|
||||
if uvs_b is not None:
|
||||
uvs_i = uvs_b[i, :vertices_i.shape[0]] if hasattr(mesh, "vertex_counts") else uvs_b[i]
|
||||
tex_img = None
|
||||
if texture_b is not None:
|
||||
from PIL import Image
|
||||
arr = (texture_b[i].clamp(0.0, 1.0).cpu().numpy() * 255).astype(np.uint8)
|
||||
tex_img = Image.fromarray(arr, mode="RGB")
|
||||
f = f"{filename}_{counter:05}_.glb"
|
||||
save_glb(vertices_i, faces_i, os.path.join(full_output_folder, f), metadata,
|
||||
uvs=uvs_i,
|
||||
vertex_colors=v_colors,
|
||||
texture_image=tex_img)
|
||||
results.append({
|
||||
"filename": f,
|
||||
"subfolder": subfolder,
|
||||
"type": "output"
|
||||
})
|
||||
counter += 1
|
||||
return IO.NodeOutput(ui={"3d": results})
|
||||
|
||||
|
||||
class Save3DExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [SaveGLB]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> Save3DExtension:
|
||||
return Save3DExtension()
|
||||
@ -56,6 +56,8 @@ folder_names_and_paths["background_removal"] = ([os.path.join(models_dir, "backg
|
||||
|
||||
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["moge"] = ([os.path.join(models_dir, "moge")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
|
||||
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user