mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-13 13:17:45 +08:00
Add pre attention and post input patches to qwen image model. (#12879)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
This commit is contained in:
parent
3ad36d6be6
commit
9642e4407b
@ -149,6 +149,9 @@ class Attention(nn.Module):
|
|||||||
seq_img = hidden_states.shape[1]
|
seq_img = hidden_states.shape[1]
|
||||||
seq_txt = encoder_hidden_states.shape[1]
|
seq_txt = encoder_hidden_states.shape[1]
|
||||||
|
|
||||||
|
transformer_patches = transformer_options.get("patches", {})
|
||||||
|
extra_options = transformer_options.copy()
|
||||||
|
|
||||||
# Project and reshape to BHND format (batch, heads, seq, dim)
|
# Project and reshape to BHND format (batch, heads, seq, dim)
|
||||||
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
@ -167,15 +170,22 @@ class Attention(nn.Module):
|
|||||||
joint_key = torch.cat([txt_key, img_key], dim=2)
|
joint_key = torch.cat([txt_key, img_key], dim=2)
|
||||||
joint_value = torch.cat([txt_value, img_value], dim=2)
|
joint_value = torch.cat([txt_value, img_value], dim=2)
|
||||||
|
|
||||||
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
|
||||||
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
|
||||||
|
|
||||||
if encoder_hidden_states_mask is not None:
|
if encoder_hidden_states_mask is not None:
|
||||||
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
|
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
|
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
|
||||||
else:
|
else:
|
||||||
attn_mask = None
|
attn_mask = None
|
||||||
|
|
||||||
|
extra_options["img_slice"] = [txt_query.shape[2], joint_query.shape[2]]
|
||||||
|
if "attn1_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["attn1_patch"]
|
||||||
|
for p in patch:
|
||||||
|
out = p(joint_query, joint_key, joint_value, pe=image_rotary_emb, attn_mask=encoder_hidden_states_mask, extra_options=extra_options)
|
||||||
|
joint_query, joint_key, joint_value, image_rotary_emb, encoder_hidden_states_mask = out.get("q", joint_query), out.get("k", joint_key), out.get("v", joint_value), out.get("pe", image_rotary_emb), out.get("attn_mask", encoder_hidden_states_mask)
|
||||||
|
|
||||||
|
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||||
|
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||||
|
|
||||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
||||||
attn_mask, transformer_options=transformer_options,
|
attn_mask, transformer_options=transformer_options,
|
||||||
skip_reshape=True)
|
skip_reshape=True)
|
||||||
@ -444,6 +454,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
|
|
||||||
timestep_zero_index = None
|
timestep_zero_index = None
|
||||||
if ref_latents is not None:
|
if ref_latents is not None:
|
||||||
|
ref_num_tokens = []
|
||||||
h = 0
|
h = 0
|
||||||
w = 0
|
w = 0
|
||||||
index = 0
|
index = 0
|
||||||
@ -474,16 +485,16 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
|
ref_num_tokens.append(kontext.shape[1])
|
||||||
if timestep_zero:
|
if timestep_zero:
|
||||||
if index > 0:
|
if index > 0:
|
||||||
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
||||||
timestep_zero_index = num_embeds
|
timestep_zero_index = num_embeds
|
||||||
|
transformer_options = transformer_options.copy()
|
||||||
|
transformer_options["reference_image_num_tokens"] = ref_num_tokens
|
||||||
|
|
||||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
|
||||||
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
|
||||||
del ids, txt_ids, img_ids
|
|
||||||
|
|
||||||
hidden_states = self.img_in(hidden_states)
|
hidden_states = self.img_in(hidden_states)
|
||||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||||
@ -495,6 +506,18 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
|
||||||
|
if "post_input" in patches:
|
||||||
|
for p in patches["post_input"]:
|
||||||
|
out = p({"img": hidden_states, "txt": encoder_hidden_states, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
|
||||||
|
hidden_states = out["img"]
|
||||||
|
encoder_hidden_states = out["txt"]
|
||||||
|
img_ids = out["img_ids"]
|
||||||
|
txt_ids = out["txt_ids"]
|
||||||
|
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||||
|
del ids, txt_ids, img_ids
|
||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
||||||
transformer_options["block_type"] = "double"
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user