Make omni stuff work on regular z image for easier testing. (#11985)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled

This commit is contained in:
comfyanonymous 2026-01-19 21:32:00 -08:00 committed by GitHub
parent 4edb87aa50
commit 8ccc0c94fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -657,7 +657,7 @@ class NextDiT(nn.Module):
device = x.device device = x.device
embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype) embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype)
if not omni: if (not omni) or self.siglip_embedder is None:
cap_feats_len = embeds[0].shape[1] + offset cap_feats_len = embeds[0].shape[1] + offset
embeds += (None,) embeds += (None,)
freqs_cis += (None,) freqs_cis += (None,)
@ -675,8 +675,9 @@ class NextDiT(nn.Module):
siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check
siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra)) siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra))
else: else:
siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1) if self.siglip_pad_token is not None:
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device) siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
if siglip_feats is None: if siglip_feats is None:
embeds += (None,) embeds += (None,)
@ -724,8 +725,9 @@ class NextDiT(nn.Module):
out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options) out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options)
for i, e in enumerate(out[0]): for i, e in enumerate(out[0]):
embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz)) if e is not None:
freqs_cis[i].append(out[1][i]) embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz))
freqs_cis[i].append(out[1][i])
start_t = out[2] start_t = out[2]
leftover_cap = ref_contexts[len(ref_latents):] leftover_cap = ref_contexts[len(ref_latents):]
@ -759,7 +761,7 @@ class NextDiT(nn.Module):
feats = (cap_feats,) feats = (cap_feats,)
fc = (cap_freqs_cis,) fc = (cap_freqs_cis,)
if omni: if omni and len(embeds[1]) > 0:
siglip_mask = None siglip_mask = None
siglip_feats_combined = torch.cat(embeds[1], dim=1) siglip_feats_combined = torch.cat(embeds[1], dim=1)
siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1) siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1)