ux improvements

This commit is contained in:
Yousef Rafat 2026-05-14 18:46:33 +03:00
parent 693a34f447
commit efc5141fb0

View File

@ -302,7 +302,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
node_id="VaeDecodeTextureTrellis", node_id="VaeDecodeTextureTrellis",
category="latent/3d", category="latent/3d",
inputs=[ inputs=[
IO.Mesh.Input("shape_mesh"), IO.Mesh.Input("mesh"),
IO.Latent.Input("samples"), IO.Latent.Input("samples"),
IO.Vae.Input("vae"), IO.Vae.Input("vae"),
IO.AnyType.Input("shape_subs"), IO.AnyType.Input("shape_subs"),
@ -314,8 +314,8 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, shape_mesh, samples, vae, shape_subs, resolution): def execute(cls, mesh, samples, vae, shape_subs, resolution):
shape_mesh = mesh
sample_tensor = samples["samples"] sample_tensor = samples["samples"]
resolution = int(resolution) resolution = int(resolution)
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
@ -549,7 +549,6 @@ class Trellis2Conditioning(IO.ComfyNode):
IO.ClipVision.Input("clip_vision_model"), IO.ClipVision.Input("clip_vision_model"),
IO.Image.Input("image"), IO.Image.Input("image"),
IO.Mask.Input("mask"), IO.Mask.Input("mask"),
IO.Combo.Input("background_color", options=["black", "gray", "white"], default="black")
], ],
outputs=[ outputs=[
IO.Conditioning.Output(display_name="positive"), IO.Conditioning.Output(display_name="positive"),
@ -558,7 +557,7 @@ class Trellis2Conditioning(IO.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, clip_vision_model, image, mask, background_color) -> IO.NodeOutput: def execute(cls, clip_vision_model, image, mask) -> IO.NodeOutput:
# Normalize to batched form so per-image conditioning loop below is uniform. # Normalize to batched form so per-image conditioning loop below is uniform.
if image.ndim == 3: if image.ndim == 3:
image = image.unsqueeze(0) image = image.unsqueeze(0)
@ -617,8 +616,7 @@ class Trellis2Conditioning(IO.ComfyNode):
logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.") logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.")
cropped_np = rgba_np.astype(np.float32) / 255.0 cropped_np = rgba_np.astype(np.float32) / 255.0
bg_colors = {"black":[0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]} bg_rgb = np.array([0.0, 0.0, 0.0], dtype=np.float32)
bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32)
fg = cropped_np[:, :, :3] fg = cropped_np[:, :, :3]
alpha_float = cropped_np[:, :, 3:4] alpha_float = cropped_np[:, :, 3:4]
@ -649,17 +647,17 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
node_id="EmptyTrellis2ShapeLatent", node_id="EmptyTrellis2ShapeLatent",
category="latent/3d", category="latent/3d",
inputs=[ inputs=[
IO.AnyType.Input("structure_or_coords"),
IO.Model.Input("model"), IO.Model.Input("model"),
IO.AnyType.Input("structure_or_coords"),
], ],
outputs=[ outputs=[
IO.Model.Output(),
IO.Latent.Output(), IO.Latent.Output(),
IO.Model.Output()
] ]
) )
@classmethod @classmethod
def execute(cls, structure_or_coords, model): def execute(cls, model, structure_or_coords):
# to accept the upscaled coords # to accept the upscaled coords
is_512_pass = False is_512_pass = False
@ -690,7 +688,7 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512" model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512"
else: else:
model.model_options["transformer_options"]["generation_mode"] = "shape_generation" model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
return IO.NodeOutput({"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"}, model) return IO.NodeOutput(model, {"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"})
class EmptyTrellis2LatentTexture(IO.ComfyNode): class EmptyTrellis2LatentTexture(IO.ComfyNode):
@classmethod @classmethod
@ -699,18 +697,18 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
node_id="EmptyTrellis2LatentTexture", node_id="EmptyTrellis2LatentTexture",
category="latent/3d", category="latent/3d",
inputs=[ inputs=[
IO.Model.Input("model"),
IO.Voxel.Input("structure_or_coords"), IO.Voxel.Input("structure_or_coords"),
IO.Latent.Input("shape_latent"), IO.Latent.Input("shape_latent"),
IO.Model.Input("model"),
], ],
outputs=[ outputs=[
IO.Model.Output(),
IO.Latent.Output(), IO.Latent.Output(),
IO.Model.Output()
] ]
) )
@classmethod @classmethod
def execute(cls, structure_or_coords, shape_latent, model): def execute(cls, model, structure_or_coords, shape_latent):
channels = 32 channels = 32
if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4:
decoded = structure_or_coords.data.unsqueeze(1) decoded = structure_or_coords.data.unsqueeze(1)
@ -736,7 +734,7 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
model.model_options["transformer_options"]["coords"] = coords model.model_options["transformer_options"]["coords"] = coords
model.model_options["transformer_options"]["generation_mode"] = "texture_generation" model.model_options["transformer_options"]["generation_mode"] = "texture_generation"
model.model_options["transformer_options"]["shape_slat"] = shape_latent model.model_options["transformer_options"]["shape_slat"] = shape_latent
return IO.NodeOutput({"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"}, model) return IO.NodeOutput(model, {"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"})
class EmptyTrellis2LatentStructure(IO.ComfyNode): class EmptyTrellis2LatentStructure(IO.ComfyNode):
@ -1346,13 +1344,12 @@ class PostProcessMesh(IO.ComfyNode):
"Set to 0 to disable hole filling.")) "Set to 0 to disable hole filling."))
], ],
outputs=[ outputs=[
IO.Mesh.Output("output_mesh"), IO.Mesh.Output("mesh"),
] ]
) )
@classmethod @classmethod
def execute(cls, mesh, target_face_count, fill_holes_perimeter): def execute(cls, mesh, target_face_count, fill_holes_perimeter):
# input should be comfy.NestedTensor
mesh = copy.deepcopy(mesh) mesh = copy.deepcopy(mesh)
def process_single(v, f, c, bar): def process_single(v, f, c, bar):
@ -1368,7 +1365,6 @@ class PostProcessMesh(IO.ComfyNode):
bar.update(1) bar.update(1)
return v, f, c return v, f, c
# Check if batch is Jagged (List) or Uniform (3D Tensor)
is_list = isinstance(mesh.vertices, list) is_list = isinstance(mesh.vertices, list)
is_batched_tensor = not is_list and mesh.vertices.ndim == 3 is_batched_tensor = not is_list and mesh.vertices.ndim == 3