diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 2c8b4d5d6..57968ed85 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -10,6 +10,7 @@ import comfy.utils import comfy.clip_model import comfy.image_encoders.dino2 import comfy.image_encoders.dino3 +from comfy.image_encoders.naf import NAF class Output: def __getitem__(self, key): @@ -53,6 +54,7 @@ class ClipVisionModel(): self.model.eval() self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.naf = None def load_sd(self, sd): return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) @@ -155,6 +157,13 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): for k in keys: if k not in u: sd.pop(k) + # NAF feature upsampler ships bundled into the DINOv3 file under the `naf.` prefix. + naf_keys = [k for k in sd if k.startswith("naf.")] + if naf_keys: + naf_sd = {k[len("naf."):]: sd.pop(k) for k in naf_keys} + naf = NAF().eval() + naf.load_state_dict(naf_sd, strict=False) + clip.naf = comfy.model_patcher.CoreModelPatcher(naf, load_device=clip.load_device, offload_device=comfy.model_management.text_encoder_offload_device()) return clip def load(ckpt_path): diff --git a/comfy/ldm/trellis2/naf/model.py b/comfy/image_encoders/naf.py similarity index 100% rename from comfy/ldm/trellis2/naf/model.py rename to comfy/image_encoders/naf.py diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 9e496e4da..3e5256e0c 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -2,15 +2,12 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types, UI, io from comfy.ldm.trellis2.vae import SparseTensor from comfy.ldm.trellis2.model import build_proj_transform_matrix, compute_stage_proj_feats -from comfy.ldm.trellis2.naf.model import NAF from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch from server import PromptServer import comfy.latent_formats import comfy.model_management -import comfy.model_patcher import comfy.utils -import folder_paths from PIL import Image import logging import numpy as np @@ -18,7 +15,6 @@ import math import torch ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES") -NAFModel = io.Custom("NAF_MODEL") shape_slat_format = comfy.latent_formats.Trellis2ShapeSLAT() @@ -864,12 +860,6 @@ class Pixal3DConditioning(IO.ComfyNode): "Wire a MoGeGeometryToFOV (axis='horizontal', unit='degrees') " "output here for a MoGe-derived FOV.", ), - NAFModel.Input( - "naf_model", - optional=True, - tooltip="Optional NAF feature upsampler. Required for shape/texture stages " - "to match upstream's trained feature distribution.", - ), ], outputs=[ IO.Conditioning.Output(display_name="positive"), @@ -878,7 +868,8 @@ class Pixal3DConditioning(IO.ComfyNode): ) @classmethod - def execute(cls, clip_vision_model, image, mask, camera_angle_x, naf_model=None) -> IO.NodeOutput: + def execute(cls, clip_vision_model, image, mask, camera_angle_x) -> IO.NodeOutput: + naf_model = getattr(clip_vision_model, "naf", None) if image.ndim == 3: image = image.unsqueeze(0) if mask.ndim == 2: @@ -987,40 +978,6 @@ class Pixal3DConditioning(IO.ComfyNode): return IO.NodeOutput(positive, negative) -class LoadNAFModel(IO.ComfyNode): - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="LoadNAFModel", - display_name="Load NAF Model", - category="loaders", - inputs=[ - IO.Combo.Input( - "naf_name", - options=folder_paths.get_filename_list("upscale_models"), - tooltip="NAF safetensors checkpoint (e.g. naf_release.safetensors).", - ), - ], - outputs=[NAFModel.Output(display_name="naf_model")], - ) - - @classmethod - def execute(cls, naf_name) -> IO.NodeOutput: - path = folder_paths.get_full_path_or_raise("upscale_models", naf_name) - sd = comfy.utils.load_torch_file(path, safe_load=True) - model = NAF().eval() - _, unexpected = model.load_state_dict(sd, strict=False) - if unexpected: - raise ValueError(f"Unexpected keys in NAF state_dict: {sorted(unexpected)[:8]}...") - patcher = comfy.model_patcher.CoreModelPatcher( - model, - load_device=comfy.model_management.get_torch_device(), - offload_device=comfy.model_management.unet_offload_device(), - ) - return IO.NodeOutput(patcher) - - class GetMeshInfo(IO.ComfyNode): """Report vertex / face counts and attributes for a MESH, displayed on the node (and as a string output). Counts are comma-formatted since meshes can @@ -1096,7 +1053,6 @@ class Trellis2Extension(ComfyExtension): return [ Trellis2Conditioning, Pixal3DConditioning, - LoadNAFModel, Trellis2ShapeStage, EmptyTrellis2LatentStructure, Trellis2TextureStage,