mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
Refactor NAF model usage
This commit is contained in:
parent
f66b3165e4
commit
f3c8db510f
@ -10,6 +10,7 @@ import comfy.utils
|
||||
import comfy.clip_model
|
||||
import comfy.image_encoders.dino2
|
||||
import comfy.image_encoders.dino3
|
||||
from comfy.image_encoders.naf import NAF
|
||||
|
||||
class Output:
|
||||
def __getitem__(self, key):
|
||||
@ -53,6 +54,7 @@ class ClipVisionModel():
|
||||
self.model.eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.naf = None
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
@ -155,6 +157,13 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
for k in keys:
|
||||
if k not in u:
|
||||
sd.pop(k)
|
||||
# NAF feature upsampler ships bundled into the DINOv3 file under the `naf.` prefix.
|
||||
naf_keys = [k for k in sd if k.startswith("naf.")]
|
||||
if naf_keys:
|
||||
naf_sd = {k[len("naf."):]: sd.pop(k) for k in naf_keys}
|
||||
naf = NAF().eval()
|
||||
naf.load_state_dict(naf_sd, strict=False)
|
||||
clip.naf = comfy.model_patcher.CoreModelPatcher(naf, load_device=clip.load_device, offload_device=comfy.model_management.text_encoder_offload_device())
|
||||
return clip
|
||||
|
||||
def load(ckpt_path):
|
||||
|
||||
@ -2,15 +2,12 @@ from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, Types, UI, io
|
||||
from comfy.ldm.trellis2.vae import SparseTensor
|
||||
from comfy.ldm.trellis2.model import build_proj_transform_matrix, compute_stage_proj_feats
|
||||
from comfy.ldm.trellis2.naf.model import NAF
|
||||
|
||||
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
|
||||
from server import PromptServer
|
||||
import comfy.latent_formats
|
||||
import comfy.model_management
|
||||
import comfy.model_patcher
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
from PIL import Image
|
||||
import logging
|
||||
import numpy as np
|
||||
@ -18,7 +15,6 @@ import math
|
||||
import torch
|
||||
|
||||
ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
|
||||
NAFModel = io.Custom("NAF_MODEL")
|
||||
|
||||
|
||||
shape_slat_format = comfy.latent_formats.Trellis2ShapeSLAT()
|
||||
@ -864,12 +860,6 @@ class Pixal3DConditioning(IO.ComfyNode):
|
||||
"Wire a MoGeGeometryToFOV (axis='horizontal', unit='degrees') "
|
||||
"output here for a MoGe-derived FOV.",
|
||||
),
|
||||
NAFModel.Input(
|
||||
"naf_model",
|
||||
optional=True,
|
||||
tooltip="Optional NAF feature upsampler. Required for shape/texture stages "
|
||||
"to match upstream's trained feature distribution.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Conditioning.Output(display_name="positive"),
|
||||
@ -878,7 +868,8 @@ class Pixal3DConditioning(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip_vision_model, image, mask, camera_angle_x, naf_model=None) -> IO.NodeOutput:
|
||||
def execute(cls, clip_vision_model, image, mask, camera_angle_x) -> IO.NodeOutput:
|
||||
naf_model = getattr(clip_vision_model, "naf", None)
|
||||
if image.ndim == 3:
|
||||
image = image.unsqueeze(0)
|
||||
if mask.ndim == 2:
|
||||
@ -987,40 +978,6 @@ class Pixal3DConditioning(IO.ComfyNode):
|
||||
return IO.NodeOutput(positive, negative)
|
||||
|
||||
|
||||
class LoadNAFModel(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="LoadNAFModel",
|
||||
display_name="Load NAF Model",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"naf_name",
|
||||
options=folder_paths.get_filename_list("upscale_models"),
|
||||
tooltip="NAF safetensors checkpoint (e.g. naf_release.safetensors).",
|
||||
),
|
||||
],
|
||||
outputs=[NAFModel.Output(display_name="naf_model")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, naf_name) -> IO.NodeOutput:
|
||||
path = folder_paths.get_full_path_or_raise("upscale_models", naf_name)
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
model = NAF().eval()
|
||||
_, unexpected = model.load_state_dict(sd, strict=False)
|
||||
if unexpected:
|
||||
raise ValueError(f"Unexpected keys in NAF state_dict: {sorted(unexpected)[:8]}...")
|
||||
patcher = comfy.model_patcher.CoreModelPatcher(
|
||||
model,
|
||||
load_device=comfy.model_management.get_torch_device(),
|
||||
offload_device=comfy.model_management.unet_offload_device(),
|
||||
)
|
||||
return IO.NodeOutput(patcher)
|
||||
|
||||
|
||||
class GetMeshInfo(IO.ComfyNode):
|
||||
"""Report vertex / face counts and attributes for a MESH, displayed on the
|
||||
node (and as a string output). Counts are comma-formatted since meshes can
|
||||
@ -1096,7 +1053,6 @@ class Trellis2Extension(ComfyExtension):
|
||||
return [
|
||||
Trellis2Conditioning,
|
||||
Pixal3DConditioning,
|
||||
LoadNAFModel,
|
||||
Trellis2ShapeStage,
|
||||
EmptyTrellis2LatentStructure,
|
||||
Trellis2TextureStage,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user