From 9d0f678f6f51ae707a34dc5c3fddf8dd1c7d74af Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 8 May 2026 19:03:06 +0300 Subject: [PATCH] removing seeds from node display --- comfy/ldm/trellis2/model.py | 22 ++-- comfy_extras/nodes_trellis2.py | 211 +++++---------------------------- 2 files changed, 41 insertions(+), 192 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index a54e4ca94..14810d56d 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -802,6 +802,11 @@ class Trellis2(nn.Module): mode = "structure_generation" not_struct_mode = False + if not not_struct_mode: + bsz = x.size(0) + x = x[:, :8] + x = x.view(bsz, 8, 16, 16, 16) + if is_1024 and not_struct_mode and not is_512_run: context = embeds @@ -821,7 +826,7 @@ class Trellis2(nn.Module): orig_bsz = x.shape[0] rule = txt_rule if mode == "texture_generation" else shape_rule - # 1. CFG Bypass Slicing + # CFG Bypass Slicing if rule and orig_bsz > 1: half = orig_bsz // 2 x_eval = x[half:] @@ -834,7 +839,7 @@ class Trellis2(nn.Module): B, N, C = x_eval.shape - # 2. Vectorized SparseTensor Construction (NO FOR LOOPS!) + # Vectorized SparseTensor Construction if mode in ["shape_generation", "texture_generation"]: if coord_counts is not None: logical_batch = coord_counts.shape[0] @@ -880,14 +885,14 @@ class Trellis2(nn.Module): if slat is None: raise ValueError("shape_slat can't be None") - slat_feats = slat.feats + slat_feats = slat # Duplicate shape context if CFG is active if coord_counts is not None and B > coord_counts.shape[0]: slat_feats = torch.cat([slat_feats, slat_feats], dim=0) elif coord_counts is None: - slat_feats = slat.feats[:N].repeat(B, 1) + slat_feats = slat_feats[:N].repeat(B, 1) - x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats], dim=-1)) + x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats.to(x_st.feats.device)], dim=-1)) out = self.shape2txt(x_st, t_eval, c_eval) else: # structure @@ -901,9 +906,6 @@ class Trellis2(nn.Module): else: out = self.structure_model(x, timestep, context) - # ================================================== - # RE-PAD AND FORMAT OUTPUT - # ================================================== if not_struct_mode: if mask is not None: # Instantly scatter the valid tokens back into a padded rectangular tensor @@ -916,7 +918,7 @@ class Trellis2(nn.Module): if rule and orig_bsz > 1: out_tensor = out_tensor.repeat(2, 1, 1, 1) return out_tensor - #else: - # out = torch.nn.functional.pad(out, (0, 0, 0, 0, 0, 0, 24, 0)) + else: + out = torch.nn.functional.pad(out, (0, 0, 0, 0, 0, 0, 0, 24)) return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 704f6f32f..e65fd9787 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -159,37 +159,6 @@ def split_batched_coords(coords, coord_counts): items.append(coords_i) return items - -def normalize_batch_index(batch_index): - if batch_index is None: - return None - if isinstance(batch_index, int): - return [int(batch_index)] - return list(batch_index) - - -def resolve_sample_indices(batch_index, batch_size): - sample_indices = normalize_batch_index(batch_index) - if sample_indices is None: - return list(range(batch_size)) - if len(sample_indices) != batch_size: - raise ValueError( - f"Trellis2 batch_index length {len(sample_indices)} does not match batch size {batch_size}" - ) - return sample_indices - - -def resolve_singleton_sample_index(batch_index): - sample_indices = normalize_batch_index(batch_index) - if sample_indices is None: - return 0 - if len(sample_indices) != 1: - raise ValueError( - f"Trellis2 batch_index must be an int or single-element iterable for singleton coords, got {sample_indices}" - ) - return int(sample_indices[0]) - - def flatten_batched_sparse_latent(samples, coords, coord_counts): samples = samples.squeeze(-1).transpose(1, 2) if coord_counts is None: @@ -218,7 +187,6 @@ def split_batched_sparse_latent(samples, coords, coord_counts): items.append((samples[i, :count], coords_i)) return items - def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): """ Generic function to paint a mesh using nearest-neighbor colors from a sparse voxel field. @@ -232,15 +200,15 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): # map voxels voxel_pos = voxel_coords.to(device).float() * voxel_size + origin verts = mesh.vertices.to(device).squeeze(0) - voxel_colors = voxel_colors.cpu() + voxel_colors = voxel_colors.to(device) - voxel_pos_np = voxel_pos.cpu().numpy() - verts_np = verts.cpu().numpy() + voxel_pos_np = voxel_pos.numpy() + verts_np = verts.numpy() tree = scipy.spatial.cKDTree(voxel_pos_np) # nearest neighbour k=1 - _, nearest_idx_np = tree.query(verts_np, k=1, workers=1) + _, nearest_idx_np = tree.query(verts_np, k=1, workers=-1) nearest_idx = torch.from_numpy(nearest_idx_np).long() v_colors = voxel_colors[nearest_idx] @@ -253,7 +221,7 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): final_colors = linear_colors.unsqueeze(0) - out_mesh = copy.copy(mesh) + out_mesh = copy.deepcopy(mesh) out_mesh.colors = final_colors return out_mesh @@ -411,10 +379,10 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): def execute(cls, samples, vae, resolution): resolution = int(resolution) sample_tensor = samples["samples"] + sample_tensor = sample_tensor[:, :8] batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape) decoder = vae.first_stage_model.struct_dec load_device = comfy.model_management.get_torch_device() - batch_index = normalize_batch_index(samples.get("batch_index")) decoded_batches = [] for start in range(0, sample_tensor.shape[0], batch_number): sample_chunk = sample_tensor[start:start + batch_number].to(load_device) @@ -426,8 +394,6 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): ratio = current_res // resolution decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5 out = Types.VOXEL(decoded.squeeze(1).float()) - if batch_index is not None: - out.batch_index = normalize_batch_index(batch_index) return IO.NodeOutput(out) class Trellis2UpsampleCascade(IO.ComfyNode): @@ -453,7 +419,6 @@ class Trellis2UpsampleCascade(IO.ComfyNode): prepare_trellis_vae_for_decode(vae, shape_latent_512["samples"].shape) coord_counts = shape_latent_512.get("coord_counts") - batch_index = normalize_batch_index(shape_latent_512.get("batch_index")) decoder = vae.first_stage_model.shape_dec lr_resolution = 512 target_resolution = int(target_resolution) @@ -529,14 +494,11 @@ class Trellis2UpsampleCascade(IO.ComfyNode): final_coords_list.append(final_coords_i) output_coord_counts.append(int(final_coords_i.shape[0])) - normalized_batch_index = normalize_batch_index(batch_index) output = { "coords": torch.cat(final_coords_list, dim=0), "coord_counts": torch.tensor(output_coord_counts, dtype=torch.int64), "resolutions": torch.full((len(final_coords_list),), int(hr_resolution), dtype=torch.int64), } - if normalized_batch_index is not None: - output["batch_index"] = normalized_batch_index return IO.NodeOutput(output,) @@ -547,8 +509,6 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): model_internal = model.model device = comfy.model_management.intermediate_device() torch_device = comfy.model_management.get_torch_device() - had_image_size = hasattr(model_internal, "image_size") - original_image_size = getattr(model_internal, "image_size", None) def prepare_tensor(pil_img, size): resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS) @@ -556,21 +516,15 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device) return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device) - cond_1024 = None - try: - model_internal.image_size = 512 - input_512 = prepare_tensor(cropped_img_tensor, 512) - cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0] + model_internal.image_size = 512 + input_512 = prepare_tensor(cropped_img_tensor, 512) + cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0] - if include_1024: - model_internal.image_size = 1024 - input_1024 = prepare_tensor(cropped_img_tensor, 1024) - cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0] - finally: - if not had_image_size: - delattr(model_internal, "image_size") - else: - model_internal.image_size = original_image_size + cond_1024 = None + if include_1024: + model_internal.image_size = 1024 + input_1024 = prepare_tensor(cropped_img_tensor, 1024) + cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0] conditioning = { 'cond_512': cond_512.to(device), @@ -580,7 +534,6 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): conditioning['cond_1024'] = cond_1024.to(device) return conditioning - class Trellis2Conditioning(IO.ComfyNode): @classmethod def define_schema(cls): @@ -693,7 +646,6 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): inputs=[ IO.AnyType.Input("structure_or_coords"), IO.Model.Input("model"), - IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff), ], outputs=[ IO.Latent.Output(), @@ -702,58 +654,25 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, structure_or_coords, model, seed): + def execute(cls, structure_or_coords, model): # to accept the upscaled coords is_512_pass = False - coord_counts = None - coord_resolutions = None - batch_index = None if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: decoded = structure_or_coords.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() is_512_pass = True - batch_index = normalize_batch_index(getattr(structure_or_coords, "batch_index", None)) - - elif isinstance(structure_or_coords, dict): - coords = structure_or_coords["coords"].int() - coord_counts = structure_or_coords.get("coord_counts") - coord_resolutions = structure_or_coords.get("resolutions") - batch_index = normalize_batch_index(structure_or_coords.get("batch_index")) - is_512_pass = False elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: coords = structure_or_coords.int() is_512_pass = False - else: raise ValueError(f"Invalid input to EmptyShapeLatent: {type(structure_or_coords)}") + + batch_size, counts, max_tokens = infer_batched_coord_layout(coords) in_channels = 32 - batch_size, inferred_coord_counts, max_tokens = infer_batched_coord_layout(coords) - if coord_counts is not None: - coord_counts = coord_counts.to(dtype=torch.int64, device=coords.device) - if coord_counts.shape != inferred_coord_counts.shape or not torch.equal(coord_counts, inferred_coord_counts): - raise ValueError( - f"Trellis2 coord_counts metadata {coord_counts.tolist()} does not match coords layout {inferred_coord_counts.tolist()}" - ) - else: - coord_counts = inferred_coord_counts - if batch_size == 1: - sample_index = resolve_singleton_sample_index(batch_index) - generator = torch.Generator(device="cpu") - generator.manual_seed(int(seed) + sample_index) - latent = torch.randn(1, in_channels, coords.shape[0], 1, generator=generator) - else: - sample_indices = resolve_sample_indices(batch_index, batch_size) - latent = torch.zeros(batch_size, in_channels, max_tokens, 1) - for i, sample_index in enumerate(sample_indices): - count = int(coord_counts[i].item()) - generator = torch.Generator(device="cpu") - generator.manual_seed(int(seed) + int(sample_index)) - latent_i = torch.randn(1, in_channels, count, 1, generator=generator) - latent[i, :, :count] = latent_i[0] - if coord_counts is not None: - latent.trellis_coord_counts = coord_counts.clone() + # image like format + latent = torch.zeros(batch_size, in_channels, max_tokens, 1) model = model.clone() model.model_options = model.model_options.copy() if "transformer_options" in model.model_options: @@ -762,20 +681,11 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"] = {} model.model_options["transformer_options"]["coords"] = coords - if coord_counts is not None: - model.model_options["transformer_options"]["coord_counts"] = coord_counts if is_512_pass: model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512" else: model.model_options["transformer_options"]["generation_mode"] = "shape_generation" - output = {"samples": latent, "coords": coords, "type": "trellis2"} - if batch_index is not None: - output["batch_index"] = normalize_batch_index(batch_index) - if coord_counts is not None: - output["coord_counts"] = coord_counts - if coord_resolutions is not None: - output["resolutions"] = coord_resolutions - return IO.NodeOutput(output, model) + return IO.NodeOutput({"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"}, model) class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod @@ -787,7 +697,6 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): IO.Voxel.Input("structure_or_coords"), IO.Latent.Input("shape_latent"), IO.Model.Input("model"), - IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff), ], outputs=[ IO.Latent.Output(), @@ -796,68 +705,22 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, structure_or_coords, shape_latent, model, seed): + def execute(cls, structure_or_coords, shape_latent, model): channels = 32 - coord_counts = None - batch_index = None if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: decoded = structure_or_coords.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() - batch_index = normalize_batch_index(getattr(structure_or_coords, "batch_index", None)) - - elif isinstance(structure_or_coords, dict): - coords = structure_or_coords["coords"].int() - coord_counts = structure_or_coords.get("coord_counts") - batch_index = normalize_batch_index(structure_or_coords.get("batch_index")) elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: coords = structure_or_coords.int() - else: - raise ValueError( - "structure_or_coords must be a voxel input with data.ndim == 4, " - f'a dict containing "coords", or a 2D torch.Tensor; got {type(structure_or_coords).__name__}' - ) - shape_batch_index = normalize_batch_index(shape_latent.get("batch_index")) - if batch_index is None: - batch_index = shape_batch_index + batch_size, counts, max_tokens = infer_batched_coord_layout(coords) + shape_latent = shape_latent["samples"] - batch_size, inferred_coord_counts, max_tokens = infer_batched_coord_layout(coords) - if coord_counts is not None: - coord_counts = coord_counts.to(dtype=torch.int64, device=coords.device) - if coord_counts.shape != inferred_coord_counts.shape or not torch.equal(coord_counts, inferred_coord_counts): - raise ValueError( - f"Trellis2 coord_counts metadata {coord_counts.tolist()} does not match coords layout {inferred_coord_counts.tolist()}" - ) - else: - coord_counts = inferred_coord_counts if shape_latent.ndim == 4: - if shape_latent.shape[0] != batch_size: - raise ValueError( - f"shape_latent batch {shape_latent.shape[0]} doesn't match coords batch {batch_size}" - ) - shape_latent = shape_latent.squeeze(-1).transpose(1, 2) - if shape_latent.shape[1] < max_tokens: - raise ValueError( - f"shape_latent tokens {shape_latent.shape[1]} can't cover coords max tokens {max_tokens}" - ) + shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels) - if batch_size == 1: - sample_index = resolve_singleton_sample_index(batch_index) - generator = torch.Generator(device="cpu") - generator.manual_seed(int(seed) + sample_index) - latent = torch.randn(1, channels, coords.shape[0], 1, generator=generator) - else: - sample_indices = resolve_sample_indices(batch_index, batch_size) - latent = torch.zeros(batch_size, channels, max_tokens, 1) - for i, sample_index in enumerate(sample_indices): - count = int(coord_counts[i].item()) - generator = torch.Generator(device="cpu") - generator.manual_seed(int(seed) + int(sample_index)) - latent_i = torch.randn(1, channels, count, 1, generator=generator) - latent[i, :, :count] = latent_i[0] - if coord_counts is not None: - latent.trellis_coord_counts = coord_counts.clone() + latent = torch.zeros(batch_size, channels, max_tokens, 1) model = model.clone() model.model_options = model.model_options.copy() if "transformer_options" in model.model_options: @@ -866,16 +729,9 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"] = {} model.model_options["transformer_options"]["coords"] = coords - if coord_counts is not None: - model.model_options["transformer_options"]["coord_counts"] = coord_counts model.model_options["transformer_options"]["generation_mode"] = "texture_generation" model.model_options["transformer_options"]["shape_slat"] = shape_latent - output = {"samples": latent, "coords": coords, "type": "trellis2"} - if batch_index is not None: - output["batch_index"] = normalize_batch_index(batch_index) - if coord_counts is not None: - output["coord_counts"] = coord_counts - return IO.NodeOutput(output, model) + return IO.NodeOutput({"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"}, model) class EmptyStructureLatentTrellis2(IO.ComfyNode): @@ -886,29 +742,20 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): category="latent/3d", inputs=[ IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), - IO.Int.Input("batch_index_start", default=0, min=0, max=4096, tooltip="Starting sample index for per-sample sampler noise."), - IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff), ], outputs=[ IO.Latent.Output(), ] ) @classmethod - def execute(cls, batch_size, batch_index_start, seed): + def execute(cls, batch_size): in_channels = 8 resolution = 16 - sample_indices = [int(batch_index_start) + i for i in range(batch_size)] latent = torch.zeros(batch_size, in_channels, resolution, resolution, resolution) - for i, sample_index in enumerate(sample_indices): - generator = torch.Generator(device="cpu") - generator.manual_seed(int(seed) + sample_index) - latent[i] = torch.randn(1, in_channels, resolution, resolution, resolution, generator=generator)[0] output = { "samples": latent, "type": "trellis2", } - if batch_size > 1 or batch_index_start != 0: - output["batch_index"] = sample_indices return IO.NodeOutput(output) def simplify_fn(vertices, faces, colors=None, target=100000, max_edge_length=None): @@ -939,7 +786,7 @@ def simplify_fn(vertices, faces, colors=None, target=100000, max_edge_length=Non else None ) - out_v, out_f, out_c = _qem_simplify_robust( + out_v, out_f, out_c = _qem_simplify( verts_np, faces_np, colors_np, target, device, max_edge_length ) @@ -952,7 +799,7 @@ def simplify_fn(vertices, faces, colors=None, target=100000, max_edge_length=Non ) return final_v, final_f, final_c -def _qem_simplify_robust(verts_np, faces_np, colors_np, target_faces, device, max_edge_length=None): +def _qem_simplify(verts_np, faces_np, colors_np, target_faces, device, max_edge_length=None): verts = torch.from_numpy(verts_np).to(device=device, dtype=torch.float64) faces = torch.from_numpy(faces_np).to(device=device, dtype=torch.int64) colors = (