Some cleanup

This commit is contained in:
kijai 2026-05-13 22:06:51 +03:00
parent 719e6facf9
commit 10837132a4
2 changed files with 15 additions and 28 deletions

View File

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
from functools import partial
from typing import Optional, Tuple from typing import Optional, Tuple
import numpy as np import numpy as np
@ -53,34 +52,24 @@ def depth_map_to_point_map(depth: torch.Tensor, intrinsics: torch.Tensor) -> tor
return rays * depth.unsqueeze(-1) return rays * depth.unsqueeze(-1)
def _solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray) -> Tuple[float, float]: def _solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray,
focal: Optional[float] = None) -> Tuple[float, float]:
"""LM-solve for z-shift; when focal is None, also recovers the optimal focal."""
uv = uv.reshape(-1, 2) uv = uv.reshape(-1, 2)
xy = xyz[..., :2].reshape(-1, 2) xy = xyz[..., :2].reshape(-1, 2)
z = xyz[..., 2].reshape(-1) z = xyz[..., 2].reshape(-1)
def fn(uv_, xy_, z_, shift): def fn(shift):
xy_proj = xy_ / (z_ + shift)[:, None] xy_proj = xy / (z + shift)[:, None]
f = (xy_proj * uv_).sum() / np.square(xy_proj).sum() f = focal if focal is not None else (xy_proj * uv).sum() / np.square(xy_proj).sum()
return (f * xy_proj - uv_).ravel() return (f * xy_proj - uv).ravel()
sol = least_squares(partial(fn, uv, xy, z), x0=0.0, ftol=1e-3, method="lm") sol = least_squares(fn, x0=0.0, ftol=1e-3, method="lm")
optim_shift = float(np.asarray(sol["x"]).squeeze()) shift = float(np.asarray(sol["x"]).squeeze())
xy_proj = xy / (z + optim_shift)[:, None] if focal is None:
optim_focal = float((xy_proj * uv).sum() / np.square(xy_proj).sum()) xy_proj = xy / (z + shift)[:, None]
return optim_shift, optim_focal focal = float((xy_proj * uv).sum() / np.square(xy_proj).sum())
return shift, focal
def _solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float) -> float:
uv = uv.reshape(-1, 2)
xy = xyz[..., :2].reshape(-1, 2)
z = xyz[..., 2].reshape(-1)
def fn(uv_, xy_, z_, shift):
xy_proj = xy_ / (z_ + shift)[:, None]
return (focal * xy_proj - uv_).ravel()
sol = least_squares(partial(fn, uv, xy, z), x0=0.0, ftol=1e-3, method="lm")
return float(np.asarray(sol["x"]).squeeze())
def recover_focal_shift(points: torch.Tensor, mask: Optional[torch.Tensor] = None, def recover_focal_shift(points: torch.Tensor, mask: Optional[torch.Tensor] = None,
@ -125,10 +114,10 @@ def recover_focal_shift(points: torch.Tensor, mask: Optional[torch.Tensor] = Non
xyz_i = points_np[i][sel] xyz_i = points_np[i][sel]
uv_i = uv_np[sel] uv_i = uv_np[sel]
if focal_np is None: if focal_np is None:
shift_i, focal_i = _solve_optimal_focal_shift(uv_i, xyz_i) shift_i, focal_i = _solve_optimal_shift(uv_i, xyz_i)
out_focal.append(focal_i) out_focal.append(focal_i)
else: else:
shift_i = _solve_optimal_shift(uv_i, xyz_i, float(focal_np[i])) shift_i, _ = _solve_optimal_shift(uv_i, xyz_i, focal=float(focal_np[i]))
out_shift.append(shift_i) out_shift.append(shift_i)
shift_t = torch.tensor(out_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3]) shift_t = torch.tensor(out_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3])

View File

@ -88,7 +88,6 @@ class MoGeModelV1(nn.Module):
@classmethod @classmethod
def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast): def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast):
"""Detect the v1 head config from sd, build a model, and load weights.""" """Detect the v1 head config from sd, build a model, and load weights."""
sd = _remap_state_dict(sd)
n_up = 1 + max(int(k.split(".")[2]) for k in sd if k.startswith("head.upsample_blocks.")) n_up = 1 + max(int(k.split(".")[2]) for k in sd if k.startswith("head.upsample_blocks."))
dim_upsample = [sd[f"head.upsample_blocks.{i}.0.0.weight"].shape[1] for i in range(n_up)] dim_upsample = [sd[f"head.upsample_blocks.{i}.0.0.weight"].shape[1] for i in range(n_up)]
# Each upsample stage is Sequential[upsampler, *res_blocks]; count res blocks at level 0. # Each upsample stage is Sequential[upsampler, *res_blocks]; count res blocks at level 0.
@ -157,7 +156,6 @@ class MoGeModelV2(nn.Module):
@classmethod @classmethod
def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast): def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast):
"""Detect the v2 encoder/neck/heads config from sd, build a model, and load weights.""" """Detect the v2 encoder/neck/heads config from sd, build a model, and load weights."""
sd = _remap_state_dict(sd)
backbone = _detect_dinov2(sd, prefix="encoder.backbone.") backbone = _detect_dinov2(sd, prefix="encoder.backbone.")
depth = backbone["num_hidden_layers"] depth = backbone["num_hidden_layers"]
n = cls.intermediate_layers n = cls.intermediate_layers