Refactor NAF model usage

This commit is contained in:
kijai 2026-07-01 01:36:14 +03:00
parent f66b3165e4
commit f3c8db510f
3 changed files with 11 additions and 46 deletions

View File

@ -10,6 +10,7 @@ import comfy.utils
import comfy.clip_model
import comfy.image_encoders.dino2
import comfy.image_encoders.dino3
from comfy.image_encoders.naf import NAF
class Output:
def __getitem__(self, key):
@ -53,6 +54,7 @@ class ClipVisionModel():
self.model.eval()
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.naf = None
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
@ -155,6 +157,13 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
for k in keys:
if k not in u:
sd.pop(k)
# NAF feature upsampler ships bundled into the DINOv3 file under the `naf.` prefix.
naf_keys = [k for k in sd if k.startswith("naf.")]
if naf_keys:
naf_sd = {k[len("naf."):]: sd.pop(k) for k in naf_keys}
naf = NAF().eval()
naf.load_state_dict(naf_sd, strict=False)
clip.naf = comfy.model_patcher.CoreModelPatcher(naf, load_device=clip.load_device, offload_device=comfy.model_management.text_encoder_offload_device())
return clip
def load(ckpt_path):

View File

@ -2,15 +2,12 @@ from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types, UI, io
from comfy.ldm.trellis2.vae import SparseTensor
from comfy.ldm.trellis2.model import build_proj_transform_matrix, compute_stage_proj_feats
from comfy.ldm.trellis2.naf.model import NAF
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
from server import PromptServer
import comfy.latent_formats
import comfy.model_management
import comfy.model_patcher
import comfy.utils
import folder_paths
from PIL import Image
import logging
import numpy as np
@ -18,7 +15,6 @@ import math
import torch
ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
NAFModel = io.Custom("NAF_MODEL")
shape_slat_format = comfy.latent_formats.Trellis2ShapeSLAT()
@ -864,12 +860,6 @@ class Pixal3DConditioning(IO.ComfyNode):
"Wire a MoGeGeometryToFOV (axis='horizontal', unit='degrees') "
"output here for a MoGe-derived FOV.",
),
NAFModel.Input(
"naf_model",
optional=True,
tooltip="Optional NAF feature upsampler. Required for shape/texture stages "
"to match upstream's trained feature distribution.",
),
],
outputs=[
IO.Conditioning.Output(display_name="positive"),
@ -878,7 +868,8 @@ class Pixal3DConditioning(IO.ComfyNode):
)
@classmethod
def execute(cls, clip_vision_model, image, mask, camera_angle_x, naf_model=None) -> IO.NodeOutput:
def execute(cls, clip_vision_model, image, mask, camera_angle_x) -> IO.NodeOutput:
naf_model = getattr(clip_vision_model, "naf", None)
if image.ndim == 3:
image = image.unsqueeze(0)
if mask.ndim == 2:
@ -987,40 +978,6 @@ class Pixal3DConditioning(IO.ComfyNode):
return IO.NodeOutput(positive, negative)
class LoadNAFModel(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="LoadNAFModel",
display_name="Load NAF Model",
category="loaders",
inputs=[
IO.Combo.Input(
"naf_name",
options=folder_paths.get_filename_list("upscale_models"),
tooltip="NAF safetensors checkpoint (e.g. naf_release.safetensors).",
),
],
outputs=[NAFModel.Output(display_name="naf_model")],
)
@classmethod
def execute(cls, naf_name) -> IO.NodeOutput:
path = folder_paths.get_full_path_or_raise("upscale_models", naf_name)
sd = comfy.utils.load_torch_file(path, safe_load=True)
model = NAF().eval()
_, unexpected = model.load_state_dict(sd, strict=False)
if unexpected:
raise ValueError(f"Unexpected keys in NAF state_dict: {sorted(unexpected)[:8]}...")
patcher = comfy.model_patcher.CoreModelPatcher(
model,
load_device=comfy.model_management.get_torch_device(),
offload_device=comfy.model_management.unet_offload_device(),
)
return IO.NodeOutput(patcher)
class GetMeshInfo(IO.ComfyNode):
"""Report vertex / face counts and attributes for a MESH, displayed on the
node (and as a string output). Counts are comma-formatted since meshes can
@ -1096,7 +1053,6 @@ class Trellis2Extension(ComfyExtension):
return [
Trellis2Conditioning,
Pixal3DConditioning,
LoadNAFModel,
Trellis2ShapeStage,
EmptyTrellis2LatentStructure,
Trellis2TextureStage,