working using comfyUI's optimized attention and rotary embedding funcs

This commit is contained in:
nolan4 2025-10-24 17:37:26 -07:00
parent 1d9124203f
commit 0f4a141faf
3 changed files with 38 additions and 99 deletions

View File

@ -14,9 +14,8 @@ import comfy.patcher_extension
class QwenEmbedRope(nn.Module):
"""Research-accurate RoPE implementation for EliGen.
This class matches the research pipeline's QwenEmbedRope exactly.
"""RoPE implementation for EliGen.
https://github.com/modelscope/DiffSynth-Studio/blob/538017177a6136f45f57cdf0b7c4e0d7e1f8b50d/diffsynth/models/qwen_image_dit.py#L61
Returns a tuple (img_freqs, txt_freqs) for separate image and text RoPE.
"""
def __init__(self, theta: int, axes_dim: list, scale_rope=False):
@ -42,14 +41,23 @@ class QwenEmbedRope(nn.Module):
"""
Args:
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
Returns:
Real-valued 2x2 rotation matrix format [..., 2, 2] compatible with apply_rotary_emb
"""
assert dim % 2 == 0
freqs = torch.outer(
index,
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))
)
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
# Convert to real-valued rotation matrix format (matches Flux rope() output)
# Rotation matrix: [[cos, -sin], [sin, cos]]
cos_freqs = torch.cos(freqs)
sin_freqs = torch.sin(freqs)
# Stack as rotation matrix: [cos, -sin, sin, cos] then reshape to [..., 2, 2]
out = torch.stack([cos_freqs, -sin_freqs, sin_freqs, cos_freqs], dim=-1)
out = out.reshape(*freqs.shape, 2, 2)
return out
def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens):
if isinstance(video_fhw, list):
@ -108,7 +116,7 @@ class QwenEmbedRope(nn.Module):
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1, 2, 2)
self.rope_cache[rope_key] = freqs.clone().contiguous()
vid_freqs.append(self.rope_cache[rope_key])
@ -166,30 +174,11 @@ class FeedForward(nn.Module):
def apply_rotary_emb(x, freqs_cis):
if x.shape[1] == 0:
return x
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape)
def apply_rotary_emb_qwen(x: torch.Tensor, freqs_cis: torch.Tensor):
"""
Research-accurate RoPE application for QwenEmbedRope.
Args:
x: Input tensor with shape [b, h, s, d] (batch, heads, sequence, dim)
freqs_cis: Complex frequency tensor with shape [s, features] from QwenEmbedRope
Returns:
Rotated tensor with same shape as input
"""
# x shape: [b, h, s, d]
# freqs_cis shape: [s, features] where features = d (complex numbers)
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
return x_out.type_as(x)
class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
super().__init__()
@ -280,29 +269,26 @@ class Attention(nn.Module):
txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key)
### NEW
#################################################
# Handle both tuple (EliGen) and single tensor (standard) RoPE formats
if isinstance(image_rotary_emb, tuple):
# EliGen path: Apply RoPE BEFORE concatenation (research-accurate)
# txt/img query/key are currently [b, s, h, d], need to rearrange to [b, h, s, d]
# txt/img query/key are in [b, s, h, d] format, compatible with apply_rotary_emb
img_rope, txt_rope = image_rotary_emb
# Rearrange to [b, h, s, d] for apply_rotary_emb_qwen
txt_query = txt_query.permute(0, 2, 1, 3) # [b, s, h, d] -> [b, h, s, d]
txt_key = txt_key.permute(0, 2, 1, 3)
img_query = img_query.permute(0, 2, 1, 3)
img_key = img_key.permute(0, 2, 1, 3)
# Add heads dimension to RoPE tensors for broadcasting
# Shape: [s, features, 2, 2] -> [s, 1, features, 2, 2]
# Also convert to match query dtype (e.g., bfloat16)
txt_rope = txt_rope.unsqueeze(1).to(dtype=txt_query.dtype)
img_rope = img_rope.unsqueeze(1).to(dtype=img_query.dtype)
# Apply RoPE separately to text and image using research function
txt_query = apply_rotary_emb_qwen(txt_query, txt_rope)
txt_key = apply_rotary_emb_qwen(txt_key, txt_rope)
img_query = apply_rotary_emb_qwen(img_query, img_rope)
img_key = apply_rotary_emb_qwen(img_key, img_rope)
# Rearrange back to [b, s, h, d]
txt_query = txt_query.permute(0, 2, 1, 3)
txt_key = txt_key.permute(0, 2, 1, 3)
img_query = img_query.permute(0, 2, 1, 3)
img_key = img_key.permute(0, 2, 1, 3)
# Apply RoPE separately to text and image streams
txt_query = apply_rotary_emb(txt_query, txt_rope)
txt_key = apply_rotary_emb(txt_key, txt_rope)
img_query = apply_rotary_emb(img_query, img_rope)
img_key = apply_rotary_emb(img_key, img_rope)
# Now concatenate
joint_query = torch.cat([txt_query, img_query], dim=1)
@ -317,13 +303,10 @@ class Attention(nn.Module):
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
# Check if we have an EliGen mask - if so, use PyTorch SDPA directly (research-accurate)
has_eligen_mask = False
effective_mask = attention_mask
if transformer_options is not None:
eligen_mask = transformer_options.get("eligen_attention_mask", None)
if eligen_mask is not None:
has_eligen_mask = True
effective_mask = eligen_mask
# Validate shape
@ -331,30 +314,8 @@ class Attention(nn.Module):
if eligen_mask.shape[-1] != expected_seq:
raise ValueError(f"EliGen mask shape {eligen_mask.shape} doesn't match sequence length {expected_seq}")
if has_eligen_mask:
# EliGen path: Use PyTorch SDPA directly (matches research implementation exactly)
# Don't flatten - keep in [b, s, h, d] format for SDPA
# Reshape to [b, h, s, d] for SDPA
joint_query = joint_query.permute(0, 2, 1, 3) # [b, s, h, d] -> [b, h, s, d]
joint_key = joint_key.permute(0, 2, 1, 3)
joint_value = joint_value.permute(0, 2, 1, 3)
#################################################
print(f"[EliGen Debug Attention] Using PyTorch SDPA directly")
print(f" - Query shape: {joint_query.shape}")
print(f" - Mask shape: {effective_mask.shape}")
print(f" - Mask min/max: {effective_mask.min()} / {effective_mask.max():.2f}")
# Apply SDPA with mask (research-accurate)
joint_hidden_states = torch.nn.functional.scaled_dot_product_attention(
joint_query, joint_key, joint_value,
attn_mask=effective_mask,
dropout_p=0.0,
is_causal=False
)
# Reshape back: [b, h, s, d] -> [b, s, h*d]
joint_hidden_states = joint_hidden_states.permute(0, 2, 1, 3).flatten(start_dim=2)
else:
# Standard path: Use ComfyUI's optimized attention
joint_query = joint_query.flatten(start_dim=2)
joint_key = joint_key.flatten(start_dim=2)
@ -482,7 +443,7 @@ class LastLayer(nn.Module):
x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :])
return x
### NEW changes
class QwenImageTransformer2DModel(nn.Module):
def __init__(
self,
@ -564,6 +525,7 @@ class QwenImageTransformer2DModel(nn.Module):
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb,
entity_prompt_emb_mask, entity_masks, height, width, image):
"""
https://github.com/modelscope/DiffSynth-Studio/blob/538017177a6136f45f57cdf0b7c4e0d7e1f8b50d/diffsynth/models/qwen_image_dit.py#L434
Process entity masks and build spatial attention mask for EliGen.
This method:
@ -634,7 +596,6 @@ class QwenImageTransformer2DModel(nn.Module):
padded_h = height // 8
padded_w = width // 8
if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w:
# Validate masks aren't larger than expected (would cause negative padding)
assert entity_masks.shape[3] <= padded_h and entity_masks.shape[4] <= padded_w, \
f"Entity masks {entity_masks.shape[3]}x{entity_masks.shape[4]} larger than padded dims {padded_h}x{padded_w}"
@ -772,36 +733,17 @@ class QwenImageTransformer2DModel(nn.Module):
entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
entity_masks = kwargs.get("entity_masks", None)
# Debug logging (set ELIGEN_DEBUG=1 environment variable to enable)
if entity_prompt_emb is not None:
print(f"[EliGen Debug] Entity data found!")
print(f" - entity_prompt_emb type: {type(entity_prompt_emb)}, len: {len(entity_prompt_emb) if isinstance(entity_prompt_emb, list) else 'N/A'}")
print(f" - entity_masks shape: {entity_masks.shape if entity_masks is not None else 'None'}")
print(f" - Number of entities: {entity_masks.shape[1] if entity_masks is not None else 'Unknown'}")
# Check if this is positive or negative conditioning
cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else []
print(f" - Conditioning type: {['uncond' if c == 1 else 'cond' for c in cond_or_uncond]}")
else:
print(f"[EliGen Debug] No entity data in kwargs. Keys: {list(kwargs.keys())}")
# Branch: EliGen vs Standard path
# Only apply EliGen to POSITIVE conditioning (cond_or_uncond contains 0)
# Negative conditioning should use standard path
cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else []
is_positive_cond = 0 in cond_or_uncond # 0 = conditional/positive, 1 = unconditional/negative
if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond:
# EliGen path - process entity masks (POSITIVE CONDITIONING ONLY)
# Note: Use padded dimensions from orig_shape, not original latent dimensions
# orig_shape is from process_img which pads to patch_size
height = int(orig_shape[-2] * 8) # Padded latent height -> pixel height (ensure int)
width = int(orig_shape[-1] * 8) # Padded latent width -> pixel width (ensure int)
print(f"[EliGen Debug] Original latent shape: {x.shape}")
print(f"[EliGen Debug] Padded latent shape (orig_shape): {orig_shape}")
print(f"[EliGen Debug] Calculated pixel dimensions: {height}x{width}")
print(f"[EliGen Debug] Expected patches: {height//16}x{width//16}")
# Call process_entity_masks to get concatenated text, RoPE, and attention mask
encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks(
latents=x,

View File

@ -1462,9 +1462,6 @@ class QwenImage(BaseModel):
entity_masks = kwargs.get("entity_masks", None)
if entity_masks is not None:
out['entity_masks'] = comfy.conds.CONDRegular(entity_masks)
# import pdb; pdb.set_trace()
return out
def extra_conds_shapes(self, **kwargs):

View File

@ -105,7 +105,7 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode):
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
return io.NodeOutput(conditioning)
################ NEW
class TextEncodeQwenImageEliGen(io.ComfyNode):
@classmethod
def define_schema(cls):