From b22226562807300dafc062d0d316f95f177c6c37 Mon Sep 17 00:00:00 2001 From: nolan4 Date: Fri, 24 Oct 2025 19:22:26 -0700 Subject: [PATCH] qwen eligen batch size > 1 fix --- comfy/ldm/qwen_image/model.py | 90 ++++++++++++++++++------ comfy_extras/nodes_qwen.py | 124 ++++++++++++++++++++++++++++++---- 2 files changed, 179 insertions(+), 35 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 896a22e19..42553154e 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import math +import logging from typing import Optional, Tuple from einops import repeat, rearrange @@ -12,6 +13,8 @@ from comfy.ldm.flux.layers import EmbedND import comfy.ldm.common_dit import comfy.patcher_extension +logger = logging.getLogger(__name__) + class QwenEmbedRope(nn.Module): """RoPE implementation for EliGen. @@ -269,9 +272,6 @@ class Attention(nn.Module): txt_query = self.norm_added_q(txt_query) txt_key = self.norm_added_k(txt_key) - ### NEW - ################################################# - # Handle both tuple (EliGen) and single tensor (standard) RoPE formats if isinstance(image_rotary_emb, tuple): # EliGen path: Apply RoPE BEFORE concatenation (research-accurate) @@ -303,6 +303,7 @@ class Attention(nn.Module): joint_query = apply_rotary_emb(joint_query, image_rotary_emb) joint_key = apply_rotary_emb(joint_key, image_rotary_emb) + # Apply EliGen attention mask if present effective_mask = attention_mask if transformer_options is not None: eligen_mask = transformer_options.get("eligen_attention_mask", None) @@ -312,11 +313,12 @@ class Attention(nn.Module): # Validate shape expected_seq = joint_query.shape[1] if eligen_mask.shape[-1] != expected_seq: - raise ValueError(f"EliGen mask shape {eligen_mask.shape} doesn't match sequence length {expected_seq}") - - ################################################# + raise ValueError( + f"EliGen attention mask shape mismatch: {eligen_mask.shape} " + f"doesn't match sequence length {expected_seq}" + ) - # Standard path: Use ComfyUI's optimized attention + # Use ComfyUI's optimized attention joint_query = joint_query.flatten(start_dim=2) joint_key = joint_key.flatten(start_dim=2) joint_value = joint_value.flatten(start_dim=2) @@ -443,8 +445,12 @@ class LastLayer(nn.Module): x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :]) return x -### NEW changes class QwenImageTransformer2DModel(nn.Module): + # Constants for EliGen processing + LATENT_TO_PIXEL_RATIO = 8 # Latents are 8x downsampled from pixel space + PATCH_TO_LATENT_RATIO = 2 # 2x2 patches in latent space + PATCH_TO_PIXEL_RATIO = 16 # Combined: 2x2 patches on 8x downsampled latents = 16x in pixel space + def __init__( self, patch_size: int = 2, @@ -540,8 +546,8 @@ class QwenImageTransformer2DModel(nn.Module): entity_prompt_emb: List[[1, L_i, 3584]] - Entity prompts entity_prompt_emb_mask: List[[1, L_i]] entity_masks: [1, N, 1, H/8, W/8] - height: int - width: int + height: int (padded pixel height) + width: int (padded pixel width) image: [B, patches, 64] - Patchified latents Returns: @@ -549,6 +555,17 @@ class QwenImageTransformer2DModel(nn.Module): image_rotary_emb: RoPE embeddings attention_mask: [1, 1, total_seq, total_seq] """ + num_entities = len(entity_prompt_emb) + batch_size = latents.shape[0] + logger.debug( + f"[EliGen Model] Processing {num_entities} entities for {height}x{width}px image " + f"(latents: {latents.shape}, batch_size: {batch_size})" + ) + + # Validate batch consistency (all batches should have same sequence lengths) + # This is a ComfyUI requirement - batched prompts must have uniform padding + if batch_size > 1: + logger.debug(f"[EliGen Model] Batch size > 1 detected ({batch_size} batches), ensuring RoPE compatibility") # SECTION 1: Concatenate entity + global prompts all_prompt_emb = entity_prompt_emb + [prompt_emb] @@ -556,45 +573,63 @@ class QwenImageTransformer2DModel(nn.Module): all_prompt_emb = torch.cat(all_prompt_emb, dim=1) # SECTION 2: Build RoPE position embeddings - # Calculate img_shapes for RoPE (batch, height//16, width//16 for images in latent space after patchifying) - img_shapes = [(latents.shape[0], height//16, width//16)] + # For EliGen, we create RoPE for ONE batch element's dimensions + # The queries/keys have shape [batch, seq, heads, dim], and RoPE broadcasts across batch dim + patch_h = height // self.PATCH_TO_PIXEL_RATIO + patch_w = width // self.PATCH_TO_PIXEL_RATIO + + # Create RoPE for a single image (frame=1 for images, not video) + # This will broadcast across all batch elements automatically + img_shapes_single = [(1, patch_h, patch_w)] # Calculate sequence lengths for entities and global prompt - entity_seq_lens = [int(mask.sum(dim=1).item()) for mask in entity_prompt_emb_mask] + # Use [0] to get first batch element (all batches should have same sequence lengths) + entity_seq_lens = [int(mask.sum(dim=1)[0].item()) for mask in entity_prompt_emb_mask] # Handle None case in ComfyUI (None means no padding, all tokens valid) if prompt_emb_mask is not None: - global_seq_len = int(prompt_emb_mask.sum(dim=1).item()) + global_seq_len = int(prompt_emb_mask.sum(dim=1)[0].item()) else: # No mask = no padding, use full sequence length global_seq_len = int(prompt_emb.shape[1]) # Get base image RoPE using global prompt length (returns tuple: (img_freqs, txt_freqs)) + # We pass a single shape, not repeated for batch, because RoPE will broadcast txt_seq_lens = [global_seq_len] - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device) + image_rotary_emb = self.pos_embed(img_shapes_single, txt_seq_lens, device=latents.device) # Create SEPARATE RoPE embeddings for each entity # Each entity gets its own positional encoding based on its sequence length - entity_rotary_emb = [self.pos_embed(img_shapes, [entity_seq_len], device=latents.device)[1] + # We only need to create these once since they're the same for all batch elements + entity_rotary_emb = [self.pos_embed([(1, patch_h, patch_w)], [entity_seq_len], device=latents.device)[1] for entity_seq_len in entity_seq_lens] # Concatenate entity RoPEs with global RoPE along sequence dimension # Result: [entity1_seq, entity2_seq, ..., global_seq] concatenated + # This creates the RoPE for ONE batch element's sequence + # Note: We DON'T repeat for batch_size because the queries/keys have shape [batch, seq, ...] + # and PyTorch will broadcast the RoPE [seq, ...] across the batch dimension automatically txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0) + logger.debug( + f"[EliGen Model] RoPE created for single batch element - " + f"img: {image_rotary_emb[0].shape}, txt: {txt_rotary_emb.shape} " + f"(both will broadcast across batch_size={batch_size})" + ) + # Replace text part of tuple with concatenated entity + global RoPE image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb) # SECTION 3: Prepare spatial masks - repeat_dim = latents.shape[1] # 16 + repeat_dim = latents.shape[1] # 16 (latent channels) max_masks = entity_masks.shape[1] # N entities entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) # Pad masks to match padded latent dimensions # entity_masks shape: [1, N, 16, H/8, W/8] # Need to pad to match orig_shape which is [B, 16, padded_H/8, padded_W/8] - padded_h = height // 8 - padded_w = width // 8 + padded_h = height // self.LATENT_TO_PIXEL_RATIO + padded_w = width // self.LATENT_TO_PIXEL_RATIO if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w: assert entity_masks.shape[3] <= padded_h and entity_masks.shape[4] <= padded_w, \ f"Entity masks {entity_masks.shape[3]}x{entity_masks.shape[4]} larger than padded dims {padded_h}x{padded_w}" @@ -602,6 +637,7 @@ class QwenImageTransformer2DModel(nn.Module): # Pad each entity mask pad_h = padded_h - entity_masks.shape[3] pad_w = padded_w - entity_masks.shape[4] + logger.debug(f"[EliGen Model] Padding entity masks by ({pad_h}, {pad_w}) to match latent dimensions") entity_masks = torch.nn.functional.pad(entity_masks, (0, pad_w, 0, pad_h), mode='constant', value=0) entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] @@ -617,12 +653,20 @@ class QwenImageTransformer2DModel(nn.Module): seq_lens = entity_seq_lens + [global_seq_len] total_seq_len = int(sum(seq_lens) + image.shape[1]) + logger.debug( + f"[EliGen Model] Building attention mask: " + f"total_seq={total_seq_len} (entities: {entity_seq_lens}, global: {global_seq_len}, image: {image.shape[1]})" + ) + patched_masks = [] for i in range(N): patched_mask = rearrange( entity_masks[i], "B C (H P) (W Q) -> B (H W) (C P Q)", - H=height//16, W=width//16, P=2, Q=2 + H=height // self.PATCH_TO_PIXEL_RATIO, + W=width // self.PATCH_TO_PIXEL_RATIO, + P=self.PATCH_TO_LATENT_RATIO, + Q=self.PATCH_TO_LATENT_RATIO ) patched_masks.append(patched_mask) @@ -671,10 +715,16 @@ class QwenImageTransformer2DModel(nn.Module): # SECTION 6: Convert to additive bias attention_mask = attention_mask.float() + num_valid_connections = (attention_mask == 1).sum().item() attention_mask[attention_mask == 0] = float('-inf') attention_mask[attention_mask == 1] = 0 attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1) + logger.debug( + f"[EliGen Model] Attention mask created: shape={attention_mask.shape}, " + f"valid_connections={num_valid_connections}/{total_seq_len * total_seq_len}" + ) + return all_prompt_emb, image_rotary_emb, attention_mask def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs): diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index d90707a49..f59c84d54 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -3,9 +3,13 @@ import comfy.utils import comfy.conds import math import torch +import logging +from typing import Optional from typing_extensions import override from comfy_api.latest import ComfyExtension, io +logger = logging.getLogger(__name__) + class TextEncodeQwenImageEdit(io.ComfyNode): @classmethod @@ -105,8 +109,31 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode): conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True) return io.NodeOutput(conditioning) -################ NEW class TextEncodeQwenImageEliGen(io.ComfyNode): + """ + Entity-Level Image Generation (EliGen) conditioning node for Qwen Image model. + + Allows specifying different prompts for different spatial regions using masks. + Each entity (mask + prompt pair) will only influence its masked region through + spatial attention masking. + + Features: + - Supports up to 3 entities per generation + - Spatial attention masks prevent cross-entity contamination + - Separate RoPE embeddings per entity (research-accurate) + - Falls back to standard generation if no entities provided + + Usage: + 1. Create spatial masks using LoadImageMask (white=entity, black=background) + 2. Use 'red', 'green', or 'blue' channel (NOT 'alpha' - it gets inverted) + 3. Provide entity-specific prompts for each masked region + + Based on DiffSynth Studio: https://github.com/modelscope/DiffSynth-Studio + """ + + # Qwen Image model uses 2x2 patches on latents (which are 8x downsampled from pixels) + PATCH_SIZE = 2 + @classmethod def define_schema(cls): return io.Schema( @@ -129,8 +156,18 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): ) @classmethod - def execute(cls, clip, global_conditioning, latent, entity_prompt_1="", entity_mask_1=None, - entity_prompt_2="", entity_mask_2=None, entity_prompt_3="", entity_mask_3=None) -> io.NodeOutput: + def execute( + cls, + clip, + global_conditioning, + latent, + entity_prompt_1: str = "", + entity_mask_1: Optional[torch.Tensor] = None, + entity_prompt_2: str = "", + entity_mask_2: Optional[torch.Tensor] = None, + entity_prompt_3: str = "", + entity_mask_3: Optional[torch.Tensor] = None + ) -> io.NodeOutput: # Extract dimensions from latent tensor # latent["samples"] shape: [batch, channels, latent_h, latent_w] @@ -139,10 +176,9 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): unpadded_latent_width = latent_samples.shape[3] # Unpadded latent space # Calculate padded dimensions (same logic as model's pad_to_patch_size with patch_size=2) - # The model pads latents to be multiples of patch_size (2 for Qwen) - patch_size = 2 - pad_h = (patch_size - unpadded_latent_height % patch_size) % patch_size - pad_w = (patch_size - unpadded_latent_width % patch_size) % patch_size + # The model pads latents to be multiples of PATCH_SIZE + pad_h = (cls.PATCH_SIZE - unpadded_latent_height % cls.PATCH_SIZE) % cls.PATCH_SIZE + pad_w = (cls.PATCH_SIZE - unpadded_latent_width % cls.PATCH_SIZE) % cls.PATCH_SIZE latent_height = unpadded_latent_height + pad_h # Padded latent dimensions latent_width = unpadded_latent_width + pad_w # Padded latent dimensions @@ -150,8 +186,8 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): width = latent_width * 8 if pad_h > 0 or pad_w > 0: - print(f"[EliGen] Latent padding detected: {unpadded_latent_height}x{unpadded_latent_width} → {latent_height}x{latent_width}") - print(f"[EliGen] Target generation dimensions: {height}x{width} pixels ({latent_height}x{latent_width} latent)") + logger.debug(f"[EliGen] Latent padding detected: {unpadded_latent_height}x{unpadded_latent_width} → {latent_height}x{latent_width}") + logger.debug(f"[EliGen] Target generation dimensions: {height}x{width} pixels ({latent_height}x{latent_width} latent)") # Collect entity prompts and masks entity_prompts = [entity_prompt_1, entity_prompt_2, entity_prompt_3] @@ -166,7 +202,7 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): # Log warning if some entities were skipped total_prompts_provided = len([p for p in entity_prompts if p.strip()]) if len(valid_entities) < total_prompts_provided: - print(f"[EliGen] Warning: Only {len(valid_entities)} of {total_prompts_provided} entity prompts have valid masks") + logger.warning(f"[EliGen] Only {len(valid_entities)} of {total_prompts_provided} entity prompts have valid masks") # If no valid entities, return standard conditioning if len(valid_entities) == 0: @@ -200,7 +236,37 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): # This is different from IMAGE type which is [batch, height, width, channels] mask_tensor = mask - # Log original mask dimensions + # Validate mask dtype + if mask_tensor.dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise TypeError( + f"Entity {i+1} mask has invalid dtype {mask_tensor.dtype}. " + f"Expected float32, float16, or bfloat16. " + f"Ensure you're using LoadImageMask node, not LoadImage." + ) + + # Log original mask statistics + logger.debug( + f"[EliGen] Entity {i+1} input mask: shape={mask_tensor.shape}, " + f"dtype={mask_tensor.dtype}, min={mask_tensor.min():.4f}, max={mask_tensor.max():.4f}" + ) + + # Check for all-zero masks (common error when wrong channel selected) + if mask_tensor.max() == 0.0: + raise ValueError( + f"Entity {i+1} mask is all zeros! This usually means:\n" + f" 1. Wrong channel selected in LoadImageMask (use 'red', 'green', or 'blue', NOT 'alpha')\n" + f" 2. Your mask image is completely black\n" + f" 3. The mask file failed to load" + ) + + # Check for constant masks (no variation) + if mask_tensor.min() == mask_tensor.max() and mask_tensor.max() > 0: + logger.warning( + f"[EliGen] Entity {i+1} mask has no variation (all pixels = {mask_tensor.min():.4f}). " + f"This entity will affect the entire image." + ) + + # Extract original dimensions original_shape = mask_tensor.shape if len(original_shape) == 2: # [height, width] - single mask without batch @@ -211,7 +277,20 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): # [batch, height, width] - standard MASK format orig_h, orig_w = original_shape[1], original_shape[2] else: - raise ValueError(f"Unexpected mask shape: {original_shape}. Expected [H, W] or [B, H, W]") + raise ValueError( + f"Entity {i+1} has unexpected mask shape: {original_shape}. " + f"Expected [H, W] or [B, H, W]. Got {len(original_shape)} dimensions." + ) + + # Log size mismatch if mask doesn't match expected latent dimensions + expected_h, expected_w = latent_height * 8, latent_width * 8 + if orig_h != expected_h or orig_w != expected_w: + logger.info( + f"[EliGen] Entity {i+1} mask size mismatch: {orig_h}x{orig_w} vs expected {expected_h}x{expected_w}. " + f"Will resize to {latent_height}x{latent_width} latent space." + ) + else: + logger.debug(f"[EliGen] Entity {i+1} mask: {orig_h}x{orig_w} → will resize to {latent_height}x{latent_width} latent") # Convert MASK format [batch, height, width] to [batch, 1, height, width] for common_upscale # common_upscale expects [batch, channels, height, width] @@ -233,17 +312,32 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): # Log how many pixels are active in the mask active_pixels = (resized_mask > 0).sum().item() total_pixels = resized_mask.numel() + coverage_pct = 100 * active_pixels / total_pixels if total_pixels > 0 else 0 + + if active_pixels == 0: + raise ValueError( + f"Entity {i+1} mask has no active pixels after resizing to latent space! " + f"Original mask may have been too small or all black." + ) + + logger.debug( + f"[EliGen] Entity {i+1} mask coverage: {active_pixels}/{total_pixels} pixels ({coverage_pct:.1f}%)" + ) processed_entity_masks.append(resized_mask) # Stack masks: [batch, num_entities, 1, latent_height, latent_width] # Each item in processed_entity_masks has shape [1, 1, H, W] (batch=1, channel=1) # We need to remove batch dim, stack, then add it back - # Option 1: Squeeze batch dim from each mask - processed_no_batch = [m.squeeze(0) for m in processed_entity_masks] # Each: [1, H, W] - entity_masks_tensor = torch.stack(processed_no_batch, dim=0) # [num_entities, 1, H, W] + processed_entity_masks_no_batch = [m.squeeze(0) for m in processed_entity_masks] # Each: [1, H, W] + entity_masks_tensor = torch.stack(processed_entity_masks_no_batch, dim=0) # [num_entities, 1, H, W] entity_masks_tensor = entity_masks_tensor.unsqueeze(0) # [1, num_entities, 1, H, W] + logger.debug( + f"[EliGen] Stacked {len(valid_entities)} entity masks into tensor: " + f"shape={entity_masks_tensor.shape} (expected: [1, {len(valid_entities)}, 1, {latent_height}, {latent_width}])" + ) + # Extract global prompt embedding and mask from conditioning # Conditioning format: [[cond_tensor, extra_dict]] global_prompt_emb = global_conditioning[0][0] # The embedding tensor directly