diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 7ac45c9a9..896a22e19 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -14,9 +14,8 @@ import comfy.patcher_extension class QwenEmbedRope(nn.Module): - """Research-accurate RoPE implementation for EliGen. - - This class matches the research pipeline's QwenEmbedRope exactly. + """RoPE implementation for EliGen. + https://github.com/modelscope/DiffSynth-Studio/blob/538017177a6136f45f57cdf0b7c4e0d7e1f8b50d/diffsynth/models/qwen_image_dit.py#L61 Returns a tuple (img_freqs, txt_freqs) for separate image and text RoPE. """ def __init__(self, theta: int, axes_dim: list, scale_rope=False): @@ -42,14 +41,23 @@ class QwenEmbedRope(nn.Module): """ Args: index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + + Returns: + Real-valued 2x2 rotation matrix format [..., 2, 2] compatible with apply_rotary_emb """ assert dim % 2 == 0 freqs = torch.outer( index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)) ) - freqs = torch.polar(torch.ones_like(freqs), freqs) - return freqs + # Convert to real-valued rotation matrix format (matches Flux rope() output) + # Rotation matrix: [[cos, -sin], [sin, cos]] + cos_freqs = torch.cos(freqs) + sin_freqs = torch.sin(freqs) + # Stack as rotation matrix: [cos, -sin, sin, cos] then reshape to [..., 2, 2] + out = torch.stack([cos_freqs, -sin_freqs, sin_freqs, cos_freqs], dim=-1) + out = out.reshape(*freqs.shape, 2, 2) + return out def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens): if isinstance(video_fhw, list): @@ -108,7 +116,7 @@ class QwenEmbedRope(nn.Module): freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) - freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1, 2, 2) self.rope_cache[rope_key] = freqs.clone().contiguous() vid_freqs.append(self.rope_cache[rope_key]) @@ -166,30 +174,11 @@ class FeedForward(nn.Module): def apply_rotary_emb(x, freqs_cis): if x.shape[1] == 0: return x - t_ = x.reshape(*x.shape[:-1], -1, 1, 2) t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] return t_out.reshape(*x.shape) -def apply_rotary_emb_qwen(x: torch.Tensor, freqs_cis: torch.Tensor): - """ - Research-accurate RoPE application for QwenEmbedRope. - - Args: - x: Input tensor with shape [b, h, s, d] (batch, heads, sequence, dim) - freqs_cis: Complex frequency tensor with shape [s, features] from QwenEmbedRope - - Returns: - Rotated tensor with same shape as input - """ - # x shape: [b, h, s, d] - # freqs_cis shape: [s, features] where features = d (complex numbers) - x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) - x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) - return x_out.type_as(x) - - class QwenTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None): super().__init__() @@ -280,29 +269,26 @@ class Attention(nn.Module): txt_query = self.norm_added_q(txt_query) txt_key = self.norm_added_k(txt_key) + ### NEW + ################################################# + # Handle both tuple (EliGen) and single tensor (standard) RoPE formats if isinstance(image_rotary_emb, tuple): # EliGen path: Apply RoPE BEFORE concatenation (research-accurate) - # txt/img query/key are currently [b, s, h, d], need to rearrange to [b, h, s, d] + # txt/img query/key are in [b, s, h, d] format, compatible with apply_rotary_emb img_rope, txt_rope = image_rotary_emb - # Rearrange to [b, h, s, d] for apply_rotary_emb_qwen - txt_query = txt_query.permute(0, 2, 1, 3) # [b, s, h, d] -> [b, h, s, d] - txt_key = txt_key.permute(0, 2, 1, 3) - img_query = img_query.permute(0, 2, 1, 3) - img_key = img_key.permute(0, 2, 1, 3) + # Add heads dimension to RoPE tensors for broadcasting + # Shape: [s, features, 2, 2] -> [s, 1, features, 2, 2] + # Also convert to match query dtype (e.g., bfloat16) + txt_rope = txt_rope.unsqueeze(1).to(dtype=txt_query.dtype) + img_rope = img_rope.unsqueeze(1).to(dtype=img_query.dtype) - # Apply RoPE separately to text and image using research function - txt_query = apply_rotary_emb_qwen(txt_query, txt_rope) - txt_key = apply_rotary_emb_qwen(txt_key, txt_rope) - img_query = apply_rotary_emb_qwen(img_query, img_rope) - img_key = apply_rotary_emb_qwen(img_key, img_rope) - - # Rearrange back to [b, s, h, d] - txt_query = txt_query.permute(0, 2, 1, 3) - txt_key = txt_key.permute(0, 2, 1, 3) - img_query = img_query.permute(0, 2, 1, 3) - img_key = img_key.permute(0, 2, 1, 3) + # Apply RoPE separately to text and image streams + txt_query = apply_rotary_emb(txt_query, txt_rope) + txt_key = apply_rotary_emb(txt_key, txt_rope) + img_query = apply_rotary_emb(img_query, img_rope) + img_key = apply_rotary_emb(img_key, img_rope) # Now concatenate joint_query = torch.cat([txt_query, img_query], dim=1) @@ -317,50 +303,25 @@ class Attention(nn.Module): joint_query = apply_rotary_emb(joint_query, image_rotary_emb) joint_key = apply_rotary_emb(joint_key, image_rotary_emb) - # Check if we have an EliGen mask - if so, use PyTorch SDPA directly (research-accurate) - has_eligen_mask = False effective_mask = attention_mask if transformer_options is not None: eligen_mask = transformer_options.get("eligen_attention_mask", None) if eligen_mask is not None: - has_eligen_mask = True effective_mask = eligen_mask # Validate shape expected_seq = joint_query.shape[1] if eligen_mask.shape[-1] != expected_seq: raise ValueError(f"EliGen mask shape {eligen_mask.shape} doesn't match sequence length {expected_seq}") + + ################################################# - if has_eligen_mask: - # EliGen path: Use PyTorch SDPA directly (matches research implementation exactly) - # Don't flatten - keep in [b, s, h, d] format for SDPA - # Reshape to [b, h, s, d] for SDPA - joint_query = joint_query.permute(0, 2, 1, 3) # [b, s, h, d] -> [b, h, s, d] - joint_key = joint_key.permute(0, 2, 1, 3) - joint_value = joint_value.permute(0, 2, 1, 3) + # Standard path: Use ComfyUI's optimized attention + joint_query = joint_query.flatten(start_dim=2) + joint_key = joint_key.flatten(start_dim=2) + joint_value = joint_value.flatten(start_dim=2) - print(f"[EliGen Debug Attention] Using PyTorch SDPA directly") - print(f" - Query shape: {joint_query.shape}") - print(f" - Mask shape: {effective_mask.shape}") - print(f" - Mask min/max: {effective_mask.min()} / {effective_mask.max():.2f}") - - # Apply SDPA with mask (research-accurate) - joint_hidden_states = torch.nn.functional.scaled_dot_product_attention( - joint_query, joint_key, joint_value, - attn_mask=effective_mask, - dropout_p=0.0, - is_causal=False - ) - - # Reshape back: [b, h, s, d] -> [b, s, h*d] - joint_hidden_states = joint_hidden_states.permute(0, 2, 1, 3).flatten(start_dim=2) - else: - # Standard path: Use ComfyUI's optimized attention - joint_query = joint_query.flatten(start_dim=2) - joint_key = joint_key.flatten(start_dim=2) - joint_value = joint_value.flatten(start_dim=2) - - joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, effective_mask, transformer_options=transformer_options) + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, effective_mask, transformer_options=transformer_options) txt_attn_output = joint_hidden_states[:, :seq_txt, :] img_attn_output = joint_hidden_states[:, seq_txt:, :] @@ -482,7 +443,7 @@ class LastLayer(nn.Module): x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :]) return x - +### NEW changes class QwenImageTransformer2DModel(nn.Module): def __init__( self, @@ -564,6 +525,7 @@ class QwenImageTransformer2DModel(nn.Module): def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, entity_masks, height, width, image): """ + https://github.com/modelscope/DiffSynth-Studio/blob/538017177a6136f45f57cdf0b7c4e0d7e1f8b50d/diffsynth/models/qwen_image_dit.py#L434 Process entity masks and build spatial attention mask for EliGen. This method: @@ -634,7 +596,6 @@ class QwenImageTransformer2DModel(nn.Module): padded_h = height // 8 padded_w = width // 8 if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w: - # Validate masks aren't larger than expected (would cause negative padding) assert entity_masks.shape[3] <= padded_h and entity_masks.shape[4] <= padded_w, \ f"Entity masks {entity_masks.shape[3]}x{entity_masks.shape[4]} larger than padded dims {padded_h}x{padded_w}" @@ -772,36 +733,17 @@ class QwenImageTransformer2DModel(nn.Module): entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None) entity_masks = kwargs.get("entity_masks", None) - # Debug logging (set ELIGEN_DEBUG=1 environment variable to enable) - if entity_prompt_emb is not None: - print(f"[EliGen Debug] Entity data found!") - print(f" - entity_prompt_emb type: {type(entity_prompt_emb)}, len: {len(entity_prompt_emb) if isinstance(entity_prompt_emb, list) else 'N/A'}") - print(f" - entity_masks shape: {entity_masks.shape if entity_masks is not None else 'None'}") - print(f" - Number of entities: {entity_masks.shape[1] if entity_masks is not None else 'Unknown'}") - # Check if this is positive or negative conditioning - cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else [] - print(f" - Conditioning type: {['uncond' if c == 1 else 'cond' for c in cond_or_uncond]}") - else: - print(f"[EliGen Debug] No entity data in kwargs. Keys: {list(kwargs.keys())}") - # Branch: EliGen vs Standard path # Only apply EliGen to POSITIVE conditioning (cond_or_uncond contains 0) - # Negative conditioning should use standard path cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else [] is_positive_cond = 0 in cond_or_uncond # 0 = conditional/positive, 1 = unconditional/negative if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond: # EliGen path - process entity masks (POSITIVE CONDITIONING ONLY) - # Note: Use padded dimensions from orig_shape, not original latent dimensions # orig_shape is from process_img which pads to patch_size height = int(orig_shape[-2] * 8) # Padded latent height -> pixel height (ensure int) width = int(orig_shape[-1] * 8) # Padded latent width -> pixel width (ensure int) - print(f"[EliGen Debug] Original latent shape: {x.shape}") - print(f"[EliGen Debug] Padded latent shape (orig_shape): {orig_shape}") - print(f"[EliGen Debug] Calculated pixel dimensions: {height}x{width}") - print(f"[EliGen Debug] Expected patches: {height//16}x{width//16}") - # Call process_entity_masks to get concatenated text, RoPE, and attention mask encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks( latents=x, diff --git a/comfy/model_base.py b/comfy/model_base.py index 869fd75bd..9a4010843 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1462,9 +1462,6 @@ class QwenImage(BaseModel): entity_masks = kwargs.get("entity_masks", None) if entity_masks is not None: out['entity_masks'] = comfy.conds.CONDRegular(entity_masks) - - # import pdb; pdb.set_trace() - return out def extra_conds_shapes(self, **kwargs): diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index d8ebbf462..5ac48dc36 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -105,7 +105,7 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode): conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True) return io.NodeOutput(conditioning) - +################ NEW class TextEncodeQwenImageEliGen(io.ComfyNode): @classmethod def define_schema(cls):