From 33caec301a6f1a6ab4e802555e80a0e0c5e5c83c Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 20 Apr 2026 16:36:48 -0500 Subject: [PATCH] Validate Trellis coord_counts noise metadata --- comfy/sample.py | 10 ++++++++++ tests-unit/comfy_test/sample_test.py | 19 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/comfy/sample.py b/comfy/sample.py index a4ce5f56f..878c4e984 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -9,6 +9,16 @@ import comfy.nested_tensor def prepare_noise_inner(latent_image, generator, noise_inds=None): coord_counts = getattr(latent_image, "trellis_coord_counts", None) if coord_counts is not None: + if coord_counts.ndim != 1: + raise ValueError(f"Trellis2 coord_counts must be 1D, got shape {tuple(coord_counts.shape)}") + if coord_counts.shape[0] != latent_image.size(0): + raise ValueError( + f"Trellis2 coord_counts length {coord_counts.shape[0]} does not match latent batch {latent_image.size(0)}" + ) + if (coord_counts < 0).any() or (coord_counts > latent_image.size(2)).any(): + raise ValueError( + f"Trellis2 coord_counts must be within [0, {latent_image.size(2)}], got {coord_counts.tolist()}" + ) noise = torch.zeros(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, device="cpu") if noise_inds is None: noise_inds = np.arange(latent_image.size(0), dtype=np.int64) diff --git a/tests-unit/comfy_test/sample_test.py b/tests-unit/comfy_test/sample_test.py index e76e65266..227659994 100644 --- a/tests-unit/comfy_test/sample_test.py +++ b/tests-unit/comfy_test/sample_test.py @@ -52,6 +52,25 @@ class TestPrepareNoiseInnerTrellis(unittest.TestCase): with self.assertRaises(ValueError): comfy.sample.prepare_noise_inner(latent, generator, noise_inds=[7]) + def test_coord_counts_metadata_must_match_batch_and_bounds(self): + generator = torch.Generator(device="cpu") + generator.manual_seed(456) + + latent = torch.zeros(2, 4, 5, 1) + latent.trellis_coord_counts = torch.tensor([[3, 5]], dtype=torch.int64) + with self.assertRaises(ValueError): + comfy.sample.prepare_noise_inner(latent, generator) + + latent = torch.zeros(2, 4, 5, 1) + latent.trellis_coord_counts = torch.tensor([3], dtype=torch.int64) + with self.assertRaises(ValueError): + comfy.sample.prepare_noise_inner(latent, generator) + + latent = torch.zeros(2, 4, 5, 1) + latent.trellis_coord_counts = torch.tensor([3, 6], dtype=torch.int64) + with self.assertRaises(ValueError): + comfy.sample.prepare_noise_inner(latent, generator) + if __name__ == "__main__": unittest.main()