mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 19:02:31 +08:00
Harden Trellis sparse metadata validation
This commit is contained in:
parent
7d98cc1305
commit
a752dd4736
@ -853,6 +853,10 @@ class Trellis2(nn.Module):
|
||||
raise ValueError(
|
||||
f"Trellis2 coord_counts batch {logical_batch} doesn't divide latent batch {B}"
|
||||
)
|
||||
if int(coord_counts.sum().item()) != coords.shape[0]:
|
||||
raise ValueError(
|
||||
f"Trellis2 coord_counts total {int(coord_counts.sum().item())} does not match coords rows {coords.shape[0]}"
|
||||
)
|
||||
batch_ids = coords[:, 0].to(torch.int64)
|
||||
order = torch.argsort(batch_ids, stable=True)
|
||||
sorted_coords = coords.index_select(0, order)
|
||||
|
||||
@ -105,6 +105,8 @@ def infer_batched_coord_layout(coords):
|
||||
raise ValueError("Trellis2 coords can't be empty")
|
||||
|
||||
batch_ids = coords[:, 0].to(torch.int64)
|
||||
if (batch_ids < 0).any():
|
||||
raise ValueError(f"Trellis2 batch ids must be non-negative, got {batch_ids.unique(sorted=True).tolist()}")
|
||||
batch_size = int(batch_ids.max().item()) + 1
|
||||
counts = torch.bincount(batch_ids, minlength=batch_size)
|
||||
|
||||
@ -116,6 +118,15 @@ def infer_batched_coord_layout(coords):
|
||||
|
||||
|
||||
def split_batched_coords(coords, coord_counts):
|
||||
if coord_counts.ndim != 1:
|
||||
raise ValueError(f"Trellis2 coord_counts must be 1D, got shape {tuple(coord_counts.shape)}")
|
||||
if (coord_counts < 0).any():
|
||||
raise ValueError(f"Trellis2 coord_counts must be non-negative, got {coord_counts.tolist()}")
|
||||
if int(coord_counts.sum().item()) != coords.shape[0]:
|
||||
raise ValueError(
|
||||
f"Trellis2 coord_counts total {int(coord_counts.sum().item())} does not match coords rows {coords.shape[0]}"
|
||||
)
|
||||
|
||||
batch_ids = coords[:, 0].to(torch.int64)
|
||||
order = torch.argsort(batch_ids, stable=True)
|
||||
sorted_coords = coords.index_select(0, order)
|
||||
@ -153,6 +164,17 @@ def resolve_sample_indices(batch_index, batch_size):
|
||||
return sample_indices
|
||||
|
||||
|
||||
def resolve_singleton_sample_index(batch_index):
|
||||
sample_indices = normalize_batch_index(batch_index)
|
||||
if sample_indices is None:
|
||||
return 0
|
||||
if len(sample_indices) != 1:
|
||||
raise ValueError(
|
||||
f"Trellis2 batch_index must be an int or single-element iterable for singleton coords, got {sample_indices}"
|
||||
)
|
||||
return int(sample_indices[0])
|
||||
|
||||
|
||||
def flatten_batched_sparse_latent(samples, coords, coord_counts):
|
||||
samples = samples.squeeze(-1).transpose(1, 2)
|
||||
if coord_counts is None:
|
||||
@ -705,9 +727,9 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||
else:
|
||||
coord_counts = inferred_coord_counts
|
||||
if batch_size == 1:
|
||||
sample_indices = normalize_batch_index(batch_index) or [0]
|
||||
sample_index = resolve_singleton_sample_index(batch_index)
|
||||
generator = torch.Generator(device="cpu")
|
||||
generator.manual_seed(int(seed) + int(sample_indices[0]))
|
||||
generator.manual_seed(int(seed) + sample_index)
|
||||
latent = torch.randn(1, in_channels, coords.shape[0], 1, generator=generator)
|
||||
else:
|
||||
sample_indices = resolve_sample_indices(batch_index, batch_size)
|
||||
@ -730,8 +752,6 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||
model.model_options["transformer_options"]["coords"] = coords
|
||||
if coord_counts is not None:
|
||||
model.model_options["transformer_options"]["coord_counts"] = coord_counts
|
||||
if coord_resolutions is not None:
|
||||
model.model_options["transformer_options"]["coord_resolutions"] = coord_resolutions
|
||||
if is_512_pass:
|
||||
model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512"
|
||||
else:
|
||||
@ -742,7 +762,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||
if coord_counts is not None:
|
||||
output["coord_counts"] = coord_counts
|
||||
if coord_resolutions is not None:
|
||||
output["coord_resolutions"] = coord_resolutions
|
||||
output["resolutions"] = coord_resolutions
|
||||
return IO.NodeOutput(output, model)
|
||||
|
||||
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||
@ -804,9 +824,9 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||
)
|
||||
|
||||
if batch_size == 1:
|
||||
sample_indices = normalize_batch_index(batch_index) or [0]
|
||||
sample_index = resolve_singleton_sample_index(batch_index)
|
||||
generator = torch.Generator(device="cpu")
|
||||
generator.manual_seed(int(seed) + int(sample_indices[0]))
|
||||
generator.manual_seed(int(seed) + sample_index)
|
||||
latent = torch.randn(1, channels, coords.shape[0], 1, generator=generator)
|
||||
else:
|
||||
sample_indices = resolve_sample_indices(batch_index, batch_size)
|
||||
|
||||
@ -190,6 +190,40 @@ class TestTrellisBatchSemantics(unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch.equal(output["coord_counts"], torch.tensor([2], dtype=torch.int64)))
|
||||
|
||||
def test_empty_shape_latent_rejects_multi_index_singleton(self):
|
||||
structure = {
|
||||
"coords": torch.tensor(
|
||||
[
|
||||
[0, 1, 1, 1],
|
||||
[0, 2, 2, 2],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"batch_index": [5, 6],
|
||||
}
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 11)
|
||||
|
||||
def test_empty_texture_latent_rejects_multi_index_singleton(self):
|
||||
coords = torch.tensor(
|
||||
[
|
||||
[0, 1, 1, 1],
|
||||
[0, 2, 2, 2],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
structure = {"coords": coords, "batch_index": [7, 8]}
|
||||
shape_latent = {"samples": torch.zeros(1, 32, 2, 1)}
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
nodes_trellis2.EmptyTextureLatentTrellis2.execute(
|
||||
structure,
|
||||
shape_latent,
|
||||
DummyCloneModel(),
|
||||
13,
|
||||
)
|
||||
|
||||
def test_flatten_batched_sparse_latent_validates_coord_counts(self):
|
||||
samples = torch.zeros(2, 32, 3, 1)
|
||||
coords = torch.tensor(
|
||||
@ -205,6 +239,49 @@ class TestTrellisBatchSemantics(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
nodes_trellis2.flatten_batched_sparse_latent(samples, coords, coord_counts)
|
||||
|
||||
def test_infer_batched_coord_layout_rejects_negative_batch_ids(self):
|
||||
coords = torch.tensor(
|
||||
[
|
||||
[-1, 1, 1, 1],
|
||||
[0, 2, 2, 2],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
nodes_trellis2.infer_batched_coord_layout(coords)
|
||||
|
||||
def test_split_batched_coords_validates_total_count(self):
|
||||
coords = torch.tensor(
|
||||
[
|
||||
[0, 1, 1, 1],
|
||||
[1, 2, 2, 2],
|
||||
[1, 3, 3, 3],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
coord_counts = torch.tensor([1, 1], dtype=torch.int64)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
nodes_trellis2.split_batched_coords(coords, coord_counts)
|
||||
|
||||
def test_empty_shape_latent_preserves_resolutions_key(self):
|
||||
structure = {
|
||||
"coords": torch.tensor(
|
||||
[
|
||||
[0, 1, 1, 1],
|
||||
[0, 2, 2, 2],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"resolutions": torch.tensor([1024], dtype=torch.int64),
|
||||
}
|
||||
|
||||
output, model = nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 11)
|
||||
|
||||
self.assertTrue(torch.equal(output["resolutions"], torch.tensor([1024], dtype=torch.int64)))
|
||||
self.assertNotIn("coord_resolutions", model.model_options["transformer_options"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user