Fail loud on Trellis invalid batch metadata

This commit is contained in:
John Pollock 2026-04-20 15:50:40 -05:00
parent a752dd4736
commit 0b99c8c44a
4 changed files with 28 additions and 0 deletions

View File

@ -14,6 +14,10 @@ def prepare_noise_inner(latent_image, generator, noise_inds=None):
noise_inds = np.arange(latent_image.size(0), dtype=np.int64) noise_inds = np.arange(latent_image.size(0), dtype=np.int64)
else: else:
noise_inds = np.asarray(noise_inds, dtype=np.int64) noise_inds = np.asarray(noise_inds, dtype=np.int64)
if noise_inds.shape[0] != latent_image.size(0):
raise ValueError(
f"Trellis2 noise_inds length {noise_inds.shape[0]} does not match latent batch {latent_image.size(0)}"
)
base_seed = int(generator.initial_seed()) base_seed = int(generator.initial_seed())
unique_inds = np.unique(noise_inds) unique_inds = np.unique(noise_inds)

View File

@ -800,6 +800,11 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2:
coords = structure_or_coords.int() coords = structure_or_coords.int()
else:
raise ValueError(
"structure_or_coords must be a voxel input with data.ndim == 4, "
f'a dict containing "coords", or a 2D torch.Tensor; got {type(structure_or_coords).__name__}'
)
shape_batch_index = normalize_batch_index(shape_latent.get("batch_index")) shape_batch_index = normalize_batch_index(shape_latent.get("batch_index"))
shape_latent = shape_latent["samples"] shape_latent = shape_latent["samples"]

View File

@ -224,6 +224,15 @@ class TestTrellisBatchSemantics(unittest.TestCase):
13, 13,
) )
def test_empty_texture_latent_rejects_invalid_structure_input(self):
with self.assertRaises(ValueError):
nodes_trellis2.EmptyTextureLatentTrellis2.execute(
"bad-input",
{"samples": torch.zeros(1, 32, 2, 1)},
DummyCloneModel(),
13,
)
def test_flatten_batched_sparse_latent_validates_coord_counts(self): def test_flatten_batched_sparse_latent_validates_coord_counts(self):
samples = torch.zeros(2, 32, 3, 1) samples = torch.zeros(2, 32, 3, 1)
coords = torch.tensor( coords = torch.tensor(

View File

@ -42,6 +42,16 @@ class TestPrepareNoiseInnerTrellis(unittest.TestCase):
self.assertTrue(torch.equal(noise[1:2, :, :5, :], expected1)) self.assertTrue(torch.equal(noise[1:2, :, :5, :], expected1))
self.assertTrue(torch.equal(noise[0, :, 3:, :], torch.zeros_like(noise[0, :, 3:, :]))) self.assertTrue(torch.equal(noise[0, :, 3:, :], torch.zeros_like(noise[0, :, 3:, :])))
def test_coord_counts_noise_inds_length_must_match_batch(self):
latent = torch.zeros(2, 4, 5, 1)
latent.trellis_coord_counts = torch.tensor([3, 5], dtype=torch.int64)
generator = torch.Generator(device="cpu")
generator.manual_seed(456)
with self.assertRaises(ValueError):
comfy.sample.prepare_noise_inner(latent, generator, noise_inds=[7])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()