diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 15939e5c6..7cf3e728e 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -853,6 +853,10 @@ class Trellis2(nn.Module): raise ValueError( f"Trellis2 coord_counts batch {logical_batch} doesn't divide latent batch {B}" ) + if int(coord_counts.sum().item()) != coords.shape[0]: + raise ValueError( + f"Trellis2 coord_counts total {int(coord_counts.sum().item())} does not match coords rows {coords.shape[0]}" + ) batch_ids = coords[:, 0].to(torch.int64) order = torch.argsort(batch_ids, stable=True) sorted_coords = coords.index_select(0, order) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 6556ed176..ce184a946 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -105,6 +105,8 @@ def infer_batched_coord_layout(coords): raise ValueError("Trellis2 coords can't be empty") batch_ids = coords[:, 0].to(torch.int64) + if (batch_ids < 0).any(): + raise ValueError(f"Trellis2 batch ids must be non-negative, got {batch_ids.unique(sorted=True).tolist()}") batch_size = int(batch_ids.max().item()) + 1 counts = torch.bincount(batch_ids, minlength=batch_size) @@ -116,6 +118,15 @@ def infer_batched_coord_layout(coords): def split_batched_coords(coords, coord_counts): + if coord_counts.ndim != 1: + raise ValueError(f"Trellis2 coord_counts must be 1D, got shape {tuple(coord_counts.shape)}") + if (coord_counts < 0).any(): + raise ValueError(f"Trellis2 coord_counts must be non-negative, got {coord_counts.tolist()}") + if int(coord_counts.sum().item()) != coords.shape[0]: + raise ValueError( + f"Trellis2 coord_counts total {int(coord_counts.sum().item())} does not match coords rows {coords.shape[0]}" + ) + batch_ids = coords[:, 0].to(torch.int64) order = torch.argsort(batch_ids, stable=True) sorted_coords = coords.index_select(0, order) @@ -153,6 +164,17 @@ def resolve_sample_indices(batch_index, batch_size): return sample_indices +def resolve_singleton_sample_index(batch_index): + sample_indices = normalize_batch_index(batch_index) + if sample_indices is None: + return 0 + if len(sample_indices) != 1: + raise ValueError( + f"Trellis2 batch_index must be an int or single-element iterable for singleton coords, got {sample_indices}" + ) + return int(sample_indices[0]) + + def flatten_batched_sparse_latent(samples, coords, coord_counts): samples = samples.squeeze(-1).transpose(1, 2) if coord_counts is None: @@ -705,9 +727,9 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): else: coord_counts = inferred_coord_counts if batch_size == 1: - sample_indices = normalize_batch_index(batch_index) or [0] + sample_index = resolve_singleton_sample_index(batch_index) generator = torch.Generator(device="cpu") - generator.manual_seed(int(seed) + int(sample_indices[0])) + generator.manual_seed(int(seed) + sample_index) latent = torch.randn(1, in_channels, coords.shape[0], 1, generator=generator) else: sample_indices = resolve_sample_indices(batch_index, batch_size) @@ -730,8 +752,6 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"]["coords"] = coords if coord_counts is not None: model.model_options["transformer_options"]["coord_counts"] = coord_counts - if coord_resolutions is not None: - model.model_options["transformer_options"]["coord_resolutions"] = coord_resolutions if is_512_pass: model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512" else: @@ -742,7 +762,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): if coord_counts is not None: output["coord_counts"] = coord_counts if coord_resolutions is not None: - output["coord_resolutions"] = coord_resolutions + output["resolutions"] = coord_resolutions return IO.NodeOutput(output, model) class EmptyTextureLatentTrellis2(IO.ComfyNode): @@ -804,9 +824,9 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): ) if batch_size == 1: - sample_indices = normalize_batch_index(batch_index) or [0] + sample_index = resolve_singleton_sample_index(batch_index) generator = torch.Generator(device="cpu") - generator.manual_seed(int(seed) + int(sample_indices[0])) + generator.manual_seed(int(seed) + sample_index) latent = torch.randn(1, channels, coords.shape[0], 1, generator=generator) else: sample_indices = resolve_sample_indices(batch_index, batch_size) diff --git a/tests-unit/comfy_extras_test/nodes_trellis2_test.py b/tests-unit/comfy_extras_test/nodes_trellis2_test.py index 95f64d031..196a88343 100644 --- a/tests-unit/comfy_extras_test/nodes_trellis2_test.py +++ b/tests-unit/comfy_extras_test/nodes_trellis2_test.py @@ -190,6 +190,40 @@ class TestTrellisBatchSemantics(unittest.TestCase): self.assertTrue(torch.equal(output["coord_counts"], torch.tensor([2], dtype=torch.int64))) + def test_empty_shape_latent_rejects_multi_index_singleton(self): + structure = { + "coords": torch.tensor( + [ + [0, 1, 1, 1], + [0, 2, 2, 2], + ], + dtype=torch.int32, + ), + "batch_index": [5, 6], + } + + with self.assertRaises(ValueError): + nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 11) + + def test_empty_texture_latent_rejects_multi_index_singleton(self): + coords = torch.tensor( + [ + [0, 1, 1, 1], + [0, 2, 2, 2], + ], + dtype=torch.int32, + ) + structure = {"coords": coords, "batch_index": [7, 8]} + shape_latent = {"samples": torch.zeros(1, 32, 2, 1)} + + with self.assertRaises(ValueError): + nodes_trellis2.EmptyTextureLatentTrellis2.execute( + structure, + shape_latent, + DummyCloneModel(), + 13, + ) + def test_flatten_batched_sparse_latent_validates_coord_counts(self): samples = torch.zeros(2, 32, 3, 1) coords = torch.tensor( @@ -205,6 +239,49 @@ class TestTrellisBatchSemantics(unittest.TestCase): with self.assertRaises(ValueError): nodes_trellis2.flatten_batched_sparse_latent(samples, coords, coord_counts) + def test_infer_batched_coord_layout_rejects_negative_batch_ids(self): + coords = torch.tensor( + [ + [-1, 1, 1, 1], + [0, 2, 2, 2], + ], + dtype=torch.int32, + ) + + with self.assertRaises(ValueError): + nodes_trellis2.infer_batched_coord_layout(coords) + + def test_split_batched_coords_validates_total_count(self): + coords = torch.tensor( + [ + [0, 1, 1, 1], + [1, 2, 2, 2], + [1, 3, 3, 3], + ], + dtype=torch.int32, + ) + coord_counts = torch.tensor([1, 1], dtype=torch.int64) + + with self.assertRaises(ValueError): + nodes_trellis2.split_batched_coords(coords, coord_counts) + + def test_empty_shape_latent_preserves_resolutions_key(self): + structure = { + "coords": torch.tensor( + [ + [0, 1, 1, 1], + [0, 2, 2, 2], + ], + dtype=torch.int32, + ), + "resolutions": torch.tensor([1024], dtype=torch.int64), + } + + output, model = nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 11) + + self.assertTrue(torch.equal(output["resolutions"], torch.tensor([1024], dtype=torch.int64))) + self.assertNotIn("coord_resolutions", model.model_options["transformer_options"]) + if __name__ == "__main__": unittest.main()