From 79c30e16300704a2b36816c93e03a0a7cd30d7d0 Mon Sep 17 00:00:00 2001 From: nolan4 Date: Mon, 27 Oct 2025 22:00:47 -0700 Subject: [PATCH] removed redundant branch --- comfy/ldm/qwen_image/model.py | 48 ++++++++++------------------------- 1 file changed, 14 insertions(+), 34 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 45996e23b..66cabab43 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -153,36 +153,14 @@ class Attention(nn.Module): txt_query = self.norm_added_q(txt_query) txt_key = self.norm_added_k(txt_key) - # Handle both tuple (EliGen) and single tensor (standard) RoPE formats - if isinstance(image_rotary_emb, tuple): - # EliGen path: Apply RoPE BEFORE concatenation (research-accurate) - # txt/img query/key are in [b, s, h, d] format, compatible with apply_rotary_emb - img_rope, txt_rope = image_rotary_emb + # Concatenate text and image streams + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) - # Add heads dimension to RoPE tensors for broadcasting - # Shape: [s, features, 2, 2] -> [s, 1, features, 2, 2] - # Also convert to match query dtype (e.g., bfloat16) - txt_rope = txt_rope.unsqueeze(1).to(dtype=txt_query.dtype) - img_rope = img_rope.unsqueeze(1).to(dtype=img_query.dtype) - - # Apply RoPE separately to text and image streams - txt_query = apply_rotary_emb(txt_query, txt_rope) - txt_key = apply_rotary_emb(txt_key, txt_rope) - img_query = apply_rotary_emb(img_query, img_rope) - img_key = apply_rotary_emb(img_key, img_rope) - - # Now concatenate - joint_query = torch.cat([txt_query, img_query], dim=1) - joint_key = torch.cat([txt_key, img_key], dim=1) - joint_value = torch.cat([txt_value, img_value], dim=1) - else: - # Standard path: Concatenate first, then apply RoPE - joint_query = torch.cat([txt_query, img_query], dim=1) - joint_key = torch.cat([txt_key, img_key], dim=1) - joint_value = torch.cat([txt_value, img_value], dim=1) - - joint_query = apply_rotary_emb(joint_query, image_rotary_emb) - joint_key = apply_rotary_emb(joint_key, image_rotary_emb) + # Apply RoPE to concatenated queries and keys + joint_query = apply_rotary_emb(joint_query, image_rotary_emb) + joint_key = apply_rotary_emb(joint_key, image_rotary_emb) # Apply EliGen attention mask if present effective_mask = attention_mask @@ -483,7 +461,9 @@ class QwenImageTransformer2DModel(nn.Module): logger.debug(f"[EliGen Model] RoPE shapes - img: {img_rope.shape}, txt: {txt_rotary_emb.shape}") - image_rotary_emb = (img_rope, txt_rotary_emb) + # Concatenate text and image RoPE embeddings + # Convert to latent dtype to match queries/keys + image_rotary_emb = torch.cat([txt_rotary_emb, img_rope], dim=0).unsqueeze(1).to(dtype=latents.dtype) # Prepare spatial masks repeat_dim = latents.shape[1] @@ -524,7 +504,7 @@ class QwenImageTransformer2DModel(nn.Module): ) patched_masks.append(patched_mask) - # SECTION 5: Build attention mask matrix + # Build attention mask matrix attention_mask = torch.ones( (batch_size, total_seq_len, total_seq_len), dtype=torch.bool @@ -539,7 +519,7 @@ class QwenImageTransformer2DModel(nn.Module): for length in seq_lens: cumsum.append(cumsum[-1] + length) - # RULE 1: Spatial restriction (prompt <-> image) + # Spatial restriction (prompt <-> image) for i in range(N): prompt_start = cumsum[i] prompt_end = cumsum[i+1] @@ -558,7 +538,7 @@ class QwenImageTransformer2DModel(nn.Module): # - Image patches can only be updated by prompts that own them attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2) - # RULE 2: Entity isolation + # Entity isolation for i in range(N): for j in range(N): if i == j: @@ -567,7 +547,7 @@ class QwenImageTransformer2DModel(nn.Module): start_j, end_j = cumsum[j], cumsum[j+1] attention_mask[:, start_i:end_i, start_j:end_j] = False - # SECTION 6: Convert to additive bias and handle CFG batching + # Convert to additive bias and handle CFG batching attention_mask = attention_mask.float() num_valid_connections = (attention_mask == 1).sum().item() attention_mask[attention_mask == 0] = float('-inf')