mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 11:32:58 +08:00
removed redundant branch
This commit is contained in:
parent
6c09121070
commit
79c30e1630
@ -153,34 +153,12 @@ class Attention(nn.Module):
|
||||
txt_query = self.norm_added_q(txt_query)
|
||||
txt_key = self.norm_added_k(txt_key)
|
||||
|
||||
# Handle both tuple (EliGen) and single tensor (standard) RoPE formats
|
||||
if isinstance(image_rotary_emb, tuple):
|
||||
# EliGen path: Apply RoPE BEFORE concatenation (research-accurate)
|
||||
# txt/img query/key are in [b, s, h, d] format, compatible with apply_rotary_emb
|
||||
img_rope, txt_rope = image_rotary_emb
|
||||
|
||||
# Add heads dimension to RoPE tensors for broadcasting
|
||||
# Shape: [s, features, 2, 2] -> [s, 1, features, 2, 2]
|
||||
# Also convert to match query dtype (e.g., bfloat16)
|
||||
txt_rope = txt_rope.unsqueeze(1).to(dtype=txt_query.dtype)
|
||||
img_rope = img_rope.unsqueeze(1).to(dtype=img_query.dtype)
|
||||
|
||||
# Apply RoPE separately to text and image streams
|
||||
txt_query = apply_rotary_emb(txt_query, txt_rope)
|
||||
txt_key = apply_rotary_emb(txt_key, txt_rope)
|
||||
img_query = apply_rotary_emb(img_query, img_rope)
|
||||
img_key = apply_rotary_emb(img_key, img_rope)
|
||||
|
||||
# Now concatenate
|
||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||
else:
|
||||
# Standard path: Concatenate first, then apply RoPE
|
||||
# Concatenate text and image streams
|
||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||
|
||||
# Apply RoPE to concatenated queries and keys
|
||||
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
||||
|
||||
@ -483,7 +461,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
|
||||
logger.debug(f"[EliGen Model] RoPE shapes - img: {img_rope.shape}, txt: {txt_rotary_emb.shape}")
|
||||
|
||||
image_rotary_emb = (img_rope, txt_rotary_emb)
|
||||
# Concatenate text and image RoPE embeddings
|
||||
# Convert to latent dtype to match queries/keys
|
||||
image_rotary_emb = torch.cat([txt_rotary_emb, img_rope], dim=0).unsqueeze(1).to(dtype=latents.dtype)
|
||||
|
||||
# Prepare spatial masks
|
||||
repeat_dim = latents.shape[1]
|
||||
@ -524,7 +504,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
)
|
||||
patched_masks.append(patched_mask)
|
||||
|
||||
# SECTION 5: Build attention mask matrix
|
||||
# Build attention mask matrix
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, total_seq_len, total_seq_len),
|
||||
dtype=torch.bool
|
||||
@ -539,7 +519,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
for length in seq_lens:
|
||||
cumsum.append(cumsum[-1] + length)
|
||||
|
||||
# RULE 1: Spatial restriction (prompt <-> image)
|
||||
# Spatial restriction (prompt <-> image)
|
||||
for i in range(N):
|
||||
prompt_start = cumsum[i]
|
||||
prompt_end = cumsum[i+1]
|
||||
@ -558,7 +538,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
# - Image patches can only be updated by prompts that own them
|
||||
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
||||
|
||||
# RULE 2: Entity isolation
|
||||
# Entity isolation
|
||||
for i in range(N):
|
||||
for j in range(N):
|
||||
if i == j:
|
||||
@ -567,7 +547,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
start_j, end_j = cumsum[j], cumsum[j+1]
|
||||
attention_mask[:, start_i:end_i, start_j:end_j] = False
|
||||
|
||||
# SECTION 6: Convert to additive bias and handle CFG batching
|
||||
# Convert to additive bias and handle CFG batching
|
||||
attention_mask = attention_mask.float()
|
||||
num_valid_connections = (attention_mask == 1).sum().item()
|
||||
attention_mask[attention_mask == 0] = float('-inf')
|
||||
|
||||
Loading…
Reference in New Issue
Block a user