ux improvements

This commit is contained in:
Yousef Rafat 2026-05-14 18:46:33 +03:00
parent 693a34f447
commit efc5141fb0

View File

@ -302,7 +302,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
node_id="VaeDecodeTextureTrellis",
category="latent/3d",
inputs=[
IO.Mesh.Input("shape_mesh"),
IO.Mesh.Input("mesh"),
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
IO.AnyType.Input("shape_subs"),
@ -314,8 +314,8 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
)
@classmethod
def execute(cls, shape_mesh, samples, vae, shape_subs, resolution):
def execute(cls, mesh, samples, vae, shape_subs, resolution):
shape_mesh = mesh
sample_tensor = samples["samples"]
resolution = int(resolution)
device = comfy.model_management.get_torch_device()
@ -549,7 +549,6 @@ class Trellis2Conditioning(IO.ComfyNode):
IO.ClipVision.Input("clip_vision_model"),
IO.Image.Input("image"),
IO.Mask.Input("mask"),
IO.Combo.Input("background_color", options=["black", "gray", "white"], default="black")
],
outputs=[
IO.Conditioning.Output(display_name="positive"),
@ -558,7 +557,7 @@ class Trellis2Conditioning(IO.ComfyNode):
)
@classmethod
def execute(cls, clip_vision_model, image, mask, background_color) -> IO.NodeOutput:
def execute(cls, clip_vision_model, image, mask) -> IO.NodeOutput:
# Normalize to batched form so per-image conditioning loop below is uniform.
if image.ndim == 3:
image = image.unsqueeze(0)
@ -617,8 +616,7 @@ class Trellis2Conditioning(IO.ComfyNode):
logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.")
cropped_np = rgba_np.astype(np.float32) / 255.0
bg_colors = {"black":[0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]}
bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32)
bg_rgb = np.array([0.0, 0.0, 0.0], dtype=np.float32)
fg = cropped_np[:, :, :3]
alpha_float = cropped_np[:, :, 3:4]
@ -649,17 +647,17 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
node_id="EmptyTrellis2ShapeLatent",
category="latent/3d",
inputs=[
IO.AnyType.Input("structure_or_coords"),
IO.Model.Input("model"),
IO.AnyType.Input("structure_or_coords"),
],
outputs=[
IO.Model.Output(),
IO.Latent.Output(),
IO.Model.Output()
]
)
@classmethod
def execute(cls, structure_or_coords, model):
def execute(cls, model, structure_or_coords):
# to accept the upscaled coords
is_512_pass = False
@ -690,7 +688,7 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512"
else:
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
return IO.NodeOutput({"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"}, model)
return IO.NodeOutput(model, {"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"})
class EmptyTrellis2LatentTexture(IO.ComfyNode):
@classmethod
@ -699,18 +697,18 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
node_id="EmptyTrellis2LatentTexture",
category="latent/3d",
inputs=[
IO.Model.Input("model"),
IO.Voxel.Input("structure_or_coords"),
IO.Latent.Input("shape_latent"),
IO.Model.Input("model"),
],
outputs=[
IO.Model.Output(),
IO.Latent.Output(),
IO.Model.Output()
]
)
@classmethod
def execute(cls, structure_or_coords, shape_latent, model):
def execute(cls, model, structure_or_coords, shape_latent):
channels = 32
if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4:
decoded = structure_or_coords.data.unsqueeze(1)
@ -736,7 +734,7 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
model.model_options["transformer_options"]["coords"] = coords
model.model_options["transformer_options"]["generation_mode"] = "texture_generation"
model.model_options["transformer_options"]["shape_slat"] = shape_latent
return IO.NodeOutput({"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"}, model)
return IO.NodeOutput(model, {"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"})
class EmptyTrellis2LatentStructure(IO.ComfyNode):
@ -1346,13 +1344,12 @@ class PostProcessMesh(IO.ComfyNode):
"Set to 0 to disable hole filling."))
],
outputs=[
IO.Mesh.Output("output_mesh"),
IO.Mesh.Output("mesh"),
]
)
@classmethod
def execute(cls, mesh, target_face_count, fill_holes_perimeter):
# input should be comfy.NestedTensor
mesh = copy.deepcopy(mesh)
def process_single(v, f, c, bar):
@ -1368,7 +1365,6 @@ class PostProcessMesh(IO.ComfyNode):
bar.update(1)
return v, f, c
# Check if batch is Jagged (List) or Uniform (3D Tensor)
is_list = isinstance(mesh.vertices, list)
is_batched_tensor = not is_list and mesh.vertices.ndim == 3