From 10837132a4754cab41b6bbfd9077f978e2ac4203 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 13 May 2026 22:06:51 +0300 Subject: [PATCH] Some cleanup --- comfy/ldm/moge/geometry.py | 41 ++++++++++++++------------------------ comfy/ldm/moge/model.py | 2 -- 2 files changed, 15 insertions(+), 28 deletions(-) diff --git a/comfy/ldm/moge/geometry.py b/comfy/ldm/moge/geometry.py index 9612bd5af..7fdc97871 100644 --- a/comfy/ldm/moge/geometry.py +++ b/comfy/ldm/moge/geometry.py @@ -2,7 +2,6 @@ from __future__ import annotations -from functools import partial from typing import Optional, Tuple import numpy as np @@ -53,34 +52,24 @@ def depth_map_to_point_map(depth: torch.Tensor, intrinsics: torch.Tensor) -> tor return rays * depth.unsqueeze(-1) -def _solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray) -> Tuple[float, float]: +def _solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, + focal: Optional[float] = None) -> Tuple[float, float]: + """LM-solve for z-shift; when focal is None, also recovers the optimal focal.""" uv = uv.reshape(-1, 2) xy = xyz[..., :2].reshape(-1, 2) z = xyz[..., 2].reshape(-1) - def fn(uv_, xy_, z_, shift): - xy_proj = xy_ / (z_ + shift)[:, None] - f = (xy_proj * uv_).sum() / np.square(xy_proj).sum() - return (f * xy_proj - uv_).ravel() + def fn(shift): + xy_proj = xy / (z + shift)[:, None] + f = focal if focal is not None else (xy_proj * uv).sum() / np.square(xy_proj).sum() + return (f * xy_proj - uv).ravel() - sol = least_squares(partial(fn, uv, xy, z), x0=0.0, ftol=1e-3, method="lm") - optim_shift = float(np.asarray(sol["x"]).squeeze()) - xy_proj = xy / (z + optim_shift)[:, None] - optim_focal = float((xy_proj * uv).sum() / np.square(xy_proj).sum()) - return optim_shift, optim_focal - - -def _solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float) -> float: - uv = uv.reshape(-1, 2) - xy = xyz[..., :2].reshape(-1, 2) - z = xyz[..., 2].reshape(-1) - - def fn(uv_, xy_, z_, shift): - xy_proj = xy_ / (z_ + shift)[:, None] - return (focal * xy_proj - uv_).ravel() - - sol = least_squares(partial(fn, uv, xy, z), x0=0.0, ftol=1e-3, method="lm") - return float(np.asarray(sol["x"]).squeeze()) + sol = least_squares(fn, x0=0.0, ftol=1e-3, method="lm") + shift = float(np.asarray(sol["x"]).squeeze()) + if focal is None: + xy_proj = xy / (z + shift)[:, None] + focal = float((xy_proj * uv).sum() / np.square(xy_proj).sum()) + return shift, focal def recover_focal_shift(points: torch.Tensor, mask: Optional[torch.Tensor] = None, @@ -125,10 +114,10 @@ def recover_focal_shift(points: torch.Tensor, mask: Optional[torch.Tensor] = Non xyz_i = points_np[i][sel] uv_i = uv_np[sel] if focal_np is None: - shift_i, focal_i = _solve_optimal_focal_shift(uv_i, xyz_i) + shift_i, focal_i = _solve_optimal_shift(uv_i, xyz_i) out_focal.append(focal_i) else: - shift_i = _solve_optimal_shift(uv_i, xyz_i, float(focal_np[i])) + shift_i, _ = _solve_optimal_shift(uv_i, xyz_i, focal=float(focal_np[i])) out_shift.append(shift_i) shift_t = torch.tensor(out_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3]) diff --git a/comfy/ldm/moge/model.py b/comfy/ldm/moge/model.py index 34246a8d9..6876c4af2 100644 --- a/comfy/ldm/moge/model.py +++ b/comfy/ldm/moge/model.py @@ -88,7 +88,6 @@ class MoGeModelV1(nn.Module): @classmethod def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast): """Detect the v1 head config from sd, build a model, and load weights.""" - sd = _remap_state_dict(sd) n_up = 1 + max(int(k.split(".")[2]) for k in sd if k.startswith("head.upsample_blocks.")) dim_upsample = [sd[f"head.upsample_blocks.{i}.0.0.weight"].shape[1] for i in range(n_up)] # Each upsample stage is Sequential[upsampler, *res_blocks]; count res blocks at level 0. @@ -157,7 +156,6 @@ class MoGeModelV2(nn.Module): @classmethod def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast): """Detect the v2 encoder/neck/heads config from sd, build a model, and load weights.""" - sd = _remap_state_dict(sd) backbone = _detect_dinov2(sd, prefix="encoder.backbone.") depth = backbone["num_hidden_layers"] n = cls.intermediate_layers