removed redundant branch

This commit is contained in:
nolan4 2025-10-27 22:00:47 -07:00
parent 6c09121070
commit 79c30e1630

View File

@ -153,34 +153,12 @@ class Attention(nn.Module):
txt_query = self.norm_added_q(txt_query) txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key) txt_key = self.norm_added_k(txt_key)
# Handle both tuple (EliGen) and single tensor (standard) RoPE formats # Concatenate text and image streams
if isinstance(image_rotary_emb, tuple):
# EliGen path: Apply RoPE BEFORE concatenation (research-accurate)
# txt/img query/key are in [b, s, h, d] format, compatible with apply_rotary_emb
img_rope, txt_rope = image_rotary_emb
# Add heads dimension to RoPE tensors for broadcasting
# Shape: [s, features, 2, 2] -> [s, 1, features, 2, 2]
# Also convert to match query dtype (e.g., bfloat16)
txt_rope = txt_rope.unsqueeze(1).to(dtype=txt_query.dtype)
img_rope = img_rope.unsqueeze(1).to(dtype=img_query.dtype)
# Apply RoPE separately to text and image streams
txt_query = apply_rotary_emb(txt_query, txt_rope)
txt_key = apply_rotary_emb(txt_key, txt_rope)
img_query = apply_rotary_emb(img_query, img_rope)
img_key = apply_rotary_emb(img_key, img_rope)
# Now concatenate
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
else:
# Standard path: Concatenate first, then apply RoPE
joint_query = torch.cat([txt_query, img_query], dim=1) joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1) joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1) joint_value = torch.cat([txt_value, img_value], dim=1)
# Apply RoPE to concatenated queries and keys
joint_query = apply_rotary_emb(joint_query, image_rotary_emb) joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
joint_key = apply_rotary_emb(joint_key, image_rotary_emb) joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
@ -483,7 +461,9 @@ class QwenImageTransformer2DModel(nn.Module):
logger.debug(f"[EliGen Model] RoPE shapes - img: {img_rope.shape}, txt: {txt_rotary_emb.shape}") logger.debug(f"[EliGen Model] RoPE shapes - img: {img_rope.shape}, txt: {txt_rotary_emb.shape}")
image_rotary_emb = (img_rope, txt_rotary_emb) # Concatenate text and image RoPE embeddings
# Convert to latent dtype to match queries/keys
image_rotary_emb = torch.cat([txt_rotary_emb, img_rope], dim=0).unsqueeze(1).to(dtype=latents.dtype)
# Prepare spatial masks # Prepare spatial masks
repeat_dim = latents.shape[1] repeat_dim = latents.shape[1]
@ -524,7 +504,7 @@ class QwenImageTransformer2DModel(nn.Module):
) )
patched_masks.append(patched_mask) patched_masks.append(patched_mask)
# SECTION 5: Build attention mask matrix # Build attention mask matrix
attention_mask = torch.ones( attention_mask = torch.ones(
(batch_size, total_seq_len, total_seq_len), (batch_size, total_seq_len, total_seq_len),
dtype=torch.bool dtype=torch.bool
@ -539,7 +519,7 @@ class QwenImageTransformer2DModel(nn.Module):
for length in seq_lens: for length in seq_lens:
cumsum.append(cumsum[-1] + length) cumsum.append(cumsum[-1] + length)
# RULE 1: Spatial restriction (prompt <-> image) # Spatial restriction (prompt <-> image)
for i in range(N): for i in range(N):
prompt_start = cumsum[i] prompt_start = cumsum[i]
prompt_end = cumsum[i+1] prompt_end = cumsum[i+1]
@ -558,7 +538,7 @@ class QwenImageTransformer2DModel(nn.Module):
# - Image patches can only be updated by prompts that own them # - Image patches can only be updated by prompts that own them
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2) attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
# RULE 2: Entity isolation # Entity isolation
for i in range(N): for i in range(N):
for j in range(N): for j in range(N):
if i == j: if i == j:
@ -567,7 +547,7 @@ class QwenImageTransformer2DModel(nn.Module):
start_j, end_j = cumsum[j], cumsum[j+1] start_j, end_j = cumsum[j], cumsum[j+1]
attention_mask[:, start_i:end_i, start_j:end_j] = False attention_mask[:, start_i:end_i, start_j:end_j] = False
# SECTION 6: Convert to additive bias and handle CFG batching # Convert to additive bias and handle CFG batching
attention_mask = attention_mask.float() attention_mask = attention_mask.float()
num_valid_connections = (attention_mask == 1).sum().item() num_valid_connections = (attention_mask == 1).sum().item()
attention_mask[attention_mask == 0] = float('-inf') attention_mask[attention_mask == 0] = float('-inf')