From 7d98cc1305612becdf0baa734997f84eb296a49d Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 20 Apr 2026 14:29:07 -0500 Subject: [PATCH] Fix Trellis seeded sparse batch semantics --- comfy/ldm/trellis2/model.py | 183 +++++++----------- comfy/sample.py | 34 ++-- comfy_extras/nodes_trellis2.py | 161 ++++++++++----- .../comfy_extras_test/nodes_trellis2_test.py | 83 ++++++++ tests-unit/comfy_test/sample_test.py | 47 +++++ 5 files changed, 333 insertions(+), 175 deletions(-) create mode 100644 tests-unit/comfy_test/sample_test.py diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index f61c50629..15939e5c6 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -813,6 +813,14 @@ class Trellis2(nn.Module): shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] dense_out = None + cond_or_uncond = transformer_options.get("cond_or_uncond") or [] + + def cond_group_indices(batch_groups): + if len(cond_or_uncond) == batch_groups: + cond_groups = [i for i, marker in enumerate(cond_or_uncond) if marker == 0] + if len(cond_groups) > 0: + return cond_groups + return [batch_groups - 1] if not_struct_mode: orig_bsz = x.shape[0] @@ -820,10 +828,17 @@ class Trellis2(nn.Module): logical_batch = coord_counts.shape[0] if coord_counts is not None else 1 if rule and orig_bsz > logical_batch: - half = orig_bsz // 2 - x_eval = x[half:] - t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep - c_eval = cond + batch_groups = orig_bsz // logical_batch + selected_groups = cond_group_indices(batch_groups) + x_groups = x.reshape(batch_groups, logical_batch, *x.shape[1:]) + x_eval = x_groups[selected_groups].reshape(-1, *x.shape[1:]) + if timestep.shape[0] > 1: + t_groups = timestep.reshape(batch_groups, logical_batch, *timestep.shape[1:]) + t_eval = t_groups[selected_groups].reshape(-1, *timestep.shape[1:]) + else: + t_eval = timestep + c_groups = context.reshape(batch_groups, logical_batch, *context.shape[1:]) + c_eval = c_groups[selected_groups].reshape(-1, *context.shape[1:]) else: x_eval = x t_eval = timestep @@ -838,113 +853,62 @@ class Trellis2(nn.Module): raise ValueError( f"Trellis2 coord_counts batch {logical_batch} doesn't divide latent batch {B}" ) + batch_ids = coords[:, 0].to(torch.int64) + order = torch.argsort(batch_ids, stable=True) + sorted_coords = coords.index_select(0, order) + sorted_batch_ids = batch_ids.index_select(0, order) + offsets = coord_counts.cumsum(0) - coord_counts + coords_by_batch = [] + for i in range(logical_batch): + count = int(coord_counts[i].item()) + start = int(offsets[i].item()) + coords_i = sorted_coords[start:start + count] + ids_i = sorted_batch_ids[start:start + count] + if coords_i.shape[0] != count or not torch.all(ids_i == i): + raise ValueError( + f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}" + ) + coords_by_batch.append(coords_i) repeat_factor = B // logical_batch sparse_outs = [] active_coord_counts = [] - if mode == "shape_generation" and repeat_factor > 1: - grouped_outs = [] - grouped_counts = [] + for rep in range(repeat_factor): for i in range(logical_batch): + out_index = rep * logical_batch + i count = int(coord_counts[i].item()) - coords_i = coords[coords[:, 0] == i].clone() - if coords_i.shape[0] != count: - raise ValueError( - f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}" - ) + coords_i = coords_by_batch[i].clone() + coords_i[:, 0] = 0 + feats_i = x_eval[out_index, :count].clone() + x_st_i = SparseTensor(feats=feats_i, coords=coords_i.to(torch.int32)) + t_i = t_eval[out_index].unsqueeze(0).clone() if t_eval.shape[0] > 1 else t_eval + c_i = c_eval[out_index].unsqueeze(0).clone() if c_eval.shape[0] > 1 else c_eval - feat_batches = [] - coord_batches = [] - index_batch = [] - for rep in range(repeat_factor): - out_index = rep * logical_batch + i - feat_batches.append(x_eval[out_index, :count]) - coords_rep = coords_i.clone() - coords_rep[:, 0] = rep - coord_batches.append(coords_rep) - index_batch.append(out_index) - - x_st_i = SparseTensor( - feats=torch.cat(feat_batches, dim=0), - coords=torch.cat(coord_batches, dim=0).to(torch.int32), - ) - index_tensor = torch.tensor(index_batch, device=x_eval.device, dtype=torch.long) - if t_eval.shape[0] > 1: - t_i = t_eval.index_select(0, index_tensor) - else: - t_i = t_eval - if c_eval.shape[0] > 1: - c_i = c_eval.index_select(0, index_tensor) - else: - c_i = c_eval - - if is_512_run: - sparse_out = self.img2shape_512(x_st_i, t_i, c_i) - else: - sparse_out = self.img2shape(x_st_i, t_i, c_i) - - feats_group, coords_group = sparse_out.to_tensor_list() - if len(feats_group) != repeat_factor: - raise ValueError( - f"Trellis2 expected {repeat_factor} sparse output groups for batch {i}, got {len(feats_group)}" - ) - for rep, (feats_rep, coords_rep) in enumerate(zip(feats_group, coords_group)): - if feats_rep.shape[0] != count: - raise ValueError( - f"Trellis2 sparse output rows for batch {i} rep {rep} expected {count}, got {feats_rep.shape[0]}" - ) - if coords_rep.shape[0] != count: - raise ValueError( - f"Trellis2 sparse output coords for batch {i} rep {rep} expected {count}, got {coords_rep.shape[0]}" - ) - grouped_outs.append(feats_group) - grouped_counts.append(count) - - for rep in range(repeat_factor): - for i in range(logical_batch): - sparse_outs.append(grouped_outs[i][rep]) - active_coord_counts.append(grouped_counts[i]) - else: - for rep in range(repeat_factor): - for i in range(logical_batch): - out_index = rep * logical_batch + i - count = int(coord_counts[i].item()) - coords_i = coords[coords[:, 0] == i].clone() - if coords_i.shape[0] != count: - raise ValueError( - f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}" - ) - coords_i[:, 0] = 0 - feats_i = x_eval[out_index, :count] - x_st_i = SparseTensor(feats=feats_i, coords=coords_i.to(torch.int32)) - t_i = t_eval[out_index].unsqueeze(0) if t_eval.shape[0] > 1 else t_eval - c_i = c_eval[out_index].unsqueeze(0) if c_eval.shape[0] > 1 else c_eval - - if mode == "shape_generation": - if is_512_run: - sparse_out = self.img2shape_512(x_st_i, t_i, c_i) - else: - sparse_out = self.img2shape(x_st_i, t_i, c_i) + if mode == "shape_generation": + if is_512_run: + sparse_out = self.img2shape_512(x_st_i, t_i, c_i) else: - slat = transformer_options.get("shape_slat") - if slat is None: - raise ValueError("shape_slat can't be None") - if slat.ndim == 3: - if slat.shape[0] != logical_batch: - raise ValueError( - f"shape_slat batch {slat.shape[0]} doesn't match coord_counts batch {logical_batch}" - ) - if slat.shape[1] < count: - raise ValueError( - f"shape_slat tokens {slat.shape[1]} can't cover coord count {count} for batch {i}" - ) - slat_feats = slat[i, :count].to(x_st_i.device) - else: - slat_feats = slat[:count].to(x_st_i.device) - x_st_i = x_st_i.replace(feats=torch.cat([x_st_i.feats, slat_feats], dim=-1)) - sparse_out = self.shape2txt(x_st_i, t_i, c_i) + sparse_out = self.img2shape(x_st_i, t_i, c_i) + else: + slat = transformer_options.get("shape_slat") + if slat is None: + raise ValueError("shape_slat can't be None") + if slat.ndim == 3: + if slat.shape[0] != logical_batch: + raise ValueError( + f"shape_slat batch {slat.shape[0]} doesn't match coord_counts batch {logical_batch}" + ) + if slat.shape[1] < count: + raise ValueError( + f"shape_slat tokens {slat.shape[1]} can't cover coord count {count} for batch {i}" + ) + slat_feats = slat[i, :count].to(x_st_i.device) + else: + slat_feats = slat[:count].to(x_st_i.device) + x_st_i = x_st_i.replace(feats=torch.cat([x_st_i.feats, slat_feats], dim=-1)) + sparse_out = self.shape2txt(x_st_i, t_i, c_i) - sparse_outs.append(sparse_out.feats) - active_coord_counts.append(count) + sparse_outs.append(sparse_out.feats) + active_coord_counts.append(count) out_channels = sparse_outs[0].shape[-1] padded = sparse_outs[0].new_zeros((B, N, out_channels)) @@ -1022,7 +986,6 @@ class Trellis2(nn.Module): out = self.shape2txt(x_st, t_eval, c_eval) else: # structure orig_bsz = x.shape[0] - cond_or_uncond = transformer_options.get("cond_or_uncond") or [] batch_groups = len(cond_or_uncond) if len(cond_or_uncond) > 0 and orig_bsz % len(cond_or_uncond) == 0 else 1 logical_batch = orig_bsz // batch_groups if logical_batch > 1: @@ -1034,23 +997,19 @@ class Trellis2(nn.Module): c_groups = context.reshape(batch_groups, logical_batch, *context.shape[1:]) if shape_rule and batch_groups > 1: - selected_group_indices = [batch_groups - 1] + selected_group_indices = cond_group_indices(batch_groups) else: selected_group_indices = list(range(batch_groups)) out_groups = [] for sample_index in range(logical_batch): if shape_rule and batch_groups > 1: - half = orig_bsz // 2 - x_i = x[half + sample_index].unsqueeze(0) + x_i = x_groups[selected_group_indices, sample_index] if timestep.shape[0] > 1: - t_i = timestep[half + sample_index].unsqueeze(0) + t_i = t_groups[selected_group_indices, sample_index] else: t_i = timestep - if cond.shape[0] > 1: - c_i = cond[sample_index].unsqueeze(0) - else: - c_i = cond + c_i = c_groups[selected_group_indices, sample_index] else: x_i = x_groups[selected_group_indices, sample_index] if timestep.shape[0] > 1: diff --git a/comfy/sample.py b/comfy/sample.py index 7251aa799..6fba221ed 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -15,32 +15,26 @@ def prepare_noise_inner(latent_image, generator, noise_inds=None): else: noise_inds = np.asarray(noise_inds, dtype=np.int64) + base_seed = int(generator.initial_seed()) unique_inds = np.unique(noise_inds) - first_indices = {int(unique_index): int(np.flatnonzero(noise_inds == unique_index)[0]) for unique_index in unique_inds.tolist()} - index_states = {} - for unique_index in sorted(first_indices): - index_states[unique_index] = generator.get_state().clone() - count = int(coord_counts[first_indices[unique_index]].item()) - torch.randn( - [1, latent_image.size(1), count, latent_image.size(3)], - dtype=torch.float32, - layout=latent_image.layout, - generator=generator, - device="cpu", - ) - - for batch_index, noise_index in enumerate(noise_inds.tolist()): - count = int(coord_counts[batch_index].item()) + sample_noises = {} + for noise_index in unique_inds.tolist(): + rows = np.flatnonzero(noise_inds == noise_index) + max_count = max(int(coord_counts[row].item()) for row in rows.tolist()) local_generator = torch.Generator(device="cpu") - local_generator.set_state(index_states[int(noise_index)].clone()) - sample_noise = torch.randn( - [1, latent_image.size(1), count, latent_image.size(3)], + local_generator.manual_seed(base_seed + int(noise_index)) + sample_noises[int(noise_index)] = torch.randn( + [1, latent_image.size(1), max_count, latent_image.size(3)], dtype=torch.float32, layout=latent_image.layout, generator=local_generator, device="cpu", ) - noise[batch_index:batch_index + 1, :, :count, :] = sample_noise + + for batch_index, noise_index in enumerate(noise_inds.tolist()): + count = int(coord_counts[batch_index].item()) + sample_noise = sample_noises[int(noise_index)] + noise[batch_index:batch_index + 1, :, :count, :] = sample_noise[:, :, :count, :] return noise.to(dtype=latent_image.dtype) if noise_inds is None: @@ -76,6 +70,8 @@ def prepare_noise(latent_image, seed, noise_inds=None): def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None): if latent_image.is_nested: return latent_image + if getattr(latent_image, "trellis_skip_empty_fix", False): + return latent_image latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels if torch.count_nonzero(latent_image) == 0: if latent_format.latent_channels != latent_image.shape[1]: diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 621cc9586..6556ed176 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -115,18 +115,54 @@ def infer_batched_coord_layout(coords): return batch_size, counts, max_tokens +def split_batched_coords(coords, coord_counts): + batch_ids = coords[:, 0].to(torch.int64) + order = torch.argsort(batch_ids, stable=True) + sorted_coords = coords.index_select(0, order) + sorted_batch_ids = batch_ids.index_select(0, order) + + offsets = coord_counts.cumsum(0) - coord_counts + items = [] + for i in range(coord_counts.shape[0]): + count = int(coord_counts[i].item()) + start = int(offsets[i].item()) + coords_i = sorted_coords[start:start + count] + ids_i = sorted_batch_ids[start:start + count] + if coords_i.shape[0] != count or not torch.all(ids_i == i): + raise ValueError(f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}") + items.append(coords_i) + return items + + +def normalize_batch_index(batch_index): + if batch_index is None: + return None + if isinstance(batch_index, int): + return [int(batch_index)] + return list(batch_index) + + +def resolve_sample_indices(batch_index, batch_size): + sample_indices = normalize_batch_index(batch_index) + if sample_indices is None: + return list(range(batch_size)) + if len(sample_indices) != batch_size: + raise ValueError( + f"Trellis2 batch_index length {len(sample_indices)} does not match batch size {batch_size}" + ) + return sample_indices + + def flatten_batched_sparse_latent(samples, coords, coord_counts): samples = samples.squeeze(-1).transpose(1, 2) if coord_counts is None: return samples.reshape(-1, samples.shape[-1]), coords + coords_items = split_batched_coords(coords, coord_counts) feat_list = [] coord_list = [] - for i in range(coord_counts.shape[0]): + for i, coords_i in enumerate(coords_items): count = int(coord_counts[i].item()) - coords_i = coords[coords[:, 0] == i] - if coords_i.shape[0] != count: - raise ValueError(f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}") feat_list.append(samples[i, :count]) coord_list.append(coords_i) @@ -138,12 +174,10 @@ def split_batched_sparse_latent(samples, coords, coord_counts): if coord_counts is None: return [(samples.reshape(-1, samples.shape[-1]), coords)] + coords_items = split_batched_coords(coords, coord_counts) items = [] - for i in range(coord_counts.shape[0]): + for i, coords_i in enumerate(coords_items): count = int(coord_counts[i].item()) - coords_i = coords[coords[:, 0] == i] - if coords_i.shape[0] != count: - raise ValueError(f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}") items.append((samples[i, :count], coords_i)) return items @@ -345,6 +379,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): load_device = comfy.model_management.get_torch_device() offload_device = comfy.model_management.vae_offload_device() decoder = decoder.to(load_device) + batch_index = normalize_batch_index(samples.get("batch_index")) samples = samples["samples"] samples = samples.to(load_device) if samples.shape[0] > 1: @@ -361,6 +396,8 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): ratio = current_res // resolution decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5 out = Types.VOXEL(decoded.squeeze(1).float()) + if batch_index is not None: + out.batch_index = normalize_batch_index(batch_index) return IO.NodeOutput(out) class Trellis2UpsampleCascade(IO.ComfyNode): @@ -386,6 +423,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode): comfy.model_management.load_model_gpu(vae.patcher) coord_counts = shape_latent_512.get("coord_counts") + batch_index = normalize_batch_index(shape_latent_512.get("batch_index")) decoder = vae.first_stage_model.shape_dec lr_resolution = 512 target_resolution = int(target_resolution) @@ -424,40 +462,48 @@ class Trellis2UpsampleCascade(IO.ComfyNode): ) decoder_dtype = next(decoder.parameters()).dtype - final_coords_list = [] - output_resolutions = [] - output_coord_counts = [] - for batch_index, (feats_i, coords_i) in enumerate(items): + sample_hr_coords = [] + for feats_i, coords_i in items: feats_i = feats_i.to(device) coords_i = coords_i.to(device).clone() coords_i[:, 0] = 0 slat_i = shape_norm(feats_i, coords_i) slat_i.feats = slat_i.feats.to(decoder_dtype) - hr_coords_i = decoder.upsample(slat_i, upsample_times=4) + sample_hr_coords.append(decoder.upsample(slat_i, upsample_times=4)) - hr_resolution = target_resolution - while True: + hr_resolution = target_resolution + while True: + exceeds_limit = False + for hr_coords_i in sample_hr_coords: quant_coords_i = torch.cat([ hr_coords_i[:, :1], ((hr_coords_i[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), ], dim=1) - final_coords_i = quant_coords_i.unique(dim=0) - num_tokens = final_coords_i.shape[0] - - if num_tokens < max_tokens or hr_resolution <= 1024: + if quant_coords_i.unique(dim=0).shape[0] >= max_tokens: + exceeds_limit = True break - hr_resolution -= 128 + if not exceeds_limit or hr_resolution <= 1024: + break + hr_resolution -= 128 + final_coords_list = [] + output_coord_counts = [] + for sample_offset, hr_coords_i in enumerate(sample_hr_coords): + quant_coords_i = torch.cat([ + hr_coords_i[:, :1], + ((hr_coords_i[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), + ], dim=1) + final_coords_i = quant_coords_i.unique(dim=0) final_coords_i = final_coords_i.clone() - final_coords_i[:, 0] = batch_index + final_coords_i[:, 0] = sample_offset final_coords_list.append(final_coords_i) - output_resolutions.append(int(hr_resolution)) output_coord_counts.append(int(final_coords_i.shape[0])) return IO.NodeOutput({ "coords": torch.cat(final_coords_list, dim=0), "coord_counts": torch.tensor(output_coord_counts, dtype=torch.int64), - "resolutions": torch.tensor(output_resolutions, dtype=torch.int64), + "resolutions": torch.full((len(final_coords_list),), int(hr_resolution), dtype=torch.int64), + "batch_index": normalize_batch_index(batch_index), },) dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) @@ -612,7 +658,8 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): category="latent/3d", inputs=[ IO.AnyType.Input("structure_or_coords"), - IO.Model.Input("model") + IO.Model.Input("model"), + IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff), ], outputs=[ IO.Latent.Output(), @@ -621,21 +668,24 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, structure_or_coords, model): + def execute(cls, structure_or_coords, model, seed): # to accept the upscaled coords is_512_pass = False coord_counts = None coord_resolutions = None + batch_index = None if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: decoded = structure_or_coords.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() is_512_pass = True + batch_index = normalize_batch_index(getattr(structure_or_coords, "batch_index", None)) elif isinstance(structure_or_coords, dict): coords = structure_or_coords["coords"].int() coord_counts = structure_or_coords.get("coord_counts") coord_resolutions = structure_or_coords.get("resolutions") + batch_index = normalize_batch_index(structure_or_coords.get("batch_index")) is_512_pass = False elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: @@ -655,15 +705,17 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): else: coord_counts = inferred_coord_counts if batch_size == 1: - coord_counts = None - latent = torch.randn(1, in_channels, coords.shape[0], 1) + sample_indices = normalize_batch_index(batch_index) or [0] + generator = torch.Generator(device="cpu") + generator.manual_seed(int(seed) + int(sample_indices[0])) + latent = torch.randn(1, in_channels, coords.shape[0], 1, generator=generator) else: + sample_indices = resolve_sample_indices(batch_index, batch_size) latent = torch.zeros(batch_size, in_channels, max_tokens, 1) - base_state = torch.random.get_rng_state() - for i in range(batch_size): + for i, sample_index in enumerate(sample_indices): count = int(coord_counts[i].item()) generator = torch.Generator(device="cpu") - generator.set_state(base_state.clone()) + generator.manual_seed(int(seed) + int(sample_index)) latent_i = torch.randn(1, in_channels, count, 1, generator=generator) latent[i, :, :count] = latent_i[0] if coord_counts is not None: @@ -685,11 +737,12 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): else: model.model_options["transformer_options"]["generation_mode"] = "shape_generation" output = {"samples": latent, "coords": coords, "type": "trellis2"} + if batch_index is not None: + output["batch_index"] = normalize_batch_index(batch_index) if coord_counts is not None: output["coord_counts"] = coord_counts if coord_resolutions is not None: output["coord_resolutions"] = coord_resolutions - output["batch_index"] = [0] * batch_size return IO.NodeOutput(output, model) class EmptyTextureLatentTrellis2(IO.ComfyNode): @@ -701,7 +754,8 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): inputs=[ IO.Voxel.Input("structure_or_coords"), IO.Latent.Input("shape_latent"), - IO.Model.Input("model") + IO.Model.Input("model"), + IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff), ], outputs=[ IO.Latent.Output(), @@ -710,20 +764,24 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, structure_or_coords, shape_latent, model): + def execute(cls, structure_or_coords, shape_latent, model, seed): channels = 32 coord_counts = None + batch_index = None if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: decoded = structure_or_coords.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + batch_index = normalize_batch_index(getattr(structure_or_coords, "batch_index", None)) elif isinstance(structure_or_coords, dict): coords = structure_or_coords["coords"].int() coord_counts = structure_or_coords.get("coord_counts") + batch_index = normalize_batch_index(structure_or_coords.get("batch_index")) elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: coords = structure_or_coords.int() + shape_batch_index = normalize_batch_index(shape_latent.get("batch_index")) shape_latent = shape_latent["samples"] batch_size, inferred_coord_counts, max_tokens = infer_batched_coord_layout(coords) if coord_counts is not None: @@ -746,19 +804,23 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): ) if batch_size == 1: - coord_counts = None - latent = torch.randn(1, channels, coords.shape[0], 1) + sample_indices = normalize_batch_index(batch_index) or [0] + generator = torch.Generator(device="cpu") + generator.manual_seed(int(seed) + int(sample_indices[0])) + latent = torch.randn(1, channels, coords.shape[0], 1, generator=generator) else: + sample_indices = resolve_sample_indices(batch_index, batch_size) latent = torch.zeros(batch_size, channels, max_tokens, 1) - base_state = torch.random.get_rng_state() - for i in range(batch_size): + for i, sample_index in enumerate(sample_indices): count = int(coord_counts[i].item()) generator = torch.Generator(device="cpu") - generator.set_state(base_state.clone()) + generator.manual_seed(int(seed) + int(sample_index)) latent_i = torch.randn(1, channels, count, 1, generator=generator) latent[i, :, :count] = latent_i[0] if coord_counts is not None: latent.trellis_coord_counts = coord_counts.clone() + if batch_index is None: + batch_index = shape_batch_index model = model.clone() model.model_options = model.model_options.copy() if "transformer_options" in model.model_options: @@ -772,9 +834,10 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"]["generation_mode"] = "texture_generation" model.model_options["transformer_options"]["shape_slat"] = shape_latent output = {"samples": latent, "coords": coords, "type": "trellis2"} + if batch_index is not None: + output["batch_index"] = normalize_batch_index(batch_index) if coord_counts is not None: output["coord_counts"] = coord_counts - output["batch_index"] = [0] * batch_size return IO.NodeOutput(output, model) @@ -786,19 +849,29 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): category="latent/3d", inputs=[ IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), + IO.Int.Input("batch_index_start", default=0, min=0, max=4096, tooltip="Starting sample index for per-sample sampler noise."), + IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff), ], outputs=[ IO.Latent.Output(), ] ) @classmethod - def execute(cls, batch_size): + def execute(cls, batch_size, batch_index_start, seed): in_channels = 8 resolution = 16 - latent = torch.randn(1, in_channels, resolution, resolution, resolution).repeat(batch_size, 1, 1, 1, 1) - output = {"samples": latent, "type": "trellis2"} - if batch_size > 1: - output["batch_index"] = [0] * batch_size + sample_indices = [int(batch_index_start) + i for i in range(batch_size)] + latent = torch.zeros(batch_size, in_channels, resolution, resolution, resolution) + for i, sample_index in enumerate(sample_indices): + generator = torch.Generator(device="cpu") + generator.manual_seed(int(seed) + sample_index) + latent[i] = torch.randn(1, in_channels, resolution, resolution, resolution, generator=generator)[0] + output = { + "samples": latent, + "type": "trellis2", + } + if batch_size > 1 or batch_index_start != 0: + output["batch_index"] = sample_indices return IO.NodeOutput(output) def simplify_fn(vertices, faces, colors=None, target=100000): diff --git a/tests-unit/comfy_extras_test/nodes_trellis2_test.py b/tests-unit/comfy_extras_test/nodes_trellis2_test.py index 920eca471..95f64d031 100644 --- a/tests-unit/comfy_extras_test/nodes_trellis2_test.py +++ b/tests-unit/comfy_extras_test/nodes_trellis2_test.py @@ -123,5 +123,88 @@ class TestRunConditioningRestore(unittest.TestCase): self.assertFalse(hasattr(inner_model, "image_size")) +class DummyCloneModel: + def __init__(self): + self.model_options = {} + + def clone(self): + cloned = DummyCloneModel() + cloned.model_options = self.model_options.copy() + return cloned + + +class TestTrellisBatchSemantics(unittest.TestCase): + def test_empty_structure_latent_is_deterministic_and_propagates_sample_indices(self): + batch_output = nodes_trellis2.EmptyStructureLatentTrellis2.execute(2, 0, 17)[0] + single_output = nodes_trellis2.EmptyStructureLatentTrellis2.execute(1, 5, 17)[0] + + expected_batch = torch.zeros(2, 8, 16, 16, 16) + expected_batch[0] = torch.randn(1, 8, 16, 16, 16, generator=torch.Generator(device="cpu").manual_seed(17))[0] + expected_batch[1] = torch.randn(1, 8, 16, 16, 16, generator=torch.Generator(device="cpu").manual_seed(18))[0] + expected_single = torch.randn(1, 8, 16, 16, 16, generator=torch.Generator(device="cpu").manual_seed(22)) + + self.assertTrue(torch.equal(batch_output["samples"], expected_batch)) + self.assertEqual(batch_output["batch_index"], [0, 1]) + self.assertTrue(torch.equal(single_output["samples"], expected_single)) + self.assertEqual(single_output["batch_index"], [5]) + + def test_empty_shape_latent_is_deterministic_and_propagates_batch_index(self): + coords = torch.tensor( + [ + [1, 5, 5, 5], + [0, 1, 1, 1], + [1, 6, 6, 6], + [0, 2, 2, 2], + [1, 7, 7, 7], + ], + dtype=torch.int32, + ) + structure = { + "coords": coords, + "coord_counts": torch.tensor([2, 3], dtype=torch.int64), + "batch_index": [4, 9], + } + + output, _ = nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 23) + + expected = torch.zeros(2, 32, 3, 1) + expected[0, :, :2, :] = torch.randn(1, 32, 2, 1, generator=torch.Generator(device="cpu").manual_seed(27))[0] + expected[1, :, :3, :] = torch.randn(1, 32, 3, 1, generator=torch.Generator(device="cpu").manual_seed(32))[0] + + self.assertTrue(torch.equal(output["samples"], expected)) + self.assertTrue(torch.equal(output["coord_counts"], torch.tensor([2, 3], dtype=torch.int64))) + self.assertEqual(output["batch_index"], [4, 9]) + + def test_empty_shape_latent_keeps_singleton_coord_counts(self): + structure = { + "coords": torch.tensor( + [ + [0, 1, 1, 1], + [0, 2, 2, 2], + ], + dtype=torch.int32, + ), + } + + output, _ = nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 11) + + self.assertTrue(torch.equal(output["coord_counts"], torch.tensor([2], dtype=torch.int64))) + + def test_flatten_batched_sparse_latent_validates_coord_counts(self): + samples = torch.zeros(2, 32, 3, 1) + coords = torch.tensor( + [ + [0, 1, 1, 1], + [1, 2, 2, 2], + [1, 3, 3, 3], + ], + dtype=torch.int32, + ) + coord_counts = torch.tensor([2, 1], dtype=torch.int64) + + with self.assertRaises(ValueError): + nodes_trellis2.flatten_batched_sparse_latent(samples, coords, coord_counts) + + if __name__ == "__main__": unittest.main() diff --git a/tests-unit/comfy_test/sample_test.py b/tests-unit/comfy_test/sample_test.py new file mode 100644 index 000000000..ad154aca8 --- /dev/null +++ b/tests-unit/comfy_test/sample_test.py @@ -0,0 +1,47 @@ +import unittest + +import torch + +import comfy.sample + + +class TestPrepareNoiseInnerTrellis(unittest.TestCase): + def test_coord_counts_noise_matches_per_index_prefix_draws(self): + latent = torch.zeros(2, 4, 5, 1) + latent.trellis_coord_counts = torch.tensor([3, 5], dtype=torch.int64) + + generator = torch.Generator(device="cpu") + generator.manual_seed(123) + noise = comfy.sample.prepare_noise_inner(latent, generator) + + expected = torch.zeros_like(noise, dtype=torch.float32) + row0 = torch.Generator(device="cpu") + row0.manual_seed(123) + expected[0, :, :3, :] = torch.randn(1, 4, 3, 1, generator=row0)[0] + row1 = torch.Generator(device="cpu") + row1.manual_seed(124) + expected[1] = torch.randn(1, 4, 5, 1, generator=row1)[0] + + self.assertTrue(torch.equal(noise.float(), expected)) + self.assertTrue(torch.equal(noise[0, :, 3:, :], torch.zeros_like(noise[0, :, 3:, :]))) + + def test_coord_counts_noise_inds_share_prefixes_for_duplicates(self): + latent = torch.zeros(2, 4, 5, 1) + latent.trellis_coord_counts = torch.tensor([3, 5], dtype=torch.int64) + + generator = torch.Generator(device="cpu") + generator.manual_seed(456) + noise = comfy.sample.prepare_noise_inner(latent, generator, noise_inds=[7, 7]) + + replay = torch.Generator(device="cpu") + replay.manual_seed(463) + expected1 = torch.randn(1, 4, 5, 1, generator=replay) + expected0 = expected1[:, :, :3, :] + + self.assertTrue(torch.equal(noise[0:1, :, :3, :], expected0)) + self.assertTrue(torch.equal(noise[1:2, :, :5, :], expected1)) + self.assertTrue(torch.equal(noise[0, :, 3:, :], torch.zeros_like(noise[0, :, 3:, :]))) + + +if __name__ == "__main__": + unittest.main()