From 0b99c8c44acf964b9989b71439826d2582363238 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 20 Apr 2026 15:50:40 -0500 Subject: [PATCH] Fail loud on Trellis invalid batch metadata --- comfy/sample.py | 4 ++++ comfy_extras/nodes_trellis2.py | 5 +++++ tests-unit/comfy_extras_test/nodes_trellis2_test.py | 9 +++++++++ tests-unit/comfy_test/sample_test.py | 10 ++++++++++ 4 files changed, 28 insertions(+) diff --git a/comfy/sample.py b/comfy/sample.py index 6fba221ed..8626269a1 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -14,6 +14,10 @@ def prepare_noise_inner(latent_image, generator, noise_inds=None): noise_inds = np.arange(latent_image.size(0), dtype=np.int64) else: noise_inds = np.asarray(noise_inds, dtype=np.int64) + if noise_inds.shape[0] != latent_image.size(0): + raise ValueError( + f"Trellis2 noise_inds length {noise_inds.shape[0]} does not match latent batch {latent_image.size(0)}" + ) base_seed = int(generator.initial_seed()) unique_inds = np.unique(noise_inds) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index ce184a946..328cec6e7 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -800,6 +800,11 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: coords = structure_or_coords.int() + else: + raise ValueError( + "structure_or_coords must be a voxel input with data.ndim == 4, " + f'a dict containing "coords", or a 2D torch.Tensor; got {type(structure_or_coords).__name__}' + ) shape_batch_index = normalize_batch_index(shape_latent.get("batch_index")) shape_latent = shape_latent["samples"] diff --git a/tests-unit/comfy_extras_test/nodes_trellis2_test.py b/tests-unit/comfy_extras_test/nodes_trellis2_test.py index 196a88343..43647e793 100644 --- a/tests-unit/comfy_extras_test/nodes_trellis2_test.py +++ b/tests-unit/comfy_extras_test/nodes_trellis2_test.py @@ -224,6 +224,15 @@ class TestTrellisBatchSemantics(unittest.TestCase): 13, ) + def test_empty_texture_latent_rejects_invalid_structure_input(self): + with self.assertRaises(ValueError): + nodes_trellis2.EmptyTextureLatentTrellis2.execute( + "bad-input", + {"samples": torch.zeros(1, 32, 2, 1)}, + DummyCloneModel(), + 13, + ) + def test_flatten_batched_sparse_latent_validates_coord_counts(self): samples = torch.zeros(2, 32, 3, 1) coords = torch.tensor( diff --git a/tests-unit/comfy_test/sample_test.py b/tests-unit/comfy_test/sample_test.py index ad154aca8..e76e65266 100644 --- a/tests-unit/comfy_test/sample_test.py +++ b/tests-unit/comfy_test/sample_test.py @@ -42,6 +42,16 @@ class TestPrepareNoiseInnerTrellis(unittest.TestCase): self.assertTrue(torch.equal(noise[1:2, :, :5, :], expected1)) self.assertTrue(torch.equal(noise[0, :, 3:, :], torch.zeros_like(noise[0, :, 3:, :]))) + def test_coord_counts_noise_inds_length_must_match_batch(self): + latent = torch.zeros(2, 4, 5, 1) + latent.trellis_coord_counts = torch.tensor([3, 5], dtype=torch.int64) + + generator = torch.Generator(device="cpu") + generator.manual_seed(456) + + with self.assertRaises(ValueError): + comfy.sample.prepare_noise_inner(latent, generator, noise_inds=[7]) + if __name__ == "__main__": unittest.main()