mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 10:52:31 +08:00
Validate Trellis coord_counts noise metadata
This commit is contained in:
parent
90ebb50f00
commit
33caec301a
@ -9,6 +9,16 @@ import comfy.nested_tensor
|
||||
def prepare_noise_inner(latent_image, generator, noise_inds=None):
|
||||
coord_counts = getattr(latent_image, "trellis_coord_counts", None)
|
||||
if coord_counts is not None:
|
||||
if coord_counts.ndim != 1:
|
||||
raise ValueError(f"Trellis2 coord_counts must be 1D, got shape {tuple(coord_counts.shape)}")
|
||||
if coord_counts.shape[0] != latent_image.size(0):
|
||||
raise ValueError(
|
||||
f"Trellis2 coord_counts length {coord_counts.shape[0]} does not match latent batch {latent_image.size(0)}"
|
||||
)
|
||||
if (coord_counts < 0).any() or (coord_counts > latent_image.size(2)).any():
|
||||
raise ValueError(
|
||||
f"Trellis2 coord_counts must be within [0, {latent_image.size(2)}], got {coord_counts.tolist()}"
|
||||
)
|
||||
noise = torch.zeros(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, device="cpu")
|
||||
if noise_inds is None:
|
||||
noise_inds = np.arange(latent_image.size(0), dtype=np.int64)
|
||||
|
||||
@ -52,6 +52,25 @@ class TestPrepareNoiseInnerTrellis(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
comfy.sample.prepare_noise_inner(latent, generator, noise_inds=[7])
|
||||
|
||||
def test_coord_counts_metadata_must_match_batch_and_bounds(self):
|
||||
generator = torch.Generator(device="cpu")
|
||||
generator.manual_seed(456)
|
||||
|
||||
latent = torch.zeros(2, 4, 5, 1)
|
||||
latent.trellis_coord_counts = torch.tensor([[3, 5]], dtype=torch.int64)
|
||||
with self.assertRaises(ValueError):
|
||||
comfy.sample.prepare_noise_inner(latent, generator)
|
||||
|
||||
latent = torch.zeros(2, 4, 5, 1)
|
||||
latent.trellis_coord_counts = torch.tensor([3], dtype=torch.int64)
|
||||
with self.assertRaises(ValueError):
|
||||
comfy.sample.prepare_noise_inner(latent, generator)
|
||||
|
||||
latent = torch.zeros(2, 4, 5, 1)
|
||||
latent.trellis_coord_counts = torch.tensor([3, 6], dtype=torch.int64)
|
||||
with self.assertRaises(ValueError):
|
||||
comfy.sample.prepare_noise_inner(latent, generator)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user