diff --git a/comfy/ldm/trellis2/naf/model.py b/comfy/ldm/trellis2/naf/model.py index 1b085e7e7..4b446f4fd 100644 --- a/comfy/ldm/trellis2/naf/model.py +++ b/comfy/ldm/trellis2/naf/model.py @@ -50,8 +50,9 @@ def na2d_pure( dilation: Tuple[int, int], # (Dh, Dw) stride within the unrolled K/V grid; also the LR→HR upsample factor. scale: float, # 1 / sqrt(d_qk) scaling for the Q·K scores. tile: int = 128, # Spatial tile size (output positions per tile) - v_chunk: int = 64 # Sub-divide d_v into chunks of this size when computing attn·V. None disables chunking. - ) -> torch.Tensor: # [B, H, W, n_heads, d_v] attended features. + v_chunk: int = 64, # Sub-divide d_v into chunks of this size when computing attn·V. None disables chunking. + output: torch.Tensor = None, # Pre-allocated [B, n_heads, d_v, H, W] buffer (may be on CPU). + ) -> torch.Tensor: # [B, n_heads, d_v, H, W] (caller views as BCHW). """Neighborhood attention in pure torch via F.unfold + per-tile slicing. K and V are passed at LR resolution and upsampled (nearest-exact) per-tile only @@ -67,8 +68,7 @@ def na2d_pure( Dh, Dw = dilation pad_h, pad_w = (Kh // 2) * Dh, (Kw // 2) * Dw - q_ = q.permute(0, 3, 4, 1, 2).contiguous() # [B, n, d_qk, H, W] - out = torch.empty((B, n, d_v, H, W), device=q.device, dtype=q.dtype) + out = output if output is not None else torch.empty((B, n, d_v, H, W), device=q.device, dtype=q.dtype) th = min(tile, H) if tile else H tw = min(tile, W) if tile else W @@ -104,7 +104,8 @@ def na2d_pure( KK = Kh * Kw k_w = F.unfold(k_tile, kernel_size=(Kh, Kw), dilation=(Dh, Dw), padding=0) k_w = k_w.view(B, n, d_qk, KK, t_h * t_w).permute(0, 1, 4, 3, 2) # [B, n, t, KK, d_qk] - q_tile = q_[:, :, :, h0:h1, w0:w1].permute(0, 1, 3, 4, 2).reshape(B, n, t_h * t_w, 1, d_qk) + # q is [B, H, W, n, d_qk]; per-tile slice + permute -> [B, n, t_h*t_w, 1, d_qk]. + q_tile = q[:, h0:h1, w0:w1].permute(0, 3, 1, 2, 4).reshape(B, n, t_h * t_w, 1, d_qk) scores = torch.matmul(q_tile, k_w.transpose(-1, -2)) * scale attn = scores.softmax(dim=-1) del k_w, scores, q_tile, k_tile @@ -120,7 +121,7 @@ def na2d_pure( del v_w, out_chunk del attn, v_tile - return out.permute(0, 3, 4, 1, 2).contiguous() # [B, H, W, n, d_v] + return out # [B, n, d_v, H, W] — sole caller (CrossAttention) views it as BCHW directly. class CrossAttention(nn.Module): @@ -140,9 +141,8 @@ class CrossAttention(nn.Module): B, C, H, W = x.shape return x.view(B, num_heads, C // num_heads, H, W).permute(0, 3, 4, 1, 2).contiguous() - def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: - # q is [B, C, Hq, Wq] at HR; k and v are at LR (Hk, Wk). We KEEP k and v at LR - # na2d_pure upsamples only the tile slice it needs. + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + output_device=None) -> torch.Tensor: hq, wq = q.shape[-2:] hk, wk = k.shape[-2:] dilation = (hq // hk, wq // wk) @@ -150,9 +150,13 @@ class CrossAttention(nn.Module): q = q.view(B, self.num_heads, C // self.num_heads, hq, wq).permute(0, 3, 4, 1, 2).contiguous() k_lr = self._split_heads_lr(k, self.num_heads).to(q.dtype) v_lr = self._split_heads_lr(v, self.num_heads).to(q.dtype) - out = na2d_pure(q, k_lr, v_lr, self.kernel_size, dilation, self.scale) - # [B, H, W, n, d] -> [B, n*d, H, W] - return out.permute(0, 3, 4, 1, 2).contiguous().view(B, -1, hq, wq) + out_buf = None + if output_device is not None: + n = self.num_heads + d_v = v.shape[1] // n + out_buf = torch.empty(B, n, d_v, hq, wq, device=output_device, dtype=q.dtype) + out = na2d_pure(q, k_lr, v_lr, self.kernel_size, dilation, self.scale, output=out_buf) + return out.view(B, -1, hq, wq) # RoPE positional embedding @@ -260,8 +264,6 @@ class ImageEncoder(nn.Module): return x -# Top-level NAF model. - class NAF(nn.Module): """NAF feature upsampler.""" @@ -281,21 +283,10 @@ class NAF(nn.Module): self, image: torch.Tensor, # [B, 3, H_img, W_img] in [0, 1]. features: torch.Tensor, # [B, C, H_feat, W_feat] low-resolution features (any C). - output_size: Tuple[int, int] # (H_out, W_out) target spatial resolution for the upsampled features. + output_size: Tuple[int, int], # (H_out, W_out) target spatial resolution for the upsampled features. + output_device=None, ) -> torch.Tensor: # [B, C, H_out, W_out] upsampled features. """Upsample low-res feature map to output_size, guided by the image.""" q = self.image_encoder(image, output_size=output_size) k = F.adaptive_avg_pool2d(q, output_size=features.shape[-2:]) - return self.upsampler(q, k, features) - - -def build_naf_from_state_dict(state_dict: dict) -> NAF: - """Instantiate NAF with the default hyperparams and load the given state_dict. - - The published NAF release uses the default constructor (dim=256, heads_attn=4, - heads_rope=4, kernel_size=9, rope_base=100, img_layers=2).""" - model = NAF() - missing, unexpected = model.load_state_dict(state_dict, strict=False) - if unexpected: - raise ValueError(f"Unexpected keys in NAF state_dict: {sorted(unexpected)[:8]}...") - return model + return self.upsampler(q, k, features, output_device=output_device) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index fb82166df..fe5ab8e88 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -2,12 +2,13 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types, UI, io from comfy.ldm.trellis2.vae import SparseTensor from comfy.ldm.trellis2.model import build_proj_transform_matrix, _project_points_to_image, compute_stage_proj_feats -from comfy.ldm.trellis2.naf.model import build_naf_from_state_dict +from comfy.ldm.trellis2.naf.model import NAF from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch from server import PromptServer import comfy.latent_formats import comfy.model_management +import comfy.model_patcher import comfy.utils import folder_paths from PIL import Image @@ -923,15 +924,19 @@ class Pixal3DConditioning(IO.ComfyNode): def _naf_hr(lr_feat, composites, image_size, naf_target): if naf_model is None or naf_target is None: return None - target_dtype = lr_feat.dtype - if next(naf_model.parameters()).dtype != target_dtype: - naf_model.to(dtype=target_dtype) - imgs = torch.cat([ - comfy.utils.common_upscale(c, image_size, image_size, "lanczos", "disabled") - for c in composites - ], dim=0).to(torch_device).to(target_dtype) - hr = naf_model(imgs, lr_feat.to(torch_device).to(target_dtype), naf_target) - return hr.to(device) + comfy.model_management.load_model_gpu(naf_model) + inner = naf_model.model + target_dtype = comfy.model_management.text_encoder_dtype(torch_device) + if next(inner.parameters()).dtype != target_dtype: + inner.to(dtype=target_dtype) + hrs = [] + for i, c in enumerate(composites): + img_i = comfy.utils.common_upscale(c, image_size, image_size, "lanczos", "disabled")\ + .to(torch_device).to(target_dtype) + lr_i = lr_feat[i:i + 1].to(torch_device).to(target_dtype) + hr_i = inner(img_i, lr_i, naf_target, output_device=device) + hrs.append(hr_i) + return torch.cat(hrs, dim=0) hr_shape_512 = _naf_hr(fm_512_dino, composite_list, 512, (512, 512)) hr_shape_1024 = _naf_hr(fm_1024_dino, composite_list, 1024, (512, 512)) @@ -1148,10 +1153,16 @@ class LoadNAFModel(IO.ComfyNode): def execute(cls, naf_name) -> IO.NodeOutput: path = folder_paths.get_full_path_or_raise("upscale_models", naf_name) sd = comfy.utils.load_torch_file(path, safe_load=True) - model = build_naf_from_state_dict(sd) - device = comfy.model_management.get_torch_device() - model = model.to(device).eval() - return IO.NodeOutput(model) + model = NAF().eval() + _, unexpected = model.load_state_dict(sd, strict=False) + if unexpected: + raise ValueError(f"Unexpected keys in NAF state_dict: {sorted(unexpected)[:8]}...") + patcher = comfy.model_patcher.CoreModelPatcher( + model, + load_device=comfy.model_management.get_torch_device(), + offload_device=comfy.model_management.unet_offload_device(), + ) + return IO.NodeOutput(patcher) class GetMeshInfo(IO.ComfyNode):