From 919a74f8191ae5761781289a43d78825ce38f44b Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 12 May 2026 16:09:24 +0300 Subject: [PATCH 1/4] initial MoGe support --- comfy/image_encoders/dino2.py | 45 ++- comfy/ldm/moge/__init__.py | 0 comfy/ldm/moge/geometry.py | 200 ++++++++++ comfy/ldm/moge/model.py | 349 ++++++++++++++++++ comfy/ldm/moge/modules.py | 204 +++++++++++ comfy/ldm/moge/panorama.py | 315 ++++++++++++++++ comfy/ldm/moge/state_dict.py | 94 +++++ comfy/moge.py | 163 +++++++++ comfy_api/latest/_util/geometry_types.py | 12 +- comfy_extras/nodes_hunyuan3d.py | 210 +---------- comfy_extras/nodes_moge.py | 445 +++++++++++++++++++++++ comfy_extras/nodes_save_3d.py | 380 +++++++++++++++++++ folder_paths.py | 2 + nodes.py | 2 + 14 files changed, 2214 insertions(+), 207 deletions(-) create mode 100644 comfy/ldm/moge/__init__.py create mode 100644 comfy/ldm/moge/geometry.py create mode 100644 comfy/ldm/moge/model.py create mode 100644 comfy/ldm/moge/modules.py create mode 100644 comfy/ldm/moge/panorama.py create mode 100644 comfy/ldm/moge/state_dict.py create mode 100644 comfy/moge.py create mode 100644 comfy_extras/nodes_moge.py create mode 100644 comfy_extras/nodes_save_3d.py diff --git a/comfy/image_encoders/dino2.py b/comfy/image_encoders/dino2.py index 9b6dace9d..ee86f8309 100644 --- a/comfy/image_encoders/dino2.py +++ b/comfy/image_encoders/dino2.py @@ -106,6 +106,7 @@ class Dino2Encoder(torch.nn.Module): class Dino2PatchEmbeddings(torch.nn.Module): def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None): super().__init__() + self.patch_size = patch_size self.projection = operations.Conv2d( in_channels=num_channels, out_channels=dim, @@ -125,17 +126,37 @@ class Dino2Embeddings(torch.nn.Module): super().__init__() patch_size = 14 image_size = 518 + self.patch_size = patch_size self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations) self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device)) - self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) + self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) # mask_token is a pre-training param, kept only so strict loading accepts the key. self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device)) + def interpolate_pos_encoding(self, x, h_pixels, w_pixels): + pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, torch.float32) + + class_pos = pos_embed[:, 0:1] + patch_pos = pos_embed[:, 1:] + N = patch_pos.shape[1] + M = int(N ** 0.5) + h0 = h_pixels // self.patch_size + w0 = w_pixels // self.patch_size + scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M) # +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0). + + patch_pos = patch_pos.reshape(1, M, M, -1).permute(0, 3, 1, 2) + patch_pos = torch.nn.functional.interpolate(patch_pos, scale_factor=scale_factor, mode="bicubic", antialias=False) + patch_pos = patch_pos.permute(0, 2, 3, 1).flatten(1, 2) + return torch.cat((class_pos, patch_pos), dim=1).to(x.dtype) + def forward(self, pixel_values): x = self.patch_embeddings(pixel_values) - # TODO: mask_token? x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1) - x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype) + if x.shape[1] - 1 == self.position_embeddings.shape[1] - 1: + x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype) + else: + h, w = pixel_values.shape[-2:] + x = x + self.interpolate_pos_encoding(x, h, w) return x @@ -158,3 +179,21 @@ class Dinov2Model(torch.nn.Module): x = self.layernorm(x) pooled_output = x[:, 0, :] return x, i, pooled_output, None + + def get_intermediate_layers(self, pixel_values, indices, apply_norm=True): + x = self.embeddings(pixel_values) + optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) + n_layers = len(self.encoder.layer) + resolved = [(i if i >= 0 else n_layers + i) for i in indices] + target = set(resolved) + max_idx = max(resolved) + n_skip = 1 # skip cls token + cache = {} + for i, layer in enumerate(self.encoder.layer): + x = layer(x, optimized_attention) + if i in target: + normed = self.layernorm(x) if apply_norm else x + cache[i] = (normed[:, n_skip:], normed[:, 0]) + if i >= max_idx: + break + return [cache[i] for i in resolved] diff --git a/comfy/ldm/moge/__init__.py b/comfy/ldm/moge/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/ldm/moge/geometry.py b/comfy/ldm/moge/geometry.py new file mode 100644 index 000000000..3174bd613 --- /dev/null +++ b/comfy/ldm/moge/geometry.py @@ -0,0 +1,200 @@ +"""Pure-torch + scipy geometry helpers for MoGe inference and mesh export.""" + +from __future__ import annotations + +from functools import partial +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F + +from scipy.optimize import least_squares + +def normalized_view_plane_uv(width: int, height: int, aspect_ratio: Optional[float] = None, + dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> torch.Tensor: + """Normalized view-plane UV coordinates with corners at +/-(W, H)/diagonal.""" + if aspect_ratio is None: + aspect_ratio = width / height + span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 + span_y = 1.0 / (1 + aspect_ratio ** 2) ** 0.5 + u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device) + v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device) + u, v = torch.meshgrid(u, v, indexing="xy") + return torch.stack([u, v], dim=-1) + + +def intrinsics_from_focal_center(fx: torch.Tensor, fy: torch.Tensor, cx: torch.Tensor, cy: torch.Tensor) -> torch.Tensor: + """Assemble (..., 3, 3) intrinsics from broadcastable fx, fy, cx, cy.""" + fx, fy, cx, cy = [torch.as_tensor(v) for v in (fx, fy, cx, cy)] + fx, fy, cx, cy = torch.broadcast_tensors(fx, fy, cx, cy) + zero = torch.zeros_like(fx) + one = torch.ones_like(fx) + return torch.stack([ + torch.stack([fx, zero, cx], dim=-1), + torch.stack([zero, fy, cy], dim=-1), + torch.stack([zero, zero, one], dim=-1), + ], dim=-2) + + +def depth_map_to_point_map(depth: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor: + """Back-project a (..., H, W) depth map through K^-1 to (..., H, W, 3) camera-space points. + + Intrinsics use normalized image coords (x in [0, 1] left->right, y in [0, 1] top->bottom). + """ + H, W = depth.shape[-2:] + device, dtype = depth.device, depth.dtype + u = (torch.arange(W, dtype=dtype, device=device) + 0.5) / W + v = (torch.arange(H, dtype=dtype, device=device) + 0.5) / H + grid_v, grid_u = torch.meshgrid(v, u, indexing="ij") + pix = torch.stack([grid_u, grid_v, torch.ones_like(grid_u)], dim=-1) + K_inv = torch.linalg.inv(intrinsics) + rays = torch.einsum("...ij,hwj->...hwi", K_inv, pix) + return rays * depth.unsqueeze(-1) + + +def _solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray) -> Tuple[float, float]: + uv = uv.reshape(-1, 2) + xy = xyz[..., :2].reshape(-1, 2) + z = xyz[..., 2].reshape(-1) + + def fn(uv_, xy_, z_, shift): + xy_proj = xy_ / (z_ + shift)[:, None] + f = (xy_proj * uv_).sum() / np.square(xy_proj).sum() + return (f * xy_proj - uv_).ravel() + + sol = least_squares(partial(fn, uv, xy, z), x0=0.0, ftol=1e-3, method="lm") + optim_shift = float(np.asarray(sol["x"]).squeeze()) + xy_proj = xy / (z + optim_shift)[:, None] + optim_focal = float((xy_proj * uv).sum() / np.square(xy_proj).sum()) + return optim_shift, optim_focal + + +def _solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float) -> float: + uv = uv.reshape(-1, 2) + xy = xyz[..., :2].reshape(-1, 2) + z = xyz[..., 2].reshape(-1) + + def fn(uv_, xy_, z_, shift): + xy_proj = xy_ / (z_ + shift)[:, None] + return (focal * xy_proj - uv_).ravel() + + sol = least_squares(partial(fn, uv, xy, z), x0=0.0, ftol=1e-3, method="lm") + return float(np.asarray(sol["x"]).squeeze()) + + +def recover_focal_shift(points: torch.Tensor, mask: Optional[torch.Tensor] = None, + focal: Optional[torch.Tensor] = None, downsample_size: Tuple[int, int] = (64, 64) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Recover the focal length and z-shift that turn ``points`` into a metric point map. + + Optical center is at the image center; returned focal is relative to half the image diagonal. + Returns ``(focal, shift)`` on the same device/dtype as ``points``. + """ + shape = points.shape + H, W = shape[-3], shape[-2] + points_b = points.reshape(-1, H, W, 3) + mask_b = None if mask is None else mask.reshape(-1, H, W) + focal_b = None if focal is None else focal.reshape(-1) + + uv = normalized_view_plane_uv(W, H, dtype=points.dtype, device=points.device) + + points_lr = F.interpolate(points_b.permute(0, 3, 1, 2), downsample_size, mode="nearest").permute(0, 2, 3, 1) + uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode="nearest").squeeze(0).permute(1, 2, 0) + mask_lr = None + if mask_b is not None: + mask_lr = F.interpolate(mask_b.to(torch.float32).unsqueeze(1), downsample_size, mode="nearest").squeeze(1) > 0 + + uv_np = uv_lr.detach().cpu().numpy() + points_np = points_lr.detach().cpu().numpy() + mask_np = None if mask_lr is None else mask_lr.detach().cpu().numpy() + focal_np = None if focal_b is None else focal_b.detach().cpu().numpy() + + out_focal: list = [] + out_shift: list = [] + for i in range(points_b.shape[0]): + if mask_np is None: + xyz_i = points_np[i].reshape(-1, 3) + uv_i = uv_np.reshape(-1, 2) + else: + sel = mask_np[i] + if sel.sum() < 2: + out_focal.append(1.0) + out_shift.append(0.0) + continue + xyz_i = points_np[i][sel] + uv_i = uv_np[sel] + if focal_np is None: + shift_i, focal_i = _solve_optimal_focal_shift(uv_i, xyz_i) + out_focal.append(focal_i) + else: + shift_i = _solve_optimal_shift(uv_i, xyz_i, float(focal_np[i])) + out_shift.append(shift_i) + + shift_t = torch.tensor(out_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3]) + if focal is None: + focal_t = torch.tensor(out_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3]) + else: + focal_t = focal.reshape(shape[:-3]) + return focal_t, shift_t + + +def depth_map_edge(depth: torch.Tensor, atol: Optional[float] = None, rtol: Optional[float] = None, kernel_size: int = 3) -> torch.Tensor: + """Per-pixel boolean: True where the local depth window's max-min span exceeds atol or rtol*depth.""" + shape = depth.shape + d = depth.reshape(-1, 1, *shape[-2:]) + pad = kernel_size // 2 + diff = F.max_pool2d(d, kernel_size, stride=1, padding=pad) + F.max_pool2d(-d, kernel_size, stride=1, padding=pad) + edge = torch.zeros_like(d, dtype=torch.bool) + if atol is not None: + edge |= diff > atol + if rtol is not None: + edge |= (diff / d.clamp_min(1e-6)).nan_to_num_() > rtol + return edge.reshape(*shape) + + +def triangulate_grid_mesh(points: torch.Tensor, mask: Optional[torch.Tensor] = None, decimation: int = 1, discontinuity_threshold: float = 0.04, + depth: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Triangulate a (H, W, 3) point map into ``(vertices, faces, uvs)`` on CPU. + + Vertices: pixels with finite coords (passing optional ``mask``). Quads with four valid corners + become two triangles. ``depth`` overrides the scalar used for the rtol edge check; pass radial + depth for panoramas (the default ``points[..., 2]`` goes negative below the equator). + """ + points = points.detach().cpu() + finite = torch.isfinite(points).all(dim=-1) + if mask is None: + mask = finite + else: + mask = mask.detach().cpu().to(torch.bool) & finite + + if discontinuity_threshold > 0: + d = depth.detach().cpu() if depth is not None else points[..., 2] + # Replace inf with 0 so max-pool doesn't poison neighbourhoods (mask above already excludes those pixels). + d_finite = torch.where(finite, d, torch.zeros_like(d)) + edge = depth_map_edge(d_finite, rtol=discontinuity_threshold) + mask = mask & ~edge + + if decimation > 1: + points = points[::decimation, ::decimation].contiguous() + mask = mask[::decimation, ::decimation].contiguous() + H, W = points.shape[:2] + + flat_mask = mask.reshape(-1) + idx = torch.full((H * W,), -1, dtype=torch.long) + n_valid = int(flat_mask.sum().item()) + idx[flat_mask] = torch.arange(n_valid, dtype=torch.long) + idx = idx.reshape(H, W) + + vertices = points.reshape(-1, 3)[flat_mask].contiguous() + + yy, xx = torch.meshgrid(torch.arange(H), torch.arange(W), indexing="ij") + u = xx.float() / max(W - 1, 1) + v = yy.float() / max(H - 1, 1) + uvs = torch.stack([u, v], dim=-1).reshape(-1, 2)[flat_mask].contiguous() + + a, b, c, d = idx[:-1, :-1], idx[:-1, 1:], idx[1:, 1:], idx[1:, :-1] + quad_ok = (a >= 0) & (b >= 0) & (c >= 0) & (d >= 0) + a, b, c, d = a[quad_ok], b[quad_ok], c[quad_ok], d[quad_ok] + faces = torch.cat([torch.stack([a, b, c], dim=-1), torch.stack([a, c, d], dim=-1)], dim=0).contiguous() + return vertices, faces, uvs diff --git a/comfy/ldm/moge/model.py b/comfy/ldm/moge/model.py new file mode 100644 index 000000000..fe340f5e1 --- /dev/null +++ b/comfy/ldm/moge/model.py @@ -0,0 +1,349 @@ +"""MoGe v1 / v2 inference modules and a state-dict-driven builder. + +V1: DINOv2 backbone + multi-output head (points, mask). +V2: DINOv2 encoder + neck + per-output heads (points, mask, normal, optional metric-scale MLP). +""" + +from __future__ import annotations + +from numbers import Number +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops +import comfy.model_management +import comfy.model_patcher + +from comfy.image_encoders.dino2 import Dinov2Model + +from .geometry import depth_map_to_point_map, intrinsics_from_focal_center, recover_focal_shift +from .modules import ConvStack, DINOv2Encoder, HeadV1, MLP, _view_plane_uv_grid + + +def _remap_points(points: torch.Tensor) -> torch.Tensor: + """Apply the ``exp`` remap: z -> exp(z), xy stays linear and gets scaled by the new z.""" + xy, z = points.split([2, 1], dim=-1) + z = torch.exp(z) + return torch.cat([xy * z, z], dim=-1) + + +def _detect_dinov2(sd: dict, prefix: str) -> Dict[str, Any]: + # All shipped MoGe checkpoints use plain DINOv2 + hidden = sd[prefix + "embeddings.cls_token"].shape[-1] + layer_prefix = prefix + "encoder.layer." + depth = 1 + max(int(k[len(layer_prefix):].split(".")[0]) for k in sd if k.startswith(layer_prefix)) + return { + "hidden_size": hidden, + "num_attention_heads": hidden // 64, + "num_hidden_layers": depth, + "layer_norm_eps": 1e-6, + "use_swiglu_ffn": False, + } + + +class MoGeModelV1(nn.Module): + """MoGe v1: DINOv2 backbone + HeadV1 (points, mask).""" + + image_mean: torch.Tensor + image_std: torch.Tensor + + intermediate_layers = 4 + num_tokens_range: Tuple[Number, Number] = (1200, 2500) + mask_threshold = 0.5 + + def __init__(self, backbone: Dict[str, Any], dim_upsample: List[int] = (256, 128, 128), + num_res_blocks: int = 1, dim_times_res_block_hidden: int = 1, + dtype=None, device=None, operations=comfy.ops.manual_cast): + super().__init__() + self.backbone = Dinov2Model(backbone, dtype, device, operations) + self.head = HeadV1(dim_in=backbone["hidden_size"], dim_upsample=list(dim_upsample), + num_res_blocks=num_res_blocks, dim_times_res_block_hidden=dim_times_res_block_hidden, + dtype=dtype, device=device, operations=operations) + self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]: + H, W = image.shape[-2:] + resize = ((num_tokens * 14 ** 2) / (H * W)) ** 0.5 + rh, rw = int(H * resize), int(W * resize) + x = F.interpolate(image, (rh, rw), mode="bicubic", align_corners=False, antialias=True) + x = (x - self.image_mean) / self.image_std + x14 = F.interpolate(x, (rh // 14 * 14, rw // 14 * 14), mode="bilinear", align_corners=False, antialias=True) + + n_layers = len(self.backbone.encoder.layer) + indices = list(range(n_layers - self.intermediate_layers, n_layers)) + feats = self.backbone.get_intermediate_layers(x14, indices, apply_norm=True) + + points, mask = self.head(feats, x) + points = F.interpolate(points.float(), (H, W), mode="bilinear", align_corners=False) + points = _remap_points(points.permute(0, 2, 3, 1)) + + mask = F.interpolate(mask.float(), (H, W), mode="bilinear", align_corners=False).squeeze(1) + + return {"points": points, "mask": mask} + + @classmethod + def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast): + """Detect the v1 head config from sd, build a model, and load weights.""" + sd = _remap_state_dict(sd) + n_up = 1 + max(int(k.split(".")[2]) for k in sd if k.startswith("head.upsample_blocks.")) + dim_upsample = [sd[f"head.upsample_blocks.{i}.0.0.weight"].shape[1] for i in range(n_up)] + # Each upsample stage is Sequential[upsampler, *res_blocks]; count res blocks at level 0. + num_res_blocks = max({int(k.split(".")[3]) for k in sd if k.startswith("head.upsample_blocks.0.")}) + hidden_out = sd["head.upsample_blocks.0.1.layers.2.weight"].shape[0] + dim_times = max(hidden_out // dim_upsample[0], 1) + model = cls(backbone=_detect_dinov2(sd, prefix="backbone."), + dim_upsample=dim_upsample, num_res_blocks=num_res_blocks, dim_times_res_block_hidden=dim_times, + dtype=dtype, device=device, operations=operations) + model.load_state_dict(sd, strict=True) + return model + + +class MoGeModelV2(nn.Module): + """MoGe v2: DINOv2 encoder + neck + per-output heads (points/mask/normal/metric-scale).""" + + intermediate_layers = 4 + num_tokens_range: Tuple[Number, Number] = (1200, 3600) + + def __init__(self, + encoder: Dict[str, Any], + neck: Dict[str, Any], + points_head: Dict[str, Any], + mask_head: Dict[str, Any], + scale_head: Dict[str, Any], + normal_head: Optional[Dict[str, Any]] = None, + dtype=None, device=None, operations=comfy.ops.manual_cast): + super().__init__() + self.encoder = DINOv2Encoder(**encoder, dtype=dtype, device=device, operations=operations) + self.neck = ConvStack(**neck, dtype=dtype, device=device, operations=operations) + self.points_head = ConvStack(**points_head, dtype=dtype, device=device, operations=operations) + self.mask_head = ConvStack(**mask_head, dtype=dtype, device=device, operations=operations) + self.scale_head = MLP(**scale_head, dtype=dtype, device=device, operations=operations) + if normal_head is not None: + self.normal_head = ConvStack(**normal_head, dtype=dtype, device=device, operations=operations) + + def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]: + B, _, H, W = image.shape + device, dtype = image.device, image.dtype + aspect_ratio = W / H + base_h = round((num_tokens / aspect_ratio) ** 0.5) + base_w = round((num_tokens * aspect_ratio) ** 0.5) + + feat_top, cls_token = self.encoder(image, base_h, base_w, return_class_token=True) + + # 5-level pyramid: feat at level 0 concatenated with UV, other levels UV-only. + levels = [_view_plane_uv_grid(B, base_h * (2 ** L), base_w * (2 ** L), aspect_ratio, dtype, device) + for L in range(5)] + levels[0] = torch.cat([feat_top, levels[0]], dim=1) + + feats = self.neck(levels) + + def _resize(v): + return F.interpolate(v, (H, W), mode="bilinear", align_corners=False) + + points = _remap_points(_resize(self.points_head(feats)[-1]).permute(0, 2, 3, 1)) + mask = _resize(self.mask_head(feats)[-1]).squeeze(1).sigmoid() + metric_scale = self.scale_head(cls_token).squeeze(1).exp() + + result = {"points": points, "mask": mask, "metric_scale": metric_scale} + if hasattr(self, "normal_head"): + normal = _resize(self.normal_head(feats)[-1]) + result["normal"] = F.normalize(normal.permute(0, 2, 3, 1), dim=-1) + return result + + @classmethod + def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast): + """Detect the v2 encoder/neck/heads config from ``sd``, build a model, and load weights.""" + sd = _remap_state_dict(sd) + backbone = _detect_dinov2(sd, prefix="encoder.backbone.") + depth = backbone["num_hidden_layers"] + n = cls.intermediate_layers + encoder = { + "backbone": backbone, + "intermediate_layers": [(depth // n) * (i + 1) - 1 for i in range(n)], + "dim_out": sd["encoder.output_projections.0.weight"].shape[0], + } + # scale_head is an MLP: Sequential of [Linear, ReLU, ..., Linear]; Linear weight is (out, in). + scale_idxs = sorted({int(k.split(".")[1]) for k in sd if k.startswith("scale_head.")}) + scale_first = sd[f"scale_head.{scale_idxs[0]}.weight"] + cfg: Dict[str, Any] = { + "encoder": encoder, + "neck": cls._detect_convstack(sd, "neck."), + "points_head": cls._detect_convstack(sd, "points_head."), + "mask_head": cls._detect_convstack(sd, "mask_head."), + "scale_head": {"dims": [scale_first.shape[1]] + [sd[f"scale_head.{i}.weight"].shape[0] for i in scale_idxs]}, + } + if any(k.startswith("normal_head.") for k in sd): + cfg["normal_head"] = cls._detect_convstack(sd, "normal_head.") + model = cls(**cfg, dtype=dtype, device=device, operations=operations) + model.load_state_dict(sd, strict=True) + return model + + @staticmethod + def _detect_convstack(sd: dict, prefix: str) -> Dict[str, Any]: + """Reconstruct a ConvStack config from the keys under ``prefix``""" + in_keys = [k for k in sd if k.startswith(f"{prefix}input_blocks.") and k.endswith(".weight")] + n = 1 + max(int(k[len(f"{prefix}input_blocks."):].split(".")[0]) for k in in_keys) + + in_shapes = [sd[f"{prefix}input_blocks.{i}.weight"].shape for i in range(n)] + has_out = lambda i: f"{prefix}output_blocks.{i}.weight" in sd + has_norm = f"{prefix}res_blocks.0.0.layers.0.weight" in sd + + def num_res_at(i): + rb_prefix = f"{prefix}res_blocks.{i}." + return len({int(k[len(rb_prefix):].split(".")[0]) for k in sd if k.startswith(rb_prefix)}) + + return { + "dim_in": [s[1] for s in in_shapes], + "dim_res_blocks": [s[0] for s in in_shapes], + "dim_out": [sd[f"{prefix}output_blocks.{i}.weight"].shape[0] if has_out(i) else None for i in range(n)], + "num_res_blocks": [num_res_at(i) for i in range(n)], + "resamplers": ["conv_transpose" if f"{prefix}resamplers.{i}.0.weight" in sd else "bilinear" + for i in range(n - 1)], + "res_block_in_norm": "layer_norm" if has_norm else "none", + "res_block_hidden_norm": "group_norm" if has_norm else "none", + } + + +# Translate the Meta-style DINOv2 keys MoGe ships to the naming ComfyUI DINOv2 port expects, +# and split each fused qkv tensor into Q/K/V. +_DINOV2_TOPLEVEL_RENAMES = { + "patch_embed.proj.weight": "embeddings.patch_embeddings.projection.weight", + "patch_embed.proj.bias": "embeddings.patch_embeddings.projection.bias", + "cls_token": "embeddings.cls_token", + "pos_embed": "embeddings.position_embeddings", + "register_tokens": "embeddings.register_tokens", + "mask_token": "embeddings.mask_token", + "norm.weight": "layernorm.weight", + "norm.bias": "layernorm.bias", +} +_DINOV2_BLOCK_RENAMES = [ + ("ls1.gamma", "layer_scale1.lambda1"), + ("ls2.gamma", "layer_scale2.lambda1"), + ("attn.proj.", "attention.output.dense."), + ("mlp.w12.", "mlp.weights_in."), + ("mlp.w3.", "mlp.weights_out."), +] + + +def _remap_state_dict(sd: dict) -> dict: + """Unwrap the upstream ``{"model": ..., "model_config": ...}`` envelope and remap DINOv2 keys""" + if "model" in sd and "model_config" in sd: + sd = sd["model"] + prefix = "encoder.backbone." if any(k.startswith("encoder.backbone.") for k in sd) else "backbone." + out: dict = {} + for k, v in sd.items(): + if not k.startswith(prefix): + out[k] = v + continue + rel = k[len(prefix):] + if rel in _DINOV2_TOPLEVEL_RENAMES: + out[prefix + _DINOV2_TOPLEVEL_RENAMES[rel]] = v + continue + if not rel.startswith("blocks."): + out[k] = v + continue + _, idx, sub = rel.split(".", 2) + if sub in ("attn.qkv.weight", "attn.qkv.bias"): + tail = sub.rsplit(".", 1)[1] + q, kw, vw = v.chunk(3, dim=0) + base = f"{prefix}encoder.layer.{idx}.attention.attention" + out[f"{base}.query.{tail}"] = q + out[f"{base}.key.{tail}"] = kw + out[f"{base}.value.{tail}"] = vw + continue + for old, new in _DINOV2_BLOCK_RENAMES: + sub = sub.replace(old, new) + out[f"{prefix}encoder.layer.{idx}.{sub}"] = v + return out + + +def build_from_state_dict(sd: dict, dtype=None, device=None, operations=comfy.ops.manual_cast) -> nn.Module: + """Dispatch to v1 or v2 based on the DINOv2 backbone prefix.""" + sd = _remap_state_dict(sd) + cls = MoGeModelV2 if any(k.startswith("encoder.backbone.") for k in sd) else MoGeModelV1 + return cls.from_state_dict(sd, dtype=dtype, device=device, operations=operations) + + +class MoGeModel: + """Loaded MoGe model + ComfyUI memory management.""" + + def __init__(self, state_dict: dict): + self.load_device = comfy.model_management.text_encoder_device() + offload_device = comfy.model_management.text_encoder_offload_device() + self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) + + self.model = build_from_state_dict(state_dict, dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast).eval() + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.version = "v2" if hasattr(self.model, "encoder") else "v1" + self.mask_threshold = float(getattr(self.model, "mask_threshold", 0.5)) + nt = getattr(self.model, "num_tokens_range", (1200, 2500 if self.version == "v1" else 3600)) + self.num_tokens_range = (int(nt[0]), int(nt[1])) + + def infer(self, image: torch.Tensor, num_tokens: Optional[int] = None, + resolution_level: int = 9, fov_x: Optional[Union[Number, torch.Tensor]] = None, + force_projection: bool = True, apply_mask: bool = True, + apply_metric_scale: bool = True + ) -> Dict[str, torch.Tensor]: + """Run a single MoGe forward + post-process pass. ``image`` is (B, 3, H, W) in [0, 1].""" + comfy.model_management.load_model_gpu(self.patcher) + image = image.to(device=self.load_device, dtype=self.dtype) + H, W = image.shape[-2:] + aspect_ratio = W / H + + if num_tokens is None: + lo, hi = self.num_tokens_range + num_tokens = int(lo + (resolution_level / 9) * (hi - lo)) + + out = self.model.forward(image, num_tokens=num_tokens) + points = out["points"].float() # recover_focal_shift goes through scipy on CPU; needs fp32. + mask_binary = out["mask"] > self.mask_threshold + normal = out.get("normal") + metric_scale = out.get("metric_scale") + + diag = (1 + aspect_ratio ** 2) ** 0.5 + + def focal_from_fov_deg(deg): + fov = torch.as_tensor(deg, device=points.device, dtype=points.dtype) + return aspect_ratio / diag / torch.tan(torch.deg2rad(fov / 2)) + + if fov_x is None: + focal, shift = recover_focal_shift(points, mask_binary) + # Fall back to 60 deg FoV when the least-squares solver flips the focal sign. + bad = ~torch.isfinite(focal) | (focal <= 0) + if bool(bad.any()): + focal = torch.where(bad, focal_from_fov_deg(60.0), focal) + _, shift = recover_focal_shift(points, mask_binary, focal=focal) + else: + focal = focal_from_fov_deg(fov_x).expand(points.shape[0]) + _, shift = recover_focal_shift(points, mask_binary, focal=focal) + + f_diag = focal / 2 * diag + half = torch.tensor(0.5, device=points.device, dtype=points.dtype) + intrinsics = intrinsics_from_focal_center(f_diag / aspect_ratio, f_diag, half, half) + points[..., 2] = points[..., 2] + shift[..., None, None] + # v2 only: filter mask by depth>0 to drop metric-scale negative-depth artifacts. + if self.version == "v2": + mask_binary = mask_binary & (points[..., 2] > 0) + depth = points[..., 2].clone() + + if force_projection: + points = depth_map_to_point_map(depth, intrinsics=intrinsics) + + if apply_metric_scale and metric_scale is not None: + points = points * metric_scale[:, None, None, None] + depth = depth * metric_scale[:, None, None] + + if apply_mask: + points = torch.where(mask_binary[..., None], points, torch.full_like(points, float("inf"))) + depth = torch.where(mask_binary, depth, torch.full_like(depth, float("inf"))) + if normal is not None: + normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal)) + + result = {"points": points, "depth": depth, "intrinsics": intrinsics, "mask": mask_binary} + if normal is not None: + result["normal"] = normal + return result diff --git a/comfy/ldm/moge/modules.py b/comfy/ldm/moge/modules.py new file mode 100644 index 000000000..ff7b0878a --- /dev/null +++ b/comfy/ldm/moge/modules.py @@ -0,0 +1,204 @@ +"""Building blocks for MoGe: residual conv stack, resamplers, MLP, DINOv2 encoder, v1 head.""" + +from __future__ import annotations + +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops +from comfy.image_encoders.dino2 import Dinov2Model + +from .geometry import normalized_view_plane_uv + + +def _conv2d(operations, c_in: int, c_out: int, k: int = 3, *, dtype=None, device=None): + return operations.Conv2d(c_in, c_out, kernel_size=k, padding=k // 2, padding_mode="replicate", dtype=dtype, device=device) + + +def _view_plane_uv_grid(batch: int, height: int, width: int, aspect_ratio: float, dtype, device) -> torch.Tensor: + """Batched normalized view-plane UV grid as a (B, 2, H, W) tensor.""" + uv = normalized_view_plane_uv(width, height, aspect_ratio=aspect_ratio, dtype=dtype, device=device) + return uv.permute(2, 0, 1).unsqueeze(0).expand(batch, -1, -1, -1) + + +def _concat_view_plane_uv(x: torch.Tensor, aspect_ratio: float) -> torch.Tensor: + """Append a 2-channel normalized view-plane UV grid to x along the channel dim.""" + uv = _view_plane_uv_grid(x.shape[0], x.shape[-2], x.shape[-1], aspect_ratio, x.dtype, x.device) + return torch.cat([x, uv], dim=1) + + +class ResidualConvBlock(nn.Module): + def __init__(self, channels: int, hidden_channels: Optional[int] = None, in_norm: str = "layer_norm", hidden_norm: str = "group_norm", + dtype=None, device=None, operations=comfy.ops.manual_cast): + super().__init__() + hidden_channels = hidden_channels if hidden_channels is not None else channels + + in_norm_layer = operations.GroupNorm(1, channels) if in_norm == "layer_norm" else nn.Identity() + hidden_norm_layer = (operations.GroupNorm(max(hidden_channels // 32, 1), hidden_channels) + if hidden_norm == "group_norm" else nn.Identity()) + + self.layers = nn.Sequential( + in_norm_layer, nn.ReLU(), _conv2d(operations, channels, hidden_channels, dtype=dtype, device=device), + hidden_norm_layer, nn.ReLU(), _conv2d(operations, hidden_channels, channels, dtype=dtype, device=device), + ) + + def forward(self, x): + return self.layers(x) + x + + +class Resampler(nn.Sequential): + """2x upsampler: ConvTranspose2d(2x2) or bilinear upsample, followed by a 3x3 conv.""" + + def __init__(self, in_channels: int, out_channels: int, type_: str, dtype=None, device=None, operations=comfy.ops.manual_cast): + if type_ == "conv_transpose": + up = operations.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, dtype=dtype, device=device) + conv_in = out_channels + else: # "bilinear" + up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False) + conv_in = in_channels + super().__init__(up, _conv2d(operations, conv_in, out_channels, dtype=dtype, device=device)) + + +class MLP(nn.Sequential): + def __init__(self, dims: Sequence[int], dtype=None, device=None, operations=comfy.ops.manual_cast): + layers = [] + for d_in, d_out in zip(dims[:-2], dims[1:-1]): + layers.append(operations.Linear(d_in, d_out, dtype=dtype, device=device)) + layers.append(nn.ReLU(inplace=True)) + layers.append(operations.Linear(dims[-2], dims[-1], dtype=dtype, device=device)) + super().__init__(*layers) + + +class ConvStack(nn.Module): + def __init__(self, dim_in: List[Optional[int]], dim_res_blocks: List[int], dim_out: List[Optional[int]], resamplers: List[str], + num_res_blocks: List[int], dim_times_res_block_hidden: int = 1, res_block_in_norm: str = "layer_norm", res_block_hidden_norm: str = "group_norm", + dtype=None, device=None, operations=comfy.ops.manual_cast): + super().__init__() + + self.input_blocks = nn.ModuleList([ + (_conv2d(operations, d_in, d_res, k=1, dtype=dtype, device=device) + if d_in is not None else nn.Identity()) + for d_in, d_res in zip(dim_in, dim_res_blocks) + ]) + + self.resamplers = nn.ModuleList([ + Resampler(prev, succ, type_=r, dtype=dtype, device=device, operations=operations) + for prev, succ, r in zip(dim_res_blocks[:-1], dim_res_blocks[1:], resamplers) + ]) + + self.res_blocks = nn.ModuleList([ + nn.Sequential(*[ + ResidualConvBlock(d_res, dim_times_res_block_hidden * d_res, in_norm=res_block_in_norm, hidden_norm=res_block_hidden_norm, dtype=dtype, device=device, operations=operations) + for _ in range(num_res_blocks[i]) + ]) + for i, d_res in enumerate(dim_res_blocks) + ]) + + self.output_blocks = nn.ModuleList([ + (_conv2d(operations, d_res, d_out, k=1, dtype=dtype, device=device) + if d_out is not None else nn.Identity()) + for d_out, d_res in zip(dim_out, dim_res_blocks) + ]) + + def forward(self, in_features: List[Optional[torch.Tensor]]): + out_features = [] + x = None + for i in range(len(self.res_blocks)): + feat = self.input_blocks[i](in_features[i]) if in_features[i] is not None else None + if i == 0: + x = feat + elif feat is not None: + x = x + feat + x = self.res_blocks[i](x) + out_features.append(self.output_blocks[i](x)) + if i < len(self.res_blocks) - 1: + x = self.resamplers[i](x) + return out_features + + +class DINOv2Encoder(nn.Module): + """Comfy DINOv2 backbone with per-layer 1x1 projection heads.""" + + def __init__(self, backbone: dict, intermediate_layers: List[int], dim_out: int, dtype=None, device=None, operations=comfy.ops.manual_cast): + super().__init__() + self.intermediate_layers = list(intermediate_layers) + dim_features = backbone["hidden_size"] + self.backbone = Dinov2Model(backbone, dtype, device, operations) + self.output_projections = nn.ModuleList([ + _conv2d(operations, dim_features, dim_out, k=1, dtype=dtype, device=device) + for _ in range(len(self.intermediate_layers)) + ]) + self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, image: torch.Tensor, token_rows: int, token_cols: int, + return_class_token: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + image_14 = F.interpolate(image, (token_rows * 14, token_cols * 14), mode="bilinear", align_corners=False, antialias=True) + image_14 = (image_14 - self.image_mean) / self.image_std + feats = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, apply_norm=True) + x = torch.stack([ + proj(feat.permute(0, 2, 1).unflatten(2, (token_rows, token_cols)).contiguous()) + for proj, (feat, _cls) in zip(self.output_projections, feats) + ], dim=1).sum(dim=1) + if return_class_token: + return x, feats[-1][1] + return x + + +class HeadV1(nn.Module): + """v1 head: 4 backbone-feature projections -> shared upsample stack -> per-target output convs (points, mask).""" + + NUM_FEATURES = 4 + DIM_PROJ = 512 + DIM_OUT = (3, 1) # 3 channels for points, 1 for mask + LAST_CONV_CHANNELS = 32 + + def __init__(self, dim_in: int, dim_upsample: List[int] = (256, 128, 128), num_res_blocks: int = 1, dim_times_res_block_hidden: int = 1, + dtype=None, device=None, operations=comfy.ops.manual_cast): + super().__init__() + self.projects = nn.ModuleList([ + _conv2d(operations, dim_in, self.DIM_PROJ, k=1, dtype=dtype, device=device) + for _ in range(self.NUM_FEATURES) + ]) + def upsampler(in_ch, out_ch): + return nn.Sequential( + operations.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2, dtype=dtype, device=device), + _conv2d(operations, out_ch, out_ch, dtype=dtype, device=device), + ) + + in_chs = [self.DIM_PROJ] + list(dim_upsample[:-1]) + self.upsample_blocks = nn.ModuleList([ + nn.Sequential( + upsampler(in_ch + 2, out_ch), + *(ResidualConvBlock(out_ch, dim_times_res_block_hidden * out_ch, dtype=dtype, device=device, operations=operations) + for _ in range(num_res_blocks)) + ) + for in_ch, out_ch in zip(in_chs, dim_upsample) + ]) + self.output_block = nn.ModuleList([ + nn.Sequential( + _conv2d(operations, dim_upsample[-1] + 2, self.LAST_CONV_CHANNELS, dtype=dtype, device=device), + nn.ReLU(inplace=True), + _conv2d(operations, self.LAST_CONV_CHANNELS, d_out, k=1, dtype=dtype, device=device), + ) + for d_out in self.DIM_OUT + ]) + + def forward(self, hidden_states, image: torch.Tensor): + img_h, img_w = image.shape[-2:] + patch_h, patch_w = img_h // 14, img_w // 14 + aspect = img_w / img_h + x = torch.stack([ + proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous()) + for proj, (feat, _cls) in zip(self.projects, hidden_states) + ], dim=1).sum(dim=1) + + for block in self.upsample_blocks: + x = block(_concat_view_plane_uv(x, aspect)) + + x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False) + x = _concat_view_plane_uv(x, aspect) + return [block(x) for block in self.output_block] diff --git a/comfy/ldm/moge/panorama.py b/comfy/ldm/moge/panorama.py new file mode 100644 index 000000000..76b2a4daf --- /dev/null +++ b/comfy/ldm/moge/panorama.py @@ -0,0 +1,315 @@ +"""Panorama (equirectangular) inference helpers for MoGe. + +Splits an equirect into 12 perspective views via an icosahedron camera rig, runs +the model per view, and stitches per-view distance maps back into a single +equirect distance map via a multi-scale Poisson + gradient sparse solve. +Image sampling uses ``F.grid_sample`` (GPU); the sparse solve uses ``lsmr`` (CPU). +""" + +from __future__ import annotations + +from typing import Callable, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F + + +def _icosahedron_directions() -> np.ndarray: + """12 icosahedron-vertex directions (non-normalised, matching upstream's vertex order).""" + A = (1.0 + np.sqrt(5.0)) / 2.0 + return np.array([ + [0, 1, A], [0, -1, A], [0, 1, -A], [0, -1, -A], + [1, A, 0], [-1, A, 0], [1, -A, 0], [-1, -A, 0], + [A, 0, 1], [A, 0, -1], [-A, 0, 1], [-A, 0, -1], + ], dtype=np.float32) + + +def _intrinsics_from_fov(fov_x_rad: float, fov_y_rad: float) -> np.ndarray: + """Normalised-image (unit-square) K matrix.""" + fx = 0.5 / np.tan(fov_x_rad / 2) + fy = 0.5 / np.tan(fov_y_rad / 2) + return np.array([[fx, 0, 0.5], [0, fy, 0.5], [0, 0, 1]], dtype=np.float32) + + +def _extrinsics_look_at(eye: np.ndarray, target: np.ndarray, up: np.ndarray) -> np.ndarray: + """OpenCV-convention world->camera extrinsics for an array of look-at targets (N, 4, 4).""" + eye = np.asarray(eye, dtype=np.float32) + target = np.asarray(target, dtype=np.float32) + up = np.asarray(up, dtype=np.float32) + if target.ndim == 1: + target = target[None] + + fwd = target - eye + fwd = fwd / np.linalg.norm(fwd, axis=-1, keepdims=True).clip(1e-12) + right = np.cross(fwd, up) + right_norm = np.linalg.norm(right, axis=-1, keepdims=True) + # Fall back to an arbitrary perpendicular if forward is parallel to up. + parallel = right_norm.squeeze(-1) < 1e-6 + if parallel.any(): + alt_up = np.array([1, 0, 0], dtype=np.float32) + right = np.where(parallel[:, None], np.cross(fwd, alt_up), right) + right_norm = np.linalg.norm(right, axis=-1, keepdims=True) + right = right / right_norm.clip(1e-12) + new_up = np.cross(fwd, right) + + R = np.stack([right, new_up, fwd], axis=-2) + t = -np.einsum("nij,j->ni", R, eye) + E = np.zeros((R.shape[0], 4, 4), dtype=np.float32) + E[:, :3, :3] = R + E[:, :3, 3] = t + E[:, 3, 3] = 1.0 + return E + + +def get_panorama_cameras() -> Tuple[np.ndarray, List[np.ndarray]]: + """Returns (extrinsics (12, 4, 4), [intrinsics] * 12) for icosahedron views at 90 deg FoV.""" + targets = _icosahedron_directions() + eye = np.zeros(3, dtype=np.float32) + up = np.array([0, 0, 1], dtype=np.float32) + extrinsics = _extrinsics_look_at(eye, targets, up) + K = _intrinsics_from_fov(np.deg2rad(90.0), np.deg2rad(90.0)) + return extrinsics, [K] * len(targets) + + +def spherical_uv_to_directions(uv: np.ndarray) -> np.ndarray: + """Equirect UV in [0, 1] -> 3D unit-direction (Z up).""" + theta = (1 - uv[..., 0]) * (2 * np.pi) + phi = uv[..., 1] * np.pi + return np.stack([ + np.sin(phi) * np.cos(theta), + np.sin(phi) * np.sin(theta), + np.cos(phi), + ], axis=-1).astype(np.float32) + + +def directions_to_spherical_uv(directions: np.ndarray) -> np.ndarray: + """3D direction -> equirect UV in [0, 1].""" + n = np.linalg.norm(directions, axis=-1, keepdims=True).clip(1e-12) + d = directions / n + u = 1 - np.arctan2(d[..., 1], d[..., 0]) / (2 * np.pi) % 1.0 + v = np.arccos(d[..., 2].clip(-1, 1)) / np.pi + return np.stack([u, v], axis=-1).astype(np.float32) + + +def _uv_grid(H: int, W: int) -> np.ndarray: + """Pixel-center UV grid in [0, 1]; (H, W, 2).""" + u = (np.arange(W, dtype=np.float32) + 0.5) / W + v = (np.arange(H, dtype=np.float32) + 0.5) / H + return np.stack(np.meshgrid(u, v, indexing="xy"), axis=-1) + + +def _unproject_cv(uv: np.ndarray, depth: np.ndarray, + extrinsics: np.ndarray, intrinsics: np.ndarray) -> np.ndarray: + """Back-project pixels into world coords (OpenCV convention).""" + pix = np.concatenate([uv, np.ones_like(uv[..., :1])], axis=-1) + K_inv = np.linalg.inv(intrinsics) + cam = pix @ K_inv.T * depth[..., None] + cam_h = np.concatenate([cam, np.ones_like(cam[..., :1])], axis=-1) + E_inv = np.linalg.inv(extrinsics) + return (cam_h @ E_inv.T)[..., :3] + + +def _project_cv(points: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """World coords -> (uv, depth) in the camera (OpenCV convention).""" + pts_h = np.concatenate([points, np.ones_like(points[..., :1])], axis=-1) + cam = pts_h @ extrinsics.T + cam_xyz = cam[..., :3] + depth = cam_xyz[..., 2] + proj = cam_xyz @ intrinsics.T + uv = proj[..., :2] / proj[..., 2:3].clip(1e-12) + return uv.astype(np.float32), depth.astype(np.float32) + + +def _grid_sample_uv(img_bchw: torch.Tensor, uv: torch.Tensor, mode: str = "bilinear") -> torch.Tensor: + """Sample img_bchw at UV-in-[0,1] coords ``uv`` of shape (B, H, W, 2); replicate-border.""" + grid = uv * 2.0 - 1.0 + return F.grid_sample(img_bchw, grid, mode=mode, padding_mode="border", align_corners=False) + + +def split_panorama_image(image: torch.Tensor, extrinsics: np.ndarray, intrinsics: List[np.ndarray], resolution: int) -> torch.Tensor: + """(3, Hp, Wp) equirect on any device -> (N, 3, R, R) perspective crops on the same device.""" + device = image.device + N = len(extrinsics) + uv = _uv_grid(resolution, resolution) + sample_uvs = [] + for i in range(N): + world = _unproject_cv(uv, np.ones(uv.shape[:-1], dtype=np.float32), extrinsics[i], intrinsics[i]) + sample_uvs.append(directions_to_spherical_uv(world)) + sample_uvs = np.stack(sample_uvs, axis=0) + + img_bchw = image.unsqueeze(0).expand(N, -1, -1, -1).contiguous() + sample_uvs_t = torch.from_numpy(sample_uvs).to(device=device, dtype=image.dtype) + return _grid_sample_uv(img_bchw, sample_uvs_t, mode="bilinear") + + +def _poisson_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False): + """Sparse Laplacian operator over the H x W grid.""" + from scipy.sparse import csr_array + grid_index = np.arange(H * W).reshape(H, W) + grid_index = np.pad(grid_index, ((0, 0), (1, 1)), mode="wrap" if wrap_x else "edge") + grid_index = np.pad(grid_index, ((1, 1), (0, 0)), mode="wrap" if wrap_y else "edge") + + data = np.array([[-4, 1, 1, 1, 1]], dtype=np.float32).repeat(H * W, axis=0).reshape(-1) + indices = np.stack([ + grid_index[1:-1, 1:-1], + grid_index[:-2, 1:-1], grid_index[2:, 1:-1], + grid_index[1:-1, :-2], grid_index[1:-1, 2:], + ], axis=-1).reshape(-1) + indptr = np.arange(0, H * W * 5 + 1, 5) + return csr_array((data, indices, indptr), shape=(H * W, H * W)) + + +def _grad_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False): + """Sparse forward-difference operator over the H x W grid.""" + from scipy.sparse import csr_array + grid_index = np.arange(W * H).reshape(H, W) + if wrap_x: + grid_index = np.pad(grid_index, ((0, 0), (0, 1)), mode="wrap") + if wrap_y: + grid_index = np.pad(grid_index, ((0, 1), (0, 0)), mode="wrap") + + data = np.concatenate([ + np.concatenate([ + np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1), + -np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1), + ], axis=1).reshape(-1), + np.concatenate([ + np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1), + -np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1), + ], axis=1).reshape(-1), + ]) + indices = np.concatenate([ + np.concatenate([grid_index[:, :-1].reshape(-1, 1), grid_index[:, 1:].reshape(-1, 1)], axis=1).reshape(-1), + np.concatenate([grid_index[:-1, :].reshape(-1, 1), grid_index[1:, :].reshape(-1, 1)], axis=1).reshape(-1), + ]) + nx = grid_index.shape[0] * (grid_index.shape[1] - 1) + ny = (grid_index.shape[0] - 1) * grid_index.shape[1] + indptr = np.arange(0, nx * 2 + ny * 2 + 1, 2) + return csr_array((data, indices, indptr), shape=(nx + ny, H * W)) + + +def _scipy_remap_bilinear(img: np.ndarray, sample_pixels: np.ndarray, mode: str = "bilinear") -> np.ndarray: + """Bilinear/nearest sampling at fractional pixel coords; out-of-range clamps to nearest border.""" + from scipy.ndimage import map_coordinates + H, W = img.shape[:2] + yy = np.clip(sample_pixels[..., 1], 0, H - 1) + xx = np.clip(sample_pixels[..., 0], 0, W - 1) + order = 1 if mode == "bilinear" else 0 + if img.ndim == 2: + return map_coordinates(img, [yy, xx], order=order, mode="nearest").astype(img.dtype) + out = np.stack([ + map_coordinates(img[..., c], [yy, xx], order=order, mode="nearest") + for c in range(img.shape[-1]) + ], axis=-1) + return out.astype(img.dtype) + + +def merge_panorama_depth(width: int, height: int, + distance_maps: List[np.ndarray], pred_masks: List[np.ndarray], + extrinsics: List[np.ndarray], intrinsics: List[np.ndarray], + on_view: Optional[Callable[[], None]] = None, + on_solve_start: Optional[Callable[[int, int], None]] = None, + on_solve_end: Optional[Callable[[int, int], None]] = None, + ) -> Tuple[np.ndarray, np.ndarray]: + """Stitch per-view distance maps into a single equirect distance map. + + Recursive multi-scale solve: solves at half resolution first and uses that as the lsmr init + for the full-resolution solve. Optional callbacks fire per view processed and around each + lsmr solve so callers can drive a progress bar. + """ + from scipy.ndimage import convolve + from scipy.sparse import vstack + from scipy.sparse.linalg import lsmr + + if max(width, height) > 256: + coarse_depth, _ = merge_panorama_depth(width // 2, height // 2, + distance_maps, pred_masks, extrinsics, intrinsics, + on_view=on_view, + on_solve_start=on_solve_start, + on_solve_end=on_solve_end) + t = torch.from_numpy(coarse_depth).unsqueeze(0).unsqueeze(0) + t = F.interpolate(t, size=(height, width), mode="bilinear", align_corners=False) + depth_init = t.squeeze().numpy().astype(np.float32) + else: + depth_init = None + + spherical_directions = spherical_uv_to_directions(_uv_grid(height, width)) + + pano_log_grad_maps, pano_grad_masks = [], [] + pano_log_lap_maps, pano_lap_masks = [], [] + pano_pred_masks: List[np.ndarray] = [] + + for i in range(len(distance_maps)): + proj_uv, proj_depth = _project_cv(spherical_directions, extrinsics[i], intrinsics[i]) + proj_valid = (proj_depth > 0) & (proj_uv > 0).all(axis=-1) & (proj_uv < 1).all(axis=-1) + + Hd, Wd = distance_maps[i].shape[:2] + proj_pixels = np.clip(proj_uv, 0, 1) * np.array([Wd - 1, Hd - 1], dtype=np.float32) + + log_dist = np.log(np.clip(distance_maps[i], 1e-6, None)) + sampled = _scipy_remap_bilinear(log_dist, proj_pixels, mode="bilinear") + pano_log = np.where(proj_valid, sampled, 0.0).astype(np.float32) + + sampled_mask = _scipy_remap_bilinear(pred_masks[i].astype(np.uint8), proj_pixels, mode="nearest") + pano_pred = proj_valid & (sampled_mask > 0) + + # Equirect wraps horizontally but not vertically: wrap pad along x, edge pad along y. + padded = np.pad(pano_log, ((0, 0), (0, 1)), mode="wrap") + gx, gy = padded[:, :-1] - padded[:, 1:], padded[:-1, :] - padded[1:, :] + padded_m = np.pad(pano_pred, ((0, 0), (0, 1)), mode="wrap") + mx, my = padded_m[:, :-1] & padded_m[:, 1:], padded_m[:-1, :] & padded_m[1:, :] + pano_log_grad_maps.append((gx, gy)) + pano_grad_masks.append((mx, my)) + + padded = np.pad(pano_log, ((1, 1), (0, 0)), mode="edge") + padded = np.pad(padded, ((0, 0), (1, 1)), mode="wrap") + lap_kernel = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=np.float32) + lap = convolve(padded, lap_kernel)[1:-1, 1:-1] + padded_m = np.pad(pano_pred, ((1, 1), (0, 0)), mode="edge") + padded_m = np.pad(padded_m, ((0, 0), (1, 1)), mode="wrap") + m_kernel = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.uint8) + lap_mask = convolve(padded_m.astype(np.uint8), m_kernel)[1:-1, 1:-1] == 5 + pano_log_lap_maps.append(lap) + pano_lap_masks.append(lap_mask) + pano_pred_masks.append(pano_pred) + + if on_view is not None: + on_view() + + gx = np.stack([m[0] for m in pano_log_grad_maps], axis=0) + gy = np.stack([m[1] for m in pano_log_grad_maps], axis=0) + mx = np.stack([m[0] for m in pano_grad_masks], axis=0) + my = np.stack([m[1] for m in pano_grad_masks], axis=0) + gx_avg = (gx * mx).sum(axis=0) / mx.sum(axis=0).clip(1e-3) + gy_avg = (gy * my).sum(axis=0) / my.sum(axis=0).clip(1e-3) + + laps = np.stack(pano_log_lap_maps, axis=0) + lap_masks = np.stack(pano_lap_masks, axis=0) + lap_avg = (laps * lap_masks).sum(axis=0) / lap_masks.sum(axis=0).clip(1e-3) + + grad_x_mask = mx.any(axis=0).reshape(-1) + grad_y_mask = my.any(axis=0).reshape(-1) + grad_mask = np.concatenate([grad_x_mask, grad_y_mask]) + lap_mask_flat = lap_masks.any(axis=0).reshape(-1) + + A = vstack([ + _grad_equation(width, height, wrap_x=True, wrap_y=False)[grad_mask], + _poisson_equation(width, height, wrap_x=True, wrap_y=False)[lap_mask_flat], + ]) + b = np.concatenate([ + gx_avg.reshape(-1)[grad_x_mask], + gy_avg.reshape(-1)[grad_y_mask], + lap_avg.reshape(-1)[lap_mask_flat], + ]) + x0 = np.log(np.clip(depth_init, 1e-6, None)).reshape(-1) if depth_init is not None else None + + if on_solve_start is not None: + on_solve_start(width, height) + x, *_ = lsmr(A, b, atol=1e-5, btol=1e-5, x0=x0, show=False) + if on_solve_end is not None: + on_solve_end(width, height) + + pano_depth = np.exp(x).reshape(height, width).astype(np.float32) + pano_mask = np.any(pano_pred_masks, axis=0) + return pano_depth, pano_mask diff --git a/comfy/ldm/moge/state_dict.py b/comfy/ldm/moge/state_dict.py new file mode 100644 index 000000000..a241d720e --- /dev/null +++ b/comfy/ldm/moge/state_dict.py @@ -0,0 +1,94 @@ +"""Translate MoGe checkpoint keys to the layouts our nn.Modules use. + +MoGe checkpoints embed DINOv2 with the original Meta naming +(``backbone.blocks.{i}.attn.qkv.weight``, ``ls1.gamma``, ``mlp.w12``, ...). +The shared ``comfy.image_encoders.dino2.Dinov2Model`` uses HF naming +(``encoder.layer.{i}.attention.attention.{query,key,value}.weight``, +``layer_scale1.lambda1``, ``mlp.weights_in``, ...). We rewrite keys at load +time and split the fused ``qkv`` weight into separate Q/K/V tensors. +""" + +from __future__ import annotations + +import re + + +_DINOV2_TOPLEVEL_RENAMES = { + "patch_embed.proj.weight": "embeddings.patch_embeddings.projection.weight", + "patch_embed.proj.bias": "embeddings.patch_embeddings.projection.bias", + "cls_token": "embeddings.cls_token", + "pos_embed": "embeddings.position_embeddings", + "register_tokens": "embeddings.register_tokens", + "mask_token": "embeddings.mask_token", + "norm.weight": "layernorm.weight", + "norm.bias": "layernorm.bias", +} + +_BLOCK_SUFFIX_RENAMES = [ + ("ls1.gamma", "layer_scale1.lambda1"), + ("ls2.gamma", "layer_scale2.lambda1"), + ("attn.proj.", "attention.output.dense."), + ("mlp.w12.", "mlp.weights_in."), + ("mlp.w3.", "mlp.weights_out."), +] + +_BLOCK_RE = re.compile(r"^blocks\.(\d+)\.(.+)$") + + +def remap_dinov2_keys(sd: dict, src_prefix: str = "") -> dict: + """Rewrite Meta-style DINOv2 keys under ``src_prefix`` to comfy/HF naming. + + Splits each fused ``attn.qkv.{weight,bias}`` into separate + ``attention.attention.{query,key,value}.{weight,bias}`` tensors using a + chunk along the leading dim. + + Keys that do not start with ``src_prefix`` are returned unchanged. + """ + out: dict = {} + for k, v in sd.items(): + if not k.startswith(src_prefix): + out[k] = v + continue + rel = k[len(src_prefix):] + + # Top-level (cls token, pos embed, patch embed, mask token, register tokens, final norm). + if rel in _DINOV2_TOPLEVEL_RENAMES: + out[src_prefix + _DINOV2_TOPLEVEL_RENAMES[rel]] = v + continue + + m = _BLOCK_RE.match(rel) + if not m: + out[k] = v + continue + + i, sub = m.group(1), m.group(2) + + # Split fused qkv into separate q / k / v tensors. + if sub == "attn.qkv.weight" or sub == "attn.qkv.bias": + q, kw, vw = v.chunk(3, dim=0) + tail = sub.rsplit(".", 1)[1] # weight / bias + base = "{}encoder.layer.{}.attention.attention".format(src_prefix, i) + out["{}.query.{}".format(base, tail)] = q + out["{}.key.{}".format(base, tail)] = kw + out["{}.value.{}".format(base, tail)] = vw + continue + + for old, new in _BLOCK_SUFFIX_RENAMES: + sub = sub.replace(old, new) + out["{}encoder.layer.{}.{}".format(src_prefix, i, sub)] = v + + return out + + +def remap_moge_state_dict(sd: dict) -> dict: + """Convert a full MoGe checkpoint state dict to the layout our modules expect. + + - v1 backbone lives under ``backbone.`` -> rewrite that subtree. + - v2 backbone lives under ``encoder.backbone.`` -> rewrite that subtree. + + Everything else (heads, neck, projections, image_mean/std buffers) keeps + its original key names and passes through unchanged. + """ + if any(k.startswith("encoder.backbone.") for k in sd): + return remap_dinov2_keys(sd, src_prefix="encoder.backbone.") + return remap_dinov2_keys(sd, src_prefix="backbone.") diff --git a/comfy/moge.py b/comfy/moge.py new file mode 100644 index 000000000..a04d9f149 --- /dev/null +++ b/comfy/moge.py @@ -0,0 +1,163 @@ +"""High-level loader and inference wrapper for MoGe v1 / v2 checkpoints. + +Mirrors the structure of :mod:`comfy.clip_vision`: owns the ``nn.Module`` and +a :class:`comfy.model_patcher.CoreModelPatcher`, exposes a +:meth:`MoGeModel.infer` that runs preprocessing, forward, and post-processing. +""" + +from __future__ import annotations + +from numbers import Number +from typing import Dict, Optional, Union + +import torch + +import comfy.model_management +import comfy.model_patcher +import comfy.ops +import comfy.utils + +from .ldm.moge.geometry import ( + depth_map_to_point_map, + intrinsics_from_focal_center, + recover_focal_shift, +) +from .ldm.moge.model import detect_and_build +from .ldm.moge.state_dict import remap_moge_state_dict + + +class MoGeModel: + """Loaded MoGe model + ComfyUI memory management.""" + + def __init__(self, state_dict: dict): + self.load_device = comfy.model_management.text_encoder_device() + offload_device = comfy.model_management.text_encoder_offload_device() + self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) + + sd = remap_moge_state_dict(state_dict) + self.model = detect_and_build(sd, dtype=self.dtype, device=offload_device, + operations=comfy.ops.manual_cast) + self.model.load_state_dict(sd, strict=True) + self.model.eval() + self.patcher = comfy.model_patcher.CoreModelPatcher( + self.model, load_device=self.load_device, offload_device=offload_device + ) + self.version = "v2" if hasattr(self.model, "encoder") else "v1" + self.mask_threshold = float(getattr(self.model, "mask_threshold", 0.5)) + nt = getattr(self.model, "num_tokens_range", (1200, 2500 if self.version == "v1" else 3600)) + self.num_tokens_range = (int(nt[0]), int(nt[1])) + + @torch.inference_mode() + def infer(self, + image: torch.Tensor, + num_tokens: Optional[int] = None, + resolution_level: int = 9, + fov_x: Optional[Union[Number, torch.Tensor]] = None, + force_projection: bool = True, + apply_mask: bool = True) -> Dict[str, torch.Tensor]: + """Run a single MoGe forward + post-process pass. + + ``image`` must already be ``(B, 3, H, W)`` in ``[0, 1]`` on any device. + Returns a dict with at least ``points``, ``depth``, ``intrinsics``, + ``mask``; v2 checkpoints additionally produce ``normal``. + """ + comfy.model_management.load_model_gpu(self.patcher) + device = self.load_device + image = image.to(device=device, dtype=self.dtype) + + if image.dim() == 3: + image = image.unsqueeze(0) + H, W = image.shape[-2:] + aspect_ratio = W / H + + if num_tokens is None: + lo, hi = self.num_tokens_range + num_tokens = int(lo + (resolution_level / 9) * (hi - lo)) + + out = self.model.forward(image, num_tokens=num_tokens) + points = out.get("points") + normal = out.get("normal") + mask = out.get("mask") + metric_scale = out.get("metric_scale") + + # Post-processing always runs in fp32 for numerical stability. + if points is not None: points = points.float() + if normal is not None: normal = normal.float() + if mask is not None: mask = mask.float() + if metric_scale is not None: metric_scale = metric_scale.float() + + mask_binary = (mask > self.mask_threshold) if mask is not None else None + + depth = None + intrinsics = None + if points is not None: + if fov_x is None: + focal, shift = recover_focal_shift(points, mask_binary) + # The unconstrained least-squares solver inside recover_focal_shift + # can converge to a degenerate solution where (z + shift) is + # negative for most pixels, which flips the sign of the + # estimated focal. Detect that and fall back to a sensible + # 60-degree-FoV default rather than emitting garbage geometry. + bad = ~torch.isfinite(focal) | (focal <= 0) + if bool(bad.any()): + default_fov = 60.0 + fov_t = torch.as_tensor(default_fov, device=points.device, dtype=points.dtype) + fallback_focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 \ + / torch.tan(torch.deg2rad(fov_t / 2)) + fallback_focal = fallback_focal.expand_as(focal).clone() + focal = torch.where(bad, fallback_focal, focal) + _, shift = recover_focal_shift(points, mask_binary, focal=focal) + else: + fov_t = torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) + focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(fov_t / 2)) + if focal.ndim == 0: + focal = focal[None].expand(points.shape[0]) + _, shift = recover_focal_shift(points, mask_binary, focal=focal) + fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio + fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 + half = torch.tensor(0.5, device=points.device, dtype=points.dtype) + intrinsics = intrinsics_from_focal_center(fx, fy, half, half) + points[..., 2] = points[..., 2] + shift[..., None, None] + # v2 upstream additionally filters mask by depth > 0 as a safeguard + # against negative-depth artifacts from the metric-scale path; v1 + # does not, and applying it there can cut out the foreground when + # shift recovery overshoots slightly. + if mask_binary is not None and self.version == "v2": + mask_binary = mask_binary & (points[..., 2] > 0) + depth = points[..., 2].clone() + + if force_projection and depth is not None and intrinsics is not None: + points = depth_map_to_point_map(depth, intrinsics=intrinsics) + + if metric_scale is not None: + if points is not None: + points = points * metric_scale[:, None, None, None] + if depth is not None: + depth = depth * metric_scale[:, None, None] + + if apply_mask and mask_binary is not None: + if points is not None: + points = torch.where(mask_binary[..., None], points, + torch.full_like(points, float("inf"))) + if depth is not None: + depth = torch.where(mask_binary, depth, + torch.full_like(depth, float("inf"))) + if normal is not None: + normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal)) + + result = { + "points": points, + "depth": depth, + "intrinsics": intrinsics, + "mask": mask_binary, + "normal": normal, + } + return {k: v for k, v in result.items() if v is not None} + + +def load(ckpt_path: str) -> MoGeModel: + """Load a MoGe ``.pt`` / ``.safetensors`` checkpoint into a :class:`MoGeModel`.""" + sd = comfy.utils.load_torch_file(ckpt_path, safe_load=True) + if isinstance(sd, dict) and "model" in sd and "model_config" in sd: + sd = sd["model"] + return MoGeModel(sd) diff --git a/comfy_api/latest/_util/geometry_types.py b/comfy_api/latest/_util/geometry_types.py index b586fceb3..e1b58ff72 100644 --- a/comfy_api/latest/_util/geometry_types.py +++ b/comfy_api/latest/_util/geometry_types.py @@ -12,9 +12,19 @@ class VOXEL: class MESH: - def __init__(self, vertices: torch.Tensor, faces: torch.Tensor): + def __init__(self, vertices: torch.Tensor, faces: torch.Tensor, + uvs: torch.Tensor | None = None, + vertex_colors: torch.Tensor | None = None, + texture: torch.Tensor | None = None): + # vertices: (B, N, 3), faces: (B, M, 3). Optional fields: + # - uvs: (B, N, 2) per-vertex texture coordinates. + # - vertex_colors: (B, N, 3 or 4) per-vertex colors in [0, 1]. + # - texture: (B, H, W, 3) baseColor texture image in [0, 1] (comfy IMAGE format). self.vertices = vertices self.faces = faces + self.uvs = uvs + self.vertex_colors = vertex_colors + self.texture = texture class File3D: diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index bf18ecb88..3a08a828a 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -1,12 +1,8 @@ import torch -import os -import json -import struct import numpy as np from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_from_grid_torch -import folder_paths import comfy.model_management -from comfy.cli_args import args +from comfy_extras.nodes_save_3d import pack_variable_mesh_batch from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy_api.latest._util import MESH, VOXEL # only for backward compatibility if someone import it from this file (will be removed later) # noqa @@ -444,7 +440,9 @@ class VoxelToMeshBasic(IO.ComfyNode): vertices.append(v) faces.append(f) - return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces): + return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + return IO.NodeOutput(pack_variable_mesh_batch(vertices, faces)) decode = execute # TODO: remove @@ -481,206 +479,13 @@ class VoxelToMesh(IO.ComfyNode): vertices.append(v) faces.append(f) - return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces): + return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + return IO.NodeOutput(pack_variable_mesh_batch(vertices, faces)) decode = execute # TODO: remove -def save_glb(vertices, faces, filepath, metadata=None): - """ - Save PyTorch tensor vertices and faces as a GLB file without external dependencies. - - Parameters: - vertices: torch.Tensor of shape (N, 3) - The vertex coordinates - faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces) - filepath: str - Output filepath (should end with .glb) - """ - - # Convert tensors to numpy arrays - vertices_np = vertices.cpu().numpy().astype(np.float32) - faces_np = faces.cpu().numpy().astype(np.uint32) - - vertices_buffer = vertices_np.tobytes() - indices_buffer = faces_np.tobytes() - - def pad_to_4_bytes(buffer): - padding_length = (4 - (len(buffer) % 4)) % 4 - return buffer + b'\x00' * padding_length - - vertices_buffer_padded = pad_to_4_bytes(vertices_buffer) - indices_buffer_padded = pad_to_4_bytes(indices_buffer) - - buffer_data = vertices_buffer_padded + indices_buffer_padded - - vertices_byte_length = len(vertices_buffer) - vertices_byte_offset = 0 - indices_byte_length = len(indices_buffer) - indices_byte_offset = len(vertices_buffer_padded) - - gltf = { - "asset": {"version": "2.0", "generator": "ComfyUI"}, - "buffers": [ - { - "byteLength": len(buffer_data) - } - ], - "bufferViews": [ - { - "buffer": 0, - "byteOffset": vertices_byte_offset, - "byteLength": vertices_byte_length, - "target": 34962 # ARRAY_BUFFER - }, - { - "buffer": 0, - "byteOffset": indices_byte_offset, - "byteLength": indices_byte_length, - "target": 34963 # ELEMENT_ARRAY_BUFFER - } - ], - "accessors": [ - { - "bufferView": 0, - "byteOffset": 0, - "componentType": 5126, # FLOAT - "count": len(vertices_np), - "type": "VEC3", - "max": vertices_np.max(axis=0).tolist(), - "min": vertices_np.min(axis=0).tolist() - }, - { - "bufferView": 1, - "byteOffset": 0, - "componentType": 5125, # UNSIGNED_INT - "count": faces_np.size, - "type": "SCALAR" - } - ], - "meshes": [ - { - "primitives": [ - { - "attributes": { - "POSITION": 0 - }, - "indices": 1, - "mode": 4 # TRIANGLES - } - ] - } - ], - "nodes": [ - { - "mesh": 0 - } - ], - "scenes": [ - { - "nodes": [0] - } - ], - "scene": 0 - } - - if metadata is not None: - gltf["asset"]["extras"] = metadata - - # Convert the JSON to bytes - gltf_json = json.dumps(gltf).encode('utf8') - - def pad_json_to_4_bytes(buffer): - padding_length = (4 - (len(buffer) % 4)) % 4 - return buffer + b' ' * padding_length - - gltf_json_padded = pad_json_to_4_bytes(gltf_json) - - # Create the GLB header - # Magic glTF - glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data)) - - # Create JSON chunk header (chunk type 0) - json_chunk_header = struct.pack('