mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-24 16:59:29 +08:00
156 lines
5.9 KiB
Python
156 lines
5.9 KiB
Python
import math
|
|
|
|
import einops
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from comfy.ldm.cascade.common import LayerNorm2d_op
|
|
from torch import nn
|
|
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
from ..utils import perspective_projection
|
|
from .transformer import MLP
|
|
|
|
class CameraEncoder(nn.Module):
|
|
def __init__(self, embed_dim: int, patch_size: int = 14, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.patch_size = patch_size
|
|
self.embed_dim = embed_dim
|
|
self.camera = FourierPositionEncoding(n=3, num_bands=16, max_resolution=64)
|
|
|
|
self.conv = operations.Conv2d(embed_dim + 99, embed_dim, kernel_size=1, bias=False, device=device, dtype=dtype)
|
|
self.norm = LayerNorm2d_op(operations)(embed_dim, device=device, dtype=dtype)
|
|
|
|
def forward(self, img_embeddings: torch.Tensor, rays: torch.Tensor):
|
|
B, D, _h, _w = img_embeddings.shape
|
|
|
|
scale = 1 / self.patch_size
|
|
rays = F.interpolate(rays, scale_factor=(scale, scale), mode="bilinear", align_corners=False, antialias=True)
|
|
rays = rays.permute(0, 2, 3, 1).contiguous() # [b, h, w, 2]
|
|
rays = torch.cat([rays, torch.ones_like(rays[..., :1])], dim=-1)
|
|
rays_embeddings = self.camera(pos=rays.reshape(B, -1, 3)) # (bs, N, 99): rays fourier embedding
|
|
rays_embeddings = einops.rearrange(rays_embeddings, "b (h w) c -> b c h w", h=_h, w=_w).contiguous()
|
|
|
|
z = torch.cat([img_embeddings, rays_embeddings], dim=1)
|
|
return self.norm(self.conv(z))
|
|
|
|
|
|
class FourierPositionEncoding(nn.Module):
|
|
"""Sin/cos Fourier features for ray positions"""
|
|
|
|
def __init__(self, n: int, num_bands: int, max_resolution: int):
|
|
super().__init__()
|
|
self.num_bands = num_bands
|
|
self.max_resolution = [max_resolution] * n
|
|
|
|
@property
|
|
def channels(self):
|
|
num_dims = len(self.max_resolution)
|
|
encoding_size = self.num_bands * num_dims
|
|
encoding_size *= 2 # sin-cos
|
|
encoding_size += num_dims # concat
|
|
|
|
return encoding_size
|
|
|
|
def forward(self, pos: torch.Tensor):
|
|
fourier_pos_enc = _generate_fourier_features(pos, num_bands=self.num_bands, max_resolution=self.max_resolution)
|
|
return fourier_pos_enc
|
|
|
|
|
|
def _generate_fourier_features(pos: torch.Tensor, num_bands: int, max_resolution: List[int], min_freq: float = 1.0):
|
|
b, n = pos.shape[:2]
|
|
|
|
freq_bands = torch.stack([torch.linspace(start=min_freq, end=res / 2, steps=num_bands, device=pos.device, dtype=pos.dtype) for res in max_resolution], dim=0)
|
|
|
|
per_pos_features = torch.stack([pos[i, :, :][:, :, None] * freq_bands[None, :, :] for i in range(b)], 0)
|
|
per_pos_features = per_pos_features.reshape(b, n, -1)
|
|
|
|
# Sin-Cos
|
|
per_pos_features = torch.cat([torch.sin(math.pi * per_pos_features), torch.cos(math.pi * per_pos_features)], dim=-1)
|
|
|
|
# Concat with initial pos
|
|
per_pos_features = torch.cat([pos, per_pos_features], dim=-1)
|
|
|
|
return per_pos_features
|
|
|
|
|
|
class PerspectiveHead(nn.Module):
|
|
"""
|
|
Predict camera translation (s, tx, ty) and perform full-perspective 2D reprojection (CLIFF/CameraHMR setup).
|
|
"""
|
|
|
|
def __init__(self, input_dim: int, img_size: Union[int, Tuple[int, int]], # model input size (W, H)
|
|
mlp_depth: int = 1, mlp_channel_div_factor: int = 8, default_scale_factor: float = 1.0,
|
|
device=None, dtype=None, operations=None
|
|
):
|
|
super().__init__()
|
|
|
|
# Metadata to compute 3D skeleton and 2D reprojection
|
|
self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
|
|
self.ncam = 3 # (s, tx, ty)
|
|
self.default_scale_factor = default_scale_factor
|
|
|
|
self.proj = MLP(
|
|
input_dim=input_dim,
|
|
hidden_dim=input_dim // mlp_channel_div_factor,
|
|
output_dim=self.ncam,
|
|
num_layers=mlp_depth,
|
|
device=device,
|
|
dtype=dtype,
|
|
operations=operations,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, init_estimate: Optional[torch.Tensor] = None):
|
|
"""
|
|
Args:
|
|
x: pose token with shape [B, C], usually C=DECODER.DIM
|
|
init_estimate: [B, self.ncam]
|
|
"""
|
|
pred_cam = self.proj(x)
|
|
if init_estimate is not None:
|
|
pred_cam = pred_cam + init_estimate
|
|
|
|
return pred_cam
|
|
|
|
def perspective_projection(
|
|
self,
|
|
points_3d: torch.Tensor,
|
|
pred_cam: torch.Tensor,
|
|
bbox_center: torch.Tensor, # [N, 2], in original image space (w, h)
|
|
bbox_size: torch.Tensor, # [N,], in original image space
|
|
img_size: torch.Tensor,
|
|
cam_int: torch.Tensor, # [B, 3, 3]
|
|
use_intrin_center: bool = False,
|
|
):
|
|
batch_size = points_3d.shape[0]
|
|
pred_cam = pred_cam.clone()
|
|
pred_cam[..., [0, 2]] *= -1 # Camera system difference
|
|
|
|
# Compute camera translation: (scale, x, y) --> (x, y, depth)
|
|
# depth ~= f / s, Note that f is in the NDC space
|
|
s, tx, ty = pred_cam[:, 0], pred_cam[:, 1], pred_cam[:, 2]
|
|
bs = bbox_size * s * self.default_scale_factor + 1e-8
|
|
focal_length = cam_int[:, 0, 0]
|
|
tz = 2 * focal_length / bs
|
|
|
|
if not use_intrin_center:
|
|
cx = 2 * (bbox_center[:, 0] - (img_size[:, 0] / 2)) / bs
|
|
cy = 2 * (bbox_center[:, 1] - (img_size[:, 1] / 2)) / bs
|
|
else:
|
|
cx = 2 * (bbox_center[:, 0] - (cam_int[:, 0, 2])) / bs
|
|
cy = 2 * (bbox_center[:, 1] - (cam_int[:, 1, 2])) / bs
|
|
|
|
pred_cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)
|
|
|
|
# Compute camera translation
|
|
j3d_cam = points_3d + pred_cam_t.unsqueeze(1)
|
|
|
|
# Projection to the image plane, note that the projection output is in original image space now.
|
|
j2d = perspective_projection(j3d_cam, cam_int)
|
|
|
|
return {
|
|
"pred_keypoints_2d": j2d.reshape(batch_size, -1, 2),
|
|
"pred_keypoints_2d_depth": j3d_cam.reshape(batch_size, -1, 3)[:, :, 2],
|
|
"pred_cam_t": pred_cam_t, "focal_length": focal_length,
|
|
}
|