mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-14 19:17:32 +08:00
Some cleanup
This commit is contained in:
parent
719e6facf9
commit
10837132a4
@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -53,34 +52,24 @@ def depth_map_to_point_map(depth: torch.Tensor, intrinsics: torch.Tensor) -> tor
|
||||
return rays * depth.unsqueeze(-1)
|
||||
|
||||
|
||||
def _solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray) -> Tuple[float, float]:
|
||||
def _solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray,
|
||||
focal: Optional[float] = None) -> Tuple[float, float]:
|
||||
"""LM-solve for z-shift; when focal is None, also recovers the optimal focal."""
|
||||
uv = uv.reshape(-1, 2)
|
||||
xy = xyz[..., :2].reshape(-1, 2)
|
||||
z = xyz[..., 2].reshape(-1)
|
||||
|
||||
def fn(uv_, xy_, z_, shift):
|
||||
xy_proj = xy_ / (z_ + shift)[:, None]
|
||||
f = (xy_proj * uv_).sum() / np.square(xy_proj).sum()
|
||||
return (f * xy_proj - uv_).ravel()
|
||||
def fn(shift):
|
||||
xy_proj = xy / (z + shift)[:, None]
|
||||
f = focal if focal is not None else (xy_proj * uv).sum() / np.square(xy_proj).sum()
|
||||
return (f * xy_proj - uv).ravel()
|
||||
|
||||
sol = least_squares(partial(fn, uv, xy, z), x0=0.0, ftol=1e-3, method="lm")
|
||||
optim_shift = float(np.asarray(sol["x"]).squeeze())
|
||||
xy_proj = xy / (z + optim_shift)[:, None]
|
||||
optim_focal = float((xy_proj * uv).sum() / np.square(xy_proj).sum())
|
||||
return optim_shift, optim_focal
|
||||
|
||||
|
||||
def _solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float) -> float:
|
||||
uv = uv.reshape(-1, 2)
|
||||
xy = xyz[..., :2].reshape(-1, 2)
|
||||
z = xyz[..., 2].reshape(-1)
|
||||
|
||||
def fn(uv_, xy_, z_, shift):
|
||||
xy_proj = xy_ / (z_ + shift)[:, None]
|
||||
return (focal * xy_proj - uv_).ravel()
|
||||
|
||||
sol = least_squares(partial(fn, uv, xy, z), x0=0.0, ftol=1e-3, method="lm")
|
||||
return float(np.asarray(sol["x"]).squeeze())
|
||||
sol = least_squares(fn, x0=0.0, ftol=1e-3, method="lm")
|
||||
shift = float(np.asarray(sol["x"]).squeeze())
|
||||
if focal is None:
|
||||
xy_proj = xy / (z + shift)[:, None]
|
||||
focal = float((xy_proj * uv).sum() / np.square(xy_proj).sum())
|
||||
return shift, focal
|
||||
|
||||
|
||||
def recover_focal_shift(points: torch.Tensor, mask: Optional[torch.Tensor] = None,
|
||||
@ -125,10 +114,10 @@ def recover_focal_shift(points: torch.Tensor, mask: Optional[torch.Tensor] = Non
|
||||
xyz_i = points_np[i][sel]
|
||||
uv_i = uv_np[sel]
|
||||
if focal_np is None:
|
||||
shift_i, focal_i = _solve_optimal_focal_shift(uv_i, xyz_i)
|
||||
shift_i, focal_i = _solve_optimal_shift(uv_i, xyz_i)
|
||||
out_focal.append(focal_i)
|
||||
else:
|
||||
shift_i = _solve_optimal_shift(uv_i, xyz_i, float(focal_np[i]))
|
||||
shift_i, _ = _solve_optimal_shift(uv_i, xyz_i, focal=float(focal_np[i]))
|
||||
out_shift.append(shift_i)
|
||||
|
||||
shift_t = torch.tensor(out_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3])
|
||||
|
||||
@ -88,7 +88,6 @@ class MoGeModelV1(nn.Module):
|
||||
@classmethod
|
||||
def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
"""Detect the v1 head config from sd, build a model, and load weights."""
|
||||
sd = _remap_state_dict(sd)
|
||||
n_up = 1 + max(int(k.split(".")[2]) for k in sd if k.startswith("head.upsample_blocks."))
|
||||
dim_upsample = [sd[f"head.upsample_blocks.{i}.0.0.weight"].shape[1] for i in range(n_up)]
|
||||
# Each upsample stage is Sequential[upsampler, *res_blocks]; count res blocks at level 0.
|
||||
@ -157,7 +156,6 @@ class MoGeModelV2(nn.Module):
|
||||
@classmethod
|
||||
def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||
"""Detect the v2 encoder/neck/heads config from sd, build a model, and load weights."""
|
||||
sd = _remap_state_dict(sd)
|
||||
backbone = _detect_dinov2(sd, prefix="encoder.backbone.")
|
||||
depth = backbone["num_hidden_layers"]
|
||||
n = cls.intermediate_layers
|
||||
|
||||
Loading…
Reference in New Issue
Block a user