mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 19:42:59 +08:00
removed redundant branch
This commit is contained in:
parent
6c09121070
commit
79c30e1630
@ -153,36 +153,14 @@ class Attention(nn.Module):
|
|||||||
txt_query = self.norm_added_q(txt_query)
|
txt_query = self.norm_added_q(txt_query)
|
||||||
txt_key = self.norm_added_k(txt_key)
|
txt_key = self.norm_added_k(txt_key)
|
||||||
|
|
||||||
# Handle both tuple (EliGen) and single tensor (standard) RoPE formats
|
# Concatenate text and image streams
|
||||||
if isinstance(image_rotary_emb, tuple):
|
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||||
# EliGen path: Apply RoPE BEFORE concatenation (research-accurate)
|
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||||
# txt/img query/key are in [b, s, h, d] format, compatible with apply_rotary_emb
|
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||||
img_rope, txt_rope = image_rotary_emb
|
|
||||||
|
|
||||||
# Add heads dimension to RoPE tensors for broadcasting
|
# Apply RoPE to concatenated queries and keys
|
||||||
# Shape: [s, features, 2, 2] -> [s, 1, features, 2, 2]
|
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
||||||
# Also convert to match query dtype (e.g., bfloat16)
|
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
||||||
txt_rope = txt_rope.unsqueeze(1).to(dtype=txt_query.dtype)
|
|
||||||
img_rope = img_rope.unsqueeze(1).to(dtype=img_query.dtype)
|
|
||||||
|
|
||||||
# Apply RoPE separately to text and image streams
|
|
||||||
txt_query = apply_rotary_emb(txt_query, txt_rope)
|
|
||||||
txt_key = apply_rotary_emb(txt_key, txt_rope)
|
|
||||||
img_query = apply_rotary_emb(img_query, img_rope)
|
|
||||||
img_key = apply_rotary_emb(img_key, img_rope)
|
|
||||||
|
|
||||||
# Now concatenate
|
|
||||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
|
||||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
|
||||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
|
||||||
else:
|
|
||||||
# Standard path: Concatenate first, then apply RoPE
|
|
||||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
|
||||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
|
||||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
|
||||||
|
|
||||||
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
|
||||||
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
|
||||||
|
|
||||||
# Apply EliGen attention mask if present
|
# Apply EliGen attention mask if present
|
||||||
effective_mask = attention_mask
|
effective_mask = attention_mask
|
||||||
@ -483,7 +461,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
|
|
||||||
logger.debug(f"[EliGen Model] RoPE shapes - img: {img_rope.shape}, txt: {txt_rotary_emb.shape}")
|
logger.debug(f"[EliGen Model] RoPE shapes - img: {img_rope.shape}, txt: {txt_rotary_emb.shape}")
|
||||||
|
|
||||||
image_rotary_emb = (img_rope, txt_rotary_emb)
|
# Concatenate text and image RoPE embeddings
|
||||||
|
# Convert to latent dtype to match queries/keys
|
||||||
|
image_rotary_emb = torch.cat([txt_rotary_emb, img_rope], dim=0).unsqueeze(1).to(dtype=latents.dtype)
|
||||||
|
|
||||||
# Prepare spatial masks
|
# Prepare spatial masks
|
||||||
repeat_dim = latents.shape[1]
|
repeat_dim = latents.shape[1]
|
||||||
@ -524,7 +504,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
)
|
)
|
||||||
patched_masks.append(patched_mask)
|
patched_masks.append(patched_mask)
|
||||||
|
|
||||||
# SECTION 5: Build attention mask matrix
|
# Build attention mask matrix
|
||||||
attention_mask = torch.ones(
|
attention_mask = torch.ones(
|
||||||
(batch_size, total_seq_len, total_seq_len),
|
(batch_size, total_seq_len, total_seq_len),
|
||||||
dtype=torch.bool
|
dtype=torch.bool
|
||||||
@ -539,7 +519,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
for length in seq_lens:
|
for length in seq_lens:
|
||||||
cumsum.append(cumsum[-1] + length)
|
cumsum.append(cumsum[-1] + length)
|
||||||
|
|
||||||
# RULE 1: Spatial restriction (prompt <-> image)
|
# Spatial restriction (prompt <-> image)
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
prompt_start = cumsum[i]
|
prompt_start = cumsum[i]
|
||||||
prompt_end = cumsum[i+1]
|
prompt_end = cumsum[i+1]
|
||||||
@ -558,7 +538,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
# - Image patches can only be updated by prompts that own them
|
# - Image patches can only be updated by prompts that own them
|
||||||
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
||||||
|
|
||||||
# RULE 2: Entity isolation
|
# Entity isolation
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
for j in range(N):
|
for j in range(N):
|
||||||
if i == j:
|
if i == j:
|
||||||
@ -567,7 +547,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
start_j, end_j = cumsum[j], cumsum[j+1]
|
start_j, end_j = cumsum[j], cumsum[j+1]
|
||||||
attention_mask[:, start_i:end_i, start_j:end_j] = False
|
attention_mask[:, start_i:end_i, start_j:end_j] = False
|
||||||
|
|
||||||
# SECTION 6: Convert to additive bias and handle CFG batching
|
# Convert to additive bias and handle CFG batching
|
||||||
attention_mask = attention_mask.float()
|
attention_mask = attention_mask.float()
|
||||||
num_valid_connections = (attention_mask == 1).sum().item()
|
num_valid_connections = (attention_mask == 1).sum().item()
|
||||||
attention_mask[attention_mask == 0] = float('-inf')
|
attention_mask[attention_mask == 0] = float('-inf')
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user