diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 6eb744286..0862f72f7 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -149,6 +149,9 @@ class Attention(nn.Module): seq_img = hidden_states.shape[1] seq_txt = encoder_hidden_states.shape[1] + transformer_patches = transformer_options.get("patches", {}) + extra_options = transformer_options.copy() + # Project and reshape to BHND format (batch, heads, seq, dim) img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() @@ -167,15 +170,22 @@ class Attention(nn.Module): joint_key = torch.cat([txt_key, img_key], dim=2) joint_value = torch.cat([txt_value, img_value], dim=2) - joint_query = apply_rope1(joint_query, image_rotary_emb) - joint_key = apply_rope1(joint_key, image_rotary_emb) - if encoder_hidden_states_mask is not None: attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device) attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask else: attn_mask = None + extra_options["img_slice"] = [txt_query.shape[2], joint_query.shape[2]] + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + for p in patch: + out = p(joint_query, joint_key, joint_value, pe=image_rotary_emb, attn_mask=encoder_hidden_states_mask, extra_options=extra_options) + joint_query, joint_key, joint_value, image_rotary_emb, encoder_hidden_states_mask = out.get("q", joint_query), out.get("k", joint_key), out.get("v", joint_value), out.get("pe", image_rotary_emb), out.get("attn_mask", encoder_hidden_states_mask) + + joint_query = apply_rope1(joint_query, image_rotary_emb) + joint_key = apply_rope1(joint_key, image_rotary_emb) + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attn_mask, transformer_options=transformer_options, skip_reshape=True) @@ -444,6 +454,7 @@ class QwenImageTransformer2DModel(nn.Module): timestep_zero_index = None if ref_latents is not None: + ref_num_tokens = [] h = 0 w = 0 index = 0 @@ -474,16 +485,16 @@ class QwenImageTransformer2DModel(nn.Module): kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) hidden_states = torch.cat([hidden_states, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) + ref_num_tokens.append(kontext.shape[1]) if timestep_zero: if index > 0: timestep = torch.cat([timestep, timestep * 0], dim=0) timestep_zero_index = num_embeds + transformer_options = transformer_options.copy() + transformer_options["reference_image_num_tokens"] = ref_num_tokens txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) - ids = torch.cat((txt_ids, img_ids), dim=1) - image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() - del ids, txt_ids, img_ids hidden_states = self.img_in(hidden_states) encoder_hidden_states = self.txt_norm(encoder_hidden_states) @@ -495,6 +506,18 @@ class QwenImageTransformer2DModel(nn.Module): patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) + if "post_input" in patches: + for p in patches["post_input"]: + out = p({"img": hidden_states, "txt": encoder_hidden_states, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options}) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + img_ids = out["img_ids"] + txt_ids = out["txt_ids"] + + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() + del ids, txt_ids, img_ids + transformer_options["total_blocks"] = len(self.transformer_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.transformer_blocks):