mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Optimize NAF memory use
This commit is contained in:
parent
2333d6bc40
commit
3aae4bf741
@ -50,8 +50,9 @@ def na2d_pure(
|
||||
dilation: Tuple[int, int], # (Dh, Dw) stride within the unrolled K/V grid; also the LR→HR upsample factor.
|
||||
scale: float, # 1 / sqrt(d_qk) scaling for the Q·K scores.
|
||||
tile: int = 128, # Spatial tile size (output positions per tile)
|
||||
v_chunk: int = 64 # Sub-divide d_v into chunks of this size when computing attn·V. None disables chunking.
|
||||
) -> torch.Tensor: # [B, H, W, n_heads, d_v] attended features.
|
||||
v_chunk: int = 64, # Sub-divide d_v into chunks of this size when computing attn·V. None disables chunking.
|
||||
output: torch.Tensor = None, # Pre-allocated [B, n_heads, d_v, H, W] buffer (may be on CPU).
|
||||
) -> torch.Tensor: # [B, n_heads, d_v, H, W] (caller views as BCHW).
|
||||
"""Neighborhood attention in pure torch via F.unfold + per-tile slicing.
|
||||
|
||||
K and V are passed at LR resolution and upsampled (nearest-exact) per-tile only
|
||||
@ -67,8 +68,7 @@ def na2d_pure(
|
||||
Dh, Dw = dilation
|
||||
pad_h, pad_w = (Kh // 2) * Dh, (Kw // 2) * Dw
|
||||
|
||||
q_ = q.permute(0, 3, 4, 1, 2).contiguous() # [B, n, d_qk, H, W]
|
||||
out = torch.empty((B, n, d_v, H, W), device=q.device, dtype=q.dtype)
|
||||
out = output if output is not None else torch.empty((B, n, d_v, H, W), device=q.device, dtype=q.dtype)
|
||||
|
||||
th = min(tile, H) if tile else H
|
||||
tw = min(tile, W) if tile else W
|
||||
@ -104,7 +104,8 @@ def na2d_pure(
|
||||
KK = Kh * Kw
|
||||
k_w = F.unfold(k_tile, kernel_size=(Kh, Kw), dilation=(Dh, Dw), padding=0)
|
||||
k_w = k_w.view(B, n, d_qk, KK, t_h * t_w).permute(0, 1, 4, 3, 2) # [B, n, t, KK, d_qk]
|
||||
q_tile = q_[:, :, :, h0:h1, w0:w1].permute(0, 1, 3, 4, 2).reshape(B, n, t_h * t_w, 1, d_qk)
|
||||
# q is [B, H, W, n, d_qk]; per-tile slice + permute -> [B, n, t_h*t_w, 1, d_qk].
|
||||
q_tile = q[:, h0:h1, w0:w1].permute(0, 3, 1, 2, 4).reshape(B, n, t_h * t_w, 1, d_qk)
|
||||
scores = torch.matmul(q_tile, k_w.transpose(-1, -2)) * scale
|
||||
attn = scores.softmax(dim=-1)
|
||||
del k_w, scores, q_tile, k_tile
|
||||
@ -120,7 +121,7 @@ def na2d_pure(
|
||||
del v_w, out_chunk
|
||||
del attn, v_tile
|
||||
|
||||
return out.permute(0, 3, 4, 1, 2).contiguous() # [B, H, W, n, d_v]
|
||||
return out # [B, n, d_v, H, W] — sole caller (CrossAttention) views it as BCHW directly.
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
@ -140,9 +141,8 @@ class CrossAttention(nn.Module):
|
||||
B, C, H, W = x.shape
|
||||
return x.view(B, num_heads, C // num_heads, H, W).permute(0, 3, 4, 1, 2).contiguous()
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
# q is [B, C, Hq, Wq] at HR; k and v are at LR (Hk, Wk). We KEEP k and v at LR
|
||||
# na2d_pure upsamples only the tile slice it needs.
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_device=None) -> torch.Tensor:
|
||||
hq, wq = q.shape[-2:]
|
||||
hk, wk = k.shape[-2:]
|
||||
dilation = (hq // hk, wq // wk)
|
||||
@ -150,9 +150,13 @@ class CrossAttention(nn.Module):
|
||||
q = q.view(B, self.num_heads, C // self.num_heads, hq, wq).permute(0, 3, 4, 1, 2).contiguous()
|
||||
k_lr = self._split_heads_lr(k, self.num_heads).to(q.dtype)
|
||||
v_lr = self._split_heads_lr(v, self.num_heads).to(q.dtype)
|
||||
out = na2d_pure(q, k_lr, v_lr, self.kernel_size, dilation, self.scale)
|
||||
# [B, H, W, n, d] -> [B, n*d, H, W]
|
||||
return out.permute(0, 3, 4, 1, 2).contiguous().view(B, -1, hq, wq)
|
||||
out_buf = None
|
||||
if output_device is not None:
|
||||
n = self.num_heads
|
||||
d_v = v.shape[1] // n
|
||||
out_buf = torch.empty(B, n, d_v, hq, wq, device=output_device, dtype=q.dtype)
|
||||
out = na2d_pure(q, k_lr, v_lr, self.kernel_size, dilation, self.scale, output=out_buf)
|
||||
return out.view(B, -1, hq, wq)
|
||||
|
||||
|
||||
# RoPE positional embedding
|
||||
@ -260,8 +264,6 @@ class ImageEncoder(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# Top-level NAF model.
|
||||
|
||||
class NAF(nn.Module):
|
||||
"""NAF feature upsampler."""
|
||||
|
||||
@ -281,21 +283,10 @@ class NAF(nn.Module):
|
||||
self,
|
||||
image: torch.Tensor, # [B, 3, H_img, W_img] in [0, 1].
|
||||
features: torch.Tensor, # [B, C, H_feat, W_feat] low-resolution features (any C).
|
||||
output_size: Tuple[int, int] # (H_out, W_out) target spatial resolution for the upsampled features.
|
||||
output_size: Tuple[int, int], # (H_out, W_out) target spatial resolution for the upsampled features.
|
||||
output_device=None,
|
||||
) -> torch.Tensor: # [B, C, H_out, W_out] upsampled features.
|
||||
"""Upsample low-res feature map to output_size, guided by the image."""
|
||||
q = self.image_encoder(image, output_size=output_size)
|
||||
k = F.adaptive_avg_pool2d(q, output_size=features.shape[-2:])
|
||||
return self.upsampler(q, k, features)
|
||||
|
||||
|
||||
def build_naf_from_state_dict(state_dict: dict) -> NAF:
|
||||
"""Instantiate NAF with the default hyperparams and load the given state_dict.
|
||||
|
||||
The published NAF release uses the default constructor (dim=256, heads_attn=4,
|
||||
heads_rope=4, kernel_size=9, rope_base=100, img_layers=2)."""
|
||||
model = NAF()
|
||||
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
||||
if unexpected:
|
||||
raise ValueError(f"Unexpected keys in NAF state_dict: {sorted(unexpected)[:8]}...")
|
||||
return model
|
||||
return self.upsampler(q, k, features, output_device=output_device)
|
||||
|
||||
@ -2,12 +2,13 @@ from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, Types, UI, io
|
||||
from comfy.ldm.trellis2.vae import SparseTensor
|
||||
from comfy.ldm.trellis2.model import build_proj_transform_matrix, _project_points_to_image, compute_stage_proj_feats
|
||||
from comfy.ldm.trellis2.naf.model import build_naf_from_state_dict
|
||||
from comfy.ldm.trellis2.naf.model import NAF
|
||||
|
||||
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
|
||||
from server import PromptServer
|
||||
import comfy.latent_formats
|
||||
import comfy.model_management
|
||||
import comfy.model_patcher
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
from PIL import Image
|
||||
@ -923,15 +924,19 @@ class Pixal3DConditioning(IO.ComfyNode):
|
||||
def _naf_hr(lr_feat, composites, image_size, naf_target):
|
||||
if naf_model is None or naf_target is None:
|
||||
return None
|
||||
target_dtype = lr_feat.dtype
|
||||
if next(naf_model.parameters()).dtype != target_dtype:
|
||||
naf_model.to(dtype=target_dtype)
|
||||
imgs = torch.cat([
|
||||
comfy.utils.common_upscale(c, image_size, image_size, "lanczos", "disabled")
|
||||
for c in composites
|
||||
], dim=0).to(torch_device).to(target_dtype)
|
||||
hr = naf_model(imgs, lr_feat.to(torch_device).to(target_dtype), naf_target)
|
||||
return hr.to(device)
|
||||
comfy.model_management.load_model_gpu(naf_model)
|
||||
inner = naf_model.model
|
||||
target_dtype = comfy.model_management.text_encoder_dtype(torch_device)
|
||||
if next(inner.parameters()).dtype != target_dtype:
|
||||
inner.to(dtype=target_dtype)
|
||||
hrs = []
|
||||
for i, c in enumerate(composites):
|
||||
img_i = comfy.utils.common_upscale(c, image_size, image_size, "lanczos", "disabled")\
|
||||
.to(torch_device).to(target_dtype)
|
||||
lr_i = lr_feat[i:i + 1].to(torch_device).to(target_dtype)
|
||||
hr_i = inner(img_i, lr_i, naf_target, output_device=device)
|
||||
hrs.append(hr_i)
|
||||
return torch.cat(hrs, dim=0)
|
||||
|
||||
hr_shape_512 = _naf_hr(fm_512_dino, composite_list, 512, (512, 512))
|
||||
hr_shape_1024 = _naf_hr(fm_1024_dino, composite_list, 1024, (512, 512))
|
||||
@ -1148,10 +1153,16 @@ class LoadNAFModel(IO.ComfyNode):
|
||||
def execute(cls, naf_name) -> IO.NodeOutput:
|
||||
path = folder_paths.get_full_path_or_raise("upscale_models", naf_name)
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
model = build_naf_from_state_dict(sd)
|
||||
device = comfy.model_management.get_torch_device()
|
||||
model = model.to(device).eval()
|
||||
return IO.NodeOutput(model)
|
||||
model = NAF().eval()
|
||||
_, unexpected = model.load_state_dict(sd, strict=False)
|
||||
if unexpected:
|
||||
raise ValueError(f"Unexpected keys in NAF state_dict: {sorted(unexpected)[:8]}...")
|
||||
patcher = comfy.model_patcher.CoreModelPatcher(
|
||||
model,
|
||||
load_device=comfy.model_management.get_torch_device(),
|
||||
offload_device=comfy.model_management.unet_offload_device(),
|
||||
)
|
||||
return IO.NodeOutput(patcher)
|
||||
|
||||
|
||||
class GetMeshInfo(IO.ComfyNode):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user