Harden Trellis sparse latent seeding

This commit is contained in:
John Pollock 2026-04-20 16:05:10 -05:00
parent 0b99c8c44a
commit 90ebb50f00
4 changed files with 35 additions and 4 deletions

View File

@ -880,6 +880,10 @@ class Trellis2(nn.Module):
for i in range(logical_batch): for i in range(logical_batch):
out_index = rep * logical_batch + i out_index = rep * logical_batch + i
count = int(coord_counts[i].item()) count = int(coord_counts[i].item())
if count > N:
raise ValueError(
f"Trellis2 coord count {count} exceeds latent token dimension {N} for batch {i}"
)
coords_i = coords_by_batch[i].clone() coords_i = coords_by_batch[i].clone()
coords_i[:, 0] = 0 coords_i[:, 0] = 0
feats_i = x_eval[out_index, :count].clone() feats_i = x_eval[out_index, :count].clone()

View File

@ -74,8 +74,6 @@ def prepare_noise(latent_image, seed, noise_inds=None):
def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None): def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None):
if latent_image.is_nested: if latent_image.is_nested:
return latent_image return latent_image
if getattr(latent_image, "trellis_skip_empty_fix", False):
return latent_image
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
if torch.count_nonzero(latent_image) == 0: if torch.count_nonzero(latent_image) == 0:
if latent_format.latent_channels != latent_image.shape[1]: if latent_format.latent_channels != latent_image.shape[1]:

View File

@ -807,6 +807,8 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
) )
shape_batch_index = normalize_batch_index(shape_latent.get("batch_index")) shape_batch_index = normalize_batch_index(shape_latent.get("batch_index"))
if batch_index is None:
batch_index = shape_batch_index
shape_latent = shape_latent["samples"] shape_latent = shape_latent["samples"]
batch_size, inferred_coord_counts, max_tokens = infer_batched_coord_layout(coords) batch_size, inferred_coord_counts, max_tokens = infer_batched_coord_layout(coords)
if coord_counts is not None: if coord_counts is not None:
@ -844,8 +846,6 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
latent[i, :, :count] = latent_i[0] latent[i, :, :count] = latent_i[0]
if coord_counts is not None: if coord_counts is not None:
latent.trellis_coord_counts = coord_counts.clone() latent.trellis_coord_counts = coord_counts.clone()
if batch_index is None:
batch_index = shape_batch_index
model = model.clone() model = model.clone()
model.model_options = model.model_options.copy() model.model_options = model.model_options.copy()
if "transformer_options" in model.model_options: if "transformer_options" in model.model_options:

View File

@ -233,6 +233,35 @@ class TestTrellisBatchSemantics(unittest.TestCase):
13, 13,
) )
def test_empty_texture_latent_uses_shape_batch_index_for_seed_fallback(self):
coords = torch.tensor(
[
[0, 1, 1, 1],
[1, 2, 2, 2],
[1, 3, 3, 3],
],
dtype=torch.int32,
)
structure = {"coords": coords}
shape_latent = {
"samples": torch.zeros(2, 32, 2, 1),
"batch_index": [4, 9],
}
output, _ = nodes_trellis2.EmptyTextureLatentTrellis2.execute(
structure,
shape_latent,
DummyCloneModel(),
13,
)
expected = torch.zeros(2, 32, 2, 1)
expected[0, :, :1, :] = torch.randn(1, 32, 1, 1, generator=torch.Generator(device="cpu").manual_seed(17))[0]
expected[1, :, :2, :] = torch.randn(1, 32, 2, 1, generator=torch.Generator(device="cpu").manual_seed(22))[0]
self.assertTrue(torch.equal(output["samples"], expected))
self.assertEqual(output["batch_index"], [4, 9])
def test_flatten_batched_sparse_latent_validates_coord_counts(self): def test_flatten_batched_sparse_latent_validates_coord_counts(self):
samples = torch.zeros(2, 32, 3, 1) samples = torch.zeros(2, 32, 3, 1)
coords = torch.tensor( coords = torch.tensor(