mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-15 03:27:24 +08:00
205 lines
9.5 KiB
Python
205 lines
9.5 KiB
Python
"""Building blocks for MoGe: residual conv stack, resamplers, MLP, DINOv2 encoder, v1 head."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import List, Optional, Sequence, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import comfy.ops
|
|
from comfy.image_encoders.dino2 import Dinov2Model
|
|
|
|
from .geometry import normalized_view_plane_uv
|
|
|
|
|
|
def _conv2d(operations, c_in: int, c_out: int, k: int = 3, *, dtype=None, device=None):
|
|
return operations.Conv2d(c_in, c_out, kernel_size=k, padding=k // 2, padding_mode="replicate", dtype=dtype, device=device)
|
|
|
|
|
|
def _view_plane_uv_grid(batch: int, height: int, width: int, aspect_ratio: float, dtype, device) -> torch.Tensor:
|
|
"""Batched normalized view-plane UV grid as a (B, 2, H, W) tensor."""
|
|
uv = normalized_view_plane_uv(width, height, aspect_ratio=aspect_ratio, dtype=dtype, device=device)
|
|
return uv.permute(2, 0, 1).unsqueeze(0).expand(batch, -1, -1, -1)
|
|
|
|
|
|
def _concat_view_plane_uv(x: torch.Tensor, aspect_ratio: float) -> torch.Tensor:
|
|
"""Append a 2-channel normalized view-plane UV grid to x along the channel dim."""
|
|
uv = _view_plane_uv_grid(x.shape[0], x.shape[-2], x.shape[-1], aspect_ratio, x.dtype, x.device)
|
|
return torch.cat([x, uv], dim=1)
|
|
|
|
|
|
class ResidualConvBlock(nn.Module):
|
|
def __init__(self, channels: int, hidden_channels: Optional[int] = None, in_norm: str = "layer_norm", hidden_norm: str = "group_norm",
|
|
dtype=None, device=None, operations=comfy.ops.manual_cast):
|
|
super().__init__()
|
|
hidden_channels = hidden_channels if hidden_channels is not None else channels
|
|
|
|
in_norm_layer = operations.GroupNorm(1, channels, dtype=dtype, device=device) if in_norm == "layer_norm" else nn.Identity()
|
|
hidden_norm_layer = (operations.GroupNorm(max(hidden_channels // 32, 1), hidden_channels, dtype=dtype, device=device)
|
|
if hidden_norm == "group_norm" else nn.Identity())
|
|
|
|
self.layers = nn.Sequential(
|
|
in_norm_layer, nn.ReLU(), _conv2d(operations, channels, hidden_channels, dtype=dtype, device=device),
|
|
hidden_norm_layer, nn.ReLU(), _conv2d(operations, hidden_channels, channels, dtype=dtype, device=device),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.layers(x) + x
|
|
|
|
|
|
class Resampler(nn.Sequential):
|
|
"""2x upsampler: ConvTranspose2d(2x2) or bilinear upsample, followed by a 3x3 conv."""
|
|
|
|
def __init__(self, in_channels: int, out_channels: int, type_: str, dtype=None, device=None, operations=comfy.ops.manual_cast):
|
|
if type_ == "conv_transpose":
|
|
up = operations.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, dtype=dtype, device=device)
|
|
conv_in = out_channels
|
|
else: # "bilinear"
|
|
up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
|
|
conv_in = in_channels
|
|
super().__init__(up, _conv2d(operations, conv_in, out_channels, dtype=dtype, device=device))
|
|
|
|
|
|
class MLP(nn.Sequential):
|
|
def __init__(self, dims: Sequence[int], dtype=None, device=None, operations=comfy.ops.manual_cast):
|
|
layers = []
|
|
for d_in, d_out in zip(dims[:-2], dims[1:-1]):
|
|
layers.append(operations.Linear(d_in, d_out, dtype=dtype, device=device))
|
|
layers.append(nn.ReLU(inplace=True))
|
|
layers.append(operations.Linear(dims[-2], dims[-1], dtype=dtype, device=device))
|
|
super().__init__(*layers)
|
|
|
|
|
|
class ConvStack(nn.Module):
|
|
def __init__(self, dim_in: List[Optional[int]], dim_res_blocks: List[int], dim_out: List[Optional[int]], resamplers: List[str],
|
|
num_res_blocks: List[int], dim_times_res_block_hidden: int = 1, res_block_in_norm: str = "layer_norm", res_block_hidden_norm: str = "group_norm",
|
|
dtype=None, device=None, operations=comfy.ops.manual_cast):
|
|
super().__init__()
|
|
|
|
self.input_blocks = nn.ModuleList([
|
|
(_conv2d(operations, d_in, d_res, k=1, dtype=dtype, device=device)
|
|
if d_in is not None else nn.Identity())
|
|
for d_in, d_res in zip(dim_in, dim_res_blocks)
|
|
])
|
|
|
|
self.resamplers = nn.ModuleList([
|
|
Resampler(prev, succ, type_=r, dtype=dtype, device=device, operations=operations)
|
|
for prev, succ, r in zip(dim_res_blocks[:-1], dim_res_blocks[1:], resamplers)
|
|
])
|
|
|
|
self.res_blocks = nn.ModuleList([
|
|
nn.Sequential(*[
|
|
ResidualConvBlock(d_res, dim_times_res_block_hidden * d_res, in_norm=res_block_in_norm, hidden_norm=res_block_hidden_norm, dtype=dtype, device=device, operations=operations)
|
|
for _ in range(num_res_blocks[i])
|
|
])
|
|
for i, d_res in enumerate(dim_res_blocks)
|
|
])
|
|
|
|
self.output_blocks = nn.ModuleList([
|
|
(_conv2d(operations, d_res, d_out, k=1, dtype=dtype, device=device)
|
|
if d_out is not None else nn.Identity())
|
|
for d_out, d_res in zip(dim_out, dim_res_blocks)
|
|
])
|
|
|
|
def forward(self, in_features: List[Optional[torch.Tensor]]):
|
|
out_features = []
|
|
x = None
|
|
for i in range(len(self.res_blocks)):
|
|
feat = self.input_blocks[i](in_features[i]) if in_features[i] is not None else None
|
|
if i == 0:
|
|
x = feat
|
|
elif feat is not None:
|
|
x = x + feat
|
|
x = self.res_blocks[i](x)
|
|
out_features.append(self.output_blocks[i](x))
|
|
if i < len(self.res_blocks) - 1:
|
|
x = self.resamplers[i](x)
|
|
return out_features
|
|
|
|
|
|
class DINOv2Encoder(nn.Module):
|
|
"""Comfy DINOv2 backbone with per-layer 1x1 projection heads."""
|
|
|
|
def __init__(self, backbone: dict, intermediate_layers: List[int], dim_out: int, dtype=None, device=None, operations=comfy.ops.manual_cast):
|
|
super().__init__()
|
|
self.intermediate_layers = list(intermediate_layers)
|
|
dim_features = backbone["hidden_size"]
|
|
self.backbone = Dinov2Model(backbone, dtype, device, operations)
|
|
self.output_projections = nn.ModuleList([
|
|
_conv2d(operations, dim_features, dim_out, k=1, dtype=dtype, device=device)
|
|
for _ in range(len(self.intermediate_layers))
|
|
])
|
|
self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
|
self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
|
|
|
def forward(self, image: torch.Tensor, token_rows: int, token_cols: int,
|
|
return_class_token: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
image_14 = F.interpolate(image, (token_rows * 14, token_cols * 14), mode="bilinear", align_corners=False, antialias=True)
|
|
image_14 = (image_14 - self.image_mean) / self.image_std
|
|
feats = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, apply_norm=True)
|
|
x = torch.stack([
|
|
proj(feat.permute(0, 2, 1).unflatten(2, (token_rows, token_cols)).contiguous())
|
|
for proj, (feat, _cls) in zip(self.output_projections, feats)
|
|
], dim=1).sum(dim=1)
|
|
if return_class_token:
|
|
return x, feats[-1][1]
|
|
return x
|
|
|
|
|
|
class HeadV1(nn.Module):
|
|
"""v1 head: 4 backbone-feature projections -> shared upsample stack -> per-target output convs (points, mask)."""
|
|
|
|
NUM_FEATURES = 4
|
|
DIM_PROJ = 512
|
|
DIM_OUT = (3, 1) # 3 channels for points, 1 for mask
|
|
LAST_CONV_CHANNELS = 32
|
|
|
|
def __init__(self, dim_in: int, dim_upsample: List[int] = (256, 128, 128), num_res_blocks: int = 1, dim_times_res_block_hidden: int = 1,
|
|
dtype=None, device=None, operations=comfy.ops.manual_cast):
|
|
super().__init__()
|
|
self.projects = nn.ModuleList([
|
|
_conv2d(operations, dim_in, self.DIM_PROJ, k=1, dtype=dtype, device=device)
|
|
for _ in range(self.NUM_FEATURES)
|
|
])
|
|
def upsampler(in_ch, out_ch):
|
|
return nn.Sequential(
|
|
operations.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2, dtype=dtype, device=device),
|
|
_conv2d(operations, out_ch, out_ch, dtype=dtype, device=device),
|
|
)
|
|
|
|
in_chs = [self.DIM_PROJ] + list(dim_upsample[:-1])
|
|
self.upsample_blocks = nn.ModuleList([
|
|
nn.Sequential(
|
|
upsampler(in_ch + 2, out_ch),
|
|
*(ResidualConvBlock(out_ch, dim_times_res_block_hidden * out_ch, dtype=dtype, device=device, operations=operations)
|
|
for _ in range(num_res_blocks))
|
|
)
|
|
for in_ch, out_ch in zip(in_chs, dim_upsample)
|
|
])
|
|
self.output_block = nn.ModuleList([
|
|
nn.Sequential(
|
|
_conv2d(operations, dim_upsample[-1] + 2, self.LAST_CONV_CHANNELS, dtype=dtype, device=device),
|
|
nn.ReLU(inplace=True),
|
|
_conv2d(operations, self.LAST_CONV_CHANNELS, d_out, k=1, dtype=dtype, device=device),
|
|
)
|
|
for d_out in self.DIM_OUT
|
|
])
|
|
|
|
def forward(self, hidden_states, image: torch.Tensor):
|
|
img_h, img_w = image.shape[-2:]
|
|
patch_h, patch_w = img_h // 14, img_w // 14
|
|
aspect = img_w / img_h
|
|
x = torch.stack([
|
|
proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous())
|
|
for proj, (feat, _cls) in zip(self.projects, hidden_states)
|
|
], dim=1).sum(dim=1)
|
|
|
|
for block in self.upsample_blocks:
|
|
x = block(_concat_view_plane_uv(x, aspect))
|
|
|
|
x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False)
|
|
x = _concat_view_plane_uv(x, aspect)
|
|
return [block(x) for block in self.output_block]
|