diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 1c5d6c3ec..e8ed39aed 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -786,6 +786,7 @@ class Trellis2(nn.Module): # 32 -> 512px path, 64 -> 1024px path. uses_1024_conditioning = self.img2shape.resolution == 64 coords = transformer_options.get("coords", None) + coord_counts = transformer_options.get("coord_counts") mode = transformer_options.get("generation_mode", "structure_generation") is_512_run = False timestep = timestep.to(self.dtype) @@ -811,15 +812,33 @@ class Trellis2(nn.Module): cond = context shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] + dense_out = None + cond_or_uncond = transformer_options.get("cond_or_uncond") or [] + + def cond_group_indices(batch_groups): + if len(cond_or_uncond) == batch_groups: + cond_groups = [i for i, marker in enumerate(cond_or_uncond) if marker == 0] + if len(cond_groups) > 0: + return cond_groups + return [batch_groups - 1] if not_struct_mode: orig_bsz = x.shape[0] rule = txt_rule if mode == "texture_generation" else shape_rule - if rule and orig_bsz > 1: - x_eval = x[1].unsqueeze(0) - t_eval = timestep[1].unsqueeze(0) if timestep.shape[0] > 1 else timestep - c_eval = cond + logical_batch = coord_counts.shape[0] if coord_counts is not None else 1 + if rule and orig_bsz > logical_batch: + batch_groups = orig_bsz // logical_batch + selected_groups = cond_group_indices(batch_groups) + x_groups = x.reshape(batch_groups, logical_batch, *x.shape[1:]) + x_eval = x_groups[selected_groups].reshape(-1, *x.shape[1:]) + if timestep.shape[0] > 1: + t_groups = timestep.reshape(batch_groups, logical_batch, *timestep.shape[1:]) + t_eval = t_groups[selected_groups].reshape(-1, *timestep.shape[1:]) + else: + t_eval = timestep + c_groups = context.reshape(batch_groups, logical_batch, *context.shape[1:]) + c_eval = c_groups[selected_groups].reshape(-1, *context.shape[1:]) else: x_eval = x t_eval = timestep @@ -828,23 +847,107 @@ class Trellis2(nn.Module): B, N, C = x_eval.shape if mode in ["shape_generation", "texture_generation"]: - feats_flat = x_eval.reshape(-1, C) + if coord_counts is not None: + logical_batch = coord_counts.shape[0] + if B % logical_batch != 0: + raise ValueError( + f"Trellis2 coord_counts batch {logical_batch} doesn't divide latent batch {B}" + ) + if int(coord_counts.sum().item()) != coords.shape[0]: + raise ValueError( + f"Trellis2 coord_counts total {int(coord_counts.sum().item())} does not match coords rows {coords.shape[0]}" + ) + batch_ids = coords[:, 0].to(torch.int64) + order = torch.argsort(batch_ids, stable=True) + sorted_coords = coords.index_select(0, order) + sorted_batch_ids = batch_ids.index_select(0, order) + offsets = coord_counts.cumsum(0) - coord_counts + coords_by_batch = [] + for i in range(logical_batch): + count = int(coord_counts[i].item()) + start = int(offsets[i].item()) + coords_i = sorted_coords[start:start + count] + ids_i = sorted_batch_ids[start:start + count] + if coords_i.shape[0] != count or not torch.all(ids_i == i): + raise ValueError( + f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}" + ) + coords_by_batch.append(coords_i) + repeat_factor = B // logical_batch + sparse_outs = [] + active_coord_counts = [] + for rep in range(repeat_factor): + for i in range(logical_batch): + out_index = rep * logical_batch + i + count = int(coord_counts[i].item()) + if count > N: + raise ValueError( + f"Trellis2 coord count {count} exceeds latent token dimension {N} for batch {i}" + ) + coords_i = coords_by_batch[i].clone() + coords_i[:, 0] = 0 + feats_i = x_eval[out_index, :count].clone() + x_st_i = SparseTensor(feats=feats_i, coords=coords_i.to(torch.int32)) + t_i = t_eval[out_index].unsqueeze(0).clone() if t_eval.shape[0] > 1 else t_eval + c_i = c_eval[out_index].unsqueeze(0).clone() if c_eval.shape[0] > 1 else c_eval - # inflate coords [N, 4] -> [B*N, 4] - coords_list = [] - for i in range(B): - c = coords.clone() - c[:, 0] = i - coords_list.append(c) + if mode == "shape_generation": + if is_512_run: + sparse_out = self.img2shape_512(x_st_i, t_i, c_i) + else: + sparse_out = self.img2shape(x_st_i, t_i, c_i) + else: + slat = transformer_options.get("shape_slat") + if slat is None: + raise ValueError("shape_slat can't be None") + if slat.ndim == 3: + if slat.shape[0] != logical_batch: + raise ValueError( + f"shape_slat batch {slat.shape[0]} doesn't match coord_counts batch {logical_batch}" + ) + if slat.shape[1] < count: + raise ValueError( + f"shape_slat tokens {slat.shape[1]} can't cover coord count {count} for batch {i}" + ) + slat_feats = slat[i, :count].to(x_st_i.device) + else: + slat_feats = slat[:count].to(x_st_i.device) + x_st_i = x_st_i.replace(feats=torch.cat([x_st_i.feats, slat_feats], dim=-1)) + sparse_out = self.shape2txt(x_st_i, t_i, c_i) - batched_coords = torch.cat(coords_list, dim=0) + sparse_outs.append(sparse_out.feats) + active_coord_counts.append(count) + + out_channels = sparse_outs[0].shape[-1] + padded = sparse_outs[0].new_zeros((B, N, out_channels)) + for out_index, (count, feats_i) in enumerate(zip(active_coord_counts, sparse_outs)): + padded[out_index, :count] = feats_i + dense_out = padded.transpose(1, 2).unsqueeze(-1) + elif coords.shape[0] == N: + feats_flat = x_eval.reshape(-1, C) + coords_list = [] + for i in range(B): + c = coords.clone() + c[:, 0] = i + coords_list.append(c) + batched_coords = torch.cat(coords_list, dim=0) + elif coords.shape[0] == B * N: + feats_flat = x_eval.reshape(-1, C) + batched_coords = coords + else: + raise ValueError( + f"Trellis2 expected coords rows {N} or {B * N}, got {coords.shape[0]}" + ) else: batched_coords = coords feats_flat = x_eval - x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) + if dense_out is None: + x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) - if mode == "shape_generation": + if dense_out is not None: + out = dense_out + elif mode == "shape_generation": if is_512_run: out = self.img2shape_512(x_st, t_eval, c_eval) else: @@ -856,23 +959,96 @@ class Trellis2(nn.Module): if slat is None: raise ValueError("shape_slat can't be None") - base_slat_feats = slat[:N] - slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device) + if slat.ndim == 3: + if coord_counts is not None: + logical_batch = coord_counts.shape[0] + if slat.shape[0] != logical_batch: + raise ValueError( + f"shape_slat batch {slat.shape[0]} doesn't match coord_counts batch {logical_batch}" + ) + if B % logical_batch != 0: + raise ValueError( + f"Trellis2 coord_counts batch {logical_batch} doesn't divide latent batch {B}" + ) + repeat_factor = B // logical_batch + slat_list = [] + for _ in range(repeat_factor): + for i in range(logical_batch): + count = int(coord_counts[i].item()) + if slat.shape[1] < count: + raise ValueError( + f"shape_slat tokens {slat.shape[1]} can't cover coord count {count} for batch {i}" + ) + slat_list.append(slat[i, :count]) + slat_feats_batched = torch.cat(slat_list, dim=0).to(x_st.device) + else: + if slat.shape[0] != B: + raise ValueError(f"shape_slat batch {slat.shape[0]} doesn't match latent batch {B}") + if slat.shape[1] != N: + raise ValueError(f"shape_slat tokens {slat.shape[1]} doesn't match latent tokens {N}") + slat_feats_batched = slat.reshape(B * N, -1).to(x_st.device) + else: + base_slat_feats = slat[:N] + slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device) x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats_batched], dim=-1)) out = self.shape2txt(x_st, t_eval, c_eval) else: # structure orig_bsz = x.shape[0] - if shape_rule and orig_bsz > 1: - half = orig_bsz // 2 - x = x[half:] - timestep = timestep[half:] if timestep.shape[0] > 1 else timestep - out = self.structure_model(x, timestep, cond if shape_rule and orig_bsz > 1 else context) - if shape_rule and orig_bsz > 1: - out = out.repeat(2, 1, 1, 1, 1) + batch_groups = len(cond_or_uncond) if len(cond_or_uncond) > 0 and orig_bsz % len(cond_or_uncond) == 0 else 1 + logical_batch = orig_bsz // batch_groups + if logical_batch > 1: + x_groups = x.reshape(batch_groups, logical_batch, *x.shape[1:]) + if timestep.shape[0] > 1: + t_groups = timestep.reshape(batch_groups, logical_batch, *timestep.shape[1:]) + else: + t_groups = timestep + c_groups = context.reshape(batch_groups, logical_batch, *context.shape[1:]) + + if shape_rule and batch_groups > 1: + selected_group_indices = cond_group_indices(batch_groups) + else: + selected_group_indices = list(range(batch_groups)) + + out_groups = [] + for sample_index in range(logical_batch): + if shape_rule and batch_groups > 1: + x_i = x_groups[selected_group_indices, sample_index] + if timestep.shape[0] > 1: + t_i = t_groups[selected_group_indices, sample_index] + else: + t_i = timestep + c_i = c_groups[selected_group_indices, sample_index] + else: + x_i = x_groups[selected_group_indices, sample_index] + if timestep.shape[0] > 1: + t_i = t_groups[selected_group_indices, sample_index] + else: + t_i = timestep + c_i = c_groups[selected_group_indices, sample_index] + out_groups.append(self.structure_model(x_i, t_i, c_i)) + + out = out_groups[0].new_zeros((orig_bsz, *out_groups[0].shape[1:])) + for sample_index, out_sample in enumerate(out_groups): + if shape_rule and batch_groups > 1: + repeated = out_sample[0] + for group_index in range(batch_groups): + out[group_index * logical_batch + sample_index] = repeated + else: + for local_group_index, group_index in enumerate(selected_group_indices): + out[group_index * logical_batch + sample_index] = out_sample[local_group_index] + else: + if shape_rule and orig_bsz > 1: + half = orig_bsz // 2 + x = x[half:] + timestep = timestep[half:] if timestep.shape[0] > 1 else timestep + out = self.structure_model(x, timestep, cond if shape_rule and orig_bsz > 1 else context) + if shape_rule and orig_bsz > 1: + out = out.repeat(2, 1, 1, 1, 1) if not_struct_mode: - out = out.feats - out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) - if rule and orig_bsz > 1: - out = out.repeat(orig_bsz, 1, 1, 1) + if dense_out is None: + out = out.feats + out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) + if rule and orig_bsz > B: + out = out.repeat(orig_bsz // B, 1, 1, 1) return out diff --git a/comfy/sample.py b/comfy/sample.py index 653829582..878c4e984 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -7,6 +7,50 @@ import logging import comfy.nested_tensor def prepare_noise_inner(latent_image, generator, noise_inds=None): + coord_counts = getattr(latent_image, "trellis_coord_counts", None) + if coord_counts is not None: + if coord_counts.ndim != 1: + raise ValueError(f"Trellis2 coord_counts must be 1D, got shape {tuple(coord_counts.shape)}") + if coord_counts.shape[0] != latent_image.size(0): + raise ValueError( + f"Trellis2 coord_counts length {coord_counts.shape[0]} does not match latent batch {latent_image.size(0)}" + ) + if (coord_counts < 0).any() or (coord_counts > latent_image.size(2)).any(): + raise ValueError( + f"Trellis2 coord_counts must be within [0, {latent_image.size(2)}], got {coord_counts.tolist()}" + ) + noise = torch.zeros(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, device="cpu") + if noise_inds is None: + noise_inds = np.arange(latent_image.size(0), dtype=np.int64) + else: + noise_inds = np.asarray(noise_inds, dtype=np.int64) + if noise_inds.shape[0] != latent_image.size(0): + raise ValueError( + f"Trellis2 noise_inds length {noise_inds.shape[0]} does not match latent batch {latent_image.size(0)}" + ) + + base_seed = int(generator.initial_seed()) + unique_inds = np.unique(noise_inds) + sample_noises = {} + for noise_index in unique_inds.tolist(): + rows = np.flatnonzero(noise_inds == noise_index) + max_count = max(int(coord_counts[row].item()) for row in rows.tolist()) + local_generator = torch.Generator(device="cpu") + local_generator.manual_seed(base_seed + int(noise_index)) + sample_noises[int(noise_index)] = torch.randn( + [1, latent_image.size(1), max_count, latent_image.size(3)], + dtype=torch.float32, + layout=latent_image.layout, + generator=local_generator, + device="cpu", + ) + + for batch_index, noise_index in enumerate(noise_inds.tolist()): + count = int(coord_counts[batch_index].item()) + sample_noise = sample_noises[int(noise_index)] + noise[batch_index:batch_index + 1, :, :count, :] = sample_noise[:, :, :count, :] + return noise.to(dtype=latent_image.dtype) + if noise_inds is None: return torch.randn(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 8121e261b..56ec3e736 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -96,6 +96,114 @@ def shape_norm(shape_latent, coords): samples = samples * std + mean return samples + +def infer_batched_coord_layout(coords): + if coords.ndim != 2 or coords.shape[1] != 4: + raise ValueError(f"Expected Trellis2 coords with shape [N, 4], got {tuple(coords.shape)}") + + if coords.shape[0] == 0: + raise ValueError("Trellis2 coords can't be empty") + + batch_ids = coords[:, 0].to(torch.int64) + if (batch_ids < 0).any(): + raise ValueError(f"Trellis2 batch ids must be non-negative, got {batch_ids.unique(sorted=True).tolist()}") + batch_size = int(batch_ids.max().item()) + 1 + counts = torch.bincount(batch_ids, minlength=batch_size) + + if (counts == 0).any(): + raise ValueError(f"Non-contiguous Trellis2 batch ids in coords: {batch_ids.unique(sorted=True).tolist()}") + + max_tokens = int(counts.max().item()) + return batch_size, counts, max_tokens + + +def split_batched_coords(coords, coord_counts): + if coord_counts.ndim != 1: + raise ValueError(f"Trellis2 coord_counts must be 1D, got shape {tuple(coord_counts.shape)}") + if (coord_counts < 0).any(): + raise ValueError(f"Trellis2 coord_counts must be non-negative, got {coord_counts.tolist()}") + if int(coord_counts.sum().item()) != coords.shape[0]: + raise ValueError( + f"Trellis2 coord_counts total {int(coord_counts.sum().item())} does not match coords rows {coords.shape[0]}" + ) + + batch_ids = coords[:, 0].to(torch.int64) + order = torch.argsort(batch_ids, stable=True) + sorted_coords = coords.index_select(0, order) + sorted_batch_ids = batch_ids.index_select(0, order) + + offsets = coord_counts.cumsum(0) - coord_counts + items = [] + for i in range(coord_counts.shape[0]): + count = int(coord_counts[i].item()) + start = int(offsets[i].item()) + coords_i = sorted_coords[start:start + count] + ids_i = sorted_batch_ids[start:start + count] + if coords_i.shape[0] != count or not torch.all(ids_i == i): + raise ValueError(f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}") + items.append(coords_i) + return items + + +def normalize_batch_index(batch_index): + if batch_index is None: + return None + if isinstance(batch_index, int): + return [int(batch_index)] + return list(batch_index) + + +def resolve_sample_indices(batch_index, batch_size): + sample_indices = normalize_batch_index(batch_index) + if sample_indices is None: + return list(range(batch_size)) + if len(sample_indices) != batch_size: + raise ValueError( + f"Trellis2 batch_index length {len(sample_indices)} does not match batch size {batch_size}" + ) + return sample_indices + + +def resolve_singleton_sample_index(batch_index): + sample_indices = normalize_batch_index(batch_index) + if sample_indices is None: + return 0 + if len(sample_indices) != 1: + raise ValueError( + f"Trellis2 batch_index must be an int or single-element iterable for singleton coords, got {sample_indices}" + ) + return int(sample_indices[0]) + + +def flatten_batched_sparse_latent(samples, coords, coord_counts): + samples = samples.squeeze(-1).transpose(1, 2) + if coord_counts is None: + return samples.reshape(-1, samples.shape[-1]), coords + + coords_items = split_batched_coords(coords, coord_counts) + feat_list = [] + coord_list = [] + for i, coords_i in enumerate(coords_items): + count = int(coord_counts[i].item()) + feat_list.append(samples[i, :count]) + coord_list.append(coords_i) + + return torch.cat(feat_list, dim=0), torch.cat(coord_list, dim=0) + + +def split_batched_sparse_latent(samples, coords, coord_counts): + samples = samples.squeeze(-1).transpose(1, 2) + if coord_counts is None: + return [(samples.reshape(-1, samples.shape[-1]), coords)] + + coords_items = split_batched_coords(coords, coord_counts) + items = [] + for i, coords_i in enumerate(coords_items): + count = int(coord_counts[i].item()) + items.append((samples[i, :count], coords_i)) + return items + + def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): """ Generic function to paint a mesh using nearest-neighbor colors from a sparse voxel field. @@ -169,12 +277,32 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): vae = vae.first_stage_model coords = samples["coords"] + coord_counts = samples.get("coord_counts") samples = samples["samples"] - samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) - samples = shape_norm(samples, coords) + if coord_counts is None: + samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts) + samples = shape_norm(samples.to(device), coords.to(device)) + mesh, subs = vae.decode_shape_slat(samples, resolution) + else: + split_items = split_batched_sparse_latent(samples, coords, coord_counts) + mesh = [] + subs_per_sample = [] + for feats_i, coords_i in split_items: + coords_i = coords_i.to(device).clone() + coords_i[:, 0] = 0 + sample_i = shape_norm(feats_i.to(device), coords_i) + mesh_i, subs_i = vae.decode_shape_slat(sample_i, resolution) + mesh.append(mesh_i[0]) + subs_per_sample.append(subs_i) + + subs = [] + for stage_index in range(len(subs_per_sample[0])): + stage_tensors = [sample_subs[stage_index] for sample_subs in subs_per_sample] + feats_list = [stage_tensor.feats for stage_tensor in stage_tensors] + coords_list = [stage_tensor.coords for stage_tensor in stage_tensors] + subs.append(SparseTensor.from_tensor_list(feats_list, coords_list)) - mesh, subs = vae.decode_shape_slat(samples, resolution) face_list = [m.faces for m in mesh] vert_list = [m.vertices for m in mesh] if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list): @@ -210,12 +338,14 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): vae = vae.first_stage_model coords = samples["coords"] + coord_counts = samples.get("coord_counts") samples = samples["samples"] - samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) + samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts) + samples = samples.to(device) std = tex_slat_normalization["std"].to(samples) mean = tex_slat_normalization["mean"].to(samples) - samples = SparseTensor(feats = samples, coords=coords) + samples = SparseTensor(feats = samples, coords=coords.to(device)) samples = samples * std + mean voxel = vae.decode_tex_slat(samples, shape_subs) @@ -271,9 +401,16 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): load_device = comfy.model_management.get_torch_device() offload_device = comfy.model_management.vae_offload_device() decoder = decoder.to(load_device) + batch_index = normalize_batch_index(samples.get("batch_index")) samples = samples["samples"] samples = samples.to(load_device) - decoded = decoder(samples)>0 + if samples.shape[0] > 1: + decoded_items = [] + for i in range(samples.shape[0]): + decoded_items.append(decoder(samples[i:i + 1]) > 0) + decoded = torch.cat(decoded_items, dim=0) + else: + decoded = decoder(samples) > 0 decoder.to(offload_device) current_res = decoded.shape[2] @@ -281,6 +418,8 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): ratio = current_res // resolution decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5 out = Types.VOXEL(decoded.squeeze(1).float()) + if batch_index is not None: + out.batch_index = normalize_batch_index(batch_index) return IO.NodeOutput(out) class Trellis2UpsampleCascade(IO.ComfyNode): @@ -305,32 +444,93 @@ class Trellis2UpsampleCascade(IO.ComfyNode): device = comfy.model_management.get_torch_device() comfy.model_management.load_model_gpu(vae.patcher) - feats = shape_latent_512["samples"].squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) - coords_512 = shape_latent_512["coords"].to(device) - - slat = shape_norm(feats, coords_512) - + coord_counts = shape_latent_512.get("coord_counts") + batch_index = normalize_batch_index(shape_latent_512.get("batch_index")) decoder = vae.first_stage_model.shape_dec - - slat.feats = slat.feats.to(next(decoder.parameters()).dtype) - hr_coords = decoder.upsample(slat, upsample_times=4) - lr_resolution = 512 - hr_resolution = int(target_resolution) + target_resolution = int(target_resolution) + if coord_counts is None: + feats, coords_512 = flatten_batched_sparse_latent( + shape_latent_512["samples"], + shape_latent_512["coords"], + coord_counts, + ) + feats = feats.to(device) + coords_512 = coords_512.to(device) + slat = shape_norm(feats, coords_512) + slat.feats = slat.feats.to(next(decoder.parameters()).dtype) + hr_coords = decoder.upsample(slat, upsample_times=4) + + hr_resolution = target_resolution + while True: + quant_coords = torch.cat([ + hr_coords[:, :1], + ((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), + ], dim=1) + final_coords = quant_coords.unique(dim=0) + num_tokens = final_coords.shape[0] + + if num_tokens < max_tokens or hr_resolution <= 1024: + break + hr_resolution -= 128 + + return IO.NodeOutput(final_coords,) + + items = split_batched_sparse_latent( + shape_latent_512["samples"], + shape_latent_512["coords"], + coord_counts, + ) + decoder_dtype = next(decoder.parameters()).dtype + + sample_hr_coords = [] + for feats_i, coords_i in items: + feats_i = feats_i.to(device) + coords_i = coords_i.to(device).clone() + coords_i[:, 0] = 0 + slat_i = shape_norm(feats_i, coords_i) + slat_i.feats = slat_i.feats.to(decoder_dtype) + sample_hr_coords.append(decoder.upsample(slat_i, upsample_times=4)) + + hr_resolution = target_resolution while True: - quant_coords = torch.cat([ - hr_coords[:, :1], - ((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), - ], dim=1) - final_coords = quant_coords.unique(dim=0) - num_tokens = final_coords.shape[0] - - if num_tokens < max_tokens or hr_resolution <= 1024: + exceeds_limit = False + for hr_coords_i in sample_hr_coords: + quant_coords_i = torch.cat([ + hr_coords_i[:, :1], + ((hr_coords_i[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), + ], dim=1) + if quant_coords_i.unique(dim=0).shape[0] >= max_tokens: + exceeds_limit = True + break + if not exceeds_limit or hr_resolution <= 1024: break hr_resolution -= 128 - return IO.NodeOutput(final_coords,) + final_coords_list = [] + output_coord_counts = [] + for sample_offset, hr_coords_i in enumerate(sample_hr_coords): + quant_coords_i = torch.cat([ + hr_coords_i[:, :1], + ((hr_coords_i[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), + ], dim=1) + final_coords_i = quant_coords_i.unique(dim=0) + final_coords_i = final_coords_i.clone() + final_coords_i[:, 0] = sample_offset + final_coords_list.append(final_coords_i) + output_coord_counts.append(int(final_coords_i.shape[0])) + + normalized_batch_index = normalize_batch_index(batch_index) + output = { + "coords": torch.cat(final_coords_list, dim=0), + "coord_counts": torch.tensor(output_coord_counts, dtype=torch.int64), + "resolutions": torch.full((len(final_coords_list),), int(hr_resolution), dtype=torch.int64), + } + if normalized_batch_index is not None: + output["batch_index"] = normalized_batch_index + + return IO.NodeOutput(output,) dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) dino_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) @@ -484,7 +684,8 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): category="latent/3d", inputs=[ IO.AnyType.Input("structure_or_coords"), - IO.Model.Input("model") + IO.Model.Input("model"), + IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff), ], outputs=[ IO.Latent.Output(), @@ -493,14 +694,25 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, structure_or_coords, model): + def execute(cls, structure_or_coords, model, seed): # to accept the upscaled coords is_512_pass = False + coord_counts = None + coord_resolutions = None + batch_index = None if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: decoded = structure_or_coords.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() is_512_pass = True + batch_index = normalize_batch_index(getattr(structure_or_coords, "batch_index", None)) + + elif isinstance(structure_or_coords, dict): + coords = structure_or_coords["coords"].int() + coord_counts = structure_or_coords.get("coord_counts") + coord_resolutions = structure_or_coords.get("resolutions") + batch_index = normalize_batch_index(structure_or_coords.get("batch_index")) + is_512_pass = False elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: coords = structure_or_coords.int() @@ -509,8 +721,31 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): else: raise ValueError(f"Invalid input to EmptyShapeLatent: {type(structure_or_coords)}") in_channels = 32 - # image like format - latent = torch.randn(1, in_channels, coords.shape[0], 1) + batch_size, inferred_coord_counts, max_tokens = infer_batched_coord_layout(coords) + if coord_counts is not None: + coord_counts = coord_counts.to(dtype=torch.int64, device=coords.device) + if coord_counts.shape != inferred_coord_counts.shape or not torch.equal(coord_counts, inferred_coord_counts): + raise ValueError( + f"Trellis2 coord_counts metadata {coord_counts.tolist()} does not match coords layout {inferred_coord_counts.tolist()}" + ) + else: + coord_counts = inferred_coord_counts + if batch_size == 1: + sample_index = resolve_singleton_sample_index(batch_index) + generator = torch.Generator(device="cpu") + generator.manual_seed(int(seed) + sample_index) + latent = torch.randn(1, in_channels, coords.shape[0], 1, generator=generator) + else: + sample_indices = resolve_sample_indices(batch_index, batch_size) + latent = torch.zeros(batch_size, in_channels, max_tokens, 1) + for i, sample_index in enumerate(sample_indices): + count = int(coord_counts[i].item()) + generator = torch.Generator(device="cpu") + generator.manual_seed(int(seed) + int(sample_index)) + latent_i = torch.randn(1, in_channels, count, 1, generator=generator) + latent[i, :, :count] = latent_i[0] + if coord_counts is not None: + latent.trellis_coord_counts = coord_counts.clone() model = model.clone() model.model_options = model.model_options.copy() if "transformer_options" in model.model_options: @@ -519,11 +754,20 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"] = {} model.model_options["transformer_options"]["coords"] = coords + if coord_counts is not None: + model.model_options["transformer_options"]["coord_counts"] = coord_counts if is_512_pass: model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512" else: model.model_options["transformer_options"]["generation_mode"] = "shape_generation" - return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model) + output = {"samples": latent, "coords": coords, "type": "trellis2"} + if batch_index is not None: + output["batch_index"] = normalize_batch_index(batch_index) + if coord_counts is not None: + output["coord_counts"] = coord_counts + if coord_resolutions is not None: + output["resolutions"] = coord_resolutions + return IO.NodeOutput(output, model) class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod @@ -534,7 +778,8 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): inputs=[ IO.Voxel.Input("structure_or_coords"), IO.Latent.Input("shape_latent"), - IO.Model.Input("model") + IO.Model.Input("model"), + IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff), ], outputs=[ IO.Latent.Output(), @@ -543,20 +788,68 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, structure_or_coords, shape_latent, model): + def execute(cls, structure_or_coords, shape_latent, model, seed): channels = 32 + coord_counts = None + batch_index = None if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: decoded = structure_or_coords.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + batch_index = normalize_batch_index(getattr(structure_or_coords, "batch_index", None)) + + elif isinstance(structure_or_coords, dict): + coords = structure_or_coords["coords"].int() + coord_counts = structure_or_coords.get("coord_counts") + batch_index = normalize_batch_index(structure_or_coords.get("batch_index")) elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: coords = structure_or_coords.int() + else: + raise ValueError( + "structure_or_coords must be a voxel input with data.ndim == 4, " + f'a dict containing "coords", or a 2D torch.Tensor; got {type(structure_or_coords).__name__}' + ) + shape_batch_index = normalize_batch_index(shape_latent.get("batch_index")) + if batch_index is None: + batch_index = shape_batch_index shape_latent = shape_latent["samples"] + batch_size, inferred_coord_counts, max_tokens = infer_batched_coord_layout(coords) + if coord_counts is not None: + coord_counts = coord_counts.to(dtype=torch.int64, device=coords.device) + if coord_counts.shape != inferred_coord_counts.shape or not torch.equal(coord_counts, inferred_coord_counts): + raise ValueError( + f"Trellis2 coord_counts metadata {coord_counts.tolist()} does not match coords layout {inferred_coord_counts.tolist()}" + ) + else: + coord_counts = inferred_coord_counts if shape_latent.ndim == 4: - shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels) + if shape_latent.shape[0] != batch_size: + raise ValueError( + f"shape_latent batch {shape_latent.shape[0]} doesn't match coords batch {batch_size}" + ) + shape_latent = shape_latent.squeeze(-1).transpose(1, 2) + if shape_latent.shape[1] < max_tokens: + raise ValueError( + f"shape_latent tokens {shape_latent.shape[1]} can't cover coords max tokens {max_tokens}" + ) - latent = torch.randn(1, channels, coords.shape[0], 1) + if batch_size == 1: + sample_index = resolve_singleton_sample_index(batch_index) + generator = torch.Generator(device="cpu") + generator.manual_seed(int(seed) + sample_index) + latent = torch.randn(1, channels, coords.shape[0], 1, generator=generator) + else: + sample_indices = resolve_sample_indices(batch_index, batch_size) + latent = torch.zeros(batch_size, channels, max_tokens, 1) + for i, sample_index in enumerate(sample_indices): + count = int(coord_counts[i].item()) + generator = torch.Generator(device="cpu") + generator.manual_seed(int(seed) + int(sample_index)) + latent_i = torch.randn(1, channels, count, 1, generator=generator) + latent[i, :, :count] = latent_i[0] + if coord_counts is not None: + latent.trellis_coord_counts = coord_counts.clone() model = model.clone() model.model_options = model.model_options.copy() if "transformer_options" in model.model_options: @@ -565,9 +858,16 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"] = {} model.model_options["transformer_options"]["coords"] = coords + if coord_counts is not None: + model.model_options["transformer_options"]["coord_counts"] = coord_counts model.model_options["transformer_options"]["generation_mode"] = "texture_generation" model.model_options["transformer_options"]["shape_slat"] = shape_latent - return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model) + output = {"samples": latent, "coords": coords, "type": "trellis2"} + if batch_index is not None: + output["batch_index"] = normalize_batch_index(batch_index) + if coord_counts is not None: + output["coord_counts"] = coord_counts + return IO.NodeOutput(output, model) class EmptyStructureLatentTrellis2(IO.ComfyNode): @@ -578,17 +878,30 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): category="latent/3d", inputs=[ IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), + IO.Int.Input("batch_index_start", default=0, min=0, max=4096, tooltip="Starting sample index for per-sample sampler noise."), + IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff), ], outputs=[ IO.Latent.Output(), ] ) @classmethod - def execute(cls, batch_size): + def execute(cls, batch_size, batch_index_start, seed): in_channels = 8 resolution = 16 - latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) - return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + sample_indices = [int(batch_index_start) + i for i in range(batch_size)] + latent = torch.zeros(batch_size, in_channels, resolution, resolution, resolution) + for i, sample_index in enumerate(sample_indices): + generator = torch.Generator(device="cpu") + generator.manual_seed(int(seed) + sample_index) + latent[i] = torch.randn(1, in_channels, resolution, resolution, resolution, generator=generator)[0] + output = { + "samples": latent, + "type": "trellis2", + } + if batch_size > 1 or batch_index_start != 0: + output["batch_index"] = sample_indices + return IO.NodeOutput(output) def simplify_fn(vertices, faces, colors=None, target=100000): if vertices.ndim == 3: diff --git a/tests-unit/comfy_extras_test/nodes_trellis2_test.py b/tests-unit/comfy_extras_test/nodes_trellis2_test.py index 920eca471..49e872bc7 100644 --- a/tests-unit/comfy_extras_test/nodes_trellis2_test.py +++ b/tests-unit/comfy_extras_test/nodes_trellis2_test.py @@ -123,5 +123,203 @@ class TestRunConditioningRestore(unittest.TestCase): self.assertFalse(hasattr(inner_model, "image_size")) +class DummyCloneModel: + def __init__(self): + self.model_options = {} + + def clone(self): + cloned = DummyCloneModel() + cloned.model_options = self.model_options.copy() + return cloned + + +class TestTrellisBatchSemantics(unittest.TestCase): + def test_empty_structure_latent_is_deterministic_and_propagates_sample_indices(self): + batch_output = nodes_trellis2.EmptyStructureLatentTrellis2.execute(2, 0, 17)[0] + single_output = nodes_trellis2.EmptyStructureLatentTrellis2.execute(1, 5, 17)[0] + + expected_batch = torch.zeros(2, 8, 16, 16, 16) + expected_batch[0] = torch.randn(1, 8, 16, 16, 16, generator=torch.Generator(device="cpu").manual_seed(17))[0] + expected_batch[1] = torch.randn(1, 8, 16, 16, 16, generator=torch.Generator(device="cpu").manual_seed(18))[0] + expected_single = torch.randn(1, 8, 16, 16, 16, generator=torch.Generator(device="cpu").manual_seed(22)) + + self.assertTrue(torch.equal(batch_output["samples"], expected_batch)) + self.assertEqual(batch_output["batch_index"], [0, 1]) + self.assertTrue(torch.equal(single_output["samples"], expected_single)) + self.assertEqual(single_output["batch_index"], [5]) + + def test_empty_shape_latent_is_deterministic_and_propagates_batch_index(self): + coords = torch.tensor( + [ + [1, 5, 5, 5], + [0, 1, 1, 1], + [1, 6, 6, 6], + [0, 2, 2, 2], + [1, 7, 7, 7], + ], + dtype=torch.int32, + ) + structure = { + "coords": coords, + "coord_counts": torch.tensor([2, 3], dtype=torch.int64), + "batch_index": [4, 9], + } + + output, _ = nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 23) + + expected = torch.zeros(2, 32, 3, 1) + expected[0, :, :2, :] = torch.randn(1, 32, 2, 1, generator=torch.Generator(device="cpu").manual_seed(27))[0] + expected[1, :, :3, :] = torch.randn(1, 32, 3, 1, generator=torch.Generator(device="cpu").manual_seed(32))[0] + + self.assertTrue(torch.equal(output["samples"], expected)) + self.assertTrue(torch.equal(output["coord_counts"], torch.tensor([2, 3], dtype=torch.int64))) + self.assertEqual(output["batch_index"], [4, 9]) + + def test_empty_shape_latent_keeps_singleton_coord_counts(self): + structure = { + "coords": torch.tensor( + [ + [0, 1, 1, 1], + [0, 2, 2, 2], + ], + dtype=torch.int32, + ), + } + + output, _ = nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 11) + + self.assertTrue(torch.equal(output["coord_counts"], torch.tensor([2], dtype=torch.int64))) + + def test_empty_shape_latent_rejects_multi_index_singleton(self): + structure = { + "coords": torch.tensor( + [ + [0, 1, 1, 1], + [0, 2, 2, 2], + ], + dtype=torch.int32, + ), + "batch_index": [5, 6], + } + + with self.assertRaises(ValueError): + nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 11) + + def test_empty_texture_latent_rejects_multi_index_singleton(self): + coords = torch.tensor( + [ + [0, 1, 1, 1], + [0, 2, 2, 2], + ], + dtype=torch.int32, + ) + structure = {"coords": coords, "batch_index": [7, 8]} + shape_latent = {"samples": torch.zeros(1, 32, 2, 1)} + + with self.assertRaises(ValueError): + nodes_trellis2.EmptyTextureLatentTrellis2.execute( + structure, + shape_latent, + DummyCloneModel(), + 13, + ) + + def test_empty_texture_latent_rejects_invalid_structure_input(self): + with self.assertRaises(ValueError): + nodes_trellis2.EmptyTextureLatentTrellis2.execute( + "bad-input", + {"samples": torch.zeros(1, 32, 2, 1)}, + DummyCloneModel(), + 13, + ) + + def test_empty_texture_latent_uses_shape_batch_index_for_seed_fallback(self): + coords = torch.tensor( + [ + [0, 1, 1, 1], + [1, 2, 2, 2], + [1, 3, 3, 3], + ], + dtype=torch.int32, + ) + structure = {"coords": coords} + shape_latent = { + "samples": torch.zeros(2, 32, 2, 1), + "batch_index": [4, 9], + } + + output, _ = nodes_trellis2.EmptyTextureLatentTrellis2.execute( + structure, + shape_latent, + DummyCloneModel(), + 13, + ) + + expected = torch.zeros(2, 32, 2, 1) + expected[0, :, :1, :] = torch.randn(1, 32, 1, 1, generator=torch.Generator(device="cpu").manual_seed(17))[0] + expected[1, :, :2, :] = torch.randn(1, 32, 2, 1, generator=torch.Generator(device="cpu").manual_seed(22))[0] + + self.assertTrue(torch.equal(output["samples"], expected)) + self.assertEqual(output["batch_index"], [4, 9]) + + def test_flatten_batched_sparse_latent_validates_coord_counts(self): + samples = torch.zeros(2, 32, 3, 1) + coords = torch.tensor( + [ + [0, 1, 1, 1], + [1, 2, 2, 2], + [1, 3, 3, 3], + ], + dtype=torch.int32, + ) + coord_counts = torch.tensor([2, 1], dtype=torch.int64) + + with self.assertRaises(ValueError): + nodes_trellis2.flatten_batched_sparse_latent(samples, coords, coord_counts) + + def test_infer_batched_coord_layout_rejects_negative_batch_ids(self): + coords = torch.tensor( + [ + [-1, 1, 1, 1], + [0, 2, 2, 2], + ], + dtype=torch.int32, + ) + + with self.assertRaises(ValueError): + nodes_trellis2.infer_batched_coord_layout(coords) + + def test_split_batched_coords_validates_total_count(self): + coords = torch.tensor( + [ + [0, 1, 1, 1], + [1, 2, 2, 2], + [1, 3, 3, 3], + ], + dtype=torch.int32, + ) + coord_counts = torch.tensor([1, 1], dtype=torch.int64) + + with self.assertRaises(ValueError): + nodes_trellis2.split_batched_coords(coords, coord_counts) + + def test_empty_shape_latent_preserves_resolutions_key(self): + structure = { + "coords": torch.tensor( + [ + [0, 1, 1, 1], + [0, 2, 2, 2], + ], + dtype=torch.int32, + ), + "resolutions": torch.tensor([1024], dtype=torch.int64), + } + + output, model = nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 11) + + self.assertTrue(torch.equal(output["resolutions"], torch.tensor([1024], dtype=torch.int64))) + self.assertNotIn("coord_resolutions", model.model_options["transformer_options"]) + + if __name__ == "__main__": unittest.main() diff --git a/tests-unit/comfy_test/sample_test.py b/tests-unit/comfy_test/sample_test.py new file mode 100644 index 000000000..227659994 --- /dev/null +++ b/tests-unit/comfy_test/sample_test.py @@ -0,0 +1,76 @@ +import unittest + +import torch + +import comfy.sample + + +class TestPrepareNoiseInnerTrellis(unittest.TestCase): + def test_coord_counts_noise_matches_per_index_prefix_draws(self): + latent = torch.zeros(2, 4, 5, 1) + latent.trellis_coord_counts = torch.tensor([3, 5], dtype=torch.int64) + + generator = torch.Generator(device="cpu") + generator.manual_seed(123) + noise = comfy.sample.prepare_noise_inner(latent, generator) + + expected = torch.zeros_like(noise, dtype=torch.float32) + row0 = torch.Generator(device="cpu") + row0.manual_seed(123) + expected[0, :, :3, :] = torch.randn(1, 4, 3, 1, generator=row0)[0] + row1 = torch.Generator(device="cpu") + row1.manual_seed(124) + expected[1] = torch.randn(1, 4, 5, 1, generator=row1)[0] + + self.assertTrue(torch.equal(noise.float(), expected)) + self.assertTrue(torch.equal(noise[0, :, 3:, :], torch.zeros_like(noise[0, :, 3:, :]))) + + def test_coord_counts_noise_inds_share_prefixes_for_duplicates(self): + latent = torch.zeros(2, 4, 5, 1) + latent.trellis_coord_counts = torch.tensor([3, 5], dtype=torch.int64) + + generator = torch.Generator(device="cpu") + generator.manual_seed(456) + noise = comfy.sample.prepare_noise_inner(latent, generator, noise_inds=[7, 7]) + + replay = torch.Generator(device="cpu") + replay.manual_seed(463) + expected1 = torch.randn(1, 4, 5, 1, generator=replay) + expected0 = expected1[:, :, :3, :] + + self.assertTrue(torch.equal(noise[0:1, :, :3, :], expected0)) + self.assertTrue(torch.equal(noise[1:2, :, :5, :], expected1)) + self.assertTrue(torch.equal(noise[0, :, 3:, :], torch.zeros_like(noise[0, :, 3:, :]))) + + def test_coord_counts_noise_inds_length_must_match_batch(self): + latent = torch.zeros(2, 4, 5, 1) + latent.trellis_coord_counts = torch.tensor([3, 5], dtype=torch.int64) + + generator = torch.Generator(device="cpu") + generator.manual_seed(456) + + with self.assertRaises(ValueError): + comfy.sample.prepare_noise_inner(latent, generator, noise_inds=[7]) + + def test_coord_counts_metadata_must_match_batch_and_bounds(self): + generator = torch.Generator(device="cpu") + generator.manual_seed(456) + + latent = torch.zeros(2, 4, 5, 1) + latent.trellis_coord_counts = torch.tensor([[3, 5]], dtype=torch.int64) + with self.assertRaises(ValueError): + comfy.sample.prepare_noise_inner(latent, generator) + + latent = torch.zeros(2, 4, 5, 1) + latent.trellis_coord_counts = torch.tensor([3], dtype=torch.int64) + with self.assertRaises(ValueError): + comfy.sample.prepare_noise_inner(latent, generator) + + latent = torch.zeros(2, 4, 5, 1) + latent.trellis_coord_counts = torch.tensor([3, 6], dtype=torch.int64) + with self.assertRaises(ValueError): + comfy.sample.prepare_noise_inner(latent, generator) + + +if __name__ == "__main__": + unittest.main()