mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 13:20:19 +08:00
Make omni stuff work on regular z image for easier testing. (#11985)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
This commit is contained in:
parent
4edb87aa50
commit
8ccc0c94fa
@ -657,7 +657,7 @@ class NextDiT(nn.Module):
|
|||||||
device = x.device
|
device = x.device
|
||||||
embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype)
|
embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype)
|
||||||
|
|
||||||
if not omni:
|
if (not omni) or self.siglip_embedder is None:
|
||||||
cap_feats_len = embeds[0].shape[1] + offset
|
cap_feats_len = embeds[0].shape[1] + offset
|
||||||
embeds += (None,)
|
embeds += (None,)
|
||||||
freqs_cis += (None,)
|
freqs_cis += (None,)
|
||||||
@ -675,8 +675,9 @@ class NextDiT(nn.Module):
|
|||||||
siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check
|
siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check
|
||||||
siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra))
|
siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra))
|
||||||
else:
|
else:
|
||||||
siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
|
if self.siglip_pad_token is not None:
|
||||||
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
|
siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
|
||||||
|
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
|
||||||
|
|
||||||
if siglip_feats is None:
|
if siglip_feats is None:
|
||||||
embeds += (None,)
|
embeds += (None,)
|
||||||
@ -724,8 +725,9 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options)
|
out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options)
|
||||||
for i, e in enumerate(out[0]):
|
for i, e in enumerate(out[0]):
|
||||||
embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz))
|
if e is not None:
|
||||||
freqs_cis[i].append(out[1][i])
|
embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz))
|
||||||
|
freqs_cis[i].append(out[1][i])
|
||||||
start_t = out[2]
|
start_t = out[2]
|
||||||
leftover_cap = ref_contexts[len(ref_latents):]
|
leftover_cap = ref_contexts[len(ref_latents):]
|
||||||
|
|
||||||
@ -759,7 +761,7 @@ class NextDiT(nn.Module):
|
|||||||
feats = (cap_feats,)
|
feats = (cap_feats,)
|
||||||
fc = (cap_freqs_cis,)
|
fc = (cap_freqs_cis,)
|
||||||
|
|
||||||
if omni:
|
if omni and len(embeds[1]) > 0:
|
||||||
siglip_mask = None
|
siglip_mask = None
|
||||||
siglip_feats_combined = torch.cat(embeds[1], dim=1)
|
siglip_feats_combined = torch.cat(embeds[1], dim=1)
|
||||||
siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1)
|
siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user