mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
293 lines
14 KiB
Python
293 lines
14 KiB
Python
"""NAF (Neighborhood Attention Filtering) feature upsampler.
|
||
|
||
Vendored from valeoai/NAF (Apache-2.0):
|
||
https://github.com/valeoai/NAF — src/model/naf.py + src/layers/{convolutions,attentions,rope}.py
|
||
Used by Pixal3D's shape/texture conditioning to produce
|
||
the 2x-upsampled half of the 2048-channel proj feature map.
|
||
"""
|
||
|
||
import math
|
||
from typing import Tuple
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
|
||
# Pure-torch neighborhood attention (replaces natten.na2d / na2d_qk + na2d_av).
|
||
|
||
def upsample_lr_slice(src_lr: torch.Tensor, lr_dh: int, lr_dw: int,
|
||
hr_h_range: Tuple[int, int], hr_w_range: Tuple[int, int]) -> torch.Tensor:
|
||
"""Slice a LR-layout tensor [B, h_lr, w_lr, n, C], permute to BCHW, and
|
||
nearest-exact upsample only the region covering [hr_h_range, hr_w_range].
|
||
Returns BCHW at hr_h_end-hr_h_start x hr_w_end-hr_w_start (no padding for
|
||
out-of-bounds regions)."""
|
||
B = src_lr.shape[0]
|
||
n = src_lr.shape[-2]
|
||
C = src_lr.shape[-1]
|
||
h_hr_start, h_hr_end = hr_h_range
|
||
w_hr_start, w_hr_end = hr_w_range
|
||
# LR positions covering [h_hr_start, h_hr_end). Nearest-exact maps HR p → p // D.
|
||
lr_h_start = h_hr_start // lr_dh
|
||
lr_h_end = (h_hr_end - 1) // lr_dh + 1
|
||
lr_w_start = w_hr_start // lr_dw
|
||
lr_w_end = (w_hr_end - 1) // lr_dw + 1
|
||
lr_slice = src_lr[:, lr_h_start:lr_h_end, lr_w_start:lr_w_end]
|
||
lh, lw = lr_slice.shape[1], lr_slice.shape[2]
|
||
lr_bcd = lr_slice.permute(0, 3, 4, 1, 2).reshape(B * n, C, lh, lw).contiguous()
|
||
up = F.interpolate(lr_bcd, scale_factor=(lr_dh, lr_dw), mode="nearest-exact")
|
||
offset_h = h_hr_start - lr_h_start * lr_dh
|
||
offset_w = w_hr_start - lr_w_start * lr_dw
|
||
return up[:, :, offset_h:offset_h + (h_hr_end - h_hr_start),
|
||
offset_w:offset_w + (w_hr_end - w_hr_start)]
|
||
|
||
|
||
def na2d_pure(
|
||
q: torch.Tensor, # [B, H, W, n_heads, d_qk] at HR.
|
||
k_lr: torch.Tensor, # [B, h_lr, w_lr, n_heads, d_qk] at LR
|
||
v_lr: torch.Tensor, # [B, h_lr, w_lr, n_heads, d_v] at LR
|
||
kernel_size: Tuple[int, int], # (Kh, Kw) attention window.
|
||
dilation: Tuple[int, int], # (Dh, Dw) stride within the unrolled K/V grid; also the LR→HR upsample factor.
|
||
scale: float, # 1 / sqrt(d_qk) scaling for the Q·K scores.
|
||
tile: int = 128, # Spatial tile size (output positions per tile)
|
||
v_chunk: int = 64, # Sub-divide d_v into chunks of this size when computing attn·V. None disables chunking.
|
||
output: torch.Tensor = None, # Pre-allocated [B, n_heads, d_v, H, W] buffer (may be on CPU).
|
||
) -> torch.Tensor: # [B, n_heads, d_v, H, W] (caller views as BCHW).
|
||
"""Neighborhood attention in pure torch via F.unfold + per-tile slicing.
|
||
|
||
K and V are passed at LR resolution and upsampled (nearest-exact) per-tile only
|
||
for the slice the unfold needs. Avoids the [B, n*d, H, W] HR allocations for K
|
||
(512 MB) and V (2 GB) at tex_1024 fp16. Spatial tiling bounds the per-tile
|
||
F.unfold blob; `v_chunk` further slices d_v so attn·V is computed in C-sized
|
||
chunks (attn is reused, computed once from Q/K).
|
||
|
||
"""
|
||
B, H, W, n, d_qk = q.shape
|
||
d_v = v_lr.shape[-1]
|
||
Kh, Kw = kernel_size
|
||
Dh, Dw = dilation
|
||
pad_h, pad_w = (Kh // 2) * Dh, (Kw // 2) * Dw
|
||
|
||
out = output if output is not None else torch.empty((B, n, d_v, H, W), device=q.device, dtype=q.dtype)
|
||
|
||
th = min(tile, H) if tile else H
|
||
tw = min(tile, W) if tile else W
|
||
chunk = v_chunk if (v_chunk and v_chunk < d_v) else d_v
|
||
|
||
for h0 in range(0, H, th):
|
||
for w0 in range(0, W, tw):
|
||
h1, w1 = min(h0 + th, H), min(w0 + tw, W)
|
||
t_h, t_w = h1 - h0, w1 - w0
|
||
|
||
# Padded HR region the unfold needs (kernel span = (K-1)*D + 1).
|
||
h_src_start = max(0, h0 - pad_h)
|
||
h_src_end = min(H, h1 + pad_h)
|
||
w_src_start = max(0, w0 - pad_w)
|
||
w_src_end = min(W, w1 + pad_w)
|
||
pad_top = max(0, pad_h - h0)
|
||
pad_bot = max(0, (h1 + pad_h) - H)
|
||
pad_lft = max(0, pad_w - w0)
|
||
pad_rgt = max(0, (w1 + pad_w) - W)
|
||
|
||
# Upsample only the tile region from k_lr / v_lr.
|
||
k_tile = upsample_lr_slice(k_lr, Dh, Dw,
|
||
(h_src_start, h_src_end),
|
||
(w_src_start, w_src_end))
|
||
v_tile = upsample_lr_slice(v_lr, Dh, Dw,
|
||
(h_src_start, h_src_end),
|
||
(w_src_start, w_src_end))
|
||
if pad_top or pad_bot or pad_lft or pad_rgt:
|
||
k_tile = F.pad(k_tile, [pad_lft, pad_rgt, pad_top, pad_bot])
|
||
v_tile = F.pad(v_tile, [pad_lft, pad_rgt, pad_top, pad_bot])
|
||
|
||
# Q·K → attention weights (small: KK=81 per output position).
|
||
KK = Kh * Kw
|
||
k_w = F.unfold(k_tile, kernel_size=(Kh, Kw), dilation=(Dh, Dw), padding=0)
|
||
k_w = k_w.view(B, n, d_qk, KK, t_h * t_w).permute(0, 1, 4, 3, 2) # [B, n, t, KK, d_qk]
|
||
# q is [B, H, W, n, d_qk]; per-tile slice + permute -> [B, n, t_h*t_w, 1, d_qk].
|
||
q_tile = q[:, h0:h1, w0:w1].permute(0, 3, 1, 2, 4).reshape(B, n, t_h * t_w, 1, d_qk)
|
||
scores = torch.matmul(q_tile, k_w.transpose(-1, -2)) * scale
|
||
attn = scores.softmax(dim=-1)
|
||
del k_w, scores, q_tile, k_tile
|
||
|
||
# attn · V, chunked over d_v.
|
||
for c0 in range(0, d_v, chunk):
|
||
c1 = min(c0 + chunk, d_v)
|
||
v_w = F.unfold(v_tile[:, c0:c1], kernel_size=(Kh, Kw),dilation=(Dh, Dw), padding=0) # [B*n, (c1-c0)*KK, t]
|
||
v_w = v_w.view(B, n, c1 - c0, KK, t_h * t_w).permute(0, 1, 4, 3, 2)
|
||
out_chunk = torch.matmul(attn, v_w).squeeze(-2) # [B, n, t, c1-c0]
|
||
out_chunk = out_chunk.view(B, n, t_h, t_w, c1 - c0).permute(0, 1, 4, 2, 3)
|
||
out[:, :, c0:c1, h0:h1, w0:w1] = out_chunk
|
||
del v_w, out_chunk
|
||
del attn, v_tile
|
||
|
||
return out # [B, n, d_v, H, W] — sole caller (CrossAttention) views it as BCHW directly.
|
||
|
||
|
||
class CrossAttention(nn.Module):
|
||
"""Window-restricted cross-attention. No learnable parameters; the model's
|
||
capacity lives entirely in the ImageEncoder convs."""
|
||
|
||
def __init__(self, dim: int, num_heads: int, kernel_size: Tuple[int, int] = (9, 9)):
|
||
super().__init__()
|
||
assert dim % num_heads == 0, "dim must be divisible by num_heads"
|
||
self.num_heads = num_heads
|
||
self.kernel_size = kernel_size
|
||
self.scale = (dim // num_heads) ** -0.5
|
||
|
||
@staticmethod
|
||
def _split_heads_lr(x: torch.Tensor, num_heads: int) -> torch.Tensor:
|
||
"""[B, n*d, h, w] -> [B, h, w, n, d] at the input resolution (no upsample)."""
|
||
B, C, H, W = x.shape
|
||
return x.view(B, num_heads, C // num_heads, H, W).permute(0, 3, 4, 1, 2).contiguous()
|
||
|
||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||
output_device=None) -> torch.Tensor:
|
||
hq, wq = q.shape[-2:]
|
||
hk, wk = k.shape[-2:]
|
||
dilation = (hq // hk, wq // wk)
|
||
B, C, _, _ = q.shape
|
||
q = q.view(B, self.num_heads, C // self.num_heads, hq, wq).permute(0, 3, 4, 1, 2).contiguous()
|
||
k_lr = self._split_heads_lr(k, self.num_heads).to(q.dtype)
|
||
v_lr = self._split_heads_lr(v, self.num_heads).to(q.dtype)
|
||
out_buf = None
|
||
if output_device is not None:
|
||
n = self.num_heads
|
||
d_v = v.shape[1] // n
|
||
out_buf = torch.empty(B, n, d_v, hq, wq, device=output_device, dtype=q.dtype)
|
||
out = na2d_pure(q, k_lr, v_lr, self.kernel_size, dilation, self.scale, output=out_buf)
|
||
return out.view(B, -1, hq, wq)
|
||
|
||
|
||
# RoPE positional embedding
|
||
|
||
def rope_rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||
x1, x2 = x.chunk(2, dim=-1)
|
||
return torch.cat([-x2, x1], dim=-1)
|
||
|
||
|
||
class RoPE(nn.Module):
|
||
def __init__(self, embed_dim: int, num_heads: int, base: float = 100.0):
|
||
super().__init__()
|
||
assert embed_dim % (4 * num_heads) == 0
|
||
self.num_heads = num_heads
|
||
self.D_head = embed_dim // num_heads
|
||
self.base = base
|
||
self.register_buffer("periods", torch.empty(self.D_head // 4), persistent=True) # loaded from the checkpoint
|
||
self._cached_key = None
|
||
self._cached_cos_sin = None
|
||
|
||
def _cos_sin(self, H: int, W: int, dtype: torch.dtype):
|
||
"""cos/sin only depend on (H, W) and the output dtype (periods are fixed
|
||
once loaded from the checkpoint), so cache them — saves the meshgrid /
|
||
angle / cos / sin / tile / flatten on every forward."""
|
||
key = (H, W, dtype)
|
||
if self._cached_key == key and self._cached_cos_sin is not None:
|
||
return self._cached_cos_sin
|
||
device = self.periods.device
|
||
coords_h = torch.arange(0.5, H, device=device, dtype=torch.float32) / H
|
||
coords_w = torch.arange(0.5, W, device=device, dtype=torch.float32) / W
|
||
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # [H, W, 2]
|
||
coords = coords.flatten(0, 1) * 2.0 - 1.0 # [HW, 2]
|
||
angles = 2 * math.pi * coords[:, :, None] / self.periods.to(coords.dtype)[None, None, :] # [HW, 2, D//4]
|
||
angles = angles.flatten(1, 2).tile(2) # [HW, D]
|
||
cos = torch.cos(angles).to(dtype)
|
||
sin = torch.sin(angles).to(dtype)
|
||
self._cached_cos_sin = (cos, sin)
|
||
self._cached_key = key
|
||
return cos, sin
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
# x: [B, n*D_head, H, W]
|
||
B, C, H, W = x.shape
|
||
n = self.num_heads
|
||
D = C // n
|
||
x = x.view(B, n, D, H, W).permute(0, 1, 3, 4, 2).reshape(B, n, H * W, D)
|
||
cos, sin = self._cos_sin(H, W, x.dtype)
|
||
x = (x * cos) + (rope_rotate_half(x) * sin)
|
||
x = x.view(B, n, H, W, D).permute(0, 1, 4, 2, 3).reshape(B, n * D, H, W)
|
||
return x
|
||
|
||
|
||
# Image encoder
|
||
|
||
class EncBlock(nn.Module):
|
||
def __init__(self, channels: int, kernel_size: int, num_groups: int = 8):
|
||
super().__init__()
|
||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=channels)
|
||
self.conv1 = nn.Conv2d(channels, channels, kernel_size=kernel_size,
|
||
padding=kernel_size // 2, padding_mode="reflect", bias=True)
|
||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=channels)
|
||
self.conv2 = nn.Conv2d(channels, channels, kernel_size=kernel_size,
|
||
padding=kernel_size // 2, padding_mode="reflect", bias=True)
|
||
self.activation_fn = nn.SiLU()
|
||
|
||
def forward(self, x):
|
||
x = self.norm1(x)
|
||
x = self.activation_fn(x)
|
||
x = self.conv1(x)
|
||
x = self.norm2(x)
|
||
x = self.activation_fn(x)
|
||
x = self.conv2(x)
|
||
return x # no skip connection
|
||
|
||
|
||
def _encoder(in_dim: int, hidden_dim: int, kernel_size: int = 1, ks_res: int = 1, num_layers: int = 2) -> nn.Sequential:
|
||
return nn.Sequential(
|
||
nn.Conv2d(in_dim, hidden_dim, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode="reflect", bias=True),
|
||
*[EncBlock(hidden_dim, kernel_size=ks_res) for _ in range(num_layers)],
|
||
)
|
||
|
||
|
||
class ImageEncoder(nn.Module):
|
||
"""Two parallel conv stacks (1x1 + 3x3) producing dim/2 channels each, then concat,
|
||
spatial average-pool to target size, RoPE-embed positions."""
|
||
|
||
def __init__(self, in_channels: int = 3, out_channels: int = 256,
|
||
heads_rope: int = 4, rope_base: float = 100.0, img_layers: int = 2):
|
||
super().__init__()
|
||
half = out_channels // 2
|
||
self.encoder = _encoder(in_channels, half, kernel_size=1, ks_res=1, num_layers=img_layers)
|
||
self.sem_encoder = _encoder(in_channels, half, kernel_size=3, ks_res=3, num_layers=img_layers)
|
||
self.rope = RoPE(embed_dim=out_channels, num_heads=heads_rope, base=rope_base)
|
||
|
||
def forward(self, x: torch.Tensor, output_size: Tuple[int, int]) -> torch.Tensor:
|
||
# Avoid running the conv stacks on >4× the target resolution.
|
||
out_h, out_w = output_size
|
||
if x.shape[-2] > 4 * out_h or x.shape[-1] > 4 * out_w:
|
||
x = F.interpolate(x, size=(min(x.shape[-2], 4 * out_h),
|
||
min(x.shape[-1], 4 * out_w)),
|
||
mode="bilinear", align_corners=False)
|
||
x = torch.cat([self.encoder(x), self.sem_encoder(x)], dim=1)
|
||
x = F.adaptive_avg_pool2d(x, output_size=output_size)
|
||
x = self.rope(x)
|
||
return x
|
||
|
||
|
||
class NAF(nn.Module):
|
||
"""NAF feature upsampler."""
|
||
|
||
def __init__(
|
||
self, dim: int = 256, # internal channel dimension of the ImageEncoder
|
||
heads_attn: int = 4, # attention heads in the windowed cross-attn
|
||
heads_rope: int = 4, # heads for RoPE position encoding (must divide dim)
|
||
kernel_size: int = 9, # square kernel for the neighborhood attention window
|
||
rope_base: float = 100.0, # base for RoPE frequency periods
|
||
img_layers: int = 2 # number of EncBlocks in each conv stack
|
||
):
|
||
super().__init__()
|
||
self.image_encoder = ImageEncoder(in_channels=3, out_channels=dim, heads_rope=heads_rope, rope_base=rope_base, img_layers=img_layers)
|
||
self.upsampler = CrossAttention(dim=dim, num_heads=heads_attn, kernel_size=(kernel_size, kernel_size))
|
||
|
||
def forward(
|
||
self,
|
||
image: torch.Tensor, # [B, 3, H_img, W_img] in [0, 1].
|
||
features: torch.Tensor, # [B, C, H_feat, W_feat] low-resolution features (any C).
|
||
output_size: Tuple[int, int], # (H_out, W_out) target spatial resolution for the upsampled features.
|
||
output_device=None,
|
||
) -> torch.Tensor: # [B, C, H_out, W_out] upsampled features.
|
||
"""Upsample low-res feature map to output_size, guided by the image."""
|
||
q = self.image_encoder(image, output_size=output_size)
|
||
k = F.adaptive_avg_pool2d(q, output_size=features.shape[-2:])
|
||
return self.upsampler(q, k, features, output_device=output_device)
|