mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 19:42:59 +08:00
working using comfyUI's optimized attention and rotary embedding funcs
This commit is contained in:
parent
1d9124203f
commit
0f4a141faf
@ -14,9 +14,8 @@ import comfy.patcher_extension
|
||||
|
||||
|
||||
class QwenEmbedRope(nn.Module):
|
||||
"""Research-accurate RoPE implementation for EliGen.
|
||||
|
||||
This class matches the research pipeline's QwenEmbedRope exactly.
|
||||
"""RoPE implementation for EliGen.
|
||||
https://github.com/modelscope/DiffSynth-Studio/blob/538017177a6136f45f57cdf0b7c4e0d7e1f8b50d/diffsynth/models/qwen_image_dit.py#L61
|
||||
Returns a tuple (img_freqs, txt_freqs) for separate image and text RoPE.
|
||||
"""
|
||||
def __init__(self, theta: int, axes_dim: list, scale_rope=False):
|
||||
@ -42,14 +41,23 @@ class QwenEmbedRope(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
||||
|
||||
Returns:
|
||||
Real-valued 2x2 rotation matrix format [..., 2, 2] compatible with apply_rotary_emb
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
freqs = torch.outer(
|
||||
index,
|
||||
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))
|
||||
)
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
# Convert to real-valued rotation matrix format (matches Flux rope() output)
|
||||
# Rotation matrix: [[cos, -sin], [sin, cos]]
|
||||
cos_freqs = torch.cos(freqs)
|
||||
sin_freqs = torch.sin(freqs)
|
||||
# Stack as rotation matrix: [cos, -sin, sin, cos] then reshape to [..., 2, 2]
|
||||
out = torch.stack([cos_freqs, -sin_freqs, sin_freqs, cos_freqs], dim=-1)
|
||||
out = out.reshape(*freqs.shape, 2, 2)
|
||||
return out
|
||||
|
||||
def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens):
|
||||
if isinstance(video_fhw, list):
|
||||
@ -108,7 +116,7 @@ class QwenEmbedRope(nn.Module):
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1, 2, 2)
|
||||
self.rope_cache[rope_key] = freqs.clone().contiguous()
|
||||
vid_freqs.append(self.rope_cache[rope_key])
|
||||
|
||||
@ -166,30 +174,11 @@ class FeedForward(nn.Module):
|
||||
def apply_rotary_emb(x, freqs_cis):
|
||||
if x.shape[1] == 0:
|
||||
return x
|
||||
|
||||
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x.shape)
|
||||
|
||||
|
||||
def apply_rotary_emb_qwen(x: torch.Tensor, freqs_cis: torch.Tensor):
|
||||
"""
|
||||
Research-accurate RoPE application for QwenEmbedRope.
|
||||
|
||||
Args:
|
||||
x: Input tensor with shape [b, h, s, d] (batch, heads, sequence, dim)
|
||||
freqs_cis: Complex frequency tensor with shape [s, features] from QwenEmbedRope
|
||||
|
||||
Returns:
|
||||
Rotated tensor with same shape as input
|
||||
"""
|
||||
# x shape: [b, h, s, d]
|
||||
# freqs_cis shape: [s, features] where features = d (complex numbers)
|
||||
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x)
|
||||
|
||||
|
||||
class QwenTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
@ -280,29 +269,26 @@ class Attention(nn.Module):
|
||||
txt_query = self.norm_added_q(txt_query)
|
||||
txt_key = self.norm_added_k(txt_key)
|
||||
|
||||
### NEW
|
||||
#################################################
|
||||
|
||||
# Handle both tuple (EliGen) and single tensor (standard) RoPE formats
|
||||
if isinstance(image_rotary_emb, tuple):
|
||||
# EliGen path: Apply RoPE BEFORE concatenation (research-accurate)
|
||||
# txt/img query/key are currently [b, s, h, d], need to rearrange to [b, h, s, d]
|
||||
# txt/img query/key are in [b, s, h, d] format, compatible with apply_rotary_emb
|
||||
img_rope, txt_rope = image_rotary_emb
|
||||
|
||||
# Rearrange to [b, h, s, d] for apply_rotary_emb_qwen
|
||||
txt_query = txt_query.permute(0, 2, 1, 3) # [b, s, h, d] -> [b, h, s, d]
|
||||
txt_key = txt_key.permute(0, 2, 1, 3)
|
||||
img_query = img_query.permute(0, 2, 1, 3)
|
||||
img_key = img_key.permute(0, 2, 1, 3)
|
||||
# Add heads dimension to RoPE tensors for broadcasting
|
||||
# Shape: [s, features, 2, 2] -> [s, 1, features, 2, 2]
|
||||
# Also convert to match query dtype (e.g., bfloat16)
|
||||
txt_rope = txt_rope.unsqueeze(1).to(dtype=txt_query.dtype)
|
||||
img_rope = img_rope.unsqueeze(1).to(dtype=img_query.dtype)
|
||||
|
||||
# Apply RoPE separately to text and image using research function
|
||||
txt_query = apply_rotary_emb_qwen(txt_query, txt_rope)
|
||||
txt_key = apply_rotary_emb_qwen(txt_key, txt_rope)
|
||||
img_query = apply_rotary_emb_qwen(img_query, img_rope)
|
||||
img_key = apply_rotary_emb_qwen(img_key, img_rope)
|
||||
|
||||
# Rearrange back to [b, s, h, d]
|
||||
txt_query = txt_query.permute(0, 2, 1, 3)
|
||||
txt_key = txt_key.permute(0, 2, 1, 3)
|
||||
img_query = img_query.permute(0, 2, 1, 3)
|
||||
img_key = img_key.permute(0, 2, 1, 3)
|
||||
# Apply RoPE separately to text and image streams
|
||||
txt_query = apply_rotary_emb(txt_query, txt_rope)
|
||||
txt_key = apply_rotary_emb(txt_key, txt_rope)
|
||||
img_query = apply_rotary_emb(img_query, img_rope)
|
||||
img_key = apply_rotary_emb(img_key, img_rope)
|
||||
|
||||
# Now concatenate
|
||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||
@ -317,13 +303,10 @@ class Attention(nn.Module):
|
||||
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
||||
|
||||
# Check if we have an EliGen mask - if so, use PyTorch SDPA directly (research-accurate)
|
||||
has_eligen_mask = False
|
||||
effective_mask = attention_mask
|
||||
if transformer_options is not None:
|
||||
eligen_mask = transformer_options.get("eligen_attention_mask", None)
|
||||
if eligen_mask is not None:
|
||||
has_eligen_mask = True
|
||||
effective_mask = eligen_mask
|
||||
|
||||
# Validate shape
|
||||
@ -331,30 +314,8 @@ class Attention(nn.Module):
|
||||
if eligen_mask.shape[-1] != expected_seq:
|
||||
raise ValueError(f"EliGen mask shape {eligen_mask.shape} doesn't match sequence length {expected_seq}")
|
||||
|
||||
if has_eligen_mask:
|
||||
# EliGen path: Use PyTorch SDPA directly (matches research implementation exactly)
|
||||
# Don't flatten - keep in [b, s, h, d] format for SDPA
|
||||
# Reshape to [b, h, s, d] for SDPA
|
||||
joint_query = joint_query.permute(0, 2, 1, 3) # [b, s, h, d] -> [b, h, s, d]
|
||||
joint_key = joint_key.permute(0, 2, 1, 3)
|
||||
joint_value = joint_value.permute(0, 2, 1, 3)
|
||||
#################################################
|
||||
|
||||
print(f"[EliGen Debug Attention] Using PyTorch SDPA directly")
|
||||
print(f" - Query shape: {joint_query.shape}")
|
||||
print(f" - Mask shape: {effective_mask.shape}")
|
||||
print(f" - Mask min/max: {effective_mask.min()} / {effective_mask.max():.2f}")
|
||||
|
||||
# Apply SDPA with mask (research-accurate)
|
||||
joint_hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
joint_query, joint_key, joint_value,
|
||||
attn_mask=effective_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False
|
||||
)
|
||||
|
||||
# Reshape back: [b, h, s, d] -> [b, s, h*d]
|
||||
joint_hidden_states = joint_hidden_states.permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
else:
|
||||
# Standard path: Use ComfyUI's optimized attention
|
||||
joint_query = joint_query.flatten(start_dim=2)
|
||||
joint_key = joint_key.flatten(start_dim=2)
|
||||
@ -482,7 +443,7 @@ class LastLayer(nn.Module):
|
||||
x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :])
|
||||
return x
|
||||
|
||||
|
||||
### NEW changes
|
||||
class QwenImageTransformer2DModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -564,6 +525,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb,
|
||||
entity_prompt_emb_mask, entity_masks, height, width, image):
|
||||
"""
|
||||
https://github.com/modelscope/DiffSynth-Studio/blob/538017177a6136f45f57cdf0b7c4e0d7e1f8b50d/diffsynth/models/qwen_image_dit.py#L434
|
||||
Process entity masks and build spatial attention mask for EliGen.
|
||||
|
||||
This method:
|
||||
@ -634,7 +596,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
padded_h = height // 8
|
||||
padded_w = width // 8
|
||||
if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w:
|
||||
# Validate masks aren't larger than expected (would cause negative padding)
|
||||
assert entity_masks.shape[3] <= padded_h and entity_masks.shape[4] <= padded_w, \
|
||||
f"Entity masks {entity_masks.shape[3]}x{entity_masks.shape[4]} larger than padded dims {padded_h}x{padded_w}"
|
||||
|
||||
@ -772,36 +733,17 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
|
||||
entity_masks = kwargs.get("entity_masks", None)
|
||||
|
||||
# Debug logging (set ELIGEN_DEBUG=1 environment variable to enable)
|
||||
if entity_prompt_emb is not None:
|
||||
print(f"[EliGen Debug] Entity data found!")
|
||||
print(f" - entity_prompt_emb type: {type(entity_prompt_emb)}, len: {len(entity_prompt_emb) if isinstance(entity_prompt_emb, list) else 'N/A'}")
|
||||
print(f" - entity_masks shape: {entity_masks.shape if entity_masks is not None else 'None'}")
|
||||
print(f" - Number of entities: {entity_masks.shape[1] if entity_masks is not None else 'Unknown'}")
|
||||
# Check if this is positive or negative conditioning
|
||||
cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else []
|
||||
print(f" - Conditioning type: {['uncond' if c == 1 else 'cond' for c in cond_or_uncond]}")
|
||||
else:
|
||||
print(f"[EliGen Debug] No entity data in kwargs. Keys: {list(kwargs.keys())}")
|
||||
|
||||
# Branch: EliGen vs Standard path
|
||||
# Only apply EliGen to POSITIVE conditioning (cond_or_uncond contains 0)
|
||||
# Negative conditioning should use standard path
|
||||
cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else []
|
||||
is_positive_cond = 0 in cond_or_uncond # 0 = conditional/positive, 1 = unconditional/negative
|
||||
|
||||
if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond:
|
||||
# EliGen path - process entity masks (POSITIVE CONDITIONING ONLY)
|
||||
# Note: Use padded dimensions from orig_shape, not original latent dimensions
|
||||
# orig_shape is from process_img which pads to patch_size
|
||||
height = int(orig_shape[-2] * 8) # Padded latent height -> pixel height (ensure int)
|
||||
width = int(orig_shape[-1] * 8) # Padded latent width -> pixel width (ensure int)
|
||||
|
||||
print(f"[EliGen Debug] Original latent shape: {x.shape}")
|
||||
print(f"[EliGen Debug] Padded latent shape (orig_shape): {orig_shape}")
|
||||
print(f"[EliGen Debug] Calculated pixel dimensions: {height}x{width}")
|
||||
print(f"[EliGen Debug] Expected patches: {height//16}x{width//16}")
|
||||
|
||||
# Call process_entity_masks to get concatenated text, RoPE, and attention mask
|
||||
encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks(
|
||||
latents=x,
|
||||
|
||||
@ -1462,9 +1462,6 @@ class QwenImage(BaseModel):
|
||||
entity_masks = kwargs.get("entity_masks", None)
|
||||
if entity_masks is not None:
|
||||
out['entity_masks'] = comfy.conds.CONDRegular(entity_masks)
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
|
||||
@ -105,7 +105,7 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode):
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
################ NEW
|
||||
class TextEncodeQwenImageEliGen(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user