Harden Trellis sparse metadata validation

This commit is contained in:
John Pollock 2026-04-20 14:46:23 -05:00
parent 7d98cc1305
commit a752dd4736
3 changed files with 108 additions and 7 deletions

View File

@ -853,6 +853,10 @@ class Trellis2(nn.Module):
raise ValueError(
f"Trellis2 coord_counts batch {logical_batch} doesn't divide latent batch {B}"
)
if int(coord_counts.sum().item()) != coords.shape[0]:
raise ValueError(
f"Trellis2 coord_counts total {int(coord_counts.sum().item())} does not match coords rows {coords.shape[0]}"
)
batch_ids = coords[:, 0].to(torch.int64)
order = torch.argsort(batch_ids, stable=True)
sorted_coords = coords.index_select(0, order)

View File

@ -105,6 +105,8 @@ def infer_batched_coord_layout(coords):
raise ValueError("Trellis2 coords can't be empty")
batch_ids = coords[:, 0].to(torch.int64)
if (batch_ids < 0).any():
raise ValueError(f"Trellis2 batch ids must be non-negative, got {batch_ids.unique(sorted=True).tolist()}")
batch_size = int(batch_ids.max().item()) + 1
counts = torch.bincount(batch_ids, minlength=batch_size)
@ -116,6 +118,15 @@ def infer_batched_coord_layout(coords):
def split_batched_coords(coords, coord_counts):
if coord_counts.ndim != 1:
raise ValueError(f"Trellis2 coord_counts must be 1D, got shape {tuple(coord_counts.shape)}")
if (coord_counts < 0).any():
raise ValueError(f"Trellis2 coord_counts must be non-negative, got {coord_counts.tolist()}")
if int(coord_counts.sum().item()) != coords.shape[0]:
raise ValueError(
f"Trellis2 coord_counts total {int(coord_counts.sum().item())} does not match coords rows {coords.shape[0]}"
)
batch_ids = coords[:, 0].to(torch.int64)
order = torch.argsort(batch_ids, stable=True)
sorted_coords = coords.index_select(0, order)
@ -153,6 +164,17 @@ def resolve_sample_indices(batch_index, batch_size):
return sample_indices
def resolve_singleton_sample_index(batch_index):
sample_indices = normalize_batch_index(batch_index)
if sample_indices is None:
return 0
if len(sample_indices) != 1:
raise ValueError(
f"Trellis2 batch_index must be an int or single-element iterable for singleton coords, got {sample_indices}"
)
return int(sample_indices[0])
def flatten_batched_sparse_latent(samples, coords, coord_counts):
samples = samples.squeeze(-1).transpose(1, 2)
if coord_counts is None:
@ -705,9 +727,9 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
else:
coord_counts = inferred_coord_counts
if batch_size == 1:
sample_indices = normalize_batch_index(batch_index) or [0]
sample_index = resolve_singleton_sample_index(batch_index)
generator = torch.Generator(device="cpu")
generator.manual_seed(int(seed) + int(sample_indices[0]))
generator.manual_seed(int(seed) + sample_index)
latent = torch.randn(1, in_channels, coords.shape[0], 1, generator=generator)
else:
sample_indices = resolve_sample_indices(batch_index, batch_size)
@ -730,8 +752,6 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
model.model_options["transformer_options"]["coords"] = coords
if coord_counts is not None:
model.model_options["transformer_options"]["coord_counts"] = coord_counts
if coord_resolutions is not None:
model.model_options["transformer_options"]["coord_resolutions"] = coord_resolutions
if is_512_pass:
model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512"
else:
@ -742,7 +762,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
if coord_counts is not None:
output["coord_counts"] = coord_counts
if coord_resolutions is not None:
output["coord_resolutions"] = coord_resolutions
output["resolutions"] = coord_resolutions
return IO.NodeOutput(output, model)
class EmptyTextureLatentTrellis2(IO.ComfyNode):
@ -804,9 +824,9 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
)
if batch_size == 1:
sample_indices = normalize_batch_index(batch_index) or [0]
sample_index = resolve_singleton_sample_index(batch_index)
generator = torch.Generator(device="cpu")
generator.manual_seed(int(seed) + int(sample_indices[0]))
generator.manual_seed(int(seed) + sample_index)
latent = torch.randn(1, channels, coords.shape[0], 1, generator=generator)
else:
sample_indices = resolve_sample_indices(batch_index, batch_size)

View File

@ -190,6 +190,40 @@ class TestTrellisBatchSemantics(unittest.TestCase):
self.assertTrue(torch.equal(output["coord_counts"], torch.tensor([2], dtype=torch.int64)))
def test_empty_shape_latent_rejects_multi_index_singleton(self):
structure = {
"coords": torch.tensor(
[
[0, 1, 1, 1],
[0, 2, 2, 2],
],
dtype=torch.int32,
),
"batch_index": [5, 6],
}
with self.assertRaises(ValueError):
nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 11)
def test_empty_texture_latent_rejects_multi_index_singleton(self):
coords = torch.tensor(
[
[0, 1, 1, 1],
[0, 2, 2, 2],
],
dtype=torch.int32,
)
structure = {"coords": coords, "batch_index": [7, 8]}
shape_latent = {"samples": torch.zeros(1, 32, 2, 1)}
with self.assertRaises(ValueError):
nodes_trellis2.EmptyTextureLatentTrellis2.execute(
structure,
shape_latent,
DummyCloneModel(),
13,
)
def test_flatten_batched_sparse_latent_validates_coord_counts(self):
samples = torch.zeros(2, 32, 3, 1)
coords = torch.tensor(
@ -205,6 +239,49 @@ class TestTrellisBatchSemantics(unittest.TestCase):
with self.assertRaises(ValueError):
nodes_trellis2.flatten_batched_sparse_latent(samples, coords, coord_counts)
def test_infer_batched_coord_layout_rejects_negative_batch_ids(self):
coords = torch.tensor(
[
[-1, 1, 1, 1],
[0, 2, 2, 2],
],
dtype=torch.int32,
)
with self.assertRaises(ValueError):
nodes_trellis2.infer_batched_coord_layout(coords)
def test_split_batched_coords_validates_total_count(self):
coords = torch.tensor(
[
[0, 1, 1, 1],
[1, 2, 2, 2],
[1, 3, 3, 3],
],
dtype=torch.int32,
)
coord_counts = torch.tensor([1, 1], dtype=torch.int64)
with self.assertRaises(ValueError):
nodes_trellis2.split_batched_coords(coords, coord_counts)
def test_empty_shape_latent_preserves_resolutions_key(self):
structure = {
"coords": torch.tensor(
[
[0, 1, 1, 1],
[0, 2, 2, 2],
],
dtype=torch.int32,
),
"resolutions": torch.tensor([1024], dtype=torch.int64),
}
output, model = nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 11)
self.assertTrue(torch.equal(output["resolutions"], torch.tensor([1024], dtype=torch.int64)))
self.assertNotIn("coord_resolutions", model.model_options["transformer_options"])
if __name__ == "__main__":
unittest.main()