Compare commits

...

2 Commits

Author SHA1 Message Date
kijai
10837132a4 Some cleanup 2026-05-13 22:06:51 +03:00
kijai
719e6facf9 Update modules.py 2026-05-13 21:58:16 +03:00
3 changed files with 17 additions and 30 deletions

View File

@ -2,7 +2,6 @@
from __future__ import annotations
from functools import partial
from typing import Optional, Tuple
import numpy as np
@ -53,34 +52,24 @@ def depth_map_to_point_map(depth: torch.Tensor, intrinsics: torch.Tensor) -> tor
return rays * depth.unsqueeze(-1)
def _solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray) -> Tuple[float, float]:
def _solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray,
focal: Optional[float] = None) -> Tuple[float, float]:
"""LM-solve for z-shift; when focal is None, also recovers the optimal focal."""
uv = uv.reshape(-1, 2)
xy = xyz[..., :2].reshape(-1, 2)
z = xyz[..., 2].reshape(-1)
def fn(uv_, xy_, z_, shift):
xy_proj = xy_ / (z_ + shift)[:, None]
f = (xy_proj * uv_).sum() / np.square(xy_proj).sum()
return (f * xy_proj - uv_).ravel()
def fn(shift):
xy_proj = xy / (z + shift)[:, None]
f = focal if focal is not None else (xy_proj * uv).sum() / np.square(xy_proj).sum()
return (f * xy_proj - uv).ravel()
sol = least_squares(partial(fn, uv, xy, z), x0=0.0, ftol=1e-3, method="lm")
optim_shift = float(np.asarray(sol["x"]).squeeze())
xy_proj = xy / (z + optim_shift)[:, None]
optim_focal = float((xy_proj * uv).sum() / np.square(xy_proj).sum())
return optim_shift, optim_focal
def _solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float) -> float:
uv = uv.reshape(-1, 2)
xy = xyz[..., :2].reshape(-1, 2)
z = xyz[..., 2].reshape(-1)
def fn(uv_, xy_, z_, shift):
xy_proj = xy_ / (z_ + shift)[:, None]
return (focal * xy_proj - uv_).ravel()
sol = least_squares(partial(fn, uv, xy, z), x0=0.0, ftol=1e-3, method="lm")
return float(np.asarray(sol["x"]).squeeze())
sol = least_squares(fn, x0=0.0, ftol=1e-3, method="lm")
shift = float(np.asarray(sol["x"]).squeeze())
if focal is None:
xy_proj = xy / (z + shift)[:, None]
focal = float((xy_proj * uv).sum() / np.square(xy_proj).sum())
return shift, focal
def recover_focal_shift(points: torch.Tensor, mask: Optional[torch.Tensor] = None,
@ -125,10 +114,10 @@ def recover_focal_shift(points: torch.Tensor, mask: Optional[torch.Tensor] = Non
xyz_i = points_np[i][sel]
uv_i = uv_np[sel]
if focal_np is None:
shift_i, focal_i = _solve_optimal_focal_shift(uv_i, xyz_i)
shift_i, focal_i = _solve_optimal_shift(uv_i, xyz_i)
out_focal.append(focal_i)
else:
shift_i = _solve_optimal_shift(uv_i, xyz_i, float(focal_np[i]))
shift_i, _ = _solve_optimal_shift(uv_i, xyz_i, focal=float(focal_np[i]))
out_shift.append(shift_i)
shift_t = torch.tensor(out_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3])

View File

@ -88,7 +88,6 @@ class MoGeModelV1(nn.Module):
@classmethod
def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast):
"""Detect the v1 head config from sd, build a model, and load weights."""
sd = _remap_state_dict(sd)
n_up = 1 + max(int(k.split(".")[2]) for k in sd if k.startswith("head.upsample_blocks."))
dim_upsample = [sd[f"head.upsample_blocks.{i}.0.0.weight"].shape[1] for i in range(n_up)]
# Each upsample stage is Sequential[upsampler, *res_blocks]; count res blocks at level 0.
@ -157,7 +156,6 @@ class MoGeModelV2(nn.Module):
@classmethod
def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast):
"""Detect the v2 encoder/neck/heads config from sd, build a model, and load weights."""
sd = _remap_state_dict(sd)
backbone = _detect_dinov2(sd, prefix="encoder.backbone.")
depth = backbone["num_hidden_layers"]
n = cls.intermediate_layers

View File

@ -36,8 +36,8 @@ class ResidualConvBlock(nn.Module):
super().__init__()
hidden_channels = hidden_channels if hidden_channels is not None else channels
in_norm_layer = operations.GroupNorm(1, channels) if in_norm == "layer_norm" else nn.Identity()
hidden_norm_layer = (operations.GroupNorm(max(hidden_channels // 32, 1), hidden_channels)
in_norm_layer = operations.GroupNorm(1, channels, dtype=dtype, device=device) if in_norm == "layer_norm" else nn.Identity()
hidden_norm_layer = (operations.GroupNorm(max(hidden_channels // 32, 1), hidden_channels, dtype=dtype, device=device)
if hidden_norm == "group_norm" else nn.Identity())
self.layers = nn.Sequential(