From ffe3503370598e16f882fccd74b60f0a349a1d81 Mon Sep 17 00:00:00 2001 From: nolan4 Date: Wed, 22 Oct 2025 22:20:43 -0700 Subject: [PATCH 01/12] appearing functional without rigorous testing --- comfy/ldm/qwen_image/model.py | 536 ++++++++++++++++++++++++++++++++-- comfy/model_base.py | 14 + comfy_extras/nodes_qwen.py | 176 +++++++++++ 3 files changed, 708 insertions(+), 18 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index b9f60c2b7..47fa7a5f6 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -2,8 +2,9 @@ import torch import torch.nn as nn import torch.nn.functional as F +import math from typing import Optional, Tuple -from einops import repeat +from einops import repeat, rearrange from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps from comfy.ldm.modules.attention import optimized_attention_masked @@ -11,6 +12,118 @@ from comfy.ldm.flux.layers import EmbedND import comfy.ldm.common_dit import comfy.patcher_extension + +class QwenEmbedRope(nn.Module): + """Research-accurate RoPE implementation for EliGen. + + This class matches the research pipeline's QwenEmbedRope exactly. + Returns a tuple (img_freqs, txt_freqs) for separate image and text RoPE. + """ + def __init__(self, theta: int, axes_dim: list, scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat([ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], dim=1) + self.neg_freqs = torch.cat([ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], dim=1) + self.rope_cache = {} + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + assert dim % 2 == 0 + freqs = torch.outer( + index, + 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)) + ) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens): + if isinstance(video_fhw, list): + video_fhw = tuple(max([i[j] for i in video_fhw]) for j in range(3)) + _, height, width = video_fhw + if self.scale_rope: + max_vid_index = max(height // 2, width // 2) + else: + max_vid_index = max(height, width) + required_len = max_vid_index + max(txt_seq_lens) + cur_max_len = self.pos_freqs.shape[0] + if required_len <= cur_max_len: + return + + new_max_len = math.ceil(required_len / 512) * 512 + pos_index = torch.arange(new_max_len) + neg_index = torch.arange(new_max_len).flip(0) * -1 - 1 + self.pos_freqs = torch.cat([ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], dim=1) + self.neg_freqs = torch.cat([ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], dim=1) + return + + def forward(self, video_fhw, txt_seq_lens, device): + self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens) + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + + if rope_key not in self.rope_cache: + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat( + [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0 + ) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + self.rope_cache[rope_key] = freqs.clone().contiguous() + vid_freqs.append(self.rope_cache[rope_key]) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + class GELU(nn.Module): def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): super().__init__() @@ -59,6 +172,24 @@ def apply_rotary_emb(x, freqs_cis): return t_out.reshape(*x.shape) +def apply_rotary_emb_qwen(x: torch.Tensor, freqs_cis: torch.Tensor): + """ + Research-accurate RoPE application for QwenEmbedRope. + + Args: + x: Input tensor with shape [b, h, s, d] (batch, heads, sequence, dim) + freqs_cis: Complex frequency tensor with shape [s, features] from QwenEmbedRope + + Returns: + Rotated tensor with same shape as input + """ + # x shape: [b, h, s, d] + # freqs_cis shape: [s, features] where features = d (complex numbers) + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + return x_out.type_as(x) + + class QwenTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None): super().__init__() @@ -149,18 +280,89 @@ class Attention(nn.Module): txt_query = self.norm_added_q(txt_query) txt_key = self.norm_added_k(txt_key) - joint_query = torch.cat([txt_query, img_query], dim=1) - joint_key = torch.cat([txt_key, img_key], dim=1) - joint_value = torch.cat([txt_value, img_value], dim=1) + # Handle both tuple (EliGen) and single tensor (standard) RoPE formats + if isinstance(image_rotary_emb, tuple): + # EliGen path: Apply RoPE BEFORE concatenation (research-accurate) + # txt/img query/key are currently [b, s, h, d], need to rearrange to [b, h, s, d] + img_rope, txt_rope = image_rotary_emb - joint_query = apply_rotary_emb(joint_query, image_rotary_emb) - joint_key = apply_rotary_emb(joint_key, image_rotary_emb) + # Rearrange to [b, h, s, d] for apply_rotary_emb_qwen + txt_query = txt_query.permute(0, 2, 1, 3) # [b, s, h, d] -> [b, h, s, d] + txt_key = txt_key.permute(0, 2, 1, 3) + img_query = img_query.permute(0, 2, 1, 3) + img_key = img_key.permute(0, 2, 1, 3) - joint_query = joint_query.flatten(start_dim=2) - joint_key = joint_key.flatten(start_dim=2) - joint_value = joint_value.flatten(start_dim=2) + # Apply RoPE separately to text and image using research function + txt_query = apply_rotary_emb_qwen(txt_query, txt_rope) + txt_key = apply_rotary_emb_qwen(txt_key, txt_rope) + img_query = apply_rotary_emb_qwen(img_query, img_rope) + img_key = apply_rotary_emb_qwen(img_key, img_rope) - joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options) + # Rearrange back to [b, s, h, d] + txt_query = txt_query.permute(0, 2, 1, 3) + txt_key = txt_key.permute(0, 2, 1, 3) + img_query = img_query.permute(0, 2, 1, 3) + img_key = img_key.permute(0, 2, 1, 3) + + # Now concatenate + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) + else: + # Standard path: Concatenate first, then apply RoPE + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) + + joint_query = apply_rotary_emb(joint_query, image_rotary_emb) + joint_key = apply_rotary_emb(joint_key, image_rotary_emb) + + # Check if we have an EliGen mask - if so, use PyTorch SDPA directly (research-accurate) + has_eligen_mask = False + effective_mask = attention_mask + if transformer_options is not None: + eligen_mask = transformer_options.get("eligen_attention_mask", None) + if eligen_mask is not None: + has_eligen_mask = True + effective_mask = eligen_mask + + # Validate shape + expected_seq = joint_query.shape[1] + if eligen_mask.shape[-1] != expected_seq: + raise ValueError(f"EliGen mask shape {eligen_mask.shape} doesn't match sequence length {expected_seq}") + + if has_eligen_mask: + # EliGen path: Use PyTorch SDPA directly (matches research implementation exactly) + # Don't flatten - keep in [b, s, h, d] format for SDPA + # Reshape to [b, h, s, d] for SDPA + joint_query = joint_query.permute(0, 2, 1, 3) # [b, s, h, d] -> [b, h, s, d] + joint_key = joint_key.permute(0, 2, 1, 3) + joint_value = joint_value.permute(0, 2, 1, 3) + + import os + if os.environ.get("ELIGEN_DEBUG"): + print(f"[EliGen Debug Attention] Using PyTorch SDPA directly") + print(f" - Query shape: {joint_query.shape}") + print(f" - Mask shape: {effective_mask.shape}") + print(f" - Mask min/max: {effective_mask.min()} / {effective_mask.max():.2f}") + + # Apply SDPA with mask (research-accurate) + joint_hidden_states = torch.nn.functional.scaled_dot_product_attention( + joint_query, joint_key, joint_value, + attn_mask=effective_mask, + dropout_p=0.0, + is_causal=False + ) + + # Reshape back: [b, h, s, d] -> [b, s, h*d] + joint_hidden_states = joint_hidden_states.permute(0, 2, 1, 3).flatten(start_dim=2) + else: + # Standard path: Use ComfyUI's optimized attention + joint_query = joint_query.flatten(start_dim=2) + joint_key = joint_key.flatten(start_dim=2) + joint_value = joint_value.flatten(start_dim=2) + + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, effective_mask, transformer_options=transformer_options) txt_attn_output = joint_hidden_states[:, :seq_txt, :] img_attn_output = joint_hidden_states[:, seq_txt:, :] @@ -310,6 +512,8 @@ class QwenImageTransformer2DModel(nn.Module): self.inner_dim = num_attention_heads * attention_head_dim self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope)) + # Add research-accurate RoPE for EliGen (returns tuple of img_freqs, txt_freqs) + self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16, 56, 56], scale_rope=True) self.time_text_embed = QwenTimestepProjEmbeddings( embedding_dim=self.inner_dim, @@ -359,6 +563,235 @@ class QwenImageTransformer2DModel(nn.Module): img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2) return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape + def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, + entity_prompt_emb_mask, entity_masks, height, width, image): + """ + Process entity masks and build spatial attention mask for EliGen. + + This method: + 1. Concatenates entity + global prompts + 2. Builds RoPE embeddings for concatenated text using ComfyUI's pe_embedder + 3. Creates attention mask enforcing spatial restrictions + + Args: + latents: [B, 16, H, W] + prompt_emb: [1, seq_len, 3584] - Global prompt + prompt_emb_mask: [1, seq_len] + entity_prompt_emb: List[[1, L_i, 3584]] - Entity prompts + entity_prompt_emb_mask: List[[1, L_i]] + entity_masks: [1, N, 1, H/8, W/8] + height: int + width: int + image: [B, patches, 64] - Patchified latents + + Returns: + all_prompt_emb: [1, total_seq, 3584] + image_rotary_emb: RoPE embeddings + attention_mask: [1, 1, total_seq, total_seq] + """ + + # SECTION 1: Concatenate entity + global prompts + all_prompt_emb = entity_prompt_emb + [prompt_emb] + all_prompt_emb = [self.txt_in(self.txt_norm(p)) for p in all_prompt_emb] + all_prompt_emb = torch.cat(all_prompt_emb, dim=1) + + # SECTION 2: Build RoPE position embeddings (RESEARCH-ACCURATE using QwenEmbedRope) + # Calculate img_shapes for RoPE (batch, height//16, width//16 for images in latent space after patchifying) + img_shapes = [(latents.shape[0], height//16, width//16)] + + # Calculate sequence lengths for entities and global prompt (RESEARCH-ACCURATE) + # Research code: seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()] + entity_seq_lens = [int(mask.sum(dim=1).item()) for mask in entity_prompt_emb_mask] + + # Handle None case in ComfyUI (None means no padding, all tokens valid) + if prompt_emb_mask is not None: + global_seq_len = int(prompt_emb_mask.sum(dim=1).item()) + else: + # No mask = no padding, use full sequence length + global_seq_len = int(prompt_emb.shape[1]) + + # Get base image RoPE using global prompt length (returns tuple: (img_freqs, txt_freqs)) + # RESEARCH: image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device) + txt_seq_lens = [global_seq_len] + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device) + + # Create SEPARATE RoPE embeddings for each entity (EXACTLY like research) + # RESEARCH: entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens] + entity_rotary_emb = [] + + import os + debug = os.environ.get("ELIGEN_DEBUG") + + for i, entity_seq_len in enumerate(entity_seq_lens): + # Pass as list for compatibility with research API + entity_rope = self.pos_embed(img_shapes, [entity_seq_len], device=latents.device)[1] + entity_rotary_emb.append(entity_rope) + if debug: + print(f"[EliGen Debug RoPE] Entity {i} RoPE shape: {entity_rope.shape}, seq_len: {entity_seq_len}") + + if debug: + print(f"[EliGen Debug RoPE] Global RoPE shape: {image_rotary_emb[1].shape}, seq_len: {global_seq_len}") + print(f"[EliGen Debug RoPE] Attempting to concatenate {len(entity_rotary_emb)} entity RoPEs + 1 global RoPE") + + # Concatenate entity RoPEs with global RoPE along sequence dimension (EXACTLY like research) + # QwenEmbedRope returns 2D tensors with shape [seq_len, features] + # Entity ropes: [entity_seq_len, features] + # Global rope: [global_seq_len, features] + # Concatenate along dim=0 to get [total_seq_len, features] + # RESEARCH: txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0) + txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0) + + # Replace text part of tuple (EXACTLY like research) + # RESEARCH: image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb) + image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb) + + # Debug output for RoPE embeddings + import os + if os.environ.get("ELIGEN_DEBUG"): + print(f"[EliGen Debug RoPE] Number of entities: {len(entity_seq_lens)}") + print(f"[EliGen Debug RoPE] Entity sequence lengths: {entity_seq_lens}") + print(f"[EliGen Debug RoPE] Global sequence length: {global_seq_len}") + print(f"[EliGen Debug RoPE] img_rotary_emb (tuple[0]) shape: {image_rotary_emb[0].shape}") + print(f"[EliGen Debug RoPE] txt_rotary_emb (tuple[1]) shape: {image_rotary_emb[1].shape}") + print(f"[EliGen Debug RoPE] Total text seq length: {sum(entity_seq_lens) + global_seq_len}") + + # SECTION 3: Prepare spatial masks + repeat_dim = latents.shape[1] # 16 + max_masks = entity_masks.shape[1] # N entities + entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) + + # Pad masks to match padded latent dimensions (same as process_img does) + # entity_masks shape: [1, N, 16, H/8, W/8] + # Need to pad to match orig_shape which is [B, 16, padded_H/8, padded_W/8] + padded_h = height // 8 + padded_w = width // 8 + if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w: + # Validate masks aren't larger than expected (would cause negative padding) + assert entity_masks.shape[3] <= padded_h and entity_masks.shape[4] <= padded_w, \ + f"Entity masks {entity_masks.shape[3]}x{entity_masks.shape[4]} larger than padded dims {padded_h}x{padded_w}" + + # Pad each entity mask + pad_h = padded_h - entity_masks.shape[3] + pad_w = padded_w - entity_masks.shape[4] + entity_masks = torch.nn.functional.pad(entity_masks, (0, pad_w, 0, pad_h), mode='constant', value=0) + + entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] + + # Add global mask (all True) - must be same size as padded entity masks + global_mask = torch.ones((entity_masks[0].shape[0], entity_masks[0].shape[1], padded_h, padded_w), + device=latents.device, dtype=latents.dtype) + entity_masks = entity_masks + [global_mask] + + # SECTION 4: Patchify masks + N = len(entity_masks) + batch_size = int(entity_masks[0].shape[0]) + seq_lens = entity_seq_lens + [global_seq_len] + total_seq_len = int(sum(seq_lens) + image.shape[1]) + + # Debug: Check mask dimensions + import os + if os.environ.get("ELIGEN_DEBUG"): + print(f"[EliGen Debug Patchify] entity_masks[0] shape: {entity_masks[0].shape}") + print(f"[EliGen Debug Patchify] height={height}, width={width}, height//16={height//16}, width//16={width//16}") + print(f"[EliGen Debug Patchify] Expected mask size: {height//16 * 2} x {width//16 * 2} = {(height//16) * 2} x {(width//16) * 2}") + + patched_masks = [] + for i in range(N): + patched_mask = rearrange( + entity_masks[i], + "B C (H P) (W Q) -> B (H W) (C P Q)", + H=height//16, W=width//16, P=2, Q=2 + ) + patched_masks.append(patched_mask) + + # SECTION 5: Build attention mask matrix + attention_mask = torch.ones( + (batch_size, total_seq_len, total_seq_len), + dtype=torch.bool + ).to(device=entity_masks[0].device) + + # Calculate positions + image_start = int(sum(seq_lens)) + image_end = int(total_seq_len) + cumsum = [0] + single_image_seq = int(image_end - image_start) + + for length in seq_lens: + cumsum.append(cumsum[-1] + length) + + # RULE 1: Spatial restriction (prompt <-> image) + for i in range(N): + prompt_start = cumsum[i] + prompt_end = cumsum[i+1] + + # Create binary mask for which image patches this entity can attend to + image_mask = torch.sum(patched_masks[i], dim=-1) > 0 + image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1) + + # Always repeat mask to match image sequence length (matches DiffSynth line 480) + repeat_time = single_image_seq // image_mask.shape[-1] + image_mask = image_mask.repeat(1, 1, repeat_time) + + # Bidirectional restriction: + # - Entity prompt can only attend to its masked image regions + attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask + # - Image patches can only be updated by prompts that own them + attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2) + + # RULE 2: Entity isolation + for i in range(N): + for j in range(N): + if i == j: + continue + start_i, end_i = cumsum[i], cumsum[i+1] + start_j, end_j = cumsum[j], cumsum[j+1] + attention_mask[:, start_i:end_i, start_j:end_j] = False + + # SECTION 6: Convert to additive bias + attention_mask = attention_mask.float() + attention_mask[attention_mask == 0] = float('-inf') + attention_mask[attention_mask == 1] = 0 + attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1) + + if debug: + print(f"\n[EliGen Debug Mask Values]") + print(f" Token ranges:") + for i in range(len(seq_lens)): + if i < len(seq_lens) - 1: + print(f" - Entity {i} tokens: {cumsum[i]}-{cumsum[i+1]-1} (length: {seq_lens[i]})") + else: + print(f" - Global tokens: {cumsum[i]}-{cumsum[i+1]-1} (length: {seq_lens[i]})") + print(f" - Image tokens: {sum(seq_lens)}-{total_seq_len-1}") + + print(f"\n Checking Entity 0 connections:") + # Entity 0 to itself (should be 0) + e0_to_e0 = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[0]:cumsum[1]] + print(f" - Entity0->Entity0: {(e0_to_e0 == 0).sum()}/{e0_to_e0.numel()} allowed") + + # Entity 0 to Entity 1 (should be -inf) + if len(seq_lens) > 2: + e0_to_e1 = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[1]:cumsum[2]] + print(f" - Entity0->Entity1: {(e0_to_e1 == float('-inf')).sum()}/{e0_to_e1.numel()} blocked") + + # Entity 0 to Global (should be -inf) + e0_to_global = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[-2]:cumsum[-1]] + print(f" - Entity0->Global: {(e0_to_global == float('-inf')).sum()}/{e0_to_global.numel()} blocked") + + # Entity 0 to Image (should be partially blocked based on mask) + e0_to_img = attention_mask[0, 0, cumsum[0]:cumsum[1], image_start:] + print(f" - Entity0->Image: {(e0_to_img == 0).sum()}/{e0_to_img.numel()} allowed, {(e0_to_img == float('-inf')).sum()} blocked") + + # Image to Entity 0 (should match Entity 0 to Image, transposed) + img_to_e0 = attention_mask[0, 0, image_start:, cumsum[0]:cumsum[1]] + print(f" - Image->Entity0: {(img_to_e0 == 0).sum()}/{img_to_e0.numel()} allowed") + + # Global to Image (should be fully allowed) + global_to_img = attention_mask[0, 0, cumsum[-2]:cumsum[-1], image_start:] + print(f"\n Checking Global connections:") + print(f" - Global->Image: {(global_to_img == 0).sum()}/{global_to_img.numel()} allowed") + + return all_prompt_emb, image_rotary_emb, attention_mask + def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, @@ -410,15 +843,82 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = torch.cat([hidden_states, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) - txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) - txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) - ids = torch.cat((txt_ids, img_ids), dim=1) - image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) - del ids, txt_ids, img_ids + # Extract entity data from kwargs + entity_prompt_emb = kwargs.get("entity_prompt_emb", None) + entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None) + entity_masks = kwargs.get("entity_masks", None) - hidden_states = self.img_in(hidden_states) - encoder_hidden_states = self.txt_norm(encoder_hidden_states) - encoder_hidden_states = self.txt_in(encoder_hidden_states) + # import pdb; pdb.set_trace() + + + # Debug logging (set ELIGEN_DEBUG=1 environment variable to enable) + import os + if os.environ.get("ELIGEN_DEBUG"): + if entity_prompt_emb is not None: + print(f"[EliGen Debug] Entity data found!") + print(f" - entity_prompt_emb type: {type(entity_prompt_emb)}, len: {len(entity_prompt_emb) if isinstance(entity_prompt_emb, list) else 'N/A'}") + print(f" - entity_masks shape: {entity_masks.shape if entity_masks is not None else 'None'}") + print(f" - Number of entities: {entity_masks.shape[1] if entity_masks is not None else 'Unknown'}") + # Check if this is positive or negative conditioning + cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else [] + print(f" - Conditioning type: {['uncond' if c == 1 else 'cond' for c in cond_or_uncond]}") + else: + print(f"[EliGen Debug] No entity data in kwargs. Keys: {list(kwargs.keys())}") + + # Branch: EliGen vs Standard path + # Only apply EliGen to POSITIVE conditioning (cond_or_uncond contains 0) + # Negative conditioning should use standard path + cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else [] + is_positive_cond = 0 in cond_or_uncond # 0 = conditional/positive, 1 = unconditional/negative + + if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond: + # EliGen path - process entity masks (POSITIVE CONDITIONING ONLY) + # Note: Use padded dimensions from orig_shape, not original latent dimensions + # orig_shape is from process_img which pads to patch_size + height = int(orig_shape[-2] * 8) # Padded latent height -> pixel height (ensure int) + width = int(orig_shape[-1] * 8) # Padded latent width -> pixel width (ensure int) + + if os.environ.get("ELIGEN_DEBUG"): + print(f"[EliGen Debug] Original latent shape: {x.shape}") + print(f"[EliGen Debug] Padded latent shape (orig_shape): {orig_shape}") + print(f"[EliGen Debug] Calculated pixel dimensions: {height}x{width}") + print(f"[EliGen Debug] Expected patches: {height//16}x{width//16}") + + # Call process_entity_masks to get concatenated text, RoPE, and attention mask + encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks( + latents=x, + prompt_emb=encoder_hidden_states, + prompt_emb_mask=encoder_hidden_states_mask, + entity_prompt_emb=entity_prompt_emb, + entity_prompt_emb_mask=entity_prompt_emb_mask, + entity_masks=entity_masks, + height=height, + width=width, + image=hidden_states + ) + + # Apply image projection (text already processed in process_entity_masks) + hidden_states = self.img_in(hidden_states) + + # Store attention mask in transformer_options for the attention layers + if transformer_options is None: + transformer_options = {} + transformer_options["eligen_attention_mask"] = eligen_attention_mask + + # Clean up + del img_ids + + else: + # Standard path - existing code + txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) + txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + del ids, txt_ids, img_ids + + hidden_states = self.img_in(hidden_states) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) if guidance is not None: guidance = guidance * 1000 diff --git a/comfy/model_base.py b/comfy/model_base.py index 8274c7dea..050e10c98 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1462,6 +1462,20 @@ class QwenImage(BaseModel): ref_latents_method = kwargs.get("reference_latents_method", None) if ref_latents_method is not None: out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method) + + # Handle EliGen entity data + entity_prompt_emb = kwargs.get("entity_prompt_emb", None) + if entity_prompt_emb is not None: + out['entity_prompt_emb'] = entity_prompt_emb # Already wrapped in CONDList by node + + entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None) + if entity_prompt_emb_mask is not None: + out['entity_prompt_emb_mask'] = entity_prompt_emb_mask # Already wrapped in CONDList by node + + entity_masks = kwargs.get("entity_masks", None) + if entity_masks is not None: + out['entity_masks'] = entity_masks # Already wrapped in CONDRegular by node + return out def extra_conds_shapes(self, **kwargs): diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index 525239ae5..184fdfcff 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -1,6 +1,8 @@ import node_helpers import comfy.utils +import comfy.conds import math +import torch from typing_extensions import override from comfy_api.latest import ComfyExtension, io @@ -104,12 +106,186 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode): return io.NodeOutput(conditioning) +class TextEncodeQwenImageEliGen(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeQwenImageEliGen", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.Conditioning.Input("global_conditioning"), + io.Latent.Input("latent"), + io.Image.Input("entity_mask_1", optional=True), + io.String.Input("entity_prompt_1", multiline=True, dynamic_prompts=True, default=""), + io.Image.Input("entity_mask_2", optional=True), + io.String.Input("entity_prompt_2", multiline=True, dynamic_prompts=True, default=""), + io.Image.Input("entity_mask_3", optional=True), + io.String.Input("entity_prompt_3", multiline=True, dynamic_prompts=True, default=""), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, clip, global_conditioning, latent, entity_prompt_1="", entity_mask_1=None, + entity_prompt_2="", entity_mask_2=None, entity_prompt_3="", entity_mask_3=None) -> io.NodeOutput: + + # Extract dimensions from latent tensor + # latent["samples"] shape: [batch, channels, latent_h, latent_w] + latent_samples = latent["samples"] + unpadded_latent_height = latent_samples.shape[2] # Unpadded latent space + unpadded_latent_width = latent_samples.shape[3] # Unpadded latent space + + # Calculate padded dimensions (same logic as model's pad_to_patch_size with patch_size=2) + # The model pads latents to be multiples of patch_size (2 for Qwen) + patch_size = 2 + pad_h = (patch_size - unpadded_latent_height % patch_size) % patch_size + pad_w = (patch_size - unpadded_latent_width % patch_size) % patch_size + latent_height = unpadded_latent_height + pad_h # Padded latent dimensions + latent_width = unpadded_latent_width + pad_w # Padded latent dimensions + + height = latent_height * 8 # Convert to pixel space for logging + width = latent_width * 8 + + if pad_h > 0 or pad_w > 0: + print(f"[EliGen] Latent padding detected: {unpadded_latent_height}x{unpadded_latent_width} → {latent_height}x{latent_width}") + print(f"[EliGen] Target generation dimensions: {height}x{width} pixels ({latent_height}x{latent_width} latent)") + + # Collect entity prompts and masks + entity_prompts = [entity_prompt_1, entity_prompt_2, entity_prompt_3] + entity_masks_raw = [entity_mask_1, entity_mask_2, entity_mask_3] + + # Filter out entities with empty prompts or missing masks + valid_entities = [] + for prompt, mask in zip(entity_prompts, entity_masks_raw): + if prompt.strip() and mask is not None: + valid_entities.append((prompt, mask)) + + # Log warning if some entities were skipped + total_prompts_provided = len([p for p in entity_prompts if p.strip()]) + if len(valid_entities) < total_prompts_provided: + print(f"[EliGen] Warning: Only {len(valid_entities)} of {total_prompts_provided} entity prompts have valid masks") + + # If no valid entities, return standard conditioning + if len(valid_entities) == 0: + return io.NodeOutput(global_conditioning) + + # Encode each entity prompt separately + entity_prompt_emb_list = [] + entity_prompt_emb_mask_list = [] + + for entity_prompt, _ in valid_entities: + entity_tokens = clip.tokenize(entity_prompt) + entity_cond = clip.encode_from_tokens_scheduled(entity_tokens) + + # Extract embeddings and masks from conditioning + # Conditioning format: [[cond_tensor, extra_dict], ...] + entity_prompt_emb = entity_cond[0][0] # The embedding tensor directly [1, seq_len, 3584] + extra_dict = entity_cond[0][1] # Metadata dict (pooled_output, attention_mask, etc.) + + # Extract attention mask from metadata dict + entity_prompt_emb_mask = extra_dict.get("attention_mask", None) + + # If no attention mask in extra_dict, create one (all True) + if entity_prompt_emb_mask is None: + seq_len = entity_prompt_emb.shape[1] + entity_prompt_emb_mask = torch.ones((entity_prompt_emb.shape[0], seq_len), + dtype=torch.bool, device=entity_prompt_emb.device) + + entity_prompt_emb_list.append(entity_prompt_emb) + entity_prompt_emb_mask_list.append(entity_prompt_emb_mask) + + # Process spatial masks to latent space + processed_masks = [] + for i, (_, mask) in enumerate(valid_entities): + # mask is expected to be [batch, height, width, channels] or [batch, height, width] + mask_tensor = mask + + # Log original mask dimensions + original_shape = mask_tensor.shape + if len(original_shape) == 3: + orig_h, orig_w = original_shape[0], original_shape[1] + elif len(original_shape) == 4: + orig_h, orig_w = original_shape[1], original_shape[2] + else: + orig_h, orig_w = original_shape[-2], original_shape[-1] + + print(f"[EliGen] Entity {i+1} mask: {orig_h}x{orig_w} → will resize to {latent_height}x{latent_width} latent") + + # Ensure mask is in [batch, channels, height, width] format for upscale + if len(mask_tensor.shape) == 3: + # [height, width, channels] -> [1, height, width, channels] (add batch dimension) + mask_tensor = mask_tensor.unsqueeze(0) + elif len(mask_tensor.shape) == 4 and mask_tensor.shape[-1] in [1, 3, 4]: + # [batch, height, width, channels] -> [batch, channels, height, width] + mask_tensor = mask_tensor.movedim(-1, 1) + + # Take only first channel if multiple channels + if mask_tensor.shape[1] > 1: + mask_tensor = mask_tensor[:, 0:1, :, :] + + # Resize to latent space dimensions using nearest neighbor + resized_mask = comfy.utils.common_upscale( + mask_tensor, + latent_width, + latent_height, + upscale_method="nearest-exact", + crop="disabled" + ) + + # Threshold to binary (0 or 1) + # Use > 0 instead of > 0.5 to preserve edge pixels from nearest-neighbor downsampling + resized_mask = (resized_mask > 0).float() + + # Log how many pixels are active in the mask + active_pixels = (resized_mask > 0).sum().item() + total_pixels = resized_mask.numel() + print(f"[EliGen] Entity {i+1} mask coverage: {active_pixels}/{total_pixels} pixels ({100*active_pixels/total_pixels:.1f}%)") + + processed_masks.append(resized_mask) + + # Stack masks: [batch, num_entities, 1, latent_height, latent_width] + # No padding - handle dynamic number of entities + entity_masks_tensor = torch.stack(processed_masks, dim=1) + + # Extract global prompt embedding and mask from conditioning + # Conditioning format: [[cond_tensor, extra_dict]] + global_prompt_emb = global_conditioning[0][0] # The embedding tensor directly + global_extra_dict = global_conditioning[0][1] # Metadata dict + + global_prompt_emb_mask = global_extra_dict.get("attention_mask", None) + + # If no attention mask, create one (all True) + if global_prompt_emb_mask is None: + global_prompt_emb_mask = torch.ones((global_prompt_emb.shape[0], global_prompt_emb.shape[1]), + dtype=torch.bool, device=global_prompt_emb.device) + + # Attach entity data to conditioning using conditioning_set_values + # Wrap lists in CONDList so they can be properly concatenated during CFG + entity_data = { + "entity_prompt_emb": comfy.conds.CONDList(entity_prompt_emb_list), + "entity_prompt_emb_mask": comfy.conds.CONDList(entity_prompt_emb_mask_list), + "entity_masks": comfy.conds.CONDRegular(entity_masks_tensor), + } + + conditioning_with_entities = node_helpers.conditioning_set_values( + global_conditioning, + entity_data, + append=True + ) + + return io.NodeOutput(conditioning_with_entities) + + class QwenExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ TextEncodeQwenImageEdit, TextEncodeQwenImageEditPlus, + TextEncodeQwenImageEliGen, ] From 1d9124203f313a22ac7715d44dc22950fc8875e9 Mon Sep 17 00:00:00 2001 From: nolan4 Date: Thu, 23 Oct 2025 22:56:32 -0700 Subject: [PATCH 02/12] fixed application of entity-specific RoPE embeddings --- comfy/ldm/qwen_image/model.py | 140 +++++++--------------------------- comfy/model_base.py | 21 ++--- comfy_extras/nodes_qwen.py | 31 ++++---- 3 files changed, 47 insertions(+), 145 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 47fa7a5f6..7ac45c9a9 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -339,12 +339,10 @@ class Attention(nn.Module): joint_key = joint_key.permute(0, 2, 1, 3) joint_value = joint_value.permute(0, 2, 1, 3) - import os - if os.environ.get("ELIGEN_DEBUG"): - print(f"[EliGen Debug Attention] Using PyTorch SDPA directly") - print(f" - Query shape: {joint_query.shape}") - print(f" - Mask shape: {effective_mask.shape}") - print(f" - Mask min/max: {effective_mask.min()} / {effective_mask.max():.2f}") + print(f"[EliGen Debug Attention] Using PyTorch SDPA directly") + print(f" - Query shape: {joint_query.shape}") + print(f" - Mask shape: {effective_mask.shape}") + print(f" - Mask min/max: {effective_mask.min()} / {effective_mask.max():.2f}") # Apply SDPA with mask (research-accurate) joint_hidden_states = torch.nn.functional.scaled_dot_product_attention( @@ -592,15 +590,14 @@ class QwenImageTransformer2DModel(nn.Module): # SECTION 1: Concatenate entity + global prompts all_prompt_emb = entity_prompt_emb + [prompt_emb] - all_prompt_emb = [self.txt_in(self.txt_norm(p)) for p in all_prompt_emb] + all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb] all_prompt_emb = torch.cat(all_prompt_emb, dim=1) - # SECTION 2: Build RoPE position embeddings (RESEARCH-ACCURATE using QwenEmbedRope) + # SECTION 2: Build RoPE position embeddings # Calculate img_shapes for RoPE (batch, height//16, width//16 for images in latent space after patchifying) img_shapes = [(latents.shape[0], height//16, width//16)] - # Calculate sequence lengths for entities and global prompt (RESEARCH-ACCURATE) - # Research code: seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()] + # Calculate sequence lengths for entities and global prompt entity_seq_lens = [int(mask.sum(dim=1).item()) for mask in entity_prompt_emb_mask] # Handle None case in ComfyUI (None means no padding, all tokens valid) @@ -611,56 +608,27 @@ class QwenImageTransformer2DModel(nn.Module): global_seq_len = int(prompt_emb.shape[1]) # Get base image RoPE using global prompt length (returns tuple: (img_freqs, txt_freqs)) - # RESEARCH: image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device) txt_seq_lens = [global_seq_len] image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device) - # Create SEPARATE RoPE embeddings for each entity (EXACTLY like research) - # RESEARCH: entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens] - entity_rotary_emb = [] + # Create SEPARATE RoPE embeddings for each entity + # Each entity gets its own positional encoding based on its sequence length + entity_rotary_emb = [self.pos_embed(img_shapes, [entity_seq_len], device=latents.device)[1] + for entity_seq_len in entity_seq_lens] - import os - debug = os.environ.get("ELIGEN_DEBUG") - - for i, entity_seq_len in enumerate(entity_seq_lens): - # Pass as list for compatibility with research API - entity_rope = self.pos_embed(img_shapes, [entity_seq_len], device=latents.device)[1] - entity_rotary_emb.append(entity_rope) - if debug: - print(f"[EliGen Debug RoPE] Entity {i} RoPE shape: {entity_rope.shape}, seq_len: {entity_seq_len}") - - if debug: - print(f"[EliGen Debug RoPE] Global RoPE shape: {image_rotary_emb[1].shape}, seq_len: {global_seq_len}") - print(f"[EliGen Debug RoPE] Attempting to concatenate {len(entity_rotary_emb)} entity RoPEs + 1 global RoPE") - - # Concatenate entity RoPEs with global RoPE along sequence dimension (EXACTLY like research) - # QwenEmbedRope returns 2D tensors with shape [seq_len, features] - # Entity ropes: [entity_seq_len, features] - # Global rope: [global_seq_len, features] - # Concatenate along dim=0 to get [total_seq_len, features] - # RESEARCH: txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0) + # Concatenate entity RoPEs with global RoPE along sequence dimension + # Result: [entity1_seq, entity2_seq, ..., global_seq] concatenated txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0) - # Replace text part of tuple (EXACTLY like research) - # RESEARCH: image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb) + # Replace text part of tuple with concatenated entity + global RoPE image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb) - # Debug output for RoPE embeddings - import os - if os.environ.get("ELIGEN_DEBUG"): - print(f"[EliGen Debug RoPE] Number of entities: {len(entity_seq_lens)}") - print(f"[EliGen Debug RoPE] Entity sequence lengths: {entity_seq_lens}") - print(f"[EliGen Debug RoPE] Global sequence length: {global_seq_len}") - print(f"[EliGen Debug RoPE] img_rotary_emb (tuple[0]) shape: {image_rotary_emb[0].shape}") - print(f"[EliGen Debug RoPE] txt_rotary_emb (tuple[1]) shape: {image_rotary_emb[1].shape}") - print(f"[EliGen Debug RoPE] Total text seq length: {sum(entity_seq_lens) + global_seq_len}") - # SECTION 3: Prepare spatial masks repeat_dim = latents.shape[1] # 16 max_masks = entity_masks.shape[1] # N entities entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) - # Pad masks to match padded latent dimensions (same as process_img does) + # Pad masks to match padded latent dimensions # entity_masks shape: [1, N, 16, H/8, W/8] # Need to pad to match orig_shape which is [B, 16, padded_H/8, padded_W/8] padded_h = height // 8 @@ -688,13 +656,6 @@ class QwenImageTransformer2DModel(nn.Module): seq_lens = entity_seq_lens + [global_seq_len] total_seq_len = int(sum(seq_lens) + image.shape[1]) - # Debug: Check mask dimensions - import os - if os.environ.get("ELIGEN_DEBUG"): - print(f"[EliGen Debug Patchify] entity_masks[0] shape: {entity_masks[0].shape}") - print(f"[EliGen Debug Patchify] height={height}, width={width}, height//16={height//16}, width//16={width//16}") - print(f"[EliGen Debug Patchify] Expected mask size: {height//16 * 2} x {width//16 * 2} = {(height//16) * 2} x {(width//16) * 2}") - patched_masks = [] for i in range(N): patched_mask = rearrange( @@ -753,43 +714,6 @@ class QwenImageTransformer2DModel(nn.Module): attention_mask[attention_mask == 1] = 0 attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1) - if debug: - print(f"\n[EliGen Debug Mask Values]") - print(f" Token ranges:") - for i in range(len(seq_lens)): - if i < len(seq_lens) - 1: - print(f" - Entity {i} tokens: {cumsum[i]}-{cumsum[i+1]-1} (length: {seq_lens[i]})") - else: - print(f" - Global tokens: {cumsum[i]}-{cumsum[i+1]-1} (length: {seq_lens[i]})") - print(f" - Image tokens: {sum(seq_lens)}-{total_seq_len-1}") - - print(f"\n Checking Entity 0 connections:") - # Entity 0 to itself (should be 0) - e0_to_e0 = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[0]:cumsum[1]] - print(f" - Entity0->Entity0: {(e0_to_e0 == 0).sum()}/{e0_to_e0.numel()} allowed") - - # Entity 0 to Entity 1 (should be -inf) - if len(seq_lens) > 2: - e0_to_e1 = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[1]:cumsum[2]] - print(f" - Entity0->Entity1: {(e0_to_e1 == float('-inf')).sum()}/{e0_to_e1.numel()} blocked") - - # Entity 0 to Global (should be -inf) - e0_to_global = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[-2]:cumsum[-1]] - print(f" - Entity0->Global: {(e0_to_global == float('-inf')).sum()}/{e0_to_global.numel()} blocked") - - # Entity 0 to Image (should be partially blocked based on mask) - e0_to_img = attention_mask[0, 0, cumsum[0]:cumsum[1], image_start:] - print(f" - Entity0->Image: {(e0_to_img == 0).sum()}/{e0_to_img.numel()} allowed, {(e0_to_img == float('-inf')).sum()} blocked") - - # Image to Entity 0 (should match Entity 0 to Image, transposed) - img_to_e0 = attention_mask[0, 0, image_start:, cumsum[0]:cumsum[1]] - print(f" - Image->Entity0: {(img_to_e0 == 0).sum()}/{img_to_e0.numel()} allowed") - - # Global to Image (should be fully allowed) - global_to_img = attention_mask[0, 0, cumsum[-2]:cumsum[-1], image_start:] - print(f"\n Checking Global connections:") - print(f" - Global->Image: {(global_to_img == 0).sum()}/{global_to_img.numel()} allowed") - return all_prompt_emb, image_rotary_emb, attention_mask def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs): @@ -848,22 +772,17 @@ class QwenImageTransformer2DModel(nn.Module): entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None) entity_masks = kwargs.get("entity_masks", None) - # import pdb; pdb.set_trace() - - # Debug logging (set ELIGEN_DEBUG=1 environment variable to enable) - import os - if os.environ.get("ELIGEN_DEBUG"): - if entity_prompt_emb is not None: - print(f"[EliGen Debug] Entity data found!") - print(f" - entity_prompt_emb type: {type(entity_prompt_emb)}, len: {len(entity_prompt_emb) if isinstance(entity_prompt_emb, list) else 'N/A'}") - print(f" - entity_masks shape: {entity_masks.shape if entity_masks is not None else 'None'}") - print(f" - Number of entities: {entity_masks.shape[1] if entity_masks is not None else 'Unknown'}") - # Check if this is positive or negative conditioning - cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else [] - print(f" - Conditioning type: {['uncond' if c == 1 else 'cond' for c in cond_or_uncond]}") - else: - print(f"[EliGen Debug] No entity data in kwargs. Keys: {list(kwargs.keys())}") + if entity_prompt_emb is not None: + print(f"[EliGen Debug] Entity data found!") + print(f" - entity_prompt_emb type: {type(entity_prompt_emb)}, len: {len(entity_prompt_emb) if isinstance(entity_prompt_emb, list) else 'N/A'}") + print(f" - entity_masks shape: {entity_masks.shape if entity_masks is not None else 'None'}") + print(f" - Number of entities: {entity_masks.shape[1] if entity_masks is not None else 'Unknown'}") + # Check if this is positive or negative conditioning + cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else [] + print(f" - Conditioning type: {['uncond' if c == 1 else 'cond' for c in cond_or_uncond]}") + else: + print(f"[EliGen Debug] No entity data in kwargs. Keys: {list(kwargs.keys())}") # Branch: EliGen vs Standard path # Only apply EliGen to POSITIVE conditioning (cond_or_uncond contains 0) @@ -878,11 +797,10 @@ class QwenImageTransformer2DModel(nn.Module): height = int(orig_shape[-2] * 8) # Padded latent height -> pixel height (ensure int) width = int(orig_shape[-1] * 8) # Padded latent width -> pixel width (ensure int) - if os.environ.get("ELIGEN_DEBUG"): - print(f"[EliGen Debug] Original latent shape: {x.shape}") - print(f"[EliGen Debug] Padded latent shape (orig_shape): {orig_shape}") - print(f"[EliGen Debug] Calculated pixel dimensions: {height}x{width}") - print(f"[EliGen Debug] Expected patches: {height//16}x{width//16}") + print(f"[EliGen Debug] Original latent shape: {x.shape}") + print(f"[EliGen Debug] Padded latent shape (orig_shape): {orig_shape}") + print(f"[EliGen Debug] Calculated pixel dimensions: {height}x{width}") + print(f"[EliGen Debug] Expected patches: {height//16}x{width//16}") # Call process_entity_masks to get concatenated text, RoPE, and attention mask encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks( diff --git a/comfy/model_base.py b/comfy/model_base.py index 050e10c98..869fd75bd 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -119,7 +119,6 @@ def convert_tensor(extra, dtype, device): extra = comfy.model_management.cast_to_device(extra, device, None) return extra - class BaseModel(torch.nn.Module): def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel): super().__init__() @@ -381,7 +380,6 @@ class BaseModel(torch.nn.Module): def extra_conds_shapes(self, **kwargs): return {} - def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None): adm_inputs = [] weights = [] @@ -477,7 +475,6 @@ class SDXL(BaseModel): flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) - class SVD_img2vid(BaseModel): def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): super().__init__(model_config, model_type, device=device) @@ -554,7 +551,6 @@ class SV3D_p(SVD_img2vid): out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out)) return torch.cat(out, dim=1) - class Stable_Zero123(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None): super().__init__(model_config, model_type, device=device) @@ -638,13 +634,11 @@ class IP2P: image = utils.resize_to_batch_size(image, noise.shape[0]) return self.process_ip2p_image_in(image) - class SD15_instructpix2pix(IP2P, BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device) self.process_ip2p_image_in = lambda image: image - class SDXL_instructpix2pix(IP2P, SDXL): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device) @@ -694,7 +688,6 @@ class StableCascade_C(BaseModel): out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn) return out - class StableCascade_B(BaseModel): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): super().__init__(model_config, model_type, device=device, unet_model=StageB) @@ -714,7 +707,6 @@ class StableCascade_B(BaseModel): out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,))) return out - class SD3(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=OpenAISignatureMMDITWrapper) @@ -729,7 +721,6 @@ class SD3(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out - class AuraFlow(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.aura.mmdit.MMDiT) @@ -741,7 +732,6 @@ class AuraFlow(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out - class StableAudio1(BaseModel): def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer) @@ -780,7 +770,6 @@ class StableAudio1(BaseModel): sd["{}{}".format(k, l)] = s[l] return sd - class HunyuanDiT(BaseModel): def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT) @@ -914,7 +903,6 @@ class Flux(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out - class GenmoMochi(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint) @@ -1166,7 +1154,6 @@ class WAN21(BaseModel): return out - class WAN21_Vace(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel) @@ -1466,15 +1453,17 @@ class QwenImage(BaseModel): # Handle EliGen entity data entity_prompt_emb = kwargs.get("entity_prompt_emb", None) if entity_prompt_emb is not None: - out['entity_prompt_emb'] = entity_prompt_emb # Already wrapped in CONDList by node + out['entity_prompt_emb'] = comfy.conds.CONDList(entity_prompt_emb) entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None) if entity_prompt_emb_mask is not None: - out['entity_prompt_emb_mask'] = entity_prompt_emb_mask # Already wrapped in CONDList by node + out['entity_prompt_emb_mask'] = comfy.conds.CONDList(entity_prompt_emb_mask) entity_masks = kwargs.get("entity_masks", None) if entity_masks is not None: - out['entity_masks'] = entity_masks # Already wrapped in CONDRegular by node + out['entity_masks'] = comfy.conds.CONDRegular(entity_masks) + + # import pdb; pdb.set_trace() return out diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index 184fdfcff..d8ebbf462 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -176,17 +176,12 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): entity_prompt_emb_list = [] entity_prompt_emb_mask_list = [] - for entity_prompt, _ in valid_entities: + for entity_prompt, _ in valid_entities: # mask not used at this point entity_tokens = clip.tokenize(entity_prompt) - entity_cond = clip.encode_from_tokens_scheduled(entity_tokens) - - # Extract embeddings and masks from conditioning - # Conditioning format: [[cond_tensor, extra_dict], ...] - entity_prompt_emb = entity_cond[0][0] # The embedding tensor directly [1, seq_len, 3584] - extra_dict = entity_cond[0][1] # Metadata dict (pooled_output, attention_mask, etc.) - - # Extract attention mask from metadata dict - entity_prompt_emb_mask = extra_dict.get("attention_mask", None) + entity_cond_dict = clip.encode_from_tokens(entity_tokens, return_pooled=True, return_dict=True) + + entity_prompt_emb = entity_cond_dict["cond"] + entity_prompt_emb_mask = entity_cond_dict.get("attention_mask", None) # If no attention mask in extra_dict, create one (all True) if entity_prompt_emb_mask is None: @@ -194,11 +189,12 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): entity_prompt_emb_mask = torch.ones((entity_prompt_emb.shape[0], seq_len), dtype=torch.bool, device=entity_prompt_emb.device) + entity_prompt_emb_list.append(entity_prompt_emb) entity_prompt_emb_mask_list.append(entity_prompt_emb_mask) # Process spatial masks to latent space - processed_masks = [] + processed_entity_masks = [] for i, (_, mask) in enumerate(valid_entities): # mask is expected to be [batch, height, width, channels] or [batch, height, width] mask_tensor = mask @@ -244,11 +240,11 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): total_pixels = resized_mask.numel() print(f"[EliGen] Entity {i+1} mask coverage: {active_pixels}/{total_pixels} pixels ({100*active_pixels/total_pixels:.1f}%)") - processed_masks.append(resized_mask) + processed_entity_masks.append(resized_mask) - # Stack masks: [batch, num_entities, 1, latent_height, latent_width] + # Stack masks: [batch, num_entities, 1, latent_height, latent_width] (1 is selected channel) # No padding - handle dynamic number of entities - entity_masks_tensor = torch.stack(processed_masks, dim=1) + entity_masks_tensor = torch.stack(processed_entity_masks, dim=1) # Extract global prompt embedding and mask from conditioning # Conditioning format: [[cond_tensor, extra_dict]] @@ -263,11 +259,10 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): dtype=torch.bool, device=global_prompt_emb.device) # Attach entity data to conditioning using conditioning_set_values - # Wrap lists in CONDList so they can be properly concatenated during CFG entity_data = { - "entity_prompt_emb": comfy.conds.CONDList(entity_prompt_emb_list), - "entity_prompt_emb_mask": comfy.conds.CONDList(entity_prompt_emb_mask_list), - "entity_masks": comfy.conds.CONDRegular(entity_masks_tensor), + "entity_prompt_emb": entity_prompt_emb_list, + "entity_prompt_emb_mask": entity_prompt_emb_mask_list, + "entity_masks": entity_masks_tensor, } conditioning_with_entities = node_helpers.conditioning_set_values( From 0f4a141faf05b7390df8446e55fa095705ffdebd Mon Sep 17 00:00:00 2001 From: nolan4 Date: Fri, 24 Oct 2025 17:37:26 -0700 Subject: [PATCH 03/12] working using comfyUI's optimized attention and rotary embedding funcs --- comfy/ldm/qwen_image/model.py | 132 ++++++++++------------------------ comfy/model_base.py | 3 - comfy_extras/nodes_qwen.py | 2 +- 3 files changed, 38 insertions(+), 99 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 7ac45c9a9..896a22e19 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -14,9 +14,8 @@ import comfy.patcher_extension class QwenEmbedRope(nn.Module): - """Research-accurate RoPE implementation for EliGen. - - This class matches the research pipeline's QwenEmbedRope exactly. + """RoPE implementation for EliGen. + https://github.com/modelscope/DiffSynth-Studio/blob/538017177a6136f45f57cdf0b7c4e0d7e1f8b50d/diffsynth/models/qwen_image_dit.py#L61 Returns a tuple (img_freqs, txt_freqs) for separate image and text RoPE. """ def __init__(self, theta: int, axes_dim: list, scale_rope=False): @@ -42,14 +41,23 @@ class QwenEmbedRope(nn.Module): """ Args: index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + + Returns: + Real-valued 2x2 rotation matrix format [..., 2, 2] compatible with apply_rotary_emb """ assert dim % 2 == 0 freqs = torch.outer( index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)) ) - freqs = torch.polar(torch.ones_like(freqs), freqs) - return freqs + # Convert to real-valued rotation matrix format (matches Flux rope() output) + # Rotation matrix: [[cos, -sin], [sin, cos]] + cos_freqs = torch.cos(freqs) + sin_freqs = torch.sin(freqs) + # Stack as rotation matrix: [cos, -sin, sin, cos] then reshape to [..., 2, 2] + out = torch.stack([cos_freqs, -sin_freqs, sin_freqs, cos_freqs], dim=-1) + out = out.reshape(*freqs.shape, 2, 2) + return out def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens): if isinstance(video_fhw, list): @@ -108,7 +116,7 @@ class QwenEmbedRope(nn.Module): freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) - freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1, 2, 2) self.rope_cache[rope_key] = freqs.clone().contiguous() vid_freqs.append(self.rope_cache[rope_key]) @@ -166,30 +174,11 @@ class FeedForward(nn.Module): def apply_rotary_emb(x, freqs_cis): if x.shape[1] == 0: return x - t_ = x.reshape(*x.shape[:-1], -1, 1, 2) t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] return t_out.reshape(*x.shape) -def apply_rotary_emb_qwen(x: torch.Tensor, freqs_cis: torch.Tensor): - """ - Research-accurate RoPE application for QwenEmbedRope. - - Args: - x: Input tensor with shape [b, h, s, d] (batch, heads, sequence, dim) - freqs_cis: Complex frequency tensor with shape [s, features] from QwenEmbedRope - - Returns: - Rotated tensor with same shape as input - """ - # x shape: [b, h, s, d] - # freqs_cis shape: [s, features] where features = d (complex numbers) - x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) - x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) - return x_out.type_as(x) - - class QwenTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None): super().__init__() @@ -280,29 +269,26 @@ class Attention(nn.Module): txt_query = self.norm_added_q(txt_query) txt_key = self.norm_added_k(txt_key) + ### NEW + ################################################# + # Handle both tuple (EliGen) and single tensor (standard) RoPE formats if isinstance(image_rotary_emb, tuple): # EliGen path: Apply RoPE BEFORE concatenation (research-accurate) - # txt/img query/key are currently [b, s, h, d], need to rearrange to [b, h, s, d] + # txt/img query/key are in [b, s, h, d] format, compatible with apply_rotary_emb img_rope, txt_rope = image_rotary_emb - # Rearrange to [b, h, s, d] for apply_rotary_emb_qwen - txt_query = txt_query.permute(0, 2, 1, 3) # [b, s, h, d] -> [b, h, s, d] - txt_key = txt_key.permute(0, 2, 1, 3) - img_query = img_query.permute(0, 2, 1, 3) - img_key = img_key.permute(0, 2, 1, 3) + # Add heads dimension to RoPE tensors for broadcasting + # Shape: [s, features, 2, 2] -> [s, 1, features, 2, 2] + # Also convert to match query dtype (e.g., bfloat16) + txt_rope = txt_rope.unsqueeze(1).to(dtype=txt_query.dtype) + img_rope = img_rope.unsqueeze(1).to(dtype=img_query.dtype) - # Apply RoPE separately to text and image using research function - txt_query = apply_rotary_emb_qwen(txt_query, txt_rope) - txt_key = apply_rotary_emb_qwen(txt_key, txt_rope) - img_query = apply_rotary_emb_qwen(img_query, img_rope) - img_key = apply_rotary_emb_qwen(img_key, img_rope) - - # Rearrange back to [b, s, h, d] - txt_query = txt_query.permute(0, 2, 1, 3) - txt_key = txt_key.permute(0, 2, 1, 3) - img_query = img_query.permute(0, 2, 1, 3) - img_key = img_key.permute(0, 2, 1, 3) + # Apply RoPE separately to text and image streams + txt_query = apply_rotary_emb(txt_query, txt_rope) + txt_key = apply_rotary_emb(txt_key, txt_rope) + img_query = apply_rotary_emb(img_query, img_rope) + img_key = apply_rotary_emb(img_key, img_rope) # Now concatenate joint_query = torch.cat([txt_query, img_query], dim=1) @@ -317,50 +303,25 @@ class Attention(nn.Module): joint_query = apply_rotary_emb(joint_query, image_rotary_emb) joint_key = apply_rotary_emb(joint_key, image_rotary_emb) - # Check if we have an EliGen mask - if so, use PyTorch SDPA directly (research-accurate) - has_eligen_mask = False effective_mask = attention_mask if transformer_options is not None: eligen_mask = transformer_options.get("eligen_attention_mask", None) if eligen_mask is not None: - has_eligen_mask = True effective_mask = eligen_mask # Validate shape expected_seq = joint_query.shape[1] if eligen_mask.shape[-1] != expected_seq: raise ValueError(f"EliGen mask shape {eligen_mask.shape} doesn't match sequence length {expected_seq}") + + ################################################# - if has_eligen_mask: - # EliGen path: Use PyTorch SDPA directly (matches research implementation exactly) - # Don't flatten - keep in [b, s, h, d] format for SDPA - # Reshape to [b, h, s, d] for SDPA - joint_query = joint_query.permute(0, 2, 1, 3) # [b, s, h, d] -> [b, h, s, d] - joint_key = joint_key.permute(0, 2, 1, 3) - joint_value = joint_value.permute(0, 2, 1, 3) + # Standard path: Use ComfyUI's optimized attention + joint_query = joint_query.flatten(start_dim=2) + joint_key = joint_key.flatten(start_dim=2) + joint_value = joint_value.flatten(start_dim=2) - print(f"[EliGen Debug Attention] Using PyTorch SDPA directly") - print(f" - Query shape: {joint_query.shape}") - print(f" - Mask shape: {effective_mask.shape}") - print(f" - Mask min/max: {effective_mask.min()} / {effective_mask.max():.2f}") - - # Apply SDPA with mask (research-accurate) - joint_hidden_states = torch.nn.functional.scaled_dot_product_attention( - joint_query, joint_key, joint_value, - attn_mask=effective_mask, - dropout_p=0.0, - is_causal=False - ) - - # Reshape back: [b, h, s, d] -> [b, s, h*d] - joint_hidden_states = joint_hidden_states.permute(0, 2, 1, 3).flatten(start_dim=2) - else: - # Standard path: Use ComfyUI's optimized attention - joint_query = joint_query.flatten(start_dim=2) - joint_key = joint_key.flatten(start_dim=2) - joint_value = joint_value.flatten(start_dim=2) - - joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, effective_mask, transformer_options=transformer_options) + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, effective_mask, transformer_options=transformer_options) txt_attn_output = joint_hidden_states[:, :seq_txt, :] img_attn_output = joint_hidden_states[:, seq_txt:, :] @@ -482,7 +443,7 @@ class LastLayer(nn.Module): x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :]) return x - +### NEW changes class QwenImageTransformer2DModel(nn.Module): def __init__( self, @@ -564,6 +525,7 @@ class QwenImageTransformer2DModel(nn.Module): def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, entity_masks, height, width, image): """ + https://github.com/modelscope/DiffSynth-Studio/blob/538017177a6136f45f57cdf0b7c4e0d7e1f8b50d/diffsynth/models/qwen_image_dit.py#L434 Process entity masks and build spatial attention mask for EliGen. This method: @@ -634,7 +596,6 @@ class QwenImageTransformer2DModel(nn.Module): padded_h = height // 8 padded_w = width // 8 if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w: - # Validate masks aren't larger than expected (would cause negative padding) assert entity_masks.shape[3] <= padded_h and entity_masks.shape[4] <= padded_w, \ f"Entity masks {entity_masks.shape[3]}x{entity_masks.shape[4]} larger than padded dims {padded_h}x{padded_w}" @@ -772,36 +733,17 @@ class QwenImageTransformer2DModel(nn.Module): entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None) entity_masks = kwargs.get("entity_masks", None) - # Debug logging (set ELIGEN_DEBUG=1 environment variable to enable) - if entity_prompt_emb is not None: - print(f"[EliGen Debug] Entity data found!") - print(f" - entity_prompt_emb type: {type(entity_prompt_emb)}, len: {len(entity_prompt_emb) if isinstance(entity_prompt_emb, list) else 'N/A'}") - print(f" - entity_masks shape: {entity_masks.shape if entity_masks is not None else 'None'}") - print(f" - Number of entities: {entity_masks.shape[1] if entity_masks is not None else 'Unknown'}") - # Check if this is positive or negative conditioning - cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else [] - print(f" - Conditioning type: {['uncond' if c == 1 else 'cond' for c in cond_or_uncond]}") - else: - print(f"[EliGen Debug] No entity data in kwargs. Keys: {list(kwargs.keys())}") - # Branch: EliGen vs Standard path # Only apply EliGen to POSITIVE conditioning (cond_or_uncond contains 0) - # Negative conditioning should use standard path cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else [] is_positive_cond = 0 in cond_or_uncond # 0 = conditional/positive, 1 = unconditional/negative if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond: # EliGen path - process entity masks (POSITIVE CONDITIONING ONLY) - # Note: Use padded dimensions from orig_shape, not original latent dimensions # orig_shape is from process_img which pads to patch_size height = int(orig_shape[-2] * 8) # Padded latent height -> pixel height (ensure int) width = int(orig_shape[-1] * 8) # Padded latent width -> pixel width (ensure int) - print(f"[EliGen Debug] Original latent shape: {x.shape}") - print(f"[EliGen Debug] Padded latent shape (orig_shape): {orig_shape}") - print(f"[EliGen Debug] Calculated pixel dimensions: {height}x{width}") - print(f"[EliGen Debug] Expected patches: {height//16}x{width//16}") - # Call process_entity_masks to get concatenated text, RoPE, and attention mask encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks( latents=x, diff --git a/comfy/model_base.py b/comfy/model_base.py index 869fd75bd..9a4010843 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1462,9 +1462,6 @@ class QwenImage(BaseModel): entity_masks = kwargs.get("entity_masks", None) if entity_masks is not None: out['entity_masks'] = comfy.conds.CONDRegular(entity_masks) - - # import pdb; pdb.set_trace() - return out def extra_conds_shapes(self, **kwargs): diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index d8ebbf462..5ac48dc36 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -105,7 +105,7 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode): conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True) return io.NodeOutput(conditioning) - +################ NEW class TextEncodeQwenImageEliGen(io.ComfyNode): @classmethod def define_schema(cls): From b0ade4bb85ad6834449b374ac3422e239da109d0 Mon Sep 17 00:00:00 2001 From: nolan4 Date: Fri, 24 Oct 2025 18:22:39 -0700 Subject: [PATCH 04/12] mask instead of image inputs for qwen eligen pipeline --- comfy_extras/nodes_qwen.py | 46 ++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index 5ac48dc36..d90707a49 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -116,11 +116,11 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): io.Clip.Input("clip"), io.Conditioning.Input("global_conditioning"), io.Latent.Input("latent"), - io.Image.Input("entity_mask_1", optional=True), + io.Mask.Input("entity_mask_1", optional=True), io.String.Input("entity_prompt_1", multiline=True, dynamic_prompts=True, default=""), - io.Image.Input("entity_mask_2", optional=True), + io.Mask.Input("entity_mask_2", optional=True), io.String.Input("entity_prompt_2", multiline=True, dynamic_prompts=True, default=""), - io.Image.Input("entity_mask_3", optional=True), + io.Mask.Input("entity_mask_3", optional=True), io.String.Input("entity_prompt_3", multiline=True, dynamic_prompts=True, default=""), ], outputs=[ @@ -196,31 +196,26 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): # Process spatial masks to latent space processed_entity_masks = [] for i, (_, mask) in enumerate(valid_entities): - # mask is expected to be [batch, height, width, channels] or [batch, height, width] + # MASK type format: [batch, height, width] (no channel dimension) + # This is different from IMAGE type which is [batch, height, width, channels] mask_tensor = mask # Log original mask dimensions original_shape = mask_tensor.shape - if len(original_shape) == 3: + if len(original_shape) == 2: + # [height, width] - single mask without batch orig_h, orig_w = original_shape[0], original_shape[1] - elif len(original_shape) == 4: + # Add batch dimension: [1, height, width] + mask_tensor = mask_tensor.unsqueeze(0) + elif len(original_shape) == 3: + # [batch, height, width] - standard MASK format orig_h, orig_w = original_shape[1], original_shape[2] else: - orig_h, orig_w = original_shape[-2], original_shape[-1] + raise ValueError(f"Unexpected mask shape: {original_shape}. Expected [H, W] or [B, H, W]") - print(f"[EliGen] Entity {i+1} mask: {orig_h}x{orig_w} → will resize to {latent_height}x{latent_width} latent") - - # Ensure mask is in [batch, channels, height, width] format for upscale - if len(mask_tensor.shape) == 3: - # [height, width, channels] -> [1, height, width, channels] (add batch dimension) - mask_tensor = mask_tensor.unsqueeze(0) - elif len(mask_tensor.shape) == 4 and mask_tensor.shape[-1] in [1, 3, 4]: - # [batch, height, width, channels] -> [batch, channels, height, width] - mask_tensor = mask_tensor.movedim(-1, 1) - - # Take only first channel if multiple channels - if mask_tensor.shape[1] > 1: - mask_tensor = mask_tensor[:, 0:1, :, :] + # Convert MASK format [batch, height, width] to [batch, 1, height, width] for common_upscale + # common_upscale expects [batch, channels, height, width] + mask_tensor = mask_tensor.unsqueeze(1) # Add channel dimension: [batch, 1, height, width] # Resize to latent space dimensions using nearest neighbor resized_mask = comfy.utils.common_upscale( @@ -238,13 +233,16 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): # Log how many pixels are active in the mask active_pixels = (resized_mask > 0).sum().item() total_pixels = resized_mask.numel() - print(f"[EliGen] Entity {i+1} mask coverage: {active_pixels}/{total_pixels} pixels ({100*active_pixels/total_pixels:.1f}%)") processed_entity_masks.append(resized_mask) - # Stack masks: [batch, num_entities, 1, latent_height, latent_width] (1 is selected channel) - # No padding - handle dynamic number of entities - entity_masks_tensor = torch.stack(processed_entity_masks, dim=1) + # Stack masks: [batch, num_entities, 1, latent_height, latent_width] + # Each item in processed_entity_masks has shape [1, 1, H, W] (batch=1, channel=1) + # We need to remove batch dim, stack, then add it back + # Option 1: Squeeze batch dim from each mask + processed_no_batch = [m.squeeze(0) for m in processed_entity_masks] # Each: [1, H, W] + entity_masks_tensor = torch.stack(processed_no_batch, dim=0) # [num_entities, 1, H, W] + entity_masks_tensor = entity_masks_tensor.unsqueeze(0) # [1, num_entities, 1, H, W] # Extract global prompt embedding and mask from conditioning # Conditioning format: [[cond_tensor, extra_dict]] From b22226562807300dafc062d0d316f95f177c6c37 Mon Sep 17 00:00:00 2001 From: nolan4 Date: Fri, 24 Oct 2025 19:22:26 -0700 Subject: [PATCH 05/12] qwen eligen batch size > 1 fix --- comfy/ldm/qwen_image/model.py | 90 ++++++++++++++++++------ comfy_extras/nodes_qwen.py | 124 ++++++++++++++++++++++++++++++---- 2 files changed, 179 insertions(+), 35 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 896a22e19..42553154e 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import math +import logging from typing import Optional, Tuple from einops import repeat, rearrange @@ -12,6 +13,8 @@ from comfy.ldm.flux.layers import EmbedND import comfy.ldm.common_dit import comfy.patcher_extension +logger = logging.getLogger(__name__) + class QwenEmbedRope(nn.Module): """RoPE implementation for EliGen. @@ -269,9 +272,6 @@ class Attention(nn.Module): txt_query = self.norm_added_q(txt_query) txt_key = self.norm_added_k(txt_key) - ### NEW - ################################################# - # Handle both tuple (EliGen) and single tensor (standard) RoPE formats if isinstance(image_rotary_emb, tuple): # EliGen path: Apply RoPE BEFORE concatenation (research-accurate) @@ -303,6 +303,7 @@ class Attention(nn.Module): joint_query = apply_rotary_emb(joint_query, image_rotary_emb) joint_key = apply_rotary_emb(joint_key, image_rotary_emb) + # Apply EliGen attention mask if present effective_mask = attention_mask if transformer_options is not None: eligen_mask = transformer_options.get("eligen_attention_mask", None) @@ -312,11 +313,12 @@ class Attention(nn.Module): # Validate shape expected_seq = joint_query.shape[1] if eligen_mask.shape[-1] != expected_seq: - raise ValueError(f"EliGen mask shape {eligen_mask.shape} doesn't match sequence length {expected_seq}") - - ################################################# + raise ValueError( + f"EliGen attention mask shape mismatch: {eligen_mask.shape} " + f"doesn't match sequence length {expected_seq}" + ) - # Standard path: Use ComfyUI's optimized attention + # Use ComfyUI's optimized attention joint_query = joint_query.flatten(start_dim=2) joint_key = joint_key.flatten(start_dim=2) joint_value = joint_value.flatten(start_dim=2) @@ -443,8 +445,12 @@ class LastLayer(nn.Module): x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :]) return x -### NEW changes class QwenImageTransformer2DModel(nn.Module): + # Constants for EliGen processing + LATENT_TO_PIXEL_RATIO = 8 # Latents are 8x downsampled from pixel space + PATCH_TO_LATENT_RATIO = 2 # 2x2 patches in latent space + PATCH_TO_PIXEL_RATIO = 16 # Combined: 2x2 patches on 8x downsampled latents = 16x in pixel space + def __init__( self, patch_size: int = 2, @@ -540,8 +546,8 @@ class QwenImageTransformer2DModel(nn.Module): entity_prompt_emb: List[[1, L_i, 3584]] - Entity prompts entity_prompt_emb_mask: List[[1, L_i]] entity_masks: [1, N, 1, H/8, W/8] - height: int - width: int + height: int (padded pixel height) + width: int (padded pixel width) image: [B, patches, 64] - Patchified latents Returns: @@ -549,6 +555,17 @@ class QwenImageTransformer2DModel(nn.Module): image_rotary_emb: RoPE embeddings attention_mask: [1, 1, total_seq, total_seq] """ + num_entities = len(entity_prompt_emb) + batch_size = latents.shape[0] + logger.debug( + f"[EliGen Model] Processing {num_entities} entities for {height}x{width}px image " + f"(latents: {latents.shape}, batch_size: {batch_size})" + ) + + # Validate batch consistency (all batches should have same sequence lengths) + # This is a ComfyUI requirement - batched prompts must have uniform padding + if batch_size > 1: + logger.debug(f"[EliGen Model] Batch size > 1 detected ({batch_size} batches), ensuring RoPE compatibility") # SECTION 1: Concatenate entity + global prompts all_prompt_emb = entity_prompt_emb + [prompt_emb] @@ -556,45 +573,63 @@ class QwenImageTransformer2DModel(nn.Module): all_prompt_emb = torch.cat(all_prompt_emb, dim=1) # SECTION 2: Build RoPE position embeddings - # Calculate img_shapes for RoPE (batch, height//16, width//16 for images in latent space after patchifying) - img_shapes = [(latents.shape[0], height//16, width//16)] + # For EliGen, we create RoPE for ONE batch element's dimensions + # The queries/keys have shape [batch, seq, heads, dim], and RoPE broadcasts across batch dim + patch_h = height // self.PATCH_TO_PIXEL_RATIO + patch_w = width // self.PATCH_TO_PIXEL_RATIO + + # Create RoPE for a single image (frame=1 for images, not video) + # This will broadcast across all batch elements automatically + img_shapes_single = [(1, patch_h, patch_w)] # Calculate sequence lengths for entities and global prompt - entity_seq_lens = [int(mask.sum(dim=1).item()) for mask in entity_prompt_emb_mask] + # Use [0] to get first batch element (all batches should have same sequence lengths) + entity_seq_lens = [int(mask.sum(dim=1)[0].item()) for mask in entity_prompt_emb_mask] # Handle None case in ComfyUI (None means no padding, all tokens valid) if prompt_emb_mask is not None: - global_seq_len = int(prompt_emb_mask.sum(dim=1).item()) + global_seq_len = int(prompt_emb_mask.sum(dim=1)[0].item()) else: # No mask = no padding, use full sequence length global_seq_len = int(prompt_emb.shape[1]) # Get base image RoPE using global prompt length (returns tuple: (img_freqs, txt_freqs)) + # We pass a single shape, not repeated for batch, because RoPE will broadcast txt_seq_lens = [global_seq_len] - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device) + image_rotary_emb = self.pos_embed(img_shapes_single, txt_seq_lens, device=latents.device) # Create SEPARATE RoPE embeddings for each entity # Each entity gets its own positional encoding based on its sequence length - entity_rotary_emb = [self.pos_embed(img_shapes, [entity_seq_len], device=latents.device)[1] + # We only need to create these once since they're the same for all batch elements + entity_rotary_emb = [self.pos_embed([(1, patch_h, patch_w)], [entity_seq_len], device=latents.device)[1] for entity_seq_len in entity_seq_lens] # Concatenate entity RoPEs with global RoPE along sequence dimension # Result: [entity1_seq, entity2_seq, ..., global_seq] concatenated + # This creates the RoPE for ONE batch element's sequence + # Note: We DON'T repeat for batch_size because the queries/keys have shape [batch, seq, ...] + # and PyTorch will broadcast the RoPE [seq, ...] across the batch dimension automatically txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0) + logger.debug( + f"[EliGen Model] RoPE created for single batch element - " + f"img: {image_rotary_emb[0].shape}, txt: {txt_rotary_emb.shape} " + f"(both will broadcast across batch_size={batch_size})" + ) + # Replace text part of tuple with concatenated entity + global RoPE image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb) # SECTION 3: Prepare spatial masks - repeat_dim = latents.shape[1] # 16 + repeat_dim = latents.shape[1] # 16 (latent channels) max_masks = entity_masks.shape[1] # N entities entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) # Pad masks to match padded latent dimensions # entity_masks shape: [1, N, 16, H/8, W/8] # Need to pad to match orig_shape which is [B, 16, padded_H/8, padded_W/8] - padded_h = height // 8 - padded_w = width // 8 + padded_h = height // self.LATENT_TO_PIXEL_RATIO + padded_w = width // self.LATENT_TO_PIXEL_RATIO if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w: assert entity_masks.shape[3] <= padded_h and entity_masks.shape[4] <= padded_w, \ f"Entity masks {entity_masks.shape[3]}x{entity_masks.shape[4]} larger than padded dims {padded_h}x{padded_w}" @@ -602,6 +637,7 @@ class QwenImageTransformer2DModel(nn.Module): # Pad each entity mask pad_h = padded_h - entity_masks.shape[3] pad_w = padded_w - entity_masks.shape[4] + logger.debug(f"[EliGen Model] Padding entity masks by ({pad_h}, {pad_w}) to match latent dimensions") entity_masks = torch.nn.functional.pad(entity_masks, (0, pad_w, 0, pad_h), mode='constant', value=0) entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] @@ -617,12 +653,20 @@ class QwenImageTransformer2DModel(nn.Module): seq_lens = entity_seq_lens + [global_seq_len] total_seq_len = int(sum(seq_lens) + image.shape[1]) + logger.debug( + f"[EliGen Model] Building attention mask: " + f"total_seq={total_seq_len} (entities: {entity_seq_lens}, global: {global_seq_len}, image: {image.shape[1]})" + ) + patched_masks = [] for i in range(N): patched_mask = rearrange( entity_masks[i], "B C (H P) (W Q) -> B (H W) (C P Q)", - H=height//16, W=width//16, P=2, Q=2 + H=height // self.PATCH_TO_PIXEL_RATIO, + W=width // self.PATCH_TO_PIXEL_RATIO, + P=self.PATCH_TO_LATENT_RATIO, + Q=self.PATCH_TO_LATENT_RATIO ) patched_masks.append(patched_mask) @@ -671,10 +715,16 @@ class QwenImageTransformer2DModel(nn.Module): # SECTION 6: Convert to additive bias attention_mask = attention_mask.float() + num_valid_connections = (attention_mask == 1).sum().item() attention_mask[attention_mask == 0] = float('-inf') attention_mask[attention_mask == 1] = 0 attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1) + logger.debug( + f"[EliGen Model] Attention mask created: shape={attention_mask.shape}, " + f"valid_connections={num_valid_connections}/{total_seq_len * total_seq_len}" + ) + return all_prompt_emb, image_rotary_emb, attention_mask def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs): diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index d90707a49..f59c84d54 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -3,9 +3,13 @@ import comfy.utils import comfy.conds import math import torch +import logging +from typing import Optional from typing_extensions import override from comfy_api.latest import ComfyExtension, io +logger = logging.getLogger(__name__) + class TextEncodeQwenImageEdit(io.ComfyNode): @classmethod @@ -105,8 +109,31 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode): conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True) return io.NodeOutput(conditioning) -################ NEW class TextEncodeQwenImageEliGen(io.ComfyNode): + """ + Entity-Level Image Generation (EliGen) conditioning node for Qwen Image model. + + Allows specifying different prompts for different spatial regions using masks. + Each entity (mask + prompt pair) will only influence its masked region through + spatial attention masking. + + Features: + - Supports up to 3 entities per generation + - Spatial attention masks prevent cross-entity contamination + - Separate RoPE embeddings per entity (research-accurate) + - Falls back to standard generation if no entities provided + + Usage: + 1. Create spatial masks using LoadImageMask (white=entity, black=background) + 2. Use 'red', 'green', or 'blue' channel (NOT 'alpha' - it gets inverted) + 3. Provide entity-specific prompts for each masked region + + Based on DiffSynth Studio: https://github.com/modelscope/DiffSynth-Studio + """ + + # Qwen Image model uses 2x2 patches on latents (which are 8x downsampled from pixels) + PATCH_SIZE = 2 + @classmethod def define_schema(cls): return io.Schema( @@ -129,8 +156,18 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): ) @classmethod - def execute(cls, clip, global_conditioning, latent, entity_prompt_1="", entity_mask_1=None, - entity_prompt_2="", entity_mask_2=None, entity_prompt_3="", entity_mask_3=None) -> io.NodeOutput: + def execute( + cls, + clip, + global_conditioning, + latent, + entity_prompt_1: str = "", + entity_mask_1: Optional[torch.Tensor] = None, + entity_prompt_2: str = "", + entity_mask_2: Optional[torch.Tensor] = None, + entity_prompt_3: str = "", + entity_mask_3: Optional[torch.Tensor] = None + ) -> io.NodeOutput: # Extract dimensions from latent tensor # latent["samples"] shape: [batch, channels, latent_h, latent_w] @@ -139,10 +176,9 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): unpadded_latent_width = latent_samples.shape[3] # Unpadded latent space # Calculate padded dimensions (same logic as model's pad_to_patch_size with patch_size=2) - # The model pads latents to be multiples of patch_size (2 for Qwen) - patch_size = 2 - pad_h = (patch_size - unpadded_latent_height % patch_size) % patch_size - pad_w = (patch_size - unpadded_latent_width % patch_size) % patch_size + # The model pads latents to be multiples of PATCH_SIZE + pad_h = (cls.PATCH_SIZE - unpadded_latent_height % cls.PATCH_SIZE) % cls.PATCH_SIZE + pad_w = (cls.PATCH_SIZE - unpadded_latent_width % cls.PATCH_SIZE) % cls.PATCH_SIZE latent_height = unpadded_latent_height + pad_h # Padded latent dimensions latent_width = unpadded_latent_width + pad_w # Padded latent dimensions @@ -150,8 +186,8 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): width = latent_width * 8 if pad_h > 0 or pad_w > 0: - print(f"[EliGen] Latent padding detected: {unpadded_latent_height}x{unpadded_latent_width} → {latent_height}x{latent_width}") - print(f"[EliGen] Target generation dimensions: {height}x{width} pixels ({latent_height}x{latent_width} latent)") + logger.debug(f"[EliGen] Latent padding detected: {unpadded_latent_height}x{unpadded_latent_width} → {latent_height}x{latent_width}") + logger.debug(f"[EliGen] Target generation dimensions: {height}x{width} pixels ({latent_height}x{latent_width} latent)") # Collect entity prompts and masks entity_prompts = [entity_prompt_1, entity_prompt_2, entity_prompt_3] @@ -166,7 +202,7 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): # Log warning if some entities were skipped total_prompts_provided = len([p for p in entity_prompts if p.strip()]) if len(valid_entities) < total_prompts_provided: - print(f"[EliGen] Warning: Only {len(valid_entities)} of {total_prompts_provided} entity prompts have valid masks") + logger.warning(f"[EliGen] Only {len(valid_entities)} of {total_prompts_provided} entity prompts have valid masks") # If no valid entities, return standard conditioning if len(valid_entities) == 0: @@ -200,7 +236,37 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): # This is different from IMAGE type which is [batch, height, width, channels] mask_tensor = mask - # Log original mask dimensions + # Validate mask dtype + if mask_tensor.dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise TypeError( + f"Entity {i+1} mask has invalid dtype {mask_tensor.dtype}. " + f"Expected float32, float16, or bfloat16. " + f"Ensure you're using LoadImageMask node, not LoadImage." + ) + + # Log original mask statistics + logger.debug( + f"[EliGen] Entity {i+1} input mask: shape={mask_tensor.shape}, " + f"dtype={mask_tensor.dtype}, min={mask_tensor.min():.4f}, max={mask_tensor.max():.4f}" + ) + + # Check for all-zero masks (common error when wrong channel selected) + if mask_tensor.max() == 0.0: + raise ValueError( + f"Entity {i+1} mask is all zeros! This usually means:\n" + f" 1. Wrong channel selected in LoadImageMask (use 'red', 'green', or 'blue', NOT 'alpha')\n" + f" 2. Your mask image is completely black\n" + f" 3. The mask file failed to load" + ) + + # Check for constant masks (no variation) + if mask_tensor.min() == mask_tensor.max() and mask_tensor.max() > 0: + logger.warning( + f"[EliGen] Entity {i+1} mask has no variation (all pixels = {mask_tensor.min():.4f}). " + f"This entity will affect the entire image." + ) + + # Extract original dimensions original_shape = mask_tensor.shape if len(original_shape) == 2: # [height, width] - single mask without batch @@ -211,7 +277,20 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): # [batch, height, width] - standard MASK format orig_h, orig_w = original_shape[1], original_shape[2] else: - raise ValueError(f"Unexpected mask shape: {original_shape}. Expected [H, W] or [B, H, W]") + raise ValueError( + f"Entity {i+1} has unexpected mask shape: {original_shape}. " + f"Expected [H, W] or [B, H, W]. Got {len(original_shape)} dimensions." + ) + + # Log size mismatch if mask doesn't match expected latent dimensions + expected_h, expected_w = latent_height * 8, latent_width * 8 + if orig_h != expected_h or orig_w != expected_w: + logger.info( + f"[EliGen] Entity {i+1} mask size mismatch: {orig_h}x{orig_w} vs expected {expected_h}x{expected_w}. " + f"Will resize to {latent_height}x{latent_width} latent space." + ) + else: + logger.debug(f"[EliGen] Entity {i+1} mask: {orig_h}x{orig_w} → will resize to {latent_height}x{latent_width} latent") # Convert MASK format [batch, height, width] to [batch, 1, height, width] for common_upscale # common_upscale expects [batch, channels, height, width] @@ -233,17 +312,32 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): # Log how many pixels are active in the mask active_pixels = (resized_mask > 0).sum().item() total_pixels = resized_mask.numel() + coverage_pct = 100 * active_pixels / total_pixels if total_pixels > 0 else 0 + + if active_pixels == 0: + raise ValueError( + f"Entity {i+1} mask has no active pixels after resizing to latent space! " + f"Original mask may have been too small or all black." + ) + + logger.debug( + f"[EliGen] Entity {i+1} mask coverage: {active_pixels}/{total_pixels} pixels ({coverage_pct:.1f}%)" + ) processed_entity_masks.append(resized_mask) # Stack masks: [batch, num_entities, 1, latent_height, latent_width] # Each item in processed_entity_masks has shape [1, 1, H, W] (batch=1, channel=1) # We need to remove batch dim, stack, then add it back - # Option 1: Squeeze batch dim from each mask - processed_no_batch = [m.squeeze(0) for m in processed_entity_masks] # Each: [1, H, W] - entity_masks_tensor = torch.stack(processed_no_batch, dim=0) # [num_entities, 1, H, W] + processed_entity_masks_no_batch = [m.squeeze(0) for m in processed_entity_masks] # Each: [1, H, W] + entity_masks_tensor = torch.stack(processed_entity_masks_no_batch, dim=0) # [num_entities, 1, H, W] entity_masks_tensor = entity_masks_tensor.unsqueeze(0) # [1, num_entities, 1, H, W] + logger.debug( + f"[EliGen] Stacked {len(valid_entities)} entity masks into tensor: " + f"shape={entity_masks_tensor.shape} (expected: [1, {len(valid_entities)}, 1, {latent_height}, {latent_width}])" + ) + # Extract global prompt embedding and mask from conditioning # Conditioning format: [[cond_tensor, extra_dict]] global_prompt_emb = global_conditioning[0][0] # The embedding tensor directly From 99a25a3dc4f313e9422fc431540aa83a46a6a15b Mon Sep 17 00:00:00 2001 From: nolan4 Date: Mon, 27 Oct 2025 09:54:41 -0700 Subject: [PATCH 06/12] Fix whitespace lint error --- comfy_extras/nodes_qwen.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index f59c84d54..03ca00f73 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -215,7 +215,6 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): for entity_prompt, _ in valid_entities: # mask not used at this point entity_tokens = clip.tokenize(entity_prompt) entity_cond_dict = clip.encode_from_tokens(entity_tokens, return_pooled=True, return_dict=True) - entity_prompt_emb = entity_cond_dict["cond"] entity_prompt_emb_mask = entity_cond_dict.get("attention_mask", None) From 6c0912107033747015d65b035a0ab62bbf513b96 Mon Sep 17 00:00:00 2001 From: nolan4 Date: Mon, 27 Oct 2025 20:19:12 -0700 Subject: [PATCH 07/12] replace QwenEmbedRope with existing ComfyUI rope --- comfy/ldm/qwen_image/model.py | 328 +++++++++++----------------------- 1 file changed, 109 insertions(+), 219 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 42553154e..45996e23b 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -16,125 +16,6 @@ import comfy.patcher_extension logger = logging.getLogger(__name__) -class QwenEmbedRope(nn.Module): - """RoPE implementation for EliGen. - https://github.com/modelscope/DiffSynth-Studio/blob/538017177a6136f45f57cdf0b7c4e0d7e1f8b50d/diffsynth/models/qwen_image_dit.py#L61 - Returns a tuple (img_freqs, txt_freqs) for separate image and text RoPE. - """ - def __init__(self, theta: int, axes_dim: list, scale_rope=False): - super().__init__() - self.theta = theta - self.axes_dim = axes_dim - pos_index = torch.arange(4096) - neg_index = torch.arange(4096).flip(0) * -1 - 1 - self.pos_freqs = torch.cat([ - self.rope_params(pos_index, self.axes_dim[0], self.theta), - self.rope_params(pos_index, self.axes_dim[1], self.theta), - self.rope_params(pos_index, self.axes_dim[2], self.theta), - ], dim=1) - self.neg_freqs = torch.cat([ - self.rope_params(neg_index, self.axes_dim[0], self.theta), - self.rope_params(neg_index, self.axes_dim[1], self.theta), - self.rope_params(neg_index, self.axes_dim[2], self.theta), - ], dim=1) - self.rope_cache = {} - self.scale_rope = scale_rope - - def rope_params(self, index, dim, theta=10000): - """ - Args: - index: [0, 1, 2, 3] 1D Tensor representing the position index of the token - - Returns: - Real-valued 2x2 rotation matrix format [..., 2, 2] compatible with apply_rotary_emb - """ - assert dim % 2 == 0 - freqs = torch.outer( - index, - 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)) - ) - # Convert to real-valued rotation matrix format (matches Flux rope() output) - # Rotation matrix: [[cos, -sin], [sin, cos]] - cos_freqs = torch.cos(freqs) - sin_freqs = torch.sin(freqs) - # Stack as rotation matrix: [cos, -sin, sin, cos] then reshape to [..., 2, 2] - out = torch.stack([cos_freqs, -sin_freqs, sin_freqs, cos_freqs], dim=-1) - out = out.reshape(*freqs.shape, 2, 2) - return out - - def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens): - if isinstance(video_fhw, list): - video_fhw = tuple(max([i[j] for i in video_fhw]) for j in range(3)) - _, height, width = video_fhw - if self.scale_rope: - max_vid_index = max(height // 2, width // 2) - else: - max_vid_index = max(height, width) - required_len = max_vid_index + max(txt_seq_lens) - cur_max_len = self.pos_freqs.shape[0] - if required_len <= cur_max_len: - return - - new_max_len = math.ceil(required_len / 512) * 512 - pos_index = torch.arange(new_max_len) - neg_index = torch.arange(new_max_len).flip(0) * -1 - 1 - self.pos_freqs = torch.cat([ - self.rope_params(pos_index, self.axes_dim[0], self.theta), - self.rope_params(pos_index, self.axes_dim[1], self.theta), - self.rope_params(pos_index, self.axes_dim[2], self.theta), - ], dim=1) - self.neg_freqs = torch.cat([ - self.rope_params(neg_index, self.axes_dim[0], self.theta), - self.rope_params(neg_index, self.axes_dim[1], self.theta), - self.rope_params(neg_index, self.axes_dim[2], self.theta), - ], dim=1) - return - - def forward(self, video_fhw, txt_seq_lens, device): - self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens) - if self.pos_freqs.device != device: - self.pos_freqs = self.pos_freqs.to(device) - self.neg_freqs = self.neg_freqs.to(device) - - vid_freqs = [] - max_vid_index = 0 - for idx, fhw in enumerate(video_fhw): - frame, height, width = fhw - rope_key = f"{idx}_{height}_{width}" - - if rope_key not in self.rope_cache: - seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) - if self.scale_rope: - freqs_height = torch.cat( - [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0 - ) - freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) - freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) - freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) - - else: - freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) - freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) - - freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1, 2, 2) - self.rope_cache[rope_key] = freqs.clone().contiguous() - vid_freqs.append(self.rope_cache[rope_key]) - - if self.scale_rope: - max_vid_index = max(height // 2, width // 2, max_vid_index) - else: - max_vid_index = max(height, width, max_vid_index) - - max_len = max(txt_seq_lens) - txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] - vid_freqs = torch.cat(vid_freqs, dim=0) - - return vid_freqs, txt_freqs - - class GELU(nn.Module): def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): super().__init__() @@ -477,9 +358,7 @@ class QwenImageTransformer2DModel(nn.Module): self.inner_dim = num_attention_heads * attention_head_dim self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope)) - # Add research-accurate RoPE for EliGen (returns tuple of img_freqs, txt_freqs) - self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16, 56, 56], scale_rope=True) - + self.time_text_embed = QwenTimestepProjEmbeddings( embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim, @@ -529,134 +408,109 @@ class QwenImageTransformer2DModel(nn.Module): return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, - entity_prompt_emb_mask, entity_masks, height, width, image): + entity_prompt_emb_mask, entity_masks, height, width, image, + cond_or_uncond=None, batch_size=None): """ - https://github.com/modelscope/DiffSynth-Studio/blob/538017177a6136f45f57cdf0b7c4e0d7e1f8b50d/diffsynth/models/qwen_image_dit.py#L434 Process entity masks and build spatial attention mask for EliGen. - This method: - 1. Concatenates entity + global prompts - 2. Builds RoPE embeddings for concatenated text using ComfyUI's pe_embedder - 3. Creates attention mask enforcing spatial restrictions + Concatenates entity+global prompts, builds RoPE embeddings, creates attention mask + enforcing spatial restrictions, and handles CFG batching with separate masks. - Args: - latents: [B, 16, H, W] - prompt_emb: [1, seq_len, 3584] - Global prompt - prompt_emb_mask: [1, seq_len] - entity_prompt_emb: List[[1, L_i, 3584]] - Entity prompts - entity_prompt_emb_mask: List[[1, L_i]] - entity_masks: [1, N, 1, H/8, W/8] - height: int (padded pixel height) - width: int (padded pixel width) - image: [B, patches, 64] - Patchified latents - - Returns: - all_prompt_emb: [1, total_seq, 3584] - image_rotary_emb: RoPE embeddings - attention_mask: [1, 1, total_seq, total_seq] + Based on: https://github.com/modelscope/DiffSynth-Studio """ num_entities = len(entity_prompt_emb) - batch_size = latents.shape[0] + actual_batch_size = latents.shape[0] + + has_positive = cond_or_uncond and 0 in cond_or_uncond + has_negative = cond_or_uncond and 1 in cond_or_uncond + is_cfg_batched = has_positive and has_negative + logger.debug( - f"[EliGen Model] Processing {num_entities} entities for {height}x{width}px image " - f"(latents: {latents.shape}, batch_size: {batch_size})" + f"[EliGen Model] Processing {num_entities} entities for {height}x{width}px, " + f"batch_size={actual_batch_size}, CFG_batched={is_cfg_batched}" ) - # Validate batch consistency (all batches should have same sequence lengths) - # This is a ComfyUI requirement - batched prompts must have uniform padding - if batch_size > 1: - logger.debug(f"[EliGen Model] Batch size > 1 detected ({batch_size} batches), ensuring RoPE compatibility") - - # SECTION 1: Concatenate entity + global prompts + # Concatenate entity + global prompts all_prompt_emb = entity_prompt_emb + [prompt_emb] all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb] all_prompt_emb = torch.cat(all_prompt_emb, dim=1) - # SECTION 2: Build RoPE position embeddings - # For EliGen, we create RoPE for ONE batch element's dimensions - # The queries/keys have shape [batch, seq, heads, dim], and RoPE broadcasts across batch dim + # Build RoPE embeddings patch_h = height // self.PATCH_TO_PIXEL_RATIO patch_w = width // self.PATCH_TO_PIXEL_RATIO - # Create RoPE for a single image (frame=1 for images, not video) - # This will broadcast across all batch elements automatically - img_shapes_single = [(1, patch_h, patch_w)] - - # Calculate sequence lengths for entities and global prompt - # Use [0] to get first batch element (all batches should have same sequence lengths) entity_seq_lens = [int(mask.sum(dim=1)[0].item()) for mask in entity_prompt_emb_mask] - # Handle None case in ComfyUI (None means no padding, all tokens valid) if prompt_emb_mask is not None: global_seq_len = int(prompt_emb_mask.sum(dim=1)[0].item()) else: - # No mask = no padding, use full sequence length global_seq_len = int(prompt_emb.shape[1]) - # Get base image RoPE using global prompt length (returns tuple: (img_freqs, txt_freqs)) - # We pass a single shape, not repeated for batch, because RoPE will broadcast - txt_seq_lens = [global_seq_len] - image_rotary_emb = self.pos_embed(img_shapes_single, txt_seq_lens, device=latents.device) + max_vid_index = max(patch_h // 2, patch_w // 2) - # Create SEPARATE RoPE embeddings for each entity - # Each entity gets its own positional encoding based on its sequence length - # We only need to create these once since they're the same for all batch elements - entity_rotary_emb = [self.pos_embed([(1, patch_h, patch_w)], [entity_seq_len], device=latents.device)[1] - for entity_seq_len in entity_seq_lens] + # Generate per-entity text RoPE (each entity starts from same offset) + entity_txt_embs = [] + for entity_seq_len in entity_seq_lens: + entity_ids = torch.arange( + max_vid_index, + max_vid_index + entity_seq_len, + device=latents.device + ).reshape(1, -1, 1).repeat(1, 1, 3) - # Concatenate entity RoPEs with global RoPE along sequence dimension - # Result: [entity1_seq, entity2_seq, ..., global_seq] concatenated - # This creates the RoPE for ONE batch element's sequence - # Note: We DON'T repeat for batch_size because the queries/keys have shape [batch, seq, ...] - # and PyTorch will broadcast the RoPE [seq, ...] across the batch dimension automatically - txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0) + entity_rope = self.pe_embedder(entity_ids).squeeze(1).squeeze(0) + entity_txt_embs.append(entity_rope) - logger.debug( - f"[EliGen Model] RoPE created for single batch element - " - f"img: {image_rotary_emb[0].shape}, txt: {txt_rotary_emb.shape} " - f"(both will broadcast across batch_size={batch_size})" - ) + # Generate global text RoPE + global_ids = torch.arange( + max_vid_index, + max_vid_index + global_seq_len, + device=latents.device + ).reshape(1, -1, 1).repeat(1, 1, 3) + global_rope = self.pe_embedder(global_ids).squeeze(1).squeeze(0) - # Replace text part of tuple with concatenated entity + global RoPE - image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb) + txt_rotary_emb = torch.cat(entity_txt_embs + [global_rope], dim=0) - # SECTION 3: Prepare spatial masks - repeat_dim = latents.shape[1] # 16 (latent channels) - max_masks = entity_masks.shape[1] # N entities + h_coords = torch.arange(-(patch_h - patch_h // 2), patch_h // 2, device=latents.device) + w_coords = torch.arange(-(patch_w - patch_w // 2), patch_w // 2, device=latents.device) + + img_ids = torch.zeros((patch_h, patch_w, 3), device=latents.device) + img_ids[:, :, 0] = 0 + img_ids[:, :, 1] = h_coords.unsqueeze(1) + img_ids[:, :, 2] = w_coords.unsqueeze(0) + img_ids = img_ids.reshape(1, -1, 3) + + img_rope = self.pe_embedder(img_ids).squeeze(1).squeeze(0) + + logger.debug(f"[EliGen Model] RoPE shapes - img: {img_rope.shape}, txt: {txt_rotary_emb.shape}") + + image_rotary_emb = (img_rope, txt_rotary_emb) + + # Prepare spatial masks + repeat_dim = latents.shape[1] + max_masks = entity_masks.shape[1] entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) - # Pad masks to match padded latent dimensions - # entity_masks shape: [1, N, 16, H/8, W/8] - # Need to pad to match orig_shape which is [B, 16, padded_H/8, padded_W/8] padded_h = height // self.LATENT_TO_PIXEL_RATIO padded_w = width // self.LATENT_TO_PIXEL_RATIO if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w: - assert entity_masks.shape[3] <= padded_h and entity_masks.shape[4] <= padded_w, \ - f"Entity masks {entity_masks.shape[3]}x{entity_masks.shape[4]} larger than padded dims {padded_h}x{padded_w}" - - # Pad each entity mask pad_h = padded_h - entity_masks.shape[3] pad_w = padded_w - entity_masks.shape[4] - logger.debug(f"[EliGen Model] Padding entity masks by ({pad_h}, {pad_w}) to match latent dimensions") + logger.debug(f"[EliGen Model] Padding masks by ({pad_h}, {pad_w})") entity_masks = torch.nn.functional.pad(entity_masks, (0, pad_w, 0, pad_h), mode='constant', value=0) entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] - # Add global mask (all True) - must be same size as padded entity masks global_mask = torch.ones((entity_masks[0].shape[0], entity_masks[0].shape[1], padded_h, padded_w), device=latents.device, dtype=latents.dtype) entity_masks = entity_masks + [global_mask] - # SECTION 4: Patchify masks + # Patchify masks N = len(entity_masks) batch_size = int(entity_masks[0].shape[0]) seq_lens = entity_seq_lens + [global_seq_len] total_seq_len = int(sum(seq_lens) + image.shape[1]) - logger.debug( - f"[EliGen Model] Building attention mask: " - f"total_seq={total_seq_len} (entities: {entity_seq_lens}, global: {global_seq_len}, image: {image.shape[1]})" - ) + logger.debug(f"[EliGen Model] total_seq={total_seq_len}") patched_masks = [] for i in range(N): @@ -694,7 +548,7 @@ class QwenImageTransformer2DModel(nn.Module): image_mask = torch.sum(patched_masks[i], dim=-1) > 0 image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1) - # Always repeat mask to match image sequence length (matches DiffSynth line 480) + # Always repeat mask to match image sequence length repeat_time = single_image_seq // image_mask.shape[-1] image_mask = image_mask.repeat(1, 1, repeat_time) @@ -713,12 +567,44 @@ class QwenImageTransformer2DModel(nn.Module): start_j, end_j = cumsum[j], cumsum[j+1] attention_mask[:, start_i:end_i, start_j:end_j] = False - # SECTION 6: Convert to additive bias + # SECTION 6: Convert to additive bias and handle CFG batching attention_mask = attention_mask.float() num_valid_connections = (attention_mask == 1).sum().item() attention_mask[attention_mask == 0] = float('-inf') attention_mask[attention_mask == 1] = 0 - attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1) + attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype) + + # Handle CFG batching: Create separate masks for positive and negative + if is_cfg_batched and actual_batch_size > 1: + # CFG batch: [positive, negative] - need different masks for each + # Positive gets entity constraints, negative gets standard attention (all zeros) + + logger.debug( + f"[EliGen Model] CFG batched detected - creating separate masks. " + f"Positive (index 0) gets entity mask, Negative (index 1) gets standard mask" + ) + + # Create standard attention mask (all zeros = no constraints) + standard_mask = torch.zeros_like(attention_mask) + + # Stack masks according to cond_or_uncond order + mask_list = [] + for cond_type in cond_or_uncond: + if cond_type == 0: # Positive - use entity mask + mask_list.append(attention_mask[0:1]) # Take first (and only) entity mask + else: # Negative - use standard mask + mask_list.append(standard_mask[0:1]) + + # Concatenate masks to match batch + attention_mask = torch.cat(mask_list, dim=0) + + logger.debug( + f"[EliGen Model] Created {len(mask_list)} masks for CFG batch. " + f"Final shape: {attention_mask.shape}" + ) + + # Add head dimension: [B, 1, seq, seq] + attention_mask = attention_mask.unsqueeze(1) logger.debug( f"[EliGen Model] Attention mask created: shape={attention_mask.shape}, " @@ -778,23 +664,28 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = torch.cat([hidden_states, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) - # Extract entity data from kwargs + # Extract EliGen entity data entity_prompt_emb = kwargs.get("entity_prompt_emb", None) entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None) entity_masks = kwargs.get("entity_masks", None) - # Branch: EliGen vs Standard path - # Only apply EliGen to POSITIVE conditioning (cond_or_uncond contains 0) + # Detect batch composition for CFG handling cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else [] - is_positive_cond = 0 in cond_or_uncond # 0 = conditional/positive, 1 = unconditional/negative + is_positive_cond = 0 in cond_or_uncond + is_negative_cond = 1 in cond_or_uncond + batch_size = x.shape[0] + + if entity_prompt_emb is not None: + logger.debug( + f"[EliGen Forward] batch_size={batch_size}, cond_or_uncond={cond_or_uncond}, " + f"has_positive={is_positive_cond}, has_negative={is_negative_cond}" + ) if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond: - # EliGen path - process entity masks (POSITIVE CONDITIONING ONLY) - # orig_shape is from process_img which pads to patch_size - height = int(orig_shape[-2] * 8) # Padded latent height -> pixel height (ensure int) - width = int(orig_shape[-1] * 8) # Padded latent width -> pixel width (ensure int) + # EliGen path + height = int(orig_shape[-2] * 8) + width = int(orig_shape[-1] * 8) - # Call process_entity_masks to get concatenated text, RoPE, and attention mask encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks( latents=x, prompt_emb=encoder_hidden_states, @@ -804,22 +695,21 @@ class QwenImageTransformer2DModel(nn.Module): entity_masks=entity_masks, height=height, width=width, - image=hidden_states + image=hidden_states, + cond_or_uncond=cond_or_uncond, + batch_size=batch_size ) - # Apply image projection (text already processed in process_entity_masks) hidden_states = self.img_in(hidden_states) - # Store attention mask in transformer_options for the attention layers if transformer_options is None: transformer_options = {} transformer_options["eligen_attention_mask"] = eligen_attention_mask - # Clean up del img_ids else: - # Standard path - existing code + # Standard path txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) From 79c30e16300704a2b36816c93e03a0a7cd30d7d0 Mon Sep 17 00:00:00 2001 From: nolan4 Date: Mon, 27 Oct 2025 22:00:47 -0700 Subject: [PATCH 08/12] removed redundant branch --- comfy/ldm/qwen_image/model.py | 48 ++++++++++------------------------- 1 file changed, 14 insertions(+), 34 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 45996e23b..66cabab43 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -153,36 +153,14 @@ class Attention(nn.Module): txt_query = self.norm_added_q(txt_query) txt_key = self.norm_added_k(txt_key) - # Handle both tuple (EliGen) and single tensor (standard) RoPE formats - if isinstance(image_rotary_emb, tuple): - # EliGen path: Apply RoPE BEFORE concatenation (research-accurate) - # txt/img query/key are in [b, s, h, d] format, compatible with apply_rotary_emb - img_rope, txt_rope = image_rotary_emb + # Concatenate text and image streams + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) - # Add heads dimension to RoPE tensors for broadcasting - # Shape: [s, features, 2, 2] -> [s, 1, features, 2, 2] - # Also convert to match query dtype (e.g., bfloat16) - txt_rope = txt_rope.unsqueeze(1).to(dtype=txt_query.dtype) - img_rope = img_rope.unsqueeze(1).to(dtype=img_query.dtype) - - # Apply RoPE separately to text and image streams - txt_query = apply_rotary_emb(txt_query, txt_rope) - txt_key = apply_rotary_emb(txt_key, txt_rope) - img_query = apply_rotary_emb(img_query, img_rope) - img_key = apply_rotary_emb(img_key, img_rope) - - # Now concatenate - joint_query = torch.cat([txt_query, img_query], dim=1) - joint_key = torch.cat([txt_key, img_key], dim=1) - joint_value = torch.cat([txt_value, img_value], dim=1) - else: - # Standard path: Concatenate first, then apply RoPE - joint_query = torch.cat([txt_query, img_query], dim=1) - joint_key = torch.cat([txt_key, img_key], dim=1) - joint_value = torch.cat([txt_value, img_value], dim=1) - - joint_query = apply_rotary_emb(joint_query, image_rotary_emb) - joint_key = apply_rotary_emb(joint_key, image_rotary_emb) + # Apply RoPE to concatenated queries and keys + joint_query = apply_rotary_emb(joint_query, image_rotary_emb) + joint_key = apply_rotary_emb(joint_key, image_rotary_emb) # Apply EliGen attention mask if present effective_mask = attention_mask @@ -483,7 +461,9 @@ class QwenImageTransformer2DModel(nn.Module): logger.debug(f"[EliGen Model] RoPE shapes - img: {img_rope.shape}, txt: {txt_rotary_emb.shape}") - image_rotary_emb = (img_rope, txt_rotary_emb) + # Concatenate text and image RoPE embeddings + # Convert to latent dtype to match queries/keys + image_rotary_emb = torch.cat([txt_rotary_emb, img_rope], dim=0).unsqueeze(1).to(dtype=latents.dtype) # Prepare spatial masks repeat_dim = latents.shape[1] @@ -524,7 +504,7 @@ class QwenImageTransformer2DModel(nn.Module): ) patched_masks.append(patched_mask) - # SECTION 5: Build attention mask matrix + # Build attention mask matrix attention_mask = torch.ones( (batch_size, total_seq_len, total_seq_len), dtype=torch.bool @@ -539,7 +519,7 @@ class QwenImageTransformer2DModel(nn.Module): for length in seq_lens: cumsum.append(cumsum[-1] + length) - # RULE 1: Spatial restriction (prompt <-> image) + # Spatial restriction (prompt <-> image) for i in range(N): prompt_start = cumsum[i] prompt_end = cumsum[i+1] @@ -558,7 +538,7 @@ class QwenImageTransformer2DModel(nn.Module): # - Image patches can only be updated by prompts that own them attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2) - # RULE 2: Entity isolation + # Entity isolation for i in range(N): for j in range(N): if i == j: @@ -567,7 +547,7 @@ class QwenImageTransformer2DModel(nn.Module): start_j, end_j = cumsum[j], cumsum[j+1] attention_mask[:, start_i:end_i, start_j:end_j] = False - # SECTION 6: Convert to additive bias and handle CFG batching + # Convert to additive bias and handle CFG batching attention_mask = attention_mask.float() num_valid_connections = (attention_mask == 1).sum().item() attention_mask[attention_mask == 0] = float('-inf') From 2d550102fca3cd549d693c990bc2750a4b8aa4b7 Mon Sep 17 00:00:00 2001 From: nolan4 Date: Tue, 28 Oct 2025 19:05:34 -0700 Subject: [PATCH 09/12] resolved Ruff lint errors --- comfy/ldm/qwen_image/model.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 66cabab43..ffa4743dd 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -import math import logging from typing import Optional, Tuple from einops import repeat, rearrange @@ -336,7 +335,6 @@ class QwenImageTransformer2DModel(nn.Module): self.inner_dim = num_attention_heads * attention_head_dim self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope)) - self.time_text_embed = QwenTimestepProjEmbeddings( embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim, @@ -560,8 +558,8 @@ class QwenImageTransformer2DModel(nn.Module): # Positive gets entity constraints, negative gets standard attention (all zeros) logger.debug( - f"[EliGen Model] CFG batched detected - creating separate masks. " - f"Positive (index 0) gets entity mask, Negative (index 1) gets standard mask" + "[EliGen Model] CFG batched detected - creating separate masks. " + "Positive (index 0) gets entity mask, Negative (index 1) gets standard mask" ) # Create standard attention mask (all zeros = no constraints) From 16adfe2153e03b437e39bb81ab3909d25ffdde6c Mon Sep 17 00:00:00 2001 From: nolan4 Date: Tue, 4 Nov 2025 15:46:52 -0800 Subject: [PATCH 10/12] restored whitespace and fixed logging --- comfy/ldm/qwen_image/model.py | 18 ++++++++---------- comfy/model_base.py | 13 +++++++++++++ comfy_extras/nodes_qwen.py | 20 +++++++++----------- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index ffa4743dd..461dde58f 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -12,8 +12,6 @@ from comfy.ldm.flux.layers import EmbedND import comfy.ldm.common_dit import comfy.patcher_extension -logger = logging.getLogger(__name__) - class GELU(nn.Module): def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): @@ -401,7 +399,7 @@ class QwenImageTransformer2DModel(nn.Module): has_negative = cond_or_uncond and 1 in cond_or_uncond is_cfg_batched = has_positive and has_negative - logger.debug( + logging.debug( f"[EliGen Model] Processing {num_entities} entities for {height}x{width}px, " f"batch_size={actual_batch_size}, CFG_batched={is_cfg_batched}" ) @@ -457,7 +455,7 @@ class QwenImageTransformer2DModel(nn.Module): img_rope = self.pe_embedder(img_ids).squeeze(1).squeeze(0) - logger.debug(f"[EliGen Model] RoPE shapes - img: {img_rope.shape}, txt: {txt_rotary_emb.shape}") + logging.debug(f"[EliGen Model] RoPE shapes - img: {img_rope.shape}, txt: {txt_rotary_emb.shape}") # Concatenate text and image RoPE embeddings # Convert to latent dtype to match queries/keys @@ -473,7 +471,7 @@ class QwenImageTransformer2DModel(nn.Module): if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w: pad_h = padded_h - entity_masks.shape[3] pad_w = padded_w - entity_masks.shape[4] - logger.debug(f"[EliGen Model] Padding masks by ({pad_h}, {pad_w})") + logging.debug(f"[EliGen Model] Padding masks by ({pad_h}, {pad_w})") entity_masks = torch.nn.functional.pad(entity_masks, (0, pad_w, 0, pad_h), mode='constant', value=0) entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] @@ -488,7 +486,7 @@ class QwenImageTransformer2DModel(nn.Module): seq_lens = entity_seq_lens + [global_seq_len] total_seq_len = int(sum(seq_lens) + image.shape[1]) - logger.debug(f"[EliGen Model] total_seq={total_seq_len}") + logging.debug(f"[EliGen Model] total_seq={total_seq_len}") patched_masks = [] for i in range(N): @@ -557,7 +555,7 @@ class QwenImageTransformer2DModel(nn.Module): # CFG batch: [positive, negative] - need different masks for each # Positive gets entity constraints, negative gets standard attention (all zeros) - logger.debug( + logging.debug( "[EliGen Model] CFG batched detected - creating separate masks. " "Positive (index 0) gets entity mask, Negative (index 1) gets standard mask" ) @@ -576,7 +574,7 @@ class QwenImageTransformer2DModel(nn.Module): # Concatenate masks to match batch attention_mask = torch.cat(mask_list, dim=0) - logger.debug( + logging.debug( f"[EliGen Model] Created {len(mask_list)} masks for CFG batch. " f"Final shape: {attention_mask.shape}" ) @@ -584,7 +582,7 @@ class QwenImageTransformer2DModel(nn.Module): # Add head dimension: [B, 1, seq, seq] attention_mask = attention_mask.unsqueeze(1) - logger.debug( + logging.debug( f"[EliGen Model] Attention mask created: shape={attention_mask.shape}, " f"valid_connections={num_valid_connections}/{total_seq_len * total_seq_len}" ) @@ -654,7 +652,7 @@ class QwenImageTransformer2DModel(nn.Module): batch_size = x.shape[0] if entity_prompt_emb is not None: - logger.debug( + logging.debug( f"[EliGen Forward] batch_size={batch_size}, cond_or_uncond={cond_or_uncond}, " f"has_positive={is_positive_cond}, has_negative={is_negative_cond}" ) diff --git a/comfy/model_base.py b/comfy/model_base.py index 9a4010843..ab30a6f97 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -119,6 +119,7 @@ def convert_tensor(extra, dtype, device): extra = comfy.model_management.cast_to_device(extra, device, None) return extra + class BaseModel(torch.nn.Module): def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel): super().__init__() @@ -380,6 +381,7 @@ class BaseModel(torch.nn.Module): def extra_conds_shapes(self, **kwargs): return {} + def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None): adm_inputs = [] weights = [] @@ -475,6 +477,7 @@ class SDXL(BaseModel): flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) + class SVD_img2vid(BaseModel): def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): super().__init__(model_config, model_type, device=device) @@ -551,6 +554,7 @@ class SV3D_p(SVD_img2vid): out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out)) return torch.cat(out, dim=1) + class Stable_Zero123(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None): super().__init__(model_config, model_type, device=device) @@ -634,11 +638,13 @@ class IP2P: image = utils.resize_to_batch_size(image, noise.shape[0]) return self.process_ip2p_image_in(image) + class SD15_instructpix2pix(IP2P, BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device) self.process_ip2p_image_in = lambda image: image + class SDXL_instructpix2pix(IP2P, SDXL): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device) @@ -688,6 +694,7 @@ class StableCascade_C(BaseModel): out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn) return out + class StableCascade_B(BaseModel): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): super().__init__(model_config, model_type, device=device, unet_model=StageB) @@ -707,6 +714,7 @@ class StableCascade_B(BaseModel): out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,))) return out + class SD3(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=OpenAISignatureMMDITWrapper) @@ -721,6 +729,7 @@ class SD3(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out + class AuraFlow(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.aura.mmdit.MMDiT) @@ -732,6 +741,7 @@ class AuraFlow(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out + class StableAudio1(BaseModel): def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer) @@ -770,6 +780,7 @@ class StableAudio1(BaseModel): sd["{}{}".format(k, l)] = s[l] return sd + class HunyuanDiT(BaseModel): def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT) @@ -903,6 +914,7 @@ class Flux(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out + class GenmoMochi(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint) @@ -1154,6 +1166,7 @@ class WAN21(BaseModel): return out + class WAN21_Vace(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel) diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index 03ca00f73..8671d60ae 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -8,8 +8,6 @@ from typing import Optional from typing_extensions import override from comfy_api.latest import ComfyExtension, io -logger = logging.getLogger(__name__) - class TextEncodeQwenImageEdit(io.ComfyNode): @classmethod @@ -186,8 +184,8 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): width = latent_width * 8 if pad_h > 0 or pad_w > 0: - logger.debug(f"[EliGen] Latent padding detected: {unpadded_latent_height}x{unpadded_latent_width} → {latent_height}x{latent_width}") - logger.debug(f"[EliGen] Target generation dimensions: {height}x{width} pixels ({latent_height}x{latent_width} latent)") + logging.debug(f"[EliGen] Latent padding detected: {unpadded_latent_height}x{unpadded_latent_width} → {latent_height}x{latent_width}") + logging.debug(f"[EliGen] Target generation dimensions: {height}x{width} pixels ({latent_height}x{latent_width} latent)") # Collect entity prompts and masks entity_prompts = [entity_prompt_1, entity_prompt_2, entity_prompt_3] @@ -202,7 +200,7 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): # Log warning if some entities were skipped total_prompts_provided = len([p for p in entity_prompts if p.strip()]) if len(valid_entities) < total_prompts_provided: - logger.warning(f"[EliGen] Only {len(valid_entities)} of {total_prompts_provided} entity prompts have valid masks") + logging.warning(f"[EliGen] Only {len(valid_entities)} of {total_prompts_provided} entity prompts have valid masks") # If no valid entities, return standard conditioning if len(valid_entities) == 0: @@ -244,7 +242,7 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): ) # Log original mask statistics - logger.debug( + logging.debug( f"[EliGen] Entity {i+1} input mask: shape={mask_tensor.shape}, " f"dtype={mask_tensor.dtype}, min={mask_tensor.min():.4f}, max={mask_tensor.max():.4f}" ) @@ -260,7 +258,7 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): # Check for constant masks (no variation) if mask_tensor.min() == mask_tensor.max() and mask_tensor.max() > 0: - logger.warning( + logging.warning( f"[EliGen] Entity {i+1} mask has no variation (all pixels = {mask_tensor.min():.4f}). " f"This entity will affect the entire image." ) @@ -284,12 +282,12 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): # Log size mismatch if mask doesn't match expected latent dimensions expected_h, expected_w = latent_height * 8, latent_width * 8 if orig_h != expected_h or orig_w != expected_w: - logger.info( + logging.info( f"[EliGen] Entity {i+1} mask size mismatch: {orig_h}x{orig_w} vs expected {expected_h}x{expected_w}. " f"Will resize to {latent_height}x{latent_width} latent space." ) else: - logger.debug(f"[EliGen] Entity {i+1} mask: {orig_h}x{orig_w} → will resize to {latent_height}x{latent_width} latent") + logging.debug(f"[EliGen] Entity {i+1} mask: {orig_h}x{orig_w} → will resize to {latent_height}x{latent_width} latent") # Convert MASK format [batch, height, width] to [batch, 1, height, width] for common_upscale # common_upscale expects [batch, channels, height, width] @@ -319,7 +317,7 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): f"Original mask may have been too small or all black." ) - logger.debug( + logging.debug( f"[EliGen] Entity {i+1} mask coverage: {active_pixels}/{total_pixels} pixels ({coverage_pct:.1f}%)" ) @@ -332,7 +330,7 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): entity_masks_tensor = torch.stack(processed_entity_masks_no_batch, dim=0) # [num_entities, 1, H, W] entity_masks_tensor = entity_masks_tensor.unsqueeze(0) # [1, num_entities, 1, H, W] - logger.debug( + logging.debug( f"[EliGen] Stacked {len(valid_entities)} entity masks into tensor: " f"shape={entity_masks_tensor.shape} (expected: [1, {len(valid_entities)}, 1, {latent_height}, {latent_width}])" ) From 9792606847d7799ae98f2156c73e915e58a52c44 Mon Sep 17 00:00:00 2001 From: nolan4 Date: Tue, 4 Nov 2025 19:55:24 -0800 Subject: [PATCH 11/12] added attention_mask to QwenImageTransformerBlock.forward() --- comfy/ldm/qwen_image/model.py | 58 +++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 23 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 461dde58f..76ad3646e 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -159,27 +159,21 @@ class Attention(nn.Module): joint_query = apply_rotary_emb(joint_query, image_rotary_emb) joint_key = apply_rotary_emb(joint_key, image_rotary_emb) - # Apply EliGen attention mask if present - effective_mask = attention_mask - if transformer_options is not None: - eligen_mask = transformer_options.get("eligen_attention_mask", None) - if eligen_mask is not None: - effective_mask = eligen_mask - - # Validate shape - expected_seq = joint_query.shape[1] - if eligen_mask.shape[-1] != expected_seq: - raise ValueError( - f"EliGen attention mask shape mismatch: {eligen_mask.shape} " - f"doesn't match sequence length {expected_seq}" - ) + # Validate attention mask shape if provided + if attention_mask is not None: + expected_seq = joint_query.shape[1] + if attention_mask.shape[-1] != expected_seq: + raise ValueError( + f"Attention mask shape mismatch: {attention_mask.shape} " + f"doesn't match sequence length {expected_seq}" + ) # Use ComfyUI's optimized attention joint_query = joint_query.flatten(start_dim=2) joint_key = joint_key.flatten(start_dim=2) joint_value = joint_value.flatten(start_dim=2) - joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, effective_mask, transformer_options=transformer_options) + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options) txt_attn_output = joint_hidden_states[:, :seq_txt, :] img_attn_output = joint_hidden_states[:, seq_txt:, :] @@ -246,6 +240,7 @@ class QwenImageTransformerBlock(nn.Module): encoder_hidden_states_mask: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, transformer_options={}, ) -> Tuple[torch.Tensor, torch.Tensor]: img_mod_params = self.img_mod(temb) @@ -262,6 +257,7 @@ class QwenImageTransformerBlock(nn.Module): hidden_states=img_modulated, encoder_hidden_states=txt_modulated, encoder_hidden_states_mask=encoder_hidden_states_mask, + attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, transformer_options=transformer_options, ) @@ -640,6 +636,9 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = torch.cat([hidden_states, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) + # Initialize attention mask (None for standard generation) + eligen_attention_mask = None + # Extract EliGen entity data entity_prompt_emb = kwargs.get("entity_prompt_emb", None) entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None) @@ -659,8 +658,8 @@ class QwenImageTransformer2DModel(nn.Module): if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond: # EliGen path - height = int(orig_shape[-2] * 8) - width = int(orig_shape[-1] * 8) + height = int(orig_shape[-2] * self.LATENT_TO_PIXEL_RATIO) + width = int(orig_shape[-1] * self.LATENT_TO_PIXEL_RATIO) encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks( latents=x, @@ -678,10 +677,6 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = self.img_in(hidden_states) - if transformer_options is None: - transformer_options = {} - transformer_options["eligen_attention_mask"] = eligen_attention_mask - del img_ids else: @@ -713,9 +708,25 @@ class QwenImageTransformer2DModel(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"]) + out["txt"], out["img"] = block( + hidden_states=args["img"], + encoder_hidden_states=args["txt"], + encoder_hidden_states_mask=args.get("encoder_hidden_states_mask"), + temb=args["vec"], + image_rotary_emb=args["pe"], + attention_mask=args.get("attention_mask"), + transformer_options=args["transformer_options"] + ) return out - out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({ + "img": hidden_states, + "txt": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "attention_mask": eligen_attention_mask, + "vec": temb, + "pe": image_rotary_emb, + "transformer_options": transformer_options + }, {"original_block": block_wrap}) hidden_states = out["img"] encoder_hidden_states = out["txt"] else: @@ -725,6 +736,7 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, + attention_mask=eligen_attention_mask, transformer_options=transformer_options, ) From 65935d512f980695306c82b75a6b067bdb77dd85 Mon Sep 17 00:00:00 2001 From: nolan4 Date: Sun, 9 Nov 2025 16:10:07 -0800 Subject: [PATCH 12/12] Increase EliGen entity limit to 8 --- comfy_extras/nodes_qwen.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index 8671d60ae..9ad258add 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -116,7 +116,7 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): spatial attention masking. Features: - - Supports up to 3 entities per generation + - Supports up to 8 entities per generation - Spatial attention masks prevent cross-entity contamination - Separate RoPE embeddings per entity (research-accurate) - Falls back to standard generation if no entities provided @@ -147,6 +147,16 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): io.String.Input("entity_prompt_2", multiline=True, dynamic_prompts=True, default=""), io.Mask.Input("entity_mask_3", optional=True), io.String.Input("entity_prompt_3", multiline=True, dynamic_prompts=True, default=""), + io.Mask.Input("entity_mask_4", optional=True), + io.String.Input("entity_prompt_4", multiline=True, dynamic_prompts=True, default=""), + io.Mask.Input("entity_mask_5", optional=True), + io.String.Input("entity_prompt_5", multiline=True, dynamic_prompts=True, default=""), + io.Mask.Input("entity_mask_6", optional=True), + io.String.Input("entity_prompt_6", multiline=True, dynamic_prompts=True, default=""), + io.Mask.Input("entity_mask_7", optional=True), + io.String.Input("entity_prompt_7", multiline=True, dynamic_prompts=True, default=""), + io.Mask.Input("entity_mask_8", optional=True), + io.String.Input("entity_prompt_8", multiline=True, dynamic_prompts=True, default=""), ], outputs=[ io.Conditioning.Output(), @@ -164,7 +174,17 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): entity_prompt_2: str = "", entity_mask_2: Optional[torch.Tensor] = None, entity_prompt_3: str = "", - entity_mask_3: Optional[torch.Tensor] = None + entity_mask_3: Optional[torch.Tensor] = None, + entity_prompt_4: str = "", + entity_mask_4: Optional[torch.Tensor] = None, + entity_prompt_5: str = "", + entity_mask_5: Optional[torch.Tensor] = None, + entity_prompt_6: str = "", + entity_mask_6: Optional[torch.Tensor] = None, + entity_prompt_7: str = "", + entity_mask_7: Optional[torch.Tensor] = None, + entity_prompt_8: str = "", + entity_mask_8: Optional[torch.Tensor] = None ) -> io.NodeOutput: # Extract dimensions from latent tensor @@ -188,8 +208,8 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): logging.debug(f"[EliGen] Target generation dimensions: {height}x{width} pixels ({latent_height}x{latent_width} latent)") # Collect entity prompts and masks - entity_prompts = [entity_prompt_1, entity_prompt_2, entity_prompt_3] - entity_masks_raw = [entity_mask_1, entity_mask_2, entity_mask_3] + entity_prompts = [entity_prompt_1, entity_prompt_2, entity_prompt_3, entity_prompt_4, entity_prompt_5, entity_prompt_6, entity_prompt_7, entity_prompt_8] + entity_masks_raw = [entity_mask_1, entity_mask_2, entity_mask_3, entity_mask_4, entity_mask_5, entity_mask_6, entity_mask_7, entity_mask_8] # Filter out entities with empty prompts or missing masks valid_entities = []