mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 19:02:31 +08:00
Harden Trellis sparse latent seeding
This commit is contained in:
parent
0b99c8c44a
commit
90ebb50f00
@ -880,6 +880,10 @@ class Trellis2(nn.Module):
|
|||||||
for i in range(logical_batch):
|
for i in range(logical_batch):
|
||||||
out_index = rep * logical_batch + i
|
out_index = rep * logical_batch + i
|
||||||
count = int(coord_counts[i].item())
|
count = int(coord_counts[i].item())
|
||||||
|
if count > N:
|
||||||
|
raise ValueError(
|
||||||
|
f"Trellis2 coord count {count} exceeds latent token dimension {N} for batch {i}"
|
||||||
|
)
|
||||||
coords_i = coords_by_batch[i].clone()
|
coords_i = coords_by_batch[i].clone()
|
||||||
coords_i[:, 0] = 0
|
coords_i[:, 0] = 0
|
||||||
feats_i = x_eval[out_index, :count].clone()
|
feats_i = x_eval[out_index, :count].clone()
|
||||||
|
|||||||
@ -74,8 +74,6 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
|||||||
def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None):
|
def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None):
|
||||||
if latent_image.is_nested:
|
if latent_image.is_nested:
|
||||||
return latent_image
|
return latent_image
|
||||||
if getattr(latent_image, "trellis_skip_empty_fix", False):
|
|
||||||
return latent_image
|
|
||||||
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
|
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
|
||||||
if torch.count_nonzero(latent_image) == 0:
|
if torch.count_nonzero(latent_image) == 0:
|
||||||
if latent_format.latent_channels != latent_image.shape[1]:
|
if latent_format.latent_channels != latent_image.shape[1]:
|
||||||
|
|||||||
@ -807,6 +807,8 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
shape_batch_index = normalize_batch_index(shape_latent.get("batch_index"))
|
shape_batch_index = normalize_batch_index(shape_latent.get("batch_index"))
|
||||||
|
if batch_index is None:
|
||||||
|
batch_index = shape_batch_index
|
||||||
shape_latent = shape_latent["samples"]
|
shape_latent = shape_latent["samples"]
|
||||||
batch_size, inferred_coord_counts, max_tokens = infer_batched_coord_layout(coords)
|
batch_size, inferred_coord_counts, max_tokens = infer_batched_coord_layout(coords)
|
||||||
if coord_counts is not None:
|
if coord_counts is not None:
|
||||||
@ -844,8 +846,6 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
latent[i, :, :count] = latent_i[0]
|
latent[i, :, :count] = latent_i[0]
|
||||||
if coord_counts is not None:
|
if coord_counts is not None:
|
||||||
latent.trellis_coord_counts = coord_counts.clone()
|
latent.trellis_coord_counts = coord_counts.clone()
|
||||||
if batch_index is None:
|
|
||||||
batch_index = shape_batch_index
|
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
model.model_options = model.model_options.copy()
|
model.model_options = model.model_options.copy()
|
||||||
if "transformer_options" in model.model_options:
|
if "transformer_options" in model.model_options:
|
||||||
|
|||||||
@ -233,6 +233,35 @@ class TestTrellisBatchSemantics(unittest.TestCase):
|
|||||||
13,
|
13,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_empty_texture_latent_uses_shape_batch_index_for_seed_fallback(self):
|
||||||
|
coords = torch.tensor(
|
||||||
|
[
|
||||||
|
[0, 1, 1, 1],
|
||||||
|
[1, 2, 2, 2],
|
||||||
|
[1, 3, 3, 3],
|
||||||
|
],
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
structure = {"coords": coords}
|
||||||
|
shape_latent = {
|
||||||
|
"samples": torch.zeros(2, 32, 2, 1),
|
||||||
|
"batch_index": [4, 9],
|
||||||
|
}
|
||||||
|
|
||||||
|
output, _ = nodes_trellis2.EmptyTextureLatentTrellis2.execute(
|
||||||
|
structure,
|
||||||
|
shape_latent,
|
||||||
|
DummyCloneModel(),
|
||||||
|
13,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected = torch.zeros(2, 32, 2, 1)
|
||||||
|
expected[0, :, :1, :] = torch.randn(1, 32, 1, 1, generator=torch.Generator(device="cpu").manual_seed(17))[0]
|
||||||
|
expected[1, :, :2, :] = torch.randn(1, 32, 2, 1, generator=torch.Generator(device="cpu").manual_seed(22))[0]
|
||||||
|
|
||||||
|
self.assertTrue(torch.equal(output["samples"], expected))
|
||||||
|
self.assertEqual(output["batch_index"], [4, 9])
|
||||||
|
|
||||||
def test_flatten_batched_sparse_latent_validates_coord_counts(self):
|
def test_flatten_batched_sparse_latent_validates_coord_counts(self):
|
||||||
samples = torch.zeros(2, 32, 3, 1)
|
samples = torch.zeros(2, 32, 3, 1)
|
||||||
coords = torch.tensor(
|
coords = torch.tensor(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user