Validate Trellis coord_counts noise metadata

This commit is contained in:
John Pollock 2026-04-20 16:36:48 -05:00
parent 90ebb50f00
commit 33caec301a
2 changed files with 29 additions and 0 deletions

View File

@ -9,6 +9,16 @@ import comfy.nested_tensor
def prepare_noise_inner(latent_image, generator, noise_inds=None):
coord_counts = getattr(latent_image, "trellis_coord_counts", None)
if coord_counts is not None:
if coord_counts.ndim != 1:
raise ValueError(f"Trellis2 coord_counts must be 1D, got shape {tuple(coord_counts.shape)}")
if coord_counts.shape[0] != latent_image.size(0):
raise ValueError(
f"Trellis2 coord_counts length {coord_counts.shape[0]} does not match latent batch {latent_image.size(0)}"
)
if (coord_counts < 0).any() or (coord_counts > latent_image.size(2)).any():
raise ValueError(
f"Trellis2 coord_counts must be within [0, {latent_image.size(2)}], got {coord_counts.tolist()}"
)
noise = torch.zeros(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, device="cpu")
if noise_inds is None:
noise_inds = np.arange(latent_image.size(0), dtype=np.int64)

View File

@ -52,6 +52,25 @@ class TestPrepareNoiseInnerTrellis(unittest.TestCase):
with self.assertRaises(ValueError):
comfy.sample.prepare_noise_inner(latent, generator, noise_inds=[7])
def test_coord_counts_metadata_must_match_batch_and_bounds(self):
generator = torch.Generator(device="cpu")
generator.manual_seed(456)
latent = torch.zeros(2, 4, 5, 1)
latent.trellis_coord_counts = torch.tensor([[3, 5]], dtype=torch.int64)
with self.assertRaises(ValueError):
comfy.sample.prepare_noise_inner(latent, generator)
latent = torch.zeros(2, 4, 5, 1)
latent.trellis_coord_counts = torch.tensor([3], dtype=torch.int64)
with self.assertRaises(ValueError):
comfy.sample.prepare_noise_inner(latent, generator)
latent = torch.zeros(2, 4, 5, 1)
latent.trellis_coord_counts = torch.tensor([3, 6], dtype=torch.int64)
with self.assertRaises(ValueError):
comfy.sample.prepare_noise_inner(latent, generator)
if __name__ == "__main__":
unittest.main()