Optimize NAF memory use

This commit is contained in:
kijai 2026-06-30 21:47:43 +03:00
parent 2333d6bc40
commit 3aae4bf741
2 changed files with 44 additions and 42 deletions

View File

@ -50,8 +50,9 @@ def na2d_pure(
dilation: Tuple[int, int], # (Dh, Dw) stride within the unrolled K/V grid; also the LR→HR upsample factor.
scale: float, # 1 / sqrt(d_qk) scaling for the Q·K scores.
tile: int = 128, # Spatial tile size (output positions per tile)
v_chunk: int = 64 # Sub-divide d_v into chunks of this size when computing attn·V. None disables chunking.
) -> torch.Tensor: # [B, H, W, n_heads, d_v] attended features.
v_chunk: int = 64, # Sub-divide d_v into chunks of this size when computing attn·V. None disables chunking.
output: torch.Tensor = None, # Pre-allocated [B, n_heads, d_v, H, W] buffer (may be on CPU).
) -> torch.Tensor: # [B, n_heads, d_v, H, W] (caller views as BCHW).
"""Neighborhood attention in pure torch via F.unfold + per-tile slicing.
K and V are passed at LR resolution and upsampled (nearest-exact) per-tile only
@ -67,8 +68,7 @@ def na2d_pure(
Dh, Dw = dilation
pad_h, pad_w = (Kh // 2) * Dh, (Kw // 2) * Dw
q_ = q.permute(0, 3, 4, 1, 2).contiguous() # [B, n, d_qk, H, W]
out = torch.empty((B, n, d_v, H, W), device=q.device, dtype=q.dtype)
out = output if output is not None else torch.empty((B, n, d_v, H, W), device=q.device, dtype=q.dtype)
th = min(tile, H) if tile else H
tw = min(tile, W) if tile else W
@ -104,7 +104,8 @@ def na2d_pure(
KK = Kh * Kw
k_w = F.unfold(k_tile, kernel_size=(Kh, Kw), dilation=(Dh, Dw), padding=0)
k_w = k_w.view(B, n, d_qk, KK, t_h * t_w).permute(0, 1, 4, 3, 2) # [B, n, t, KK, d_qk]
q_tile = q_[:, :, :, h0:h1, w0:w1].permute(0, 1, 3, 4, 2).reshape(B, n, t_h * t_w, 1, d_qk)
# q is [B, H, W, n, d_qk]; per-tile slice + permute -> [B, n, t_h*t_w, 1, d_qk].
q_tile = q[:, h0:h1, w0:w1].permute(0, 3, 1, 2, 4).reshape(B, n, t_h * t_w, 1, d_qk)
scores = torch.matmul(q_tile, k_w.transpose(-1, -2)) * scale
attn = scores.softmax(dim=-1)
del k_w, scores, q_tile, k_tile
@ -120,7 +121,7 @@ def na2d_pure(
del v_w, out_chunk
del attn, v_tile
return out.permute(0, 3, 4, 1, 2).contiguous() # [B, H, W, n, d_v]
return out # [B, n, d_v, H, W] — sole caller (CrossAttention) views it as BCHW directly.
class CrossAttention(nn.Module):
@ -140,9 +141,8 @@ class CrossAttention(nn.Module):
B, C, H, W = x.shape
return x.view(B, num_heads, C // num_heads, H, W).permute(0, 3, 4, 1, 2).contiguous()
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
# q is [B, C, Hq, Wq] at HR; k and v are at LR (Hk, Wk). We KEEP k and v at LR
# na2d_pure upsamples only the tile slice it needs.
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_device=None) -> torch.Tensor:
hq, wq = q.shape[-2:]
hk, wk = k.shape[-2:]
dilation = (hq // hk, wq // wk)
@ -150,9 +150,13 @@ class CrossAttention(nn.Module):
q = q.view(B, self.num_heads, C // self.num_heads, hq, wq).permute(0, 3, 4, 1, 2).contiguous()
k_lr = self._split_heads_lr(k, self.num_heads).to(q.dtype)
v_lr = self._split_heads_lr(v, self.num_heads).to(q.dtype)
out = na2d_pure(q, k_lr, v_lr, self.kernel_size, dilation, self.scale)
# [B, H, W, n, d] -> [B, n*d, H, W]
return out.permute(0, 3, 4, 1, 2).contiguous().view(B, -1, hq, wq)
out_buf = None
if output_device is not None:
n = self.num_heads
d_v = v.shape[1] // n
out_buf = torch.empty(B, n, d_v, hq, wq, device=output_device, dtype=q.dtype)
out = na2d_pure(q, k_lr, v_lr, self.kernel_size, dilation, self.scale, output=out_buf)
return out.view(B, -1, hq, wq)
# RoPE positional embedding
@ -260,8 +264,6 @@ class ImageEncoder(nn.Module):
return x
# Top-level NAF model.
class NAF(nn.Module):
"""NAF feature upsampler."""
@ -281,21 +283,10 @@ class NAF(nn.Module):
self,
image: torch.Tensor, # [B, 3, H_img, W_img] in [0, 1].
features: torch.Tensor, # [B, C, H_feat, W_feat] low-resolution features (any C).
output_size: Tuple[int, int] # (H_out, W_out) target spatial resolution for the upsampled features.
output_size: Tuple[int, int], # (H_out, W_out) target spatial resolution for the upsampled features.
output_device=None,
) -> torch.Tensor: # [B, C, H_out, W_out] upsampled features.
"""Upsample low-res feature map to output_size, guided by the image."""
q = self.image_encoder(image, output_size=output_size)
k = F.adaptive_avg_pool2d(q, output_size=features.shape[-2:])
return self.upsampler(q, k, features)
def build_naf_from_state_dict(state_dict: dict) -> NAF:
"""Instantiate NAF with the default hyperparams and load the given state_dict.
The published NAF release uses the default constructor (dim=256, heads_attn=4,
heads_rope=4, kernel_size=9, rope_base=100, img_layers=2)."""
model = NAF()
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if unexpected:
raise ValueError(f"Unexpected keys in NAF state_dict: {sorted(unexpected)[:8]}...")
return model
return self.upsampler(q, k, features, output_device=output_device)

View File

@ -2,12 +2,13 @@ from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types, UI, io
from comfy.ldm.trellis2.vae import SparseTensor
from comfy.ldm.trellis2.model import build_proj_transform_matrix, _project_points_to_image, compute_stage_proj_feats
from comfy.ldm.trellis2.naf.model import build_naf_from_state_dict
from comfy.ldm.trellis2.naf.model import NAF
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
from server import PromptServer
import comfy.latent_formats
import comfy.model_management
import comfy.model_patcher
import comfy.utils
import folder_paths
from PIL import Image
@ -923,15 +924,19 @@ class Pixal3DConditioning(IO.ComfyNode):
def _naf_hr(lr_feat, composites, image_size, naf_target):
if naf_model is None or naf_target is None:
return None
target_dtype = lr_feat.dtype
if next(naf_model.parameters()).dtype != target_dtype:
naf_model.to(dtype=target_dtype)
imgs = torch.cat([
comfy.utils.common_upscale(c, image_size, image_size, "lanczos", "disabled")
for c in composites
], dim=0).to(torch_device).to(target_dtype)
hr = naf_model(imgs, lr_feat.to(torch_device).to(target_dtype), naf_target)
return hr.to(device)
comfy.model_management.load_model_gpu(naf_model)
inner = naf_model.model
target_dtype = comfy.model_management.text_encoder_dtype(torch_device)
if next(inner.parameters()).dtype != target_dtype:
inner.to(dtype=target_dtype)
hrs = []
for i, c in enumerate(composites):
img_i = comfy.utils.common_upscale(c, image_size, image_size, "lanczos", "disabled")\
.to(torch_device).to(target_dtype)
lr_i = lr_feat[i:i + 1].to(torch_device).to(target_dtype)
hr_i = inner(img_i, lr_i, naf_target, output_device=device)
hrs.append(hr_i)
return torch.cat(hrs, dim=0)
hr_shape_512 = _naf_hr(fm_512_dino, composite_list, 512, (512, 512))
hr_shape_1024 = _naf_hr(fm_1024_dino, composite_list, 1024, (512, 512))
@ -1148,10 +1153,16 @@ class LoadNAFModel(IO.ComfyNode):
def execute(cls, naf_name) -> IO.NodeOutput:
path = folder_paths.get_full_path_or_raise("upscale_models", naf_name)
sd = comfy.utils.load_torch_file(path, safe_load=True)
model = build_naf_from_state_dict(sd)
device = comfy.model_management.get_torch_device()
model = model.to(device).eval()
return IO.NodeOutput(model)
model = NAF().eval()
_, unexpected = model.load_state_dict(sd, strict=False)
if unexpected:
raise ValueError(f"Unexpected keys in NAF state_dict: {sorted(unexpected)[:8]}...")
patcher = comfy.model_patcher.CoreModelPatcher(
model,
load_device=comfy.model_management.get_torch_device(),
offload_device=comfy.model_management.unet_offload_device(),
)
return IO.NodeOutput(patcher)
class GetMeshInfo(IO.ComfyNode):