import math import einops import torch import torch.nn.functional as F from comfy.ldm.cascade.common import LayerNorm2d_op from torch import nn from typing import List, Optional, Tuple, Union from ..utils import perspective_projection from .transformer import MLP class CameraEncoder(nn.Module): def __init__(self, embed_dim: int, patch_size: int = 14, device=None, dtype=None, operations=None): super().__init__() self.patch_size = patch_size self.embed_dim = embed_dim self.camera = FourierPositionEncoding(n=3, num_bands=16, max_resolution=64) self.conv = operations.Conv2d(embed_dim + 99, embed_dim, kernel_size=1, bias=False, device=device, dtype=dtype) self.norm = LayerNorm2d_op(operations)(embed_dim, device=device, dtype=dtype) def forward(self, img_embeddings: torch.Tensor, rays: torch.Tensor): B, D, _h, _w = img_embeddings.shape scale = 1 / self.patch_size rays = F.interpolate(rays, scale_factor=(scale, scale), mode="bilinear", align_corners=False, antialias=True) rays = rays.permute(0, 2, 3, 1).contiguous() # [b, h, w, 2] rays = torch.cat([rays, torch.ones_like(rays[..., :1])], dim=-1) rays_embeddings = self.camera(pos=rays.reshape(B, -1, 3)) # (bs, N, 99): rays fourier embedding rays_embeddings = einops.rearrange(rays_embeddings, "b (h w) c -> b c h w", h=_h, w=_w).contiguous() z = torch.cat([img_embeddings, rays_embeddings], dim=1) return self.norm(self.conv(z)) class FourierPositionEncoding(nn.Module): """Sin/cos Fourier features for ray positions""" def __init__(self, n: int, num_bands: int, max_resolution: int): super().__init__() self.num_bands = num_bands self.max_resolution = [max_resolution] * n @property def channels(self): num_dims = len(self.max_resolution) encoding_size = self.num_bands * num_dims encoding_size *= 2 # sin-cos encoding_size += num_dims # concat return encoding_size def forward(self, pos: torch.Tensor): fourier_pos_enc = _generate_fourier_features(pos, num_bands=self.num_bands, max_resolution=self.max_resolution) return fourier_pos_enc def _generate_fourier_features(pos: torch.Tensor, num_bands: int, max_resolution: List[int], min_freq: float = 1.0): b, n = pos.shape[:2] freq_bands = torch.stack([torch.linspace(start=min_freq, end=res / 2, steps=num_bands, device=pos.device, dtype=pos.dtype) for res in max_resolution], dim=0) per_pos_features = torch.stack([pos[i, :, :][:, :, None] * freq_bands[None, :, :] for i in range(b)], 0) per_pos_features = per_pos_features.reshape(b, n, -1) # Sin-Cos per_pos_features = torch.cat([torch.sin(math.pi * per_pos_features), torch.cos(math.pi * per_pos_features)], dim=-1) # Concat with initial pos per_pos_features = torch.cat([pos, per_pos_features], dim=-1) return per_pos_features class PerspectiveHead(nn.Module): """ Predict camera translation (s, tx, ty) and perform full-perspective 2D reprojection (CLIFF/CameraHMR setup). """ def __init__(self, input_dim: int, img_size: Union[int, Tuple[int, int]], # model input size (W, H) mlp_depth: int = 1, mlp_channel_div_factor: int = 8, default_scale_factor: float = 1.0, device=None, dtype=None, operations=None ): super().__init__() # Metadata to compute 3D skeleton and 2D reprojection self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) self.ncam = 3 # (s, tx, ty) self.default_scale_factor = default_scale_factor self.proj = MLP( input_dim=input_dim, hidden_dim=input_dim // mlp_channel_div_factor, output_dim=self.ncam, num_layers=mlp_depth, device=device, dtype=dtype, operations=operations, ) def forward(self, x: torch.Tensor, init_estimate: Optional[torch.Tensor] = None): """ Args: x: pose token with shape [B, C], usually C=DECODER.DIM init_estimate: [B, self.ncam] """ pred_cam = self.proj(x) if init_estimate is not None: pred_cam = pred_cam + init_estimate return pred_cam def perspective_projection( self, points_3d: torch.Tensor, pred_cam: torch.Tensor, bbox_center: torch.Tensor, # [N, 2], in original image space (w, h) bbox_size: torch.Tensor, # [N,], in original image space img_size: torch.Tensor, cam_int: torch.Tensor, # [B, 3, 3] use_intrin_center: bool = False, ): batch_size = points_3d.shape[0] pred_cam = pred_cam.clone() pred_cam[..., [0, 2]] *= -1 # Camera system difference # Compute camera translation: (scale, x, y) --> (x, y, depth) # depth ~= f / s, Note that f is in the NDC space s, tx, ty = pred_cam[:, 0], pred_cam[:, 1], pred_cam[:, 2] bs = bbox_size * s * self.default_scale_factor + 1e-8 focal_length = cam_int[:, 0, 0] tz = 2 * focal_length / bs if not use_intrin_center: cx = 2 * (bbox_center[:, 0] - (img_size[:, 0] / 2)) / bs cy = 2 * (bbox_center[:, 1] - (img_size[:, 1] / 2)) / bs else: cx = 2 * (bbox_center[:, 0] - (cam_int[:, 0, 2])) / bs cy = 2 * (bbox_center[:, 1] - (cam_int[:, 1, 2])) / bs pred_cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1) # Compute camera translation j3d_cam = points_3d + pred_cam_t.unsqueeze(1) # Projection to the image plane, note that the projection output is in original image space now. j2d = perspective_projection(j3d_cam, cam_int) return { "pred_keypoints_2d": j2d.reshape(batch_size, -1, 2), "pred_keypoints_2d_depth": j3d_cam.reshape(batch_size, -1, 3)[:, :, 2], "pred_cam_t": pred_cam_t, "focal_length": focal_length, }