diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 8c75670cd..23ccb87d3 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -2,8 +2,9 @@ import torch import torch.nn as nn import torch.nn.functional as F +import logging from typing import Optional, Tuple -from einops import repeat +from einops import repeat, rearrange from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps from comfy.ldm.modules.attention import optimized_attention_masked @@ -54,7 +55,6 @@ class FeedForward(nn.Module): def apply_rotary_emb(x, freqs_cis): if x.shape[1] == 0: return x - t_ = x.reshape(*x.shape[:-1], -1, 1, 2) t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] return t_out.reshape(*x.shape) @@ -229,6 +229,7 @@ class QwenImageTransformerBlock(nn.Module): encoder_hidden_states_mask: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, transformer_options={}, ) -> Tuple[torch.Tensor, torch.Tensor]: img_mod_params = self.img_mod(temb) @@ -245,6 +246,7 @@ class QwenImageTransformerBlock(nn.Module): hidden_states=img_modulated, encoder_hidden_states=txt_modulated, encoder_hidden_states_mask=encoder_hidden_states_mask, + attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, transformer_options=transformer_options, ) @@ -288,8 +290,12 @@ class LastLayer(nn.Module): x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :]) return x - class QwenImageTransformer2DModel(nn.Module): + # Constants for EliGen processing + LATENT_TO_PIXEL_RATIO = 8 # Latents are 8x downsampled from pixel space + PATCH_TO_LATENT_RATIO = 2 # 2x2 patches in latent space + PATCH_TO_PIXEL_RATIO = 16 # Combined: 2x2 patches on 8x downsampled latents = 16x in pixel space + def __init__( self, patch_size: int = 2, @@ -316,7 +322,6 @@ class QwenImageTransformer2DModel(nn.Module): self.inner_dim = num_attention_heads * attention_head_dim self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope)) - self.time_text_embed = QwenTimestepProjEmbeddings( embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim, @@ -365,6 +370,214 @@ class QwenImageTransformer2DModel(nn.Module): img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2) return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape + def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, + entity_prompt_emb_mask, entity_masks, height, width, image, + cond_or_uncond=None, batch_size=None): + """ + Process entity masks and build spatial attention mask for EliGen. + + Concatenates entity+global prompts, builds RoPE embeddings, creates attention mask + enforcing spatial restrictions, and handles CFG batching with separate masks. + + Based on: https://github.com/modelscope/DiffSynth-Studio + """ + num_entities = len(entity_prompt_emb) + actual_batch_size = latents.shape[0] + + has_positive = cond_or_uncond and 0 in cond_or_uncond + has_negative = cond_or_uncond and 1 in cond_or_uncond + is_cfg_batched = has_positive and has_negative + + logging.debug( + f"[EliGen Model] Processing {num_entities} entities for {height}x{width}px, " + f"batch_size={actual_batch_size}, CFG_batched={is_cfg_batched}" + ) + + # Concatenate entity + global prompts + all_prompt_emb = entity_prompt_emb + [prompt_emb] + all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb] + all_prompt_emb = torch.cat(all_prompt_emb, dim=1) + + # Build RoPE embeddings + patch_h = height // self.PATCH_TO_PIXEL_RATIO + patch_w = width // self.PATCH_TO_PIXEL_RATIO + + entity_seq_lens = [int(mask.sum(dim=1)[0].item()) for mask in entity_prompt_emb_mask] + + if prompt_emb_mask is not None: + global_seq_len = int(prompt_emb_mask.sum(dim=1)[0].item()) + else: + global_seq_len = int(prompt_emb.shape[1]) + + max_vid_index = max(patch_h // 2, patch_w // 2) + + # Generate per-entity text RoPE (each entity starts from same offset) + entity_txt_embs = [] + for entity_seq_len in entity_seq_lens: + entity_ids = torch.arange( + max_vid_index, + max_vid_index + entity_seq_len, + device=latents.device + ).reshape(1, -1, 1).repeat(1, 1, 3) + + entity_rope = self.pe_embedder(entity_ids) # Keep shape [1, 1, seq, dim, 2, 2] + entity_txt_embs.append(entity_rope) + + # Generate global text RoPE + global_ids = torch.arange( + max_vid_index, + max_vid_index + global_seq_len, + device=latents.device + ).reshape(1, -1, 1).repeat(1, 1, 3) + global_rope = self.pe_embedder(global_ids) # Keep shape [1, 1, seq, dim, 2, 2] + + txt_rotary_emb = torch.cat(entity_txt_embs + [global_rope], dim=2) # Concatenate on sequence dimension + + h_coords = torch.arange(-(patch_h - patch_h // 2), patch_h // 2, device=latents.device) + w_coords = torch.arange(-(patch_w - patch_w // 2), patch_w // 2, device=latents.device) + + img_ids = torch.zeros((patch_h, patch_w, 3), device=latents.device) + img_ids[:, :, 0] = 0 + img_ids[:, :, 1] = h_coords.unsqueeze(1) + img_ids[:, :, 2] = w_coords.unsqueeze(0) + img_ids = img_ids.reshape(1, -1, 3) + + img_rope = self.pe_embedder(img_ids) # Keep shape [1, 1, seq, dim, 2, 2] + + logging.debug(f"[EliGen Model] RoPE shapes - img: {img_rope.shape}, txt: {txt_rotary_emb.shape}") + + # Concatenate text and image RoPE embeddings on sequence dimension + # Shape will be [1, 1, total_seq, dim, 2, 2] where total_seq = txt_seq + img_seq + image_rotary_emb = torch.cat([txt_rotary_emb, img_rope], dim=2).to(dtype=latents.dtype) + + # Prepare spatial masks + repeat_dim = latents.shape[1] + max_masks = entity_masks.shape[1] + entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) + + padded_h = height // self.LATENT_TO_PIXEL_RATIO + padded_w = width // self.LATENT_TO_PIXEL_RATIO + if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w: + pad_h = padded_h - entity_masks.shape[3] + pad_w = padded_w - entity_masks.shape[4] + logging.debug(f"[EliGen Model] Padding masks by ({pad_h}, {pad_w})") + entity_masks = torch.nn.functional.pad(entity_masks, (0, pad_w, 0, pad_h), mode='constant', value=0) + + entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] + + global_mask = torch.ones((entity_masks[0].shape[0], entity_masks[0].shape[1], padded_h, padded_w), + device=latents.device, dtype=latents.dtype) + entity_masks = entity_masks + [global_mask] + + # Patchify masks + N = len(entity_masks) + batch_size = int(entity_masks[0].shape[0]) + seq_lens = entity_seq_lens + [global_seq_len] + total_seq_len = int(sum(seq_lens) + image.shape[1]) + + logging.debug(f"[EliGen Model] total_seq={total_seq_len}") + + patched_masks = [] + for i in range(N): + patched_mask = rearrange( + entity_masks[i], + "B C (H P) (W Q) -> B (H W) (C P Q)", + H=height // self.PATCH_TO_PIXEL_RATIO, + W=width // self.PATCH_TO_PIXEL_RATIO, + P=self.PATCH_TO_LATENT_RATIO, + Q=self.PATCH_TO_LATENT_RATIO + ) + patched_masks.append(patched_mask) + + # Build attention mask matrix + attention_mask = torch.ones( + (batch_size, total_seq_len, total_seq_len), + dtype=torch.bool + ).to(device=entity_masks[0].device) + + # Calculate positions + image_start = int(sum(seq_lens)) + image_end = int(total_seq_len) + cumsum = [0] + single_image_seq = int(image_end - image_start) + + for length in seq_lens: + cumsum.append(cumsum[-1] + length) + + # Spatial restriction (prompt <-> image) + for i in range(N): + prompt_start = cumsum[i] + prompt_end = cumsum[i+1] + + # Create binary mask for which image patches this entity can attend to + image_mask = torch.sum(patched_masks[i], dim=-1) > 0 + image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1) + + # Always repeat mask to match image sequence length + repeat_time = single_image_seq // image_mask.shape[-1] + image_mask = image_mask.repeat(1, 1, repeat_time) + + # Bidirectional restriction: + # - Entity prompt can only attend to its masked image regions + attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask + # - Image patches can only be updated by prompts that own them + attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2) + + # Entity isolation + for i in range(N): + for j in range(N): + if i == j: + continue + start_i, end_i = cumsum[i], cumsum[i+1] + start_j, end_j = cumsum[j], cumsum[j+1] + attention_mask[:, start_i:end_i, start_j:end_j] = False + + # Convert to additive bias and handle CFG batching + attention_mask = attention_mask.float() + num_valid_connections = (attention_mask == 1).sum().item() + attention_mask[attention_mask == 0] = float('-inf') + attention_mask[attention_mask == 1] = 0 + attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype) + + # Handle CFG batching: Create separate masks for positive and negative + if is_cfg_batched and actual_batch_size > 1: + # CFG batch: [positive, negative] - need different masks for each + # Positive gets entity constraints, negative gets standard attention (all zeros) + + logging.debug( + "[EliGen Model] CFG batched detected - creating separate masks. " + "Positive (index 0) gets entity mask, Negative (index 1) gets standard mask" + ) + + # Create standard attention mask (all zeros = no constraints) + standard_mask = torch.zeros_like(attention_mask) + + # Stack masks according to cond_or_uncond order + mask_list = [] + for cond_type in cond_or_uncond: + if cond_type == 0: # Positive - use entity mask + mask_list.append(attention_mask[0:1]) # Take first (and only) entity mask + else: # Negative - use standard mask + mask_list.append(standard_mask[0:1]) + + # Concatenate masks to match batch + attention_mask = torch.cat(mask_list, dim=0) + + logging.debug( + f"[EliGen Model] Created {len(mask_list)} masks for CFG batch. " + f"Final shape: {attention_mask.shape}" + ) + + # Add head dimension: [B, 1, seq, seq] + attention_mask = attention_mask.unsqueeze(1) + + logging.debug( + f"[EliGen Model] Attention mask created: shape={attention_mask.shape}, " + f"valid_connections={num_valid_connections}/{total_seq_len * total_seq_len}" + ) + + return all_prompt_emb, image_rotary_emb, attention_mask + def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, @@ -416,15 +629,60 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = torch.cat([hidden_states, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) - txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) - txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) - ids = torch.cat((txt_ids, img_ids), dim=1) - image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() - del ids, txt_ids, img_ids + # Initialize attention mask (None for standard generation) + eligen_attention_mask = None - hidden_states = self.img_in(hidden_states) - encoder_hidden_states = self.txt_norm(encoder_hidden_states) - encoder_hidden_states = self.txt_in(encoder_hidden_states) + # Extract EliGen entity data + entity_prompt_emb = kwargs.get("entity_prompt_emb", None) + entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None) + entity_masks = kwargs.get("entity_masks", None) + + # Detect batch composition for CFG handling + cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else [] + is_positive_cond = 0 in cond_or_uncond + is_negative_cond = 1 in cond_or_uncond + batch_size = x.shape[0] + + if entity_prompt_emb is not None: + logging.debug( + f"[EliGen Forward] batch_size={batch_size}, cond_or_uncond={cond_or_uncond}, " + f"has_positive={is_positive_cond}, has_negative={is_negative_cond}" + ) + + if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond: + # EliGen path + height = int(orig_shape[-2] * self.LATENT_TO_PIXEL_RATIO) + width = int(orig_shape[-1] * self.LATENT_TO_PIXEL_RATIO) + + encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks( + latents=x, + prompt_emb=encoder_hidden_states, + prompt_emb_mask=encoder_hidden_states_mask, + entity_prompt_emb=entity_prompt_emb, + entity_prompt_emb_mask=entity_prompt_emb_mask, + entity_masks=entity_masks, + height=height, + width=width, + image=hidden_states, + cond_or_uncond=cond_or_uncond, + batch_size=batch_size + ) + + hidden_states = self.img_in(hidden_states) + + del img_ids + + else: + # Standard path + txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) + txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() + del ids, txt_ids, img_ids + + hidden_states = self.img_in(hidden_states) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) if guidance is not None: guidance = guidance * 1000 @@ -446,9 +704,25 @@ class QwenImageTransformer2DModel(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"]) + out["txt"], out["img"] = block( + hidden_states=args["img"], + encoder_hidden_states=args["txt"], + encoder_hidden_states_mask=args.get("encoder_hidden_states_mask"), + temb=args["vec"], + image_rotary_emb=args["pe"], + attention_mask=args.get("attention_mask"), + transformer_options=args["transformer_options"] + ) return out - out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({ + "img": hidden_states, + "txt": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "attention_mask": eligen_attention_mask, + "vec": temb, + "pe": image_rotary_emb, + "transformer_options": transformer_options + }, {"original_block": block_wrap}) hidden_states = out["img"] encoder_hidden_states = out["txt"] else: @@ -458,6 +732,7 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, + attention_mask=eligen_attention_mask, transformer_options=transformer_options, ) diff --git a/comfy/model_base.py b/comfy/model_base.py index 6b8a8454d..8cbd6de3d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1484,6 +1484,19 @@ class QwenImage(BaseModel): ref_latents_method = kwargs.get("reference_latents_method", None) if ref_latents_method is not None: out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method) + + # Handle EliGen entity data + entity_prompt_emb = kwargs.get("entity_prompt_emb", None) + if entity_prompt_emb is not None: + out['entity_prompt_emb'] = comfy.conds.CONDList(entity_prompt_emb) + + entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None) + if entity_prompt_emb_mask is not None: + out['entity_prompt_emb_mask'] = comfy.conds.CONDList(entity_prompt_emb_mask) + + entity_masks = kwargs.get("entity_masks", None) + if entity_masks is not None: + out['entity_masks'] = comfy.conds.CONDRegular(entity_masks) return out def extra_conds_shapes(self, **kwargs): diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index 525239ae5..9ad258add 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -1,6 +1,10 @@ import node_helpers import comfy.utils +import comfy.conds import math +import torch +import logging +from typing import Optional from typing_extensions import override from comfy_api.latest import ComfyExtension, io @@ -103,6 +107,281 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode): conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True) return io.NodeOutput(conditioning) +class TextEncodeQwenImageEliGen(io.ComfyNode): + """ + Entity-Level Image Generation (EliGen) conditioning node for Qwen Image model. + + Allows specifying different prompts for different spatial regions using masks. + Each entity (mask + prompt pair) will only influence its masked region through + spatial attention masking. + + Features: + - Supports up to 8 entities per generation + - Spatial attention masks prevent cross-entity contamination + - Separate RoPE embeddings per entity (research-accurate) + - Falls back to standard generation if no entities provided + + Usage: + 1. Create spatial masks using LoadImageMask (white=entity, black=background) + 2. Use 'red', 'green', or 'blue' channel (NOT 'alpha' - it gets inverted) + 3. Provide entity-specific prompts for each masked region + + Based on DiffSynth Studio: https://github.com/modelscope/DiffSynth-Studio + """ + + # Qwen Image model uses 2x2 patches on latents (which are 8x downsampled from pixels) + PATCH_SIZE = 2 + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeQwenImageEliGen", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.Conditioning.Input("global_conditioning"), + io.Latent.Input("latent"), + io.Mask.Input("entity_mask_1", optional=True), + io.String.Input("entity_prompt_1", multiline=True, dynamic_prompts=True, default=""), + io.Mask.Input("entity_mask_2", optional=True), + io.String.Input("entity_prompt_2", multiline=True, dynamic_prompts=True, default=""), + io.Mask.Input("entity_mask_3", optional=True), + io.String.Input("entity_prompt_3", multiline=True, dynamic_prompts=True, default=""), + io.Mask.Input("entity_mask_4", optional=True), + io.String.Input("entity_prompt_4", multiline=True, dynamic_prompts=True, default=""), + io.Mask.Input("entity_mask_5", optional=True), + io.String.Input("entity_prompt_5", multiline=True, dynamic_prompts=True, default=""), + io.Mask.Input("entity_mask_6", optional=True), + io.String.Input("entity_prompt_6", multiline=True, dynamic_prompts=True, default=""), + io.Mask.Input("entity_mask_7", optional=True), + io.String.Input("entity_prompt_7", multiline=True, dynamic_prompts=True, default=""), + io.Mask.Input("entity_mask_8", optional=True), + io.String.Input("entity_prompt_8", multiline=True, dynamic_prompts=True, default=""), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute( + cls, + clip, + global_conditioning, + latent, + entity_prompt_1: str = "", + entity_mask_1: Optional[torch.Tensor] = None, + entity_prompt_2: str = "", + entity_mask_2: Optional[torch.Tensor] = None, + entity_prompt_3: str = "", + entity_mask_3: Optional[torch.Tensor] = None, + entity_prompt_4: str = "", + entity_mask_4: Optional[torch.Tensor] = None, + entity_prompt_5: str = "", + entity_mask_5: Optional[torch.Tensor] = None, + entity_prompt_6: str = "", + entity_mask_6: Optional[torch.Tensor] = None, + entity_prompt_7: str = "", + entity_mask_7: Optional[torch.Tensor] = None, + entity_prompt_8: str = "", + entity_mask_8: Optional[torch.Tensor] = None + ) -> io.NodeOutput: + + # Extract dimensions from latent tensor + # latent["samples"] shape: [batch, channels, latent_h, latent_w] + latent_samples = latent["samples"] + unpadded_latent_height = latent_samples.shape[2] # Unpadded latent space + unpadded_latent_width = latent_samples.shape[3] # Unpadded latent space + + # Calculate padded dimensions (same logic as model's pad_to_patch_size with patch_size=2) + # The model pads latents to be multiples of PATCH_SIZE + pad_h = (cls.PATCH_SIZE - unpadded_latent_height % cls.PATCH_SIZE) % cls.PATCH_SIZE + pad_w = (cls.PATCH_SIZE - unpadded_latent_width % cls.PATCH_SIZE) % cls.PATCH_SIZE + latent_height = unpadded_latent_height + pad_h # Padded latent dimensions + latent_width = unpadded_latent_width + pad_w # Padded latent dimensions + + height = latent_height * 8 # Convert to pixel space for logging + width = latent_width * 8 + + if pad_h > 0 or pad_w > 0: + logging.debug(f"[EliGen] Latent padding detected: {unpadded_latent_height}x{unpadded_latent_width} → {latent_height}x{latent_width}") + logging.debug(f"[EliGen] Target generation dimensions: {height}x{width} pixels ({latent_height}x{latent_width} latent)") + + # Collect entity prompts and masks + entity_prompts = [entity_prompt_1, entity_prompt_2, entity_prompt_3, entity_prompt_4, entity_prompt_5, entity_prompt_6, entity_prompt_7, entity_prompt_8] + entity_masks_raw = [entity_mask_1, entity_mask_2, entity_mask_3, entity_mask_4, entity_mask_5, entity_mask_6, entity_mask_7, entity_mask_8] + + # Filter out entities with empty prompts or missing masks + valid_entities = [] + for prompt, mask in zip(entity_prompts, entity_masks_raw): + if prompt.strip() and mask is not None: + valid_entities.append((prompt, mask)) + + # Log warning if some entities were skipped + total_prompts_provided = len([p for p in entity_prompts if p.strip()]) + if len(valid_entities) < total_prompts_provided: + logging.warning(f"[EliGen] Only {len(valid_entities)} of {total_prompts_provided} entity prompts have valid masks") + + # If no valid entities, return standard conditioning + if len(valid_entities) == 0: + return io.NodeOutput(global_conditioning) + + # Encode each entity prompt separately + entity_prompt_emb_list = [] + entity_prompt_emb_mask_list = [] + + for entity_prompt, _ in valid_entities: # mask not used at this point + entity_tokens = clip.tokenize(entity_prompt) + entity_cond_dict = clip.encode_from_tokens(entity_tokens, return_pooled=True, return_dict=True) + entity_prompt_emb = entity_cond_dict["cond"] + entity_prompt_emb_mask = entity_cond_dict.get("attention_mask", None) + + # If no attention mask in extra_dict, create one (all True) + if entity_prompt_emb_mask is None: + seq_len = entity_prompt_emb.shape[1] + entity_prompt_emb_mask = torch.ones((entity_prompt_emb.shape[0], seq_len), + dtype=torch.bool, device=entity_prompt_emb.device) + + + entity_prompt_emb_list.append(entity_prompt_emb) + entity_prompt_emb_mask_list.append(entity_prompt_emb_mask) + + # Process spatial masks to latent space + processed_entity_masks = [] + for i, (_, mask) in enumerate(valid_entities): + # MASK type format: [batch, height, width] (no channel dimension) + # This is different from IMAGE type which is [batch, height, width, channels] + mask_tensor = mask + + # Validate mask dtype + if mask_tensor.dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise TypeError( + f"Entity {i+1} mask has invalid dtype {mask_tensor.dtype}. " + f"Expected float32, float16, or bfloat16. " + f"Ensure you're using LoadImageMask node, not LoadImage." + ) + + # Log original mask statistics + logging.debug( + f"[EliGen] Entity {i+1} input mask: shape={mask_tensor.shape}, " + f"dtype={mask_tensor.dtype}, min={mask_tensor.min():.4f}, max={mask_tensor.max():.4f}" + ) + + # Check for all-zero masks (common error when wrong channel selected) + if mask_tensor.max() == 0.0: + raise ValueError( + f"Entity {i+1} mask is all zeros! This usually means:\n" + f" 1. Wrong channel selected in LoadImageMask (use 'red', 'green', or 'blue', NOT 'alpha')\n" + f" 2. Your mask image is completely black\n" + f" 3. The mask file failed to load" + ) + + # Check for constant masks (no variation) + if mask_tensor.min() == mask_tensor.max() and mask_tensor.max() > 0: + logging.warning( + f"[EliGen] Entity {i+1} mask has no variation (all pixels = {mask_tensor.min():.4f}). " + f"This entity will affect the entire image." + ) + + # Extract original dimensions + original_shape = mask_tensor.shape + if len(original_shape) == 2: + # [height, width] - single mask without batch + orig_h, orig_w = original_shape[0], original_shape[1] + # Add batch dimension: [1, height, width] + mask_tensor = mask_tensor.unsqueeze(0) + elif len(original_shape) == 3: + # [batch, height, width] - standard MASK format + orig_h, orig_w = original_shape[1], original_shape[2] + else: + raise ValueError( + f"Entity {i+1} has unexpected mask shape: {original_shape}. " + f"Expected [H, W] or [B, H, W]. Got {len(original_shape)} dimensions." + ) + + # Log size mismatch if mask doesn't match expected latent dimensions + expected_h, expected_w = latent_height * 8, latent_width * 8 + if orig_h != expected_h or orig_w != expected_w: + logging.info( + f"[EliGen] Entity {i+1} mask size mismatch: {orig_h}x{orig_w} vs expected {expected_h}x{expected_w}. " + f"Will resize to {latent_height}x{latent_width} latent space." + ) + else: + logging.debug(f"[EliGen] Entity {i+1} mask: {orig_h}x{orig_w} → will resize to {latent_height}x{latent_width} latent") + + # Convert MASK format [batch, height, width] to [batch, 1, height, width] for common_upscale + # common_upscale expects [batch, channels, height, width] + mask_tensor = mask_tensor.unsqueeze(1) # Add channel dimension: [batch, 1, height, width] + + # Resize to latent space dimensions using nearest neighbor + resized_mask = comfy.utils.common_upscale( + mask_tensor, + latent_width, + latent_height, + upscale_method="nearest-exact", + crop="disabled" + ) + + # Threshold to binary (0 or 1) + # Use > 0 instead of > 0.5 to preserve edge pixels from nearest-neighbor downsampling + resized_mask = (resized_mask > 0).float() + + # Log how many pixels are active in the mask + active_pixels = (resized_mask > 0).sum().item() + total_pixels = resized_mask.numel() + coverage_pct = 100 * active_pixels / total_pixels if total_pixels > 0 else 0 + + if active_pixels == 0: + raise ValueError( + f"Entity {i+1} mask has no active pixels after resizing to latent space! " + f"Original mask may have been too small or all black." + ) + + logging.debug( + f"[EliGen] Entity {i+1} mask coverage: {active_pixels}/{total_pixels} pixels ({coverage_pct:.1f}%)" + ) + + processed_entity_masks.append(resized_mask) + + # Stack masks: [batch, num_entities, 1, latent_height, latent_width] + # Each item in processed_entity_masks has shape [1, 1, H, W] (batch=1, channel=1) + # We need to remove batch dim, stack, then add it back + processed_entity_masks_no_batch = [m.squeeze(0) for m in processed_entity_masks] # Each: [1, H, W] + entity_masks_tensor = torch.stack(processed_entity_masks_no_batch, dim=0) # [num_entities, 1, H, W] + entity_masks_tensor = entity_masks_tensor.unsqueeze(0) # [1, num_entities, 1, H, W] + + logging.debug( + f"[EliGen] Stacked {len(valid_entities)} entity masks into tensor: " + f"shape={entity_masks_tensor.shape} (expected: [1, {len(valid_entities)}, 1, {latent_height}, {latent_width}])" + ) + + # Extract global prompt embedding and mask from conditioning + # Conditioning format: [[cond_tensor, extra_dict]] + global_prompt_emb = global_conditioning[0][0] # The embedding tensor directly + global_extra_dict = global_conditioning[0][1] # Metadata dict + + global_prompt_emb_mask = global_extra_dict.get("attention_mask", None) + + # If no attention mask, create one (all True) + if global_prompt_emb_mask is None: + global_prompt_emb_mask = torch.ones((global_prompt_emb.shape[0], global_prompt_emb.shape[1]), + dtype=torch.bool, device=global_prompt_emb.device) + + # Attach entity data to conditioning using conditioning_set_values + entity_data = { + "entity_prompt_emb": entity_prompt_emb_list, + "entity_prompt_emb_mask": entity_prompt_emb_mask_list, + "entity_masks": entity_masks_tensor, + } + + conditioning_with_entities = node_helpers.conditioning_set_values( + global_conditioning, + entity_data, + append=True + ) + + return io.NodeOutput(conditioning_with_entities) + class QwenExtension(ComfyExtension): @override @@ -110,6 +389,7 @@ class QwenExtension(ComfyExtension): return [ TextEncodeQwenImageEdit, TextEncodeQwenImageEditPlus, + TextEncodeQwenImageEliGen, ]