mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 10:52:31 +08:00
Harden Trellis sparse latent seeding
This commit is contained in:
parent
0b99c8c44a
commit
90ebb50f00
@ -880,6 +880,10 @@ class Trellis2(nn.Module):
|
||||
for i in range(logical_batch):
|
||||
out_index = rep * logical_batch + i
|
||||
count = int(coord_counts[i].item())
|
||||
if count > N:
|
||||
raise ValueError(
|
||||
f"Trellis2 coord count {count} exceeds latent token dimension {N} for batch {i}"
|
||||
)
|
||||
coords_i = coords_by_batch[i].clone()
|
||||
coords_i[:, 0] = 0
|
||||
feats_i = x_eval[out_index, :count].clone()
|
||||
|
||||
@ -74,8 +74,6 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None):
|
||||
if latent_image.is_nested:
|
||||
return latent_image
|
||||
if getattr(latent_image, "trellis_skip_empty_fix", False):
|
||||
return latent_image
|
||||
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
|
||||
if torch.count_nonzero(latent_image) == 0:
|
||||
if latent_format.latent_channels != latent_image.shape[1]:
|
||||
|
||||
@ -807,6 +807,8 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||
)
|
||||
|
||||
shape_batch_index = normalize_batch_index(shape_latent.get("batch_index"))
|
||||
if batch_index is None:
|
||||
batch_index = shape_batch_index
|
||||
shape_latent = shape_latent["samples"]
|
||||
batch_size, inferred_coord_counts, max_tokens = infer_batched_coord_layout(coords)
|
||||
if coord_counts is not None:
|
||||
@ -844,8 +846,6 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||
latent[i, :, :count] = latent_i[0]
|
||||
if coord_counts is not None:
|
||||
latent.trellis_coord_counts = coord_counts.clone()
|
||||
if batch_index is None:
|
||||
batch_index = shape_batch_index
|
||||
model = model.clone()
|
||||
model.model_options = model.model_options.copy()
|
||||
if "transformer_options" in model.model_options:
|
||||
|
||||
@ -233,6 +233,35 @@ class TestTrellisBatchSemantics(unittest.TestCase):
|
||||
13,
|
||||
)
|
||||
|
||||
def test_empty_texture_latent_uses_shape_batch_index_for_seed_fallback(self):
|
||||
coords = torch.tensor(
|
||||
[
|
||||
[0, 1, 1, 1],
|
||||
[1, 2, 2, 2],
|
||||
[1, 3, 3, 3],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
structure = {"coords": coords}
|
||||
shape_latent = {
|
||||
"samples": torch.zeros(2, 32, 2, 1),
|
||||
"batch_index": [4, 9],
|
||||
}
|
||||
|
||||
output, _ = nodes_trellis2.EmptyTextureLatentTrellis2.execute(
|
||||
structure,
|
||||
shape_latent,
|
||||
DummyCloneModel(),
|
||||
13,
|
||||
)
|
||||
|
||||
expected = torch.zeros(2, 32, 2, 1)
|
||||
expected[0, :, :1, :] = torch.randn(1, 32, 1, 1, generator=torch.Generator(device="cpu").manual_seed(17))[0]
|
||||
expected[1, :, :2, :] = torch.randn(1, 32, 2, 1, generator=torch.Generator(device="cpu").manual_seed(22))[0]
|
||||
|
||||
self.assertTrue(torch.equal(output["samples"], expected))
|
||||
self.assertEqual(output["batch_index"], [4, 9])
|
||||
|
||||
def test_flatten_batched_sparse_latent_validates_coord_counts(self):
|
||||
samples = torch.zeros(2, 32, 3, 1)
|
||||
coords = torch.tensor(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user