ComfyUI/comfy/image_encoders/naf.py
2026-07-01 01:36:14 +03:00

293 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""NAF (Neighborhood Attention Filtering) feature upsampler.
Vendored from valeoai/NAF (Apache-2.0):
https://github.com/valeoai/NAF — src/model/naf.py + src/layers/{convolutions,attentions,rope}.py
Used by Pixal3D's shape/texture conditioning to produce
the 2x-upsampled half of the 2048-channel proj feature map.
"""
import math
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# Pure-torch neighborhood attention (replaces natten.na2d / na2d_qk + na2d_av).
def upsample_lr_slice(src_lr: torch.Tensor, lr_dh: int, lr_dw: int,
hr_h_range: Tuple[int, int], hr_w_range: Tuple[int, int]) -> torch.Tensor:
"""Slice a LR-layout tensor [B, h_lr, w_lr, n, C], permute to BCHW, and
nearest-exact upsample only the region covering [hr_h_range, hr_w_range].
Returns BCHW at hr_h_end-hr_h_start x hr_w_end-hr_w_start (no padding for
out-of-bounds regions)."""
B = src_lr.shape[0]
n = src_lr.shape[-2]
C = src_lr.shape[-1]
h_hr_start, h_hr_end = hr_h_range
w_hr_start, w_hr_end = hr_w_range
# LR positions covering [h_hr_start, h_hr_end). Nearest-exact maps HR p → p // D.
lr_h_start = h_hr_start // lr_dh
lr_h_end = (h_hr_end - 1) // lr_dh + 1
lr_w_start = w_hr_start // lr_dw
lr_w_end = (w_hr_end - 1) // lr_dw + 1
lr_slice = src_lr[:, lr_h_start:lr_h_end, lr_w_start:lr_w_end]
lh, lw = lr_slice.shape[1], lr_slice.shape[2]
lr_bcd = lr_slice.permute(0, 3, 4, 1, 2).reshape(B * n, C, lh, lw).contiguous()
up = F.interpolate(lr_bcd, scale_factor=(lr_dh, lr_dw), mode="nearest-exact")
offset_h = h_hr_start - lr_h_start * lr_dh
offset_w = w_hr_start - lr_w_start * lr_dw
return up[:, :, offset_h:offset_h + (h_hr_end - h_hr_start),
offset_w:offset_w + (w_hr_end - w_hr_start)]
def na2d_pure(
q: torch.Tensor, # [B, H, W, n_heads, d_qk] at HR.
k_lr: torch.Tensor, # [B, h_lr, w_lr, n_heads, d_qk] at LR
v_lr: torch.Tensor, # [B, h_lr, w_lr, n_heads, d_v] at LR
kernel_size: Tuple[int, int], # (Kh, Kw) attention window.
dilation: Tuple[int, int], # (Dh, Dw) stride within the unrolled K/V grid; also the LR→HR upsample factor.
scale: float, # 1 / sqrt(d_qk) scaling for the Q·K scores.
tile: int = 128, # Spatial tile size (output positions per tile)
v_chunk: int = 64, # Sub-divide d_v into chunks of this size when computing attn·V. None disables chunking.
output: torch.Tensor = None, # Pre-allocated [B, n_heads, d_v, H, W] buffer (may be on CPU).
) -> torch.Tensor: # [B, n_heads, d_v, H, W] (caller views as BCHW).
"""Neighborhood attention in pure torch via F.unfold + per-tile slicing.
K and V are passed at LR resolution and upsampled (nearest-exact) per-tile only
for the slice the unfold needs. Avoids the [B, n*d, H, W] HR allocations for K
(512 MB) and V (2 GB) at tex_1024 fp16. Spatial tiling bounds the per-tile
F.unfold blob; `v_chunk` further slices d_v so attn·V is computed in C-sized
chunks (attn is reused, computed once from Q/K).
"""
B, H, W, n, d_qk = q.shape
d_v = v_lr.shape[-1]
Kh, Kw = kernel_size
Dh, Dw = dilation
pad_h, pad_w = (Kh // 2) * Dh, (Kw // 2) * Dw
out = output if output is not None else torch.empty((B, n, d_v, H, W), device=q.device, dtype=q.dtype)
th = min(tile, H) if tile else H
tw = min(tile, W) if tile else W
chunk = v_chunk if (v_chunk and v_chunk < d_v) else d_v
for h0 in range(0, H, th):
for w0 in range(0, W, tw):
h1, w1 = min(h0 + th, H), min(w0 + tw, W)
t_h, t_w = h1 - h0, w1 - w0
# Padded HR region the unfold needs (kernel span = (K-1)*D + 1).
h_src_start = max(0, h0 - pad_h)
h_src_end = min(H, h1 + pad_h)
w_src_start = max(0, w0 - pad_w)
w_src_end = min(W, w1 + pad_w)
pad_top = max(0, pad_h - h0)
pad_bot = max(0, (h1 + pad_h) - H)
pad_lft = max(0, pad_w - w0)
pad_rgt = max(0, (w1 + pad_w) - W)
# Upsample only the tile region from k_lr / v_lr.
k_tile = upsample_lr_slice(k_lr, Dh, Dw,
(h_src_start, h_src_end),
(w_src_start, w_src_end))
v_tile = upsample_lr_slice(v_lr, Dh, Dw,
(h_src_start, h_src_end),
(w_src_start, w_src_end))
if pad_top or pad_bot or pad_lft or pad_rgt:
k_tile = F.pad(k_tile, [pad_lft, pad_rgt, pad_top, pad_bot])
v_tile = F.pad(v_tile, [pad_lft, pad_rgt, pad_top, pad_bot])
# Q·K → attention weights (small: KK=81 per output position).
KK = Kh * Kw
k_w = F.unfold(k_tile, kernel_size=(Kh, Kw), dilation=(Dh, Dw), padding=0)
k_w = k_w.view(B, n, d_qk, KK, t_h * t_w).permute(0, 1, 4, 3, 2) # [B, n, t, KK, d_qk]
# q is [B, H, W, n, d_qk]; per-tile slice + permute -> [B, n, t_h*t_w, 1, d_qk].
q_tile = q[:, h0:h1, w0:w1].permute(0, 3, 1, 2, 4).reshape(B, n, t_h * t_w, 1, d_qk)
scores = torch.matmul(q_tile, k_w.transpose(-1, -2)) * scale
attn = scores.softmax(dim=-1)
del k_w, scores, q_tile, k_tile
# attn · V, chunked over d_v.
for c0 in range(0, d_v, chunk):
c1 = min(c0 + chunk, d_v)
v_w = F.unfold(v_tile[:, c0:c1], kernel_size=(Kh, Kw),dilation=(Dh, Dw), padding=0) # [B*n, (c1-c0)*KK, t]
v_w = v_w.view(B, n, c1 - c0, KK, t_h * t_w).permute(0, 1, 4, 3, 2)
out_chunk = torch.matmul(attn, v_w).squeeze(-2) # [B, n, t, c1-c0]
out_chunk = out_chunk.view(B, n, t_h, t_w, c1 - c0).permute(0, 1, 4, 2, 3)
out[:, :, c0:c1, h0:h1, w0:w1] = out_chunk
del v_w, out_chunk
del attn, v_tile
return out # [B, n, d_v, H, W] — sole caller (CrossAttention) views it as BCHW directly.
class CrossAttention(nn.Module):
"""Window-restricted cross-attention. No learnable parameters; the model's
capacity lives entirely in the ImageEncoder convs."""
def __init__(self, dim: int, num_heads: int, kernel_size: Tuple[int, int] = (9, 9)):
super().__init__()
assert dim % num_heads == 0, "dim must be divisible by num_heads"
self.num_heads = num_heads
self.kernel_size = kernel_size
self.scale = (dim // num_heads) ** -0.5
@staticmethod
def _split_heads_lr(x: torch.Tensor, num_heads: int) -> torch.Tensor:
"""[B, n*d, h, w] -> [B, h, w, n, d] at the input resolution (no upsample)."""
B, C, H, W = x.shape
return x.view(B, num_heads, C // num_heads, H, W).permute(0, 3, 4, 1, 2).contiguous()
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_device=None) -> torch.Tensor:
hq, wq = q.shape[-2:]
hk, wk = k.shape[-2:]
dilation = (hq // hk, wq // wk)
B, C, _, _ = q.shape
q = q.view(B, self.num_heads, C // self.num_heads, hq, wq).permute(0, 3, 4, 1, 2).contiguous()
k_lr = self._split_heads_lr(k, self.num_heads).to(q.dtype)
v_lr = self._split_heads_lr(v, self.num_heads).to(q.dtype)
out_buf = None
if output_device is not None:
n = self.num_heads
d_v = v.shape[1] // n
out_buf = torch.empty(B, n, d_v, hq, wq, device=output_device, dtype=q.dtype)
out = na2d_pure(q, k_lr, v_lr, self.kernel_size, dilation, self.scale, output=out_buf)
return out.view(B, -1, hq, wq)
# RoPE positional embedding
def rope_rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
class RoPE(nn.Module):
def __init__(self, embed_dim: int, num_heads: int, base: float = 100.0):
super().__init__()
assert embed_dim % (4 * num_heads) == 0
self.num_heads = num_heads
self.D_head = embed_dim // num_heads
self.base = base
self.register_buffer("periods", torch.empty(self.D_head // 4), persistent=True) # loaded from the checkpoint
self._cached_key = None
self._cached_cos_sin = None
def _cos_sin(self, H: int, W: int, dtype: torch.dtype):
"""cos/sin only depend on (H, W) and the output dtype (periods are fixed
once loaded from the checkpoint), so cache them — saves the meshgrid /
angle / cos / sin / tile / flatten on every forward."""
key = (H, W, dtype)
if self._cached_key == key and self._cached_cos_sin is not None:
return self._cached_cos_sin
device = self.periods.device
coords_h = torch.arange(0.5, H, device=device, dtype=torch.float32) / H
coords_w = torch.arange(0.5, W, device=device, dtype=torch.float32) / W
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # [H, W, 2]
coords = coords.flatten(0, 1) * 2.0 - 1.0 # [HW, 2]
angles = 2 * math.pi * coords[:, :, None] / self.periods.to(coords.dtype)[None, None, :] # [HW, 2, D//4]
angles = angles.flatten(1, 2).tile(2) # [HW, D]
cos = torch.cos(angles).to(dtype)
sin = torch.sin(angles).to(dtype)
self._cached_cos_sin = (cos, sin)
self._cached_key = key
return cos, sin
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, n*D_head, H, W]
B, C, H, W = x.shape
n = self.num_heads
D = C // n
x = x.view(B, n, D, H, W).permute(0, 1, 3, 4, 2).reshape(B, n, H * W, D)
cos, sin = self._cos_sin(H, W, x.dtype)
x = (x * cos) + (rope_rotate_half(x) * sin)
x = x.view(B, n, H, W, D).permute(0, 1, 4, 2, 3).reshape(B, n * D, H, W)
return x
# Image encoder
class EncBlock(nn.Module):
def __init__(self, channels: int, kernel_size: int, num_groups: int = 8):
super().__init__()
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=channels)
self.conv1 = nn.Conv2d(channels, channels, kernel_size=kernel_size,
padding=kernel_size // 2, padding_mode="reflect", bias=True)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=channels)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=kernel_size,
padding=kernel_size // 2, padding_mode="reflect", bias=True)
self.activation_fn = nn.SiLU()
def forward(self, x):
x = self.norm1(x)
x = self.activation_fn(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.activation_fn(x)
x = self.conv2(x)
return x # no skip connection
def _encoder(in_dim: int, hidden_dim: int, kernel_size: int = 1, ks_res: int = 1, num_layers: int = 2) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(in_dim, hidden_dim, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode="reflect", bias=True),
*[EncBlock(hidden_dim, kernel_size=ks_res) for _ in range(num_layers)],
)
class ImageEncoder(nn.Module):
"""Two parallel conv stacks (1x1 + 3x3) producing dim/2 channels each, then concat,
spatial average-pool to target size, RoPE-embed positions."""
def __init__(self, in_channels: int = 3, out_channels: int = 256,
heads_rope: int = 4, rope_base: float = 100.0, img_layers: int = 2):
super().__init__()
half = out_channels // 2
self.encoder = _encoder(in_channels, half, kernel_size=1, ks_res=1, num_layers=img_layers)
self.sem_encoder = _encoder(in_channels, half, kernel_size=3, ks_res=3, num_layers=img_layers)
self.rope = RoPE(embed_dim=out_channels, num_heads=heads_rope, base=rope_base)
def forward(self, x: torch.Tensor, output_size: Tuple[int, int]) -> torch.Tensor:
# Avoid running the conv stacks on >4× the target resolution.
out_h, out_w = output_size
if x.shape[-2] > 4 * out_h or x.shape[-1] > 4 * out_w:
x = F.interpolate(x, size=(min(x.shape[-2], 4 * out_h),
min(x.shape[-1], 4 * out_w)),
mode="bilinear", align_corners=False)
x = torch.cat([self.encoder(x), self.sem_encoder(x)], dim=1)
x = F.adaptive_avg_pool2d(x, output_size=output_size)
x = self.rope(x)
return x
class NAF(nn.Module):
"""NAF feature upsampler."""
def __init__(
self, dim: int = 256, # internal channel dimension of the ImageEncoder
heads_attn: int = 4, # attention heads in the windowed cross-attn
heads_rope: int = 4, # heads for RoPE position encoding (must divide dim)
kernel_size: int = 9, # square kernel for the neighborhood attention window
rope_base: float = 100.0, # base for RoPE frequency periods
img_layers: int = 2 # number of EncBlocks in each conv stack
):
super().__init__()
self.image_encoder = ImageEncoder(in_channels=3, out_channels=dim, heads_rope=heads_rope, rope_base=rope_base, img_layers=img_layers)
self.upsampler = CrossAttention(dim=dim, num_heads=heads_attn, kernel_size=(kernel_size, kernel_size))
def forward(
self,
image: torch.Tensor, # [B, 3, H_img, W_img] in [0, 1].
features: torch.Tensor, # [B, C, H_feat, W_feat] low-resolution features (any C).
output_size: Tuple[int, int], # (H_out, W_out) target spatial resolution for the upsampled features.
output_device=None,
) -> torch.Tensor: # [B, C, H_out, W_out] upsampled features.
"""Upsample low-res feature map to output_size, guided by the image."""
q = self.image_encoder(image, output_size=output_size)
k = F.adaptive_avg_pool2d(q, output_size=features.shape[-2:])
return self.upsampler(q, k, features, output_device=output_device)