From 90ebb50f00bf89ed8c947a0e4ed4ed0803981ea1 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 20 Apr 2026 16:05:10 -0500 Subject: [PATCH] Harden Trellis sparse latent seeding --- comfy/ldm/trellis2/model.py | 4 +++ comfy/sample.py | 2 -- comfy_extras/nodes_trellis2.py | 4 +-- .../comfy_extras_test/nodes_trellis2_test.py | 29 +++++++++++++++++++ 4 files changed, 35 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 7cf3e728e..e8ed39aed 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -880,6 +880,10 @@ class Trellis2(nn.Module): for i in range(logical_batch): out_index = rep * logical_batch + i count = int(coord_counts[i].item()) + if count > N: + raise ValueError( + f"Trellis2 coord count {count} exceeds latent token dimension {N} for batch {i}" + ) coords_i = coords_by_batch[i].clone() coords_i[:, 0] = 0 feats_i = x_eval[out_index, :count].clone() diff --git a/comfy/sample.py b/comfy/sample.py index 8626269a1..a4ce5f56f 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -74,8 +74,6 @@ def prepare_noise(latent_image, seed, noise_inds=None): def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None): if latent_image.is_nested: return latent_image - if getattr(latent_image, "trellis_skip_empty_fix", False): - return latent_image latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels if torch.count_nonzero(latent_image) == 0: if latent_format.latent_channels != latent_image.shape[1]: diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 328cec6e7..d345641b1 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -807,6 +807,8 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): ) shape_batch_index = normalize_batch_index(shape_latent.get("batch_index")) + if batch_index is None: + batch_index = shape_batch_index shape_latent = shape_latent["samples"] batch_size, inferred_coord_counts, max_tokens = infer_batched_coord_layout(coords) if coord_counts is not None: @@ -844,8 +846,6 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): latent[i, :, :count] = latent_i[0] if coord_counts is not None: latent.trellis_coord_counts = coord_counts.clone() - if batch_index is None: - batch_index = shape_batch_index model = model.clone() model.model_options = model.model_options.copy() if "transformer_options" in model.model_options: diff --git a/tests-unit/comfy_extras_test/nodes_trellis2_test.py b/tests-unit/comfy_extras_test/nodes_trellis2_test.py index 43647e793..49e872bc7 100644 --- a/tests-unit/comfy_extras_test/nodes_trellis2_test.py +++ b/tests-unit/comfy_extras_test/nodes_trellis2_test.py @@ -233,6 +233,35 @@ class TestTrellisBatchSemantics(unittest.TestCase): 13, ) + def test_empty_texture_latent_uses_shape_batch_index_for_seed_fallback(self): + coords = torch.tensor( + [ + [0, 1, 1, 1], + [1, 2, 2, 2], + [1, 3, 3, 3], + ], + dtype=torch.int32, + ) + structure = {"coords": coords} + shape_latent = { + "samples": torch.zeros(2, 32, 2, 1), + "batch_index": [4, 9], + } + + output, _ = nodes_trellis2.EmptyTextureLatentTrellis2.execute( + structure, + shape_latent, + DummyCloneModel(), + 13, + ) + + expected = torch.zeros(2, 32, 2, 1) + expected[0, :, :1, :] = torch.randn(1, 32, 1, 1, generator=torch.Generator(device="cpu").manual_seed(17))[0] + expected[1, :, :2, :] = torch.randn(1, 32, 2, 1, generator=torch.Generator(device="cpu").manual_seed(22))[0] + + self.assertTrue(torch.equal(output["samples"], expected)) + self.assertEqual(output["batch_index"], [4, 9]) + def test_flatten_batched_sparse_latent_validates_coord_counts(self): samples = torch.zeros(2, 32, 3, 1) coords = torch.tensor(